Commit c8142e30 authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

[MRG] [python] check params for num_boost_round & early_stopping_rounds (#806)

* check params

* add test case

* fix pylint
parent ddda85b0
......@@ -4,14 +4,15 @@
from __future__ import absolute_import
import collections
import warnings
from operator import attrgetter
import numpy as np
from . import callback
from .basic import Booster, Dataset, LightGBMError, _InnerPredictor
from .compat import (SKLEARN_INSTALLED, LGBMStratifiedKFold, LGBMGroupKFold, integer_types,
range_, string_type)
from .compat import (SKLEARN_INSTALLED, LGBMGroupKFold, LGBMStratifiedKFold,
integer_types, range_, string_type)
def train(params, train_set, num_boost_round=100,
......@@ -94,6 +95,17 @@ def train(params, train_set, num_boost_round=100,
booster : a trained booster model
"""
"""create predictor first"""
for alias in ["num_boost_round", "num_iterations", "num_iteration", "num_tree", "num_trees", "num_round", "num_rounds"]:
if alias in params:
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
num_boost_round = params.pop(alias)
break
for alias in ["early_stopping_round", "early_stopping_rounds", "early_stopping"]:
if alias in params:
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
early_stopping_rounds = params.pop(alias)
break
if isinstance(init_model, string_type):
predictor = _InnerPredictor(model_file=init_model)
elif isinstance(init_model, Booster):
......@@ -370,6 +382,17 @@ def cv(params, train_set, num_boost_round=10,
if not isinstance(train_set, Dataset):
raise TypeError("Traninig only accepts Dataset object")
for alias in ["num_boost_round", "num_iterations", "num_iteration", "num_tree", "num_trees", "num_round", "num_rounds"]:
if alias in params:
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
num_boost_round = params.pop(alias)
break
for alias in ["early_stopping_round", "early_stopping_rounds", "early_stopping"]:
if alias in params:
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
early_stopping_rounds = params.pop(alias)
break
if isinstance(init_model, string_type):
predictor = _InnerPredictor(model_file=init_model)
elif isinstance(init_model, Booster):
......
......@@ -37,18 +37,20 @@ class TestEngine(unittest.TestCase):
params = {
'objective': 'binary',
'metric': 'binary_logloss',
'verbose': -1
'verbose': -1,
'num_iteration': 50 # test num_iteration in dict here
}
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
evals_result = {}
gbm = lgb.train(params, lgb_train,
num_boost_round=50,
num_boost_round=20,
valid_sets=lgb_eval,
verbose_eval=False,
evals_result=evals_result)
ret = log_loss(y_test, gbm.predict(X_test))
self.assertLess(ret, 0.15)
self.assertEqual(len(evals_result['valid_0']['binary_logloss']), 50)
self.assertAlmostEqual(evals_result['valid_0']['binary_logloss'][-1], ret, places=5)
def test_rf(self):
......@@ -454,17 +456,12 @@ class TestEngine(unittest.TestCase):
np.testing.assert_almost_equal(pred0, pred3)
np.testing.assert_almost_equal(pred0, pred4)
def test_subset_train_val(self):
'''
Tests that it's fine to construct a single lgb.Dataframe object,
takes subsets of it, and uses the subsets for training and validation
'''
n = 1000
X = np.random.normal(size=(n, 2))
y = np.random.normal(size=n)
def test_reference_chain(self):
X = np.random.normal(size=(100, 2))
y = np.random.normal(size=100)
tmp_dat = lgb.Dataset(X, y)
# take subsets and train
tmp_dat_train = tmp_dat.subset(np.arange(int(n * .8)))
tmp_dat_val = tmp_dat.subset(np.arange(int(n * .8), n)).subset(np.arange(n * .2 * .9))
tmp_dat_train = tmp_dat.subset(np.arange(80))
tmp_dat_val = tmp_dat.subset(np.arange(80, 100)).subset(np.arange(18))
params = {'objective': 'regression_l2', 'metric': 'rmse'}
gbm = lgb.train(params, tmp_dat_train, num_boost_round=20, valid_sets=[tmp_dat_train, tmp_dat_val])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment