Commit 3beee91d authored by Guolin Ke's avatar Guolin Ke
Browse files

only stop training when all classes are finshed in multi-class.

parent 2e962c77
......@@ -162,7 +162,6 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
if (train_data_ != nullptr) {
// reset config for tree learner
tree_learner_->ResetConfig(&new_config->tree_config);
is_class_end_ = std::vector<bool>(num_class_, false);
}
gbdt_config_.reset(new_config.release());
}
......@@ -284,7 +283,7 @@ void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int curr_class) {
#ifdef TIMETAG
auto start_time = std::chrono::steady_clock::now();
#endif
// we need to predict out-of-bag socres of data for boosting
// we need to predict out-of-bag scores of data for boosting
if (num_data_ - bag_data_cnt_ > 0 && !is_use_subset_) {
train_score_updater_->AddScore(tree, bag_data_indices_.data() + bag_data_cnt_, num_data_ - bag_data_cnt_, curr_class);
}
......@@ -351,7 +350,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
// get sub gradients
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
auto bias = curr_class * num_data_;
// cannot multi-threding
// cannot multi-threading
for (int i = 0; i < bag_data_cnt_; ++i) {
gradients_[bias + i] = gradient[bias + bag_data_indices_[i]];
hessians_[bias + i] = hessian[bias + bag_data_indices_[i]];
......@@ -369,10 +368,8 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
start_time = std::chrono::steady_clock::now();
#endif
std::unique_ptr<Tree> new_tree(new Tree(2));
if (!is_class_end_[curr_class]) {
// train a new tree
new_tree.reset(tree_learner_->Train(gradient + curr_class * num_data_, hessian + curr_class * num_data_));
}
#ifdef TIMETAG
tree_time += std::chrono::steady_clock::now() - start_time;
#endif
......@@ -384,10 +381,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
// update score
UpdateScore(new_tree.get(), curr_class);
UpdateScoreOutOfBag(new_tree.get(), curr_class);
} else {
is_class_end_[curr_class] = true;
}
// add model
models_.push_back(std::move(new_tree));
}
......@@ -423,7 +417,6 @@ void GBDT::RollbackOneIter() {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
models_.pop_back();
}
is_class_end_ = std::vector<bool>(num_class_, false);
--iter_;
}
......
......@@ -344,7 +344,6 @@ protected:
std::vector<data_size_t> right_write_pos_buf_;
std::unique_ptr<Dataset> tmp_subset_;
bool is_use_subset_;
std::vector<bool> is_class_end_;
bool boost_from_average_;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment