Commit 577a03ce authored by Nikita Titov's avatar Nikita Titov Committed by Guolin Ke
Browse files

[python] added possibility to use sklearn splitter classes in cv function (#1685)

* added sklearn splitter classes in cv function

* added tests
parent eb131ead
...@@ -261,8 +261,17 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi ...@@ -261,8 +261,17 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi
full_data = full_data.construct() full_data = full_data.construct()
num_data = full_data.num_data() num_data = full_data.num_data()
if folds is not None: if folds is not None:
if not hasattr(folds, '__iter__'): if not hasattr(folds, '__iter__') and not hasattr(folds, 'split'):
raise AttributeError("folds should be a generator or iterator of (train_idx, test_idx)") raise AttributeError("folds should be a generator or iterator of (train_idx, test_idx) tuples "
"or scikit-learn splitter object with split method")
if hasattr(folds, 'split'):
group_info = full_data.get_group()
if group_info is not None:
group_info = group_info.astype(int)
flatted_group = np.repeat(range_(len(group_info)), repeats=group_info)
else:
flatted_group = np.zeros(num_data, dtype=int)
folds = folds.split(X=np.zeros(num_data), y=full_data.get_label(), groups=flatted_group)
else: else:
if 'objective' in params and params['objective'] == 'lambdarank': if 'objective' in params and params['objective'] == 'lambdarank':
if not SKLEARN_INSTALLED: if not SKLEARN_INSTALLED:
...@@ -332,8 +341,11 @@ def cv(params, train_set, num_boost_round=100, ...@@ -332,8 +341,11 @@ def cv(params, train_set, num_boost_round=100,
Data to be trained on. Data to be trained on.
num_boost_round : int, optional (default=100) num_boost_round : int, optional (default=100)
Number of boosting iterations. Number of boosting iterations.
folds : a generator or iterator of (train_idx, test_idx) tuples or None, optional (default=None) folds : a generator or iterator of (train_idx, test_idx) tuples, scikit-learn splitter object or None, optional (default=None)
The train and test indices for the each fold. If generator or iterator, it should yield the train and test indices for the each fold.
If object, it should be one of the scikit-learn splitter classes
(http://scikit-learn.org/stable/modules/classes.html#splitter-classes)
and have ``split`` method.
This argument has highest priority over other data split arguments. This argument has highest priority over other data split arguments.
nfold : int, optional (default=5) nfold : int, optional (default=5)
Number of folds in CV. Number of folds in CV.
......
...@@ -11,7 +11,7 @@ import numpy as np ...@@ -11,7 +11,7 @@ import numpy as np
from sklearn.datasets import (load_boston, load_breast_cancer, load_digits, from sklearn.datasets import (load_boston, load_breast_cancer, load_digits,
load_iris, load_svmlight_file) load_iris, load_svmlight_file)
from sklearn.metrics import log_loss, mean_absolute_error, mean_squared_error from sklearn.metrics import log_loss, mean_absolute_error, mean_squared_error
from sklearn.model_selection import train_test_split, TimeSeriesSplit from sklearn.model_selection import train_test_split, TimeSeriesSplit, GroupKFold
from scipy.sparse import csr_matrix from scipy.sparse import csr_matrix
try: try:
...@@ -438,29 +438,47 @@ class TestEngine(unittest.TestCase): ...@@ -438,29 +438,47 @@ class TestEngine(unittest.TestCase):
lgb_train = lgb.Dataset(X_train, y_train) lgb_train = lgb.Dataset(X_train, y_train)
# shuffle = False, override metric in params # shuffle = False, override metric in params
params_with_metric = {'metric': 'l2', 'verbose': -1} params_with_metric = {'metric': 'l2', 'verbose': -1}
lgb.cv(params_with_metric, lgb_train, num_boost_round=10, nfold=3, stratified=False, shuffle=False, cv_res = lgb.cv(params_with_metric, lgb_train, num_boost_round=10, nfold=3, stratified=False, shuffle=False,
metrics='l1', verbose_eval=False) metrics='l1', verbose_eval=False)
self.assertIn('l1-mean', cv_res)
self.assertNotIn('l2-mean', cv_res)
self.assertEqual(len(cv_res['l1-mean']), 10)
# shuffle = True, callbacks # shuffle = True, callbacks
lgb.cv(params, lgb_train, num_boost_round=10, nfold=3, stratified=False, shuffle=True, cv_res = lgb.cv(params, lgb_train, num_boost_round=10, nfold=3, stratified=False, shuffle=True,
metrics='l1', verbose_eval=False, metrics='l1', verbose_eval=False,
callbacks=[lgb.reset_parameter(learning_rate=lambda i: 0.1 - 0.001 * i)]) callbacks=[lgb.reset_parameter(learning_rate=lambda i: 0.1 - 0.001 * i)])
self.assertIn('l1-mean', cv_res)
self.assertEqual(len(cv_res['l1-mean']), 10)
# self defined folds # self defined folds
tss = TimeSeriesSplit(3) tss = TimeSeriesSplit(3)
folds = tss.split(X_train) folds = tss.split(X_train)
lgb.cv(params_with_metric, lgb_train, num_boost_round=10, folds=folds, stratified=False, verbose_eval=False) cv_res_gen = lgb.cv(params, lgb_train, num_boost_round=10, folds=folds,
metrics='l2', verbose_eval=False)
cv_res_obj = lgb.cv(params, lgb_train, num_boost_round=10, folds=tss,
metrics='l2', verbose_eval=False)
np.testing.assert_almost_equal(cv_res_gen['l2-mean'], cv_res_obj['l2-mean'])
# lambdarank # lambdarank
X_train, y_train = load_svmlight_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../examples/lambdarank/rank.train')) X_train, y_train = load_svmlight_file(os.path.join(os.path.dirname(os.path.realpath(__file__)),
q_train = np.loadtxt(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../examples/lambdarank/rank.train.query')) '../../examples/lambdarank/rank.train'))
q_train = np.loadtxt(os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/lambdarank/rank.train.query'))
params_lambdarank = {'objective': 'lambdarank', 'verbose': -1, 'eval_at': 3} params_lambdarank = {'objective': 'lambdarank', 'verbose': -1, 'eval_at': 3}
lgb_train = lgb.Dataset(X_train, y_train, group=q_train) lgb_train = lgb.Dataset(X_train, y_train, group=q_train)
# ... with NDCG (default) metric # ... with NDCG (default) metric
cv_res = lgb.cv(params_lambdarank, lgb_train, num_boost_round=10, nfold=3, stratified=False, verbose_eval=False) cv_res_lambda = lgb.cv(params_lambdarank, lgb_train, num_boost_round=10, nfold=3,
self.assertEqual(len(cv_res), 2) verbose_eval=False)
self.assertFalse(np.isnan(cv_res['ndcg@3-mean']).any()) self.assertEqual(len(cv_res_lambda), 2)
self.assertFalse(np.isnan(cv_res_lambda['ndcg@3-mean']).any())
# ... with l2 metric # ... with l2 metric
cv_res = lgb.cv(params_lambdarank, lgb_train, num_boost_round=10, nfold=3, stratified=False, metrics='l2', verbose_eval=False) cv_res_lambda = lgb.cv(params_lambdarank, lgb_train, num_boost_round=10, nfold=3,
self.assertEqual(len(cv_res), 2) metrics='l2', verbose_eval=False)
self.assertFalse(np.isnan(cv_res['l2-mean']).any()) self.assertEqual(len(cv_res_lambda), 2)
self.assertFalse(np.isnan(cv_res_lambda['l2-mean']).any())
# self defined folds with lambdarank
cv_res_lambda_obj = lgb.cv(params_lambdarank, lgb_train, num_boost_round=10,
folds=GroupKFold(n_splits=3),
metrics='l2', verbose_eval=False)
np.testing.assert_almost_equal(cv_res_lambda['l2-mean'], cv_res_lambda_obj['l2-mean'])
def test_feature_name(self): def test_feature_name(self):
X, y = load_boston(True) X, y = load_boston(True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment