"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "c53ac33b50760129da1b60886283d4918a21d5e7"
Unverified Commit 1ce23571 authored by Darcie Delzell's avatar Darcie Delzell Committed by GitHub
Browse files

[python-package] allow use of early_stopping_round<=0 to turn off early...

[python-package] allow use of early_stopping_round<=0 to turn off early stopping (fixes #6401) (#6406)
parent f6ecd4de
...@@ -280,8 +280,7 @@ class _EarlyStoppingCallback: ...@@ -280,8 +280,7 @@ class _EarlyStoppingCallback:
verbose: bool = True, verbose: bool = True,
min_delta: Union[float, List[float]] = 0.0, min_delta: Union[float, List[float]] = 0.0,
) -> None: ) -> None:
if not isinstance(stopping_rounds, int) or stopping_rounds <= 0: self.enabled = _should_enable_early_stopping(stopping_rounds)
raise ValueError(f"stopping_rounds should be an integer and greater than 0. got: {stopping_rounds}")
self.order = 30 self.order = 30
self.before_iteration = False self.before_iteration = False
...@@ -291,7 +290,6 @@ class _EarlyStoppingCallback: ...@@ -291,7 +290,6 @@ class _EarlyStoppingCallback:
self.verbose = verbose self.verbose = verbose
self.min_delta = min_delta self.min_delta = min_delta
self.enabled = True
self._reset_storages() self._reset_storages()
def _reset_storages(self) -> None: def _reset_storages(self) -> None:
...@@ -438,6 +436,18 @@ class _EarlyStoppingCallback: ...@@ -438,6 +436,18 @@ class _EarlyStoppingCallback:
self._final_iteration_check(env, eval_name_splitted, i) self._final_iteration_check(env, eval_name_splitted, i)
def _should_enable_early_stopping(stopping_rounds: Any) -> bool:
"""Check if early stopping should be activated.
This function will evaluate to True if the early stopping callback should be
activated (i.e. stopping_rounds > 0). It also provides an informative error if the
type is not int.
"""
if not isinstance(stopping_rounds, int):
raise TypeError(f"early_stopping_round should be an integer. Got '{type(stopping_rounds).__name__}'")
return stopping_rounds > 0
def early_stopping( def early_stopping(
stopping_rounds: int, stopping_rounds: int,
first_metric_only: bool = False, first_metric_only: bool = False,
......
...@@ -236,7 +236,7 @@ def train( ...@@ -236,7 +236,7 @@ def train(
cb.__dict__.setdefault("order", i - len(callbacks)) cb.__dict__.setdefault("order", i - len(callbacks))
callbacks_set = set(callbacks) callbacks_set = set(callbacks)
if "early_stopping_round" in params: if callback._should_enable_early_stopping(params.get("early_stopping_round", 0)):
callbacks_set.add( callbacks_set.add(
callback.early_stopping( callback.early_stopping(
stopping_rounds=params["early_stopping_round"], # type: ignore[arg-type] stopping_rounds=params["early_stopping_round"], # type: ignore[arg-type]
...@@ -760,7 +760,7 @@ def cv( ...@@ -760,7 +760,7 @@ def cv(
cb.__dict__.setdefault("order", i - len(callbacks)) cb.__dict__.setdefault("order", i - len(callbacks))
callbacks_set = set(callbacks) callbacks_set = set(callbacks)
if "early_stopping_round" in params: if callback._should_enable_early_stopping(params.get("early_stopping_round", 0)):
callbacks_set.add( callbacks_set.add(
callback.early_stopping( callback.early_stopping(
stopping_rounds=params["early_stopping_round"], # type: ignore[arg-type] stopping_rounds=params["early_stopping_round"], # type: ignore[arg-type]
......
...@@ -22,14 +22,14 @@ def test_early_stopping_callback_is_picklable(serializer): ...@@ -22,14 +22,14 @@ def test_early_stopping_callback_is_picklable(serializer):
def test_early_stopping_callback_rejects_invalid_stopping_rounds_with_informative_errors(): def test_early_stopping_callback_rejects_invalid_stopping_rounds_with_informative_errors():
with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: 0"): with pytest.raises(TypeError, match="early_stopping_round should be an integer. Got 'str'"):
lgb.early_stopping(stopping_rounds=0) lgb.early_stopping(stopping_rounds="neverrrr")
with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: -1"):
lgb.early_stopping(stopping_rounds=-1)
with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: neverrrr"): @pytest.mark.parametrize("stopping_rounds", [-10, -1, 0])
lgb.early_stopping(stopping_rounds="neverrrr") def test_early_stopping_callback_accepts_non_positive_stopping_rounds(stopping_rounds):
cb = lgb.early_stopping(stopping_rounds=stopping_rounds)
assert cb.enabled is False
@pytest.mark.parametrize("serializer", SERIALIZERS) @pytest.mark.parametrize("serializer", SERIALIZERS)
......
...@@ -938,6 +938,54 @@ def test_early_stopping_via_global_params(first_metric_only): ...@@ -938,6 +938,54 @@ def test_early_stopping_via_global_params(first_metric_only):
assert "error" in gbm.best_score[valid_set_name] assert "error" in gbm.best_score[valid_set_name]
@pytest.mark.parametrize("early_stopping_round", [-10, -1, 0, None, "None"])
def test_early_stopping_is_not_enabled_for_non_positive_stopping_rounds(early_stopping_round):
X, y = load_breast_cancer(return_X_y=True)
num_trees = 5
params = {
"num_trees": num_trees,
"objective": "binary",
"metric": "None",
"verbose": -1,
"early_stopping_round": early_stopping_round,
"first_metric_only": True,
}
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
valid_set_name = "valid_set"
if early_stopping_round is None:
gbm = lgb.train(
params,
lgb_train,
feval=[constant_metric],
valid_sets=lgb_eval,
valid_names=valid_set_name,
)
assert "early_stopping_round" not in gbm.params
assert gbm.num_trees() == num_trees
elif early_stopping_round == "None":
with pytest.raises(TypeError, match="early_stopping_round should be an integer. Got 'str'"):
gbm = lgb.train(
params,
lgb_train,
feval=[constant_metric],
valid_sets=lgb_eval,
valid_names=valid_set_name,
)
elif early_stopping_round <= 0:
gbm = lgb.train(
params,
lgb_train,
feval=[constant_metric],
valid_sets=lgb_eval,
valid_names=valid_set_name,
)
assert gbm.params["early_stopping_round"] == early_stopping_round
assert gbm.num_trees() == num_trees
@pytest.mark.parametrize("first_only", [True, False]) @pytest.mark.parametrize("first_only", [True, False])
@pytest.mark.parametrize("single_metric", [True, False]) @pytest.mark.parametrize("single_metric", [True, False])
@pytest.mark.parametrize("greater_is_better", [True, False]) @pytest.mark.parametrize("greater_is_better", [True, False])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment