Unverified Commit 399b746b authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

[python] fix the bug when use different params with reference (#2907)



* fix the bug when use different params with reference

* fix

* Update basic.py

* Apply suggestions from code review
Co-Authored-By: default avatarNikita Titov <nekit94-08@mail.ru>

* Update basic.py

* add test

* Apply suggestions from code review

* added asserts in test
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarStrikerRUS <nekit94-12@hotmail.com>
parent f248170d
...@@ -1097,6 +1097,10 @@ class Dataset(object): ...@@ -1097,6 +1097,10 @@ class Dataset(object):
""" """
if self.handle is None: if self.handle is None:
if self.reference is not None: if self.reference is not None:
reference_params = self.reference.get_params()
if self.get_params() != reference_params:
warnings.warn('Overriding the parameters from Reference Dataset.')
self._update_params(reference_params)
if self.used_indices is None: if self.used_indices is None:
# create valid # create valid
self._lazy_init(self.data, label=self.label, reference=self.reference, self._lazy_init(self.data, label=self.label, reference=self.reference,
...@@ -1222,6 +1226,8 @@ class Dataset(object): ...@@ -1222,6 +1226,8 @@ class Dataset(object):
return self return self
def _update_params(self, params): def _update_params(self, params):
if not params:
return self
params = copy.deepcopy(params) params = copy.deepcopy(params)
def update(): def update():
......
...@@ -1958,6 +1958,18 @@ class TestEngine(unittest.TestCase): ...@@ -1958,6 +1958,18 @@ class TestEngine(unittest.TestCase):
with np.testing.assert_raises_regex(lgb.basic.LightGBMError, err_msg): with np.testing.assert_raises_regex(lgb.basic.LightGBMError, err_msg):
lgb.train(new_params, lgb_data, num_boost_round=3) lgb.train(new_params, lgb_data, num_boost_round=3)
def test_dataset_params_with_reference(self):
default_params = {"max_bin": 100}
X = np.random.random((100, 2))
y = np.random.random(100)
X_val = np.random.random((100, 2))
y_val = np.random.random(100)
lgb_train = lgb.Dataset(X, y, params=default_params, free_raw_data=False).construct()
lgb_val = lgb.Dataset(X_val, y_val, reference=lgb_train, free_raw_data=False).construct()
self.assertDictEqual(lgb_train.get_params(), default_params)
self.assertDictEqual(lgb_val.get_params(), default_params)
model = lgb.train(default_params, lgb_train, valid_sets=[lgb_val])
def test_extra_trees(self): def test_extra_trees(self):
# check extra trees increases regularization # check extra trees increases regularization
X, y = load_boston(True) X, y = load_boston(True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment