Commit a8bb3951 authored by Guolin Ke's avatar Guolin Ke
Browse files

add 4bits_bin

parent 67f98d66
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <LightGBM/bin.h> #include <LightGBM/bin.h>
#include "dense_bin.hpp" #include "dense_bin.hpp"
#include "dense_nbits_bin.hpp"
#include "sparse_bin.hpp" #include "sparse_bin.hpp"
#include "ordered_sparse_bin.hpp" #include "ordered_sparse_bin.hpp"
...@@ -277,14 +278,15 @@ Bin* Bin::CreateBin(data_size_t num_data, int num_bin, double sparse_rate, ...@@ -277,14 +278,15 @@ Bin* Bin::CreateBin(data_size_t num_data, int num_bin, double sparse_rate,
} }
Bin* Bin::CreateDenseBin(data_size_t num_data, int num_bin) { Bin* Bin::CreateDenseBin(data_size_t num_data, int num_bin) {
if (num_bin <= 256) { if (num_bin <= 16) {
return new Dense4bitsBin(num_data);
} else if (num_bin <= 256) {
return new DenseBin<uint8_t>(num_data); return new DenseBin<uint8_t>(num_data);
} else if (num_bin <= 65536) { } else if (num_bin <= 65536) {
return new DenseBin<uint16_t>(num_data); return new DenseBin<uint16_t>(num_data);
} else { } else {
return new DenseBin<uint32_t>(num_data); return new DenseBin<uint32_t>(num_data);
} }
} }
Bin* Bin::CreateSparseBin(data_size_t num_data, int num_bin) { Bin* Bin::CreateSparseBin(data_size_t num_data, int num_bin) {
......
...@@ -67,7 +67,7 @@ public: ...@@ -67,7 +67,7 @@ public:
HistogramBinEntry* out) const override { HistogramBinEntry* out) const override {
// use 4-way unrolling, will be faster // use 4-way unrolling, will be faster
if (data_indices != nullptr) { // if use part of data if (data_indices != nullptr) { // if use part of data
const data_size_t rest = num_data % 4; const data_size_t rest = num_data & 0x3;
data_size_t i = 0; data_size_t i = 0;
for (; i < num_data - rest; i += 4) { for (; i < num_data - rest; i += 4) {
const VAL_T bin0 = data_[data_indices[i]]; const VAL_T bin0 = data_[data_indices[i]];
...@@ -97,7 +97,7 @@ public: ...@@ -97,7 +97,7 @@ public:
++out[bin].cnt; ++out[bin].cnt;
} }
} else { // use full data } else { // use full data
const data_size_t rest = num_data % 4; const data_size_t rest = num_data & 0x3;
data_size_t i = 0; data_size_t i = 0;
for (; i < num_data - rest; i += 4) { for (; i < num_data - rest; i += 4) {
const VAL_T bin0 = data_[i]; const VAL_T bin0 = data_[i];
......
#ifndef LIGHTGBM_IO_DENSE_NBITS_BIN_HPP_
#define LIGHTGBM_IO_DENSE_NBITS_BIN_HPP_
#include <LightGBM/bin.h>
#include <vector>
#include <cstring>
#include <cstdint>
namespace LightGBM {
class Dense4bitsBin;
class Dense4bitsBinIterator: public BinIterator {
public:
explicit Dense4bitsBinIterator(const Dense4bitsBin* bin_data, uint32_t min_bin, uint32_t max_bin, uint32_t default_bin)
: bin_data_(bin_data), min_bin_(static_cast<uint8_t>(min_bin)),
max_bin_(static_cast<uint8_t>(max_bin)),
default_bin_(static_cast<uint8_t>(default_bin)) {
if (default_bin_ == 0) {
bias_ = 1;
} else {
bias_ = 0;
}
}
inline uint32_t Get(data_size_t idx) override;
inline void Reset(data_size_t) override { }
private:
const Dense4bitsBin* bin_data_;
uint8_t min_bin_;
uint8_t max_bin_;
uint8_t default_bin_;
uint8_t bias_;
};
class Dense4bitsBin: public Bin {
public:
friend Dense4bitsBinIterator;
Dense4bitsBin(data_size_t num_data)
: num_data_(num_data) {
int len = (num_data_ + 1) / 2;
data_ = std::vector<uint8_t>(len, static_cast<uint8_t>(0));
}
~Dense4bitsBin() {
}
void Push(int, data_size_t idx, uint32_t value) override {
if (buf_.empty()) {
#pragma omp critical
{
if (buf_.empty()) {
int len = (num_data_ + 1) / 2;
buf_ = std::vector<uint8_t>(len, static_cast<uint8_t>(0));
}
}
}
const int i1 = idx >> 1;
const int i2 = (idx & 1) << 2;
const uint8_t val = static_cast<uint8_t>(value) << i2;
if (i2 == 0) {
data_[i1] = val;
} else {
buf_[i1] = val;
}
}
void ReSize(data_size_t num_data) override {
if (num_data_ != num_data) {
num_data_ = num_data;
int len = (num_data_ + 1) / 2;
data_.resize(len);
}
}
BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t default_bin) const override;
void ConstructHistogram(const data_size_t* data_indices, data_size_t num_data,
const score_t* ordered_gradients, const score_t* ordered_hessians,
HistogramBinEntry* out) const override {
if (data_indices != nullptr) { // if use part of data
const data_size_t rest = num_data & 0x3;
data_size_t i = 0;
for (; i < num_data - rest; i += 4) {
data_size_t idx = data_indices[i];
const auto bin0 = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
idx = data_indices[i + 1];
const auto bin1 = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
idx = data_indices[i + 2];
const auto bin2 = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
idx = data_indices[i + 3];
const auto bin3 = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
out[bin0].sum_gradients += ordered_gradients[i];
out[bin1].sum_gradients += ordered_gradients[i + 1];
out[bin2].sum_gradients += ordered_gradients[i + 2];
out[bin3].sum_gradients += ordered_gradients[i + 3];
out[bin0].sum_hessians += ordered_hessians[i];
out[bin1].sum_hessians += ordered_hessians[i + 1];
out[bin2].sum_hessians += ordered_hessians[i + 2];
out[bin3].sum_hessians += ordered_hessians[i + 3];
++out[bin0].cnt;
++out[bin1].cnt;
++out[bin2].cnt;
++out[bin3].cnt;
}
for (; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const auto bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
out[bin].sum_gradients += ordered_gradients[i];
out[bin].sum_hessians += ordered_hessians[i];
++out[bin].cnt;
}
} else { // use full data
const data_size_t rest = num_data & 0x3;
data_size_t i = 0;
for (; i < num_data - rest; i += 4) {
int j = i >> 1;
const auto bin0 = (data_[j]) & 0xf;
const auto bin1 = (data_[j] >> 4) & 0xf;
++j;
const auto bin2 = (data_[j]) & 0xf;
const auto bin3 = (data_[j] >> 4) & 0xf;
out[bin0].sum_gradients += ordered_gradients[i];
out[bin1].sum_gradients += ordered_gradients[i + 1];
out[bin2].sum_gradients += ordered_gradients[i + 2];
out[bin3].sum_gradients += ordered_gradients[i + 3];
out[bin0].sum_hessians += ordered_hessians[i];
out[bin1].sum_hessians += ordered_hessians[i + 1];
out[bin2].sum_hessians += ordered_hessians[i + 2];
out[bin3].sum_hessians += ordered_hessians[i + 3];
++out[bin0].cnt;
++out[bin1].cnt;
++out[bin2].cnt;
++out[bin3].cnt;
}
for (; i < num_data; ++i) {
const auto bin = (data_[i >> 1] >> ((i & 1) << 2)) & 0xf;
out[bin].sum_gradients += ordered_gradients[i];
out[bin].sum_hessians += ordered_hessians[i];
++out[bin].cnt;
}
}
}
virtual data_size_t Split(
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin,
uint32_t threshold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const override {
if (num_data <= 0) { return 0; }
uint8_t th = static_cast<uint8_t>(threshold + min_bin);
uint8_t minb = static_cast<uint8_t>(min_bin);
uint8_t maxb = static_cast<uint8_t>(max_bin);
if (default_bin == 0) {
th -= 1;
}
data_size_t lte_count = 0;
data_size_t gt_count = 0;
data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count;
if (default_bin <= threshold) {
default_indices = lte_indices;
default_count = &lte_count;
}
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const auto bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
if (bin > maxb || bin < minb) {
default_indices[(*default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx;
}
}
return lte_count;
}
data_size_t num_data() const override { return num_data_; }
/*! \brief not ordered bin for dense feature */
OrderedBin* CreateOrderedBin() const override { return nullptr; }
void FinishLoad() override {
int len = (num_data_ + 1) / 2;
for (int i = 0; i < len; ++i) {
data_[i] |= buf_[i];
}
buf_.clear();
}
void LoadFromMemory(const void* memory, const std::vector<data_size_t>& local_used_indices) override {
const uint8_t* mem_data = reinterpret_cast<const uint8_t*>(memory);
if (!local_used_indices.empty()) {
for (int i = 0; i < num_data_; ++i) {
// get old bins
const data_size_t idx = local_used_indices[i];
const auto bin = (mem_data[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
// add
Push(0, i, bin);
}
} else {
for (size_t i = 0; i < data_.size(); ++i) {
data_[i] = mem_data[i];
}
}
}
void CopySubset(const Bin* full_bin, const data_size_t* used_indices, data_size_t num_used_indices) override {
auto other_bin = reinterpret_cast<const Dense4bitsBin*>(full_bin);
for (int i = 0; i < num_used_indices; ++i) {
const data_size_t idx = used_indices[i];
const auto bin = (other_bin->data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
Push(0, i, bin);
}
}
void SaveBinaryToFile(FILE* file) const override {
fwrite(data_.data(), sizeof(uint8_t), data_.size(), file);
}
size_t SizesInByte() const override {
return sizeof(uint8_t) * data_.size();
}
protected:
data_size_t num_data_;
std::vector<uint8_t> data_;
std::vector<uint8_t> buf_;
};
uint32_t Dense4bitsBinIterator::Get(data_size_t idx) {
const auto bin = (bin_data_->data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
if (bin >= min_bin_ && bin <= max_bin_) {
return bin - min_bin_ + bias_;
} else {
return default_bin_;
}
}
BinIterator* Dense4bitsBin::GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t default_bin) const {
return new Dense4bitsBinIterator(this, min_bin, max_bin, default_bin);
}
} // namespace LightGBM
#endif // LIGHTGBM_IO_DENSE_NBITS_BIN_HPP_
...@@ -216,6 +216,7 @@ ...@@ -216,6 +216,7 @@
<ClInclude Include="..\src\boosting\goss.hpp" /> <ClInclude Include="..\src\boosting\goss.hpp" />
<ClInclude Include="..\src\boosting\score_updater.hpp" /> <ClInclude Include="..\src\boosting\score_updater.hpp" />
<ClInclude Include="..\src\io\dense_bin.hpp" /> <ClInclude Include="..\src\io\dense_bin.hpp" />
<ClInclude Include="..\src\io\dense_nbits_bin.hpp" />
<ClInclude Include="..\src\io\ordered_sparse_bin.hpp" /> <ClInclude Include="..\src\io\ordered_sparse_bin.hpp" />
<ClInclude Include="..\src\io\parser.hpp" /> <ClInclude Include="..\src\io\parser.hpp" />
<ClInclude Include="..\src\io\sparse_bin.hpp" /> <ClInclude Include="..\src\io\sparse_bin.hpp" />
......
...@@ -174,6 +174,9 @@ ...@@ -174,6 +174,9 @@
<ClInclude Include="..\src\boosting\goss.hpp"> <ClInclude Include="..\src\boosting\goss.hpp">
<Filter>src\boosting</Filter> <Filter>src\boosting</Filter>
</ClInclude> </ClInclude>
<ClInclude Include="..\src\io\dense_nbits_bin.hpp">
<Filter>src\io</Filter>
</ClInclude>
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<ClCompile Include="..\src\application\application.cpp"> <ClCompile Include="..\src\application\application.cpp">
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment