Unverified Commit 59153b28 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python] simplify param aliases handling (#3864)

* Update sklearn.py

* Update dask.py
parent 066720ef
......@@ -236,9 +236,8 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
# * 'machine_list_filename': not relevant for the Dask interface
# * 'num_machines': set automatically from Dask worker list
# * 'num_threads': overridden to match nthreads on each Dask process
for param_name in ['machines', 'machine_list_filename', 'num_machines', 'num_threads']:
for param_alias in _ConfigAliases.get(param_name):
params.pop(param_alias, None)
for param_alias in _ConfigAliases.get('machines', 'machine_list_filename', 'num_machines', 'num_threads'):
params.pop(param_alias, None)
# Split arrays/dataframes into parts. Arrange parts into dicts to enforce co-locality
data_parts = _split_to_parts(data=data, is_matrix=True)
......
......@@ -6,7 +6,7 @@ from inspect import signature
import numpy as np
from .basic import Dataset, LightGBMError, _ConfigAliases, _log_warning
from .basic import Dataset, LightGBMError, _ConfigAliases, _choose_param_value, _log_warning
from .compat import (SKLEARN_INSTALLED, _LGBMClassifierBase,
LGBMNotFittedError, _LGBMLabelEncoder, _LGBMModelBase,
_LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckSampleWeight,
......@@ -551,13 +551,11 @@ class LGBMModel(_LGBMModelBase):
original_metric = "ndcg"
# overwrite default metric by explicitly set metric
for metric_alias in _ConfigAliases.get("metric"):
if metric_alias in params:
original_metric = params.pop(metric_alias)
params = _choose_param_value("metric", params, original_metric)
# concatenate metric from params (or default if not provided in params) and eval_metric
original_metric = [original_metric] if isinstance(original_metric, (str, type(None))) else original_metric
params['metric'] = [e for e in eval_metrics_builtin if e not in original_metric] + original_metric
params['metric'] = [params['metric']] if isinstance(params['metric'], (str, type(None))) else params['metric']
params['metric'] = [e for e in eval_metrics_builtin if e not in params['metric']] + params['metric']
params['metric'] = [metric for metric in params['metric'] if metric is not None]
if not isinstance(X, (pd_DataFrame, dt_DataTable)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment