Unverified Commit f57ef6f4 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python][sklearn] respect parameters for predictions in `init()` and `set_params()` methods (#4822)

* in predict(), respect params set via `set_params()` after fit()

* continue

* add test

* fix return name

* hotfix

* simplify
parent b31d5a43
......@@ -423,6 +423,16 @@ class _ConfigAliases:
ret |= cls.aliases.get(i, {i})
return ret
@classmethod
def get_by_alias(cls, *args):
ret = set(args)
for arg in args:
for aliases in cls.aliases.values():
if arg in aliases:
ret |= aliases
break
return ret
def _choose_param_value(main_param_name: str, params: Dict[str, Any], default_value: Any) -> Dict[str, Any]:
"""Get a single parameter value, accounting for aliases.
......
......@@ -2,7 +2,7 @@
"""Scikit-learn wrapper interface for LightGBM."""
import copy
from inspect import signature
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
......@@ -582,21 +582,30 @@ class LGBMModel(_LGBMModelBase):
self._other_params[key] = value
return self
def fit(self, X, y,
sample_weight=None, init_score=None, group=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_class_weight=None, eval_init_score=None, eval_group=None,
eval_metric=None, early_stopping_rounds=None,
feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Docstring is set after definition, using a template."""
def _process_params(self, stage: str) -> Dict[str, Any]:
"""Process the parameters of this estimator based on its type, parameter aliases, etc.
Parameters
----------
stage : str
Name of the stage (can be ``fit`` or ``predict``) this method is called from.
Returns
-------
processed_params : dict
Processed parameter names mapped to their values.
"""
assert stage in {"fit", "predict"}
params = self.get_params()
params.pop('objective', None)
for alias in _ConfigAliases.get('objective'):
if alias in params:
self._objective = params.pop(alias)
obj = params.pop(alias)
_log_warning(f"Found '{alias}' in params. Will use it instead of 'objective' argument")
if stage == "fit":
self._objective = obj
if stage == "fit":
if self._objective is None:
if isinstance(self, LGBMRegressor):
self._objective = "regression"
......@@ -610,9 +619,11 @@ class LGBMModel(_LGBMModelBase):
else:
raise ValueError("Unknown LGBMModel type.")
if callable(self._objective):
if stage == "fit":
self._fobj = _ObjectiveFunctionWrapper(self._objective)
params['objective'] = 'None' # objective = nullptr for unknown objective
else:
if stage == "fit":
self._fobj = None
params['objective'] = self._objective
......@@ -634,16 +645,6 @@ class LGBMModel(_LGBMModelBase):
eval_at = params.pop(alias)
params['eval_at'] = eval_at
# Do not modify original args in fit function
# Refer to https://github.com/microsoft/LightGBM/pull/2619
eval_metric_list = copy.deepcopy(eval_metric)
if not isinstance(eval_metric_list, list):
eval_metric_list = [eval_metric_list]
# Separate built-in from callable evaluation metrics
eval_metrics_callable = [_EvalFunctionWrapper(f) for f in eval_metric_list if callable(f)]
eval_metrics_builtin = [m for m in eval_metric_list if isinstance(m, str)]
# register default metric for consistency with callable eval_metric case
original_metric = self._objective if isinstance(self._objective, str) else None
if original_metric is None:
......@@ -658,6 +659,28 @@ class LGBMModel(_LGBMModelBase):
# overwrite default metric by explicitly set metric
params = _choose_param_value("metric", params, original_metric)
return params
def fit(self, X, y,
sample_weight=None, init_score=None, group=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_class_weight=None, eval_init_score=None, eval_group=None,
eval_metric=None, early_stopping_rounds=None,
feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Docstring is set after definition, using a template."""
params = self._process_params(stage="fit")
# Do not modify original args in fit function
# Refer to https://github.com/microsoft/LightGBM/pull/2619
eval_metric_list = copy.deepcopy(eval_metric)
if not isinstance(eval_metric_list, list):
eval_metric_list = [eval_metric_list]
# Separate built-in from callable evaluation metrics
eval_metrics_callable = [_EvalFunctionWrapper(f) for f in eval_metric_list if callable(f)]
eval_metrics_builtin = [m for m in eval_metric_list if isinstance(m, str)]
# concatenate metric from params (or default if not provided in params) and eval_metric
params['metric'] = [params['metric']] if isinstance(params['metric'], (str, type(None))) else params['metric']
params['metric'] = [e for e in eval_metrics_builtin if e not in params['metric']] + params['metric']
......@@ -799,8 +822,23 @@ class LGBMModel(_LGBMModelBase):
raise ValueError("Number of features of the model must "
f"match the input. Model n_features_ is {self._n_features} and "
f"input n_features is {n_features}")
# retrive original params that possibly can be used in both training and prediction
# and then overwrite them (considering aliases) with params that were passed directly in prediction
predict_params = self._process_params(stage="predict")
for alias in _ConfigAliases.get_by_alias(
"data",
"X",
"raw_score",
"start_iteration",
"num_iteration",
"pred_leaf",
"pred_contrib",
*kwargs.keys()
):
predict_params.pop(alias, None)
predict_params.update(kwargs)
return self._Booster.predict(X, raw_score=raw_score, start_iteration=start_iteration, num_iteration=num_iteration,
pred_leaf=pred_leaf, pred_contrib=pred_contrib, **kwargs)
pred_leaf=pred_leaf, pred_contrib=pred_contrib, **predict_params)
predict.__doc__ = _lgbmmodel_doc_predict.format(
description="Return the predicted value for each sample.",
......
......@@ -612,7 +612,7 @@ def test_pandas_sparse():
def test_predict():
# With default params
iris = load_iris(return_X_y=False)
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target,
X_train, X_test, y_train, _ = train_test_split(iris.data, iris.target,
test_size=0.2, random_state=42)
gbm = lgb.train({'objective': 'multiclass',
......@@ -689,6 +689,41 @@ def test_predict():
np.testing.assert_allclose(res_engine, res_sklearn_params)
def test_predict_with_params_from_init():
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, _ = train_test_split(X, y, test_size=0.2, random_state=42)
predict_params = {
'pred_early_stop': True,
'pred_early_stop_margin': 1.0
}
y_preds_no_params = lgb.LGBMClassifier(verbose=-1).fit(X_train, y_train).predict(
X_test, raw_score=True)
y_preds_params_in_predict = lgb.LGBMClassifier(verbose=-1).fit(X_train, y_train).predict(
X_test, raw_score=True, **predict_params)
with pytest.raises(AssertionError):
np.testing.assert_allclose(y_preds_no_params, y_preds_params_in_predict)
y_preds_params_in_set_params_before_fit = lgb.LGBMClassifier(verbose=-1).set_params(
**predict_params).fit(X_train, y_train).predict(X_test, raw_score=True)
np.testing.assert_allclose(y_preds_params_in_predict, y_preds_params_in_set_params_before_fit)
y_preds_params_in_set_params_after_fit = lgb.LGBMClassifier(verbose=-1).fit(X_train, y_train).set_params(
**predict_params).predict(X_test, raw_score=True)
np.testing.assert_allclose(y_preds_params_in_predict, y_preds_params_in_set_params_after_fit)
y_preds_params_in_init = lgb.LGBMClassifier(verbose=-1, **predict_params).fit(X_train, y_train).predict(
X_test, raw_score=True)
np.testing.assert_allclose(y_preds_params_in_predict, y_preds_params_in_init)
# test that params passed in predict have higher priority
y_preds_params_overwritten = lgb.LGBMClassifier(verbose=-1, **predict_params).fit(X_train, y_train).predict(
X_test, raw_score=True, pred_early_stop=False)
np.testing.assert_allclose(y_preds_no_params, y_preds_params_overwritten)
def test_evaluate_train_set():
X, y = load_boston(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment