Commit 5d12a8db authored by Guolin Ke's avatar Guolin Ke
Browse files

speed up bagging by multi-threading

parent 4306b22c
#include "gbdt.h" #include "gbdt.h"
#include <omp.h>
#include <LightGBM/utils/common.h> #include <LightGBM/utils/common.h>
#include <LightGBM/feature.h> #include <LightGBM/feature.h>
...@@ -27,6 +29,11 @@ GBDT::GBDT() ...@@ -27,6 +29,11 @@ GBDT::GBDT()
num_iteration_for_pred_(0), num_iteration_for_pred_(0),
shrinkage_rate_(0.1f), shrinkage_rate_(0.1f),
num_init_iteration_(0) { num_init_iteration_(0) {
#pragma omp parallel
#pragma omp master
{
num_threads_ = omp_get_num_threads();
}
} }
GBDT::~GBDT() { GBDT::~GBDT() {
...@@ -39,7 +46,9 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O ...@@ -39,7 +46,9 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
num_iteration_for_pred_ = 0; num_iteration_for_pred_ = 0;
max_feature_idx_ = 0; max_feature_idx_ = 0;
num_class_ = config->num_class; num_class_ = config->num_class;
random_ = Random(config->bagging_seed); for (int i = 0; i < num_threads_; ++i) {
random_.emplace_back(config->bagging_seed + i);
}
train_data_ = nullptr; train_data_ = nullptr;
gbdt_config_ = nullptr; gbdt_config_ = nullptr;
tree_learner_ = nullptr; tree_learner_ = nullptr;
...@@ -104,13 +113,19 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -104,13 +113,19 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
|| (gbdt_config_ != nullptr && gbdt_config_->bagging_fraction != new_config->bagging_fraction)) { || (gbdt_config_ != nullptr && gbdt_config_->bagging_fraction != new_config->bagging_fraction)) {
// if need bagging, create buffer // if need bagging, create buffer
if (new_config->bagging_fraction < 1.0 && new_config->bagging_freq > 0) { if (new_config->bagging_fraction < 1.0 && new_config->bagging_freq > 0) {
out_of_bag_data_indices_.resize(num_data_); bag_data_cnt_ =
static_cast<data_size_t>(new_config->bagging_fraction * num_data_);
bag_data_indices_.resize(num_data_); bag_data_indices_.resize(num_data_);
tmp_indices_.resize(num_data_);
offsets_buf_.resize(num_threads_);
left_cnts_buf_.resize(num_threads_);
right_cnts_buf_.resize(num_threads_);
left_write_pos_buf_.resize(num_threads_);
right_write_pos_buf_.resize(num_threads_);
} else { } else {
out_of_bag_data_cnt_ = 0;
out_of_bag_data_indices_.clear();
bag_data_cnt_ = num_data_; bag_data_cnt_ = num_data_;
bag_data_indices_.clear(); bag_data_indices_.clear();
tmp_indices_.clear();
} }
} }
train_data_ = train_data; train_data_ = train_data;
...@@ -153,53 +168,65 @@ void GBDT::AddValidDataset(const Dataset* valid_data, ...@@ -153,53 +168,65 @@ void GBDT::AddValidDataset(const Dataset* valid_data,
valid_metrics_.back().shrink_to_fit(); valid_metrics_.back().shrink_to_fit();
} }
data_size_t GBDT::BaggingHelper(data_size_t start, data_size_t cnt, data_size_t* buffer){
void GBDT::Bagging(int iter) { const int tid = omp_get_thread_num();
// if need bagging data_size_t bag_data_cnt =
if (!out_of_bag_data_indices_.empty() && iter % gbdt_config_->bagging_freq == 0) { static_cast<data_size_t>(gbdt_config_->bagging_fraction * cnt);
// if doesn't have query data
if (train_data_->metadata().query_boundaries() == nullptr) {
bag_data_cnt_ =
static_cast<data_size_t>(gbdt_config_->bagging_fraction * num_data_);
out_of_bag_data_cnt_ = num_data_ - bag_data_cnt_;
data_size_t cur_left_cnt = 0; data_size_t cur_left_cnt = 0;
data_size_t cur_right_cnt = 0; data_size_t cur_right_cnt = 0;
// random bagging, minimal unit is one record // random bagging, minimal unit is one record
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < cnt; ++i) {
double prob =
(bag_data_cnt_ - cur_left_cnt) / static_cast<double>(num_data_ - i);
if (random_.NextDouble() < prob) {
bag_data_indices_[cur_left_cnt++] = i;
} else {
out_of_bag_data_indices_[cur_right_cnt++] = i;
}
}
} else {
// if have query data
const data_size_t* query_boundaries = train_data_->metadata().query_boundaries();
data_size_t num_query = train_data_->metadata().num_queries();
data_size_t bag_query_cnt =
static_cast<data_size_t>(num_query * gbdt_config_->bagging_fraction);
data_size_t cur_left_query_cnt = 0;
data_size_t cur_left_cnt = 0;
data_size_t cur_right_cnt = 0;
// random bagging, minimal unit is one query
for (data_size_t i = 0; i < num_query; ++i) {
double prob = double prob =
(bag_query_cnt - cur_left_query_cnt) / static_cast<double>(num_query - i); (bag_data_cnt - cur_left_cnt) / static_cast<double>(cnt - i);
if (random_.NextDouble() < prob) { if (random_[tid].NextDouble() < prob) {
for (data_size_t j = query_boundaries[i]; j < query_boundaries[i + 1]; ++j) { buffer[cur_left_cnt++] = start + i;
bag_data_indices_[cur_left_cnt++] = j;
}
cur_left_query_cnt++;
} else { } else {
for (data_size_t j = query_boundaries[i]; j < query_boundaries[i + 1]; ++j) { buffer[bag_data_cnt + cur_right_cnt++] = start + i;
out_of_bag_data_indices_[cur_right_cnt++] = j;
} }
} }
CHECK(cur_left_cnt == bag_data_cnt);
return cur_left_cnt;
}
void GBDT::Bagging(int iter) {
// if need bagging
if (bag_data_cnt_ < num_data_ && iter % gbdt_config_->bagging_freq == 0) {
const data_size_t min_inner_size = 10000;
data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_;
if (inner_size < min_inner_size) { inner_size = min_inner_size; }
#pragma omp parallel for schedule(static, 1)
for (int i = 0; i < num_threads_; ++i) {
left_cnts_buf_[i] = 0;
right_cnts_buf_[i] = 0;
data_size_t cur_start = i * inner_size;
if (cur_start > num_data_) { continue; }
data_size_t cur_cnt = inner_size;
if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; }
data_size_t cur_left_count = BaggingHelper(cur_start, cur_cnt, tmp_indices_.data() + cur_start);
offsets_buf_[i] = cur_start;
left_cnts_buf_[i] = cur_left_count;
right_cnts_buf_[i] = cur_cnt - cur_left_count;
}
data_size_t left_cnt = 0;
left_write_pos_buf_[0] = 0;
right_write_pos_buf_[0] = 0;
for (int i = 1; i < num_threads_; ++i) {
left_write_pos_buf_[i] = left_write_pos_buf_[i - 1] + left_cnts_buf_[i - 1];
right_write_pos_buf_[i] = right_write_pos_buf_[i - 1] + right_cnts_buf_[i - 1];
}
left_cnt = left_write_pos_buf_[num_threads_ - 1] + left_cnts_buf_[num_threads_ - 1];
#pragma omp parallel for schedule(static, 1)
for (int i = 0; i < num_threads_; ++i) {
if (left_cnts_buf_[i] > 0) {
std::memcpy(bag_data_indices_.data() + left_write_pos_buf_[i],
tmp_indices_.data() + offsets_buf_[i], left_cnts_buf_[i] * sizeof(data_size_t));
}
if (right_cnts_buf_[i] > 0) {
std::memcpy(bag_data_indices_.data() + left_cnt + right_write_pos_buf_[i],
tmp_indices_.data() + offsets_buf_[i] + left_cnts_buf_[i], right_cnts_buf_[i] * sizeof(data_size_t));
} }
bag_data_cnt_ = cur_left_cnt;
out_of_bag_data_cnt_ = num_data_ - bag_data_cnt_;
} }
Log::Debug("Re-bagging, using %d data to train", bag_data_cnt_); Log::Debug("Re-bagging, using %d data to train", bag_data_cnt_);
// set bagging data to tree learner // set bagging data to tree learner
...@@ -209,8 +236,8 @@ void GBDT::Bagging(int iter) { ...@@ -209,8 +236,8 @@ void GBDT::Bagging(int iter) {
void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int curr_class) { void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int curr_class) {
// we need to predict out-of-bag socres of data for boosting // we need to predict out-of-bag socres of data for boosting
if (!out_of_bag_data_indices_.empty()) { if (num_data_ - bag_data_cnt_ > 0) {
train_score_updater_->AddScore(tree, out_of_bag_data_indices_.data(), out_of_bag_data_cnt_, curr_class); train_score_updater_->AddScore(tree, bag_data_indices_.data() + bag_data_cnt_, num_data_ - bag_data_cnt_, curr_class);
} }
} }
......
...@@ -219,7 +219,16 @@ protected: ...@@ -219,7 +219,16 @@ protected:
* \brief Implement bagging logic * \brief Implement bagging logic
* \param iter Current interation * \param iter Current interation
*/ */
void Bagging(int iter); virtual void Bagging(int iter);
/*!
* \brief Helper function for bagging, used for multi-threading optimization
* \param start start indice of bagging
* \param cnt count
* \param buffer output buffer
* \return count of left size
*/
virtual data_size_t BaggingHelper(data_size_t start, data_size_t cnt, data_size_t* buffer);
/*! /*!
* \brief updating score for out-of-bag data. * \brief updating score for out-of-bag data.
* Data should be update since we may re-bagging data on training * Data should be update since we may re-bagging data on training
...@@ -282,20 +291,18 @@ protected: ...@@ -282,20 +291,18 @@ protected:
std::vector<score_t> gradients_; std::vector<score_t> gradients_;
/*! \brief Secend order derivative of training data */ /*! \brief Secend order derivative of training data */
std::vector<score_t> hessians_; std::vector<score_t> hessians_;
/*! \brief Store the data indices of out-of-bag */
std::vector<data_size_t> out_of_bag_data_indices_;
/*! \brief Number of out-of-bag data */
data_size_t out_of_bag_data_cnt_;
/*! \brief Store the indices of in-bag data */ /*! \brief Store the indices of in-bag data */
std::vector<data_size_t> bag_data_indices_; std::vector<data_size_t> bag_data_indices_;
/*! \brief Number of in-bag data */ /*! \brief Number of in-bag data */
data_size_t bag_data_cnt_; data_size_t bag_data_cnt_;
/*! \brief Store the indices of in-bag data */
std::vector<data_size_t> tmp_indices_;
/*! \brief Number of training data */ /*! \brief Number of training data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Number of classes */ /*! \brief Number of classes */
int num_class_; int num_class_;
/*! \brief Random generator, used for bagging */ /*! \brief Random generator, used for bagging */
Random random_; std::vector<Random> random_;
/*! /*!
* \brief Sigmoid parameter, used for prediction. * \brief Sigmoid parameter, used for prediction.
* if > 0 means output score will transform by sigmoid function * if > 0 means output score will transform by sigmoid function
...@@ -311,6 +318,18 @@ protected: ...@@ -311,6 +318,18 @@ protected:
int num_init_iteration_; int num_init_iteration_;
/*! \brief Feature names */ /*! \brief Feature names */
std::vector<std::string> feature_names_; std::vector<std::string> feature_names_;
/*! \brief number of threads */
int num_threads_;
/*! \brief Buffer for multi-threading bagging */
std::vector<data_size_t> offsets_buf_;
/*! \brief Buffer for multi-threading bagging */
std::vector<data_size_t> left_cnts_buf_;
/*! \brief Buffer for multi-threading bagging */
std::vector<data_size_t> right_cnts_buf_;
/*! \brief Buffer for multi-threading bagging */
std::vector<data_size_t> left_write_pos_buf_;
/*! \brief Buffer for multi-threading bagging */
std::vector<data_size_t> right_write_pos_buf_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment