Unverified Commit 9c386db1 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

reduce the overhead of OMP_NUM_THREADS in training (#2852)

* reduce overhead of get num_threads

* add warning

* Apply suggestions from code review

* Apply suggestions from code review
parent d0bec9e9
...@@ -277,6 +277,7 @@ class Parser { ...@@ -277,6 +277,7 @@ class Parser {
}; };
struct TrainingShareStates { struct TrainingShareStates {
int num_threads = 0;
bool is_colwise = true; bool is_colwise = true;
bool is_use_subcol = false; bool is_use_subcol = false;
bool is_use_subrow = false; bool is_use_subrow = false;
...@@ -298,7 +299,7 @@ struct TrainingShareStates { ...@@ -298,7 +299,7 @@ struct TrainingShareStates {
return; return;
} }
multi_val_bin.reset(bin); multi_val_bin.reset(bin);
int num_threads = OMP_NUM_THREADS(); num_threads = OMP_NUM_THREADS();
num_bin_aligned = num_bin_aligned =
(bin->num_bin() + kAlignedSize - 1) / kAlignedSize * kAlignedSize; (bin->num_bin() + kAlignedSize - 1) / kAlignedSize * kAlignedSize;
size_t new_size = static_cast<size_t>(num_bin_aligned) * 2 * num_threads; size_t new_size = static_cast<size_t>(num_bin_aligned) * 2 * num_threads;
......
...@@ -1208,15 +1208,13 @@ void Dataset::ConstructHistogramsMultiVal(const data_size_t* data_indices, ...@@ -1208,15 +1208,13 @@ void Dataset::ConstructHistogramsMultiVal(const data_size_t* data_indices,
if (multi_val_bin == nullptr) { if (multi_val_bin == nullptr) {
return; return;
} }
int num_threads = OMP_NUM_THREADS();
global_timer.Start("Dataset::sparse_bin_histogram"); global_timer.Start("Dataset::sparse_bin_histogram");
const int num_bin = multi_val_bin->num_bin(); const int num_bin = multi_val_bin->num_bin();
const int num_bin_aligned = const int num_bin_aligned =
(num_bin + kAlignedSize - 1) / kAlignedSize * kAlignedSize; (num_bin + kAlignedSize - 1) / kAlignedSize * kAlignedSize;
int n_data_block = 1; int n_data_block = 1;
int data_block_size = num_data; int data_block_size = num_data;
Threading::BlockInfo<data_size_t>(num_threads, num_data, 1024, Threading::BlockInfo<data_size_t>(share_state->num_threads, num_data, 1024,
&n_data_block, &data_block_size); &n_data_block, &data_block_size);
const size_t buf_size = const size_t buf_size =
static_cast<size_t>(n_data_block - 1) * num_bin_aligned * 2; static_cast<size_t>(n_data_block - 1) * num_bin_aligned * 2;
...@@ -1263,7 +1261,7 @@ void Dataset::ConstructHistogramsMultiVal(const data_size_t* data_indices, ...@@ -1263,7 +1261,7 @@ void Dataset::ConstructHistogramsMultiVal(const data_size_t* data_indices,
global_timer.Start("Dataset::sparse_bin_histogram_merge"); global_timer.Start("Dataset::sparse_bin_histogram_merge");
int n_bin_block = 1; int n_bin_block = 1;
int bin_block_size = num_bin; int bin_block_size = num_bin;
Threading::BlockInfo<data_size_t>(num_threads, num_bin, 512, &n_bin_block, Threading::BlockInfo<data_size_t>(share_state->num_threads, num_bin, 512, &n_bin_block,
&bin_block_size); &bin_block_size);
if (!share_state->is_constant_hessian) { if (!share_state->is_constant_hessian) {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
......
...@@ -165,9 +165,8 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplits() { ...@@ -165,9 +165,8 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>&, bool) { void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>&, bool) {
int num_threads = OMP_NUM_THREADS(); std::vector<SplitInfo> smaller_bests_per_thread(this->share_state_->num_threads);
std::vector<SplitInfo> smaller_bests_per_thread(num_threads, SplitInfo()); std::vector<SplitInfo> larger_bests_per_thread(this->share_state_->num_threads);
std::vector<SplitInfo> larger_bests_per_thread(num_threads, SplitInfo());
std::vector<int8_t> smaller_node_used_features(this->num_features_, 1); std::vector<int8_t> smaller_node_used_features(this->num_features_, 1);
std::vector<int8_t> larger_node_used_features(this->num_features_, 1); std::vector<int8_t> larger_node_used_features(this->num_features_, 1);
if (this->config_->feature_fraction_bynode < 1.0f) { if (this->config_->feature_fraction_bynode < 1.0f) {
......
...@@ -152,7 +152,15 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians ...@@ -152,7 +152,15 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
Common::FunctionTimer fun_timer("SerialTreeLearner::Train", global_timer); Common::FunctionTimer fun_timer("SerialTreeLearner::Train", global_timer);
gradients_ = gradients; gradients_ = gradients;
hessians_ = hessians; hessians_ = hessians;
int num_threads = OMP_NUM_THREADS();
if (share_state_->num_threads != num_threads && share_state_->num_threads > 0){
Log::Warning(
"Detect num_threads changed durning traing (from %d to %d), may cause "
"unexpected errors.",
share_state_->num_threads, num_threads);
}
share_state_->num_threads = num_threads;
// some initial works before training // some initial works before training
BeforeTrain(); BeforeTrain();
...@@ -403,9 +411,8 @@ void SerialTreeLearner::FindBestSplitsFromHistograms( ...@@ -403,9 +411,8 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(
const std::vector<int8_t>& is_feature_used, bool use_subtract) { const std::vector<int8_t>& is_feature_used, bool use_subtract) {
Common::FunctionTimer fun_timer( Common::FunctionTimer fun_timer(
"SerialTreeLearner::FindBestSplitsFromHistograms", global_timer); "SerialTreeLearner::FindBestSplitsFromHistograms", global_timer);
int num_threads = OMP_NUM_THREADS(); std::vector<SplitInfo> smaller_best(share_state_->num_threads);
std::vector<SplitInfo> smaller_best(num_threads); std::vector<SplitInfo> larger_best(share_state_->num_threads);
std::vector<SplitInfo> larger_best(num_threads);
std::vector<int8_t> smaller_node_used_features(num_features_, 1); std::vector<int8_t> smaller_node_used_features(num_features_, 1);
std::vector<int8_t> larger_node_used_features(num_features_, 1); std::vector<int8_t> larger_node_used_features(num_features_, 1);
if (config_->feature_fraction_bynode < 1.0f) { if (config_->feature_fraction_bynode < 1.0f) {
......
...@@ -349,9 +349,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() { ...@@ -349,9 +349,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>&, bool) { void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>&, bool) {
int num_threads = OMP_NUM_THREADS(); std::vector<SplitInfo> smaller_bests_per_thread(this->share_state_->num_threads);
std::vector<SplitInfo> smaller_bests_per_thread(num_threads); std::vector<SplitInfo> larger_bests_per_thread(this->share_state_->num_threads);
std::vector<SplitInfo> larger_best_per_thread(num_threads);
std::vector<int8_t> smaller_node_used_features(this->num_features_, 1); std::vector<int8_t> smaller_node_used_features(this->num_features_, 1);
std::vector<int8_t> larger_node_used_features(this->num_features_, 1); std::vector<int8_t> larger_node_used_features(this->num_features_, 1);
if (this->config_->feature_fraction_bynode < 1.0f) { if (this->config_->feature_fraction_bynode < 1.0f) {
...@@ -395,8 +394,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons ...@@ -395,8 +394,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons
real_feature_index, real_feature_index,
larger_node_used_features[feature_index], larger_node_used_features[feature_index],
GetGlobalDataCountInLeaf(larger_leaf_splits_global_->leaf_index()), GetGlobalDataCountInLeaf(larger_leaf_splits_global_->leaf_index()),
larger_leaf_splits_global_.get(), larger_leaf_splits_global_.get(), &larger_bests_per_thread[tid]);
&larger_best_per_thread[tid]);
} }
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
...@@ -408,8 +406,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons ...@@ -408,8 +406,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons
if (this->larger_leaf_splits_ != nullptr && this->larger_leaf_splits_->leaf_index() >= 0) { if (this->larger_leaf_splits_ != nullptr && this->larger_leaf_splits_->leaf_index() >= 0) {
leaf = this->larger_leaf_splits_->leaf_index(); leaf = this->larger_leaf_splits_->leaf_index();
auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_best_per_thread); auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_bests_per_thread);
this->best_split_per_leaf_[leaf] = larger_best_per_thread[larger_best_idx]; this->best_split_per_leaf_[leaf] = larger_bests_per_thread[larger_best_idx];
} }
// find local best // find local best
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment