Unverified Commit 8f5cd522 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix bug for multi-val-bin construction (#2841)

* fix

* Update multi_val_sparse_bin.hpp
parent 49ea824f
......@@ -694,8 +694,7 @@ namespace LightGBM {
int num_bin,
double estimate_element_per_row) {
size_t estimate_total_entries =
static_cast<size_t>(estimate_element_per_row * 1.1) *
static_cast<size_t>(num_data);
static_cast<size_t>(estimate_element_per_row * 1.1 * num_data);
if (estimate_total_entries <= std::numeric_limits<uint16_t>::max()) {
if (num_bin <= 256) {
return new MultiValSparseBin<uint16_t, uint8_t>(
......
......@@ -24,9 +24,7 @@ class MultiValSparseBin : public MultiValBin {
num_bin_(num_bin),
estimate_element_per_row_(estimate_element_per_row) {
row_ptr_.resize(num_data_ + 1, 0);
INDEX_T estimate_num_data =
static_cast<INDEX_T>(estimate_element_per_row_ * 1.1) *
static_cast<INDEX_T>(num_data_);
INDEX_T estimate_num_data = static_cast<INDEX_T>(estimate_element_per_row_ * 1.1 * num_data_);
int num_threads = 1;
#pragma omp parallel
#pragma omp master
......@@ -73,7 +71,7 @@ class MultiValSparseBin : public MultiValBin {
void MergeData(const INDEX_T* sizes) {
Common::FunctionTimer fun_time("MultiValSparseBin::MergeData", global_timer);
for (INDEX_T i = 0; i < static_cast<INDEX_T>(num_data_); ++i) {
for (data_size_t i = 0; i < num_data_; ++i) {
row_ptr_[i + 1] += row_ptr_[i];
}
if (t_data_.size() > 0) {
......@@ -83,7 +81,7 @@ class MultiValSparseBin : public MultiValBin {
offsets[tid + 1] = offsets[tid] + sizes[tid + 1];
}
data_.resize(row_ptr_[num_data_]);
#pragma omp parallel for schedule(static)
#pragma omp parallel for schedule(static, 1)
for (int tid = 0; tid < static_cast<int>(t_data_.size()); ++tid) {
std::copy_n(t_data_[tid].data(), sizes[tid + 1],
data_.data() + offsets[tid]);
......@@ -199,8 +197,7 @@ class MultiValSparseBin : public MultiValBin {
auto other_bin = dynamic_cast<const MultiValSparseBin<INDEX_T, VAL_T>*>(full_bin);
row_ptr_.resize(num_data_ + 1, 0);
INDEX_T estimate_num_data =
static_cast<INDEX_T>(estimate_element_per_row_ * 1.1) *
static_cast<INDEX_T>(num_data_);
static_cast<INDEX_T>(estimate_element_per_row_ * 1.1 * num_data_);
data_.clear();
data_.reserve(estimate_num_data);
for (data_size_t i = 0; i < num_used_indices; ++i) {
......@@ -224,8 +221,7 @@ class MultiValSparseBin : public MultiValBin {
num_bin_ = num_bin;
estimate_element_per_row_ = estimate_element_per_row;
INDEX_T estimate_num_data =
static_cast<INDEX_T>(estimate_element_per_row_ * 1.1) *
static_cast<INDEX_T>(num_data_);
static_cast<INDEX_T>(estimate_element_per_row_ * 1.1 * num_data_);
size_t npart = 1 + t_data_.size();
INDEX_T avg_num_data = static_cast<INDEX_T>(estimate_num_data / npart);
if (static_cast<INDEX_T>(data_.size()) < avg_num_data) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment