Unverified Commit 59153b28 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python] simplify param aliases handling (#3864)

* Update sklearn.py

* Update dask.py
parent 066720ef
...@@ -236,9 +236,8 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group ...@@ -236,9 +236,8 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
# * 'machine_list_filename': not relevant for the Dask interface # * 'machine_list_filename': not relevant for the Dask interface
# * 'num_machines': set automatically from Dask worker list # * 'num_machines': set automatically from Dask worker list
# * 'num_threads': overridden to match nthreads on each Dask process # * 'num_threads': overridden to match nthreads on each Dask process
for param_name in ['machines', 'machine_list_filename', 'num_machines', 'num_threads']: for param_alias in _ConfigAliases.get('machines', 'machine_list_filename', 'num_machines', 'num_threads'):
for param_alias in _ConfigAliases.get(param_name): params.pop(param_alias, None)
params.pop(param_alias, None)
# Split arrays/dataframes into parts. Arrange parts into dicts to enforce co-locality # Split arrays/dataframes into parts. Arrange parts into dicts to enforce co-locality
data_parts = _split_to_parts(data=data, is_matrix=True) data_parts = _split_to_parts(data=data, is_matrix=True)
......
...@@ -6,7 +6,7 @@ from inspect import signature ...@@ -6,7 +6,7 @@ from inspect import signature
import numpy as np import numpy as np
from .basic import Dataset, LightGBMError, _ConfigAliases, _log_warning from .basic import Dataset, LightGBMError, _ConfigAliases, _choose_param_value, _log_warning
from .compat import (SKLEARN_INSTALLED, _LGBMClassifierBase, from .compat import (SKLEARN_INSTALLED, _LGBMClassifierBase,
LGBMNotFittedError, _LGBMLabelEncoder, _LGBMModelBase, LGBMNotFittedError, _LGBMLabelEncoder, _LGBMModelBase,
_LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckSampleWeight, _LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckSampleWeight,
...@@ -551,13 +551,11 @@ class LGBMModel(_LGBMModelBase): ...@@ -551,13 +551,11 @@ class LGBMModel(_LGBMModelBase):
original_metric = "ndcg" original_metric = "ndcg"
# overwrite default metric by explicitly set metric # overwrite default metric by explicitly set metric
for metric_alias in _ConfigAliases.get("metric"): params = _choose_param_value("metric", params, original_metric)
if metric_alias in params:
original_metric = params.pop(metric_alias)
# concatenate metric from params (or default if not provided in params) and eval_metric # concatenate metric from params (or default if not provided in params) and eval_metric
original_metric = [original_metric] if isinstance(original_metric, (str, type(None))) else original_metric params['metric'] = [params['metric']] if isinstance(params['metric'], (str, type(None))) else params['metric']
params['metric'] = [e for e in eval_metrics_builtin if e not in original_metric] + original_metric params['metric'] = [e for e in eval_metrics_builtin if e not in params['metric']] + params['metric']
params['metric'] = [metric for metric in params['metric'] if metric is not None] params['metric'] = [metric for metric in params['metric'] if metric is not None]
if not isinstance(X, (pd_DataFrame, dt_DataTable)): if not isinstance(X, (pd_DataFrame, dt_DataTable)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment