sparse_bin.hpp 22.4 KB
Newer Older
1
2
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
3
4
 * Licensed under the MIT License. See LICENSE file in the project root for
 * license information.
5
 */
Guolin Ke's avatar
Guolin Ke committed
6
7
8
#ifndef LIGHTGBM_IO_SPARSE_BIN_HPP_
#define LIGHTGBM_IO_SPARSE_BIN_HPP_

9
10
11
12
#include <LightGBM/bin.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>

13
14
15
#include <algorithm>
#include <cstdint>
#include <cstring>
16
#include <limits>
17
#include <utility>
Guolin Ke's avatar
Guolin Ke committed
18
19
20
21
#include <vector>

namespace LightGBM {

22
23
template <typename VAL_T>
class SparseBin;
24

Guolin Ke's avatar
Guolin Ke committed
25
26
const size_t kNumFastIndex = 64;

27
template <typename VAL_T>
28
class SparseBinIterator : public BinIterator {
29
 public:
30
31
32
33
34
35
  SparseBinIterator(const SparseBin<VAL_T>* bin_data, uint32_t min_bin,
                    uint32_t max_bin, uint32_t most_freq_bin)
      : bin_data_(bin_data),
        min_bin_(static_cast<VAL_T>(min_bin)),
        max_bin_(static_cast<VAL_T>(max_bin)),
        most_freq_bin_(static_cast<VAL_T>(most_freq_bin)) {
Guolin Ke's avatar
Guolin Ke committed
36
    if (most_freq_bin_ == 0) {
37
      offset_ = 1;
Guolin Ke's avatar
Guolin Ke committed
38
    } else {
39
      offset_ = 0;
Guolin Ke's avatar
Guolin Ke committed
40
41
42
    }
    Reset(0);
  }
43
  SparseBinIterator(const SparseBin<VAL_T>* bin_data, data_size_t start_idx)
44
      : bin_data_(bin_data) {
45
46
47
    Reset(start_idx);
  }

48
49
  inline uint32_t RawGet(data_size_t idx) override;
  inline VAL_T InnerRawGet(data_size_t idx);
50

51
  inline uint32_t Get(data_size_t idx) override {
Guolin Ke's avatar
Guolin Ke committed
52
    VAL_T ret = InnerRawGet(idx);
Guolin Ke's avatar
Guolin Ke committed
53
    if (ret >= min_bin_ && ret <= max_bin_) {
54
      return ret - min_bin_ + offset_;
Guolin Ke's avatar
Guolin Ke committed
55
    } else {
Guolin Ke's avatar
Guolin Ke committed
56
      return most_freq_bin_;
Guolin Ke's avatar
Guolin Ke committed
57
    }
58
59
  }

Guolin Ke's avatar
Guolin Ke committed
60
  inline void Reset(data_size_t idx) override;
61

62
 private:
63
64
65
  const SparseBin<VAL_T>* bin_data_;
  data_size_t cur_pos_;
  data_size_t i_delta_;
Guolin Ke's avatar
Guolin Ke committed
66
67
  VAL_T min_bin_;
  VAL_T max_bin_;
Guolin Ke's avatar
Guolin Ke committed
68
  VAL_T most_freq_bin_;
69
  uint8_t offset_;
70
71
};

Guolin Ke's avatar
Guolin Ke committed
72
template <typename VAL_T>
73
class SparseBin : public Bin {
74
 public:
Guolin Ke's avatar
Guolin Ke committed
75
76
  friend class SparseBinIterator<VAL_T>;

77
  explicit SparseBin(data_size_t num_data) : num_data_(num_data) {
78
    int num_threads = OMP_NUM_THREADS();
Guolin Ke's avatar
Guolin Ke committed
79
    push_buffers_.resize(num_threads);
Guolin Ke's avatar
Guolin Ke committed
80
81
  }

82
  ~SparseBin() {}
Guolin Ke's avatar
Guolin Ke committed
83

84
  void ReSize(data_size_t num_data) override { num_data_ = num_data; }
Guolin Ke's avatar
Guolin Ke committed
85
86

  void Push(int tid, data_size_t idx, uint32_t value) override {
87
    auto cur_bin = static_cast<VAL_T>(value);
Guolin Ke's avatar
Guolin Ke committed
88
    if (cur_bin != 0) {
89
90
      push_buffers_[tid].emplace_back(idx, cur_bin);
    }
Guolin Ke's avatar
Guolin Ke committed
91
92
  }

93
94
  BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin,
                           uint32_t most_freq_bin) const override;
Guolin Ke's avatar
Guolin Ke committed
95

96
#define ACC_GH(hist, i, g, h)               \
97
  const auto ti = static_cast<int>(i) << 1; \
98
99
  hist[ti] += g;                            \
  hist[ti + 1] += h;
100

101
102
103
104
  void ConstructHistogram(const data_size_t* data_indices, data_size_t start,
                          data_size_t end, const score_t* ordered_gradients,
                          const score_t* ordered_hessians,
                          hist_t* out) const override {
105
106
107
108
109
110
    data_size_t i_delta, cur_pos;
    InitIndex(data_indices[start], &i_delta, &cur_pos);
    data_size_t i = start;
    for (;;) {
      if (cur_pos < data_indices[i]) {
        cur_pos += deltas_[++i_delta];
111
112
113
        if (i_delta >= num_vals_) {
          break;
        }
114
      } else if (cur_pos > data_indices[i]) {
115
116
117
        if (++i >= end) {
          break;
        }
118
119
120
      } else {
        const VAL_T bin = vals_[i_delta];
        ACC_GH(out, bin, ordered_gradients[i], ordered_hessians[i]);
121
122
123
        if (++i >= end) {
          break;
        }
124
        cur_pos += deltas_[++i_delta];
125
126
127
        if (i_delta >= num_vals_) {
          break;
        }
128
129
      }
    }
Guolin Ke's avatar
Guolin Ke committed
130
131
  }

132
  void ConstructHistogram(data_size_t start, data_size_t end,
133
134
135
                          const score_t* ordered_gradients,
                          const score_t* ordered_hessians,
                          hist_t* out) const override {
136
137
138
139
140
141
142
143
144
145
    data_size_t i_delta, cur_pos;
    InitIndex(start, &i_delta, &cur_pos);
    while (cur_pos < start && i_delta < num_vals_) {
      cur_pos += deltas_[++i_delta];
    }
    while (cur_pos < end && i_delta < num_vals_) {
      const VAL_T bin = vals_[i_delta];
      ACC_GH(out, bin, ordered_gradients[cur_pos], ordered_hessians[cur_pos]);
      cur_pos += deltas_[++i_delta];
    }
146
147
  }

148
149
150
  void ConstructHistogram(const data_size_t* data_indices, data_size_t start,
                          data_size_t end, const score_t* ordered_gradients,
                          hist_t* out) const override {
151
152
153
    data_size_t i_delta, cur_pos;
    InitIndex(data_indices[start], &i_delta, &cur_pos);
    data_size_t i = start;
Guolin Ke's avatar
Guolin Ke committed
154
155
    hist_t* grad = out;
    hist_cnt_t* cnt = reinterpret_cast<hist_cnt_t*>(out + 1);
156
157
158
    for (;;) {
      if (cur_pos < data_indices[i]) {
        cur_pos += deltas_[++i_delta];
159
160
161
        if (i_delta >= num_vals_) {
          break;
        }
162
      } else if (cur_pos > data_indices[i]) {
163
164
165
        if (++i >= end) {
          break;
        }
166
      } else {
Guolin Ke's avatar
Guolin Ke committed
167
168
169
        const uint32_t ti = static_cast<uint32_t>(vals_[i_delta]) << 1;
        grad[ti] += ordered_gradients[i];
        ++cnt[ti];
170
171
172
        if (++i >= end) {
          break;
        }
173
        cur_pos += deltas_[++i_delta];
174
175
176
        if (i_delta >= num_vals_) {
          break;
        }
177
178
      }
    }
179
180
  }

181
  void ConstructHistogram(data_size_t start, data_size_t end,
182
183
                          const score_t* ordered_gradients,
                          hist_t* out) const override {
184
185
    data_size_t i_delta, cur_pos;
    InitIndex(start, &i_delta, &cur_pos);
Guolin Ke's avatar
Guolin Ke committed
186
187
    hist_t* grad = out;
    hist_cnt_t* cnt = reinterpret_cast<hist_cnt_t*>(out + 1);
188
189
190
191
    while (cur_pos < start && i_delta < num_vals_) {
      cur_pos += deltas_[++i_delta];
    }
    while (cur_pos < end && i_delta < num_vals_) {
Guolin Ke's avatar
Guolin Ke committed
192
193
194
      const uint32_t ti = static_cast<uint32_t>(vals_[i_delta]) << 1;
      grad[ti] += ordered_gradients[cur_pos];
      ++cnt[ti];
195
196
      cur_pos += deltas_[++i_delta];
    }
197
  }
198
#undef ACC_GH
199

200
201
  inline void NextNonzeroFast(data_size_t* i_delta,
                              data_size_t* cur_pos) const {
202
203
204
    *cur_pos += deltas_[++(*i_delta)];
    if (*i_delta >= num_vals_) {
      *cur_pos = num_data_;
205
    }
206
207
  }

208
  inline bool NextNonzero(data_size_t* i_delta, data_size_t* cur_pos) const {
209
    *cur_pos += deltas_[++(*i_delta)];
210
    if (*i_delta < num_vals_) {
211
212
      return true;
    } else {
213
      *cur_pos = num_data_;
214
215
216
217
      return false;
    }
  }

218
219
220
221
222
223
224
225
226
227
  template <bool MISS_IS_ZERO, bool MISS_IS_NA, bool MFB_IS_ZERO,
            bool MFB_IS_NA, bool USE_MIN_BIN>
  data_size_t SplitInner(uint32_t min_bin, uint32_t max_bin,
                         uint32_t default_bin, uint32_t most_freq_bin,
                         bool default_left, uint32_t threshold,
                         const data_size_t* data_indices, data_size_t cnt,
                         data_size_t* lte_indices,
                         data_size_t* gt_indices) const {
    auto th = static_cast<VAL_T>(threshold + min_bin);
    auto t_zero_bin = static_cast<VAL_T>(min_bin + default_bin);
Guolin Ke's avatar
Guolin Ke committed
228
    if (most_freq_bin == 0) {
229
230
      --th;
      --t_zero_bin;
Guolin Ke's avatar
Guolin Ke committed
231
    }
232
233
    const auto minb = static_cast<VAL_T>(min_bin);
    const auto maxb = static_cast<VAL_T>(max_bin);
Guolin Ke's avatar
Guolin Ke committed
234
235
    data_size_t lte_count = 0;
    data_size_t gt_count = 0;
Guolin Ke's avatar
Guolin Ke committed
236
237
    data_size_t* default_indices = gt_indices;
    data_size_t* default_count = &gt_count;
Guolin Ke's avatar
Guolin Ke committed
238
239
240
241
242
243
    data_size_t* missing_default_indices = gt_indices;
    data_size_t* missing_default_count = &gt_count;
    if (most_freq_bin <= threshold) {
      default_indices = lte_indices;
      default_count = &lte_count;
    }
244
    if (MISS_IS_ZERO || MISS_IS_NA) {
245
246
247
      if (default_left) {
        missing_default_indices = lte_indices;
        missing_default_count = &lte_count;
Guolin Ke's avatar
Guolin Ke committed
248
      }
249
250
251
252
253
254
255
256
257
258
259
260
    }
    SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
    if (min_bin < max_bin) {
      for (data_size_t i = 0; i < cnt; ++i) {
        const data_size_t idx = data_indices[i];
        const auto bin = iterator.InnerRawGet(idx);
        if ((MISS_IS_ZERO && !MFB_IS_ZERO && bin == t_zero_bin) ||
            (MISS_IS_NA && !MFB_IS_NA && bin == maxb)) {
          missing_default_indices[(*missing_default_count)++] = idx;
        } else if ((USE_MIN_BIN && (bin < minb || bin > maxb)) ||
                   (!USE_MIN_BIN && bin == 0)) {
          if ((MISS_IS_NA && MFB_IS_NA) || (MISS_IS_ZERO && MFB_IS_ZERO)) {
261
262
263
264
            missing_default_indices[(*missing_default_count)++] = idx;
          } else {
            default_indices[(*default_count)++] = idx;
          }
265
266
267
268
        } else if (bin > th) {
          gt_indices[gt_count++] = idx;
        } else {
          lte_indices[lte_count++] = idx;
269
270
271
        }
      }
    } else {
272
273
274
275
276
      data_size_t* max_bin_indices = gt_indices;
      data_size_t* max_bin_count = &gt_count;
      if (maxb <= th) {
        max_bin_indices = lte_indices;
        max_bin_count = &lte_count;
277
      }
278
279
280
281
282
283
284
      for (data_size_t i = 0; i < cnt; ++i) {
        const data_size_t idx = data_indices[i];
        const auto bin = iterator.InnerRawGet(idx);
        if (MISS_IS_ZERO && !MFB_IS_ZERO && bin == t_zero_bin) {
          missing_default_indices[(*missing_default_count)++] = idx;
        } else if (bin != maxb) {
          if ((MISS_IS_NA && MFB_IS_NA) || (MISS_IS_ZERO && MFB_IS_ZERO)) {
Guolin Ke's avatar
Guolin Ke committed
285
286
            missing_default_indices[(*missing_default_count)++] = idx;
          } else {
287
            default_indices[(*default_count)++] = idx;
Guolin Ke's avatar
Guolin Ke committed
288
          }
289
290
        } else {
          if (MISS_IS_NA && !MFB_IS_NA) {
Guolin Ke's avatar
Guolin Ke committed
291
292
            missing_default_indices[(*missing_default_count)++] = idx;
          } else {
293
            max_bin_indices[(*max_bin_count)++] = idx;
Guolin Ke's avatar
Guolin Ke committed
294
          }
295
        }
Guolin Ke's avatar
Guolin Ke committed
296
      }
297
    }
298
299
300
    return lte_count;
  }

301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
  data_size_t Split(uint32_t min_bin, uint32_t max_bin, uint32_t default_bin,
                    uint32_t most_freq_bin, MissingType missing_type,
                    bool default_left, uint32_t threshold,
                    const data_size_t* data_indices, data_size_t cnt,
                    data_size_t* lte_indices,
                    data_size_t* gt_indices) const override {
#define ARGUMENTS                                                        \
  min_bin, max_bin, default_bin, most_freq_bin, default_left, threshold, \
      data_indices, cnt, lte_indices, gt_indices
    if (missing_type == MissingType::None) {
      return SplitInner<false, false, false, false, true>(ARGUMENTS);
    } else if (missing_type == MissingType::Zero) {
      if (default_bin == most_freq_bin) {
        return SplitInner<true, false, true, false, true>(ARGUMENTS);
      } else {
        return SplitInner<true, false, false, false, true>(ARGUMENTS);
      }
    } else {
      if (max_bin == most_freq_bin + min_bin && most_freq_bin > 0) {
        return SplitInner<false, true, false, true, true>(ARGUMENTS);
      } else {
        return SplitInner<false, true, false, false, true>(ARGUMENTS);
      }
    }
#undef ARGUMENTS
  }

  data_size_t Split(uint32_t max_bin, uint32_t default_bin,
                    uint32_t most_freq_bin, MissingType missing_type,
                    bool default_left, uint32_t threshold,
                    const data_size_t* data_indices, data_size_t cnt,
                    data_size_t* lte_indices,
                    data_size_t* gt_indices) const override {
#define ARGUMENTS                                                  \
  1, max_bin, default_bin, most_freq_bin, default_left, threshold, \
      data_indices, cnt, lte_indices, gt_indices
    if (missing_type == MissingType::None) {
      return SplitInner<false, false, false, false, false>(ARGUMENTS);
    } else if (missing_type == MissingType::Zero) {
      if (default_bin == most_freq_bin) {
        return SplitInner<true, false, true, false, false>(ARGUMENTS);
      } else {
        return SplitInner<true, false, false, false, false>(ARGUMENTS);
      }
    } else {
      if (max_bin == most_freq_bin + 1 && most_freq_bin > 0) {
        return SplitInner<false, true, false, true, false>(ARGUMENTS);
      } else {
        return SplitInner<false, true, false, false, false>(ARGUMENTS);
      }
    }
#undef ARGUMENTS
  }
  template <bool USE_MIN_BIN>
  data_size_t SplitCategoricalInner(uint32_t min_bin, uint32_t max_bin,
                                    uint32_t most_freq_bin,
                                    const uint32_t* threshold,
Nikita Titov's avatar
Nikita Titov committed
358
                                    int num_threshold,
359
360
361
                                    const data_size_t* data_indices,
                                    data_size_t cnt, data_size_t* lte_indices,
                                    data_size_t* gt_indices) const {
362
363
364
365
    data_size_t lte_count = 0;
    data_size_t gt_count = 0;
    data_size_t* default_indices = gt_indices;
    data_size_t* default_count = &gt_count;
366
    SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
367
    int8_t offset = most_freq_bin == 0 ? 1 : 0;
Nikita Titov's avatar
Nikita Titov committed
368
    if (most_freq_bin > 0 && Common::FindInBitset(threshold, num_threshold, most_freq_bin)) {
369
370
371
      default_indices = lte_indices;
      default_count = &lte_count;
    }
372
    for (data_size_t i = 0; i < cnt; ++i) {
373
      const data_size_t idx = data_indices[i];
374
375
376
377
      const uint32_t bin = iterator.RawGet(idx);
      if (USE_MIN_BIN && (bin < min_bin || bin > max_bin)) {
        default_indices[(*default_count)++] = idx;
      } else if (!USE_MIN_BIN && bin == 0) {
378
        default_indices[(*default_count)++] = idx;
Nikita Titov's avatar
Nikita Titov committed
379
      } else if (Common::FindInBitset(threshold, num_threshold,
380
                                      bin - min_bin + offset)) {
381
382
383
384
        lte_indices[lte_count++] = idx;
      } else {
        gt_indices[gt_count++] = idx;
      }
Guolin Ke's avatar
Guolin Ke committed
385
386
387
388
    }
    return lte_count;
  }

389
390
  data_size_t SplitCategorical(uint32_t min_bin, uint32_t max_bin,
                               uint32_t most_freq_bin,
Nikita Titov's avatar
Nikita Titov committed
391
                               const uint32_t* threshold, int num_threshold,
392
393
394
395
                               const data_size_t* data_indices, data_size_t cnt,
                               data_size_t* lte_indices,
                               data_size_t* gt_indices) const override {
    return SplitCategoricalInner<true>(min_bin, max_bin, most_freq_bin,
Nikita Titov's avatar
Nikita Titov committed
396
                                       threshold, num_threshold, data_indices,
397
398
399
400
                                       cnt, lte_indices, gt_indices);
  }

  data_size_t SplitCategorical(uint32_t max_bin, uint32_t most_freq_bin,
Nikita Titov's avatar
Nikita Titov committed
401
                               const uint32_t* threshold, int num_threshold,
402
403
404
405
                               const data_size_t* data_indices, data_size_t cnt,
                               data_size_t* lte_indices,
                               data_size_t* gt_indices) const override {
    return SplitCategoricalInner<false>(1, max_bin, most_freq_bin, threshold,
Nikita Titov's avatar
Nikita Titov committed
406
                                        num_threshold, data_indices, cnt,
407
408
409
                                        lte_indices, gt_indices);
  }

Guolin Ke's avatar
Guolin Ke committed
410
411
  data_size_t num_data() const override { return num_data_; }

412
413
  void* get_data() override { return nullptr; }

Guolin Ke's avatar
Guolin Ke committed
414
415
  void FinishLoad() override {
    // get total non zero size
416
    size_t pair_cnt = 0;
417
    for (size_t i = 0; i < push_buffers_.size(); ++i) {
418
      pair_cnt += push_buffers_[i].size();
Guolin Ke's avatar
Guolin Ke committed
419
    }
420
421
    std::vector<std::pair<data_size_t, VAL_T>>& idx_val_pairs =
        push_buffers_[0];
422
    idx_val_pairs.reserve(pair_cnt);
Guolin Ke's avatar
Guolin Ke committed
423
424

    for (size_t i = 1; i < push_buffers_.size(); ++i) {
425
426
      idx_val_pairs.insert(idx_val_pairs.end(), push_buffers_[i].begin(),
                           push_buffers_[i].end());
Guolin Ke's avatar
Guolin Ke committed
427
428
429
430
      push_buffers_[i].clear();
      push_buffers_[i].shrink_to_fit();
    }
    // sort by data index
431
    std::sort(idx_val_pairs.begin(), idx_val_pairs.end(),
432
433
434
435
              [](const std::pair<data_size_t, VAL_T>& a,
                 const std::pair<data_size_t, VAL_T>& b) {
                return a.first < b.first;
              });
zhangyafeikimi's avatar
zhangyafeikimi committed
436
    // load delta array
437
    LoadFromPair(idx_val_pairs);
Guolin Ke's avatar
Guolin Ke committed
438
439
  }

440
441
  void LoadFromPair(
      const std::vector<std::pair<data_size_t, VAL_T>>& idx_val_pairs) {
442
    deltas_.clear();
Guolin Ke's avatar
Guolin Ke committed
443
    vals_.clear();
Guolin Ke's avatar
Guolin Ke committed
444
445
    deltas_.reserve(idx_val_pairs.size());
    vals_.reserve(idx_val_pairs.size());
Guolin Ke's avatar
Guolin Ke committed
446
447
    // transform to delta array
    data_size_t last_idx = 0;
448
449
450
    for (size_t i = 0; i < idx_val_pairs.size(); ++i) {
      const data_size_t cur_idx = idx_val_pairs[i].first;
      const VAL_T bin = idx_val_pairs[i].second;
Guolin Ke's avatar
Guolin Ke committed
451
      data_size_t cur_delta = cur_idx - last_idx;
452
      // disallow the multi-val in one row
453
454
455
      if (i > 0 && cur_delta == 0) {
        continue;
      }
456
      while (cur_delta >= 256) {
457
        deltas_.push_back(255);
Guolin Ke's avatar
Guolin Ke committed
458
        vals_.push_back(0);
459
        cur_delta -= 255;
Guolin Ke's avatar
Guolin Ke committed
460
      }
461
      deltas_.push_back(static_cast<uint8_t>(cur_delta));
Guolin Ke's avatar
Guolin Ke committed
462
463
464
465
      vals_.push_back(bin);
      last_idx = cur_idx;
    }
    // avoid out of range
466
    deltas_.push_back(0);
Guolin Ke's avatar
Guolin Ke committed
467
468
469
    num_vals_ = static_cast<data_size_t>(vals_.size());

    // reduce memory cost
470
    deltas_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
    vals_.shrink_to_fit();

    // generate fast index
    GetFastIndex();
  }

  void GetFastIndex() {
    fast_index_.clear();
    // get shift cnt
    data_size_t mod_size = (num_data_ + kNumFastIndex - 1) / kNumFastIndex;
    data_size_t pow2_mod_size = 1;
    fast_index_shift_ = 0;
    while (pow2_mod_size < mod_size) {
      pow2_mod_size <<= 1;
      ++fast_index_shift_;
    }
    // build fast index
488
    data_size_t i_delta = -1;
Guolin Ke's avatar
Guolin Ke committed
489
    data_size_t cur_pos = 0;
490
491
    data_size_t next_threshold = 0;
    while (NextNonzero(&i_delta, &cur_pos)) {
Guolin Ke's avatar
Guolin Ke committed
492
      while (next_threshold <= cur_pos) {
493
494
        fast_index_.emplace_back(i_delta, cur_pos);
        next_threshold += pow2_mod_size;
Guolin Ke's avatar
Guolin Ke committed
495
496
497
      }
    }
    // avoid out of range
498
    while (next_threshold < num_data_) {
Guolin Ke's avatar
Guolin Ke committed
499
      fast_index_.emplace_back(num_vals_ - 1, cur_pos);
500
      next_threshold += pow2_mod_size;
Guolin Ke's avatar
Guolin Ke committed
501
502
503
504
    }
    fast_index_.shrink_to_fit();
  }

505
506
507
508
  void SaveBinaryToFile(const VirtualFileWriter* writer) const override {
    writer->Write(&num_vals_, sizeof(num_vals_));
    writer->Write(deltas_.data(), sizeof(uint8_t) * (num_vals_ + 1));
    writer->Write(vals_.data(), sizeof(VAL_T) * num_vals_);
Guolin Ke's avatar
Guolin Ke committed
509
510
511
  }

  size_t SizesInByte() const override {
512
513
    return sizeof(num_vals_) + sizeof(uint8_t) * (num_vals_ + 1) +
           sizeof(VAL_T) * num_vals_;
Guolin Ke's avatar
Guolin Ke committed
514
515
  }

516
517
518
  void LoadFromMemory(
      const void* memory,
      const std::vector<data_size_t>& local_used_indices) override {
Guolin Ke's avatar
Guolin Ke committed
519
520
521
522
523
524
525
    const char* mem_ptr = reinterpret_cast<const char*>(memory);
    data_size_t tmp_num_vals = *(reinterpret_cast<const data_size_t*>(mem_ptr));
    mem_ptr += sizeof(tmp_num_vals);
    const uint8_t* tmp_delta = reinterpret_cast<const uint8_t*>(mem_ptr);
    mem_ptr += sizeof(uint8_t) * (tmp_num_vals + 1);
    const VAL_T* tmp_vals = reinterpret_cast<const VAL_T*>(mem_ptr);

526
527
528
529
530
531
532
533
534
535
536
    deltas_.clear();
    vals_.clear();
    num_vals_ = tmp_num_vals;
    for (data_size_t i = 0; i < num_vals_; ++i) {
      deltas_.push_back(tmp_delta[i]);
      vals_.push_back(tmp_vals[i]);
    }
    deltas_.push_back(0);
    // reduce memory cost
    deltas_.shrink_to_fit();
    vals_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
537

Guolin Ke's avatar
Guolin Ke committed
538
    if (local_used_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
539
540
541
542
      // generate fast index
      GetFastIndex();
    } else {
      std::vector<std::pair<data_size_t, VAL_T>> tmp_pair;
543
544
      data_size_t cur_pos = 0;
      data_size_t j = -1;
545
546
      for (data_size_t i = 0;
           i < static_cast<data_size_t>(local_used_indices.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
547
        const data_size_t idx = local_used_indices[i];
548
549
        while (cur_pos < idx && j < num_vals_) {
          NextNonzero(&j, &cur_pos);
Guolin Ke's avatar
Guolin Ke committed
550
        }
551
        if (cur_pos == idx && j < num_vals_ && vals_[j] > 0) {
Guolin Ke's avatar
Guolin Ke committed
552
          // new row index is i
553
          tmp_pair.emplace_back(i, vals_[j]);
Guolin Ke's avatar
Guolin Ke committed
554
555
556
557
        }
      }
      LoadFromPair(tmp_pair);
    }
558
  }
Guolin Ke's avatar
Guolin Ke committed
559

560
561
  void CopySubrow(const Bin* full_bin, const data_size_t* used_indices,
                  data_size_t num_used_indices) override {
Guolin Ke's avatar
Guolin Ke committed
562
    auto other_bin = dynamic_cast<const SparseBin<VAL_T>*>(full_bin);
Guolin Ke's avatar
Guolin Ke committed
563
564
    deltas_.clear();
    vals_.clear();
Guolin Ke's avatar
Guolin Ke committed
565
566
567
568
569
    data_size_t start = 0;
    if (num_used_indices > 0) {
      start = used_indices[0];
    }
    SparseBinIterator<VAL_T> iterator(other_bin, start);
570
571
    // transform to delta array
    data_size_t last_idx = 0;
572
    for (data_size_t i = 0; i < num_used_indices; ++i) {
573
      auto bin = iterator.InnerRawGet(used_indices[i]);
Guolin Ke's avatar
Guolin Ke committed
574
      if (bin > 0) {
575
576
        data_size_t cur_delta = i - last_idx;
        while (cur_delta >= 256) {
577
          deltas_.push_back(255);
578
          vals_.push_back(0);
579
          cur_delta -= 255;
580
581
582
583
        }
        deltas_.push_back(static_cast<uint8_t>(cur_delta));
        vals_.push_back(bin);
        last_idx = i;
584
585
      }
    }
586
587
588
589
590
591
592
593
594
595
    // avoid out of range
    deltas_.push_back(0);
    num_vals_ = static_cast<data_size_t>(vals_.size());

    // reduce memory cost
    deltas_.shrink_to_fit();
    vals_.shrink_to_fit();

    // generate fast index
    GetFastIndex();
Guolin Ke's avatar
Guolin Ke committed
596
597
  }

598
599
600
  SparseBin<VAL_T>* Clone() override;

  SparseBin<VAL_T>(const SparseBin<VAL_T>& other)
601
602
603
604
605
606
607
608
609
610
      : num_data_(other.num_data_),
        deltas_(other.deltas_),
        vals_(other.vals_),
        num_vals_(other.num_vals_),
        push_buffers_(other.push_buffers_),
        fast_index_(other.fast_index_),
        fast_index_shift_(other.fast_index_shift_) {}

  void InitIndex(data_size_t start_idx, data_size_t* i_delta,
                 data_size_t* cur_pos) const {
611
612
613
614
615
616
617
618
619
620
621
    auto idx = start_idx >> fast_index_shift_;
    if (static_cast<size_t>(idx) < fast_index_.size()) {
      const auto fast_pair = fast_index_[start_idx >> fast_index_shift_];
      *i_delta = fast_pair.first;
      *cur_pos = fast_pair.second;
    } else {
      *i_delta = -1;
      *cur_pos = 0;
    }
  }

622
 private:
Guolin Ke's avatar
Guolin Ke committed
623
  data_size_t num_data_;
624
625
  std::vector<uint8_t, Common::AlignmentAllocator<uint8_t, kAlignedSize>>
      deltas_;
626
  std::vector<VAL_T, Common::AlignmentAllocator<VAL_T, kAlignedSize>> vals_;
Guolin Ke's avatar
Guolin Ke committed
627
628
629
630
631
632
  data_size_t num_vals_;
  std::vector<std::vector<std::pair<data_size_t, VAL_T>>> push_buffers_;
  std::vector<std::pair<data_size_t, data_size_t>> fast_index_;
  data_size_t fast_index_shift_;
};

633
template <typename VAL_T>
634
SparseBin<VAL_T>* SparseBin<VAL_T>::Clone() {
635
636
637
  return new SparseBin(*this);
}

Guolin Ke's avatar
Guolin Ke committed
638
template <typename VAL_T>
639
640
641
642
643
644
inline uint32_t SparseBinIterator<VAL_T>::RawGet(data_size_t idx) {
  return InnerRawGet(idx);
}

template <typename VAL_T>
inline VAL_T SparseBinIterator<VAL_T>::InnerRawGet(data_size_t idx) {
645
  while (cur_pos_ < idx) {
646
    bin_data_->NextNonzeroFast(&i_delta_, &cur_pos_);
Guolin Ke's avatar
Guolin Ke committed
647
  }
648
  if (cur_pos_ == idx) {
649
650
    return bin_data_->vals_[i_delta_];
  } else {
Guolin Ke's avatar
Guolin Ke committed
651
    return 0;
Guolin Ke's avatar
Guolin Ke committed
652
  }
653
}
Guolin Ke's avatar
Guolin Ke committed
654

655
656
template <typename VAL_T>
inline void SparseBinIterator<VAL_T>::Reset(data_size_t start_idx) {
657
  bin_data_->InitIndex(start_idx, &i_delta_, &cur_pos_);
658
}
Guolin Ke's avatar
Guolin Ke committed
659
660

template <typename VAL_T>
661
662
BinIterator* SparseBin<VAL_T>::GetIterator(uint32_t min_bin, uint32_t max_bin,
                                           uint32_t most_freq_bin) const {
Guolin Ke's avatar
Guolin Ke committed
663
  return new SparseBinIterator<VAL_T>(this, min_bin, max_bin, most_freq_bin);
Guolin Ke's avatar
Guolin Ke committed
664
665
666
}

}  // namespace LightGBM
667
#endif  // LightGBM_IO_SPARSE_BIN_HPP_