Commit 5e3b7193 authored by Guolin Ke's avatar Guolin Ke
Browse files

speed up subset of 4bits_bin.

parent 834a8986
......@@ -347,7 +347,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
#endif
// if cannot learn a new tree, then stop
if (new_tree->num_leaves() <= 1) {
Log::Info("Stopped training because there are no more leafs that meet the split requirements.");
Log::Info("Stopped training because there are no more leaves that meet the split requirements.");
return true;
}
......
......@@ -206,15 +206,19 @@ public:
void LoadFromMemory(const void* memory, const std::vector<data_size_t>& local_used_indices) override {
const uint8_t* mem_data = reinterpret_cast<const uint8_t*>(memory);
if (!local_used_indices.empty()) {
for (int i = 0; i < num_data_; ++i) {
for (int i = 0; i < num_data_; i += 2) {
// get old bins
const data_size_t idx = local_used_indices[i];
const auto bin = (mem_data[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
data_size_t idx = local_used_indices[i];
const auto bin1 = static_cast<uint8_t>((mem_data[idx >> 1] >> ((idx & 1) << 2)) & 0xf);
idx = local_used_indices[i + 1];
const auto bin2 = static_cast<uint8_t>((mem_data[idx >> 1] >> ((idx & 1) << 2)) & 0xf);
// add
const int i1 = i >> 1;
const int i2 = (i & 1) << 2;
const uint8_t val = static_cast<uint8_t>(bin) << i2;
data_[i1] |= val;
data_[i1] = (bin1 | (bin2 << 4));
}
if ((num_data_ & 1) == 1) {
data_size_t idx = local_used_indices[num_data_ - 1];
data_[num_data_ / 2 + 1] = (mem_data[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
}
} else {
for (size_t i = 0; i < data_.size(); ++i) {
......@@ -225,13 +229,17 @@ public:
void CopySubset(const Bin* full_bin, const data_size_t* used_indices, data_size_t num_used_indices) override {
auto other_bin = reinterpret_cast<const Dense4bitsBin*>(full_bin);
for (int i = 0; i < num_used_indices; ++i) {
const data_size_t idx = used_indices[i];
const auto bin = (other_bin->data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
for (int i = 0; i < num_used_indices; i += 2) {
data_size_t idx = used_indices[i];
const auto bin1 = static_cast<uint8_t>((other_bin->data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf);
idx = used_indices[i + 1];
const auto bin2 = static_cast<uint8_t>((other_bin->data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf);
const int i1 = i >> 1;
const int i2 = (i & 1) << 2;
const uint8_t val = static_cast<uint8_t>(bin) << i2;
data_[i1] |= val;
data_[i1] = (bin1 | (bin2 << 4));
}
if ((num_used_indices & 1) == 1) {
data_size_t idx = used_indices[num_used_indices - 1];
data_[num_used_indices / 2 + 1] = (other_bin->data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment