Unverified Commit 06e94ada authored by Joaquín Ignacio Aramendía's avatar Joaquín Ignacio Aramendía Committed by GitHub
Browse files

[python-package] Load parameters from model string (#6852)



* [python-package] Test serialization and deserialization from in-memory string

Test case for #6851

* [python-package] Fill in `params` when loading from in-memory string

This fixes #6851 by using the same workaround as when loading the model
from a file.

* test_basic: use rng instead of legacy numpy RandomState

* test_basic: remove debug prints leftovers
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* test_basic: add boolean, array of float and array of integers to testcase
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* test_basic: make a cheaper model (2 rounds with 7 leaves each)
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* test_basic: bugfix typos

* python_package_test: move string load test from basic to engine

* test_engine: catch params ignored warnings

* test_engine: be explicit about parameters assertion

* test_engine: shush linter complaint

* test_basic: delete empty line
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* test_engine: even cheaper model with less features
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* test_engine: delete redundant assert
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* test_engine: run pre-commit and take it's word for it

* test_engine: be explicit in an E712 compliant way

* test_engine.py: pass different value as argument to make sure it is ignored

---------
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 0bbb02fd
......@@ -3715,6 +3715,9 @@ class Booster:
params = self._get_loaded_param()
elif model_str is not None:
self.model_from_string(model_str)
if params:
_log_warning("Ignoring params argument, using parameters from model string.")
params = self._get_loaded_param()
else:
raise TypeError(
"Need at least one training dataset or model file or model string to create Booster instance"
......
......@@ -1498,7 +1498,7 @@ def test_parameters_are_loaded_from_model_file(tmp_path, capsys, rng):
assert bst.params["categorical_feature"] == [1, 2]
# check that passing parameters to the constructor raises warning and ignores them
with pytest.warns(UserWarning, match="Ignoring params argument"):
with pytest.warns(UserWarning, match="Ignoring params argument, using parameters from model file."):
bst2 = lgb.Booster(params={"num_leaves": 7}, model_file=model_file)
assert bst.params == bst2.params
......@@ -1508,6 +1508,50 @@ def test_parameters_are_loaded_from_model_file(tmp_path, capsys, rng):
np.testing.assert_allclose(preds, orig_preds)
def test_string_serialized_params_retrieval(rng):
# Random train data
train_x = rng.random((500, 3))
train_y = rng.integers(0, 1, 500)
train_data = lgb.Dataset(train_x, train_y)
# Parameters
params = {
"boosting": "gbdt",
"deterministic": True,
"feature_contri": [0.5] * train_x.shape[1],
"interaction_constraints": [[0, 1], [0]],
"objective": "binary",
"metric": ["auc"],
"num_leaves": 7,
"learning_rate": 0.05,
"feature_fraction": 0.9,
"bagging_fraction": 0.8,
"bagging_freq": 5,
"verbosity": -100,
}
# train a model and serialize it to a string in memory
model = lgb.train(params, train_data, num_boost_round=2)
model_serialized = model.model_to_string()
# load a new model with the string
with pytest.warns(UserWarning, match="Ignoring params argument, using parameters from model string."):
new_model = lgb.Booster(params={"num_leaves": 32}, model_str=model_serialized)
assert new_model.params["boosting"] == "gbdt"
assert new_model.params["deterministic"] is True
assert new_model.params["feature_contri"] == [0.5] * train_x.shape[1]
assert new_model.params["interaction_constraints"] == [[0, 1], [0]]
assert new_model.params["objective"] == "binary"
assert new_model.params["metric"] == ["auc"]
assert new_model.params["num_leaves"] == 7
assert new_model.params["learning_rate"] == 0.05
assert new_model.params["feature_fraction"] == 0.9
assert new_model.params["bagging_fraction"] == 0.8
assert new_model.params["bagging_freq"] == 5
assert new_model.params["verbosity"] == -100
def test_save_load_copy_pickle(tmp_path):
def train_and_predict(init_model=None, return_model=False):
X, y = make_synthetic_regression()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment