Unverified Commit d6f20e37 authored by shiyu1994's avatar shiyu1994 Committed by GitHub
Browse files

fix max_block_size in train states (fix #3570) (#3575)

* remove max_block_size_ in train states (fix #3570)

* avoid zero elements per row

* add min constraint for min_block_size_
parent 1ee7c292
...@@ -57,7 +57,7 @@ class MultiValBinWrapper { ...@@ -57,7 +57,7 @@ class MultiValBinWrapper {
n_data_block_ = 1; n_data_block_ = 1;
data_block_size_ = num_data; data_block_size_ = num_data;
Threading::BlockInfo<data_size_t>(num_threads_, num_data, min_block_size_, Threading::BlockInfo<data_size_t>(num_threads_, num_data, min_block_size_,
max_block_size_, &n_data_block_, &data_block_size_); &n_data_block_, &data_block_size_);
ResizeHistBuf(hist_buf, cur_multi_val_bin, origin_hist_data); ResizeHistBuf(hist_buf, cur_multi_val_bin, origin_hist_data);
OMP_INIT_EX(); OMP_INIT_EX();
#pragma omp parallel for schedule(static) num_threads(num_threads_) #pragma omp parallel for schedule(static) num_threads(num_threads_)
...@@ -137,7 +137,6 @@ class MultiValBinWrapper { ...@@ -137,7 +137,6 @@ class MultiValBinWrapper {
const std::vector<int> feature_groups_contained_; const std::vector<int> feature_groups_contained_;
int num_threads_; int num_threads_;
int max_block_size_;
int num_bin_; int num_bin_;
int num_bin_aligned_; int num_bin_aligned_;
int n_data_block_; int n_data_block_;
......
...@@ -40,24 +40,6 @@ class Threading { ...@@ -40,24 +40,6 @@ class Threading {
} }
} }
template <typename INDEX_T>
static inline void BlockInfo(int num_threads, INDEX_T cnt,
INDEX_T min_cnt_per_block, INDEX_T max_cnt_per_block,
int* out_nblock, INDEX_T* block_size) {
CHECK(max_cnt_per_block >= min_cnt_per_block);
*out_nblock = std::min<int>(
num_threads,
static_cast<int>((cnt + min_cnt_per_block - 1) / min_cnt_per_block));
*out_nblock = std::max<int>(
*out_nblock,
static_cast<int>((cnt + max_cnt_per_block - 1) / max_cnt_per_block));
if (*out_nblock > 1) {
*block_size = SIZE_ALIGNED((cnt + (*out_nblock) - 1) / (*out_nblock));
} else {
*block_size = cnt;
}
}
template <typename INDEX_T> template <typename INDEX_T>
static inline void BlockInfoForceSize(int num_threads, INDEX_T cnt, static inline void BlockInfoForceSize(int num_threads, INDEX_T cnt,
INDEX_T min_cnt_per_block, INDEX_T min_cnt_per_block,
......
...@@ -12,7 +12,6 @@ MultiValBinWrapper::MultiValBinWrapper(MultiValBin* bin, data_size_t num_data, ...@@ -12,7 +12,6 @@ MultiValBinWrapper::MultiValBinWrapper(MultiValBin* bin, data_size_t num_data,
const std::vector<int>& feature_groups_contained): const std::vector<int>& feature_groups_contained):
feature_groups_contained_(feature_groups_contained) { feature_groups_contained_(feature_groups_contained) {
num_threads_ = OMP_NUM_THREADS(); num_threads_ = OMP_NUM_THREADS();
max_block_size_ = num_data;
num_data_ = num_data; num_data_ = num_data;
multi_val_bin_.reset(bin); multi_val_bin_.reset(bin);
if (bin == nullptr) { if (bin == nullptr) {
...@@ -39,8 +38,10 @@ void MultiValBinWrapper::InitTrain(const std::vector<int>& group_feature_start, ...@@ -39,8 +38,10 @@ void MultiValBinWrapper::InitTrain(const std::vector<int>& group_feature_start,
if (cur_multi_val_bin != nullptr) { if (cur_multi_val_bin != nullptr) {
num_bin_ = cur_multi_val_bin->num_bin(); num_bin_ = cur_multi_val_bin->num_bin();
num_bin_aligned_ = (num_bin_ + kAlignedSize - 1) / kAlignedSize * kAlignedSize; num_bin_aligned_ = (num_bin_ + kAlignedSize - 1) / kAlignedSize * kAlignedSize;
auto num_element_per_row = cur_multi_val_bin->num_element_per_row();
min_block_size_ = std::min<int>(static_cast<int>(0.3f * num_bin_ / min_block_size_ = std::min<int>(static_cast<int>(0.3f * num_bin_ /
cur_multi_val_bin->num_element_per_row()) + 1, 1024); (num_element_per_row + kZeroThreshold)) + 1, 1024);
min_block_size_ = std::max<int>(min_block_size_, 32);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment