Unverified Commit 06e94ada authored by Joaquín Ignacio Aramendía's avatar Joaquín Ignacio Aramendía Committed by GitHub
Browse files

[python-package] Load parameters from model string (#6852)



* [python-package] Test serialization and deserialization from in-memory string

Test case for #6851

* [python-package] Fill in `params` when loading from in-memory string

This fixes #6851 by using the same workaround as when loading the model
from a file.

* test_basic: use rng instead of legacy numpy RandomState

* test_basic: remove debug prints leftovers
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* test_basic: add boolean, array of float and array of integers to testcase
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* test_basic: make a cheaper model (2 rounds with 7 leaves each)
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* test_basic: bugfix typos

* python_package_test: move string load test from basic to engine

* test_engine: catch params ignored warnings

* test_engine: be explicit about parameters assertion

* test_engine: shush linter complaint

* test_basic: delete empty line
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* test_engine: even cheaper model with less features
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* test_engine: delete redundant assert
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* test_engine: run pre-commit and take it's word for it

* test_engine: be explicit in an E712 compliant way

* test_engine.py: pass different value as argument to make sure it is ignored

---------
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 0bbb02fd
...@@ -3715,6 +3715,9 @@ class Booster: ...@@ -3715,6 +3715,9 @@ class Booster:
params = self._get_loaded_param() params = self._get_loaded_param()
elif model_str is not None: elif model_str is not None:
self.model_from_string(model_str) self.model_from_string(model_str)
if params:
_log_warning("Ignoring params argument, using parameters from model string.")
params = self._get_loaded_param()
else: else:
raise TypeError( raise TypeError(
"Need at least one training dataset or model file or model string to create Booster instance" "Need at least one training dataset or model file or model string to create Booster instance"
......
...@@ -1498,7 +1498,7 @@ def test_parameters_are_loaded_from_model_file(tmp_path, capsys, rng): ...@@ -1498,7 +1498,7 @@ def test_parameters_are_loaded_from_model_file(tmp_path, capsys, rng):
assert bst.params["categorical_feature"] == [1, 2] assert bst.params["categorical_feature"] == [1, 2]
# check that passing parameters to the constructor raises warning and ignores them # check that passing parameters to the constructor raises warning and ignores them
with pytest.warns(UserWarning, match="Ignoring params argument"): with pytest.warns(UserWarning, match="Ignoring params argument, using parameters from model file."):
bst2 = lgb.Booster(params={"num_leaves": 7}, model_file=model_file) bst2 = lgb.Booster(params={"num_leaves": 7}, model_file=model_file)
assert bst.params == bst2.params assert bst.params == bst2.params
...@@ -1508,6 +1508,50 @@ def test_parameters_are_loaded_from_model_file(tmp_path, capsys, rng): ...@@ -1508,6 +1508,50 @@ def test_parameters_are_loaded_from_model_file(tmp_path, capsys, rng):
np.testing.assert_allclose(preds, orig_preds) np.testing.assert_allclose(preds, orig_preds)
def test_string_serialized_params_retrieval(rng):
# Random train data
train_x = rng.random((500, 3))
train_y = rng.integers(0, 1, 500)
train_data = lgb.Dataset(train_x, train_y)
# Parameters
params = {
"boosting": "gbdt",
"deterministic": True,
"feature_contri": [0.5] * train_x.shape[1],
"interaction_constraints": [[0, 1], [0]],
"objective": "binary",
"metric": ["auc"],
"num_leaves": 7,
"learning_rate": 0.05,
"feature_fraction": 0.9,
"bagging_fraction": 0.8,
"bagging_freq": 5,
"verbosity": -100,
}
# train a model and serialize it to a string in memory
model = lgb.train(params, train_data, num_boost_round=2)
model_serialized = model.model_to_string()
# load a new model with the string
with pytest.warns(UserWarning, match="Ignoring params argument, using parameters from model string."):
new_model = lgb.Booster(params={"num_leaves": 32}, model_str=model_serialized)
assert new_model.params["boosting"] == "gbdt"
assert new_model.params["deterministic"] is True
assert new_model.params["feature_contri"] == [0.5] * train_x.shape[1]
assert new_model.params["interaction_constraints"] == [[0, 1], [0]]
assert new_model.params["objective"] == "binary"
assert new_model.params["metric"] == ["auc"]
assert new_model.params["num_leaves"] == 7
assert new_model.params["learning_rate"] == 0.05
assert new_model.params["feature_fraction"] == 0.9
assert new_model.params["bagging_fraction"] == 0.8
assert new_model.params["bagging_freq"] == 5
assert new_model.params["verbosity"] == -100
def test_save_load_copy_pickle(tmp_path): def test_save_load_copy_pickle(tmp_path):
def train_and_predict(init_model=None, return_model=False): def train_and_predict(init_model=None, return_model=False):
X, y = make_synthetic_regression() X, y = make_synthetic_regression()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment