Commit 9db054cf authored by Guolin Ke's avatar Guolin Ke
Browse files

fix bug in ResetTrainingData

parent 99b483dd
...@@ -19,7 +19,7 @@ namespace LightGBM { ...@@ -19,7 +19,7 @@ namespace LightGBM {
#ifndef CHECK_NOTNULL #ifndef CHECK_NOTNULL
#define CHECK_NOTNULL(pointer) \ #define CHECK_NOTNULL(pointer) \
if ((pointer) == nullptr) LightGBM::Log::Fatal(#pointer " Can't be NULL"); if ((pointer) == nullptr) LightGBM::Log::Fatal(#pointer " Can't be NULL at %s, line %d .\n", __FILE__, __LINE__);
#endif #endif
......
...@@ -32,44 +32,21 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O ...@@ -32,44 +32,21 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
num_iteration_for_pred_ = 0; num_iteration_for_pred_ = 0;
max_feature_idx_ = 0; max_feature_idx_ = 0;
num_class_ = config->num_class; num_class_ = config->num_class;
random_ = Random(config->bagging_seed);
train_data_ = nullptr; train_data_ = nullptr;
gbdt_config_ = nullptr; gbdt_config_ = nullptr;
tree_learner_.clear();
ResetTrainingData(config, train_data, object_function, training_metrics); ResetTrainingData(config, train_data, object_function, training_metrics);
} }
void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics) { const std::vector<const Metric*>& training_metrics) {
if (train_data == nullptr) { return; }
auto new_config = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config)); auto new_config = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
if (train_data_ != nullptr && !train_data_->CheckAlign(*train_data)) { if (train_data_ != nullptr && !train_data_->CheckAlign(*train_data)) {
Log::Fatal("cannot reset training data, since new training data has different bin mappers"); Log::Fatal("cannot reset training data, since new training data has different bin mappers");
} }
early_stopping_round_ = new_config->early_stopping_round; early_stopping_round_ = new_config->early_stopping_round;
shrinkage_rate_ = new_config->learning_rate; shrinkage_rate_ = new_config->learning_rate;
// cannot reset seed, only create one time
if (gbdt_config_ == nullptr ) {
random_ = Random(new_config->bagging_seed);
}
// create tree learner, only create once
if (gbdt_config_ == nullptr) {
tree_learner_.clear();
for (int i = 0; i < num_class_; ++i) {
auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(new_config->tree_learner_type, &new_config->tree_config));
tree_learner_.push_back(std::move(new_tree_learner));
}
tree_learner_.shrink_to_fit();
}
// init tree learner
if (train_data_ != train_data) {
for (int i = 0; i < num_class_; ++i) {
tree_learner_[i]->Init(train_data);
}
}
// reset config for tree learner
for (int i = 0; i < num_class_; ++i) {
tree_learner_[i]->ResetConfig(&new_config->tree_config);
}
object_function_ = object_function; object_function_ = object_function;
...@@ -80,7 +57,19 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -80,7 +57,19 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
sigmoid_ = new_config->sigmoid; sigmoid_ = new_config->sigmoid;
} }
if (train_data_ != train_data) { if (train_data_ != train_data && train_data != nullptr) {
if (tree_learner_.empty()) {
for (int i = 0; i < num_class_; ++i) {
auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(new_config->tree_learner_type, &new_config->tree_config));
tree_learner_.push_back(std::move(new_tree_learner));
}
tree_learner_.shrink_to_fit();
}
// init tree learner
for (int i = 0; i < num_class_; ++i) {
tree_learner_[i]->Init(train_data);
}
// push training metrics // push training metrics
training_metrics_.clear(); training_metrics_.clear();
for (const auto& metric : training_metrics) { for (const auto& metric : training_metrics) {
...@@ -109,9 +98,8 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -109,9 +98,8 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
label_idx_ = train_data->label_idx(); label_idx_ = train_data->label_idx();
} }
if (train_data_ != train_data if ((train_data_ != train_data && train_data != nullptr)
|| gbdt_config_ == nullptr || (gbdt_config_ != nullptr && gbdt_config_->bagging_fraction != new_config->bagging_fraction)) {
|| (gbdt_config_->bagging_fraction != new_config->bagging_fraction)) {
// if need bagging, create buffer // if need bagging, create buffer
if (new_config->bagging_fraction < 1.0 && new_config->bagging_freq > 0) { if (new_config->bagging_fraction < 1.0 && new_config->bagging_freq > 0) {
out_of_bag_data_indices_.resize(num_data_); out_of_bag_data_indices_.resize(num_data_);
...@@ -124,6 +112,12 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -124,6 +112,12 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
} }
} }
train_data_ = train_data; train_data_ = train_data;
if (train_data_ != nullptr) {
// reset config for tree learner
for (int i = 0; i < num_class_; ++i) {
tree_learner_[i]->ResetConfig(&new_config->tree_config);
}
}
gbdt_config_.reset(new_config.release()); gbdt_config_.reset(new_config.release());
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment