Commit 7a81f7bd authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

reset_learning_rate->reset_parameter (#131)

* add gridsearch example for python sklearn

* reset_learning_rate->reset_parameter

* make callbacks public
parent 97df0684
...@@ -14,6 +14,14 @@ ...@@ -14,6 +14,14 @@
- [LGBMRegressor](Python-API.md#lgbmregressor) - [LGBMRegressor](Python-API.md#lgbmregressor)
- [LGBMRanker](Python-API.md#lgbmranker) - [LGBMRanker](Python-API.md#lgbmranker)
* [Callbacks](Python-API.md#callbacks)
- [Before iteration](Python-API.md#before-iteration)
+ [reset_parameter](Python-API.md#reset_parameterkwargs)
- [After iteration](Python-API.md#after-iteration)
+ [print_evaluation](Python-API.md#print_evaluationperiod1-show_stdvtrue)
+ [record_evaluation](Python-API.md#record_evaluationeval_result)
+ [early_stopping](Python-API.md#early_stoppingstopping_rounds-verbosetrue)
The methods of each Class is in alphabetical order. The methods of each Class is in alphabetical order.
---- ----
...@@ -496,12 +504,10 @@ The methods of each Class is in alphabetical order. ...@@ -496,12 +504,10 @@ The methods of each Class is in alphabetical order.
an evaluation metric is printed every 4 (instead of 1) boosting stages. an evaluation metric is printed every 4 (instead of 1) boosting stages.
learning_rates: list or function learning_rates: list or function
List of learning rate for each boosting round List of learning rate for each boosting round
or a customized function that calculates learning_rate in terms of or a customized function that calculates learning_rate
current number of round (and the total number of boosting round) in terms of current number of round (e.g. yields learning rate decay)
(e.g. yields learning rate decay)
- list l: learning_rate = l[current_round] - list l: learning_rate = l[current_round]
- function f: learning_rate = f(current_round, total_boost_round) - function f: learning_rate = f(current_round)
or learning_rate = f(current_round)
callbacks : list of callback functions callbacks : list of callback functions
List of callback functions that are applied at end of each iteration. List of callback functions that are applied at end of each iteration.
...@@ -805,3 +811,80 @@ The methods of each Class is in alphabetical order. ...@@ -805,3 +811,80 @@ The methods of each Class is in alphabetical order.
eval_at : list of int eval_at : list of int
The evaulation positions of NDCG The evaulation positions of NDCG
## Callbacks
###Before iteration
####reset_parameter(**kwargs)
Reset parameter after first iteration
NOTE: the initial parameter will still take in-effect on first iteration.
Parameters
----------
**kwargs: value should be list or function
List of parameters for each boosting round
or a customized function that calculates learning_rate in terms of
current number of round (e.g. yields learning rate decay)
- list l: parameter = l[current_round]
- function f: parameter = f(current_round)
Returns
-------
callback : function
The requested callback function.
###After iteration
####print_evaluation(period=1, show_stdv=True)
Create a callback that print evaluation result.
(Same function as `verbose_eval` in lightgbm.train())
Parameters
----------
period : int
The period to log the evaluation results
show_stdv : bool, optional
Whether show standard deviation if provided
Returns
-------
callback : function
A callback that prints evaluation every period iterations.
####record_evaluation(eval_result)
Create a call back that records the evaluation history into eval_result.
(Same function as `evals_result` in lightgbm.train())
Parameters
----------
eval_result : dict
A dictionary to store the evaluation results.
Returns
-------
callback : function
The requested callback function.
####early_stopping(stopping_rounds, verbose=True)
Create a callback that activates early stopping.
To activates early stopping, at least one validation data and one metric is required.
If there's more than one, all of them will be checked.
(Same function as `early_stopping_rounds` in lightgbm.train())
Parameters
----------
stopping_rounds : int
The stopping rounds before the trend occur.
verbose : optional, bool
Whether to print message about early stopping information.
Returns
-------
callback : function
The requested callback function.
...@@ -36,7 +36,7 @@ print('Calculate feature importances...') ...@@ -36,7 +36,7 @@ print('Calculate feature importances...')
# feature importances # feature importances
print('Feature importances:', list(gbm.feature_importance())) print('Feature importances:', list(gbm.feature_importance()))
# other scikit-learn built-in module # other scikit-learn modules
estimator = lgb.LGBMRegressor(num_leaves=31) estimator = lgb.LGBMRegressor(num_leaves=31)
param_grid = { param_grid = {
......
...@@ -10,6 +10,7 @@ import os ...@@ -10,6 +10,7 @@ import os
from .basic import Dataset, Booster from .basic import Dataset, Booster
from .engine import train, cv from .engine import train, cv
from .callback import print_evaluation, record_evaluation, reset_parameter, early_stopping
try: try:
from .sklearn import LGBMModel, LGBMRegressor, LGBMClassifier, LGBMRanker from .sklearn import LGBMModel, LGBMRegressor, LGBMClassifier, LGBMRanker
except ImportError: except ImportError:
...@@ -20,5 +21,6 @@ __version__ = 0.1 ...@@ -20,5 +21,6 @@ __version__ = 0.1
__all__ = ['Dataset', 'Booster', __all__ = ['Dataset', 'Booster',
'train', 'cv', 'train', 'cv',
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker'] 'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
'print_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping']
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# pylint: disable = invalid-name, W0105, C0301 # pylint: disable = invalid-name, W0105, C0301
from __future__ import absolute_import from __future__ import absolute_import
import collections import collections
import inspect
class EarlyStopException(Exception): class EarlyStopException(Exception):
"""Exception of early stopping. """Exception of early stopping.
...@@ -98,21 +97,19 @@ def record_evaluation(eval_result): ...@@ -98,21 +97,19 @@ def record_evaluation(eval_result):
return callback return callback
def reset_learning_rate(learning_rates): def reset_parameter(**kwargs):
"""Reset learning rate after first iteration """Reset parameter after first iteration
NOTE: the initial learning rate will still take in-effect on first iteration. NOTE: the initial parameter will still take in-effect on first iteration.
Parameters Parameters
---------- ----------
learning_rates: list or function **kwargs: value should be list or function
List of learning rate for each boosting round \ List of parameters for each boosting round
or a customized function that calculates learning_rate in terms of \ or a customized function that calculates learning_rate in terms of
current number of round and the total number of boosting round \ current number of round (e.g. yields learning rate decay)
(e.g. yields learning rate decay) - list l: parameter = l[current_round]
- list l: learning_rate = l[current_round] - function f: parameter = f(current_round)
- function f: learning_rate = f(current_round, total_boost_round) \
or learning_rate = f(current_round)
Returns Returns
------- -------
callback : function callback : function
...@@ -120,25 +117,19 @@ def reset_learning_rate(learning_rates): ...@@ -120,25 +117,19 @@ def reset_learning_rate(learning_rates):
""" """
def callback(env): def callback(env):
"""internal function""" """internal function"""
if isinstance(learning_rates, list): for key, value in kwargs.items():
if len(learning_rates) != env.end_iteration - env.begin_iteration: if isinstance(value, list):
raise ValueError("Length of list 'learning_rates' has to equal to 'num_boost_round'.") if len(value) != env.end_iteration - env.begin_iteration:
env.model.reset_parameter({'learning_rate':learning_rates[env.iteration]}) raise ValueError("Length of list {} has to equal to 'num_boost_round'.".format(repr(key)))
env.model.reset_parameter({key: value[env.iteration - env.begin_iteration]})
else: else:
argc = len(inspect.getargspec(learning_rates).args) env.model.reset_parameter({key: value(env.iteration - env.begin_iteration)})
if argc is 1:
env.model.reset_parameter({"learning_rate": learning_rates(env.iteration - env.begin_iteration)})
elif argc is 2:
env.model.reset_parameter({"learning_rate": \
learning_rates(env.iteration - env.begin_iteration, env.end_iteration - env.begin_iteration)})
else:
raise ValueError("Self-defined function 'learning_rates' should have 1 or 2 arguments, got %d" %(argc))
callback.before_iteration = True callback.before_iteration = True
callback.order = 10 callback.order = 10
return callback return callback
def early_stop(stopping_rounds, verbose=True): def early_stopping(stopping_rounds, verbose=True):
"""Create a callback that activates early stopping. """Create a callback that activates early stopping.
Activates early stopping. Activates early stopping.
Requires at least one validation data and one metric Requires at least one validation data and one metric
......
...@@ -69,12 +69,10 @@ def train(params, train_set, num_boost_round=100, ...@@ -69,12 +69,10 @@ def train(params, train_set, num_boost_round=100,
an evaluation metric is printed every 4 (instead of 1) boosting stages. an evaluation metric is printed every 4 (instead of 1) boosting stages.
learning_rates: list or function learning_rates: list or function
List of learning rate for each boosting round List of learning rate for each boosting round
or a customized function that calculates learning_rate in terms of or a customized function that calculates learning_rate
current number of round (and the total number of boosting round) in terms of current number of round (e.g. yields learning rate decay)
(e.g. yields learning rate decay)
- list l: learning_rate = l[current_round] - list l: learning_rate = l[current_round]
- function f: learning_rate = f(current_round, total_boost_round) - function f: learning_rate = f(current_round)
or learning_rate = f(current_round)
callbacks : list of callback functions callbacks : list of callback functions
List of callback functions that are applied at end of each iteration. List of callback functions that are applied at end of each iteration.
...@@ -138,11 +136,10 @@ def train(params, train_set, num_boost_round=100, ...@@ -138,11 +136,10 @@ def train(params, train_set, num_boost_round=100,
callbacks.add(callback.print_evaluation(verbose_eval)) callbacks.add(callback.print_evaluation(verbose_eval))
if early_stopping_rounds is not None: if early_stopping_rounds is not None:
callbacks.add(callback.early_stop(early_stopping_rounds, callbacks.add(callback.early_stopping(early_stopping_rounds, verbose=bool(verbose_eval)))
verbose=bool(verbose_eval)))
if learning_rates is not None: if learning_rates is not None:
callbacks.add(callback.reset_learning_rate(learning_rates)) callbacks.add(callback.reset_parameter(learning_rate=learning_rates))
if evals_result is not None: if evals_result is not None:
callbacks.add(callback.record_evaluation(evals_result)) callbacks.add(callback.record_evaluation(evals_result))
...@@ -355,7 +352,7 @@ def cv(params, train_set, num_boost_round=10, nfold=5, stratified=False, ...@@ -355,7 +352,7 @@ def cv(params, train_set, num_boost_round=10, nfold=5, stratified=False,
cb.__dict__.setdefault('order', i - len(callbacks)) cb.__dict__.setdefault('order', i - len(callbacks))
callbacks = set(callbacks) callbacks = set(callbacks)
if early_stopping_rounds is not None: if early_stopping_rounds is not None:
callbacks.add(callback.early_stop(early_stopping_rounds, verbose=False)) callbacks.add(callback.early_stopping(early_stopping_rounds, verbose=False))
if verbose_eval is True: if verbose_eval is True:
callbacks.add(callback.print_evaluation(show_stdv=show_stdv)) callbacks.add(callback.print_evaluation(show_stdv=show_stdv))
elif isinstance(verbose_eval, int): elif isinstance(verbose_eval, int):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment