#ifndef LIGHTGBM_BOOSTING_GOSS_H_ #define LIGHTGBM_BOOSTING_GOSS_H_ #include #include #include #include #include "score_updater.hpp" #include "gbdt.h" #include #include #include #include #include #include namespace LightGBM { #ifdef TIMETAG std::chrono::duration subset_time; std::chrono::duration re_init_tree_time; #endif class GOSS: public GBDT { public: /*! * \brief Constructor */ GOSS() : GBDT() { } ~GOSS() { #ifdef TIMETAG Log::Info("GOSS::subset costs %f", subset_time * 1e-3); Log::Info("GOSS::re_init_tree costs %f", re_init_tree_time * 1e-3); #endif } void Init(const Config* config, const Dataset* train_data, const ObjectiveFunction* objective_function, const std::vector& training_metrics) override { GBDT::Init(config, train_data, objective_function, training_metrics); ResetGoss(); } void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function, const std::vector& training_metrics) override { GBDT::ResetTrainingData(train_data, objective_function, training_metrics); ResetGoss(); } void ResetConfig(const Config* config) override { GBDT::ResetConfig(config); ResetGoss(); } void ResetGoss() { CHECK(config_->top_rate + config_->other_rate <= 1.0f); CHECK(config_->top_rate > 0.0f && config_->other_rate > 0.0f); if (config_->bagging_freq > 0 && config_->bagging_fraction != 1.0f) { Log::Fatal("Cannot use bagging in GOSS"); } Log::Info("Using GOSS"); bag_data_indices_.resize(num_data_); tmp_indices_.resize(num_data_); tmp_indice_right_.resize(num_data_); offsets_buf_.resize(num_threads_); left_cnts_buf_.resize(num_threads_); right_cnts_buf_.resize(num_threads_); left_write_pos_buf_.resize(num_threads_); right_write_pos_buf_.resize(num_threads_); is_use_subset_ = false; if (config_->top_rate + config_->other_rate <= 0.5) { auto bag_data_cnt = static_cast((config_->top_rate + config_->other_rate) * num_data_); bag_data_cnt = std::max(1, bag_data_cnt); tmp_subset_.reset(new Dataset(bag_data_cnt)); tmp_subset_->CopyFeatureMapperFrom(train_data_); is_use_subset_ = true; } // flag to not bagging first bag_data_cnt_ = num_data_; } data_size_t BaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer, data_size_t* buffer_right) { std::vector tmp_gradients(cnt, 0.0f); for (data_size_t i = 0; i < cnt; ++i) { for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { size_t idx = static_cast(cur_tree_id) * num_data_ + start + i; tmp_gradients[i] += std::fabs(gradients_[idx] * hessians_[idx]); } } data_size_t top_k = static_cast(cnt * config_->top_rate); data_size_t other_k = static_cast(cnt * config_->other_rate); top_k = std::max(1, top_k); ArrayArgs::ArgMaxAtK(&tmp_gradients, 0, static_cast(tmp_gradients.size()), top_k - 1); score_t threshold = tmp_gradients[top_k - 1]; score_t multiply = static_cast(cnt - top_k) / other_k; data_size_t cur_left_cnt = 0; data_size_t cur_right_cnt = 0; data_size_t big_weight_cnt = 0; for (data_size_t i = 0; i < cnt; ++i) { score_t grad = 0.0f; for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { size_t idx = static_cast(cur_tree_id) * num_data_ + start + i; grad += std::fabs(gradients_[idx] * hessians_[idx]); } if (grad >= threshold) { buffer[cur_left_cnt++] = start + i; ++big_weight_cnt; } else { data_size_t sampled = cur_left_cnt - big_weight_cnt; data_size_t rest_need = other_k - sampled; data_size_t rest_all = (cnt - i) - (top_k - big_weight_cnt); double prob = (rest_need) / static_cast(rest_all); if (cur_rand.NextFloat() < prob) { buffer[cur_left_cnt++] = start + i; for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { size_t idx = static_cast(cur_tree_id) * num_data_ + start + i; gradients_[idx] *= multiply; hessians_[idx] *= multiply; } } else { buffer_right[cur_right_cnt++] = start + i; } } } return cur_left_cnt; } void Bagging(int iter) override { bag_data_cnt_ = num_data_; // not subsample for first iterations if (iter < static_cast(1.0f / config_->learning_rate)) { return; } const data_size_t min_inner_size = 100; data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_; if (inner_size < min_inner_size) { inner_size = min_inner_size; } OMP_INIT_EX(); #pragma omp parallel for schedule(static, 1) for (int i = 0; i < num_threads_; ++i) { OMP_LOOP_EX_BEGIN(); left_cnts_buf_[i] = 0; right_cnts_buf_[i] = 0; data_size_t cur_start = i * inner_size; if (cur_start > num_data_) { continue; } data_size_t cur_cnt = inner_size; if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; } Random cur_rand(config_->bagging_seed + iter * num_threads_ + i); data_size_t cur_left_count = BaggingHelper(cur_rand, cur_start, cur_cnt, tmp_indices_.data() + cur_start, tmp_indice_right_.data() + cur_start); offsets_buf_[i] = cur_start; left_cnts_buf_[i] = cur_left_count; right_cnts_buf_[i] = cur_cnt - cur_left_count; OMP_LOOP_EX_END(); } OMP_THROW_EX(); data_size_t left_cnt = 0; left_write_pos_buf_[0] = 0; right_write_pos_buf_[0] = 0; for (int i = 1; i < num_threads_; ++i) { left_write_pos_buf_[i] = left_write_pos_buf_[i - 1] + left_cnts_buf_[i - 1]; right_write_pos_buf_[i] = right_write_pos_buf_[i - 1] + right_cnts_buf_[i - 1]; } left_cnt = left_write_pos_buf_[num_threads_ - 1] + left_cnts_buf_[num_threads_ - 1]; #pragma omp parallel for schedule(static, 1) for (int i = 0; i < num_threads_; ++i) { OMP_LOOP_EX_BEGIN(); if (left_cnts_buf_[i] > 0) { std::memcpy(bag_data_indices_.data() + left_write_pos_buf_[i], tmp_indices_.data() + offsets_buf_[i], left_cnts_buf_[i] * sizeof(data_size_t)); } if (right_cnts_buf_[i] > 0) { std::memcpy(bag_data_indices_.data() + left_cnt + right_write_pos_buf_[i], tmp_indice_right_.data() + offsets_buf_[i], right_cnts_buf_[i] * sizeof(data_size_t)); } OMP_LOOP_EX_END(); } OMP_THROW_EX(); bag_data_cnt_ = left_cnt; // set bagging data to tree learner if (!is_use_subset_) { tree_learner_->SetBaggingData(bag_data_indices_.data(), bag_data_cnt_); } else { // get subset #ifdef TIMETAG auto start_time = std::chrono::steady_clock::now(); #endif tmp_subset_->ReSize(bag_data_cnt_); tmp_subset_->CopySubset(train_data_, bag_data_indices_.data(), bag_data_cnt_, false); #ifdef TIMETAG subset_time += std::chrono::steady_clock::now() - start_time; #endif #ifdef TIMETAG start_time = std::chrono::steady_clock::now(); #endif tree_learner_->ResetTrainingData(tmp_subset_.get()); #ifdef TIMETAG re_init_tree_time += std::chrono::steady_clock::now() - start_time; #endif } } private: std::vector tmp_indice_right_; }; } // namespace LightGBM #endif // LIGHTGBM_BOOSTING_GOSS_H_