Commit ed1e4f8e authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

[python] Dataset params back up before training (#786)

* params back up

* refine logic
parent 2367b463
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""Wrapper c_api of LightGBM""" """Wrapper c_api of LightGBM"""
from __future__ import absolute_import from __future__ import absolute_import
import copy
import ctypes import ctypes
import os import os
import warnings import warnings
...@@ -591,11 +592,12 @@ class Dataset(object): ...@@ -591,11 +592,12 @@ class Dataset(object):
self.silent = silent self.silent = silent
self.feature_name = feature_name self.feature_name = feature_name
self.categorical_feature = categorical_feature self.categorical_feature = categorical_feature
self.params = params self.params = copy.deepcopy(params)
self.free_raw_data = free_raw_data self.free_raw_data = free_raw_data
self.used_indices = None self.used_indices = None
self._predictor = None self._predictor = None
self.pandas_categorical = None self.pandas_categorical = None
self.params_back_up = None
def __del__(self): def __del__(self):
self._free_handle() self._free_handle()
...@@ -872,8 +874,13 @@ class Dataset(object): ...@@ -872,8 +874,13 @@ class Dataset(object):
if not self.params: if not self.params:
self.params = params self.params = params
else: else:
self.params_back_up = copy.deepcopy(self.params)
self.params.update(params) self.params.update(params)
def _reverse_update_params(self):
self.params = copy.deepcopy(self.params_back_up)
self.params_back_up = None
def set_field(self, field_name, data): def set_field(self, field_name, data):
"""Set property into the Dataset. """Set property into the Dataset.
......
...@@ -128,14 +128,13 @@ def train(params, train_set, num_boost_round=100, ...@@ -128,14 +128,13 @@ def train(params, train_set, num_boost_round=100,
continue continue
if not isinstance(valid_data, Dataset): if not isinstance(valid_data, Dataset):
raise TypeError("Traninig only accepts Dataset object") raise TypeError("Traninig only accepts Dataset object")
valid_data._update_params(params)
valid_data.set_reference(train_set) valid_data.set_reference(train_set)
reduced_valid_sets.append(valid_data) reduced_valid_sets.append(valid_data)
if valid_names is not None and len(valid_names) > i: if valid_names is not None and len(valid_names) > i:
name_valid_sets.append(valid_names[i]) name_valid_sets.append(valid_names[i])
else: else:
name_valid_sets.append('valid_' + str(i)) name_valid_sets.append('valid_' + str(i))
for valid_data in valid_sets:
valid_data._update_params(params)
"""process callbacks""" """process callbacks"""
if callbacks is None: if callbacks is None:
callbacks = set() callbacks = set()
...@@ -165,11 +164,16 @@ def train(params, train_set, num_boost_round=100, ...@@ -165,11 +164,16 @@ def train(params, train_set, num_boost_round=100,
callbacks_after_iter = sorted(callbacks_after_iter, key=attrgetter('order')) callbacks_after_iter = sorted(callbacks_after_iter, key=attrgetter('order'))
"""construct booster""" """construct booster"""
booster = Booster(params=params, train_set=train_set) try:
if is_valid_contain_train: booster = Booster(params=params, train_set=train_set)
booster.set_train_data_name(train_data_name) if is_valid_contain_train:
for valid_set, name_valid_set in zip(reduced_valid_sets, name_valid_sets): booster.set_train_data_name(train_data_name)
booster.add_valid(valid_set, name_valid_set) for valid_set, name_valid_set in zip(reduced_valid_sets, name_valid_sets):
booster.add_valid(valid_set, name_valid_set)
finally:
train_set._reverse_update_params()
for valid_set in reduced_valid_sets:
valid_set._reverse_update_params()
booster.best_iteration = 0 booster.best_iteration = 0
"""start training""" """start training"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment