Unverified Commit 70654048 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python] preserve None in `_choose_param_value()` (#5289)



* [python] preserve None in _choose_param_value()

* Update python-package/lightgbm/basic.py
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 7f43767a
......@@ -408,21 +408,27 @@ def _choose_param_value(main_param_name: str, params: Dict[str, Any], default_va
# avoid side effects on passed-in parameters
params = deepcopy(params)
# find a value, and remove other aliases with .pop()
# prefer the value of 'main_param_name' if it exists, otherwise search the aliases
found_value = None
aliases = _ConfigAliases.get(main_param_name) - {main_param_name}
# if main_param_name was provided, keep that value and remove all aliases
if main_param_name in params.keys():
found_value = params[main_param_name]
for param in aliases:
params.pop(param, None)
return params
for param in _ConfigAliases.get(main_param_name):
val = params.pop(param, None)
if found_value is None and val is not None:
found_value = val
# if main param name was not found, search for an alias
for param in aliases:
if param in params.keys():
params[main_param_name] = params[param]
break
if found_value is not None:
params[main_param_name] = found_value
else:
params[main_param_name] = default_value
if main_param_name in params.keys():
for param in aliases:
params.pop(param, None)
return params
# neither of main_param_name, aliases were found
params[main_param_name] = default_value
return params
......
......@@ -513,6 +513,42 @@ def test_choose_param_value():
assert original_params == expected_params
def test_choose_param_value_preserves_nones():
# preserves None found for main param and still removes aliases
params = lgb.basic._choose_param_value(
main_param_name="num_threads",
params={
"num_threads": None,
"n_jobs": 4,
"objective": "regression"
},
default_value=2
)
assert params == {"num_threads": None, "objective": "regression"}
# correctly chooses value when only an alias is provided
params = lgb.basic._choose_param_value(
main_param_name="num_threads",
params={
"n_jobs": None,
"objective": "regression"
},
default_value=2
)
assert params == {"num_threads": None, "objective": "regression"}
# adds None if that's given as the default and param not found
params = lgb.basic._choose_param_value(
main_param_name="min_data_in_leaf",
params={
"objective": "regression"
},
default_value=None
)
assert params == {"objective": "regression", "min_data_in_leaf": None}
@pytest.mark.parametrize("objective_alias", lgb.basic._ConfigAliases.get("objective"))
def test_choose_param_value_objective(objective_alias):
# If callable is found in objective
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment