Unverified Commit f71328d4 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python][sklearn] Remove `early_stopping_rounds` argument of `fit()` method (#4846)

parent 1114ec80
......@@ -27,7 +27,7 @@ gbm = lgb.LGBMRegressor(num_leaves=31,
gbm.fit(X_train, y_train,
eval_set=[(X_test, y_test)],
eval_metric='l1',
early_stopping_rounds=5)
callbacks=[lgb.early_stopping(5)])
print('Starting predicting...')
# predict
......@@ -52,7 +52,7 @@ print('Starting training with custom eval function...')
gbm.fit(X_train, y_train,
eval_set=[(X_test, y_test)],
eval_metric=rmsle,
early_stopping_rounds=5)
callbacks=[lgb.early_stopping(5)])
# another self-defined eval metric
......@@ -67,7 +67,7 @@ print('Starting training with multiple custom eval functions...')
gbm.fit(X_train, y_train,
eval_set=[(X_test, y_test)],
eval_metric=[rmsle, rae],
early_stopping_rounds=5)
callbacks=[lgb.early_stopping(5)])
print('Starting predicting...')
# predict
......
......@@ -1038,15 +1038,11 @@ class _DaskLGBMModel:
eval_group: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_at: Optional[Iterable[int]] = None,
early_stopping_rounds: Optional[int] = None,
**kwargs: Any
) -> "_DaskLGBMModel":
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask')
if early_stopping_rounds is not None:
raise RuntimeError('early_stopping_rounds is not currently supported in lightgbm.dask')
params = self.get_params(True)
params.pop("client", None)
......@@ -1171,13 +1167,9 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
eval_class_weight: Optional[List[Union[dict, str]]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
early_stopping_rounds: Optional[int] = None,
**kwargs: Any
) -> "DaskLGBMClassifier":
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
if early_stopping_rounds is not None:
raise RuntimeError('early_stopping_rounds is not currently supported in lightgbm.dask')
return self._lgb_dask_fit(
model_factory=LGBMClassifier,
X=X,
......@@ -1204,16 +1196,13 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
eval_group_shape="list of Dask Array or Dask Series, or None, optional (default=None)"
)
# DaskLGBMClassifier does not support group, eval_group, early_stopping_rounds.
# DaskLGBMClassifier does not support group, eval_group.
_base_doc = (_base_doc[:_base_doc.find('group :')]
+ _base_doc[_base_doc.find('eval_set :'):])
_base_doc = (_base_doc[:_base_doc.find('eval_group :')]
+ _base_doc[_base_doc.find('eval_metric :'):])
_base_doc = (_base_doc[:_base_doc.find('early_stopping_rounds :')]
+ _base_doc[_base_doc.find('feature_name :'):])
# DaskLGBMClassifier support for callbacks and init_model is not tested
fit.__doc__ = f"""{_base_doc[:_base_doc.find('callbacks :')]}**kwargs
Other parameters passed through to ``LGBMClassifier.fit()``.
......@@ -1352,13 +1341,9 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
eval_init_score: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
early_stopping_rounds: Optional[int] = None,
**kwargs: Any
) -> "DaskLGBMRegressor":
"""Docstring is inherited from the lightgbm.LGBMRegressor.fit."""
if early_stopping_rounds is not None:
raise RuntimeError('early_stopping_rounds is not currently supported in lightgbm.dask')
return self._lgb_dask_fit(
model_factory=LGBMRegressor,
X=X,
......@@ -1384,7 +1369,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
eval_group_shape="list of Dask Array or Dask Series, or None, optional (default=None)"
)
# DaskLGBMRegressor does not support group, eval_class_weight, eval_group, early_stopping_rounds.
# DaskLGBMRegressor does not support group, eval_class_weight, eval_group.
_base_doc = (_base_doc[:_base_doc.find('group :')]
+ _base_doc[_base_doc.find('eval_set :'):])
......@@ -1394,9 +1379,6 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
_base_doc = (_base_doc[:_base_doc.find('eval_group :')]
+ _base_doc[_base_doc.find('eval_metric :'):])
_base_doc = (_base_doc[:_base_doc.find('early_stopping_rounds :')]
+ _base_doc[_base_doc.find('feature_name :'):])
# DaskLGBMRegressor support for callbacks and init_model is not tested
fit.__doc__ = f"""{_base_doc[:_base_doc.find('callbacks :')]}**kwargs
Other parameters passed through to ``LGBMRegressor.fit()``.
......@@ -1519,13 +1501,9 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
eval_group: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_at: Iterable[int] = (1, 2, 3, 4, 5),
early_stopping_rounds: Optional[int] = None,
**kwargs: Any
) -> "DaskLGBMRanker":
"""Docstring is inherited from the lightgbm.LGBMRanker.fit."""
if early_stopping_rounds is not None:
raise RuntimeError('early_stopping_rounds is not currently supported in lightgbm.dask')
return self._lgb_dask_fit(
model_factory=LGBMRanker,
X=X,
......@@ -1558,7 +1536,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
_base_doc = (_base_doc[:_base_doc.find('eval_class_weight :')]
+ _base_doc[_base_doc.find('eval_init_score :'):])
_base_doc = (_base_doc[:_base_doc.find('early_stopping_rounds :')]
_base_doc = (_base_doc[:_base_doc.find('feature_name :')]
+ "eval_at : iterable of int, optional (default=(1, 2, 3, 4, 5))\n"
+ f"{' ':8}The evaluation positions of the specified metric.\n"
+ f"{' ':4}{_base_doc[_base_doc.find('feature_name :'):]}")
......
......@@ -250,14 +250,6 @@ _lgbmmodel_doc_fit = (
If list, it can be a list of built-in metrics, a list of custom evaluation metrics, or a mix of both.
In either case, the ``metric`` from the model parameters will be evaluated and used as well.
Default: 'l2' for LGBMRegressor, 'logloss' for LGBMClassifier, 'ndcg' for LGBMRanker.
early_stopping_rounds : int or None, optional (default=None)
Activates early stopping. The model will train until the validation score stops improving.
Validation score needs to improve at least every ``early_stopping_rounds`` round(s)
to continue training.
Requires at least one validation data and one metric.
If there's more than one, will check all of them. But the training data is ignored anyway.
To check only the first metric, set the ``first_metric_only`` parameter to ``True``
in additional parameters ``**kwargs`` of the model constructor.
feature_name : list of str, or 'auto', optional (default='auto')
Feature names.
If 'auto' and data is pandas DataFrame, data columns names are used.
......@@ -661,13 +653,25 @@ class LGBMModel(_LGBMModelBase):
return params
def fit(self, X, y,
sample_weight=None, init_score=None, group=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_class_weight=None, eval_init_score=None, eval_group=None,
eval_metric=None, early_stopping_rounds=None,
feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
def fit(
self,
X,
y,
sample_weight=None,
init_score=None,
group=None,
eval_set=None,
eval_names=None,
eval_sample_weight=None,
eval_class_weight=None,
eval_init_score=None,
eval_group=None,
eval_metric=None,
feature_name='auto',
categorical_feature='auto',
callbacks=None,
init_model=None
):
"""Docstring is set after definition, using a template."""
params = self._process_params(stage="fit")
......@@ -754,11 +758,6 @@ class LGBMModel(_LGBMModelBase):
if isinstance(init_model, LGBMModel):
init_model = init_model.booster_
if early_stopping_rounds is not None and early_stopping_rounds > 0:
_log_warning("'early_stopping_rounds' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'early_stopping()' callback via 'callbacks' argument instead.")
params['early_stopping_rounds'] = early_stopping_rounds
if callbacks is None:
callbacks = []
else:
......@@ -940,18 +939,38 @@ class LGBMModel(_LGBMModelBase):
class LGBMRegressor(_LGBMRegressorBase, LGBMModel):
"""LightGBM regressor."""
def fit(self, X, y,
sample_weight=None, init_score=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_init_score=None, eval_metric=None, early_stopping_rounds=None,
feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
def fit(
self,
X,
y,
sample_weight=None,
init_score=None,
eval_set=None,
eval_names=None,
eval_sample_weight=None,
eval_init_score=None,
eval_metric=None,
feature_name='auto',
categorical_feature='auto',
callbacks=None,
init_model=None
):
"""Docstring is inherited from the LGBMModel."""
super().fit(X, y, sample_weight=sample_weight, init_score=init_score,
eval_set=eval_set, eval_names=eval_names, eval_sample_weight=eval_sample_weight,
eval_init_score=eval_init_score, eval_metric=eval_metric,
early_stopping_rounds=early_stopping_rounds, feature_name=feature_name,
categorical_feature=categorical_feature, callbacks=callbacks, init_model=init_model)
super().fit(
X,
y,
sample_weight=sample_weight,
init_score=init_score,
eval_set=eval_set,
eval_names=eval_names,
eval_sample_weight=eval_sample_weight,
eval_init_score=eval_init_score,
eval_metric=eval_metric,
feature_name=feature_name,
categorical_feature=categorical_feature,
callbacks=callbacks,
init_model=init_model
)
return self
_base_doc = LGBMModel.fit.__doc__.replace("self : LGBMModel", "self : LGBMRegressor") # type: ignore
......@@ -966,13 +985,23 @@ class LGBMRegressor(_LGBMRegressorBase, LGBMModel):
class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
"""LightGBM classifier."""
def fit(self, X, y,
sample_weight=None, init_score=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_class_weight=None, eval_init_score=None, eval_metric=None,
early_stopping_rounds=None,
feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
def fit(
self,
X,
y,
sample_weight=None,
init_score=None,
eval_set=None,
eval_names=None,
eval_sample_weight=None,
eval_class_weight=None,
eval_init_score=None,
eval_metric=None,
feature_name='auto',
categorical_feature='auto',
callbacks=None,
init_model=None
):
"""Docstring is inherited from the LGBMModel."""
_LGBMAssertAllFinite(y)
_LGBMCheckClassificationTargets(y)
......@@ -1013,12 +1042,22 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
else:
valid_sets[i] = (valid_x, self._le.transform(valid_y))
super().fit(X, _y, sample_weight=sample_weight, init_score=init_score, eval_set=valid_sets,
eval_names=eval_names, eval_sample_weight=eval_sample_weight,
eval_class_weight=eval_class_weight, eval_init_score=eval_init_score,
eval_metric=eval_metric, early_stopping_rounds=early_stopping_rounds,
feature_name=feature_name, categorical_feature=categorical_feature,
callbacks=callbacks, init_model=init_model)
super().fit(
X,
_y,
sample_weight=sample_weight,
init_score=init_score,
eval_set=valid_sets,
eval_names=eval_names,
eval_sample_weight=eval_sample_weight,
eval_class_weight=eval_class_weight,
eval_init_score=eval_init_score,
eval_metric=eval_metric,
feature_name=feature_name,
categorical_feature=categorical_feature,
callbacks=callbacks,
init_model=init_model
)
return self
_base_doc = LGBMModel.fit.__doc__.replace("self : LGBMModel", "self : LGBMClassifier") # type: ignore
......@@ -1088,13 +1127,25 @@ class LGBMRanker(LGBMModel):
Please use this class mainly for training and applying ranking models in common sklearnish way.
"""
def fit(self, X, y,
sample_weight=None, init_score=None, group=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_init_score=None, eval_group=None, eval_metric=None,
eval_at=(1, 2, 3, 4, 5), early_stopping_rounds=None,
feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
def fit(
self,
X,
y,
sample_weight=None,
init_score=None,
group=None,
eval_set=None,
eval_names=None,
eval_sample_weight=None,
eval_init_score=None,
eval_group=None,
eval_metric=None,
eval_at=(1, 2, 3, 4, 5),
feature_name='auto',
categorical_feature='auto',
callbacks=None,
init_model=None
):
"""Docstring is inherited from the LGBMModel."""
# check group data
if group is None:
......@@ -1113,18 +1164,30 @@ class LGBMRanker(LGBMModel):
"if you use dict, the index should start from 0")
self._eval_at = eval_at
super().fit(X, y, sample_weight=sample_weight, init_score=init_score, group=group,
eval_set=eval_set, eval_names=eval_names, eval_sample_weight=eval_sample_weight,
eval_init_score=eval_init_score, eval_group=eval_group, eval_metric=eval_metric,
early_stopping_rounds=early_stopping_rounds, feature_name=feature_name,
categorical_feature=categorical_feature, callbacks=callbacks, init_model=init_model)
super().fit(
X,
y,
sample_weight=sample_weight,
init_score=init_score,
group=group,
eval_set=eval_set,
eval_names=eval_names,
eval_sample_weight=eval_sample_weight,
eval_init_score=eval_init_score,
eval_group=eval_group,
eval_metric=eval_metric,
feature_name=feature_name,
categorical_feature=categorical_feature,
callbacks=callbacks,
init_model=init_model
)
return self
_base_doc = LGBMModel.fit.__doc__.replace("self : LGBMModel", "self : LGBMRanker") # type: ignore
fit.__doc__ = (_base_doc[:_base_doc.find('eval_class_weight :')] # type: ignore
+ _base_doc[_base_doc.find('eval_init_score :'):]) # type: ignore
_base_doc = fit.__doc__
_before_early_stop, _early_stop, _after_early_stop = _base_doc.partition('early_stopping_rounds :')
fit.__doc__ = f"""{_before_early_stop}eval_at : iterable of int, optional (default=(1, 2, 3, 4, 5))
_before_feature_name, _feature_name, _after_feature_name = _base_doc.partition('feature_name :')
fit.__doc__ = f"""{_before_feature_name}eval_at : iterable of int, optional (default=(1, 2, 3, 4, 5))
The evaluation positions of the specified metric.
{_early_stop}{_after_early_stop}"""
{_feature_name}{_after_feature_name}"""
......@@ -92,7 +92,7 @@ def test_binary():
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
gbm = lgb.LGBMClassifier(n_estimators=50, verbose=-1)
gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=5)
gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=[lgb.early_stopping(5)])
ret = log_loss(y_test, gbm.predict_proba(X_test))
assert ret < 0.12
assert gbm.evals_result_['valid_0']['binary_logloss'][gbm.best_iteration_ - 1] == pytest.approx(ret)
......@@ -102,7 +102,7 @@ def test_regression():
X, y = load_boston(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
gbm = lgb.LGBMRegressor(n_estimators=50, verbose=-1)
gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=5)
gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=[lgb.early_stopping(5)])
ret = mean_squared_error(y_test, gbm.predict(X_test))
assert ret < 7
assert gbm.evals_result_['valid_0']['l2'][gbm.best_iteration_ - 1] == pytest.approx(ret)
......@@ -112,7 +112,7 @@ def test_multiclass():
X, y = load_digits(n_class=10, return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
gbm = lgb.LGBMClassifier(n_estimators=50, verbose=-1)
gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=5)
gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=[lgb.early_stopping(5)])
ret = multi_error(y_test, gbm.predict(X_test))
assert ret < 0.05
ret = multi_logloss(y_test, gbm.predict_proba(X_test))
......@@ -127,9 +127,18 @@ def test_lambdarank():
q_train = np.loadtxt(str(rank_example_dir / 'rank.train.query'))
q_test = np.loadtxt(str(rank_example_dir / 'rank.test.query'))
gbm = lgb.LGBMRanker(n_estimators=50)
gbm.fit(X_train, y_train, group=q_train, eval_set=[(X_test, y_test)],
eval_group=[q_test], eval_at=[1, 3], early_stopping_rounds=10,
callbacks=[lgb.reset_parameter(learning_rate=lambda x: max(0.01, 0.1 - 0.01 * x))])
gbm.fit(
X_train,
y_train,
group=q_train,
eval_set=[(X_test, y_test)],
eval_group=[q_test],
eval_at=[1, 3],
callbacks=[
lgb.early_stopping(10),
lgb.reset_parameter(learning_rate=lambda x: max(0.01, 0.1 - 0.01 * x))
]
)
assert gbm.best_iteration_ <= 24
assert gbm.best_score_['valid_0']['ndcg@1'] > 0.5674
assert gbm.best_score_['valid_0']['ndcg@3'] > 0.578
......@@ -142,10 +151,19 @@ def test_xendcg():
q_train = np.loadtxt(str(xendcg_example_dir / 'rank.train.query'))
q_test = np.loadtxt(str(xendcg_example_dir / 'rank.test.query'))
gbm = lgb.LGBMRanker(n_estimators=50, objective='rank_xendcg', random_state=5, n_jobs=1)
gbm.fit(X_train, y_train, group=q_train, eval_set=[(X_test, y_test)],
eval_group=[q_test], eval_at=[1, 3], early_stopping_rounds=10,
eval_metric='ndcg',
callbacks=[lgb.reset_parameter(learning_rate=lambda x: max(0.01, 0.1 - 0.01 * x))])
gbm.fit(
X_train,
y_train,
group=q_train,
eval_set=[(X_test, y_test)],
eval_group=[q_test],
eval_at=[1, 3],
eval_metric='ndcg',
callbacks=[
lgb.early_stopping(10),
lgb.reset_parameter(learning_rate=lambda x: max(0.01, 0.1 - 0.01 * x))
]
)
assert gbm.best_iteration_ <= 24
assert gbm.best_score_['valid_0']['ndcg@1'] > 0.6211
assert gbm.best_score_['valid_0']['ndcg@3'] > 0.6253
......@@ -196,7 +214,7 @@ def test_regression_with_custom_objective():
X, y = load_boston(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
gbm = lgb.LGBMRegressor(n_estimators=50, verbose=-1, objective=objective_ls)
gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=5)
gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=[lgb.early_stopping(5)])
ret = mean_squared_error(y_test, gbm.predict(X_test))
assert ret < 7.0
assert gbm.evals_result_['valid_0']['l2'][gbm.best_iteration_ - 1] == pytest.approx(ret)
......@@ -206,7 +224,7 @@ def test_binary_classification_with_custom_objective():
X, y = load_digits(n_class=2, return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
gbm = lgb.LGBMClassifier(n_estimators=50, verbose=-1, objective=logregobj)
gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=5)
gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=[lgb.early_stopping(5)])
# prediction result is actually not transformed (is raw) due to custom objective
y_pred_raw = gbm.predict_proba(X_test)
assert not np.all(y_pred_raw >= 0)
......@@ -318,7 +336,7 @@ def test_random_search():
reg_alpha=[np.random.uniform(low=0.01, high=0.06) for i in range(n_iter)])
fit_params = dict(eval_set=[(X_val, y_val)],
eval_metric=constant_metric,
early_stopping_rounds=2)
callbacks=[lgb.early_stopping(2)])
rand = RandomizedSearchCV(estimator=lgb.LGBMClassifier(**params),
param_distributions=param_dist, cv=2,
n_iter=n_iter, random_state=42)
......@@ -440,9 +458,19 @@ def test_joblib():
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
gbm = lgb.LGBMRegressor(n_estimators=10, objective=custom_asymmetric_obj,
verbose=-1, importance_type='split')
gbm.fit(X_train, y_train, eval_set=[(X_train, y_train), (X_test, y_test)],
eval_metric=mse, early_stopping_rounds=5,
callbacks=[lgb.reset_parameter(learning_rate=list(np.arange(1, 0, -0.1)))])
gbm.fit(
X_train,
y_train,
eval_set=[
(X_train, y_train),
(X_test, y_test)
],
eval_metric=mse,
callbacks=[
lgb.early_stopping(5),
lgb.reset_parameter(learning_rate=list(np.arange(1, 0, -0.1)))
]
)
joblib.dump(gbm, 'lgb.pkl') # test model with custom functions
gbm_pickle = joblib.load('lgb.pkl')
......@@ -1048,7 +1076,7 @@ def test_inf_handle():
weight = np.full(nrows, 1e10)
params = {'n_estimators': 20, 'verbose': -1}
params_fit = {'X': X, 'y': y, 'sample_weight': weight, 'eval_set': (X, y),
'early_stopping_rounds': 5}
'callbacks': [lgb.early_stopping(5)]}
gbm = lgb.LGBMRegressor(**params).fit(**params_fit)
np.testing.assert_allclose(gbm.evals_result_['training']['l2'], np.inf)
......@@ -1061,7 +1089,7 @@ def test_nan_handle():
weight = np.zeros(nrows)
params = {'n_estimators': 20, 'verbose': -1}
params_fit = {'X': X, 'y': y, 'sample_weight': weight, 'eval_set': (X, y),
'early_stopping_rounds': 5}
'callbacks': [lgb.early_stopping(5)]}
gbm = lgb.LGBMRegressor(**params).fit(**params_fit)
np.testing.assert_allclose(gbm.evals_result_['training']['l2'], np.nan)
......@@ -1079,7 +1107,7 @@ def test_first_metric_only():
assert metric_name in gbm.evals_result_[eval_set_name]
actual = len(gbm.evals_result_[eval_set_name][metric_name])
expected = assumed_iteration + (params_fit['early_stopping_rounds']
expected = assumed_iteration + (params['early_stopping_rounds']
if eval_set_name != 'training'
and assumed_iteration != gbm.n_estimators else 0)
assert expected == actual
......@@ -1095,10 +1123,10 @@ def test_first_metric_only():
'learning_rate': 0.8,
'num_leaves': 15,
'verbose': -1,
'seed': 123}
'seed': 123,
'early_stopping_rounds': 5} # early stop should be supported via global LightGBM parameter
params_fit = {'X': X_train,
'y': y_train,
'early_stopping_rounds': 5}
'y': y_train}
iter_valid1_l1 = 3
iter_valid1_l2 = 18
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment