Unverified Commit e063dad2 authored by José Morales's avatar José Morales Committed by GitHub
Browse files

[python-package] ignore training set on early stopping callback (fixes #5354) (#5412)



* ignore training set on early stopping callback

* fixes

* lint

* Apply suggestions from code review
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* trigger ci
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
parent 581d53c6
...@@ -258,11 +258,24 @@ class _EarlyStoppingCallback: ...@@ -258,11 +258,24 @@ class _EarlyStoppingCallback:
def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool: def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
return curr_score < best_score - delta return curr_score < best_score - delta
def _is_train_set(self, ds_name: str, eval_name: str, train_name: str) -> bool:
return (ds_name == "cv_agg" and eval_name == "train") or ds_name == train_name
def _init(self, env: CallbackEnv) -> None: def _init(self, env: CallbackEnv) -> None:
self.enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias is_dart = any(env.params.get(alias, "") == 'dart' for alias in _ConfigAliases.get("boosting"))
in _ConfigAliases.get("boosting")) only_train_set = (
len(env.evaluation_result_list) == 1
and self._is_train_set(
ds_name=env.evaluation_result_list[0][0],
eval_name=env.evaluation_result_list[0][1].split(" ")[0],
train_name=env.model._train_data_name)
)
self.enabled = not is_dart and not only_train_set
if not self.enabled: if not self.enabled:
_log_warning('Early stopping is not available in dart mode') if is_dart:
_log_warning('Early stopping is not available in dart mode')
elif only_train_set:
_log_warning('Only training set found, disabling early stopping.')
return return
if not env.evaluation_result_list: if not env.evaluation_result_list:
raise ValueError('For early stopping, ' raise ValueError('For early stopping, '
...@@ -339,9 +352,7 @@ class _EarlyStoppingCallback: ...@@ -339,9 +352,7 @@ class _EarlyStoppingCallback:
eval_name_splitted = env.evaluation_result_list[i][1].split(" ") eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
if self.first_metric_only and self.first_metric != eval_name_splitted[-1]: if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
continue # use only the first metric for early stopping continue # use only the first metric for early stopping
if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train" if self._is_train_set(env.evaluation_result_list[i][0], eval_name_splitted[0], env.model._train_data_name):
or env.evaluation_result_list[i][0] == env.model._train_data_name)):
self._final_iteration_check(env, eval_name_splitted, i)
continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train) continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
elif env.iteration - self.best_iter[i] >= self.stopping_rounds: elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
if self.verbose: if self.verbose:
......
...@@ -765,6 +765,43 @@ def test_early_stopping(): ...@@ -765,6 +765,43 @@ def test_early_stopping():
assert 'binary_logloss' in gbm.best_score[valid_set_name] assert 'binary_logloss' in gbm.best_score[valid_set_name]
@pytest.mark.parametrize('use_valid', [True, False])
def test_early_stopping_ignores_training_set(use_valid):
x = np.linspace(-1, 1, 100)
X = x.reshape(-1, 1)
y = x**2
X_train, X_valid = X[:80], X[80:]
y_train, y_valid = y[:80], y[80:]
train_ds = lgb.Dataset(X_train, y_train)
valid_ds = lgb.Dataset(X_valid, y_valid)
valid_sets = [train_ds]
valid_names = ['train']
if use_valid:
valid_sets.append(valid_ds)
valid_names.append('valid')
eval_result = {}
def train_fn():
return lgb.train(
{'num_leaves': 5},
train_ds,
num_boost_round=2,
valid_sets=valid_sets,
valid_names=valid_names,
callbacks=[lgb.early_stopping(1), lgb.record_evaluation(eval_result)]
)
if use_valid:
bst = train_fn()
assert bst.best_iteration == 1
assert eval_result['train']['l2'][1] < eval_result['train']['l2'][0] # train improved
assert eval_result['valid']['l2'][1] > eval_result['valid']['l2'][0] # valid didn't
else:
with pytest.warns(UserWarning, match='Only training set found, disabling early stopping.'):
bst = train_fn()
assert bst.current_iteration() == 2
assert bst.best_iteration == 0
@pytest.mark.parametrize('first_metric_only', [True, False]) @pytest.mark.parametrize('first_metric_only', [True, False])
def test_early_stopping_via_global_params(first_metric_only): def test_early_stopping_via_global_params(first_metric_only):
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
......
...@@ -1124,11 +1124,6 @@ def test_first_metric_only(): ...@@ -1124,11 +1124,6 @@ def test_first_metric_only():
iter_min = min([iter_min_l1, iter_min_l2]) iter_min = min([iter_min_l1, iter_min_l2])
iter_min_valid1 = min([iter_valid1_l1, iter_valid1_l2]) iter_min_valid1 = min([iter_valid1_l1, iter_valid1_l2])
# training data as eval_set
params_fit['eval_set'] = (X_train, y_train)
fit_and_check(['training'], ['l2'], 30, False)
fit_and_check(['training'], ['l2'], 30, True)
# feval # feval
params['metric'] = 'None' params['metric'] = 'None'
params_fit['eval_metric'] = lambda preds, train_data: [decreasing_metric(preds, train_data), params_fit['eval_metric'] = lambda preds, train_data: [decreasing_metric(preds, train_data),
......
...@@ -29,17 +29,18 @@ def test_register_logger(tmp_path): ...@@ -29,17 +29,18 @@ def test_register_logger(tmp_path):
[1, 2, 3]], [1, 2, 3]],
dtype=np.float32) dtype=np.float32)
y = np.array([0, 1, 1, 0]) y = np.array([0, 1, 1, 0])
lgb_data = lgb.Dataset(X, y) lgb_train = lgb.Dataset(X, y)
lgb_valid = lgb.Dataset(X, y) # different object for early-stopping
eval_records = {} eval_records = {}
callbacks = [ callbacks = [
lgb.record_evaluation(eval_records), lgb.record_evaluation(eval_records),
lgb.log_evaluation(2), lgb.log_evaluation(2),
lgb.early_stopping(4) lgb.early_stopping(10)
] ]
lgb.train({'objective': 'binary', 'metric': ['auc', 'binary_error']}, lgb.train({'objective': 'binary', 'metric': ['auc', 'binary_error']},
lgb_data, num_boost_round=10, feval=dummy_metric, lgb_train, num_boost_round=10, feval=dummy_metric,
valid_sets=[lgb_data], categorical_feature=[1], callbacks=callbacks) valid_sets=[lgb_valid], categorical_feature=[1], callbacks=callbacks)
lgb.plot_metric(eval_records) lgb.plot_metric(eval_records)
...@@ -51,32 +52,32 @@ INFO | [LightGBM] [Info] Number of data points in the train set: 4, number of us ...@@ -51,32 +52,32 @@ INFO | [LightGBM] [Info] Number of data points in the train set: 4, number of us
INFO | [LightGBM] [Info] [binary:BoostFromScore]: pavg=0.500000 -> initscore=0.000000 INFO | [LightGBM] [Info] [binary:BoostFromScore]: pavg=0.500000 -> initscore=0.000000
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric DEBUG | In dummy_metric
INFO | Training until validation scores don't improve for 4 rounds INFO | Training until validation scores don't improve for 10 rounds
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric DEBUG | In dummy_metric
INFO | [2] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1 INFO | [2] valid_0's auc: 0.5 valid_0's binary_error: 0.5 valid_0's dummy_metric: 1
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric DEBUG | In dummy_metric
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric DEBUG | In dummy_metric
INFO | [4] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1 INFO | [4] valid_0's auc: 0.5 valid_0's binary_error: 0.5 valid_0's dummy_metric: 1
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric DEBUG | In dummy_metric
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric DEBUG | In dummy_metric
INFO | [6] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1 INFO | [6] valid_0's auc: 0.5 valid_0's binary_error: 0.5 valid_0's dummy_metric: 1
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric DEBUG | In dummy_metric
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric DEBUG | In dummy_metric
INFO | [8] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1 INFO | [8] valid_0's auc: 0.5 valid_0's binary_error: 0.5 valid_0's dummy_metric: 1
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric DEBUG | In dummy_metric
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric DEBUG | In dummy_metric
INFO | [10] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1 INFO | [10] valid_0's auc: 0.5 valid_0's binary_error: 0.5 valid_0's dummy_metric: 1
INFO | Did not meet early stopping. Best iteration is: INFO | Did not meet early stopping. Best iteration is:
[1] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1 [1] valid_0's auc: 0.5 valid_0's binary_error: 0.5 valid_0's dummy_metric: 1
WARNING | More than one metric available, picking one to plot. WARNING | More than one metric available, picking one to plot.
""".strip() """.strip()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment