Commit f6024c8b authored by Guolin Ke's avatar Guolin Ke
Browse files

fix test for continued train, due to default saved number of model is best_iteration now

parent 28972b86
...@@ -17,9 +17,9 @@ ...@@ -17,9 +17,9 @@
namespace LightGBM { namespace LightGBM {
GBDT::GBDT() GBDT::GBDT()
:num_iteration_for_pred_(0), :iter_(0),
num_iteration_for_pred_(0),
num_init_iteration_(0) { num_init_iteration_(0) {
} }
GBDT::~GBDT() { GBDT::~GBDT() {
...@@ -581,6 +581,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) { ...@@ -581,6 +581,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
Log::Info("Finished loading %d models", models_.size()); Log::Info("Finished loading %d models", models_.size());
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_class_; num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_class_;
num_init_iteration_ = num_iteration_for_pred_; num_init_iteration_ = num_iteration_for_pred_;
iter_ = 0;
} }
std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const { std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const {
......
...@@ -9,8 +9,7 @@ import lightgbm as lgb ...@@ -9,8 +9,7 @@ import lightgbm as lgb
class TestBasic(unittest.TestCase): class TestBasic(unittest.TestCase):
def test(self): def test(self):
X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(True), test_size=0.1) X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1)
train_data = lgb.Dataset(X_train, max_bin=255, label=y_train) train_data = lgb.Dataset(X_train, max_bin=255, label=y_train)
valid_data = train_data.create_valid(X_test, label=y_test) valid_data = train_data.create_valid(X_test, label=y_test)
......
...@@ -17,7 +17,7 @@ def multi_logloss(y_true, y_pred): ...@@ -17,7 +17,7 @@ def multi_logloss(y_true, y_pred):
def test_template(params = {'objective' : 'regression', 'metric' : 'l2'}, def test_template(params = {'objective' : 'regression', 'metric' : 'l2'},
X_y=load_boston(True), feval=mean_squared_error, X_y=load_boston(True), feval=mean_squared_error,
num_round=100, init_model=None, custom_eval=None, num_round=100, init_model=None, custom_eval=None,
return_data=False, return_model=False): return_data=False, return_model=False, early_stopping_rounds=10):
X_train, X_test, y_train, y_test = train_test_split(*X_y, test_size=0.1, random_state=42) X_train, X_test, y_train, y_test = train_test_split(*X_y, test_size=0.1, random_state=42)
lgb_train = lgb.Dataset(X_train, y_train, params=params) lgb_train = lgb.Dataset(X_train, y_train, params=params)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train, params=params) lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train, params=params)
...@@ -31,7 +31,7 @@ def test_template(params = {'objective' : 'regression', 'metric' : 'l2'}, ...@@ -31,7 +31,7 @@ def test_template(params = {'objective' : 'regression', 'metric' : 'l2'},
verbose_eval=False, verbose_eval=False,
feval=custom_eval, feval=custom_eval,
evals_result=evals_result, evals_result=evals_result,
early_stopping_rounds=10, early_stopping_rounds=early_stopping_rounds,
init_model=init_model) init_model=init_model)
if return_model: return gbm if return_model: return gbm
else: return evals_result, feval(y_test, gbm.predict(X_test, gbm.best_iteration)) else: return evals_result, feval(y_test, gbm.predict(X_test, gbm.best_iteration))
...@@ -71,7 +71,7 @@ class TestEngine(unittest.TestCase): ...@@ -71,7 +71,7 @@ class TestEngine(unittest.TestCase):
'metric' : 'l1' 'metric' : 'l1'
} }
model_name = 'model.txt' model_name = 'model.txt'
gbm = test_template(params, num_round=20, return_model=True) gbm = test_template(params, num_round=20, return_model=True, early_stopping_rounds=-1)
gbm.save_model(model_name) gbm.save_model(model_name)
evals_result, ret = test_template(params, feval=mean_absolute_error, evals_result, ret = test_template(params, feval=mean_absolute_error,
num_round=80, init_model=model_name, num_round=80, init_model=model_name,
...@@ -91,7 +91,7 @@ class TestEngine(unittest.TestCase): ...@@ -91,7 +91,7 @@ class TestEngine(unittest.TestCase):
'metric' : 'multi_logloss', 'metric' : 'multi_logloss',
'num_class' : 3 'num_class' : 3
} }
gbm = test_template(params, X_y, num_round=20, return_model=True) gbm = test_template(params, X_y, num_round=20, return_model=True, early_stopping_rounds=-1)
evals_result, ret = test_template(params, X_y, feval=multi_logloss, evals_result, ret = test_template(params, X_y, feval=multi_logloss,
num_round=80, init_model=gbm) num_round=80, init_model=gbm)
self.assertLess(ret, 1.5) self.assertLess(ret, 1.5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment