dense_nbits_bin.hpp 11.9 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
9
10
#ifndef LIGHTGBM_IO_DENSE_NBITS_BIN_HPP_
#define LIGHTGBM_IO_DENSE_NBITS_BIN_HPP_

#include <LightGBM/bin.h>

#include <cstdint>
11
12
#include <cstring>
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
13
14
15
16
17

namespace LightGBM {

class Dense4bitsBin;

Guolin Ke's avatar
Guolin Ke committed
18
class Dense4bitsBinIterator : public BinIterator {
19
 public:
Guolin Ke's avatar
Guolin Ke committed
20
  explicit Dense4bitsBinIterator(const Dense4bitsBin* bin_data, uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin)
Guolin Ke's avatar
Guolin Ke committed
21
22
    : bin_data_(bin_data), min_bin_(static_cast<uint8_t>(min_bin)),
    max_bin_(static_cast<uint8_t>(max_bin)),
Guolin Ke's avatar
Guolin Ke committed
23
24
    most_freq_bin_(static_cast<uint8_t>(most_freq_bin)) {
    if (most_freq_bin_ == 0) {
25
      offset_ = 1;
Guolin Ke's avatar
Guolin Ke committed
26
    } else {
27
      offset_ = 0;
Guolin Ke's avatar
Guolin Ke committed
28
29
    }
  }
30
  inline uint32_t RawGet(data_size_t idx) override;
Guolin Ke's avatar
Guolin Ke committed
31
  inline uint32_t Get(data_size_t idx) override;
Guolin Ke's avatar
Guolin Ke committed
32
  inline void Reset(data_size_t) override {}
Nikita Titov's avatar
Nikita Titov committed
33

34
 private:
Guolin Ke's avatar
Guolin Ke committed
35
36
37
  const Dense4bitsBin* bin_data_;
  uint8_t min_bin_;
  uint8_t max_bin_;
Guolin Ke's avatar
Guolin Ke committed
38
  uint8_t most_freq_bin_;
39
  uint8_t offset_;
Guolin Ke's avatar
Guolin Ke committed
40
41
};

Guolin Ke's avatar
Guolin Ke committed
42
class Dense4bitsBin : public Bin {
43
 public:
Guolin Ke's avatar
Guolin Ke committed
44
  friend Dense4bitsBinIterator;
Guolin Ke's avatar
Guolin Ke committed
45
  explicit Dense4bitsBin(data_size_t num_data)
Guolin Ke's avatar
Guolin Ke committed
46
47
    : num_data_(num_data) {
    int len = (num_data_ + 1) / 2;
48
    data_.resize(len, static_cast<uint8_t>(0));
Guolin Ke's avatar
Guolin Ke committed
49
    buf_ = std::vector<uint8_t>(len, static_cast<uint8_t>(0));
Guolin Ke's avatar
Guolin Ke committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
  }

  ~Dense4bitsBin() {
  }

  void Push(int, data_size_t idx, uint32_t value) override {
    const int i1 = idx >> 1;
    const int i2 = (idx & 1) << 2;
    const uint8_t val = static_cast<uint8_t>(value) << i2;
    if (i2 == 0) {
      data_[i1] = val;
    } else {
      buf_[i1] = val;
    }
  }

  void ReSize(data_size_t num_data) override {
    if (num_data_ != num_data) {
      num_data_ = num_data;
Guolin Ke's avatar
Guolin Ke committed
69
      const int len = (num_data_ + 1) / 2;
Guolin Ke's avatar
Guolin Ke committed
70
71
72
73
      data_.resize(len);
    }
  }

Guolin Ke's avatar
Guolin Ke committed
74
  inline BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin) const override;
Guolin Ke's avatar
Guolin Ke committed
75

76
77
78
79
80
81
82
83
  #define ACC_GH(hist, i, g, h) \
  const auto ti = (i) << 1; \
  hist[ti] += g; \
  hist[ti + 1] += h; \

  template<bool use_indices, bool use_prefetch, bool use_hessians>
  void ConstructHistogramInner(const data_size_t* data_indices, data_size_t start, data_size_t end,
    const score_t* ordered_gradients, const score_t* ordered_hessians, hist_t* out) const {
84
    data_size_t i = start;
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99

    if (use_prefetch) {
      const data_size_t pf_offset = 64;
      const data_size_t pf_end = end - pf_offset;
      for (; i < pf_end; ++i) {
        const auto idx = use_indices ? data_indices[i] : i;
        const auto pf_idx = use_indices ? data_indices[i + pf_offset] : i + pf_offset;
        PREFETCH_T0(data_.data() + (pf_idx >> 1));
        const auto bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
        if (use_hessians) {
          ACC_GH(out, bin, ordered_gradients[i], ordered_hessians[i]);
        } else {
          ACC_GH(out, bin, ordered_gradients[i], 1.0f);
        }
      }
100
    }
101
102
    for (; i < end; ++i) {
      const auto idx = use_indices ? data_indices[i] : i;
103
      const auto bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
104
105
106
107
108
      if (use_hessians) {
        ACC_GH(out, bin, ordered_gradients[i], ordered_hessians[i]);
      } else {
        ACC_GH(out, bin, ordered_gradients[i], 1.0f);
      }
109
110
    }
  }
111
112
113
114
115
116
117
  #undef ACC_GH

  void ConstructHistogram(const data_size_t* data_indices, data_size_t start, data_size_t end,
    const score_t* ordered_gradients, const score_t* ordered_hessians,
    hist_t* out) const override {
    ConstructHistogramInner<true, true, true>(data_indices, start, end, ordered_gradients, ordered_hessians, out);
  }
Guolin Ke's avatar
Guolin Ke committed
118

119
120
  void ConstructHistogram(data_size_t start, data_size_t end,
    const score_t* ordered_gradients, const score_t* ordered_hessians,
121
122
    hist_t* out) const override {
    ConstructHistogramInner<false, false, true>(nullptr, start, end, ordered_gradients, ordered_hessians, out);
Guolin Ke's avatar
Guolin Ke committed
123
124
  }

125
126
  void ConstructHistogram(const data_size_t* data_indices, data_size_t start, data_size_t end,
    const score_t* ordered_gradients,
127
128
    hist_t* out) const override {
    ConstructHistogramInner<true, true, false>(data_indices, start, end, ordered_gradients, nullptr, out);
129
  }
Guolin Ke's avatar
Guolin Ke committed
130

131
132
  void ConstructHistogram(data_size_t start, data_size_t end,
    const score_t* ordered_gradients,
133
134
    hist_t* out) const override {
    ConstructHistogramInner<false, false, false>(nullptr, start, end, ordered_gradients, nullptr, out);
135
136
  }

Guolin Ke's avatar
Guolin Ke committed
137
  data_size_t Split(
Guolin Ke's avatar
Guolin Ke committed
138
    uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, uint32_t most_freq_bin, MissingType missing_type, bool default_left,
Guolin Ke's avatar
Guolin Ke committed
139
    uint32_t threshold, data_size_t* data_indices, data_size_t num_data,
140
    data_size_t* lte_indices, data_size_t* gt_indices) const override {
Guolin Ke's avatar
Guolin Ke committed
141
142
    if (num_data <= 0) { return 0; }
    uint8_t th = static_cast<uint8_t>(threshold + min_bin);
143
144
    const uint8_t minb = static_cast<uint8_t>(min_bin);
    const uint8_t maxb = static_cast<uint8_t>(max_bin);
Guolin Ke's avatar
Guolin Ke committed
145
    uint8_t t_default_bin = static_cast<uint8_t>(min_bin + default_bin);
Guolin Ke's avatar
Guolin Ke committed
146
147
    uint8_t t_most_freq_bin = static_cast<uint8_t>(min_bin + most_freq_bin);
    if (most_freq_bin == 0) {
Guolin Ke's avatar
Guolin Ke committed
148
      th -= 1;
Guolin Ke's avatar
Guolin Ke committed
149
      t_default_bin -= 1;
Guolin Ke's avatar
Guolin Ke committed
150
      t_most_freq_bin -= 1;
Guolin Ke's avatar
Guolin Ke committed
151
152
153
154
155
    }
    data_size_t lte_count = 0;
    data_size_t gt_count = 0;
    data_size_t* default_indices = gt_indices;
    data_size_t* default_count = &gt_count;
Guolin Ke's avatar
Guolin Ke committed
156
157
158
159
160
161
    data_size_t* missing_default_indices = gt_indices;
    data_size_t* missing_default_count = &gt_count;
    if (most_freq_bin <= threshold) {
      default_indices = lte_indices;
      default_count = &lte_count;
    }
162
163
164
165
    if (missing_type == MissingType::NaN) {
      if (default_left) {
        missing_default_indices = lte_indices;
        missing_default_count = &lte_count;
Guolin Ke's avatar
Guolin Ke committed
166
      }
167
168
169
      for (data_size_t i = 0; i < num_data; ++i) {
        const data_size_t idx = data_indices[i];
        const uint8_t bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
Guolin Ke's avatar
Guolin Ke committed
170
        if (bin == maxb) {
171
          missing_default_indices[(*missing_default_count)++] = idx;
Guolin Ke's avatar
Guolin Ke committed
172
173
        } else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
          default_indices[(*default_count)++] = idx;
174
175
176
177
        } else if (bin > th) {
          gt_indices[gt_count++] = idx;
        } else {
          lte_indices[lte_count++] = idx;
178
179
180
        }
      }
    } else {
Guolin Ke's avatar
Guolin Ke committed
181
182
183
184
      if ((default_left && missing_type == MissingType::Zero)
          || (default_bin <= threshold && missing_type != MissingType::Zero)) {
        missing_default_indices = lte_indices;
        missing_default_count = &lte_count;
185
      }
Guolin Ke's avatar
Guolin Ke committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
      if (default_bin == most_freq_bin) {
        for (data_size_t i = 0; i < num_data; ++i) {
          const data_size_t idx = data_indices[i];
          const uint8_t bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
          if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
            missing_default_indices[(*missing_default_count)++] = idx;
          } else if (bin > th) {
            gt_indices[gt_count++] = idx;
          } else {
            lte_indices[lte_count++] = idx;
          }
        }
      } else {
        for (data_size_t i = 0; i < num_data; ++i) {
          const data_size_t idx = data_indices[i];
          const uint8_t bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
          if (bin == t_default_bin) {
            missing_default_indices[(*missing_default_count)++] = idx;
          } else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
            default_indices[(*default_count)++] = idx;
          } else if (bin > th) {
            gt_indices[gt_count++] = idx;
          } else {
            lte_indices[lte_count++] = idx;
          }
211
        }
Guolin Ke's avatar
Guolin Ke committed
212
213
214
215
      }
    }
    return lte_count;
  }
216

Guolin Ke's avatar
Guolin Ke committed
217
  data_size_t SplitCategorical(
Guolin Ke's avatar
Guolin Ke committed
218
    uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin,
219
    const uint32_t* threshold, int num_threahold, data_size_t* data_indices, data_size_t num_data,
220
221
222
223
224
225
    data_size_t* lte_indices, data_size_t* gt_indices) const override {
    if (num_data <= 0) { return 0; }
    data_size_t lte_count = 0;
    data_size_t gt_count = 0;
    data_size_t* default_indices = gt_indices;
    data_size_t* default_count = &gt_count;
Guolin Ke's avatar
Guolin Ke committed
226
    if (Common::FindInBitset(threshold, num_threahold, most_freq_bin)) {
227
228
229
230
231
232
233
234
      default_indices = lte_indices;
      default_count = &lte_count;
    }
    for (data_size_t i = 0; i < num_data; ++i) {
      const data_size_t idx = data_indices[i];
      const uint32_t bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
      if (bin < min_bin || bin > max_bin) {
        default_indices[(*default_count)++] = idx;
235
      } else if (Common::FindInBitset(threshold, num_threahold, bin - min_bin)) {
236
237
238
239
240
241
242
243
        lte_indices[lte_count++] = idx;
      } else {
        gt_indices[gt_count++] = idx;
      }
    }
    return lte_count;
  }

Guolin Ke's avatar
Guolin Ke committed
244
245
246
247
  data_size_t num_data() const override { return num_data_; }


  void FinishLoad() override {
Guolin Ke's avatar
Guolin Ke committed
248
    if (buf_.empty()) { return; }
Guolin Ke's avatar
Guolin Ke committed
249
250
251
252
253
254
255
256
257
258
    int len = (num_data_ + 1) / 2;
    for (int i = 0; i < len; ++i) {
      data_[i] |= buf_[i];
    }
    buf_.clear();
  }

  void LoadFromMemory(const void* memory, const std::vector<data_size_t>& local_used_indices) override {
    const uint8_t* mem_data = reinterpret_cast<const uint8_t*>(memory);
    if (!local_used_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
259
260
      const data_size_t rest = num_data_ & 1;
      for (int i = 0; i < num_data_ - rest; i += 2) {
Guolin Ke's avatar
Guolin Ke committed
261
        // get old bins
Guolin Ke's avatar
Guolin Ke committed
262
263
264
265
        data_size_t idx = local_used_indices[i];
        const auto bin1 = static_cast<uint8_t>((mem_data[idx >> 1] >> ((idx & 1) << 2)) & 0xf);
        idx = local_used_indices[i + 1];
        const auto bin2 = static_cast<uint8_t>((mem_data[idx >> 1] >> ((idx & 1) << 2)) & 0xf);
Guolin Ke's avatar
Guolin Ke committed
266
        // add
Guolin Ke's avatar
Guolin Ke committed
267
        const int i1 = i >> 1;
Guolin Ke's avatar
Guolin Ke committed
268
269
        data_[i1] = (bin1 | (bin2 << 4));
      }
Guolin Ke's avatar
Guolin Ke committed
270
      if (rest) {
Guolin Ke's avatar
Guolin Ke committed
271
        data_size_t idx = local_used_indices[num_data_ - 1];
Guolin Ke's avatar
Guolin Ke committed
272
        data_[num_data_ >> 1] = (mem_data[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
Guolin Ke's avatar
Guolin Ke committed
273
274
275
276
277
278
279
280
281
      }
    } else {
      for (size_t i = 0; i < data_.size(); ++i) {
        data_[i] = mem_data[i];
      }
    }
  }

  void CopySubset(const Bin* full_bin, const data_size_t* used_indices, data_size_t num_used_indices) override {
Guolin Ke's avatar
Guolin Ke committed
282
    auto other_bin = dynamic_cast<const Dense4bitsBin*>(full_bin);
Guolin Ke's avatar
Guolin Ke committed
283
284
    const data_size_t rest = num_used_indices & 1;
    for (int i = 0; i < num_used_indices - rest; i += 2) {
Guolin Ke's avatar
Guolin Ke committed
285
286
287
288
      data_size_t idx = used_indices[i];
      const auto bin1 = static_cast<uint8_t>((other_bin->data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf);
      idx = used_indices[i + 1];
      const auto bin2 = static_cast<uint8_t>((other_bin->data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf);
Guolin Ke's avatar
Guolin Ke committed
289
      const int i1 = i >> 1;
Guolin Ke's avatar
Guolin Ke committed
290
291
      data_[i1] = (bin1 | (bin2 << 4));
    }
Guolin Ke's avatar
Guolin Ke committed
292
    if (rest) {
Guolin Ke's avatar
Guolin Ke committed
293
      data_size_t idx = used_indices[num_used_indices - 1];
Guolin Ke's avatar
Guolin Ke committed
294
      data_[num_used_indices >> 1] = (other_bin->data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
Guolin Ke's avatar
Guolin Ke committed
295
296
297
    }
  }

298
299
  void SaveBinaryToFile(const VirtualFileWriter* writer) const override {
    writer->Write(data_.data(), sizeof(uint8_t) * data_.size());
Guolin Ke's avatar
Guolin Ke committed
300
301
302
  }

  size_t SizesInByte() const override {
303
    return sizeof(uint8_t) * data_.size();
Guolin Ke's avatar
Guolin Ke committed
304
305
  }

306
307
308
309
  Dense4bitsBin* Clone() override {
    return new Dense4bitsBin(*this);
  }

310
 protected:
311
  Dense4bitsBin(const Dense4bitsBin& other)
312
313
    : num_data_(other.num_data_), data_(other.data_), buf_(other.buf_) {
  }
314

Guolin Ke's avatar
Guolin Ke committed
315
  data_size_t num_data_;
316
  std::vector<uint8_t, Common::AlignmentAllocator<uint8_t, kAlignedSize>> data_;
Guolin Ke's avatar
Guolin Ke committed
317
318
319
320
321
322
  std::vector<uint8_t> buf_;
};

uint32_t Dense4bitsBinIterator::Get(data_size_t idx) {
  const auto bin = (bin_data_->data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
  if (bin >= min_bin_ && bin <= max_bin_) {
323
    return bin - min_bin_ + offset_;
Guolin Ke's avatar
Guolin Ke committed
324
  } else {
Guolin Ke's avatar
Guolin Ke committed
325
    return most_freq_bin_;
Guolin Ke's avatar
Guolin Ke committed
326
327
328
  }
}

329
330
331
332
uint32_t Dense4bitsBinIterator::RawGet(data_size_t idx) {
  return (bin_data_->data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
}

Guolin Ke's avatar
Guolin Ke committed
333
334
inline BinIterator* Dense4bitsBin::GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin) const {
  return new Dense4bitsBinIterator(this, min_bin, max_bin, most_freq_bin);
Guolin Ke's avatar
Guolin Ke committed
335
336
337
338
}

}  // namespace LightGBM
#endif   // LIGHTGBM_IO_DENSE_NBITS_BIN_HPP_