Unverified Commit 69798c3e authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python][tests] small Python tests cleanup (#3715)

parent d1014ea6
...@@ -423,7 +423,7 @@ lightgbm.Rcheck/ ...@@ -423,7 +423,7 @@ lightgbm.Rcheck/
miktex*.zip miktex*.zip
*.def *.def
# Files created by R examples and tests # Files created by R and Python examples and tests
**/lgb-Dataset.data **/lgb-Dataset.data
**/lgb.Dataset.data **/lgb.Dataset.data
**/model.txt **/model.txt
......
...@@ -126,8 +126,8 @@ class TestBasic(unittest.TestCase): ...@@ -126,8 +126,8 @@ class TestBasic(unittest.TestCase):
est_2 = lgb.train(params, train_data_2, num_boost_round=10) est_2 = lgb.train(params, train_data_2, num_boost_round=10)
pred_2 = est_2.predict(X_train) pred_2 = est_2.predict(X_train)
np.testing.assert_allclose(pred_1, pred_2) np.testing.assert_allclose(pred_1, pred_2)
est_2.save_model('temp_model.txt') est_2.save_model('model.txt')
est_3 = lgb.Booster(model_file='temp_model.txt') est_3 = lgb.Booster(model_file='model.txt')
pred_3 = est_3.predict(X_train) pred_3 = est_3.predict(X_train)
np.testing.assert_allclose(pred_2, pred_3) np.testing.assert_allclose(pred_2, pred_3)
......
...@@ -2612,7 +2612,8 @@ class TestEngine(unittest.TestCase): ...@@ -2612,7 +2612,8 @@ class TestEngine(unittest.TestCase):
def test_reset_params_works_with_metric_num_class_and_boosting(self): def test_reset_params_works_with_metric_num_class_and_boosting(self):
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
params = { dataset_params = {"max_bin": 150}
booster_params = {
'objective': 'multiclass', 'objective': 'multiclass',
'max_depth': 4, 'max_depth': 4,
'bagging_fraction': 0.8, 'bagging_fraction': 0.8,
...@@ -2620,26 +2621,18 @@ class TestEngine(unittest.TestCase): ...@@ -2620,26 +2621,18 @@ class TestEngine(unittest.TestCase):
'boosting': 'gbdt', 'boosting': 'gbdt',
'num_class': 5 'num_class': 5
} }
dtrain = lgb.Dataset(X, y, params={"max_bin": 150}) dtrain = lgb.Dataset(X, y, params=dataset_params)
bst = lgb.Booster( bst = lgb.Booster(
params=params, params=booster_params,
train_set=dtrain train_set=dtrain
) )
expected_params = {
'objective': 'multiclass',
'max_depth': 4,
'bagging_fraction': 0.8,
'metric': ['multi_logloss', 'multi_error'],
'boosting': 'gbdt',
'num_class': 5,
'max_bin': 150
}
assert bst.params == expected_params
params['bagging_fraction'] = 0.9 expected_params = dict(dataset_params, **booster_params)
ret_bst = bst.reset_parameter(params) self.assertDictEqual(bst.params, expected_params)
booster_params['bagging_fraction'] += 0.1
new_bst = bst.reset_parameter(booster_params)
expected_params['bagging_fraction'] = 0.9 expected_params = dict(dataset_params, **booster_params)
assert bst.params == expected_params self.assertDictEqual(bst.params, expected_params)
assert ret_bst.params == expected_params self.assertDictEqual(new_bst.params, expected_params)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment