sparse_bin.hpp 31.2 KB
Newer Older
1
2
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
3
4
 * Licensed under the MIT License. See LICENSE file in the project root for
 * license information.
5
 */
Guolin Ke's avatar
Guolin Ke committed
6
7
8
#ifndef LIGHTGBM_IO_SPARSE_BIN_HPP_
#define LIGHTGBM_IO_SPARSE_BIN_HPP_

9
10
11
12
#include <LightGBM/bin.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>

13
14
15
#include <algorithm>
#include <cstdint>
#include <cstring>
16
#include <limits>
17
#include <utility>
Guolin Ke's avatar
Guolin Ke committed
18
19
20
21
#include <vector>

namespace LightGBM {

22
23
template <typename VAL_T>
class SparseBin;
24

Guolin Ke's avatar
Guolin Ke committed
25
26
const size_t kNumFastIndex = 64;

27
template <typename VAL_T>
28
class SparseBinIterator : public BinIterator {
29
 public:
30
31
32
33
34
35
  SparseBinIterator(const SparseBin<VAL_T>* bin_data, uint32_t min_bin,
                    uint32_t max_bin, uint32_t most_freq_bin)
      : bin_data_(bin_data),
        min_bin_(static_cast<VAL_T>(min_bin)),
        max_bin_(static_cast<VAL_T>(max_bin)),
        most_freq_bin_(static_cast<VAL_T>(most_freq_bin)) {
Guolin Ke's avatar
Guolin Ke committed
36
    if (most_freq_bin_ == 0) {
37
      offset_ = 1;
Guolin Ke's avatar
Guolin Ke committed
38
    } else {
39
      offset_ = 0;
Guolin Ke's avatar
Guolin Ke committed
40
41
42
    }
    Reset(0);
  }
43
  SparseBinIterator(const SparseBin<VAL_T>* bin_data, data_size_t start_idx)
44
      : bin_data_(bin_data) {
45
46
47
    Reset(start_idx);
  }

48
49
  inline uint32_t RawGet(data_size_t idx) override;
  inline VAL_T InnerRawGet(data_size_t idx);
50

51
  inline uint32_t Get(data_size_t idx) override {
Guolin Ke's avatar
Guolin Ke committed
52
    VAL_T ret = InnerRawGet(idx);
Guolin Ke's avatar
Guolin Ke committed
53
    if (ret >= min_bin_ && ret <= max_bin_) {
54
      return ret - min_bin_ + offset_;
Guolin Ke's avatar
Guolin Ke committed
55
    } else {
Guolin Ke's avatar
Guolin Ke committed
56
      return most_freq_bin_;
Guolin Ke's avatar
Guolin Ke committed
57
    }
58
59
  }

Guolin Ke's avatar
Guolin Ke committed
60
  inline void Reset(data_size_t idx) override;
61

62
 private:
63
64
65
  const SparseBin<VAL_T>* bin_data_;
  data_size_t cur_pos_;
  data_size_t i_delta_;
Guolin Ke's avatar
Guolin Ke committed
66
67
  VAL_T min_bin_;
  VAL_T max_bin_;
Guolin Ke's avatar
Guolin Ke committed
68
  VAL_T most_freq_bin_;
69
  uint8_t offset_;
70
71
};

Guolin Ke's avatar
Guolin Ke committed
72
template <typename VAL_T>
73
class SparseBin : public Bin {
74
 public:
Guolin Ke's avatar
Guolin Ke committed
75
76
  friend class SparseBinIterator<VAL_T>;

77
  explicit SparseBin(data_size_t num_data) : num_data_(num_data) {
78
    int num_threads = OMP_NUM_THREADS();
Guolin Ke's avatar
Guolin Ke committed
79
    push_buffers_.resize(num_threads);
Guolin Ke's avatar
Guolin Ke committed
80
81
  }

82
  ~SparseBin() {}
Guolin Ke's avatar
Guolin Ke committed
83

84
85
86
87
  void InitStreaming(uint32_t num_thread, int32_t omp_max_threads) override {
    // Each external thread needs its own set of OpenMP push buffers,
    // so allocate num_thread times the maximum number of OMP threads per external thread
    push_buffers_.resize(omp_max_threads * num_thread);
88
89
  };

90
  void ReSize(data_size_t num_data) override { num_data_ = num_data; }
Guolin Ke's avatar
Guolin Ke committed
91
92

  void Push(int tid, data_size_t idx, uint32_t value) override {
93
    auto cur_bin = static_cast<VAL_T>(value);
Guolin Ke's avatar
Guolin Ke committed
94
    if (cur_bin != 0) {
95
96
      push_buffers_[tid].emplace_back(idx, cur_bin);
    }
Guolin Ke's avatar
Guolin Ke committed
97
98
  }

99
100
  BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin,
                           uint32_t most_freq_bin) const override;
Guolin Ke's avatar
Guolin Ke committed
101

102
#define ACC_GH(hist, i, g, h)               \
103
  const auto ti = static_cast<int>(i) << 1; \
104
105
  hist[ti] += g;                            \
  hist[ti + 1] += h;
106

107
108
109
110
  void ConstructHistogram(const data_size_t* data_indices, data_size_t start,
                          data_size_t end, const score_t* ordered_gradients,
                          const score_t* ordered_hessians,
                          hist_t* out) const override {
111
112
113
114
115
116
    data_size_t i_delta, cur_pos;
    InitIndex(data_indices[start], &i_delta, &cur_pos);
    data_size_t i = start;
    for (;;) {
      if (cur_pos < data_indices[i]) {
        cur_pos += deltas_[++i_delta];
117
118
119
        if (i_delta >= num_vals_) {
          break;
        }
120
      } else if (cur_pos > data_indices[i]) {
121
122
123
        if (++i >= end) {
          break;
        }
124
125
126
      } else {
        const VAL_T bin = vals_[i_delta];
        ACC_GH(out, bin, ordered_gradients[i], ordered_hessians[i]);
127
128
129
        if (++i >= end) {
          break;
        }
130
        cur_pos += deltas_[++i_delta];
131
132
133
        if (i_delta >= num_vals_) {
          break;
        }
134
135
      }
    }
Guolin Ke's avatar
Guolin Ke committed
136
137
  }

138
  void ConstructHistogram(data_size_t start, data_size_t end,
139
140
141
                          const score_t* ordered_gradients,
                          const score_t* ordered_hessians,
                          hist_t* out) const override {
142
143
144
145
146
147
148
149
150
151
    data_size_t i_delta, cur_pos;
    InitIndex(start, &i_delta, &cur_pos);
    while (cur_pos < start && i_delta < num_vals_) {
      cur_pos += deltas_[++i_delta];
    }
    while (cur_pos < end && i_delta < num_vals_) {
      const VAL_T bin = vals_[i_delta];
      ACC_GH(out, bin, ordered_gradients[cur_pos], ordered_hessians[cur_pos]);
      cur_pos += deltas_[++i_delta];
    }
152
153
  }

154
155
156
  void ConstructHistogram(const data_size_t* data_indices, data_size_t start,
                          data_size_t end, const score_t* ordered_gradients,
                          hist_t* out) const override {
157
158
159
    data_size_t i_delta, cur_pos;
    InitIndex(data_indices[start], &i_delta, &cur_pos);
    data_size_t i = start;
Guolin Ke's avatar
Guolin Ke committed
160
161
    hist_t* grad = out;
    hist_cnt_t* cnt = reinterpret_cast<hist_cnt_t*>(out + 1);
162
163
164
    for (;;) {
      if (cur_pos < data_indices[i]) {
        cur_pos += deltas_[++i_delta];
165
166
167
        if (i_delta >= num_vals_) {
          break;
        }
168
      } else if (cur_pos > data_indices[i]) {
169
170
171
        if (++i >= end) {
          break;
        }
172
      } else {
Guolin Ke's avatar
Guolin Ke committed
173
174
175
        const uint32_t ti = static_cast<uint32_t>(vals_[i_delta]) << 1;
        grad[ti] += ordered_gradients[i];
        ++cnt[ti];
176
177
178
        if (++i >= end) {
          break;
        }
179
        cur_pos += deltas_[++i_delta];
180
181
182
        if (i_delta >= num_vals_) {
          break;
        }
183
184
      }
    }
185
186
  }

187
  void ConstructHistogram(data_size_t start, data_size_t end,
188
189
                          const score_t* ordered_gradients,
                          hist_t* out) const override {
190
191
    data_size_t i_delta, cur_pos;
    InitIndex(start, &i_delta, &cur_pos);
Guolin Ke's avatar
Guolin Ke committed
192
193
    hist_t* grad = out;
    hist_cnt_t* cnt = reinterpret_cast<hist_cnt_t*>(out + 1);
194
195
196
197
    while (cur_pos < start && i_delta < num_vals_) {
      cur_pos += deltas_[++i_delta];
    }
    while (cur_pos < end && i_delta < num_vals_) {
Guolin Ke's avatar
Guolin Ke committed
198
199
200
      const uint32_t ti = static_cast<uint32_t>(vals_[i_delta]) << 1;
      grad[ti] += ordered_gradients[cur_pos];
      ++cnt[ti];
201
202
      cur_pos += deltas_[++i_delta];
    }
203
  }
204
#undef ACC_GH
205

206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
  template <bool USE_HESSIAN, typename PACKED_HIST_T, typename GRAD_HIST_T, typename HESS_HIST_T, int HIST_BITS>
  void ConstructIntHistogramInner(data_size_t start, data_size_t end,
                          const score_t* ordered_gradients_and_hessians,
                          hist_t* out) const {
    data_size_t i_delta, cur_pos;
    InitIndex(start, &i_delta, &cur_pos);
    if (USE_HESSIAN) {
      PACKED_HIST_T* out_ptr = reinterpret_cast<PACKED_HIST_T*>(out);
      const int16_t* gradients_and_hessians_ptr = reinterpret_cast<const int16_t*>(ordered_gradients_and_hessians);
      while (cur_pos < start && i_delta < num_vals_) {
        cur_pos += deltas_[++i_delta];
      }
      while (cur_pos < end && i_delta < num_vals_) {
        const VAL_T bin = vals_[i_delta];
        const int16_t gradient_16 = gradients_and_hessians_ptr[cur_pos];
        const PACKED_HIST_T gradient_64 = (static_cast<PACKED_HIST_T>(static_cast<int8_t>(gradient_16 >> 8)) << HIST_BITS) | (gradient_16 & 0xff);
        out_ptr[bin] += gradient_64;
        cur_pos += deltas_[++i_delta];
      }
    } else {
      GRAD_HIST_T* grad = reinterpret_cast<GRAD_HIST_T*>(out);
      HESS_HIST_T* cnt = reinterpret_cast<HESS_HIST_T*>(out) + 1;
      const int8_t* gradients_and_hessians_ptr = reinterpret_cast<const int8_t*>(ordered_gradients_and_hessians);
      while (cur_pos < start && i_delta < num_vals_) {
        cur_pos += deltas_[++i_delta];
      }
      while (cur_pos < end && i_delta < num_vals_) {
        const uint32_t ti = static_cast<uint32_t>(vals_[i_delta]) << 1;
        grad[ti] += gradients_and_hessians_ptr[cur_pos];
        ++cnt[ti];
        cur_pos += deltas_[++i_delta];
      }
    }
  }

  template <bool USE_HESSIAN, typename PACKED_HIST_T, typename GRAD_HIST_T, typename HESS_HIST_T, int HIST_BITS>
  void ConstructIntHistogramInner(const data_size_t* data_indices, data_size_t start,
                          data_size_t end, const score_t* ordered_gradients_and_hessians,
                          hist_t* out) const {
    data_size_t i_delta, cur_pos;
    InitIndex(data_indices[start], &i_delta, &cur_pos);
    data_size_t i = start;
    if (USE_HESSIAN) {
      PACKED_HIST_T* out_ptr = reinterpret_cast<PACKED_HIST_T*>(out);
      const int16_t* gradients_and_hessians_ptr = reinterpret_cast<const int16_t*>(ordered_gradients_and_hessians);
      for (;;) {
        if (cur_pos < data_indices[i]) {
          cur_pos += deltas_[++i_delta];
          if (i_delta >= num_vals_) {
            break;
          }
        } else if (cur_pos > data_indices[i]) {
          if (++i >= end) {
            break;
          }
        } else {
          const VAL_T bin = vals_[i_delta];
          const int16_t gradient_16 = gradients_and_hessians_ptr[i];
          const PACKED_HIST_T gradient_packed = (HIST_BITS == 8) ? gradient_16 :
            (static_cast<PACKED_HIST_T>(static_cast<int8_t>(gradient_16 >> 8)) << HIST_BITS) | (gradient_16 & 0xff);
          out_ptr[bin] += gradient_packed;
          if (++i >= end) {
            break;
          }
          cur_pos += deltas_[++i_delta];
          if (i_delta >= num_vals_) {
            break;
          }
        }
      }
    } else {
      GRAD_HIST_T* grad = reinterpret_cast<GRAD_HIST_T*>(out);
      HESS_HIST_T* cnt = reinterpret_cast<HESS_HIST_T*>(out) + 1;
      const int8_t* gradients_and_hessians_ptr = reinterpret_cast<const int8_t*>(ordered_gradients_and_hessians);
      for (;;) {
        if (cur_pos < data_indices[i]) {
          cur_pos += deltas_[++i_delta];
          if (i_delta >= num_vals_) {
            break;
          }
        } else if (cur_pos > data_indices[i]) {
          if (++i >= end) {
            break;
          }
        } else {
          const uint32_t ti = static_cast<uint32_t>(vals_[i_delta]) << 1;
          grad[ti] += gradients_and_hessians_ptr[i << 1];
          ++cnt[ti];
          if (++i >= end) {
            break;
          }
          cur_pos += deltas_[++i_delta];
          if (i_delta >= num_vals_) {
            break;
          }
        }
      }
    }
  }

  void ConstructHistogramInt32(const data_size_t* data_indices, data_size_t start,
                          data_size_t end, const score_t* ordered_gradients,
                          const score_t* /*ordered_hessians*/,
                          hist_t* out) const override {
    ConstructIntHistogramInner<true, int64_t, int32_t, uint32_t, 32>(data_indices, start, end, ordered_gradients, out);
  }

  void ConstructHistogramInt32(data_size_t start, data_size_t end,
                          const score_t* ordered_gradients,
                          const score_t* /*ordered_hessians*/,
                          hist_t* out) const override {
    ConstructIntHistogramInner<true, int64_t, int32_t, uint32_t, 32>(start, end, ordered_gradients, out);
  }

  void ConstructHistogramInt32(const data_size_t* data_indices, data_size_t start,
                          data_size_t end, const score_t* ordered_gradients,
                          hist_t* out) const override {
    ConstructIntHistogramInner<false, int64_t, int32_t, uint32_t, 32>(data_indices, start, end, ordered_gradients, out);
  }

  void ConstructHistogramInt32(data_size_t start, data_size_t end,
                          const score_t* ordered_gradients,
                          hist_t* out) const override {
    ConstructIntHistogramInner<false, int64_t, int32_t, uint32_t, 32>(start, end, ordered_gradients, out);
  }

  void ConstructHistogramInt16(const data_size_t* data_indices, data_size_t start,
                          data_size_t end, const score_t* ordered_gradients,
                          const score_t* /*ordered_hessians*/,
                          hist_t* out) const override {
    ConstructIntHistogramInner<true, int32_t, int16_t, uint16_t, 16>(data_indices, start, end, ordered_gradients, out);
  }

  void ConstructHistogramInt16(data_size_t start, data_size_t end,
                          const score_t* ordered_gradients,
                          const score_t* /*ordered_hessians*/,
                          hist_t* out) const override {
    ConstructIntHistogramInner<true, int32_t, int16_t, uint16_t, 16>(start, end, ordered_gradients, out);
  }

  void ConstructHistogramInt16(const data_size_t* data_indices, data_size_t start,
                          data_size_t end, const score_t* ordered_gradients,
                          hist_t* out) const override {
    ConstructIntHistogramInner<false, int32_t, int16_t, uint16_t, 16>(data_indices, start, end, ordered_gradients, out);
  }

  void ConstructHistogramInt16(data_size_t start, data_size_t end,
                          const score_t* ordered_gradients,
                          hist_t* out) const override {
    ConstructIntHistogramInner<false, int32_t, int16_t, uint16_t, 16>(start, end, ordered_gradients, out);
  }

  void ConstructHistogramInt8(const data_size_t* data_indices, data_size_t start,
                          data_size_t end, const score_t* ordered_gradients,
                          const score_t* /*ordered_hessians*/,
                          hist_t* out) const override {
    ConstructIntHistogramInner<true, int16_t, uint8_t, uint8_t, 8>(data_indices, start, end, ordered_gradients, out);
  }

  void ConstructHistogramInt8(data_size_t start, data_size_t end,
                          const score_t* ordered_gradients,
                          const score_t* /*ordered_hessians*/,
                          hist_t* out) const override {
    ConstructIntHistogramInner<true, int16_t, uint8_t, uint8_t, 8>(start, end, ordered_gradients, out);
  }

  void ConstructHistogramInt8(const data_size_t* data_indices, data_size_t start,
                          data_size_t end, const score_t* ordered_gradients,
                          hist_t* out) const override {
    ConstructIntHistogramInner<false, int16_t, uint8_t, uint8_t, 8>(data_indices, start, end, ordered_gradients, out);
  }

  void ConstructHistogramInt8(data_size_t start, data_size_t end,
                          const score_t* ordered_gradients,
                          hist_t* out) const override {
    ConstructIntHistogramInner<false, int16_t, uint8_t, uint8_t, 8>(start, end, ordered_gradients, out);
  }

384
385
  inline void NextNonzeroFast(data_size_t* i_delta,
                              data_size_t* cur_pos) const {
386
387
388
    *cur_pos += deltas_[++(*i_delta)];
    if (*i_delta >= num_vals_) {
      *cur_pos = num_data_;
389
    }
390
391
  }

392
  inline bool NextNonzero(data_size_t* i_delta, data_size_t* cur_pos) const {
393
    *cur_pos += deltas_[++(*i_delta)];
394
    if (*i_delta < num_vals_) {
395
396
      return true;
    } else {
397
      *cur_pos = num_data_;
398
399
400
401
      return false;
    }
  }

402
403
404
405
406
407
408
409
410
411
  template <bool MISS_IS_ZERO, bool MISS_IS_NA, bool MFB_IS_ZERO,
            bool MFB_IS_NA, bool USE_MIN_BIN>
  data_size_t SplitInner(uint32_t min_bin, uint32_t max_bin,
                         uint32_t default_bin, uint32_t most_freq_bin,
                         bool default_left, uint32_t threshold,
                         const data_size_t* data_indices, data_size_t cnt,
                         data_size_t* lte_indices,
                         data_size_t* gt_indices) const {
    auto th = static_cast<VAL_T>(threshold + min_bin);
    auto t_zero_bin = static_cast<VAL_T>(min_bin + default_bin);
Guolin Ke's avatar
Guolin Ke committed
412
    if (most_freq_bin == 0) {
413
414
      --th;
      --t_zero_bin;
Guolin Ke's avatar
Guolin Ke committed
415
    }
416
417
    const auto minb = static_cast<VAL_T>(min_bin);
    const auto maxb = static_cast<VAL_T>(max_bin);
Guolin Ke's avatar
Guolin Ke committed
418
419
    data_size_t lte_count = 0;
    data_size_t gt_count = 0;
Guolin Ke's avatar
Guolin Ke committed
420
421
    data_size_t* default_indices = gt_indices;
    data_size_t* default_count = &gt_count;
Guolin Ke's avatar
Guolin Ke committed
422
423
424
425
426
427
    data_size_t* missing_default_indices = gt_indices;
    data_size_t* missing_default_count = &gt_count;
    if (most_freq_bin <= threshold) {
      default_indices = lte_indices;
      default_count = &lte_count;
    }
428
    if (MISS_IS_ZERO || MISS_IS_NA) {
429
430
431
      if (default_left) {
        missing_default_indices = lte_indices;
        missing_default_count = &lte_count;
Guolin Ke's avatar
Guolin Ke committed
432
      }
433
434
435
436
437
438
439
440
441
442
443
444
    }
    SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
    if (min_bin < max_bin) {
      for (data_size_t i = 0; i < cnt; ++i) {
        const data_size_t idx = data_indices[i];
        const auto bin = iterator.InnerRawGet(idx);
        if ((MISS_IS_ZERO && !MFB_IS_ZERO && bin == t_zero_bin) ||
            (MISS_IS_NA && !MFB_IS_NA && bin == maxb)) {
          missing_default_indices[(*missing_default_count)++] = idx;
        } else if ((USE_MIN_BIN && (bin < minb || bin > maxb)) ||
                   (!USE_MIN_BIN && bin == 0)) {
          if ((MISS_IS_NA && MFB_IS_NA) || (MISS_IS_ZERO && MFB_IS_ZERO)) {
445
446
447
448
            missing_default_indices[(*missing_default_count)++] = idx;
          } else {
            default_indices[(*default_count)++] = idx;
          }
449
450
451
452
        } else if (bin > th) {
          gt_indices[gt_count++] = idx;
        } else {
          lte_indices[lte_count++] = idx;
453
454
455
        }
      }
    } else {
456
457
458
459
460
      data_size_t* max_bin_indices = gt_indices;
      data_size_t* max_bin_count = &gt_count;
      if (maxb <= th) {
        max_bin_indices = lte_indices;
        max_bin_count = &lte_count;
461
      }
462
463
464
465
466
467
468
      for (data_size_t i = 0; i < cnt; ++i) {
        const data_size_t idx = data_indices[i];
        const auto bin = iterator.InnerRawGet(idx);
        if (MISS_IS_ZERO && !MFB_IS_ZERO && bin == t_zero_bin) {
          missing_default_indices[(*missing_default_count)++] = idx;
        } else if (bin != maxb) {
          if ((MISS_IS_NA && MFB_IS_NA) || (MISS_IS_ZERO && MFB_IS_ZERO)) {
Guolin Ke's avatar
Guolin Ke committed
469
470
            missing_default_indices[(*missing_default_count)++] = idx;
          } else {
471
            default_indices[(*default_count)++] = idx;
Guolin Ke's avatar
Guolin Ke committed
472
          }
473
474
        } else {
          if (MISS_IS_NA && !MFB_IS_NA) {
Guolin Ke's avatar
Guolin Ke committed
475
476
            missing_default_indices[(*missing_default_count)++] = idx;
          } else {
477
            max_bin_indices[(*max_bin_count)++] = idx;
Guolin Ke's avatar
Guolin Ke committed
478
          }
479
        }
Guolin Ke's avatar
Guolin Ke committed
480
      }
481
    }
482
483
484
    return lte_count;
  }

485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
  data_size_t Split(uint32_t min_bin, uint32_t max_bin, uint32_t default_bin,
                    uint32_t most_freq_bin, MissingType missing_type,
                    bool default_left, uint32_t threshold,
                    const data_size_t* data_indices, data_size_t cnt,
                    data_size_t* lte_indices,
                    data_size_t* gt_indices) const override {
#define ARGUMENTS                                                        \
  min_bin, max_bin, default_bin, most_freq_bin, default_left, threshold, \
      data_indices, cnt, lte_indices, gt_indices
    if (missing_type == MissingType::None) {
      return SplitInner<false, false, false, false, true>(ARGUMENTS);
    } else if (missing_type == MissingType::Zero) {
      if (default_bin == most_freq_bin) {
        return SplitInner<true, false, true, false, true>(ARGUMENTS);
      } else {
        return SplitInner<true, false, false, false, true>(ARGUMENTS);
      }
    } else {
      if (max_bin == most_freq_bin + min_bin && most_freq_bin > 0) {
        return SplitInner<false, true, false, true, true>(ARGUMENTS);
      } else {
        return SplitInner<false, true, false, false, true>(ARGUMENTS);
      }
    }
#undef ARGUMENTS
  }

  data_size_t Split(uint32_t max_bin, uint32_t default_bin,
                    uint32_t most_freq_bin, MissingType missing_type,
                    bool default_left, uint32_t threshold,
                    const data_size_t* data_indices, data_size_t cnt,
                    data_size_t* lte_indices,
                    data_size_t* gt_indices) const override {
#define ARGUMENTS                                                  \
  1, max_bin, default_bin, most_freq_bin, default_left, threshold, \
      data_indices, cnt, lte_indices, gt_indices
    if (missing_type == MissingType::None) {
      return SplitInner<false, false, false, false, false>(ARGUMENTS);
    } else if (missing_type == MissingType::Zero) {
      if (default_bin == most_freq_bin) {
        return SplitInner<true, false, true, false, false>(ARGUMENTS);
      } else {
        return SplitInner<true, false, false, false, false>(ARGUMENTS);
      }
    } else {
      if (max_bin == most_freq_bin + 1 && most_freq_bin > 0) {
        return SplitInner<false, true, false, true, false>(ARGUMENTS);
      } else {
        return SplitInner<false, true, false, false, false>(ARGUMENTS);
      }
    }
#undef ARGUMENTS
  }
  template <bool USE_MIN_BIN>
  data_size_t SplitCategoricalInner(uint32_t min_bin, uint32_t max_bin,
                                    uint32_t most_freq_bin,
                                    const uint32_t* threshold,
Nikita Titov's avatar
Nikita Titov committed
542
                                    int num_threshold,
543
544
545
                                    const data_size_t* data_indices,
                                    data_size_t cnt, data_size_t* lte_indices,
                                    data_size_t* gt_indices) const {
546
547
548
549
    data_size_t lte_count = 0;
    data_size_t gt_count = 0;
    data_size_t* default_indices = gt_indices;
    data_size_t* default_count = &gt_count;
550
    SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
551
    int8_t offset = most_freq_bin == 0 ? 1 : 0;
Nikita Titov's avatar
Nikita Titov committed
552
    if (most_freq_bin > 0 && Common::FindInBitset(threshold, num_threshold, most_freq_bin)) {
553
554
555
      default_indices = lte_indices;
      default_count = &lte_count;
    }
556
    for (data_size_t i = 0; i < cnt; ++i) {
557
      const data_size_t idx = data_indices[i];
558
559
560
561
      const uint32_t bin = iterator.RawGet(idx);
      if (USE_MIN_BIN && (bin < min_bin || bin > max_bin)) {
        default_indices[(*default_count)++] = idx;
      } else if (!USE_MIN_BIN && bin == 0) {
562
        default_indices[(*default_count)++] = idx;
Nikita Titov's avatar
Nikita Titov committed
563
      } else if (Common::FindInBitset(threshold, num_threshold,
564
                                      bin - min_bin + offset)) {
565
566
567
568
        lte_indices[lte_count++] = idx;
      } else {
        gt_indices[gt_count++] = idx;
      }
Guolin Ke's avatar
Guolin Ke committed
569
570
571
572
    }
    return lte_count;
  }

573
574
  data_size_t SplitCategorical(uint32_t min_bin, uint32_t max_bin,
                               uint32_t most_freq_bin,
Nikita Titov's avatar
Nikita Titov committed
575
                               const uint32_t* threshold, int num_threshold,
576
577
578
579
                               const data_size_t* data_indices, data_size_t cnt,
                               data_size_t* lte_indices,
                               data_size_t* gt_indices) const override {
    return SplitCategoricalInner<true>(min_bin, max_bin, most_freq_bin,
Nikita Titov's avatar
Nikita Titov committed
580
                                       threshold, num_threshold, data_indices,
581
582
583
584
                                       cnt, lte_indices, gt_indices);
  }

  data_size_t SplitCategorical(uint32_t max_bin, uint32_t most_freq_bin,
Nikita Titov's avatar
Nikita Titov committed
585
                               const uint32_t* threshold, int num_threshold,
586
587
588
589
                               const data_size_t* data_indices, data_size_t cnt,
                               data_size_t* lte_indices,
                               data_size_t* gt_indices) const override {
    return SplitCategoricalInner<false>(1, max_bin, most_freq_bin, threshold,
Nikita Titov's avatar
Nikita Titov committed
590
                                        num_threshold, data_indices, cnt,
591
592
593
                                        lte_indices, gt_indices);
  }

Guolin Ke's avatar
Guolin Ke committed
594
595
  data_size_t num_data() const override { return num_data_; }

596
597
  void* get_data() override { return nullptr; }

Guolin Ke's avatar
Guolin Ke committed
598
599
  void FinishLoad() override {
    // get total non zero size
600
    size_t pair_cnt = 0;
601
    for (size_t i = 0; i < push_buffers_.size(); ++i) {
602
      pair_cnt += push_buffers_[i].size();
Guolin Ke's avatar
Guolin Ke committed
603
    }
604
605
    std::vector<std::pair<data_size_t, VAL_T>>& idx_val_pairs =
        push_buffers_[0];
606
    idx_val_pairs.reserve(pair_cnt);
Guolin Ke's avatar
Guolin Ke committed
607
608

    for (size_t i = 1; i < push_buffers_.size(); ++i) {
609
610
      idx_val_pairs.insert(idx_val_pairs.end(), push_buffers_[i].begin(),
                           push_buffers_[i].end());
Guolin Ke's avatar
Guolin Ke committed
611
612
613
614
      push_buffers_[i].clear();
      push_buffers_[i].shrink_to_fit();
    }
    // sort by data index
615
    std::sort(idx_val_pairs.begin(), idx_val_pairs.end(),
616
617
618
619
              [](const std::pair<data_size_t, VAL_T>& a,
                 const std::pair<data_size_t, VAL_T>& b) {
                return a.first < b.first;
              });
zhangyafeikimi's avatar
zhangyafeikimi committed
620
    // load delta array
621
    LoadFromPair(idx_val_pairs);
Guolin Ke's avatar
Guolin Ke committed
622
623
  }

624
625
  void LoadFromPair(
      const std::vector<std::pair<data_size_t, VAL_T>>& idx_val_pairs) {
626
    deltas_.clear();
Guolin Ke's avatar
Guolin Ke committed
627
    vals_.clear();
Guolin Ke's avatar
Guolin Ke committed
628
629
    deltas_.reserve(idx_val_pairs.size());
    vals_.reserve(idx_val_pairs.size());
Guolin Ke's avatar
Guolin Ke committed
630
631
    // transform to delta array
    data_size_t last_idx = 0;
632
633
634
    for (size_t i = 0; i < idx_val_pairs.size(); ++i) {
      const data_size_t cur_idx = idx_val_pairs[i].first;
      const VAL_T bin = idx_val_pairs[i].second;
Guolin Ke's avatar
Guolin Ke committed
635
      data_size_t cur_delta = cur_idx - last_idx;
636
      // disallow the multi-val in one row
637
638
639
      if (i > 0 && cur_delta == 0) {
        continue;
      }
640
      while (cur_delta >= 256) {
641
        deltas_.push_back(255);
Guolin Ke's avatar
Guolin Ke committed
642
        vals_.push_back(0);
643
        cur_delta -= 255;
Guolin Ke's avatar
Guolin Ke committed
644
      }
645
      deltas_.push_back(static_cast<uint8_t>(cur_delta));
Guolin Ke's avatar
Guolin Ke committed
646
647
648
649
      vals_.push_back(bin);
      last_idx = cur_idx;
    }
    // avoid out of range
650
    deltas_.push_back(0);
Guolin Ke's avatar
Guolin Ke committed
651
652
653
    num_vals_ = static_cast<data_size_t>(vals_.size());

    // reduce memory cost
654
    deltas_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
    vals_.shrink_to_fit();

    // generate fast index
    GetFastIndex();
  }

  void GetFastIndex() {
    fast_index_.clear();
    // get shift cnt
    data_size_t mod_size = (num_data_ + kNumFastIndex - 1) / kNumFastIndex;
    data_size_t pow2_mod_size = 1;
    fast_index_shift_ = 0;
    while (pow2_mod_size < mod_size) {
      pow2_mod_size <<= 1;
      ++fast_index_shift_;
    }
    // build fast index
672
    data_size_t i_delta = -1;
Guolin Ke's avatar
Guolin Ke committed
673
    data_size_t cur_pos = 0;
674
675
    data_size_t next_threshold = 0;
    while (NextNonzero(&i_delta, &cur_pos)) {
Guolin Ke's avatar
Guolin Ke committed
676
      while (next_threshold <= cur_pos) {
677
678
        fast_index_.emplace_back(i_delta, cur_pos);
        next_threshold += pow2_mod_size;
Guolin Ke's avatar
Guolin Ke committed
679
680
681
      }
    }
    // avoid out of range
682
    while (next_threshold < num_data_) {
Guolin Ke's avatar
Guolin Ke committed
683
      fast_index_.emplace_back(num_vals_ - 1, cur_pos);
684
      next_threshold += pow2_mod_size;
Guolin Ke's avatar
Guolin Ke committed
685
686
687
688
    }
    fast_index_.shrink_to_fit();
  }

689
  void SaveBinaryToFile(BinaryWriter* writer) const override {
690
691
692
    writer->AlignedWrite(&num_vals_, sizeof(num_vals_));
    writer->AlignedWrite(deltas_.data(), sizeof(uint8_t) * (num_vals_ + 1));
    writer->AlignedWrite(vals_.data(), sizeof(VAL_T) * num_vals_);
Guolin Ke's avatar
Guolin Ke committed
693
694
695
  }

  size_t SizesInByte() const override {
696
697
698
    return VirtualFileWriter::AlignedSize(sizeof(num_vals_)) +
           VirtualFileWriter::AlignedSize(sizeof(uint8_t) * (num_vals_ + 1)) +
           VirtualFileWriter::AlignedSize(sizeof(VAL_T) * num_vals_);
Guolin Ke's avatar
Guolin Ke committed
699
700
  }

701
702
703
  void LoadFromMemory(
      const void* memory,
      const std::vector<data_size_t>& local_used_indices) override {
Guolin Ke's avatar
Guolin Ke committed
704
705
    const char* mem_ptr = reinterpret_cast<const char*>(memory);
    data_size_t tmp_num_vals = *(reinterpret_cast<const data_size_t*>(mem_ptr));
706
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(tmp_num_vals));
Guolin Ke's avatar
Guolin Ke committed
707
    const uint8_t* tmp_delta = reinterpret_cast<const uint8_t*>(mem_ptr);
708
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(uint8_t) * (tmp_num_vals + 1));
Guolin Ke's avatar
Guolin Ke committed
709
710
    const VAL_T* tmp_vals = reinterpret_cast<const VAL_T*>(mem_ptr);

711
712
713
714
715
716
717
718
719
720
721
    deltas_.clear();
    vals_.clear();
    num_vals_ = tmp_num_vals;
    for (data_size_t i = 0; i < num_vals_; ++i) {
      deltas_.push_back(tmp_delta[i]);
      vals_.push_back(tmp_vals[i]);
    }
    deltas_.push_back(0);
    // reduce memory cost
    deltas_.shrink_to_fit();
    vals_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
722

Guolin Ke's avatar
Guolin Ke committed
723
    if (local_used_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
724
725
726
727
      // generate fast index
      GetFastIndex();
    } else {
      std::vector<std::pair<data_size_t, VAL_T>> tmp_pair;
728
729
      data_size_t cur_pos = 0;
      data_size_t j = -1;
730
731
      for (data_size_t i = 0;
           i < static_cast<data_size_t>(local_used_indices.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
732
        const data_size_t idx = local_used_indices[i];
733
734
        while (cur_pos < idx && j < num_vals_) {
          NextNonzero(&j, &cur_pos);
Guolin Ke's avatar
Guolin Ke committed
735
        }
736
        if (cur_pos == idx && j < num_vals_ && vals_[j] > 0) {
Guolin Ke's avatar
Guolin Ke committed
737
          // new row index is i
738
          tmp_pair.emplace_back(i, vals_[j]);
Guolin Ke's avatar
Guolin Ke committed
739
740
741
742
        }
      }
      LoadFromPair(tmp_pair);
    }
743
  }
Guolin Ke's avatar
Guolin Ke committed
744

745
746
  void CopySubrow(const Bin* full_bin, const data_size_t* used_indices,
                  data_size_t num_used_indices) override {
Guolin Ke's avatar
Guolin Ke committed
747
    auto other_bin = dynamic_cast<const SparseBin<VAL_T>*>(full_bin);
Guolin Ke's avatar
Guolin Ke committed
748
749
    deltas_.clear();
    vals_.clear();
Guolin Ke's avatar
Guolin Ke committed
750
751
752
753
754
    data_size_t start = 0;
    if (num_used_indices > 0) {
      start = used_indices[0];
    }
    SparseBinIterator<VAL_T> iterator(other_bin, start);
755
756
    // transform to delta array
    data_size_t last_idx = 0;
757
    for (data_size_t i = 0; i < num_used_indices; ++i) {
758
      auto bin = iterator.InnerRawGet(used_indices[i]);
Guolin Ke's avatar
Guolin Ke committed
759
      if (bin > 0) {
760
761
        data_size_t cur_delta = i - last_idx;
        while (cur_delta >= 256) {
762
          deltas_.push_back(255);
763
          vals_.push_back(0);
764
          cur_delta -= 255;
765
766
767
768
        }
        deltas_.push_back(static_cast<uint8_t>(cur_delta));
        vals_.push_back(bin);
        last_idx = i;
769
770
      }
    }
771
772
773
774
775
776
777
778
779
780
    // avoid out of range
    deltas_.push_back(0);
    num_vals_ = static_cast<data_size_t>(vals_.size());

    // reduce memory cost
    deltas_.shrink_to_fit();
    vals_.shrink_to_fit();

    // generate fast index
    GetFastIndex();
Guolin Ke's avatar
Guolin Ke committed
781
782
  }

783
784
  SparseBin<VAL_T>* Clone() override;

785
  SparseBin(const SparseBin<VAL_T>& other)
786
787
788
789
790
791
792
793
794
795
      : num_data_(other.num_data_),
        deltas_(other.deltas_),
        vals_(other.vals_),
        num_vals_(other.num_vals_),
        push_buffers_(other.push_buffers_),
        fast_index_(other.fast_index_),
        fast_index_shift_(other.fast_index_shift_) {}

  void InitIndex(data_size_t start_idx, data_size_t* i_delta,
                 data_size_t* cur_pos) const {
796
797
798
799
800
801
802
803
804
805
806
    auto idx = start_idx >> fast_index_shift_;
    if (static_cast<size_t>(idx) < fast_index_.size()) {
      const auto fast_pair = fast_index_[start_idx >> fast_index_shift_];
      *i_delta = fast_pair.first;
      *cur_pos = fast_pair.second;
    } else {
      *i_delta = -1;
      *cur_pos = 0;
    }
  }

807
808
809
810
  const void* GetColWiseData(uint8_t* bit_type, bool* is_sparse, std::vector<BinIterator*>* bin_iterator, const int num_threads) const override;

  const void* GetColWiseData(uint8_t* bit_type, bool* is_sparse, BinIterator** bin_iterator) const override;

811
 private:
Guolin Ke's avatar
Guolin Ke committed
812
  data_size_t num_data_;
813
814
  std::vector<uint8_t, Common::AlignmentAllocator<uint8_t, kAlignedSize>>
      deltas_;
815
  std::vector<VAL_T, Common::AlignmentAllocator<VAL_T, kAlignedSize>> vals_;
Guolin Ke's avatar
Guolin Ke committed
816
817
818
819
820
821
  data_size_t num_vals_;
  std::vector<std::vector<std::pair<data_size_t, VAL_T>>> push_buffers_;
  std::vector<std::pair<data_size_t, data_size_t>> fast_index_;
  data_size_t fast_index_shift_;
};

822
template <typename VAL_T>
823
SparseBin<VAL_T>* SparseBin<VAL_T>::Clone() {
824
825
826
  return new SparseBin(*this);
}

Guolin Ke's avatar
Guolin Ke committed
827
template <typename VAL_T>
828
829
830
831
832
833
inline uint32_t SparseBinIterator<VAL_T>::RawGet(data_size_t idx) {
  return InnerRawGet(idx);
}

template <typename VAL_T>
inline VAL_T SparseBinIterator<VAL_T>::InnerRawGet(data_size_t idx) {
834
  while (cur_pos_ < idx) {
835
    bin_data_->NextNonzeroFast(&i_delta_, &cur_pos_);
Guolin Ke's avatar
Guolin Ke committed
836
  }
837
  if (cur_pos_ == idx) {
838
839
    return bin_data_->vals_[i_delta_];
  } else {
Guolin Ke's avatar
Guolin Ke committed
840
    return 0;
Guolin Ke's avatar
Guolin Ke committed
841
  }
842
}
Guolin Ke's avatar
Guolin Ke committed
843

844
845
template <typename VAL_T>
inline void SparseBinIterator<VAL_T>::Reset(data_size_t start_idx) {
846
  bin_data_->InitIndex(start_idx, &i_delta_, &cur_pos_);
847
}
Guolin Ke's avatar
Guolin Ke committed
848
849

template <typename VAL_T>
850
851
BinIterator* SparseBin<VAL_T>::GetIterator(uint32_t min_bin, uint32_t max_bin,
                                           uint32_t most_freq_bin) const {
Guolin Ke's avatar
Guolin Ke committed
852
  return new SparseBinIterator<VAL_T>(this, min_bin, max_bin, most_freq_bin);
Guolin Ke's avatar
Guolin Ke committed
853
854
855
}

}  // namespace LightGBM
856

857
#endif  // LightGBM_IO_SPARSE_BIN_HPP_