Unverified Commit e063dad2 authored by José Morales's avatar José Morales Committed by GitHub
Browse files

[python-package] ignore training set on early stopping callback (fixes #5354) (#5412)



* ignore training set on early stopping callback

* fixes

* lint

* Apply suggestions from code review
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* trigger ci
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
parent 581d53c6
......@@ -258,11 +258,24 @@ class _EarlyStoppingCallback:
def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
return curr_score < best_score - delta
def _is_train_set(self, ds_name: str, eval_name: str, train_name: str) -> bool:
return (ds_name == "cv_agg" and eval_name == "train") or ds_name == train_name
def _init(self, env: CallbackEnv) -> None:
self.enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
in _ConfigAliases.get("boosting"))
is_dart = any(env.params.get(alias, "") == 'dart' for alias in _ConfigAliases.get("boosting"))
only_train_set = (
len(env.evaluation_result_list) == 1
and self._is_train_set(
ds_name=env.evaluation_result_list[0][0],
eval_name=env.evaluation_result_list[0][1].split(" ")[0],
train_name=env.model._train_data_name)
)
self.enabled = not is_dart and not only_train_set
if not self.enabled:
if is_dart:
_log_warning('Early stopping is not available in dart mode')
elif only_train_set:
_log_warning('Only training set found, disabling early stopping.')
return
if not env.evaluation_result_list:
raise ValueError('For early stopping, '
......@@ -339,9 +352,7 @@ class _EarlyStoppingCallback:
eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
continue # use only the first metric for early stopping
if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
or env.evaluation_result_list[i][0] == env.model._train_data_name)):
self._final_iteration_check(env, eval_name_splitted, i)
if self._is_train_set(env.evaluation_result_list[i][0], eval_name_splitted[0], env.model._train_data_name):
continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
if self.verbose:
......
......@@ -765,6 +765,43 @@ def test_early_stopping():
assert 'binary_logloss' in gbm.best_score[valid_set_name]
@pytest.mark.parametrize('use_valid', [True, False])
def test_early_stopping_ignores_training_set(use_valid):
x = np.linspace(-1, 1, 100)
X = x.reshape(-1, 1)
y = x**2
X_train, X_valid = X[:80], X[80:]
y_train, y_valid = y[:80], y[80:]
train_ds = lgb.Dataset(X_train, y_train)
valid_ds = lgb.Dataset(X_valid, y_valid)
valid_sets = [train_ds]
valid_names = ['train']
if use_valid:
valid_sets.append(valid_ds)
valid_names.append('valid')
eval_result = {}
def train_fn():
return lgb.train(
{'num_leaves': 5},
train_ds,
num_boost_round=2,
valid_sets=valid_sets,
valid_names=valid_names,
callbacks=[lgb.early_stopping(1), lgb.record_evaluation(eval_result)]
)
if use_valid:
bst = train_fn()
assert bst.best_iteration == 1
assert eval_result['train']['l2'][1] < eval_result['train']['l2'][0] # train improved
assert eval_result['valid']['l2'][1] > eval_result['valid']['l2'][0] # valid didn't
else:
with pytest.warns(UserWarning, match='Only training set found, disabling early stopping.'):
bst = train_fn()
assert bst.current_iteration() == 2
assert bst.best_iteration == 0
@pytest.mark.parametrize('first_metric_only', [True, False])
def test_early_stopping_via_global_params(first_metric_only):
X, y = load_breast_cancer(return_X_y=True)
......
......@@ -1124,11 +1124,6 @@ def test_first_metric_only():
iter_min = min([iter_min_l1, iter_min_l2])
iter_min_valid1 = min([iter_valid1_l1, iter_valid1_l2])
# training data as eval_set
params_fit['eval_set'] = (X_train, y_train)
fit_and_check(['training'], ['l2'], 30, False)
fit_and_check(['training'], ['l2'], 30, True)
# feval
params['metric'] = 'None'
params_fit['eval_metric'] = lambda preds, train_data: [decreasing_metric(preds, train_data),
......
......@@ -29,17 +29,18 @@ def test_register_logger(tmp_path):
[1, 2, 3]],
dtype=np.float32)
y = np.array([0, 1, 1, 0])
lgb_data = lgb.Dataset(X, y)
lgb_train = lgb.Dataset(X, y)
lgb_valid = lgb.Dataset(X, y) # different object for early-stopping
eval_records = {}
callbacks = [
lgb.record_evaluation(eval_records),
lgb.log_evaluation(2),
lgb.early_stopping(4)
lgb.early_stopping(10)
]
lgb.train({'objective': 'binary', 'metric': ['auc', 'binary_error']},
lgb_data, num_boost_round=10, feval=dummy_metric,
valid_sets=[lgb_data], categorical_feature=[1], callbacks=callbacks)
lgb_train, num_boost_round=10, feval=dummy_metric,
valid_sets=[lgb_valid], categorical_feature=[1], callbacks=callbacks)
lgb.plot_metric(eval_records)
......@@ -51,32 +52,32 @@ INFO | [LightGBM] [Info] Number of data points in the train set: 4, number of us
INFO | [LightGBM] [Info] [binary:BoostFromScore]: pavg=0.500000 -> initscore=0.000000
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | Training until validation scores don't improve for 4 rounds
INFO | Training until validation scores don't improve for 10 rounds
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [2] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1
INFO | [2] valid_0's auc: 0.5 valid_0's binary_error: 0.5 valid_0's dummy_metric: 1
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [4] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1
INFO | [4] valid_0's auc: 0.5 valid_0's binary_error: 0.5 valid_0's dummy_metric: 1
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [6] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1
INFO | [6] valid_0's auc: 0.5 valid_0's binary_error: 0.5 valid_0's dummy_metric: 1
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [8] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1
INFO | [8] valid_0's auc: 0.5 valid_0's binary_error: 0.5 valid_0's dummy_metric: 1
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [10] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1
INFO | [10] valid_0's auc: 0.5 valid_0's binary_error: 0.5 valid_0's dummy_metric: 1
INFO | Did not meet early stopping. Best iteration is:
[1] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1
[1] valid_0's auc: 0.5 valid_0's binary_error: 0.5 valid_0's dummy_metric: 1
WARNING | More than one metric available, picking one to plot.
""".strip()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment