Commit 80a52ad4 authored by Guolin Ke's avatar Guolin Ke
Browse files

add LGBMRanker

parent 164524d8
......@@ -20,4 +20,4 @@ __version__ = 0.1
__all__ = ['Dataset', 'Booster',
'train', 'cv',
'LGBMModel','LGBMRegressor', 'LGBMClassifier']
\ No newline at end of file
'LGBMModel','LGBMRegressor', 'LGBMClassifier', 'LGBMRanker']
\ No newline at end of file
......@@ -23,9 +23,6 @@ except ImportError:
def _point_wise_objective(func):
"""Decorate an objective function
Converts an objective function using the typical sklearn metrics to LightGBM fobj
Note: for multi-class task, the y_pred is group by class_id first, then group by row_id
if you want to get i-th row y_pred in j-th class, the access way is y_pred[j*num_data+i]
and you should group grad and hess in this way as well
......@@ -36,7 +33,7 @@ def _point_wise_objective(func):
y_true: array_like of shape [n_samples]
The target values
y_pred: array_like of shape [n_samples] or shape[n_samples* n_class]
y_pred: array_like of shape [n_samples] or shape[n_samples* n_class] (for multi-class)
The predicted values
......@@ -66,6 +63,8 @@ def _point_wise_objective(func):
else:
num_data = len(weight)
num_class = len(grad) // num_data
if num_class * num_data != len(grad):
raise ValueError("lenght of grad and hess should equal with num_class * num_data")
for k in range(num_class):
for i in range(num_data):
idx = k * num_data + i
......@@ -74,7 +73,6 @@ def _point_wise_objective(func):
return grad, hess
return inner
class LGBMModel(LGBMModelBase):
"""Implementation of the Scikit-Learn API for LightGBM.
......@@ -169,6 +167,10 @@ class LGBMModel(LGBMModelBase):
self.is_unbalance = is_unbalance
self.seed = seed
self._Booster = None
if callable(self.objective):
self.fobj = _point_wise_objective(self.objective)
else:
self.fobj = None
def booster(self):
"""Get the underlying lightgbm Booster of this model.
......@@ -205,11 +207,11 @@ class LGBMModel(LGBMModelBase):
eval_set : list, optional
A list of (X, y) tuple pairs to use as a validation set for early-stopping
eval_metric : str, list of str, callable, optional
If a str, should be a built-in evaluation metric to use. See
doc/parameter.md. If callable, a custom evaluation metric. The call
signature is func(y_predicted, y_true) where y_true will be a
If a str, should be a built-in evaluation metric to use.
If callable, a custom evaluation metric. The call
signature is func(y_predicted, dataset) where dataset will be a
Dataset fobject such that you may need to call the get_label
method. And it must return (eval_name, feature_result, is_bigger_better)
method. And it must return (eval_name->str, eval_result->float, is_bigger_better->Bool)
early_stopping_rounds : int
verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation
......@@ -228,12 +230,11 @@ class LGBMModel(LGBMModelBase):
if other_params is not None:
params.update(other_params)
if callable(self.objective):
fobj = _point_wise_objective(self.objective)
if self.fobj:
params["objective"] = "None"
else:
params["objective"] = self.objective
fobj = None
if callable(eval_metric):
feval = eval_metric
elif is_str(eval_metric) or isinstance(eval_metric, list):
......@@ -246,7 +247,7 @@ class LGBMModel(LGBMModelBase):
self._Booster = train(params, (X, y),
self.n_estimators, valid_datas=eval_set,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, fobj=fobj, feval=feval,
evals_result=evals_result, fobj=self.fobj, feval=feval,
verbose_eval=verbose, train_fields=train_fields, valid_fields=valid_fields)
if evals_result:
......@@ -316,11 +317,9 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase):
other_params = {}
if self.n_classes_ > 2:
# Switch to using a multiclass objective in the underlying LGBM instance
if not callable(self.objective):
self.objective = "multiclass"
other_params['num_class'] = self.n_classes_
else:
if not callable(self.objective):
self.objective = "binary"
self._le = LGBMLabelEncoder().fit(y)
......@@ -355,3 +354,80 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase):
classzero_probs = 1.0 - classone_probs
return np.vstack((classzero_probs, classone_probs)).transpose()
def _group_wise_objective(func):
"""Decorate an objective function
Parameters
----------
func: callable
Expects a callable with signature ``func(y_true, group, y_pred)``:
y_true: array_like of shape [n_samples]
The target values
group : array_like of shape
group size data of data
y_pred: array_like of shape [n_samples] or shape[n_samples* n_class] (for multi-class)
The predicted values
Returns
-------
new_func: callable
The new objective function as expected by ``lightgbm.engine.train``.
The signature is ``new_func(preds, dataset)``:
preds: array_like, shape [n_samples] or shape[n_samples* n_class]
The predicted values
dataset: ``dataset``
The training set from which the labels will be extracted using
``dataset.get_label()``
"""
def inner(preds, dataset):
"""internal function"""
labels = dataset.get_label()
group = dataset.get_group()
if group is None:
raise ValueError("group should not be None for ranking task")
grad, hess = func(labels, group, preds)
"""weighted for objective"""
weight = dataset.get_weight()
if weight is not None:
"""only one class"""
if len(weight) == len(grad):
grad = np.multiply(grad, weight)
hess = np.multiply(hess, weight)
else:
raise ValueError("lenght of grad and hess should equal with num_data")
return grad, hess
return inner
class LGBMRanker(LGBMModel):
__doc__ = """Implementation of the scikit-learn API for LightGBM ranking application.
""" + '\n'.join(LGBMModel.__doc__.split('\n')[2:])
def fit(self, X, y, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True,
train_fields=None, valid_fields=None, other_params=None):
"""check group data"""
if "group" not in train_fields:
raise ValueError("should set group in train_fields for ranking task")
if eval_set is not None:
if valid_fields is None:
raise ValueError("valid_fields cannot be None when eval_set is not None")
elif len(valid_fields) != len(eval_set):
raise ValueError("lenght of valid_fields should equal with eval_set")
else:
for inner in valid_fields:
if "group" not in inner:
raise ValueError("should set group in valid_fields for ranking task")
if callable(self.objective):
self.fobj = _group_wise_objective(self.objective)
else:
self.objective = "lambdarank"
self.fobj = None
super(LGBMRanker, self).fit(X, y, eval_set, eval_metric,
early_stopping_rounds, verbose, train_fields, valid_fields, other_params)
return self
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment