Unverified Commit f91e5644 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python] added ability to pass first_metric_only in params (#2175)

* added ability to pass first_metric_only in params

* simplified tests

* fixed test

* fixed punctuation
parent 2a369170
...@@ -207,7 +207,7 @@ Note that ``train()`` will return a model from the best iteration. ...@@ -207,7 +207,7 @@ Note that ``train()`` will return a model from the best iteration.
This works with both metrics to minimize (L2, log loss, etc.) and to maximize (NDCG, AUC, etc.). This works with both metrics to minimize (L2, log loss, etc.) and to maximize (NDCG, AUC, etc.).
Note that if you specify more than one evaluation metric, all of them will be used for early stopping. Note that if you specify more than one evaluation metric, all of them will be used for early stopping.
However, you can change this behavior and make LightGBM check only the first metric for early stopping by creating ``early_stopping`` callback with ``first_metric_only=True``. However, you can change this behavior and make LightGBM check only the first metric for early stopping by passing ``first_metric_only=True`` in ``param`` or ``early_stopping`` callback constructor.
Prediction Prediction
---------- ----------
......
...@@ -66,8 +66,7 @@ def train(params, train_set, num_boost_round=100, ...@@ -66,8 +66,7 @@ def train(params, train_set, num_boost_round=100,
to continue training. to continue training.
Requires at least one validation data and one metric. Requires at least one validation data and one metric.
If there's more than one, will check all of them. But the training data is ignored anyway. If there's more than one, will check all of them. But the training data is ignored anyway.
To check only the first metric you can pass in ``callbacks`` To check only the first metric, set the ``first_metric_only`` parameter to ``True`` in ``params``.
``early_stopping`` callback with ``first_metric_only=True``.
The index of iteration that has the best performance will be saved in the ``best_iteration`` field The index of iteration that has the best performance will be saved in the ``best_iteration`` field
if early stopping logic is enabled by setting ``early_stopping_rounds``. if early stopping logic is enabled by setting ``early_stopping_rounds``.
evals_result: dict or None, optional (default=None) evals_result: dict or None, optional (default=None)
...@@ -116,14 +115,15 @@ def train(params, train_set, num_boost_round=100, ...@@ -116,14 +115,15 @@ def train(params, train_set, num_boost_round=100,
for alias in ["num_iterations", "num_iteration", "n_iter", "num_tree", "num_trees", for alias in ["num_iterations", "num_iteration", "n_iter", "num_tree", "num_trees",
"num_round", "num_rounds", "num_boost_round", "n_estimators"]: "num_round", "num_rounds", "num_boost_round", "n_estimators"]:
if alias in params: if alias in params:
num_boost_round = int(params.pop(alias)) num_boost_round = params.pop(alias)
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias)) warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
break break
for alias in ["early_stopping_round", "early_stopping_rounds", "early_stopping"]: for alias in ["early_stopping_round", "early_stopping_rounds", "early_stopping"]:
if alias in params and params[alias] is not None: if alias in params:
early_stopping_rounds = int(params.pop(alias)) early_stopping_rounds = params.pop(alias)
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias)) warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
break break
first_metric_only = params.pop('first_metric_only', False)
if num_boost_round <= 0: if num_boost_round <= 0:
raise ValueError("num_boost_round should be greater than zero.") raise ValueError("num_boost_round should be greater than zero.")
...@@ -181,7 +181,7 @@ def train(params, train_set, num_boost_round=100, ...@@ -181,7 +181,7 @@ def train(params, train_set, num_boost_round=100,
callbacks.add(callback.print_evaluation(verbose_eval)) callbacks.add(callback.print_evaluation(verbose_eval))
if early_stopping_rounds is not None: if early_stopping_rounds is not None:
callbacks.add(callback.early_stopping(early_stopping_rounds, verbose=bool(verbose_eval))) callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=bool(verbose_eval)))
if learning_rates is not None: if learning_rates is not None:
callbacks.add(callback.reset_parameter(learning_rate=learning_rates)) callbacks.add(callback.reset_parameter(learning_rate=learning_rates))
...@@ -400,8 +400,7 @@ def cv(params, train_set, num_boost_round=100, ...@@ -400,8 +400,7 @@ def cv(params, train_set, num_boost_round=100,
CV score needs to improve at least every ``early_stopping_rounds`` round(s) CV score needs to improve at least every ``early_stopping_rounds`` round(s)
to continue. to continue.
Requires at least one metric. If there's more than one, will check all of them. Requires at least one metric. If there's more than one, will check all of them.
To check only the first metric you can pass in ``callbacks`` To check only the first metric, set the ``first_metric_only`` parameter to ``True`` in ``params``.
``early_stopping`` callback with ``first_metric_only=True``.
Last entry in evaluation history is the one from the best iteration. Last entry in evaluation history is the one from the best iteration.
fpreproc : callable or None, optional (default=None) fpreproc : callable or None, optional (default=None)
Preprocessing function that takes (dtrain, dtest, params) Preprocessing function that takes (dtrain, dtest, params)
...@@ -449,6 +448,7 @@ def cv(params, train_set, num_boost_round=100, ...@@ -449,6 +448,7 @@ def cv(params, train_set, num_boost_round=100,
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias)) warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
early_stopping_rounds = params.pop(alias) early_stopping_rounds = params.pop(alias)
break break
first_metric_only = params.pop('first_metric_only', False)
if num_boost_round <= 0: if num_boost_round <= 0:
raise ValueError("num_boost_round should be greater than zero.") raise ValueError("num_boost_round should be greater than zero.")
...@@ -480,7 +480,7 @@ def cv(params, train_set, num_boost_round=100, ...@@ -480,7 +480,7 @@ def cv(params, train_set, num_boost_round=100,
cb.__dict__.setdefault('order', i - len(callbacks)) cb.__dict__.setdefault('order', i - len(callbacks))
callbacks = set(callbacks) callbacks = set(callbacks)
if early_stopping_rounds is not None: if early_stopping_rounds is not None:
callbacks.add(callback.early_stopping(early_stopping_rounds, verbose=False)) callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=False))
if verbose_eval is True: if verbose_eval is True:
callbacks.add(callback.print_evaluation(show_stdv=show_stdv)) callbacks.add(callback.print_evaluation(show_stdv=show_stdv))
elif isinstance(verbose_eval, integer_types): elif isinstance(verbose_eval, integer_types):
......
...@@ -376,8 +376,8 @@ class LGBMModel(_LGBMModelBase): ...@@ -376,8 +376,8 @@ class LGBMModel(_LGBMModelBase):
to continue training. to continue training.
Requires at least one validation data and one metric. Requires at least one validation data and one metric.
If there's more than one, will check all of them. But the training data is ignored anyway. If there's more than one, will check all of them. But the training data is ignored anyway.
To check only the first metric you can pass in ``callbacks`` To check only the first metric, set the ``first_metric_only`` parameter to ``True``
``early_stopping`` callback with ``first_metric_only=True``. in additional parameters ``**kwargs`` of the model constructor.
verbose : bool or int, optional (default=True) verbose : bool or int, optional (default=True)
Requires at least one evaluation data. Requires at least one evaluation data.
If True, the eval metric on the eval set is printed at each boosting stage. If True, the eval metric on the eval set is printed at each boosting stage.
......
...@@ -21,7 +21,7 @@ class FileLoader(object): ...@@ -21,7 +21,7 @@ class FileLoader(object):
if line and not line.startswith('#'): if line and not line.startswith('#'):
key, value = [token.strip() for token in line.split('=')] key, value = [token.strip() for token in line.split('=')]
if 'early_stopping' not in key: # disable early_stopping if 'early_stopping' not in key: # disable early_stopping
self.params[key] = value self.params[key] = value if key != 'num_trees' else int(value)
def load_dataset(self, suffix, is_sparse=False): def load_dataset(self, suffix, is_sparse=False):
filename = self.path(suffix) filename = self.path(suffix)
......
...@@ -1379,24 +1379,23 @@ class TestEngine(unittest.TestCase): ...@@ -1379,24 +1379,23 @@ class TestEngine(unittest.TestCase):
return ('constant_metric', 0.0, False) return ('constant_metric', 0.0, False)
# test that all metrics are checked (default behaviour) # test that all metrics are checked (default behaviour)
early_stop_callback = lgb.early_stopping(5, verbose=False)
gbm = lgb.train(params, lgb_train, num_boost_round=20, valid_sets=[lgb_eval], gbm = lgb.train(params, lgb_train, num_boost_round=20, valid_sets=[lgb_eval],
feval=lambda preds, train_data: [decreasing_metric(preds, train_data), feval=lambda preds, train_data: [decreasing_metric(preds, train_data),
constant_metric(preds, train_data)], constant_metric(preds, train_data)],
callbacks=[early_stop_callback]) early_stopping_rounds=5, verbose_eval=False)
self.assertEqual(gbm.best_iteration, 1) self.assertEqual(gbm.best_iteration, 1)
# test that only the first metric is checked # test that only the first metric is checked
early_stop_callback = lgb.early_stopping(5, first_metric_only=True, verbose=False) gbm = lgb.train(dict(params, first_metric_only=True), lgb_train,
gbm = lgb.train(params, lgb_train, num_boost_round=20, valid_sets=[lgb_eval], num_boost_round=20, valid_sets=[lgb_eval],
feval=lambda preds, train_data: [decreasing_metric(preds, train_data), feval=lambda preds, train_data: [decreasing_metric(preds, train_data),
constant_metric(preds, train_data)], constant_metric(preds, train_data)],
callbacks=[early_stop_callback]) early_stopping_rounds=5, verbose_eval=False)
self.assertEqual(gbm.best_iteration, 20) self.assertEqual(gbm.best_iteration, 20)
# ... change the order of metrics # ... change the order of metrics
early_stop_callback = lgb.early_stopping(5, first_metric_only=True, verbose=False) gbm = lgb.train(dict(params, first_metric_only=True), lgb_train,
gbm = lgb.train(params, lgb_train, num_boost_round=20, valid_sets=[lgb_eval], num_boost_round=20, valid_sets=[lgb_eval],
feval=lambda preds, train_data: [constant_metric(preds, train_data), feval=lambda preds, train_data: [constant_metric(preds, train_data),
decreasing_metric(preds, train_data)], decreasing_metric(preds, train_data)],
callbacks=[early_stop_callback]) early_stopping_rounds=5, verbose_eval=False)
self.assertEqual(gbm.best_iteration, 1) self.assertEqual(gbm.best_iteration, 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment