Commit 918c157b authored by Guolin Ke's avatar Guolin Ke
Browse files

refine fobj and feval in sklearn interface

parent ebfc8521
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# pylint: disable = invalid-name, W0105, C0111 # pylint: disable = invalid-name, W0105, C0111
"""Scikit-Learn Wrapper interface for LightGBM.""" """Scikit-Learn Wrapper interface for LightGBM."""
from __future__ import absolute_import from __future__ import absolute_import
import inspect
import numpy as np import numpy as np
from .basic import LightGBMError, Dataset, is_str from .basic import LightGBMError, Dataset, is_str
...@@ -23,7 +24,7 @@ except ImportError: ...@@ -23,7 +24,7 @@ except ImportError:
LGBMRegressorBase = object LGBMRegressorBase = object
LGBMLabelEncoder = None LGBMLabelEncoder = None
def _point_wise_objective(func): def _objective_function_wrapper(func):
"""Decorate an objective function """Decorate an objective function
Note: for multi-class task, the y_pred is group by class_id first, then group by row_id Note: for multi-class task, the y_pred is group by class_id first, then group by row_id
if you want to get i-th row y_pred in j-th class, the access way is y_pred[j*num_data+i] if you want to get i-th row y_pred in j-th class, the access way is y_pred[j*num_data+i]
...@@ -31,12 +32,13 @@ def _point_wise_objective(func): ...@@ -31,12 +32,13 @@ def _point_wise_objective(func):
Parameters Parameters
---------- ----------
func: callable func: callable
Expects a callable with signature ``func(y_true, y_pred)``: Expects a callable with signature ``func(y_true, y_pred)`` or ``func(y_true, y_pred, group):
y_true: array_like of shape [n_samples]
y_true: array_like of shape [n_samples] The target values
The target values y_pred: array_like of shape [n_samples] or shape[n_samples* n_class] (for multi-class)
y_pred: array_like of shape [n_samples] or shape[n_samples* n_class] (for multi-class) The predicted values
The predicted values group: array_like
group/query data, used for ranking task
Returns Returns
------- -------
...@@ -53,7 +55,13 @@ def _point_wise_objective(func): ...@@ -53,7 +55,13 @@ def _point_wise_objective(func):
def inner(preds, dataset): def inner(preds, dataset):
"""internal function""" """internal function"""
labels = dataset.get_label() labels = dataset.get_label()
grad, hess = func(labels, preds) argc = len(inspect.getargspec(func).args)
if argc == 2:
grad, hess = func(labels, preds)
elif argc == 3:
grad, hess = func(labels, preds, dataset.get_group())
else:
raise TypeError("parameter number of objective function should be (2, 3), got %d" %(argc))
"""weighted for objective""" """weighted for objective"""
weight = dataset.get_weight() weight = dataset.get_weight()
if weight is not None: if weight is not None:
...@@ -74,6 +82,51 @@ def _point_wise_objective(func): ...@@ -74,6 +82,51 @@ def _point_wise_objective(func):
return grad, hess return grad, hess
return inner return inner
def _eval_function_wrapper(func):
"""Decorate an eval function
Note: for multi-class task, the y_pred is group by class_id first, then group by row_id
if you want to get i-th row y_pred in j-th class, the access way is y_pred[j*num_data+i]
Parameters
----------
func: callable
Expects a callable with following functions: ``func(y_true, y_pred)``, ``func(y_true, y_pred, weight)``
or ``func(y_true, y_pred, weight, group)`` and return (eval_name->str, eval_result->float, is_bigger_better->Bool):
y_true: array_like of shape [n_samples]
The target values
y_pred: array_like of shape [n_samples] or shape[n_samples* n_class] (for multi-class)
The predicted values
weight: array_like of shape [n_samples]
The weight of samples
group: array_like
group/query data, used for ranking task
Returns
-------
new_func: callable
The new eval function as expected by ``lightgbm.engine.train``.
The signature is ``new_func(preds, dataset)``:
preds: array_like, shape [n_samples] or shape[n_samples* n_class]
The predicted values
dataset: ``dataset``
The training set from which the labels will be extracted using
``dataset.get_label()``
"""
def inner(preds, dataset):
"""internal function"""
labels = dataset.get_label()
argc = len(inspect.getargspec(func).args)
if argc == 2:
return func(labels, preds)
elif argc == 3:
return func(labels, preds, dataset.get_weight())
elif argc == 4:
return func(labels, preds, dataset.get_weight(), dataset.get_group())
else:
raise TypeError("parameter number of eval function should be (2, 3, 4), got %d" %(argc))
return inner
class LGBMModel(LGBMModelBase): class LGBMModel(LGBMModelBase):
"""Implementation of the Scikit-Learn API for LightGBM. """Implementation of the Scikit-Learn API for LightGBM.
...@@ -121,17 +174,18 @@ class LGBMModel(LGBMModelBase): ...@@ -121,17 +174,18 @@ class LGBMModel(LGBMModelBase):
---- ----
A custom objective function can be provided for the ``objective`` A custom objective function can be provided for the ``objective``
parameter. In this case, it should have the signature parameter. In this case, it should have the signature
``objective(y_true, y_pred) -> grad, hess``: ``objective(y_true, y_pred) -> grad, hess`` or ``objective(y_true, y_pred, group) -> grad, hess``:
y_true: array_like of shape [n_samples] y_true: array_like of shape [n_samples]
The target values The target values
y_pred: array_like of shape [n_samples] or shape[n_samples* n_class] y_pred: array_like of shape [n_samples] or shape[n_samples* n_class]
The predicted values The predicted values
group: array_like
grad: array_like of shape [n_samples] or shape[n_samples* n_class] group/query data, used for ranking task
The value of the gradient for each sample point. grad: array_like of shape [n_samples] or shape[n_samples* n_class]
hess: array_like of shape [n_samples] or shape[n_samples* n_class] The value of the gradient for each sample point.
The value of the second derivative for each sample point hess: array_like of shape [n_samples] or shape[n_samples* n_class]
The value of the second derivative for each sample point
for multi-class task, the y_pred is group by class_id first, then group by row_id for multi-class task, the y_pred is group by class_id first, then group by row_id
if you want to get i-th row y_pred in j-th class, the access way is y_pred[j*num_data+i] if you want to get i-th row y_pred in j-th class, the access way is y_pred[j*num_data+i]
...@@ -170,7 +224,7 @@ class LGBMModel(LGBMModelBase): ...@@ -170,7 +224,7 @@ class LGBMModel(LGBMModelBase):
self._Booster = None self._Booster = None
self.best_iteration = -1 self.best_iteration = -1
if callable(self.objective): if callable(self.objective):
self.fobj = _point_wise_objective(self.objective) self.fobj = _objective_function_wrapper(self.objective)
else: else:
self.fobj = None self.fobj = None
...@@ -227,10 +281,7 @@ class LGBMModel(LGBMModelBase): ...@@ -227,10 +281,7 @@ class LGBMModel(LGBMModelBase):
group data of eval data group data of eval data
eval_metric : str, list of str, callable, optional eval_metric : str, list of str, callable, optional
If a str, should be a built-in evaluation metric to use. If a str, should be a built-in evaluation metric to use.
If callable, a custom evaluation metric. The call \ If callable, a custom evaluation metric, see note for more details.
signature is func(y_predicted, dataset) where dataset will be a \
Dateset object such that you may need to call the get_label \
method. And it must return (eval_name->str, eval_result->float, is_bigger_better->Bool)
early_stopping_rounds : int early_stopping_rounds : int
verbose : bool verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation If `verbose` and an evaluation set is used, writes the evaluation
...@@ -241,6 +292,29 @@ class LGBMModel(LGBMModelBase): ...@@ -241,6 +292,29 @@ class LGBMModel(LGBMModelBase):
type str represents feature names (need to specify feature_name as well) type str represents feature names (need to specify feature_name as well)
other_params: dict other_params: dict
Other parameters Other parameters
Note
----
Custom eval function expects a callable with following functions: ``func(y_true, y_pred)``, ``func(y_true, y_pred, weight)``
or ``func(y_true, y_pred, weight, group)``.
return (eval_name, eval_result, is_bigger_better) or list of (eval_name, eval_result, is_bigger_better)
y_true: array_like of shape [n_samples]
The target values
y_pred: array_like of shape [n_samples] or shape[n_samples* n_class] (for multi-class)
The predicted values
weight: array_like of shape [n_samples]
The weight of samples
group: array_like
group/query data, used for ranking task
eval_name: str
name of evaluation
eval_result: float
eval result
is_bigger_better: bool
is eval result bigger better, e.g. AUC is bigger_better.
for multi-class task, the y_pred is group by class_id first, then group by row_id
if you want to get i-th row y_pred in j-th class, the access way is y_pred[j*num_data+i]
""" """
evals_result = {} evals_result = {}
params = self.get_params() params = self.get_params()
...@@ -262,13 +336,12 @@ class LGBMModel(LGBMModelBase): ...@@ -262,13 +336,12 @@ class LGBMModel(LGBMModelBase):
}.get(self.objective, None) }.get(self.objective, None)
if callable(eval_metric): if callable(eval_metric):
feval = eval_metric feval = _eval_function_wrapper(eval_metric)
elif is_str(eval_metric) or isinstance(eval_metric, list): elif is_str(eval_metric) or isinstance(eval_metric, list):
feval = None feval = None
params.update({'metric': eval_metric}) params.update({'metric': eval_metric})
else: else:
feval = None feval = None
feval = eval_metric if callable(eval_metric) else None
def _construct_dataset(X, y, sample_weight, init_score, group): def _construct_dataset(X, y, sample_weight, init_score, group):
ret = Dataset(X, label=y, weight=sample_weight, group=group) ret = Dataset(X, label=y, weight=sample_weight, group=group)
...@@ -448,51 +521,6 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase): ...@@ -448,51 +521,6 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase):
classzero_probs = 1.0 - classone_probs classzero_probs = 1.0 - classone_probs
return np.vstack((classzero_probs, classone_probs)).transpose() return np.vstack((classzero_probs, classone_probs)).transpose()
def _group_wise_objective(func):
"""Decorate an objective function
Parameters
----------
func: callable
Expects a callable with signature ``func(y_true, group, y_pred)``:
y_true: array_like of shape [n_samples]
The target values
group : array_like of shape
Group size data of data
y_pred: array_like of shape [n_samples] or shape[n_samples* n_class] (for multi-class)
The predicted values
Returns
-------
new_func: callable
The new objective function as expected by ``lightgbm.engine.train``.
The signature is ``new_func(preds, dataset)``:
preds: array_like, shape [n_samples] or shape[n_samples* n_class]
The predicted values
dataset: ``dataset``
The training set from which the labels will be extracted using
``dataset.get_label()``
"""
def inner(preds, dataset):
"""internal function"""
labels = dataset.get_label()
group = dataset.get_group()
if group is None:
raise ValueError("Group should not be None for ranking task")
grad, hess = func(labels, group, preds)
"""weighted for objective"""
weight = dataset.get_weight()
if weight is not None:
"""only one class"""
if len(weight) == len(grad):
grad = np.multiply(grad, weight)
hess = np.multiply(hess, weight)
else:
raise ValueError("Length of grad and hess should equal with num_data")
return grad, hess
return inner
class LGBMRanker(LGBMModel): class LGBMRanker(LGBMModel):
__doc__ = """Implementation of the scikit-learn API for LightGBM ranking application. __doc__ = """Implementation of the scikit-learn API for LightGBM ranking application.
...@@ -512,10 +540,6 @@ class LGBMRanker(LGBMModel): ...@@ -512,10 +540,6 @@ class LGBMRanker(LGBMModel):
subsample, subsample_freq, colsample_bytree, subsample, subsample_freq, colsample_bytree,
reg_alpha, reg_lambda, scale_pos_weight, reg_alpha, reg_lambda, scale_pos_weight,
is_unbalance, seed) is_unbalance, seed)
if callable(self.objective):
self.fobj = _group_wise_objective(self.objective)
else:
self.fobj = None
def fit(self, X, y, def fit(self, X, y,
sample_weight=None, init_score=None, group=None, sample_weight=None, init_score=None, group=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment