Unverified Commit 8f5cd522 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix bug for multi-val-bin construction (#2841)

* fix

* Update multi_val_sparse_bin.hpp
parent 49ea824f
...@@ -694,8 +694,7 @@ namespace LightGBM { ...@@ -694,8 +694,7 @@ namespace LightGBM {
int num_bin, int num_bin,
double estimate_element_per_row) { double estimate_element_per_row) {
size_t estimate_total_entries = size_t estimate_total_entries =
static_cast<size_t>(estimate_element_per_row * 1.1) * static_cast<size_t>(estimate_element_per_row * 1.1 * num_data);
static_cast<size_t>(num_data);
if (estimate_total_entries <= std::numeric_limits<uint16_t>::max()) { if (estimate_total_entries <= std::numeric_limits<uint16_t>::max()) {
if (num_bin <= 256) { if (num_bin <= 256) {
return new MultiValSparseBin<uint16_t, uint8_t>( return new MultiValSparseBin<uint16_t, uint8_t>(
......
...@@ -24,9 +24,7 @@ class MultiValSparseBin : public MultiValBin { ...@@ -24,9 +24,7 @@ class MultiValSparseBin : public MultiValBin {
num_bin_(num_bin), num_bin_(num_bin),
estimate_element_per_row_(estimate_element_per_row) { estimate_element_per_row_(estimate_element_per_row) {
row_ptr_.resize(num_data_ + 1, 0); row_ptr_.resize(num_data_ + 1, 0);
INDEX_T estimate_num_data = INDEX_T estimate_num_data = static_cast<INDEX_T>(estimate_element_per_row_ * 1.1 * num_data_);
static_cast<INDEX_T>(estimate_element_per_row_ * 1.1) *
static_cast<INDEX_T>(num_data_);
int num_threads = 1; int num_threads = 1;
#pragma omp parallel #pragma omp parallel
#pragma omp master #pragma omp master
...@@ -73,7 +71,7 @@ class MultiValSparseBin : public MultiValBin { ...@@ -73,7 +71,7 @@ class MultiValSparseBin : public MultiValBin {
void MergeData(const INDEX_T* sizes) { void MergeData(const INDEX_T* sizes) {
Common::FunctionTimer fun_time("MultiValSparseBin::MergeData", global_timer); Common::FunctionTimer fun_time("MultiValSparseBin::MergeData", global_timer);
for (INDEX_T i = 0; i < static_cast<INDEX_T>(num_data_); ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
row_ptr_[i + 1] += row_ptr_[i]; row_ptr_[i + 1] += row_ptr_[i];
} }
if (t_data_.size() > 0) { if (t_data_.size() > 0) {
...@@ -83,7 +81,7 @@ class MultiValSparseBin : public MultiValBin { ...@@ -83,7 +81,7 @@ class MultiValSparseBin : public MultiValBin {
offsets[tid + 1] = offsets[tid] + sizes[tid + 1]; offsets[tid + 1] = offsets[tid] + sizes[tid + 1];
} }
data_.resize(row_ptr_[num_data_]); data_.resize(row_ptr_[num_data_]);
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static, 1)
for (int tid = 0; tid < static_cast<int>(t_data_.size()); ++tid) { for (int tid = 0; tid < static_cast<int>(t_data_.size()); ++tid) {
std::copy_n(t_data_[tid].data(), sizes[tid + 1], std::copy_n(t_data_[tid].data(), sizes[tid + 1],
data_.data() + offsets[tid]); data_.data() + offsets[tid]);
...@@ -199,8 +197,7 @@ class MultiValSparseBin : public MultiValBin { ...@@ -199,8 +197,7 @@ class MultiValSparseBin : public MultiValBin {
auto other_bin = dynamic_cast<const MultiValSparseBin<INDEX_T, VAL_T>*>(full_bin); auto other_bin = dynamic_cast<const MultiValSparseBin<INDEX_T, VAL_T>*>(full_bin);
row_ptr_.resize(num_data_ + 1, 0); row_ptr_.resize(num_data_ + 1, 0);
INDEX_T estimate_num_data = INDEX_T estimate_num_data =
static_cast<INDEX_T>(estimate_element_per_row_ * 1.1) * static_cast<INDEX_T>(estimate_element_per_row_ * 1.1 * num_data_);
static_cast<INDEX_T>(num_data_);
data_.clear(); data_.clear();
data_.reserve(estimate_num_data); data_.reserve(estimate_num_data);
for (data_size_t i = 0; i < num_used_indices; ++i) { for (data_size_t i = 0; i < num_used_indices; ++i) {
...@@ -224,8 +221,7 @@ class MultiValSparseBin : public MultiValBin { ...@@ -224,8 +221,7 @@ class MultiValSparseBin : public MultiValBin {
num_bin_ = num_bin; num_bin_ = num_bin;
estimate_element_per_row_ = estimate_element_per_row; estimate_element_per_row_ = estimate_element_per_row;
INDEX_T estimate_num_data = INDEX_T estimate_num_data =
static_cast<INDEX_T>(estimate_element_per_row_ * 1.1) * static_cast<INDEX_T>(estimate_element_per_row_ * 1.1 * num_data_);
static_cast<INDEX_T>(num_data_);
size_t npart = 1 + t_data_.size(); size_t npart = 1 + t_data_.size();
INDEX_T avg_num_data = static_cast<INDEX_T>(estimate_num_data / npart); INDEX_T avg_num_data = static_cast<INDEX_T>(estimate_num_data / npart);
if (static_cast<INDEX_T>(data_.size()) < avg_num_data) { if (static_cast<INDEX_T>(data_.size()) < avg_num_data) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment