Unverified Commit 4c3be168 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

slightly reduce the cost of multi-val bin construct. (#2728)

* slightly reduce the cost of multi-val bin construct.

* Update multi_val_sparse_bin.hpp
parent 85889901
......@@ -493,6 +493,7 @@ void PushDataToMultiValBin(int num_threads, data_size_t num_data, const std::vec
#pragma omp parallel for schedule(static)
for (int tid = 0; tid < n_block; ++tid) {
std::vector<uint32_t> cur_data;
cur_data.reserve(most_freq_bins.size());
data_size_t start = tid * block_size;
data_size_t end = std::min(num_data, start + block_size);
for (size_t j = 0; j < most_freq_bins.size(); ++j) {
......@@ -517,14 +518,13 @@ void PushDataToMultiValBin(int num_threads, data_size_t num_data, const std::vec
} else {
#pragma omp parallel for schedule(static)
for (int tid = 0; tid < n_block; ++tid) {
std::vector<uint32_t> cur_data;
std::vector<uint32_t> cur_data(most_freq_bins.size(), 0);
data_size_t start = tid * block_size;
data_size_t end = std::min(num_data, start + block_size);
for (size_t j = 0; j < most_freq_bins.size(); ++j) {
iters[tid][j]->Reset(start);
}
for (data_size_t i = start; i < end; ++i) {
cur_data.clear();
for (size_t j = 0; j < most_freq_bins.size(); ++j) {
auto cur_bin = iters[tid][j]->Get(i);
if (cur_bin == most_freq_bins[j]) {
......@@ -535,7 +535,7 @@ void PushDataToMultiValBin(int num_threads, data_size_t num_data, const std::vec
cur_bin -= 1;
}
}
cur_data.push_back(cur_bin);
cur_data[j] = cur_bin;
}
ret->PushOneRow(tid, i, cur_data);
}
......
......@@ -7,7 +7,6 @@
#include <LightGBM/bin.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <cstdint>
#include <cstring>
......@@ -39,7 +38,9 @@ public:
void PushOneRow(int , data_size_t idx, const std::vector<uint32_t>& values) override {
auto start = RowPtr(idx);
#ifdef DEBUG
CHECK(num_feature_ == static_cast<int>(values.size()));
#endif // DEBUG
for (auto i = 0; i < num_feature_; ++i) {
data_[start + i] = static_cast<VAL_T>(values[i]);
}
......
......@@ -65,12 +65,16 @@ public:
row_ptr_[i + 1] += row_ptr_[i];
}
if (t_data_.size() > 0) {
size_t offset = data_.size();
std::vector<size_t> offsets;
offsets.push_back(data_.size());
for (size_t tid = 0; tid < t_data_.size() - 1; ++tid) {
offsets.push_back(offsets.back() + t_data_[tid].size());
}
data_.resize(row_ptr_[num_data_]);
for (size_t tid = 0; tid < t_data_.size(); ++tid) {
std::memcpy(data_.data() + offset, t_data_[tid].data(), t_data_[tid].size() * sizeof(VAL_T));
offset += t_data_[tid].size();
t_data_[tid].clear();
#pragma omp parallel for schedule(static)
for (int tid = 0; tid < static_cast<int>(t_data_.size()); ++tid) {
std::copy_n(t_data_[tid].data(), t_data_[tid].size(),
data_.data() + offsets[tid]);
}
}
row_ptr_.shrink_to_fit();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment