sparse_bin.hpp 22.3 KB
Newer Older
1
2
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
3
4
 * Licensed under the MIT License. See LICENSE file in the project root for
 * license information.
5
 */
Guolin Ke's avatar
Guolin Ke committed
6
7
8
#ifndef LIGHTGBM_IO_SPARSE_BIN_HPP_
#define LIGHTGBM_IO_SPARSE_BIN_HPP_

9
10
11
12
#include <LightGBM/bin.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>

13
14
15
#include <algorithm>
#include <cstdint>
#include <cstring>
16
#include <limits>
17
#include <utility>
Guolin Ke's avatar
Guolin Ke committed
18
19
20
21
#include <vector>

namespace LightGBM {

22
23
template <typename VAL_T>
class SparseBin;
24

Guolin Ke's avatar
Guolin Ke committed
25
26
const size_t kNumFastIndex = 64;

27
template <typename VAL_T>
28
class SparseBinIterator : public BinIterator {
29
 public:
30
31
32
33
34
35
  SparseBinIterator(const SparseBin<VAL_T>* bin_data, uint32_t min_bin,
                    uint32_t max_bin, uint32_t most_freq_bin)
      : bin_data_(bin_data),
        min_bin_(static_cast<VAL_T>(min_bin)),
        max_bin_(static_cast<VAL_T>(max_bin)),
        most_freq_bin_(static_cast<VAL_T>(most_freq_bin)) {
Guolin Ke's avatar
Guolin Ke committed
36
    if (most_freq_bin_ == 0) {
37
      offset_ = 1;
Guolin Ke's avatar
Guolin Ke committed
38
    } else {
39
      offset_ = 0;
Guolin Ke's avatar
Guolin Ke committed
40
41
42
    }
    Reset(0);
  }
43
  SparseBinIterator(const SparseBin<VAL_T>* bin_data, data_size_t start_idx)
44
      : bin_data_(bin_data) {
45
46
47
    Reset(start_idx);
  }

48
49
  inline uint32_t RawGet(data_size_t idx) override;
  inline VAL_T InnerRawGet(data_size_t idx);
50

51
  inline uint32_t Get(data_size_t idx) override {
Guolin Ke's avatar
Guolin Ke committed
52
    VAL_T ret = InnerRawGet(idx);
Guolin Ke's avatar
Guolin Ke committed
53
    if (ret >= min_bin_ && ret <= max_bin_) {
54
      return ret - min_bin_ + offset_;
Guolin Ke's avatar
Guolin Ke committed
55
    } else {
Guolin Ke's avatar
Guolin Ke committed
56
      return most_freq_bin_;
Guolin Ke's avatar
Guolin Ke committed
57
    }
58
59
  }

Guolin Ke's avatar
Guolin Ke committed
60
  inline void Reset(data_size_t idx) override;
61

62
 private:
63
64
65
  const SparseBin<VAL_T>* bin_data_;
  data_size_t cur_pos_;
  data_size_t i_delta_;
Guolin Ke's avatar
Guolin Ke committed
66
67
  VAL_T min_bin_;
  VAL_T max_bin_;
Guolin Ke's avatar
Guolin Ke committed
68
  VAL_T most_freq_bin_;
69
  uint8_t offset_;
70
71
};

Guolin Ke's avatar
Guolin Ke committed
72
template <typename VAL_T>
73
class SparseBin : public Bin {
74
 public:
Guolin Ke's avatar
Guolin Ke committed
75
76
  friend class SparseBinIterator<VAL_T>;

77
  explicit SparseBin(data_size_t num_data) : num_data_(num_data) {
78
    int num_threads = OMP_NUM_THREADS();
Guolin Ke's avatar
Guolin Ke committed
79
    push_buffers_.resize(num_threads);
Guolin Ke's avatar
Guolin Ke committed
80
81
  }

82
  ~SparseBin() {}
Guolin Ke's avatar
Guolin Ke committed
83

84
  void ReSize(data_size_t num_data) override { num_data_ = num_data; }
Guolin Ke's avatar
Guolin Ke committed
85
86

  void Push(int tid, data_size_t idx, uint32_t value) override {
87
    auto cur_bin = static_cast<VAL_T>(value);
Guolin Ke's avatar
Guolin Ke committed
88
    if (cur_bin != 0) {
89
90
      push_buffers_[tid].emplace_back(idx, cur_bin);
    }
Guolin Ke's avatar
Guolin Ke committed
91
92
  }

93
94
  BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin,
                           uint32_t most_freq_bin) const override;
Guolin Ke's avatar
Guolin Ke committed
95

96
#define ACC_GH(hist, i, g, h)               \
97
  const auto ti = static_cast<int>(i) << 1; \
98
99
  hist[ti] += g;                            \
  hist[ti + 1] += h;
100

101
102
103
104
  void ConstructHistogram(const data_size_t* data_indices, data_size_t start,
                          data_size_t end, const score_t* ordered_gradients,
                          const score_t* ordered_hessians,
                          hist_t* out) const override {
105
106
107
108
109
110
    data_size_t i_delta, cur_pos;
    InitIndex(data_indices[start], &i_delta, &cur_pos);
    data_size_t i = start;
    for (;;) {
      if (cur_pos < data_indices[i]) {
        cur_pos += deltas_[++i_delta];
111
112
113
        if (i_delta >= num_vals_) {
          break;
        }
114
      } else if (cur_pos > data_indices[i]) {
115
116
117
        if (++i >= end) {
          break;
        }
118
119
120
      } else {
        const VAL_T bin = vals_[i_delta];
        ACC_GH(out, bin, ordered_gradients[i], ordered_hessians[i]);
121
122
123
        if (++i >= end) {
          break;
        }
124
        cur_pos += deltas_[++i_delta];
125
126
127
        if (i_delta >= num_vals_) {
          break;
        }
128
129
      }
    }
Guolin Ke's avatar
Guolin Ke committed
130
131
  }

132
  void ConstructHistogram(data_size_t start, data_size_t end,
133
134
135
                          const score_t* ordered_gradients,
                          const score_t* ordered_hessians,
                          hist_t* out) const override {
136
137
138
139
140
141
142
143
144
145
    data_size_t i_delta, cur_pos;
    InitIndex(start, &i_delta, &cur_pos);
    while (cur_pos < start && i_delta < num_vals_) {
      cur_pos += deltas_[++i_delta];
    }
    while (cur_pos < end && i_delta < num_vals_) {
      const VAL_T bin = vals_[i_delta];
      ACC_GH(out, bin, ordered_gradients[cur_pos], ordered_hessians[cur_pos]);
      cur_pos += deltas_[++i_delta];
    }
146
147
  }

148
149
150
  void ConstructHistogram(const data_size_t* data_indices, data_size_t start,
                          data_size_t end, const score_t* ordered_gradients,
                          hist_t* out) const override {
151
152
153
    data_size_t i_delta, cur_pos;
    InitIndex(data_indices[start], &i_delta, &cur_pos);
    data_size_t i = start;
Guolin Ke's avatar
Guolin Ke committed
154
155
    hist_t* grad = out;
    hist_cnt_t* cnt = reinterpret_cast<hist_cnt_t*>(out + 1);
156
157
158
    for (;;) {
      if (cur_pos < data_indices[i]) {
        cur_pos += deltas_[++i_delta];
159
160
161
        if (i_delta >= num_vals_) {
          break;
        }
162
      } else if (cur_pos > data_indices[i]) {
163
164
165
        if (++i >= end) {
          break;
        }
166
      } else {
Guolin Ke's avatar
Guolin Ke committed
167
168
169
        const uint32_t ti = static_cast<uint32_t>(vals_[i_delta]) << 1;
        grad[ti] += ordered_gradients[i];
        ++cnt[ti];
170
171
172
        if (++i >= end) {
          break;
        }
173
        cur_pos += deltas_[++i_delta];
174
175
176
        if (i_delta >= num_vals_) {
          break;
        }
177
178
      }
    }
179
180
  }

181
  void ConstructHistogram(data_size_t start, data_size_t end,
182
183
                          const score_t* ordered_gradients,
                          hist_t* out) const override {
184
185
    data_size_t i_delta, cur_pos;
    InitIndex(start, &i_delta, &cur_pos);
Guolin Ke's avatar
Guolin Ke committed
186
187
    hist_t* grad = out;
    hist_cnt_t* cnt = reinterpret_cast<hist_cnt_t*>(out + 1);
188
189
190
191
    while (cur_pos < start && i_delta < num_vals_) {
      cur_pos += deltas_[++i_delta];
    }
    while (cur_pos < end && i_delta < num_vals_) {
Guolin Ke's avatar
Guolin Ke committed
192
193
194
      const uint32_t ti = static_cast<uint32_t>(vals_[i_delta]) << 1;
      grad[ti] += ordered_gradients[cur_pos];
      ++cnt[ti];
195
196
      cur_pos += deltas_[++i_delta];
    }
197
  }
198
#undef ACC_GH
199

200
201
  inline void NextNonzeroFast(data_size_t* i_delta,
                              data_size_t* cur_pos) const {
202
203
204
    *cur_pos += deltas_[++(*i_delta)];
    if (*i_delta >= num_vals_) {
      *cur_pos = num_data_;
205
    }
206
207
  }

208
  inline bool NextNonzero(data_size_t* i_delta, data_size_t* cur_pos) const {
209
    *cur_pos += deltas_[++(*i_delta)];
210
    if (*i_delta < num_vals_) {
211
212
      return true;
    } else {
213
      *cur_pos = num_data_;
214
215
216
217
      return false;
    }
  }

218
219
220
221
222
223
224
225
226
227
  template <bool MISS_IS_ZERO, bool MISS_IS_NA, bool MFB_IS_ZERO,
            bool MFB_IS_NA, bool USE_MIN_BIN>
  data_size_t SplitInner(uint32_t min_bin, uint32_t max_bin,
                         uint32_t default_bin, uint32_t most_freq_bin,
                         bool default_left, uint32_t threshold,
                         const data_size_t* data_indices, data_size_t cnt,
                         data_size_t* lte_indices,
                         data_size_t* gt_indices) const {
    auto th = static_cast<VAL_T>(threshold + min_bin);
    auto t_zero_bin = static_cast<VAL_T>(min_bin + default_bin);
Guolin Ke's avatar
Guolin Ke committed
228
    if (most_freq_bin == 0) {
229
230
      --th;
      --t_zero_bin;
Guolin Ke's avatar
Guolin Ke committed
231
    }
232
233
    const auto minb = static_cast<VAL_T>(min_bin);
    const auto maxb = static_cast<VAL_T>(max_bin);
Guolin Ke's avatar
Guolin Ke committed
234
235
    data_size_t lte_count = 0;
    data_size_t gt_count = 0;
Guolin Ke's avatar
Guolin Ke committed
236
237
    data_size_t* default_indices = gt_indices;
    data_size_t* default_count = &gt_count;
Guolin Ke's avatar
Guolin Ke committed
238
239
240
241
242
243
    data_size_t* missing_default_indices = gt_indices;
    data_size_t* missing_default_count = &gt_count;
    if (most_freq_bin <= threshold) {
      default_indices = lte_indices;
      default_count = &lte_count;
    }
244
    if (MISS_IS_ZERO || MISS_IS_NA) {
245
246
247
      if (default_left) {
        missing_default_indices = lte_indices;
        missing_default_count = &lte_count;
Guolin Ke's avatar
Guolin Ke committed
248
      }
249
250
251
252
253
254
255
256
257
258
259
260
    }
    SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
    if (min_bin < max_bin) {
      for (data_size_t i = 0; i < cnt; ++i) {
        const data_size_t idx = data_indices[i];
        const auto bin = iterator.InnerRawGet(idx);
        if ((MISS_IS_ZERO && !MFB_IS_ZERO && bin == t_zero_bin) ||
            (MISS_IS_NA && !MFB_IS_NA && bin == maxb)) {
          missing_default_indices[(*missing_default_count)++] = idx;
        } else if ((USE_MIN_BIN && (bin < minb || bin > maxb)) ||
                   (!USE_MIN_BIN && bin == 0)) {
          if ((MISS_IS_NA && MFB_IS_NA) || (MISS_IS_ZERO && MFB_IS_ZERO)) {
261
262
263
264
            missing_default_indices[(*missing_default_count)++] = idx;
          } else {
            default_indices[(*default_count)++] = idx;
          }
265
266
267
268
        } else if (bin > th) {
          gt_indices[gt_count++] = idx;
        } else {
          lte_indices[lte_count++] = idx;
269
270
271
        }
      }
    } else {
272
273
274
275
276
      data_size_t* max_bin_indices = gt_indices;
      data_size_t* max_bin_count = &gt_count;
      if (maxb <= th) {
        max_bin_indices = lte_indices;
        max_bin_count = &lte_count;
277
      }
278
279
280
281
282
283
284
      for (data_size_t i = 0; i < cnt; ++i) {
        const data_size_t idx = data_indices[i];
        const auto bin = iterator.InnerRawGet(idx);
        if (MISS_IS_ZERO && !MFB_IS_ZERO && bin == t_zero_bin) {
          missing_default_indices[(*missing_default_count)++] = idx;
        } else if (bin != maxb) {
          if ((MISS_IS_NA && MFB_IS_NA) || (MISS_IS_ZERO && MFB_IS_ZERO)) {
Guolin Ke's avatar
Guolin Ke committed
285
286
            missing_default_indices[(*missing_default_count)++] = idx;
          } else {
287
            default_indices[(*default_count)++] = idx;
Guolin Ke's avatar
Guolin Ke committed
288
          }
289
290
        } else {
          if (MISS_IS_NA && !MFB_IS_NA) {
Guolin Ke's avatar
Guolin Ke committed
291
292
            missing_default_indices[(*missing_default_count)++] = idx;
          } else {
293
            max_bin_indices[(*max_bin_count)++] = idx;
Guolin Ke's avatar
Guolin Ke committed
294
          }
295
        }
Guolin Ke's avatar
Guolin Ke committed
296
      }
297
    }
298
299
300
    return lte_count;
  }

301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
  data_size_t Split(uint32_t min_bin, uint32_t max_bin, uint32_t default_bin,
                    uint32_t most_freq_bin, MissingType missing_type,
                    bool default_left, uint32_t threshold,
                    const data_size_t* data_indices, data_size_t cnt,
                    data_size_t* lte_indices,
                    data_size_t* gt_indices) const override {
#define ARGUMENTS                                                        \
  min_bin, max_bin, default_bin, most_freq_bin, default_left, threshold, \
      data_indices, cnt, lte_indices, gt_indices
    if (missing_type == MissingType::None) {
      return SplitInner<false, false, false, false, true>(ARGUMENTS);
    } else if (missing_type == MissingType::Zero) {
      if (default_bin == most_freq_bin) {
        return SplitInner<true, false, true, false, true>(ARGUMENTS);
      } else {
        return SplitInner<true, false, false, false, true>(ARGUMENTS);
      }
    } else {
      if (max_bin == most_freq_bin + min_bin && most_freq_bin > 0) {
        return SplitInner<false, true, false, true, true>(ARGUMENTS);
      } else {
        return SplitInner<false, true, false, false, true>(ARGUMENTS);
      }
    }
#undef ARGUMENTS
  }

  data_size_t Split(uint32_t max_bin, uint32_t default_bin,
                    uint32_t most_freq_bin, MissingType missing_type,
                    bool default_left, uint32_t threshold,
                    const data_size_t* data_indices, data_size_t cnt,
                    data_size_t* lte_indices,
                    data_size_t* gt_indices) const override {
#define ARGUMENTS                                                  \
  1, max_bin, default_bin, most_freq_bin, default_left, threshold, \
      data_indices, cnt, lte_indices, gt_indices
    if (missing_type == MissingType::None) {
      return SplitInner<false, false, false, false, false>(ARGUMENTS);
    } else if (missing_type == MissingType::Zero) {
      if (default_bin == most_freq_bin) {
        return SplitInner<true, false, true, false, false>(ARGUMENTS);
      } else {
        return SplitInner<true, false, false, false, false>(ARGUMENTS);
      }
    } else {
      if (max_bin == most_freq_bin + 1 && most_freq_bin > 0) {
        return SplitInner<false, true, false, true, false>(ARGUMENTS);
      } else {
        return SplitInner<false, true, false, false, false>(ARGUMENTS);
      }
    }
#undef ARGUMENTS
  }
  template <bool USE_MIN_BIN>
  data_size_t SplitCategoricalInner(uint32_t min_bin, uint32_t max_bin,
                                    uint32_t most_freq_bin,
                                    const uint32_t* threshold,
                                    int num_threahold,
                                    const data_size_t* data_indices,
                                    data_size_t cnt, data_size_t* lte_indices,
                                    data_size_t* gt_indices) const {
362
363
364
365
    data_size_t lte_count = 0;
    data_size_t gt_count = 0;
    data_size_t* default_indices = gt_indices;
    data_size_t* default_count = &gt_count;
366
    SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
367
368
    int8_t offset = most_freq_bin == 0 ? 1 : 0;
    if (most_freq_bin > 0 && Common::FindInBitset(threshold, num_threahold, most_freq_bin)) {
369
370
371
      default_indices = lte_indices;
      default_count = &lte_count;
    }
372
    for (data_size_t i = 0; i < cnt; ++i) {
373
      const data_size_t idx = data_indices[i];
374
375
376
377
      const uint32_t bin = iterator.RawGet(idx);
      if (USE_MIN_BIN && (bin < min_bin || bin > max_bin)) {
        default_indices[(*default_count)++] = idx;
      } else if (!USE_MIN_BIN && bin == 0) {
378
        default_indices[(*default_count)++] = idx;
379
      } else if (Common::FindInBitset(threshold, num_threahold,
380
                                      bin - min_bin + offset)) {
381
382
383
384
        lte_indices[lte_count++] = idx;
      } else {
        gt_indices[gt_count++] = idx;
      }
Guolin Ke's avatar
Guolin Ke committed
385
386
387
388
    }
    return lte_count;
  }

389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
  data_size_t SplitCategorical(uint32_t min_bin, uint32_t max_bin,
                               uint32_t most_freq_bin,
                               const uint32_t* threshold, int num_threahold,
                               const data_size_t* data_indices, data_size_t cnt,
                               data_size_t* lte_indices,
                               data_size_t* gt_indices) const override {
    return SplitCategoricalInner<true>(min_bin, max_bin, most_freq_bin,
                                       threshold, num_threahold, data_indices,
                                       cnt, lte_indices, gt_indices);
  }

  data_size_t SplitCategorical(uint32_t max_bin, uint32_t most_freq_bin,
                               const uint32_t* threshold, int num_threahold,
                               const data_size_t* data_indices, data_size_t cnt,
                               data_size_t* lte_indices,
                               data_size_t* gt_indices) const override {
    return SplitCategoricalInner<false>(1, max_bin, most_freq_bin, threshold,
                                        num_threahold, data_indices, cnt,
                                        lte_indices, gt_indices);
  }

Guolin Ke's avatar
Guolin Ke committed
410
411
412
413
  data_size_t num_data() const override { return num_data_; }

  void FinishLoad() override {
    // get total non zero size
414
    size_t pair_cnt = 0;
415
    for (size_t i = 0; i < push_buffers_.size(); ++i) {
416
      pair_cnt += push_buffers_[i].size();
Guolin Ke's avatar
Guolin Ke committed
417
    }
418
419
    std::vector<std::pair<data_size_t, VAL_T>>& idx_val_pairs =
        push_buffers_[0];
420
    idx_val_pairs.reserve(pair_cnt);
Guolin Ke's avatar
Guolin Ke committed
421
422

    for (size_t i = 1; i < push_buffers_.size(); ++i) {
423
424
      idx_val_pairs.insert(idx_val_pairs.end(), push_buffers_[i].begin(),
                           push_buffers_[i].end());
Guolin Ke's avatar
Guolin Ke committed
425
426
427
428
      push_buffers_[i].clear();
      push_buffers_[i].shrink_to_fit();
    }
    // sort by data index
429
    std::sort(idx_val_pairs.begin(), idx_val_pairs.end(),
430
431
432
433
              [](const std::pair<data_size_t, VAL_T>& a,
                 const std::pair<data_size_t, VAL_T>& b) {
                return a.first < b.first;
              });
zhangyafeikimi's avatar
zhangyafeikimi committed
434
    // load delta array
435
    LoadFromPair(idx_val_pairs);
Guolin Ke's avatar
Guolin Ke committed
436
437
  }

438
439
  void LoadFromPair(
      const std::vector<std::pair<data_size_t, VAL_T>>& idx_val_pairs) {
440
    deltas_.clear();
Guolin Ke's avatar
Guolin Ke committed
441
    vals_.clear();
Guolin Ke's avatar
Guolin Ke committed
442
443
    deltas_.reserve(idx_val_pairs.size());
    vals_.reserve(idx_val_pairs.size());
Guolin Ke's avatar
Guolin Ke committed
444
445
    // transform to delta array
    data_size_t last_idx = 0;
446
447
448
    for (size_t i = 0; i < idx_val_pairs.size(); ++i) {
      const data_size_t cur_idx = idx_val_pairs[i].first;
      const VAL_T bin = idx_val_pairs[i].second;
Guolin Ke's avatar
Guolin Ke committed
449
      data_size_t cur_delta = cur_idx - last_idx;
450
      // disallow the multi-val in one row
451
452
453
      if (i > 0 && cur_delta == 0) {
        continue;
      }
454
      while (cur_delta >= 256) {
455
        deltas_.push_back(255);
Guolin Ke's avatar
Guolin Ke committed
456
        vals_.push_back(0);
457
        cur_delta -= 255;
Guolin Ke's avatar
Guolin Ke committed
458
      }
459
      deltas_.push_back(static_cast<uint8_t>(cur_delta));
Guolin Ke's avatar
Guolin Ke committed
460
461
462
463
      vals_.push_back(bin);
      last_idx = cur_idx;
    }
    // avoid out of range
464
    deltas_.push_back(0);
Guolin Ke's avatar
Guolin Ke committed
465
466
467
    num_vals_ = static_cast<data_size_t>(vals_.size());

    // reduce memory cost
468
    deltas_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
    vals_.shrink_to_fit();

    // generate fast index
    GetFastIndex();
  }

  void GetFastIndex() {
    fast_index_.clear();
    // get shift cnt
    data_size_t mod_size = (num_data_ + kNumFastIndex - 1) / kNumFastIndex;
    data_size_t pow2_mod_size = 1;
    fast_index_shift_ = 0;
    while (pow2_mod_size < mod_size) {
      pow2_mod_size <<= 1;
      ++fast_index_shift_;
    }
    // build fast index
486
    data_size_t i_delta = -1;
Guolin Ke's avatar
Guolin Ke committed
487
    data_size_t cur_pos = 0;
488
489
    data_size_t next_threshold = 0;
    while (NextNonzero(&i_delta, &cur_pos)) {
Guolin Ke's avatar
Guolin Ke committed
490
      while (next_threshold <= cur_pos) {
491
492
        fast_index_.emplace_back(i_delta, cur_pos);
        next_threshold += pow2_mod_size;
Guolin Ke's avatar
Guolin Ke committed
493
494
495
      }
    }
    // avoid out of range
496
    while (next_threshold < num_data_) {
Guolin Ke's avatar
Guolin Ke committed
497
      fast_index_.emplace_back(num_vals_ - 1, cur_pos);
498
      next_threshold += pow2_mod_size;
Guolin Ke's avatar
Guolin Ke committed
499
500
501
502
    }
    fast_index_.shrink_to_fit();
  }

503
504
505
506
  void SaveBinaryToFile(const VirtualFileWriter* writer) const override {
    writer->Write(&num_vals_, sizeof(num_vals_));
    writer->Write(deltas_.data(), sizeof(uint8_t) * (num_vals_ + 1));
    writer->Write(vals_.data(), sizeof(VAL_T) * num_vals_);
Guolin Ke's avatar
Guolin Ke committed
507
508
509
  }

  size_t SizesInByte() const override {
510
511
    return sizeof(num_vals_) + sizeof(uint8_t) * (num_vals_ + 1) +
           sizeof(VAL_T) * num_vals_;
Guolin Ke's avatar
Guolin Ke committed
512
513
  }

514
515
516
  void LoadFromMemory(
      const void* memory,
      const std::vector<data_size_t>& local_used_indices) override {
Guolin Ke's avatar
Guolin Ke committed
517
518
519
520
521
522
523
    const char* mem_ptr = reinterpret_cast<const char*>(memory);
    data_size_t tmp_num_vals = *(reinterpret_cast<const data_size_t*>(mem_ptr));
    mem_ptr += sizeof(tmp_num_vals);
    const uint8_t* tmp_delta = reinterpret_cast<const uint8_t*>(mem_ptr);
    mem_ptr += sizeof(uint8_t) * (tmp_num_vals + 1);
    const VAL_T* tmp_vals = reinterpret_cast<const VAL_T*>(mem_ptr);

524
525
526
527
528
529
530
531
532
533
534
    deltas_.clear();
    vals_.clear();
    num_vals_ = tmp_num_vals;
    for (data_size_t i = 0; i < num_vals_; ++i) {
      deltas_.push_back(tmp_delta[i]);
      vals_.push_back(tmp_vals[i]);
    }
    deltas_.push_back(0);
    // reduce memory cost
    deltas_.shrink_to_fit();
    vals_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
535

Guolin Ke's avatar
Guolin Ke committed
536
    if (local_used_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
537
538
539
540
      // generate fast index
      GetFastIndex();
    } else {
      std::vector<std::pair<data_size_t, VAL_T>> tmp_pair;
541
542
      data_size_t cur_pos = 0;
      data_size_t j = -1;
543
544
      for (data_size_t i = 0;
           i < static_cast<data_size_t>(local_used_indices.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
545
        const data_size_t idx = local_used_indices[i];
546
547
        while (cur_pos < idx && j < num_vals_) {
          NextNonzero(&j, &cur_pos);
Guolin Ke's avatar
Guolin Ke committed
548
        }
549
        if (cur_pos == idx && j < num_vals_ && vals_[j] > 0) {
Guolin Ke's avatar
Guolin Ke committed
550
          // new row index is i
551
          tmp_pair.emplace_back(i, vals_[j]);
Guolin Ke's avatar
Guolin Ke committed
552
553
554
555
        }
      }
      LoadFromPair(tmp_pair);
    }
556
  }
Guolin Ke's avatar
Guolin Ke committed
557

558
559
  void CopySubrow(const Bin* full_bin, const data_size_t* used_indices,
                  data_size_t num_used_indices) override {
Guolin Ke's avatar
Guolin Ke committed
560
    auto other_bin = dynamic_cast<const SparseBin<VAL_T>*>(full_bin);
Guolin Ke's avatar
Guolin Ke committed
561
562
    deltas_.clear();
    vals_.clear();
Guolin Ke's avatar
Guolin Ke committed
563
564
565
566
567
    data_size_t start = 0;
    if (num_used_indices > 0) {
      start = used_indices[0];
    }
    SparseBinIterator<VAL_T> iterator(other_bin, start);
568
569
    // transform to delta array
    data_size_t last_idx = 0;
570
    for (data_size_t i = 0; i < num_used_indices; ++i) {
571
      auto bin = iterator.InnerRawGet(used_indices[i]);
Guolin Ke's avatar
Guolin Ke committed
572
      if (bin > 0) {
573
574
        data_size_t cur_delta = i - last_idx;
        while (cur_delta >= 256) {
575
          deltas_.push_back(255);
576
          vals_.push_back(0);
577
          cur_delta -= 255;
578
579
580
581
        }
        deltas_.push_back(static_cast<uint8_t>(cur_delta));
        vals_.push_back(bin);
        last_idx = i;
582
583
      }
    }
584
585
586
587
588
589
590
591
592
593
    // avoid out of range
    deltas_.push_back(0);
    num_vals_ = static_cast<data_size_t>(vals_.size());

    // reduce memory cost
    deltas_.shrink_to_fit();
    vals_.shrink_to_fit();

    // generate fast index
    GetFastIndex();
Guolin Ke's avatar
Guolin Ke committed
594
595
  }

596
597
598
  SparseBin<VAL_T>* Clone() override;

  SparseBin<VAL_T>(const SparseBin<VAL_T>& other)
599
600
601
602
603
604
605
606
607
608
      : num_data_(other.num_data_),
        deltas_(other.deltas_),
        vals_(other.vals_),
        num_vals_(other.num_vals_),
        push_buffers_(other.push_buffers_),
        fast_index_(other.fast_index_),
        fast_index_shift_(other.fast_index_shift_) {}

  void InitIndex(data_size_t start_idx, data_size_t* i_delta,
                 data_size_t* cur_pos) const {
609
610
611
612
613
614
615
616
617
618
619
    auto idx = start_idx >> fast_index_shift_;
    if (static_cast<size_t>(idx) < fast_index_.size()) {
      const auto fast_pair = fast_index_[start_idx >> fast_index_shift_];
      *i_delta = fast_pair.first;
      *cur_pos = fast_pair.second;
    } else {
      *i_delta = -1;
      *cur_pos = 0;
    }
  }

620
 private:
Guolin Ke's avatar
Guolin Ke committed
621
  data_size_t num_data_;
622
623
  std::vector<uint8_t, Common::AlignmentAllocator<uint8_t, kAlignedSize>>
      deltas_;
624
  std::vector<VAL_T, Common::AlignmentAllocator<VAL_T, kAlignedSize>> vals_;
Guolin Ke's avatar
Guolin Ke committed
625
626
627
628
629
630
  data_size_t num_vals_;
  std::vector<std::vector<std::pair<data_size_t, VAL_T>>> push_buffers_;
  std::vector<std::pair<data_size_t, data_size_t>> fast_index_;
  data_size_t fast_index_shift_;
};

631
template <typename VAL_T>
632
SparseBin<VAL_T>* SparseBin<VAL_T>::Clone() {
633
634
635
  return new SparseBin(*this);
}

Guolin Ke's avatar
Guolin Ke committed
636
template <typename VAL_T>
637
638
639
640
641
642
inline uint32_t SparseBinIterator<VAL_T>::RawGet(data_size_t idx) {
  return InnerRawGet(idx);
}

template <typename VAL_T>
inline VAL_T SparseBinIterator<VAL_T>::InnerRawGet(data_size_t idx) {
643
  while (cur_pos_ < idx) {
644
    bin_data_->NextNonzeroFast(&i_delta_, &cur_pos_);
Guolin Ke's avatar
Guolin Ke committed
645
  }
646
  if (cur_pos_ == idx) {
647
648
    return bin_data_->vals_[i_delta_];
  } else {
Guolin Ke's avatar
Guolin Ke committed
649
    return 0;
Guolin Ke's avatar
Guolin Ke committed
650
  }
651
}
Guolin Ke's avatar
Guolin Ke committed
652

653
654
template <typename VAL_T>
inline void SparseBinIterator<VAL_T>::Reset(data_size_t start_idx) {
655
  bin_data_->InitIndex(start_idx, &i_delta_, &cur_pos_);
656
}
Guolin Ke's avatar
Guolin Ke committed
657
658

template <typename VAL_T>
659
660
BinIterator* SparseBin<VAL_T>::GetIterator(uint32_t min_bin, uint32_t max_bin,
                                           uint32_t most_freq_bin) const {
Guolin Ke's avatar
Guolin Ke committed
661
  return new SparseBinIterator<VAL_T>(this, min_bin, max_bin, most_freq_bin);
Guolin Ke's avatar
Guolin Ke committed
662
663
664
}

}  // namespace LightGBM
665
#endif  // LightGBM_IO_SPARSE_BIN_HPP_