Commit 918c157b authored by Guolin Ke's avatar Guolin Ke
Browse files

refine fobj and feval in sklearn interface

parent ebfc8521
......@@ -2,6 +2,7 @@
# pylint: disable = invalid-name, W0105, C0111
"""Scikit-Learn Wrapper interface for LightGBM."""
from __future__ import absolute_import
import inspect
import numpy as np
from .basic import LightGBMError, Dataset, is_str
......@@ -23,7 +24,7 @@ except ImportError:
LGBMRegressorBase = object
LGBMLabelEncoder = None
def _point_wise_objective(func):
def _objective_function_wrapper(func):
"""Decorate an objective function
Note: for multi-class task, the y_pred is group by class_id first, then group by row_id
if you want to get i-th row y_pred in j-th class, the access way is y_pred[j*num_data+i]
......@@ -31,12 +32,13 @@ def _point_wise_objective(func):
Parameters
----------
func: callable
Expects a callable with signature ``func(y_true, y_pred)``:
Expects a callable with signature ``func(y_true, y_pred)`` or ``func(y_true, y_pred, group):
y_true: array_like of shape [n_samples]
The target values
y_pred: array_like of shape [n_samples] or shape[n_samples* n_class] (for multi-class)
The predicted values
group: array_like
group/query data, used for ranking task
Returns
-------
......@@ -53,7 +55,13 @@ def _point_wise_objective(func):
def inner(preds, dataset):
"""internal function"""
labels = dataset.get_label()
argc = len(inspect.getargspec(func).args)
if argc == 2:
grad, hess = func(labels, preds)
elif argc == 3:
grad, hess = func(labels, preds, dataset.get_group())
else:
raise TypeError("parameter number of objective function should be (2, 3), got %d" %(argc))
"""weighted for objective"""
weight = dataset.get_weight()
if weight is not None:
......@@ -74,6 +82,51 @@ def _point_wise_objective(func):
return grad, hess
return inner
def _eval_function_wrapper(func):
"""Decorate an eval function
Note: for multi-class task, the y_pred is group by class_id first, then group by row_id
if you want to get i-th row y_pred in j-th class, the access way is y_pred[j*num_data+i]
Parameters
----------
func: callable
Expects a callable with following functions: ``func(y_true, y_pred)``, ``func(y_true, y_pred, weight)``
or ``func(y_true, y_pred, weight, group)`` and return (eval_name->str, eval_result->float, is_bigger_better->Bool):
y_true: array_like of shape [n_samples]
The target values
y_pred: array_like of shape [n_samples] or shape[n_samples* n_class] (for multi-class)
The predicted values
weight: array_like of shape [n_samples]
The weight of samples
group: array_like
group/query data, used for ranking task
Returns
-------
new_func: callable
The new eval function as expected by ``lightgbm.engine.train``.
The signature is ``new_func(preds, dataset)``:
preds: array_like, shape [n_samples] or shape[n_samples* n_class]
The predicted values
dataset: ``dataset``
The training set from which the labels will be extracted using
``dataset.get_label()``
"""
def inner(preds, dataset):
"""internal function"""
labels = dataset.get_label()
argc = len(inspect.getargspec(func).args)
if argc == 2:
return func(labels, preds)
elif argc == 3:
return func(labels, preds, dataset.get_weight())
elif argc == 4:
return func(labels, preds, dataset.get_weight(), dataset.get_group())
else:
raise TypeError("parameter number of eval function should be (2, 3, 4), got %d" %(argc))
return inner
class LGBMModel(LGBMModelBase):
"""Implementation of the Scikit-Learn API for LightGBM.
......@@ -121,13 +174,14 @@ class LGBMModel(LGBMModelBase):
----
A custom objective function can be provided for the ``objective``
parameter. In this case, it should have the signature
``objective(y_true, y_pred) -> grad, hess``:
``objective(y_true, y_pred) -> grad, hess`` or ``objective(y_true, y_pred, group) -> grad, hess``:
y_true: array_like of shape [n_samples]
The target values
y_pred: array_like of shape [n_samples] or shape[n_samples* n_class]
The predicted values
group: array_like
group/query data, used for ranking task
grad: array_like of shape [n_samples] or shape[n_samples* n_class]
The value of the gradient for each sample point.
hess: array_like of shape [n_samples] or shape[n_samples* n_class]
......@@ -170,7 +224,7 @@ class LGBMModel(LGBMModelBase):
self._Booster = None
self.best_iteration = -1
if callable(self.objective):
self.fobj = _point_wise_objective(self.objective)
self.fobj = _objective_function_wrapper(self.objective)
else:
self.fobj = None
......@@ -227,10 +281,7 @@ class LGBMModel(LGBMModelBase):
group data of eval data
eval_metric : str, list of str, callable, optional
If a str, should be a built-in evaluation metric to use.
If callable, a custom evaluation metric. The call \
signature is func(y_predicted, dataset) where dataset will be a \
Dateset object such that you may need to call the get_label \
method. And it must return (eval_name->str, eval_result->float, is_bigger_better->Bool)
If callable, a custom evaluation metric, see note for more details.
early_stopping_rounds : int
verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation
......@@ -241,6 +292,29 @@ class LGBMModel(LGBMModelBase):
type str represents feature names (need to specify feature_name as well)
other_params: dict
Other parameters
Note
----
Custom eval function expects a callable with following functions: ``func(y_true, y_pred)``, ``func(y_true, y_pred, weight)``
or ``func(y_true, y_pred, weight, group)``.
return (eval_name, eval_result, is_bigger_better) or list of (eval_name, eval_result, is_bigger_better)
y_true: array_like of shape [n_samples]
The target values
y_pred: array_like of shape [n_samples] or shape[n_samples* n_class] (for multi-class)
The predicted values
weight: array_like of shape [n_samples]
The weight of samples
group: array_like
group/query data, used for ranking task
eval_name: str
name of evaluation
eval_result: float
eval result
is_bigger_better: bool
is eval result bigger better, e.g. AUC is bigger_better.
for multi-class task, the y_pred is group by class_id first, then group by row_id
if you want to get i-th row y_pred in j-th class, the access way is y_pred[j*num_data+i]
"""
evals_result = {}
params = self.get_params()
......@@ -262,13 +336,12 @@ class LGBMModel(LGBMModelBase):
}.get(self.objective, None)
if callable(eval_metric):
feval = eval_metric
feval = _eval_function_wrapper(eval_metric)
elif is_str(eval_metric) or isinstance(eval_metric, list):
feval = None
params.update({'metric': eval_metric})
else:
feval = None
feval = eval_metric if callable(eval_metric) else None
def _construct_dataset(X, y, sample_weight, init_score, group):
ret = Dataset(X, label=y, weight=sample_weight, group=group)
......@@ -448,51 +521,6 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase):
classzero_probs = 1.0 - classone_probs
return np.vstack((classzero_probs, classone_probs)).transpose()
def _group_wise_objective(func):
"""Decorate an objective function
Parameters
----------
func: callable
Expects a callable with signature ``func(y_true, group, y_pred)``:
y_true: array_like of shape [n_samples]
The target values
group : array_like of shape
Group size data of data
y_pred: array_like of shape [n_samples] or shape[n_samples* n_class] (for multi-class)
The predicted values
Returns
-------
new_func: callable
The new objective function as expected by ``lightgbm.engine.train``.
The signature is ``new_func(preds, dataset)``:
preds: array_like, shape [n_samples] or shape[n_samples* n_class]
The predicted values
dataset: ``dataset``
The training set from which the labels will be extracted using
``dataset.get_label()``
"""
def inner(preds, dataset):
"""internal function"""
labels = dataset.get_label()
group = dataset.get_group()
if group is None:
raise ValueError("Group should not be None for ranking task")
grad, hess = func(labels, group, preds)
"""weighted for objective"""
weight = dataset.get_weight()
if weight is not None:
"""only one class"""
if len(weight) == len(grad):
grad = np.multiply(grad, weight)
hess = np.multiply(hess, weight)
else:
raise ValueError("Length of grad and hess should equal with num_data")
return grad, hess
return inner
class LGBMRanker(LGBMModel):
__doc__ = """Implementation of the scikit-learn API for LightGBM ranking application.
......@@ -512,10 +540,6 @@ class LGBMRanker(LGBMModel):
subsample, subsample_freq, colsample_bytree,
reg_alpha, reg_lambda, scale_pos_weight,
is_unbalance, seed)
if callable(self.objective):
self.fobj = _group_wise_objective(self.objective)
else:
self.fobj = None
def fit(self, X, y,
sample_weight=None, init_score=None, group=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment