Unverified Commit 70654048 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python] preserve None in `_choose_param_value()` (#5289)



* [python] preserve None in _choose_param_value()

* Update python-package/lightgbm/basic.py
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 7f43767a
...@@ -408,21 +408,27 @@ def _choose_param_value(main_param_name: str, params: Dict[str, Any], default_va ...@@ -408,21 +408,27 @@ def _choose_param_value(main_param_name: str, params: Dict[str, Any], default_va
# avoid side effects on passed-in parameters # avoid side effects on passed-in parameters
params = deepcopy(params) params = deepcopy(params)
# find a value, and remove other aliases with .pop() aliases = _ConfigAliases.get(main_param_name) - {main_param_name}
# prefer the value of 'main_param_name' if it exists, otherwise search the aliases
found_value = None # if main_param_name was provided, keep that value and remove all aliases
if main_param_name in params.keys(): if main_param_name in params.keys():
found_value = params[main_param_name] for param in aliases:
params.pop(param, None)
return params
for param in _ConfigAliases.get(main_param_name): # if main param name was not found, search for an alias
val = params.pop(param, None) for param in aliases:
if found_value is None and val is not None: if param in params.keys():
found_value = val params[main_param_name] = params[param]
break
if found_value is not None: if main_param_name in params.keys():
params[main_param_name] = found_value for param in aliases:
else: params.pop(param, None)
params[main_param_name] = default_value return params
# neither of main_param_name, aliases were found
params[main_param_name] = default_value
return params return params
......
...@@ -513,6 +513,42 @@ def test_choose_param_value(): ...@@ -513,6 +513,42 @@ def test_choose_param_value():
assert original_params == expected_params assert original_params == expected_params
def test_choose_param_value_preserves_nones():
# preserves None found for main param and still removes aliases
params = lgb.basic._choose_param_value(
main_param_name="num_threads",
params={
"num_threads": None,
"n_jobs": 4,
"objective": "regression"
},
default_value=2
)
assert params == {"num_threads": None, "objective": "regression"}
# correctly chooses value when only an alias is provided
params = lgb.basic._choose_param_value(
main_param_name="num_threads",
params={
"n_jobs": None,
"objective": "regression"
},
default_value=2
)
assert params == {"num_threads": None, "objective": "regression"}
# adds None if that's given as the default and param not found
params = lgb.basic._choose_param_value(
main_param_name="min_data_in_leaf",
params={
"objective": "regression"
},
default_value=None
)
assert params == {"objective": "regression", "min_data_in_leaf": None}
@pytest.mark.parametrize("objective_alias", lgb.basic._ConfigAliases.get("objective")) @pytest.mark.parametrize("objective_alias", lgb.basic._ConfigAliases.get("objective"))
def test_choose_param_value_objective(objective_alias): def test_choose_param_value_objective(objective_alias):
# If callable is found in objective # If callable is found in objective
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment