sparse_bin.hpp 23.1 KB
Newer Older
1
2
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
3
4
 * Licensed under the MIT License. See LICENSE file in the project root for
 * license information.
5
 */
Guolin Ke's avatar
Guolin Ke committed
6
7
8
#ifndef LIGHTGBM_IO_SPARSE_BIN_HPP_
#define LIGHTGBM_IO_SPARSE_BIN_HPP_

9
10
11
12
#include <LightGBM/bin.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>

13
14
15
#include <algorithm>
#include <cstdint>
#include <cstring>
16
#include <limits>
17
#include <utility>
Guolin Ke's avatar
Guolin Ke committed
18
19
20
21
#include <vector>

namespace LightGBM {

22
23
template <typename VAL_T>
class SparseBin;
24

Guolin Ke's avatar
Guolin Ke committed
25
26
const size_t kNumFastIndex = 64;

27
template <typename VAL_T>
28
class SparseBinIterator : public BinIterator {
29
 public:
30
31
32
33
34
35
  SparseBinIterator(const SparseBin<VAL_T>* bin_data, uint32_t min_bin,
                    uint32_t max_bin, uint32_t most_freq_bin)
      : bin_data_(bin_data),
        min_bin_(static_cast<VAL_T>(min_bin)),
        max_bin_(static_cast<VAL_T>(max_bin)),
        most_freq_bin_(static_cast<VAL_T>(most_freq_bin)) {
Guolin Ke's avatar
Guolin Ke committed
36
    if (most_freq_bin_ == 0) {
37
      offset_ = 1;
Guolin Ke's avatar
Guolin Ke committed
38
    } else {
39
      offset_ = 0;
Guolin Ke's avatar
Guolin Ke committed
40
41
42
    }
    Reset(0);
  }
43
  SparseBinIterator(const SparseBin<VAL_T>* bin_data, data_size_t start_idx)
44
      : bin_data_(bin_data) {
45
46
47
    Reset(start_idx);
  }

48
49
  inline uint32_t RawGet(data_size_t idx) override;
  inline VAL_T InnerRawGet(data_size_t idx);
50

51
  inline uint32_t Get(data_size_t idx) override {
Guolin Ke's avatar
Guolin Ke committed
52
    VAL_T ret = InnerRawGet(idx);
Guolin Ke's avatar
Guolin Ke committed
53
    if (ret >= min_bin_ && ret <= max_bin_) {
54
      return ret - min_bin_ + offset_;
Guolin Ke's avatar
Guolin Ke committed
55
    } else {
Guolin Ke's avatar
Guolin Ke committed
56
      return most_freq_bin_;
Guolin Ke's avatar
Guolin Ke committed
57
    }
58
59
  }

Guolin Ke's avatar
Guolin Ke committed
60
  inline void Reset(data_size_t idx) override;
61

62
 private:
63
64
65
  const SparseBin<VAL_T>* bin_data_;
  data_size_t cur_pos_;
  data_size_t i_delta_;
Guolin Ke's avatar
Guolin Ke committed
66
67
  VAL_T min_bin_;
  VAL_T max_bin_;
Guolin Ke's avatar
Guolin Ke committed
68
  VAL_T most_freq_bin_;
69
  uint8_t offset_;
70
71
};

Guolin Ke's avatar
Guolin Ke committed
72
template <typename VAL_T>
73
class SparseBin : public Bin {
74
 public:
Guolin Ke's avatar
Guolin Ke committed
75
76
  friend class SparseBinIterator<VAL_T>;

77
  explicit SparseBin(data_size_t num_data) : num_data_(num_data) {
78
    int num_threads = OMP_NUM_THREADS();
Guolin Ke's avatar
Guolin Ke committed
79
    push_buffers_.resize(num_threads);
Guolin Ke's avatar
Guolin Ke committed
80
81
  }

82
  ~SparseBin() {}
Guolin Ke's avatar
Guolin Ke committed
83

84
85
86
87
88
89
  void InitStreaming(uint32_t num_thread) override {
    // Each thread needs its own push buffer, so allocate external num_thread times the number of OMP threads
    int num_omp_threads = OMP_NUM_THREADS();
    push_buffers_.resize(num_omp_threads * num_thread);
  };

90
  void ReSize(data_size_t num_data) override { num_data_ = num_data; }
Guolin Ke's avatar
Guolin Ke committed
91
92

  void Push(int tid, data_size_t idx, uint32_t value) override {
93
    auto cur_bin = static_cast<VAL_T>(value);
Guolin Ke's avatar
Guolin Ke committed
94
    if (cur_bin != 0) {
95
96
      push_buffers_[tid].emplace_back(idx, cur_bin);
    }
Guolin Ke's avatar
Guolin Ke committed
97
98
  }

99
100
  BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin,
                           uint32_t most_freq_bin) const override;
Guolin Ke's avatar
Guolin Ke committed
101

102
#define ACC_GH(hist, i, g, h)               \
103
  const auto ti = static_cast<int>(i) << 1; \
104
105
  hist[ti] += g;                            \
  hist[ti + 1] += h;
106

107
108
109
110
  void ConstructHistogram(const data_size_t* data_indices, data_size_t start,
                          data_size_t end, const score_t* ordered_gradients,
                          const score_t* ordered_hessians,
                          hist_t* out) const override {
111
112
113
114
115
116
    data_size_t i_delta, cur_pos;
    InitIndex(data_indices[start], &i_delta, &cur_pos);
    data_size_t i = start;
    for (;;) {
      if (cur_pos < data_indices[i]) {
        cur_pos += deltas_[++i_delta];
117
118
119
        if (i_delta >= num_vals_) {
          break;
        }
120
      } else if (cur_pos > data_indices[i]) {
121
122
123
        if (++i >= end) {
          break;
        }
124
125
126
      } else {
        const VAL_T bin = vals_[i_delta];
        ACC_GH(out, bin, ordered_gradients[i], ordered_hessians[i]);
127
128
129
        if (++i >= end) {
          break;
        }
130
        cur_pos += deltas_[++i_delta];
131
132
133
        if (i_delta >= num_vals_) {
          break;
        }
134
135
      }
    }
Guolin Ke's avatar
Guolin Ke committed
136
137
  }

138
  void ConstructHistogram(data_size_t start, data_size_t end,
139
140
141
                          const score_t* ordered_gradients,
                          const score_t* ordered_hessians,
                          hist_t* out) const override {
142
143
144
145
146
147
148
149
150
151
    data_size_t i_delta, cur_pos;
    InitIndex(start, &i_delta, &cur_pos);
    while (cur_pos < start && i_delta < num_vals_) {
      cur_pos += deltas_[++i_delta];
    }
    while (cur_pos < end && i_delta < num_vals_) {
      const VAL_T bin = vals_[i_delta];
      ACC_GH(out, bin, ordered_gradients[cur_pos], ordered_hessians[cur_pos]);
      cur_pos += deltas_[++i_delta];
    }
152
153
  }

154
155
156
  void ConstructHistogram(const data_size_t* data_indices, data_size_t start,
                          data_size_t end, const score_t* ordered_gradients,
                          hist_t* out) const override {
157
158
159
    data_size_t i_delta, cur_pos;
    InitIndex(data_indices[start], &i_delta, &cur_pos);
    data_size_t i = start;
Guolin Ke's avatar
Guolin Ke committed
160
161
    hist_t* grad = out;
    hist_cnt_t* cnt = reinterpret_cast<hist_cnt_t*>(out + 1);
162
163
164
    for (;;) {
      if (cur_pos < data_indices[i]) {
        cur_pos += deltas_[++i_delta];
165
166
167
        if (i_delta >= num_vals_) {
          break;
        }
168
      } else if (cur_pos > data_indices[i]) {
169
170
171
        if (++i >= end) {
          break;
        }
172
      } else {
Guolin Ke's avatar
Guolin Ke committed
173
174
175
        const uint32_t ti = static_cast<uint32_t>(vals_[i_delta]) << 1;
        grad[ti] += ordered_gradients[i];
        ++cnt[ti];
176
177
178
        if (++i >= end) {
          break;
        }
179
        cur_pos += deltas_[++i_delta];
180
181
182
        if (i_delta >= num_vals_) {
          break;
        }
183
184
      }
    }
185
186
  }

187
  void ConstructHistogram(data_size_t start, data_size_t end,
188
189
                          const score_t* ordered_gradients,
                          hist_t* out) const override {
190
191
    data_size_t i_delta, cur_pos;
    InitIndex(start, &i_delta, &cur_pos);
Guolin Ke's avatar
Guolin Ke committed
192
193
    hist_t* grad = out;
    hist_cnt_t* cnt = reinterpret_cast<hist_cnt_t*>(out + 1);
194
195
196
197
    while (cur_pos < start && i_delta < num_vals_) {
      cur_pos += deltas_[++i_delta];
    }
    while (cur_pos < end && i_delta < num_vals_) {
Guolin Ke's avatar
Guolin Ke committed
198
199
200
      const uint32_t ti = static_cast<uint32_t>(vals_[i_delta]) << 1;
      grad[ti] += ordered_gradients[cur_pos];
      ++cnt[ti];
201
202
      cur_pos += deltas_[++i_delta];
    }
203
  }
204
#undef ACC_GH
205

206
207
  inline void NextNonzeroFast(data_size_t* i_delta,
                              data_size_t* cur_pos) const {
208
209
210
    *cur_pos += deltas_[++(*i_delta)];
    if (*i_delta >= num_vals_) {
      *cur_pos = num_data_;
211
    }
212
213
  }

214
  inline bool NextNonzero(data_size_t* i_delta, data_size_t* cur_pos) const {
215
    *cur_pos += deltas_[++(*i_delta)];
216
    if (*i_delta < num_vals_) {
217
218
      return true;
    } else {
219
      *cur_pos = num_data_;
220
221
222
223
      return false;
    }
  }

224
225
226
227
228
229
230
231
232
233
  template <bool MISS_IS_ZERO, bool MISS_IS_NA, bool MFB_IS_ZERO,
            bool MFB_IS_NA, bool USE_MIN_BIN>
  data_size_t SplitInner(uint32_t min_bin, uint32_t max_bin,
                         uint32_t default_bin, uint32_t most_freq_bin,
                         bool default_left, uint32_t threshold,
                         const data_size_t* data_indices, data_size_t cnt,
                         data_size_t* lte_indices,
                         data_size_t* gt_indices) const {
    auto th = static_cast<VAL_T>(threshold + min_bin);
    auto t_zero_bin = static_cast<VAL_T>(min_bin + default_bin);
Guolin Ke's avatar
Guolin Ke committed
234
    if (most_freq_bin == 0) {
235
236
      --th;
      --t_zero_bin;
Guolin Ke's avatar
Guolin Ke committed
237
    }
238
239
    const auto minb = static_cast<VAL_T>(min_bin);
    const auto maxb = static_cast<VAL_T>(max_bin);
Guolin Ke's avatar
Guolin Ke committed
240
241
    data_size_t lte_count = 0;
    data_size_t gt_count = 0;
Guolin Ke's avatar
Guolin Ke committed
242
243
    data_size_t* default_indices = gt_indices;
    data_size_t* default_count = &gt_count;
Guolin Ke's avatar
Guolin Ke committed
244
245
246
247
248
249
    data_size_t* missing_default_indices = gt_indices;
    data_size_t* missing_default_count = &gt_count;
    if (most_freq_bin <= threshold) {
      default_indices = lte_indices;
      default_count = &lte_count;
    }
250
    if (MISS_IS_ZERO || MISS_IS_NA) {
251
252
253
      if (default_left) {
        missing_default_indices = lte_indices;
        missing_default_count = &lte_count;
Guolin Ke's avatar
Guolin Ke committed
254
      }
255
256
257
258
259
260
261
262
263
264
265
266
    }
    SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
    if (min_bin < max_bin) {
      for (data_size_t i = 0; i < cnt; ++i) {
        const data_size_t idx = data_indices[i];
        const auto bin = iterator.InnerRawGet(idx);
        if ((MISS_IS_ZERO && !MFB_IS_ZERO && bin == t_zero_bin) ||
            (MISS_IS_NA && !MFB_IS_NA && bin == maxb)) {
          missing_default_indices[(*missing_default_count)++] = idx;
        } else if ((USE_MIN_BIN && (bin < minb || bin > maxb)) ||
                   (!USE_MIN_BIN && bin == 0)) {
          if ((MISS_IS_NA && MFB_IS_NA) || (MISS_IS_ZERO && MFB_IS_ZERO)) {
267
268
269
270
            missing_default_indices[(*missing_default_count)++] = idx;
          } else {
            default_indices[(*default_count)++] = idx;
          }
271
272
273
274
        } else if (bin > th) {
          gt_indices[gt_count++] = idx;
        } else {
          lte_indices[lte_count++] = idx;
275
276
277
        }
      }
    } else {
278
279
280
281
282
      data_size_t* max_bin_indices = gt_indices;
      data_size_t* max_bin_count = &gt_count;
      if (maxb <= th) {
        max_bin_indices = lte_indices;
        max_bin_count = &lte_count;
283
      }
284
285
286
287
288
289
290
      for (data_size_t i = 0; i < cnt; ++i) {
        const data_size_t idx = data_indices[i];
        const auto bin = iterator.InnerRawGet(idx);
        if (MISS_IS_ZERO && !MFB_IS_ZERO && bin == t_zero_bin) {
          missing_default_indices[(*missing_default_count)++] = idx;
        } else if (bin != maxb) {
          if ((MISS_IS_NA && MFB_IS_NA) || (MISS_IS_ZERO && MFB_IS_ZERO)) {
Guolin Ke's avatar
Guolin Ke committed
291
292
            missing_default_indices[(*missing_default_count)++] = idx;
          } else {
293
            default_indices[(*default_count)++] = idx;
Guolin Ke's avatar
Guolin Ke committed
294
          }
295
296
        } else {
          if (MISS_IS_NA && !MFB_IS_NA) {
Guolin Ke's avatar
Guolin Ke committed
297
298
            missing_default_indices[(*missing_default_count)++] = idx;
          } else {
299
            max_bin_indices[(*max_bin_count)++] = idx;
Guolin Ke's avatar
Guolin Ke committed
300
          }
301
        }
Guolin Ke's avatar
Guolin Ke committed
302
      }
303
    }
304
305
306
    return lte_count;
  }

307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
  data_size_t Split(uint32_t min_bin, uint32_t max_bin, uint32_t default_bin,
                    uint32_t most_freq_bin, MissingType missing_type,
                    bool default_left, uint32_t threshold,
                    const data_size_t* data_indices, data_size_t cnt,
                    data_size_t* lte_indices,
                    data_size_t* gt_indices) const override {
#define ARGUMENTS                                                        \
  min_bin, max_bin, default_bin, most_freq_bin, default_left, threshold, \
      data_indices, cnt, lte_indices, gt_indices
    if (missing_type == MissingType::None) {
      return SplitInner<false, false, false, false, true>(ARGUMENTS);
    } else if (missing_type == MissingType::Zero) {
      if (default_bin == most_freq_bin) {
        return SplitInner<true, false, true, false, true>(ARGUMENTS);
      } else {
        return SplitInner<true, false, false, false, true>(ARGUMENTS);
      }
    } else {
      if (max_bin == most_freq_bin + min_bin && most_freq_bin > 0) {
        return SplitInner<false, true, false, true, true>(ARGUMENTS);
      } else {
        return SplitInner<false, true, false, false, true>(ARGUMENTS);
      }
    }
#undef ARGUMENTS
  }

  data_size_t Split(uint32_t max_bin, uint32_t default_bin,
                    uint32_t most_freq_bin, MissingType missing_type,
                    bool default_left, uint32_t threshold,
                    const data_size_t* data_indices, data_size_t cnt,
                    data_size_t* lte_indices,
                    data_size_t* gt_indices) const override {
#define ARGUMENTS                                                  \
  1, max_bin, default_bin, most_freq_bin, default_left, threshold, \
      data_indices, cnt, lte_indices, gt_indices
    if (missing_type == MissingType::None) {
      return SplitInner<false, false, false, false, false>(ARGUMENTS);
    } else if (missing_type == MissingType::Zero) {
      if (default_bin == most_freq_bin) {
        return SplitInner<true, false, true, false, false>(ARGUMENTS);
      } else {
        return SplitInner<true, false, false, false, false>(ARGUMENTS);
      }
    } else {
      if (max_bin == most_freq_bin + 1 && most_freq_bin > 0) {
        return SplitInner<false, true, false, true, false>(ARGUMENTS);
      } else {
        return SplitInner<false, true, false, false, false>(ARGUMENTS);
      }
    }
#undef ARGUMENTS
  }
  template <bool USE_MIN_BIN>
  data_size_t SplitCategoricalInner(uint32_t min_bin, uint32_t max_bin,
                                    uint32_t most_freq_bin,
                                    const uint32_t* threshold,
Nikita Titov's avatar
Nikita Titov committed
364
                                    int num_threshold,
365
366
367
                                    const data_size_t* data_indices,
                                    data_size_t cnt, data_size_t* lte_indices,
                                    data_size_t* gt_indices) const {
368
369
370
371
    data_size_t lte_count = 0;
    data_size_t gt_count = 0;
    data_size_t* default_indices = gt_indices;
    data_size_t* default_count = &gt_count;
372
    SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
373
    int8_t offset = most_freq_bin == 0 ? 1 : 0;
Nikita Titov's avatar
Nikita Titov committed
374
    if (most_freq_bin > 0 && Common::FindInBitset(threshold, num_threshold, most_freq_bin)) {
375
376
377
      default_indices = lte_indices;
      default_count = &lte_count;
    }
378
    for (data_size_t i = 0; i < cnt; ++i) {
379
      const data_size_t idx = data_indices[i];
380
381
382
383
      const uint32_t bin = iterator.RawGet(idx);
      if (USE_MIN_BIN && (bin < min_bin || bin > max_bin)) {
        default_indices[(*default_count)++] = idx;
      } else if (!USE_MIN_BIN && bin == 0) {
384
        default_indices[(*default_count)++] = idx;
Nikita Titov's avatar
Nikita Titov committed
385
      } else if (Common::FindInBitset(threshold, num_threshold,
386
                                      bin - min_bin + offset)) {
387
388
389
390
        lte_indices[lte_count++] = idx;
      } else {
        gt_indices[gt_count++] = idx;
      }
Guolin Ke's avatar
Guolin Ke committed
391
392
393
394
    }
    return lte_count;
  }

395
396
  data_size_t SplitCategorical(uint32_t min_bin, uint32_t max_bin,
                               uint32_t most_freq_bin,
Nikita Titov's avatar
Nikita Titov committed
397
                               const uint32_t* threshold, int num_threshold,
398
399
400
401
                               const data_size_t* data_indices, data_size_t cnt,
                               data_size_t* lte_indices,
                               data_size_t* gt_indices) const override {
    return SplitCategoricalInner<true>(min_bin, max_bin, most_freq_bin,
Nikita Titov's avatar
Nikita Titov committed
402
                                       threshold, num_threshold, data_indices,
403
404
405
406
                                       cnt, lte_indices, gt_indices);
  }

  data_size_t SplitCategorical(uint32_t max_bin, uint32_t most_freq_bin,
Nikita Titov's avatar
Nikita Titov committed
407
                               const uint32_t* threshold, int num_threshold,
408
409
410
411
                               const data_size_t* data_indices, data_size_t cnt,
                               data_size_t* lte_indices,
                               data_size_t* gt_indices) const override {
    return SplitCategoricalInner<false>(1, max_bin, most_freq_bin, threshold,
Nikita Titov's avatar
Nikita Titov committed
412
                                        num_threshold, data_indices, cnt,
413
414
415
                                        lte_indices, gt_indices);
  }

Guolin Ke's avatar
Guolin Ke committed
416
417
  data_size_t num_data() const override { return num_data_; }

418
419
  void* get_data() override { return nullptr; }

Guolin Ke's avatar
Guolin Ke committed
420
421
  void FinishLoad() override {
    // get total non zero size
422
    size_t pair_cnt = 0;
423
    for (size_t i = 0; i < push_buffers_.size(); ++i) {
424
      pair_cnt += push_buffers_[i].size();
Guolin Ke's avatar
Guolin Ke committed
425
    }
426
427
    std::vector<std::pair<data_size_t, VAL_T>>& idx_val_pairs =
        push_buffers_[0];
428
    idx_val_pairs.reserve(pair_cnt);
Guolin Ke's avatar
Guolin Ke committed
429
430

    for (size_t i = 1; i < push_buffers_.size(); ++i) {
431
432
      idx_val_pairs.insert(idx_val_pairs.end(), push_buffers_[i].begin(),
                           push_buffers_[i].end());
Guolin Ke's avatar
Guolin Ke committed
433
434
435
436
      push_buffers_[i].clear();
      push_buffers_[i].shrink_to_fit();
    }
    // sort by data index
437
    std::sort(idx_val_pairs.begin(), idx_val_pairs.end(),
438
439
440
441
              [](const std::pair<data_size_t, VAL_T>& a,
                 const std::pair<data_size_t, VAL_T>& b) {
                return a.first < b.first;
              });
zhangyafeikimi's avatar
zhangyafeikimi committed
442
    // load delta array
443
    LoadFromPair(idx_val_pairs);
Guolin Ke's avatar
Guolin Ke committed
444
445
  }

446
447
  void LoadFromPair(
      const std::vector<std::pair<data_size_t, VAL_T>>& idx_val_pairs) {
448
    deltas_.clear();
Guolin Ke's avatar
Guolin Ke committed
449
    vals_.clear();
Guolin Ke's avatar
Guolin Ke committed
450
451
    deltas_.reserve(idx_val_pairs.size());
    vals_.reserve(idx_val_pairs.size());
Guolin Ke's avatar
Guolin Ke committed
452
453
    // transform to delta array
    data_size_t last_idx = 0;
454
455
456
    for (size_t i = 0; i < idx_val_pairs.size(); ++i) {
      const data_size_t cur_idx = idx_val_pairs[i].first;
      const VAL_T bin = idx_val_pairs[i].second;
Guolin Ke's avatar
Guolin Ke committed
457
      data_size_t cur_delta = cur_idx - last_idx;
458
      // disallow the multi-val in one row
459
460
461
      if (i > 0 && cur_delta == 0) {
        continue;
      }
462
      while (cur_delta >= 256) {
463
        deltas_.push_back(255);
Guolin Ke's avatar
Guolin Ke committed
464
        vals_.push_back(0);
465
        cur_delta -= 255;
Guolin Ke's avatar
Guolin Ke committed
466
      }
467
      deltas_.push_back(static_cast<uint8_t>(cur_delta));
Guolin Ke's avatar
Guolin Ke committed
468
469
470
471
      vals_.push_back(bin);
      last_idx = cur_idx;
    }
    // avoid out of range
472
    deltas_.push_back(0);
Guolin Ke's avatar
Guolin Ke committed
473
474
475
    num_vals_ = static_cast<data_size_t>(vals_.size());

    // reduce memory cost
476
    deltas_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
    vals_.shrink_to_fit();

    // generate fast index
    GetFastIndex();
  }

  void GetFastIndex() {
    fast_index_.clear();
    // get shift cnt
    data_size_t mod_size = (num_data_ + kNumFastIndex - 1) / kNumFastIndex;
    data_size_t pow2_mod_size = 1;
    fast_index_shift_ = 0;
    while (pow2_mod_size < mod_size) {
      pow2_mod_size <<= 1;
      ++fast_index_shift_;
    }
    // build fast index
494
    data_size_t i_delta = -1;
Guolin Ke's avatar
Guolin Ke committed
495
    data_size_t cur_pos = 0;
496
497
    data_size_t next_threshold = 0;
    while (NextNonzero(&i_delta, &cur_pos)) {
Guolin Ke's avatar
Guolin Ke committed
498
      while (next_threshold <= cur_pos) {
499
500
        fast_index_.emplace_back(i_delta, cur_pos);
        next_threshold += pow2_mod_size;
Guolin Ke's avatar
Guolin Ke committed
501
502
503
      }
    }
    // avoid out of range
504
    while (next_threshold < num_data_) {
Guolin Ke's avatar
Guolin Ke committed
505
      fast_index_.emplace_back(num_vals_ - 1, cur_pos);
506
      next_threshold += pow2_mod_size;
Guolin Ke's avatar
Guolin Ke committed
507
508
509
510
    }
    fast_index_.shrink_to_fit();
  }

511
  void SaveBinaryToFile(const VirtualFileWriter* writer) const override {
512
513
514
    writer->AlignedWrite(&num_vals_, sizeof(num_vals_));
    writer->AlignedWrite(deltas_.data(), sizeof(uint8_t) * (num_vals_ + 1));
    writer->AlignedWrite(vals_.data(), sizeof(VAL_T) * num_vals_);
Guolin Ke's avatar
Guolin Ke committed
515
516
517
  }

  size_t SizesInByte() const override {
518
519
520
    return VirtualFileWriter::AlignedSize(sizeof(num_vals_)) +
           VirtualFileWriter::AlignedSize(sizeof(uint8_t) * (num_vals_ + 1)) +
           VirtualFileWriter::AlignedSize(sizeof(VAL_T) * num_vals_);
Guolin Ke's avatar
Guolin Ke committed
521
522
  }

523
524
525
  void LoadFromMemory(
      const void* memory,
      const std::vector<data_size_t>& local_used_indices) override {
Guolin Ke's avatar
Guolin Ke committed
526
527
    const char* mem_ptr = reinterpret_cast<const char*>(memory);
    data_size_t tmp_num_vals = *(reinterpret_cast<const data_size_t*>(mem_ptr));
528
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(tmp_num_vals));
Guolin Ke's avatar
Guolin Ke committed
529
    const uint8_t* tmp_delta = reinterpret_cast<const uint8_t*>(mem_ptr);
530
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(uint8_t) * (tmp_num_vals + 1));
Guolin Ke's avatar
Guolin Ke committed
531
532
    const VAL_T* tmp_vals = reinterpret_cast<const VAL_T*>(mem_ptr);

533
534
535
536
537
538
539
540
541
542
543
    deltas_.clear();
    vals_.clear();
    num_vals_ = tmp_num_vals;
    for (data_size_t i = 0; i < num_vals_; ++i) {
      deltas_.push_back(tmp_delta[i]);
      vals_.push_back(tmp_vals[i]);
    }
    deltas_.push_back(0);
    // reduce memory cost
    deltas_.shrink_to_fit();
    vals_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
544

Guolin Ke's avatar
Guolin Ke committed
545
    if (local_used_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
546
547
548
549
      // generate fast index
      GetFastIndex();
    } else {
      std::vector<std::pair<data_size_t, VAL_T>> tmp_pair;
550
551
      data_size_t cur_pos = 0;
      data_size_t j = -1;
552
553
      for (data_size_t i = 0;
           i < static_cast<data_size_t>(local_used_indices.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
554
        const data_size_t idx = local_used_indices[i];
555
556
        while (cur_pos < idx && j < num_vals_) {
          NextNonzero(&j, &cur_pos);
Guolin Ke's avatar
Guolin Ke committed
557
        }
558
        if (cur_pos == idx && j < num_vals_ && vals_[j] > 0) {
Guolin Ke's avatar
Guolin Ke committed
559
          // new row index is i
560
          tmp_pair.emplace_back(i, vals_[j]);
Guolin Ke's avatar
Guolin Ke committed
561
562
563
564
        }
      }
      LoadFromPair(tmp_pair);
    }
565
  }
Guolin Ke's avatar
Guolin Ke committed
566

567
568
  void CopySubrow(const Bin* full_bin, const data_size_t* used_indices,
                  data_size_t num_used_indices) override {
Guolin Ke's avatar
Guolin Ke committed
569
    auto other_bin = dynamic_cast<const SparseBin<VAL_T>*>(full_bin);
Guolin Ke's avatar
Guolin Ke committed
570
571
    deltas_.clear();
    vals_.clear();
Guolin Ke's avatar
Guolin Ke committed
572
573
574
575
576
    data_size_t start = 0;
    if (num_used_indices > 0) {
      start = used_indices[0];
    }
    SparseBinIterator<VAL_T> iterator(other_bin, start);
577
578
    // transform to delta array
    data_size_t last_idx = 0;
579
    for (data_size_t i = 0; i < num_used_indices; ++i) {
580
      auto bin = iterator.InnerRawGet(used_indices[i]);
Guolin Ke's avatar
Guolin Ke committed
581
      if (bin > 0) {
582
583
        data_size_t cur_delta = i - last_idx;
        while (cur_delta >= 256) {
584
          deltas_.push_back(255);
585
          vals_.push_back(0);
586
          cur_delta -= 255;
587
588
589
590
        }
        deltas_.push_back(static_cast<uint8_t>(cur_delta));
        vals_.push_back(bin);
        last_idx = i;
591
592
      }
    }
593
594
595
596
597
598
599
600
601
602
    // avoid out of range
    deltas_.push_back(0);
    num_vals_ = static_cast<data_size_t>(vals_.size());

    // reduce memory cost
    deltas_.shrink_to_fit();
    vals_.shrink_to_fit();

    // generate fast index
    GetFastIndex();
Guolin Ke's avatar
Guolin Ke committed
603
604
  }

605
606
607
  SparseBin<VAL_T>* Clone() override;

  SparseBin<VAL_T>(const SparseBin<VAL_T>& other)
608
609
610
611
612
613
614
615
616
617
      : num_data_(other.num_data_),
        deltas_(other.deltas_),
        vals_(other.vals_),
        num_vals_(other.num_vals_),
        push_buffers_(other.push_buffers_),
        fast_index_(other.fast_index_),
        fast_index_shift_(other.fast_index_shift_) {}

  void InitIndex(data_size_t start_idx, data_size_t* i_delta,
                 data_size_t* cur_pos) const {
618
619
620
621
622
623
624
625
626
627
628
    auto idx = start_idx >> fast_index_shift_;
    if (static_cast<size_t>(idx) < fast_index_.size()) {
      const auto fast_pair = fast_index_[start_idx >> fast_index_shift_];
      *i_delta = fast_pair.first;
      *cur_pos = fast_pair.second;
    } else {
      *i_delta = -1;
      *cur_pos = 0;
    }
  }

629
630
631
632
  const void* GetColWiseData(uint8_t* bit_type, bool* is_sparse, std::vector<BinIterator*>* bin_iterator, const int num_threads) const override;

  const void* GetColWiseData(uint8_t* bit_type, bool* is_sparse, BinIterator** bin_iterator) const override;

633
 private:
Guolin Ke's avatar
Guolin Ke committed
634
  data_size_t num_data_;
635
636
  std::vector<uint8_t, Common::AlignmentAllocator<uint8_t, kAlignedSize>>
      deltas_;
637
  std::vector<VAL_T, Common::AlignmentAllocator<VAL_T, kAlignedSize>> vals_;
Guolin Ke's avatar
Guolin Ke committed
638
639
640
641
642
643
  data_size_t num_vals_;
  std::vector<std::vector<std::pair<data_size_t, VAL_T>>> push_buffers_;
  std::vector<std::pair<data_size_t, data_size_t>> fast_index_;
  data_size_t fast_index_shift_;
};

644
template <typename VAL_T>
645
SparseBin<VAL_T>* SparseBin<VAL_T>::Clone() {
646
647
648
  return new SparseBin(*this);
}

Guolin Ke's avatar
Guolin Ke committed
649
template <typename VAL_T>
650
651
652
653
654
655
inline uint32_t SparseBinIterator<VAL_T>::RawGet(data_size_t idx) {
  return InnerRawGet(idx);
}

template <typename VAL_T>
inline VAL_T SparseBinIterator<VAL_T>::InnerRawGet(data_size_t idx) {
656
  while (cur_pos_ < idx) {
657
    bin_data_->NextNonzeroFast(&i_delta_, &cur_pos_);
Guolin Ke's avatar
Guolin Ke committed
658
  }
659
  if (cur_pos_ == idx) {
660
661
    return bin_data_->vals_[i_delta_];
  } else {
Guolin Ke's avatar
Guolin Ke committed
662
    return 0;
Guolin Ke's avatar
Guolin Ke committed
663
  }
664
}
Guolin Ke's avatar
Guolin Ke committed
665

666
667
template <typename VAL_T>
inline void SparseBinIterator<VAL_T>::Reset(data_size_t start_idx) {
668
  bin_data_->InitIndex(start_idx, &i_delta_, &cur_pos_);
669
}
Guolin Ke's avatar
Guolin Ke committed
670
671

template <typename VAL_T>
672
673
BinIterator* SparseBin<VAL_T>::GetIterator(uint32_t min_bin, uint32_t max_bin,
                                           uint32_t most_freq_bin) const {
Guolin Ke's avatar
Guolin Ke committed
674
  return new SparseBinIterator<VAL_T>(this, min_bin, max_bin, most_freq_bin);
Guolin Ke's avatar
Guolin Ke committed
675
676
677
}

}  // namespace LightGBM
678

679
#endif  // LightGBM_IO_SPARSE_BIN_HPP_