Unverified Commit 0b3d9da2 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] mark EarlyStopException as part of public API (#6095)

parent 1a6e6ff9
......@@ -6,7 +6,7 @@ Contributors: https://github.com/microsoft/LightGBM/graphs/contributors.
from pathlib import Path
from .basic import Booster, Dataset, Sequence, register_logger
from .callback import early_stopping, log_evaluation, record_evaluation, reset_parameter
from .callback import EarlyStopException, early_stopping, log_evaluation, record_evaluation, reset_parameter
from .engine import CVBooster, cv, train
try:
......@@ -32,5 +32,5 @@ __all__ = ['Dataset', 'Booster', 'CVBooster', 'Sequence',
'train', 'cv',
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
'DaskLGBMRegressor', 'DaskLGBMClassifier', 'DaskLGBMRanker',
'log_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping',
'log_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping', 'EarlyStopException',
'plot_importance', 'plot_split_value_histogram', 'plot_metric', 'plot_tree', 'create_tree_digraph']
......@@ -12,6 +12,7 @@ if TYPE_CHECKING:
from .engine import CVBooster
__all__ = [
'EarlyStopException',
'early_stopping',
'log_evaluation',
'record_evaluation',
......@@ -30,7 +31,11 @@ _ListOfEvalResultTuples = Union[
class EarlyStopException(Exception):
"""Exception of early stopping."""
"""Exception of early stopping.
Raise this from a callback passed in via keyword argument ``callbacks``
in ``cv()`` or ``train()`` to trigger early stopping.
"""
def __init__(self, best_iteration: int, best_score: _ListOfEvalResultTuples) -> None:
"""Create early stopping exception.
......@@ -39,6 +44,7 @@ class EarlyStopException(Exception):
----------
best_iteration : int
The best iteration stopped.
0-based... pass ``best_iteration=2`` to indicate that the third iteration was the best one.
best_score : list of (eval_name, metric_name, eval_result, is_higher_better) tuple or (eval_name, metric_name, eval_result, is_higher_better, stdv) tuple
Scores for each metric, on each validation set, as of the best iteration.
"""
......
......@@ -1092,6 +1092,33 @@ def test_early_stopping_min_delta(first_only, single_metric, greater_is_better):
assert np.greater_equal(last_score, best_score - min_delta).any()
def test_early_stopping_can_be_triggered_via_custom_callback():
X, y = make_synthetic_regression()
def _early_stop_after_seventh_iteration(env):
if env.iteration == 6:
exc = lgb.EarlyStopException(
best_iteration=6,
best_score=[("some_validation_set", "some_metric", 0.708, True)]
)
raise exc
bst = lgb.train(
params={
"objective": "regression",
"verbose": -1,
"num_leaves": 2
},
train_set=lgb.Dataset(X, label=y),
num_boost_round=23,
callbacks=[_early_stop_after_seventh_iteration]
)
assert bst.num_trees() == 7
assert bst.best_score["some_validation_set"]["some_metric"] == 0.708
assert bst.best_iteration == 7
assert bst.current_iteration() == 7
def test_continue_train():
X, y = make_synthetic_regression()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment