Commit 13d4581b authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

add data_splitter to cv (#298)

* add data_splitter for cv

* update gitignore

* clean code
parent ea6bc0a5
...@@ -368,5 +368,9 @@ ENV/ ...@@ -368,5 +368,9 @@ ENV/
# R testing artefact # R testing artefact
lightgbm.model lightgbm.model
# saved or dumped model
*.model
*.pkl
# macOS # macOS
.DS_Store .DS_Store
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* [Training API](Python-API.md#training-api) * [Training API](Python-API.md#training-api)
- [train](Python-API.md#trainparams-train_set-num_boost_round100-valid_setsnone-valid_namesnone-fobjnone-fevalnone-init_modelnone-feature_nameauto-categorical_featureauto-early_stopping_roundsnone-evals_resultnone-verbose_evaltrue-learning_ratesnone-callbacksnone) - [train](Python-API.md#trainparams-train_set-num_boost_round100-valid_setsnone-valid_namesnone-fobjnone-fevalnone-init_modelnone-feature_nameauto-categorical_featureauto-early_stopping_roundsnone-evals_resultnone-verbose_evaltrue-learning_ratesnone-callbacksnone)
- [cv](Python-API.md#cvparams-train_set-num_boost_round10-nfold5-stratifiedfalse-shuffletrue-metricsnone-fobjnone-fevalnone-init_modelnone-feature_nameauto-categorical_featureauto-early_stopping_roundsnone-fpreprocnone-verbose_evalnone-show_stdvtrue-seed0-callbacksnone) - [cv](Python-API.md#cvparams-train_set-num_boost_round10-data_splitternone-nfold5-stratifiedfalse-shuffletrue-metricsnone-fobjnone-fevalnone-init_modelnone-feature_nameauto-categorical_featureauto-early_stopping_roundsnone-fpreprocnone-verbose_evalnone-show_stdvtrue-seed0-callbacksnone)
* [Scikit-learn API](Python-API.md#scikit-learn-api) * [Scikit-learn API](Python-API.md#scikit-learn-api)
- [Common Methods](Python-API.md#common-methods) - [Common Methods](Python-API.md#common-methods)
...@@ -536,7 +536,7 @@ The methods of each Class is in alphabetical order. ...@@ -536,7 +536,7 @@ The methods of each Class is in alphabetical order.
booster : a trained booster model booster : a trained booster model
####cv(params, train_set, num_boost_round=10, nfold=5, stratified=False, shuffle=True, metrics=None, fobj=None, feval=None, init_model=None, feature_name='auto', categorical_feature='auto', early_stopping_rounds=None, fpreproc=None, verbose_eval=None, show_stdv=True, seed=0, callbacks=None) ####cv(params, train_set, num_boost_round=10, data_splitter=None, nfold=5, stratified=False, shuffle=True, metrics=None, fobj=None, feval=None, init_model=None, feature_name='auto', categorical_feature='auto', early_stopping_rounds=None, fpreproc=None, verbose_eval=None, show_stdv=True, seed=0, callbacks=None)
Cross-validation with given paramaters. Cross-validation with given paramaters.
...@@ -548,14 +548,14 @@ The methods of each Class is in alphabetical order. ...@@ -548,14 +548,14 @@ The methods of each Class is in alphabetical order.
Data to be trained. Data to be trained.
num_boost_round : int num_boost_round : int
Number of boosting iterations. Number of boosting iterations.
data_splitter : an instance with split(X) method
Instance with split(X) method.
nfold : int nfold : int
Number of folds in CV. Number of folds in CV.
stratified : bool stratified : bool
Perform stratified sampling. Perform stratified sampling.
shuffle: bool shuffle: bool
Whether shuffle before split data. Whether shuffle before split data.
folds : a KFold or StratifiedKFold instance
Sklearn KFolds or StratifiedKFolds.
metrics : str or list of str metrics : str or list of str
Evaluation metrics to be watched in CV. Evaluation metrics to be watched in CV.
fobj : function fobj : function
......
...@@ -221,30 +221,35 @@ class CVBooster(object): ...@@ -221,30 +221,35 @@ class CVBooster(object):
return handlerFunction return handlerFunction
def _make_n_folds(full_data, nfold, params, seed, fpreproc=None, stratified=False, shuffle=True): def _make_n_folds(full_data, data_splitter, nfold, params, seed, fpreproc=None, stratified=False, shuffle=True):
""" """
Make an n-fold list of Booster from random indices. Make an n-fold list of Booster from random indices.
""" """
np.random.seed(seed) np.random.seed(seed)
if stratified: num_data = full_data.construct().num_data()
if SKLEARN_INSTALLED: if data_splitter is not None:
sfk = LGBMStratifiedKFold(n_splits=nfold, shuffle=shuffle, random_state=seed) if not hasattr(data_splitter, 'split'):
idset = [x[1] for x in sfk.split(X=full_data.get_label(), y=full_data.get_label())] raise AttributeError("data_splitter has no method 'split'")
else: folds = data_splitter.split(np.arange(num_data))
elif stratified:
if not SKLEARN_INSTALLED:
raise LightGBMError('Scikit-learn is required for stratified cv') raise LightGBMError('Scikit-learn is required for stratified cv')
sfk = LGBMStratifiedKFold(n_splits=nfold, shuffle=shuffle, random_state=seed)
folds = sfk.split(X=np.zeros(num_data), y=full_data.get_label())
else: else:
full_data.construct()
if shuffle: if shuffle:
randidx = np.random.permutation(full_data.num_data()) randidx = np.random.permutation(num_data)
else: else:
randidx = np.arange(full_data.num_data()) randidx = np.arange(num_data)
kstep = int(len(randidx) / nfold) kstep = int(num_data / nfold)
idset = [randidx[(i * kstep): min(len(randidx), (i + 1) * kstep)] for i in range_(nfold)] test_id = [randidx[i: i + kstep] for i in range_(0, num_data, kstep)]
train_id = [np.concatenate([test_id[i] for i in range_(nfold) if k != i]) for k in range_(nfold)]
folds = zip(train_id, test_id)
ret = CVBooster() ret = CVBooster()
for k in range_(nfold): for train_idx, test_idx in folds:
train_set = full_data.subset(np.concatenate([idset[i] for i in range_(nfold) if k != i])) train_set = full_data.subset(train_idx)
valid_set = full_data.subset(idset[k]) valid_set = full_data.subset(test_idx)
# run preprocessing on the data set if needed # run preprocessing on the data set if needed
if fpreproc is not None: if fpreproc is not None:
train_set, valid_set, tparam = fpreproc(train_set, valid_set, params.copy()) train_set, valid_set, tparam = fpreproc(train_set, valid_set, params.copy())
...@@ -269,8 +274,9 @@ def _agg_cv_result(raw_results): ...@@ -269,8 +274,9 @@ def _agg_cv_result(raw_results):
return [('cv_agg', k, np.mean(v), metric_type[k], np.std(v)) for k, v in cvmap.items()] return [('cv_agg', k, np.mean(v), metric_type[k], np.std(v)) for k, v in cvmap.items()]
def cv(params, train_set, num_boost_round=10, nfold=5, stratified=False, def cv(params, train_set, num_boost_round=10,
shuffle=True, metrics=None, fobj=None, feval=None, init_model=None, data_splitter=None, nfold=5, stratified=False, shuffle=True,
metrics=None, fobj=None, feval=None, init_model=None,
feature_name='auto', categorical_feature='auto', feature_name='auto', categorical_feature='auto',
early_stopping_rounds=None, fpreproc=None, early_stopping_rounds=None, fpreproc=None,
verbose_eval=None, show_stdv=True, seed=0, verbose_eval=None, show_stdv=True, seed=0,
...@@ -286,14 +292,14 @@ def cv(params, train_set, num_boost_round=10, nfold=5, stratified=False, ...@@ -286,14 +292,14 @@ def cv(params, train_set, num_boost_round=10, nfold=5, stratified=False,
Data to be trained. Data to be trained.
num_boost_round : int num_boost_round : int
Number of boosting iterations. Number of boosting iterations.
data_splitter : an instance with split(X) method
Instance with split(X) method.
nfold : int nfold : int
Number of folds in CV. Number of folds in CV.
stratified : bool stratified : bool
Perform stratified sampling. Perform stratified sampling.
shuffle: bool shuffle: bool
Whether shuffle before split data Whether shuffle before split data
folds : a KFold or StratifiedKFold instance
Sklearn KFolds or StratifiedKFolds.
metrics : string or list of strings metrics : string or list of strings
Evaluation metrics to be watched in CV. Evaluation metrics to be watched in CV.
fobj : function fobj : function
...@@ -358,7 +364,10 @@ def cv(params, train_set, num_boost_round=10, nfold=5, stratified=False, ...@@ -358,7 +364,10 @@ def cv(params, train_set, num_boost_round=10, nfold=5, stratified=False,
params['metric'].extend(metrics) params['metric'].extend(metrics)
results = collections.defaultdict(list) results = collections.defaultdict(list)
cvfolds = _make_n_folds(train_set, nfold, params, seed, fpreproc, stratified, shuffle) cvfolds = _make_n_folds(train_set, data_splitter=data_splitter,
nfold=nfold, params=params, seed=seed,
fpreproc=fpreproc, stratified=stratified,
shuffle=shuffle)
# setup callbacks # setup callbacks
if callbacks is None: if callbacks is None:
......
...@@ -10,7 +10,7 @@ import numpy as np ...@@ -10,7 +10,7 @@ import numpy as np
from sklearn.datasets import (load_boston, load_breast_cancer, load_digits, from sklearn.datasets import (load_boston, load_breast_cancer, load_digits,
load_iris) load_iris)
from sklearn.metrics import log_loss, mean_absolute_error, mean_squared_error from sklearn.metrics import log_loss, mean_absolute_error, mean_squared_error
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split, TimeSeriesSplit
try: try:
import pandas as pd import pandas as pd
...@@ -120,9 +120,12 @@ class TestEngine(unittest.TestCase): ...@@ -120,9 +120,12 @@ class TestEngine(unittest.TestCase):
def test_cv(self): def test_cv(self):
lgb_train, _ = template.test_template(return_data=True) lgb_train, _ = template.test_template(return_data=True)
lgb.cv({'verbose': -1}, lgb_train, num_boost_round=20, nfold=5, lgb.cv({'verbose': -1}, lgb_train, num_boost_round=20, nfold=5, shuffle=False,
metrics='l1', verbose_eval=False, metrics='l1', verbose_eval=False,
callbacks=[lgb.reset_parameter(learning_rate=lambda i: 0.1 - 0.001 * i)]) callbacks=[lgb.reset_parameter(learning_rate=lambda i: 0.1 - 0.001 * i)])
tss = TimeSeriesSplit(3)
lgb.cv({'verbose': -1}, lgb_train, num_boost_round=20, data_splitter=tss, nfold=5, # test if wrong nfold is ignored
metrics='l2', verbose_eval=False)
def test_feature_name(self): def test_feature_name(self):
lgb_train, _ = template.test_template(return_data=True) lgb_train, _ = template.test_template(return_data=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment