Unverified Commit 4c3be168 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

slightly reduce the cost of multi-val bin construct. (#2728)

* slightly reduce the cost of multi-val bin construct.

* Update multi_val_sparse_bin.hpp
parent 85889901
...@@ -493,6 +493,7 @@ void PushDataToMultiValBin(int num_threads, data_size_t num_data, const std::vec ...@@ -493,6 +493,7 @@ void PushDataToMultiValBin(int num_threads, data_size_t num_data, const std::vec
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int tid = 0; tid < n_block; ++tid) { for (int tid = 0; tid < n_block; ++tid) {
std::vector<uint32_t> cur_data; std::vector<uint32_t> cur_data;
cur_data.reserve(most_freq_bins.size());
data_size_t start = tid * block_size; data_size_t start = tid * block_size;
data_size_t end = std::min(num_data, start + block_size); data_size_t end = std::min(num_data, start + block_size);
for (size_t j = 0; j < most_freq_bins.size(); ++j) { for (size_t j = 0; j < most_freq_bins.size(); ++j) {
...@@ -517,14 +518,13 @@ void PushDataToMultiValBin(int num_threads, data_size_t num_data, const std::vec ...@@ -517,14 +518,13 @@ void PushDataToMultiValBin(int num_threads, data_size_t num_data, const std::vec
} else { } else {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int tid = 0; tid < n_block; ++tid) { for (int tid = 0; tid < n_block; ++tid) {
std::vector<uint32_t> cur_data; std::vector<uint32_t> cur_data(most_freq_bins.size(), 0);
data_size_t start = tid * block_size; data_size_t start = tid * block_size;
data_size_t end = std::min(num_data, start + block_size); data_size_t end = std::min(num_data, start + block_size);
for (size_t j = 0; j < most_freq_bins.size(); ++j) { for (size_t j = 0; j < most_freq_bins.size(); ++j) {
iters[tid][j]->Reset(start); iters[tid][j]->Reset(start);
} }
for (data_size_t i = start; i < end; ++i) { for (data_size_t i = start; i < end; ++i) {
cur_data.clear();
for (size_t j = 0; j < most_freq_bins.size(); ++j) { for (size_t j = 0; j < most_freq_bins.size(); ++j) {
auto cur_bin = iters[tid][j]->Get(i); auto cur_bin = iters[tid][j]->Get(i);
if (cur_bin == most_freq_bins[j]) { if (cur_bin == most_freq_bins[j]) {
...@@ -535,7 +535,7 @@ void PushDataToMultiValBin(int num_threads, data_size_t num_data, const std::vec ...@@ -535,7 +535,7 @@ void PushDataToMultiValBin(int num_threads, data_size_t num_data, const std::vec
cur_bin -= 1; cur_bin -= 1;
} }
} }
cur_data.push_back(cur_bin); cur_data[j] = cur_bin;
} }
ret->PushOneRow(tid, i, cur_data); ret->PushOneRow(tid, i, cur_data);
} }
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
#include <LightGBM/bin.h> #include <LightGBM/bin.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <cstdint> #include <cstdint>
#include <cstring> #include <cstring>
...@@ -39,7 +38,9 @@ public: ...@@ -39,7 +38,9 @@ public:
void PushOneRow(int , data_size_t idx, const std::vector<uint32_t>& values) override { void PushOneRow(int , data_size_t idx, const std::vector<uint32_t>& values) override {
auto start = RowPtr(idx); auto start = RowPtr(idx);
#ifdef DEBUG
CHECK(num_feature_ == static_cast<int>(values.size())); CHECK(num_feature_ == static_cast<int>(values.size()));
#endif // DEBUG
for (auto i = 0; i < num_feature_; ++i) { for (auto i = 0; i < num_feature_; ++i) {
data_[start + i] = static_cast<VAL_T>(values[i]); data_[start + i] = static_cast<VAL_T>(values[i]);
} }
......
...@@ -65,12 +65,16 @@ public: ...@@ -65,12 +65,16 @@ public:
row_ptr_[i + 1] += row_ptr_[i]; row_ptr_[i + 1] += row_ptr_[i];
} }
if (t_data_.size() > 0) { if (t_data_.size() > 0) {
size_t offset = data_.size(); std::vector<size_t> offsets;
offsets.push_back(data_.size());
for (size_t tid = 0; tid < t_data_.size() - 1; ++tid) {
offsets.push_back(offsets.back() + t_data_[tid].size());
}
data_.resize(row_ptr_[num_data_]); data_.resize(row_ptr_[num_data_]);
for (size_t tid = 0; tid < t_data_.size(); ++tid) { #pragma omp parallel for schedule(static)
std::memcpy(data_.data() + offset, t_data_[tid].data(), t_data_[tid].size() * sizeof(VAL_T)); for (int tid = 0; tid < static_cast<int>(t_data_.size()); ++tid) {
offset += t_data_[tid].size(); std::copy_n(t_data_[tid].data(), t_data_[tid].size(),
t_data_[tid].clear(); data_.data() + offsets[tid]);
} }
} }
row_ptr_.shrink_to_fit(); row_ptr_.shrink_to_fit();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment