"qa/vscode:/vscode.git/clone" did not exist on "847f4654f0c4bcd9a3d1e3c33da9b79a836488f2"
Commit 3beee91d authored by Guolin Ke's avatar Guolin Ke
Browse files

only stop training when all classes are finshed in multi-class.

parent 2e962c77
...@@ -162,7 +162,6 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -162,7 +162,6 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
if (train_data_ != nullptr) { if (train_data_ != nullptr) {
// reset config for tree learner // reset config for tree learner
tree_learner_->ResetConfig(&new_config->tree_config); tree_learner_->ResetConfig(&new_config->tree_config);
is_class_end_ = std::vector<bool>(num_class_, false);
} }
gbdt_config_.reset(new_config.release()); gbdt_config_.reset(new_config.release());
} }
...@@ -284,7 +283,7 @@ void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int curr_class) { ...@@ -284,7 +283,7 @@ void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int curr_class) {
#ifdef TIMETAG #ifdef TIMETAG
auto start_time = std::chrono::steady_clock::now(); auto start_time = std::chrono::steady_clock::now();
#endif #endif
// we need to predict out-of-bag socres of data for boosting // we need to predict out-of-bag scores of data for boosting
if (num_data_ - bag_data_cnt_ > 0 && !is_use_subset_) { if (num_data_ - bag_data_cnt_ > 0 && !is_use_subset_) {
train_score_updater_->AddScore(tree, bag_data_indices_.data() + bag_data_cnt_, num_data_ - bag_data_cnt_, curr_class); train_score_updater_->AddScore(tree, bag_data_indices_.data() + bag_data_cnt_, num_data_ - bag_data_cnt_, curr_class);
} }
...@@ -351,7 +350,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is ...@@ -351,7 +350,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
// get sub gradients // get sub gradients
for (int curr_class = 0; curr_class < num_class_; ++curr_class) { for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
auto bias = curr_class * num_data_; auto bias = curr_class * num_data_;
// cannot multi-threding // cannot multi-threading
for (int i = 0; i < bag_data_cnt_; ++i) { for (int i = 0; i < bag_data_cnt_; ++i) {
gradients_[bias + i] = gradient[bias + bag_data_indices_[i]]; gradients_[bias + i] = gradient[bias + bag_data_indices_[i]];
hessians_[bias + i] = hessian[bias + bag_data_indices_[i]]; hessians_[bias + i] = hessian[bias + bag_data_indices_[i]];
...@@ -369,10 +368,8 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is ...@@ -369,10 +368,8 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
start_time = std::chrono::steady_clock::now(); start_time = std::chrono::steady_clock::now();
#endif #endif
std::unique_ptr<Tree> new_tree(new Tree(2)); std::unique_ptr<Tree> new_tree(new Tree(2));
if (!is_class_end_[curr_class]) { // train a new tree
// train a new tree new_tree.reset(tree_learner_->Train(gradient + curr_class * num_data_, hessian + curr_class * num_data_));
new_tree.reset(tree_learner_->Train(gradient + curr_class * num_data_, hessian + curr_class * num_data_));
}
#ifdef TIMETAG #ifdef TIMETAG
tree_time += std::chrono::steady_clock::now() - start_time; tree_time += std::chrono::steady_clock::now() - start_time;
#endif #endif
...@@ -384,10 +381,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is ...@@ -384,10 +381,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
// update score // update score
UpdateScore(new_tree.get(), curr_class); UpdateScore(new_tree.get(), curr_class);
UpdateScoreOutOfBag(new_tree.get(), curr_class); UpdateScoreOutOfBag(new_tree.get(), curr_class);
} else {
is_class_end_[curr_class] = true;
} }
// add model // add model
models_.push_back(std::move(new_tree)); models_.push_back(std::move(new_tree));
} }
...@@ -423,7 +417,6 @@ void GBDT::RollbackOneIter() { ...@@ -423,7 +417,6 @@ void GBDT::RollbackOneIter() {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) { for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
models_.pop_back(); models_.pop_back();
} }
is_class_end_ = std::vector<bool>(num_class_, false);
--iter_; --iter_;
} }
......
...@@ -344,7 +344,6 @@ protected: ...@@ -344,7 +344,6 @@ protected:
std::vector<data_size_t> right_write_pos_buf_; std::vector<data_size_t> right_write_pos_buf_;
std::unique_ptr<Dataset> tmp_subset_; std::unique_ptr<Dataset> tmp_subset_;
bool is_use_subset_; bool is_use_subset_;
std::vector<bool> is_class_end_;
bool boost_from_average_; bool boost_from_average_;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment