Commit ef408f55 authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

lambdarank cv (#459)

parent bcb9613e
......@@ -1109,7 +1109,7 @@ class Dataset(object):
def get_group(self):
"""
Get the initial score of the Dataset.
Get the group of the Dataset.
Returns
-------
......
......@@ -65,9 +65,9 @@ try:
from sklearn.preprocessing import LabelEncoder
from sklearn.utils import deprecated
try:
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import StratifiedKFold, GroupKFold
except ImportError:
from sklearn.cross_validation import StratifiedKFold
from sklearn.cross_validation import StratifiedKFold, GroupKFold
SKLEARN_INSTALLED = True
LGBMModelBase = BaseEstimator
LGBMRegressorBase = RegressorMixin
......@@ -75,6 +75,7 @@ try:
LGBMLabelEncoder = LabelEncoder
LGBMDeprecated = deprecated
LGBMStratifiedKFold = StratifiedKFold
LGBMGroupKFold = GroupKFold
except ImportError:
SKLEARN_INSTALLED = False
LGBMModelBase = object
......@@ -82,3 +83,4 @@ except ImportError:
LGBMRegressorBase = object
LGBMLabelEncoder = None
LGBMStratifiedKFold = None
LGBMGroupKFold = None
......@@ -10,7 +10,7 @@ import numpy as np
from . import callback
from .basic import Booster, Dataset, LightGBMError, _InnerPredictor
from .compat import (SKLEARN_INSTALLED, LGBMStratifiedKFold, integer_types,
from .compat import (SKLEARN_INSTALLED, LGBMStratifiedKFold, LGBMGroupKFold, integer_types,
range_, string_type)
......@@ -228,25 +228,35 @@ def _make_n_folds(full_data, data_splitter, nfold, params, seed, fpreproc=None,
"""
Make an n-fold list of Booster from random indices.
"""
num_data = full_data.construct().num_data()
full_data = full_data.construct()
num_data = full_data.num_data()
if data_splitter is not None:
if not hasattr(data_splitter, 'split'):
raise AttributeError("data_splitter has no method 'split'")
folds = data_splitter.split(np.arange(num_data))
elif stratified:
if not SKLEARN_INSTALLED:
raise LightGBMError('Scikit-learn is required for stratified cv')
sfk = LGBMStratifiedKFold(n_splits=nfold, shuffle=shuffle, random_state=seed)
folds = sfk.split(X=np.zeros(num_data), y=full_data.get_label())
else:
if shuffle:
randidx = np.random.RandomState(seed).permutation(num_data)
if 'objective' in params and params['objective'] == 'lambdarank':
if not SKLEARN_INSTALLED:
raise LightGBMError('Scikit-learn is required for lambdarank cv.')
# lambdarank task, split according to groups
group_info = full_data.get_group().astype(int)
flatted_group = np.repeat(range(len(group_info)), repeats=group_info)
group_kfold = LGBMGroupKFold(n_splits=nfold)
folds = group_kfold.split(X=np.zeros(num_data), groups=flatted_group)
elif stratified:
if not SKLEARN_INSTALLED:
raise LightGBMError('Scikit-learn is required for stratified cv.')
skf = LGBMStratifiedKFold(n_splits=nfold, shuffle=shuffle, random_state=seed)
folds = skf.split(X=np.zeros(num_data), y=full_data.get_label())
else:
randidx = np.arange(num_data)
kstep = int(num_data / nfold)
test_id = [randidx[i: i + kstep] for i in range_(0, num_data, kstep)]
train_id = [np.concatenate([test_id[i] for i in range_(nfold) if k != i]) for k in range_(nfold)]
folds = zip(train_id, test_id)
if shuffle:
randidx = np.random.RandomState(seed).permutation(num_data)
else:
randidx = np.arange(num_data)
kstep = int(num_data / nfold)
test_id = [randidx[i: i + kstep] for i in range_(0, num_data, kstep)]
train_id = [np.concatenate([test_id[i] for i in range_(nfold) if k != i]) for k in range_(nfold)]
folds = zip(train_id, test_id)
ret = CVBooster()
for train_idx, test_idx in folds:
......
......@@ -8,7 +8,7 @@ import unittest
import lightgbm as lgb
import numpy as np
from sklearn.datasets import (load_boston, load_breast_cancer, load_digits,
load_iris)
load_iris, load_svmlight_file)
from sklearn.metrics import log_loss, mean_absolute_error, mean_squared_error
from sklearn.model_selection import train_test_split, TimeSeriesSplit
......@@ -152,15 +152,23 @@ class TestEngine(unittest.TestCase):
def test_cv(self):
lgb_train, _ = template.test_template(return_data=True)
lgb.cv({'verbose': -1}, lgb_train, num_boost_round=20, nfold=5, shuffle=False,
metrics='l1', verbose_eval=False,
callbacks=[lgb.reset_parameter(learning_rate=lambda i: 0.1 - 0.001 * i)])
lgb.cv({'verbose': -1}, lgb_train, num_boost_round=20, nfold=5, shuffle=True,
# shuffle = False
lgb.cv({'verbose': -1}, lgb_train, num_boost_round=10, nfold=3, shuffle=False,
metrics='l1', verbose_eval=False)
# shuffle = True, callbacks
lgb.cv({'verbose': -1}, lgb_train, num_boost_round=10, nfold=3, shuffle=True,
metrics='l1', verbose_eval=False,
callbacks=[lgb.reset_parameter(learning_rate=lambda i: 0.1 - 0.001 * i)])
# self defined data_splitter
tss = TimeSeriesSplit(3)
lgb.cv({'verbose': -1}, lgb_train, num_boost_round=20, data_splitter=tss, nfold=5, # test if wrong nfold is ignored
lgb.cv({'verbose': -1}, lgb_train, num_boost_round=10, data_splitter=tss, nfold=5, # test if wrong nfold is ignored
metrics='l2', verbose_eval=False)
# lambdarank
X_train, y_train = load_svmlight_file('../../examples/lambdarank/rank.train')
q_train = np.loadtxt('../../examples/lambdarank/rank.train.query')
params = {'objective': 'lambdarank', 'verbose': -1}
lgb_train = lgb.Dataset(X_train, y_train, group=q_train, params=params)
lgb.cv(params, lgb_train, num_boost_round=20, nfold=3, metrics='l2', verbose_eval=False)
def test_feature_name(self):
lgb_train, _ = template.test_template(return_data=True)
......
......@@ -95,11 +95,11 @@ class TestSklearn(unittest.TestCase):
def test_grid_search(self):
X_train, X_test, y_train, y_test = template.test_template(return_data=True)
params = {'boosting_type': ['dart', 'gbdt'],
'n_estimators': [15, 20],
'drop_rate': [0.1, 0.2]}
'n_estimators': [5, 8],
'drop_rate': [0.05, 0.1]}
gbm = GridSearchCV(lgb.LGBMRegressor(), params, cv=3)
gbm.fit(X_train, y_train)
self.assertIn(gbm.best_params_['n_estimators'], [15, 20])
self.assertIn(gbm.best_params_['n_estimators'], [5, 8])
def test_clone_and_property(self):
gbm = template.test_template(return_model=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment