"src/vscode:/vscode.git/clone" did not exist on "4f232570a8c381f7995f584ab81e0a3ab4452969"
Commit 586d53fb authored by Guolin Ke's avatar Guolin Ke
Browse files

remove the useless training.

parent fe9061fa
......@@ -161,6 +161,7 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
if (train_data_ != nullptr) {
// reset config for tree learner
tree_learner_->ResetConfig(&new_config->tree_config);
is_class_end_ = std::vector<bool>(num_class_, false);
}
gbdt_config_.reset(new_config.release());
}
......@@ -336,31 +337,35 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
sub_gradient_time += std::chrono::steady_clock::now() - start_time;
#endif
}
bool shouldContinue = false;
bool should_continue = false;
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
#ifdef TIMETAG
start_time = std::chrono::steady_clock::now();
#endif
std::unique_ptr<Tree> new_tree(new Tree(2));
if (!is_class_end_[curr_class]) {
// train a new tree
std::unique_ptr<Tree> new_tree(tree_learner_->Train(gradient + curr_class * num_data_, hessian + curr_class * num_data_));
new_tree.reset(tree_learner_->Train(gradient + curr_class * num_data_, hessian + curr_class * num_data_));
}
#ifdef TIMETAG
tree_time += std::chrono::steady_clock::now() - start_time;
#endif
if (new_tree->num_leaves() > 1) {
shouldContinue = true;
}
should_continue = true;
// shrinkage by learning rate
new_tree->Shrinkage(shrinkage_rate_);
// update score
UpdateScore(new_tree.get(), curr_class);
UpdateScoreOutOfBag(new_tree.get(), curr_class);
} else {
is_class_end_[curr_class] = true;
}
// add model
models_.push_back(std::move(new_tree));
}
if (!shouldContinue) {
if (!should_continue) {
Log::Warning("Stopped training because there are no more leaves that meet the split requirements.");
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
models_.pop_back();
......@@ -392,6 +397,7 @@ void GBDT::RollbackOneIter() {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
models_.pop_back();
}
is_class_end_ = std::vector<bool>(num_class_, false);
--iter_;
}
......
......@@ -344,6 +344,7 @@ protected:
std::vector<data_size_t> right_write_pos_buf_;
std::unique_ptr<Dataset> tmp_subset_;
bool is_use_subset_;
std::vector<bool> is_class_end_;
};
} // namespace LightGBM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment