Unverified Commit 0b3d9da2 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] mark EarlyStopException as part of public API (#6095)

parent 1a6e6ff9
...@@ -6,7 +6,7 @@ Contributors: https://github.com/microsoft/LightGBM/graphs/contributors. ...@@ -6,7 +6,7 @@ Contributors: https://github.com/microsoft/LightGBM/graphs/contributors.
from pathlib import Path from pathlib import Path
from .basic import Booster, Dataset, Sequence, register_logger from .basic import Booster, Dataset, Sequence, register_logger
from .callback import early_stopping, log_evaluation, record_evaluation, reset_parameter from .callback import EarlyStopException, early_stopping, log_evaluation, record_evaluation, reset_parameter
from .engine import CVBooster, cv, train from .engine import CVBooster, cv, train
try: try:
...@@ -32,5 +32,5 @@ __all__ = ['Dataset', 'Booster', 'CVBooster', 'Sequence', ...@@ -32,5 +32,5 @@ __all__ = ['Dataset', 'Booster', 'CVBooster', 'Sequence',
'train', 'cv', 'train', 'cv',
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker', 'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
'DaskLGBMRegressor', 'DaskLGBMClassifier', 'DaskLGBMRanker', 'DaskLGBMRegressor', 'DaskLGBMClassifier', 'DaskLGBMRanker',
'log_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping', 'log_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping', 'EarlyStopException',
'plot_importance', 'plot_split_value_histogram', 'plot_metric', 'plot_tree', 'create_tree_digraph'] 'plot_importance', 'plot_split_value_histogram', 'plot_metric', 'plot_tree', 'create_tree_digraph']
...@@ -12,6 +12,7 @@ if TYPE_CHECKING: ...@@ -12,6 +12,7 @@ if TYPE_CHECKING:
from .engine import CVBooster from .engine import CVBooster
__all__ = [ __all__ = [
'EarlyStopException',
'early_stopping', 'early_stopping',
'log_evaluation', 'log_evaluation',
'record_evaluation', 'record_evaluation',
...@@ -30,7 +31,11 @@ _ListOfEvalResultTuples = Union[ ...@@ -30,7 +31,11 @@ _ListOfEvalResultTuples = Union[
class EarlyStopException(Exception): class EarlyStopException(Exception):
"""Exception of early stopping.""" """Exception of early stopping.
Raise this from a callback passed in via keyword argument ``callbacks``
in ``cv()`` or ``train()`` to trigger early stopping.
"""
def __init__(self, best_iteration: int, best_score: _ListOfEvalResultTuples) -> None: def __init__(self, best_iteration: int, best_score: _ListOfEvalResultTuples) -> None:
"""Create early stopping exception. """Create early stopping exception.
...@@ -39,6 +44,7 @@ class EarlyStopException(Exception): ...@@ -39,6 +44,7 @@ class EarlyStopException(Exception):
---------- ----------
best_iteration : int best_iteration : int
The best iteration stopped. The best iteration stopped.
0-based... pass ``best_iteration=2`` to indicate that the third iteration was the best one.
best_score : list of (eval_name, metric_name, eval_result, is_higher_better) tuple or (eval_name, metric_name, eval_result, is_higher_better, stdv) tuple best_score : list of (eval_name, metric_name, eval_result, is_higher_better) tuple or (eval_name, metric_name, eval_result, is_higher_better, stdv) tuple
Scores for each metric, on each validation set, as of the best iteration. Scores for each metric, on each validation set, as of the best iteration.
""" """
......
...@@ -1092,6 +1092,33 @@ def test_early_stopping_min_delta(first_only, single_metric, greater_is_better): ...@@ -1092,6 +1092,33 @@ def test_early_stopping_min_delta(first_only, single_metric, greater_is_better):
assert np.greater_equal(last_score, best_score - min_delta).any() assert np.greater_equal(last_score, best_score - min_delta).any()
def test_early_stopping_can_be_triggered_via_custom_callback():
X, y = make_synthetic_regression()
def _early_stop_after_seventh_iteration(env):
if env.iteration == 6:
exc = lgb.EarlyStopException(
best_iteration=6,
best_score=[("some_validation_set", "some_metric", 0.708, True)]
)
raise exc
bst = lgb.train(
params={
"objective": "regression",
"verbose": -1,
"num_leaves": 2
},
train_set=lgb.Dataset(X, label=y),
num_boost_round=23,
callbacks=[_early_stop_after_seventh_iteration]
)
assert bst.num_trees() == 7
assert bst.best_score["some_validation_set"]["some_metric"] == 0.708
assert bst.best_iteration == 7
assert bst.current_iteration() == 7
def test_continue_train(): def test_continue_train():
X, y = make_synthetic_regression() X, y = make_synthetic_regression()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment