Commit 9db054cf authored by Guolin Ke's avatar Guolin Ke
Browse files

fix bug in ResetTrainingData

parent 99b483dd
......@@ -19,7 +19,7 @@ namespace LightGBM {
#ifndef CHECK_NOTNULL
#define CHECK_NOTNULL(pointer) \
if ((pointer) == nullptr) LightGBM::Log::Fatal(#pointer " Can't be NULL");
if ((pointer) == nullptr) LightGBM::Log::Fatal(#pointer " Can't be NULL at %s, line %d .\n", __FILE__, __LINE__);
#endif
......
......@@ -32,44 +32,21 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
num_iteration_for_pred_ = 0;
max_feature_idx_ = 0;
num_class_ = config->num_class;
random_ = Random(config->bagging_seed);
train_data_ = nullptr;
gbdt_config_ = nullptr;
tree_learner_.clear();
ResetTrainingData(config, train_data, object_function, training_metrics);
}
void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics) {
if (train_data == nullptr) { return; }
auto new_config = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
if (train_data_ != nullptr && !train_data_->CheckAlign(*train_data)) {
Log::Fatal("cannot reset training data, since new training data has different bin mappers");
}
early_stopping_round_ = new_config->early_stopping_round;
shrinkage_rate_ = new_config->learning_rate;
// cannot reset seed, only create one time
if (gbdt_config_ == nullptr ) {
random_ = Random(new_config->bagging_seed);
}
// create tree learner, only create once
if (gbdt_config_ == nullptr) {
tree_learner_.clear();
for (int i = 0; i < num_class_; ++i) {
auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(new_config->tree_learner_type, &new_config->tree_config));
tree_learner_.push_back(std::move(new_tree_learner));
}
tree_learner_.shrink_to_fit();
}
// init tree learner
if (train_data_ != train_data) {
for (int i = 0; i < num_class_; ++i) {
tree_learner_[i]->Init(train_data);
}
}
// reset config for tree learner
for (int i = 0; i < num_class_; ++i) {
tree_learner_[i]->ResetConfig(&new_config->tree_config);
}
object_function_ = object_function;
......@@ -80,7 +57,19 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
sigmoid_ = new_config->sigmoid;
}
if (train_data_ != train_data) {
if (train_data_ != train_data && train_data != nullptr) {
if (tree_learner_.empty()) {
for (int i = 0; i < num_class_; ++i) {
auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(new_config->tree_learner_type, &new_config->tree_config));
tree_learner_.push_back(std::move(new_tree_learner));
}
tree_learner_.shrink_to_fit();
}
// init tree learner
for (int i = 0; i < num_class_; ++i) {
tree_learner_[i]->Init(train_data);
}
// push training metrics
training_metrics_.clear();
for (const auto& metric : training_metrics) {
......@@ -109,9 +98,8 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
label_idx_ = train_data->label_idx();
}
if (train_data_ != train_data
|| gbdt_config_ == nullptr
|| (gbdt_config_->bagging_fraction != new_config->bagging_fraction)) {
if ((train_data_ != train_data && train_data != nullptr)
|| (gbdt_config_ != nullptr && gbdt_config_->bagging_fraction != new_config->bagging_fraction)) {
// if need bagging, create buffer
if (new_config->bagging_fraction < 1.0 && new_config->bagging_freq > 0) {
out_of_bag_data_indices_.resize(num_data_);
......@@ -124,6 +112,12 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
}
}
train_data_ = train_data;
if (train_data_ != nullptr) {
// reset config for tree learner
for (int i = 0; i < num_class_; ++i) {
tree_learner_[i]->ResetConfig(&new_config->tree_config);
}
}
gbdt_config_.reset(new_config.release());
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment