"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "61fb5ea24d5aaa2bed4b893e01c39ab966893f7d"
Commit 586d53fb authored by Guolin Ke's avatar Guolin Ke
Browse files

remove the useless training.

parent fe9061fa
...@@ -161,6 +161,7 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -161,6 +161,7 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
if (train_data_ != nullptr) { if (train_data_ != nullptr) {
// reset config for tree learner // reset config for tree learner
tree_learner_->ResetConfig(&new_config->tree_config); tree_learner_->ResetConfig(&new_config->tree_config);
is_class_end_ = std::vector<bool>(num_class_, false);
} }
gbdt_config_.reset(new_config.release()); gbdt_config_.reset(new_config.release());
} }
...@@ -336,31 +337,35 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is ...@@ -336,31 +337,35 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
sub_gradient_time += std::chrono::steady_clock::now() - start_time; sub_gradient_time += std::chrono::steady_clock::now() - start_time;
#endif #endif
} }
bool shouldContinue = false; bool should_continue = false;
for (int curr_class = 0; curr_class < num_class_; ++curr_class) { for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
#ifdef TIMETAG #ifdef TIMETAG
start_time = std::chrono::steady_clock::now(); start_time = std::chrono::steady_clock::now();
#endif #endif
// train a new tree std::unique_ptr<Tree> new_tree(new Tree(2));
std::unique_ptr<Tree> new_tree(tree_learner_->Train(gradient + curr_class * num_data_, hessian + curr_class * num_data_)); if (!is_class_end_[curr_class]) {
// train a new tree
new_tree.reset(tree_learner_->Train(gradient + curr_class * num_data_, hessian + curr_class * num_data_));
}
#ifdef TIMETAG #ifdef TIMETAG
tree_time += std::chrono::steady_clock::now() - start_time; tree_time += std::chrono::steady_clock::now() - start_time;
#endif #endif
if (new_tree->num_leaves() > 1) { if (new_tree->num_leaves() > 1) {
shouldContinue = true; should_continue = true;
// shrinkage by learning rate
new_tree->Shrinkage(shrinkage_rate_);
// update score
UpdateScore(new_tree.get(), curr_class);
UpdateScoreOutOfBag(new_tree.get(), curr_class);
} else {
is_class_end_[curr_class] = true;
} }
// shrinkage by learning rate
new_tree->Shrinkage(shrinkage_rate_);
// update score
UpdateScore(new_tree.get(), curr_class);
UpdateScoreOutOfBag(new_tree.get(), curr_class);
// add model // add model
models_.push_back(std::move(new_tree)); models_.push_back(std::move(new_tree));
} }
if (!shouldContinue) { if (!should_continue) {
Log::Warning("Stopped training because there are no more leaves that meet the split requirements."); Log::Warning("Stopped training because there are no more leaves that meet the split requirements.");
for (int curr_class = 0; curr_class < num_class_; ++curr_class) { for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
models_.pop_back(); models_.pop_back();
...@@ -392,6 +397,7 @@ void GBDT::RollbackOneIter() { ...@@ -392,6 +397,7 @@ void GBDT::RollbackOneIter() {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) { for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
models_.pop_back(); models_.pop_back();
} }
is_class_end_ = std::vector<bool>(num_class_, false);
--iter_; --iter_;
} }
......
...@@ -344,6 +344,7 @@ protected: ...@@ -344,6 +344,7 @@ protected:
std::vector<data_size_t> right_write_pos_buf_; std::vector<data_size_t> right_write_pos_buf_;
std::unique_ptr<Dataset> tmp_subset_; std::unique_ptr<Dataset> tmp_subset_;
bool is_use_subset_; bool is_use_subset_;
std::vector<bool> is_class_end_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment