Unverified Commit ebc831bc authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix bug when using dart with init_model (#2251)

* add test

* fix a index bug
parent 291752de
......@@ -171,8 +171,8 @@ class DART: public GBDT {
train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
}
if (!config_->uniform_drop) {
sum_weight_ -= tree_weight_[i] * (1.0f / (k + 1.0f));
tree_weight_[i] *= (k / (k + 1.0f));
sum_weight_ -= tree_weight_[i - num_init_iteration_] * (1.0f / (k + 1.0f));
tree_weight_[i - num_init_iteration_] *= (k / (k + 1.0f));
}
}
} else {
......@@ -189,8 +189,8 @@ class DART: public GBDT {
train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
}
if (!config_->uniform_drop) {
sum_weight_ -= tree_weight_[i] * (1.0f / (k + config_->learning_rate));;
tree_weight_[i] *= (k / (k + config_->learning_rate));
sum_weight_ -= tree_weight_[i - num_init_iteration_] * (1.0f / (k + config_->learning_rate));;
tree_weight_[i - num_init_iteration_] *= (k / (k + config_->learning_rate));
}
}
}
......
......@@ -480,6 +480,31 @@ class TestEngine(unittest.TestCase):
self.assertAlmostEqual(l1, mae, places=5)
os.remove(model_name)
def test_continue_train_dart(self):
X, y = load_boston(True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
params = {
'boosting_type': 'dart',
'objective': 'regression',
'metric': 'l1',
'verbose': -1
}
lgb_train = lgb.Dataset(X_train, y_train, free_raw_data=False)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train, free_raw_data=False)
init_gbm = lgb.train(params, lgb_train, num_boost_round=50)
evals_result = {}
gbm = lgb.train(params, lgb_train,
num_boost_round=50,
valid_sets=lgb_eval,
verbose_eval=False,
evals_result=evals_result,
init_model=init_gbm)
ret = mean_absolute_error(y_test, gbm.predict(X_test))
self.assertLess(ret, 3.5)
self.assertAlmostEqual(evals_result['valid_0']['l1'][-1], ret, places=5)
for l1, mae in zip(evals_result['valid_0']['l1'], evals_result['valid_0']['mae']):
self.assertAlmostEqual(l1, mae, places=5)
def test_continue_train_multiclass(self):
X, y = load_iris(True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment