Commit ed1e4f8e authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

[python] Dataset params back up before training (#786)

* params back up

* refine logic
parent 2367b463
......@@ -4,6 +4,7 @@
"""Wrapper c_api of LightGBM"""
from __future__ import absolute_import
import copy
import ctypes
import os
import warnings
......@@ -591,11 +592,12 @@ class Dataset(object):
self.silent = silent
self.feature_name = feature_name
self.categorical_feature = categorical_feature
self.params = params
self.params = copy.deepcopy(params)
self.free_raw_data = free_raw_data
self.used_indices = None
self._predictor = None
self.pandas_categorical = None
self.params_back_up = None
def __del__(self):
self._free_handle()
......@@ -872,8 +874,13 @@ class Dataset(object):
if not self.params:
self.params = params
else:
self.params_back_up = copy.deepcopy(self.params)
self.params.update(params)
def _reverse_update_params(self):
self.params = copy.deepcopy(self.params_back_up)
self.params_back_up = None
def set_field(self, field_name, data):
"""Set property into the Dataset.
......
......@@ -128,14 +128,13 @@ def train(params, train_set, num_boost_round=100,
continue
if not isinstance(valid_data, Dataset):
raise TypeError("Traninig only accepts Dataset object")
valid_data._update_params(params)
valid_data.set_reference(train_set)
reduced_valid_sets.append(valid_data)
if valid_names is not None and len(valid_names) > i:
name_valid_sets.append(valid_names[i])
else:
name_valid_sets.append('valid_' + str(i))
for valid_data in valid_sets:
valid_data._update_params(params)
"""process callbacks"""
if callbacks is None:
callbacks = set()
......@@ -165,11 +164,16 @@ def train(params, train_set, num_boost_round=100,
callbacks_after_iter = sorted(callbacks_after_iter, key=attrgetter('order'))
"""construct booster"""
try:
booster = Booster(params=params, train_set=train_set)
if is_valid_contain_train:
booster.set_train_data_name(train_data_name)
for valid_set, name_valid_set in zip(reduced_valid_sets, name_valid_sets):
booster.add_valid(valid_set, name_valid_set)
finally:
train_set._reverse_update_params()
for valid_set in reduced_valid_sets:
valid_set._reverse_update_params()
booster.best_iteration = 0
"""start training"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment