Unverified Commit d6f20e37 authored by shiyu1994's avatar shiyu1994 Committed by GitHub
Browse files

fix max_block_size in train states (fix #3570) (#3575)

* remove max_block_size_ in train states (fix #3570)

* avoid zero elements per row

* add min constraint for min_block_size_
parent 1ee7c292
......@@ -57,7 +57,7 @@ class MultiValBinWrapper {
n_data_block_ = 1;
data_block_size_ = num_data;
Threading::BlockInfo<data_size_t>(num_threads_, num_data, min_block_size_,
max_block_size_, &n_data_block_, &data_block_size_);
&n_data_block_, &data_block_size_);
ResizeHistBuf(hist_buf, cur_multi_val_bin, origin_hist_data);
OMP_INIT_EX();
#pragma omp parallel for schedule(static) num_threads(num_threads_)
......@@ -137,7 +137,6 @@ class MultiValBinWrapper {
const std::vector<int> feature_groups_contained_;
int num_threads_;
int max_block_size_;
int num_bin_;
int num_bin_aligned_;
int n_data_block_;
......
......@@ -40,24 +40,6 @@ class Threading {
}
}
template <typename INDEX_T>
static inline void BlockInfo(int num_threads, INDEX_T cnt,
INDEX_T min_cnt_per_block, INDEX_T max_cnt_per_block,
int* out_nblock, INDEX_T* block_size) {
CHECK(max_cnt_per_block >= min_cnt_per_block);
*out_nblock = std::min<int>(
num_threads,
static_cast<int>((cnt + min_cnt_per_block - 1) / min_cnt_per_block));
*out_nblock = std::max<int>(
*out_nblock,
static_cast<int>((cnt + max_cnt_per_block - 1) / max_cnt_per_block));
if (*out_nblock > 1) {
*block_size = SIZE_ALIGNED((cnt + (*out_nblock) - 1) / (*out_nblock));
} else {
*block_size = cnt;
}
}
template <typename INDEX_T>
static inline void BlockInfoForceSize(int num_threads, INDEX_T cnt,
INDEX_T min_cnt_per_block,
......
......@@ -12,7 +12,6 @@ MultiValBinWrapper::MultiValBinWrapper(MultiValBin* bin, data_size_t num_data,
const std::vector<int>& feature_groups_contained):
feature_groups_contained_(feature_groups_contained) {
num_threads_ = OMP_NUM_THREADS();
max_block_size_ = num_data;
num_data_ = num_data;
multi_val_bin_.reset(bin);
if (bin == nullptr) {
......@@ -39,8 +38,10 @@ void MultiValBinWrapper::InitTrain(const std::vector<int>& group_feature_start,
if (cur_multi_val_bin != nullptr) {
num_bin_ = cur_multi_val_bin->num_bin();
num_bin_aligned_ = (num_bin_ + kAlignedSize - 1) / kAlignedSize * kAlignedSize;
auto num_element_per_row = cur_multi_val_bin->num_element_per_row();
min_block_size_ = std::min<int>(static_cast<int>(0.3f * num_bin_ /
cur_multi_val_bin->num_element_per_row()) + 1, 1024);
(num_element_per_row + kZeroThreshold)) + 1, 1024);
min_block_size_ = std::max<int>(min_block_size_, 32);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment