sparse_bin.hpp 22.2 KB
Newer Older
1
2
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
3
4
 * Licensed under the MIT License. See LICENSE file in the project root for
 * license information.
5
 */
Guolin Ke's avatar
Guolin Ke committed
6
7
8
#ifndef LIGHTGBM_IO_SPARSE_BIN_HPP_
#define LIGHTGBM_IO_SPARSE_BIN_HPP_

9
10
11
#include <algorithm>
#include <cstdint>
#include <cstring>
12
#include <limits>
13
#include <utility>
Guolin Ke's avatar
Guolin Ke committed
14
15
#include <vector>

16
17
18
19
#include <LightGBM/bin.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>

Guolin Ke's avatar
Guolin Ke committed
20
21
namespace LightGBM {

22
23
template <typename VAL_T>
class SparseBin;
24

Guolin Ke's avatar
Guolin Ke committed
25
26
const size_t kNumFastIndex = 64;

27
template <typename VAL_T>
28
class SparseBinIterator : public BinIterator {
29
 public:
30
31
32
33
34
35
  SparseBinIterator(const SparseBin<VAL_T>* bin_data, uint32_t min_bin,
                    uint32_t max_bin, uint32_t most_freq_bin)
      : bin_data_(bin_data),
        min_bin_(static_cast<VAL_T>(min_bin)),
        max_bin_(static_cast<VAL_T>(max_bin)),
        most_freq_bin_(static_cast<VAL_T>(most_freq_bin)) {
Guolin Ke's avatar
Guolin Ke committed
36
    if (most_freq_bin_ == 0) {
37
      offset_ = 1;
Guolin Ke's avatar
Guolin Ke committed
38
    } else {
39
      offset_ = 0;
Guolin Ke's avatar
Guolin Ke committed
40
41
42
    }
    Reset(0);
  }
43
  SparseBinIterator(const SparseBin<VAL_T>* bin_data, data_size_t start_idx)
44
      : bin_data_(bin_data) {
45
46
47
    Reset(start_idx);
  }

48
49
  inline uint32_t RawGet(data_size_t idx) override;
  inline VAL_T InnerRawGet(data_size_t idx);
50

51
  inline uint32_t Get(data_size_t idx) override {
Guolin Ke's avatar
Guolin Ke committed
52
    VAL_T ret = InnerRawGet(idx);
Guolin Ke's avatar
Guolin Ke committed
53
    if (ret >= min_bin_ && ret <= max_bin_) {
54
      return ret - min_bin_ + offset_;
Guolin Ke's avatar
Guolin Ke committed
55
    } else {
Guolin Ke's avatar
Guolin Ke committed
56
      return most_freq_bin_;
Guolin Ke's avatar
Guolin Ke committed
57
    }
58
59
  }

Guolin Ke's avatar
Guolin Ke committed
60
  inline void Reset(data_size_t idx) override;
61

62
 private:
63
64
65
  const SparseBin<VAL_T>* bin_data_;
  data_size_t cur_pos_;
  data_size_t i_delta_;
Guolin Ke's avatar
Guolin Ke committed
66
67
  VAL_T min_bin_;
  VAL_T max_bin_;
Guolin Ke's avatar
Guolin Ke committed
68
  VAL_T most_freq_bin_;
69
  uint8_t offset_;
70
71
};

Guolin Ke's avatar
Guolin Ke committed
72
template <typename VAL_T>
73
class SparseBin : public Bin {
74
 public:
Guolin Ke's avatar
Guolin Ke committed
75
76
  friend class SparseBinIterator<VAL_T>;

77
  explicit SparseBin(data_size_t num_data) : num_data_(num_data) {
78
    int num_threads = OMP_NUM_THREADS();
Guolin Ke's avatar
Guolin Ke committed
79
    push_buffers_.resize(num_threads);
Guolin Ke's avatar
Guolin Ke committed
80
81
  }

82
  ~SparseBin() {}
Guolin Ke's avatar
Guolin Ke committed
83

84
  void ReSize(data_size_t num_data) override { num_data_ = num_data; }
Guolin Ke's avatar
Guolin Ke committed
85
86

  void Push(int tid, data_size_t idx, uint32_t value) override {
87
    auto cur_bin = static_cast<VAL_T>(value);
Guolin Ke's avatar
Guolin Ke committed
88
    if (cur_bin != 0) {
89
90
      push_buffers_[tid].emplace_back(idx, cur_bin);
    }
Guolin Ke's avatar
Guolin Ke committed
91
92
  }

93
94
  BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin,
                           uint32_t most_freq_bin) const override;
Guolin Ke's avatar
Guolin Ke committed
95

96
#define ACC_GH(hist, i, g, h)               \
97
  const auto ti = static_cast<int>(i) << 1; \
98
99
  hist[ti] += g;                            \
  hist[ti + 1] += h;
100

101
102
103
104
  void ConstructHistogram(const data_size_t* data_indices, data_size_t start,
                          data_size_t end, const score_t* ordered_gradients,
                          const score_t* ordered_hessians,
                          hist_t* out) const override {
105
106
107
108
109
110
    data_size_t i_delta, cur_pos;
    InitIndex(data_indices[start], &i_delta, &cur_pos);
    data_size_t i = start;
    for (;;) {
      if (cur_pos < data_indices[i]) {
        cur_pos += deltas_[++i_delta];
111
112
113
        if (i_delta >= num_vals_) {
          break;
        }
114
      } else if (cur_pos > data_indices[i]) {
115
116
117
        if (++i >= end) {
          break;
        }
118
119
120
      } else {
        const VAL_T bin = vals_[i_delta];
        ACC_GH(out, bin, ordered_gradients[i], ordered_hessians[i]);
121
122
123
        if (++i >= end) {
          break;
        }
124
        cur_pos += deltas_[++i_delta];
125
126
127
        if (i_delta >= num_vals_) {
          break;
        }
128
129
      }
    }
Guolin Ke's avatar
Guolin Ke committed
130
131
  }

132
  void ConstructHistogram(data_size_t start, data_size_t end,
133
134
135
                          const score_t* ordered_gradients,
                          const score_t* ordered_hessians,
                          hist_t* out) const override {
136
137
138
139
140
141
142
143
144
145
    data_size_t i_delta, cur_pos;
    InitIndex(start, &i_delta, &cur_pos);
    while (cur_pos < start && i_delta < num_vals_) {
      cur_pos += deltas_[++i_delta];
    }
    while (cur_pos < end && i_delta < num_vals_) {
      const VAL_T bin = vals_[i_delta];
      ACC_GH(out, bin, ordered_gradients[cur_pos], ordered_hessians[cur_pos]);
      cur_pos += deltas_[++i_delta];
    }
146
147
  }

148
149
150
  void ConstructHistogram(const data_size_t* data_indices, data_size_t start,
                          data_size_t end, const score_t* ordered_gradients,
                          hist_t* out) const override {
151
152
153
    data_size_t i_delta, cur_pos;
    InitIndex(data_indices[start], &i_delta, &cur_pos);
    data_size_t i = start;
Guolin Ke's avatar
Guolin Ke committed
154
155
    hist_t* grad = out;
    hist_cnt_t* cnt = reinterpret_cast<hist_cnt_t*>(out + 1);
156
157
158
    for (;;) {
      if (cur_pos < data_indices[i]) {
        cur_pos += deltas_[++i_delta];
159
160
161
        if (i_delta >= num_vals_) {
          break;
        }
162
      } else if (cur_pos > data_indices[i]) {
163
164
165
        if (++i >= end) {
          break;
        }
166
      } else {
Guolin Ke's avatar
Guolin Ke committed
167
168
169
        const uint32_t ti = static_cast<uint32_t>(vals_[i_delta]) << 1;
        grad[ti] += ordered_gradients[i];
        ++cnt[ti];
170
171
172
        if (++i >= end) {
          break;
        }
173
        cur_pos += deltas_[++i_delta];
174
175
176
        if (i_delta >= num_vals_) {
          break;
        }
177
178
      }
    }
179
180
  }

181
  void ConstructHistogram(data_size_t start, data_size_t end,
182
183
                          const score_t* ordered_gradients,
                          hist_t* out) const override {
184
185
    data_size_t i_delta, cur_pos;
    InitIndex(start, &i_delta, &cur_pos);
Guolin Ke's avatar
Guolin Ke committed
186
187
    hist_t* grad = out;
    hist_cnt_t* cnt = reinterpret_cast<hist_cnt_t*>(out + 1);
188
189
190
191
    while (cur_pos < start && i_delta < num_vals_) {
      cur_pos += deltas_[++i_delta];
    }
    while (cur_pos < end && i_delta < num_vals_) {
Guolin Ke's avatar
Guolin Ke committed
192
193
194
      const uint32_t ti = static_cast<uint32_t>(vals_[i_delta]) << 1;
      grad[ti] += ordered_gradients[cur_pos];
      ++cnt[ti];
195
196
      cur_pos += deltas_[++i_delta];
    }
197
  }
198
#undef ACC_GH
199

200
201
  inline void NextNonzeroFast(data_size_t* i_delta,
                              data_size_t* cur_pos) const {
202
203
204
    *cur_pos += deltas_[++(*i_delta)];
    if (*i_delta >= num_vals_) {
      *cur_pos = num_data_;
205
    }
206
207
  }

208
  inline bool NextNonzero(data_size_t* i_delta, data_size_t* cur_pos) const {
209
    *cur_pos += deltas_[++(*i_delta)];
210
    if (*i_delta < num_vals_) {
211
212
      return true;
    } else {
213
      *cur_pos = num_data_;
214
215
216
217
      return false;
    }
  }

218
219
220
221
222
223
224
225
226
227
  template <bool MISS_IS_ZERO, bool MISS_IS_NA, bool MFB_IS_ZERO,
            bool MFB_IS_NA, bool USE_MIN_BIN>
  data_size_t SplitInner(uint32_t min_bin, uint32_t max_bin,
                         uint32_t default_bin, uint32_t most_freq_bin,
                         bool default_left, uint32_t threshold,
                         const data_size_t* data_indices, data_size_t cnt,
                         data_size_t* lte_indices,
                         data_size_t* gt_indices) const {
    auto th = static_cast<VAL_T>(threshold + min_bin);
    auto t_zero_bin = static_cast<VAL_T>(min_bin + default_bin);
Guolin Ke's avatar
Guolin Ke committed
228
    if (most_freq_bin == 0) {
229
230
      --th;
      --t_zero_bin;
Guolin Ke's avatar
Guolin Ke committed
231
    }
232
233
    const auto minb = static_cast<VAL_T>(min_bin);
    const auto maxb = static_cast<VAL_T>(max_bin);
Guolin Ke's avatar
Guolin Ke committed
234
235
    data_size_t lte_count = 0;
    data_size_t gt_count = 0;
Guolin Ke's avatar
Guolin Ke committed
236
237
    data_size_t* default_indices = gt_indices;
    data_size_t* default_count = &gt_count;
Guolin Ke's avatar
Guolin Ke committed
238
239
240
241
242
243
    data_size_t* missing_default_indices = gt_indices;
    data_size_t* missing_default_count = &gt_count;
    if (most_freq_bin <= threshold) {
      default_indices = lte_indices;
      default_count = &lte_count;
    }
244
    if (MISS_IS_ZERO || MISS_IS_NA) {
245
246
247
      if (default_left) {
        missing_default_indices = lte_indices;
        missing_default_count = &lte_count;
Guolin Ke's avatar
Guolin Ke committed
248
      }
249
250
251
252
253
254
255
256
257
258
259
260
    }
    SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
    if (min_bin < max_bin) {
      for (data_size_t i = 0; i < cnt; ++i) {
        const data_size_t idx = data_indices[i];
        const auto bin = iterator.InnerRawGet(idx);
        if ((MISS_IS_ZERO && !MFB_IS_ZERO && bin == t_zero_bin) ||
            (MISS_IS_NA && !MFB_IS_NA && bin == maxb)) {
          missing_default_indices[(*missing_default_count)++] = idx;
        } else if ((USE_MIN_BIN && (bin < minb || bin > maxb)) ||
                   (!USE_MIN_BIN && bin == 0)) {
          if ((MISS_IS_NA && MFB_IS_NA) || (MISS_IS_ZERO && MFB_IS_ZERO)) {
261
262
263
264
            missing_default_indices[(*missing_default_count)++] = idx;
          } else {
            default_indices[(*default_count)++] = idx;
          }
265
266
267
268
        } else if (bin > th) {
          gt_indices[gt_count++] = idx;
        } else {
          lte_indices[lte_count++] = idx;
269
270
271
        }
      }
    } else {
272
273
274
275
276
      data_size_t* max_bin_indices = gt_indices;
      data_size_t* max_bin_count = &gt_count;
      if (maxb <= th) {
        max_bin_indices = lte_indices;
        max_bin_count = &lte_count;
277
      }
278
279
280
281
282
283
284
      for (data_size_t i = 0; i < cnt; ++i) {
        const data_size_t idx = data_indices[i];
        const auto bin = iterator.InnerRawGet(idx);
        if (MISS_IS_ZERO && !MFB_IS_ZERO && bin == t_zero_bin) {
          missing_default_indices[(*missing_default_count)++] = idx;
        } else if (bin != maxb) {
          if ((MISS_IS_NA && MFB_IS_NA) || (MISS_IS_ZERO && MFB_IS_ZERO)) {
Guolin Ke's avatar
Guolin Ke committed
285
286
            missing_default_indices[(*missing_default_count)++] = idx;
          } else {
287
            default_indices[(*default_count)++] = idx;
Guolin Ke's avatar
Guolin Ke committed
288
          }
289
290
        } else {
          if (MISS_IS_NA && !MFB_IS_NA) {
Guolin Ke's avatar
Guolin Ke committed
291
292
            missing_default_indices[(*missing_default_count)++] = idx;
          } else {
293
            max_bin_indices[(*max_bin_count)++] = idx;
Guolin Ke's avatar
Guolin Ke committed
294
          }
295
        }
Guolin Ke's avatar
Guolin Ke committed
296
      }
297
    }
298
299
300
    return lte_count;
  }

301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
  data_size_t Split(uint32_t min_bin, uint32_t max_bin, uint32_t default_bin,
                    uint32_t most_freq_bin, MissingType missing_type,
                    bool default_left, uint32_t threshold,
                    const data_size_t* data_indices, data_size_t cnt,
                    data_size_t* lte_indices,
                    data_size_t* gt_indices) const override {
#define ARGUMENTS                                                        \
  min_bin, max_bin, default_bin, most_freq_bin, default_left, threshold, \
      data_indices, cnt, lte_indices, gt_indices
    if (missing_type == MissingType::None) {
      return SplitInner<false, false, false, false, true>(ARGUMENTS);
    } else if (missing_type == MissingType::Zero) {
      if (default_bin == most_freq_bin) {
        return SplitInner<true, false, true, false, true>(ARGUMENTS);
      } else {
        return SplitInner<true, false, false, false, true>(ARGUMENTS);
      }
    } else {
      if (max_bin == most_freq_bin + min_bin && most_freq_bin > 0) {
        return SplitInner<false, true, false, true, true>(ARGUMENTS);
      } else {
        return SplitInner<false, true, false, false, true>(ARGUMENTS);
      }
    }
#undef ARGUMENTS
  }

  data_size_t Split(uint32_t max_bin, uint32_t default_bin,
                    uint32_t most_freq_bin, MissingType missing_type,
                    bool default_left, uint32_t threshold,
                    const data_size_t* data_indices, data_size_t cnt,
                    data_size_t* lte_indices,
                    data_size_t* gt_indices) const override {
#define ARGUMENTS                                                  \
  1, max_bin, default_bin, most_freq_bin, default_left, threshold, \
      data_indices, cnt, lte_indices, gt_indices
    if (missing_type == MissingType::None) {
      return SplitInner<false, false, false, false, false>(ARGUMENTS);
    } else if (missing_type == MissingType::Zero) {
      if (default_bin == most_freq_bin) {
        return SplitInner<true, false, true, false, false>(ARGUMENTS);
      } else {
        return SplitInner<true, false, false, false, false>(ARGUMENTS);
      }
    } else {
      if (max_bin == most_freq_bin + 1 && most_freq_bin > 0) {
        return SplitInner<false, true, false, true, false>(ARGUMENTS);
      } else {
        return SplitInner<false, true, false, false, false>(ARGUMENTS);
      }
    }
#undef ARGUMENTS
  }
  template <bool USE_MIN_BIN>
  data_size_t SplitCategoricalInner(uint32_t min_bin, uint32_t max_bin,
                                    uint32_t most_freq_bin,
                                    const uint32_t* threshold,
                                    int num_threahold,
                                    const data_size_t* data_indices,
                                    data_size_t cnt, data_size_t* lte_indices,
                                    data_size_t* gt_indices) const {
362
363
364
365
    data_size_t lte_count = 0;
    data_size_t gt_count = 0;
    data_size_t* default_indices = gt_indices;
    data_size_t* default_count = &gt_count;
366
    SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
Guolin Ke's avatar
Guolin Ke committed
367
    if (Common::FindInBitset(threshold, num_threahold, most_freq_bin)) {
368
369
370
      default_indices = lte_indices;
      default_count = &lte_count;
    }
371
    for (data_size_t i = 0; i < cnt; ++i) {
372
      const data_size_t idx = data_indices[i];
373
374
375
376
      const uint32_t bin = iterator.RawGet(idx);
      if (USE_MIN_BIN && (bin < min_bin || bin > max_bin)) {
        default_indices[(*default_count)++] = idx;
      } else if (!USE_MIN_BIN && bin == 0) {
377
        default_indices[(*default_count)++] = idx;
378
379
      } else if (Common::FindInBitset(threshold, num_threahold,
                                      bin - min_bin)) {
380
381
382
383
        lte_indices[lte_count++] = idx;
      } else {
        gt_indices[gt_count++] = idx;
      }
Guolin Ke's avatar
Guolin Ke committed
384
385
386
387
    }
    return lte_count;
  }

388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
  data_size_t SplitCategorical(uint32_t min_bin, uint32_t max_bin,
                               uint32_t most_freq_bin,
                               const uint32_t* threshold, int num_threahold,
                               const data_size_t* data_indices, data_size_t cnt,
                               data_size_t* lte_indices,
                               data_size_t* gt_indices) const override {
    return SplitCategoricalInner<true>(min_bin, max_bin, most_freq_bin,
                                       threshold, num_threahold, data_indices,
                                       cnt, lte_indices, gt_indices);
  }

  data_size_t SplitCategorical(uint32_t max_bin, uint32_t most_freq_bin,
                               const uint32_t* threshold, int num_threahold,
                               const data_size_t* data_indices, data_size_t cnt,
                               data_size_t* lte_indices,
                               data_size_t* gt_indices) const override {
    return SplitCategoricalInner<false>(1, max_bin, most_freq_bin, threshold,
                                        num_threahold, data_indices, cnt,
                                        lte_indices, gt_indices);
  }

Guolin Ke's avatar
Guolin Ke committed
409
410
411
412
  data_size_t num_data() const override { return num_data_; }

  void FinishLoad() override {
    // get total non zero size
413
    size_t pair_cnt = 0;
414
    for (size_t i = 0; i < push_buffers_.size(); ++i) {
415
      pair_cnt += push_buffers_[i].size();
Guolin Ke's avatar
Guolin Ke committed
416
    }
417
418
    std::vector<std::pair<data_size_t, VAL_T>>& idx_val_pairs =
        push_buffers_[0];
419
    idx_val_pairs.reserve(pair_cnt);
Guolin Ke's avatar
Guolin Ke committed
420
421

    for (size_t i = 1; i < push_buffers_.size(); ++i) {
422
423
      idx_val_pairs.insert(idx_val_pairs.end(), push_buffers_[i].begin(),
                           push_buffers_[i].end());
Guolin Ke's avatar
Guolin Ke committed
424
425
426
427
      push_buffers_[i].clear();
      push_buffers_[i].shrink_to_fit();
    }
    // sort by data index
428
    std::sort(idx_val_pairs.begin(), idx_val_pairs.end(),
429
430
431
432
              [](const std::pair<data_size_t, VAL_T>& a,
                 const std::pair<data_size_t, VAL_T>& b) {
                return a.first < b.first;
              });
zhangyafeikimi's avatar
zhangyafeikimi committed
433
    // load delta array
434
    LoadFromPair(idx_val_pairs);
Guolin Ke's avatar
Guolin Ke committed
435
436
  }

437
438
  void LoadFromPair(
      const std::vector<std::pair<data_size_t, VAL_T>>& idx_val_pairs) {
439
    deltas_.clear();
Guolin Ke's avatar
Guolin Ke committed
440
    vals_.clear();
Guolin Ke's avatar
Guolin Ke committed
441
442
    deltas_.reserve(idx_val_pairs.size());
    vals_.reserve(idx_val_pairs.size());
Guolin Ke's avatar
Guolin Ke committed
443
444
    // transform to delta array
    data_size_t last_idx = 0;
445
446
447
    for (size_t i = 0; i < idx_val_pairs.size(); ++i) {
      const data_size_t cur_idx = idx_val_pairs[i].first;
      const VAL_T bin = idx_val_pairs[i].second;
Guolin Ke's avatar
Guolin Ke committed
448
      data_size_t cur_delta = cur_idx - last_idx;
449
      // disallow the multi-val in one row
450
451
452
      if (i > 0 && cur_delta == 0) {
        continue;
      }
453
      while (cur_delta >= 256) {
454
        deltas_.push_back(255);
Guolin Ke's avatar
Guolin Ke committed
455
        vals_.push_back(0);
456
        cur_delta -= 255;
Guolin Ke's avatar
Guolin Ke committed
457
      }
458
      deltas_.push_back(static_cast<uint8_t>(cur_delta));
Guolin Ke's avatar
Guolin Ke committed
459
460
461
462
      vals_.push_back(bin);
      last_idx = cur_idx;
    }
    // avoid out of range
463
    deltas_.push_back(0);
Guolin Ke's avatar
Guolin Ke committed
464
465
466
    num_vals_ = static_cast<data_size_t>(vals_.size());

    // reduce memory cost
467
    deltas_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
    vals_.shrink_to_fit();

    // generate fast index
    GetFastIndex();
  }

  void GetFastIndex() {
    fast_index_.clear();
    // get shift cnt
    data_size_t mod_size = (num_data_ + kNumFastIndex - 1) / kNumFastIndex;
    data_size_t pow2_mod_size = 1;
    fast_index_shift_ = 0;
    while (pow2_mod_size < mod_size) {
      pow2_mod_size <<= 1;
      ++fast_index_shift_;
    }
    // build fast index
485
    data_size_t i_delta = -1;
Guolin Ke's avatar
Guolin Ke committed
486
    data_size_t cur_pos = 0;
487
488
    data_size_t next_threshold = 0;
    while (NextNonzero(&i_delta, &cur_pos)) {
Guolin Ke's avatar
Guolin Ke committed
489
      while (next_threshold <= cur_pos) {
490
491
        fast_index_.emplace_back(i_delta, cur_pos);
        next_threshold += pow2_mod_size;
Guolin Ke's avatar
Guolin Ke committed
492
493
494
      }
    }
    // avoid out of range
495
    while (next_threshold < num_data_) {
Guolin Ke's avatar
Guolin Ke committed
496
      fast_index_.emplace_back(num_vals_ - 1, cur_pos);
497
      next_threshold += pow2_mod_size;
Guolin Ke's avatar
Guolin Ke committed
498
499
500
501
    }
    fast_index_.shrink_to_fit();
  }

502
503
504
505
  void SaveBinaryToFile(const VirtualFileWriter* writer) const override {
    writer->Write(&num_vals_, sizeof(num_vals_));
    writer->Write(deltas_.data(), sizeof(uint8_t) * (num_vals_ + 1));
    writer->Write(vals_.data(), sizeof(VAL_T) * num_vals_);
Guolin Ke's avatar
Guolin Ke committed
506
507
508
  }

  size_t SizesInByte() const override {
509
510
    return sizeof(num_vals_) + sizeof(uint8_t) * (num_vals_ + 1) +
           sizeof(VAL_T) * num_vals_;
Guolin Ke's avatar
Guolin Ke committed
511
512
  }

513
514
515
  void LoadFromMemory(
      const void* memory,
      const std::vector<data_size_t>& local_used_indices) override {
Guolin Ke's avatar
Guolin Ke committed
516
517
518
519
520
521
522
    const char* mem_ptr = reinterpret_cast<const char*>(memory);
    data_size_t tmp_num_vals = *(reinterpret_cast<const data_size_t*>(mem_ptr));
    mem_ptr += sizeof(tmp_num_vals);
    const uint8_t* tmp_delta = reinterpret_cast<const uint8_t*>(mem_ptr);
    mem_ptr += sizeof(uint8_t) * (tmp_num_vals + 1);
    const VAL_T* tmp_vals = reinterpret_cast<const VAL_T*>(mem_ptr);

523
524
525
526
527
528
529
530
531
532
533
    deltas_.clear();
    vals_.clear();
    num_vals_ = tmp_num_vals;
    for (data_size_t i = 0; i < num_vals_; ++i) {
      deltas_.push_back(tmp_delta[i]);
      vals_.push_back(tmp_vals[i]);
    }
    deltas_.push_back(0);
    // reduce memory cost
    deltas_.shrink_to_fit();
    vals_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
534

Guolin Ke's avatar
Guolin Ke committed
535
    if (local_used_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
536
537
538
539
      // generate fast index
      GetFastIndex();
    } else {
      std::vector<std::pair<data_size_t, VAL_T>> tmp_pair;
540
541
      data_size_t cur_pos = 0;
      data_size_t j = -1;
542
543
      for (data_size_t i = 0;
           i < static_cast<data_size_t>(local_used_indices.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
544
        const data_size_t idx = local_used_indices[i];
545
546
        while (cur_pos < idx && j < num_vals_) {
          NextNonzero(&j, &cur_pos);
Guolin Ke's avatar
Guolin Ke committed
547
        }
548
        if (cur_pos == idx && j < num_vals_ && vals_[j] > 0) {
Guolin Ke's avatar
Guolin Ke committed
549
          // new row index is i
550
          tmp_pair.emplace_back(i, vals_[j]);
Guolin Ke's avatar
Guolin Ke committed
551
552
553
554
        }
      }
      LoadFromPair(tmp_pair);
    }
555
  }
Guolin Ke's avatar
Guolin Ke committed
556

557
558
  void CopySubrow(const Bin* full_bin, const data_size_t* used_indices,
                  data_size_t num_used_indices) override {
Guolin Ke's avatar
Guolin Ke committed
559
    auto other_bin = dynamic_cast<const SparseBin<VAL_T>*>(full_bin);
Guolin Ke's avatar
Guolin Ke committed
560
561
    deltas_.clear();
    vals_.clear();
Guolin Ke's avatar
Guolin Ke committed
562
563
564
565
566
    data_size_t start = 0;
    if (num_used_indices > 0) {
      start = used_indices[0];
    }
    SparseBinIterator<VAL_T> iterator(other_bin, start);
567
568
    // transform to delta array
    data_size_t last_idx = 0;
569
    for (data_size_t i = 0; i < num_used_indices; ++i) {
570
      auto bin = iterator.InnerRawGet(used_indices[i]);
Guolin Ke's avatar
Guolin Ke committed
571
      if (bin > 0) {
572
573
        data_size_t cur_delta = i - last_idx;
        while (cur_delta >= 256) {
574
          deltas_.push_back(255);
575
          vals_.push_back(0);
576
          cur_delta -= 255;
577
578
579
580
        }
        deltas_.push_back(static_cast<uint8_t>(cur_delta));
        vals_.push_back(bin);
        last_idx = i;
581
582
      }
    }
583
584
585
586
587
588
589
590
591
592
    // avoid out of range
    deltas_.push_back(0);
    num_vals_ = static_cast<data_size_t>(vals_.size());

    // reduce memory cost
    deltas_.shrink_to_fit();
    vals_.shrink_to_fit();

    // generate fast index
    GetFastIndex();
Guolin Ke's avatar
Guolin Ke committed
593
594
  }

595
596
597
  SparseBin<VAL_T>* Clone() override;

  SparseBin<VAL_T>(const SparseBin<VAL_T>& other)
598
599
600
601
602
603
604
605
606
607
      : num_data_(other.num_data_),
        deltas_(other.deltas_),
        vals_(other.vals_),
        num_vals_(other.num_vals_),
        push_buffers_(other.push_buffers_),
        fast_index_(other.fast_index_),
        fast_index_shift_(other.fast_index_shift_) {}

  void InitIndex(data_size_t start_idx, data_size_t* i_delta,
                 data_size_t* cur_pos) const {
608
609
610
611
612
613
614
615
616
617
618
    auto idx = start_idx >> fast_index_shift_;
    if (static_cast<size_t>(idx) < fast_index_.size()) {
      const auto fast_pair = fast_index_[start_idx >> fast_index_shift_];
      *i_delta = fast_pair.first;
      *cur_pos = fast_pair.second;
    } else {
      *i_delta = -1;
      *cur_pos = 0;
    }
  }

619
 private:
Guolin Ke's avatar
Guolin Ke committed
620
  data_size_t num_data_;
621
622
  std::vector<uint8_t, Common::AlignmentAllocator<uint8_t, kAlignedSize>>
      deltas_;
623
  std::vector<VAL_T, Common::AlignmentAllocator<VAL_T, kAlignedSize>> vals_;
Guolin Ke's avatar
Guolin Ke committed
624
625
626
627
628
629
  data_size_t num_vals_;
  std::vector<std::vector<std::pair<data_size_t, VAL_T>>> push_buffers_;
  std::vector<std::pair<data_size_t, data_size_t>> fast_index_;
  data_size_t fast_index_shift_;
};

630
template <typename VAL_T>
631
SparseBin<VAL_T>* SparseBin<VAL_T>::Clone() {
632
633
634
  return new SparseBin(*this);
}

Guolin Ke's avatar
Guolin Ke committed
635
template <typename VAL_T>
636
637
638
639
640
641
inline uint32_t SparseBinIterator<VAL_T>::RawGet(data_size_t idx) {
  return InnerRawGet(idx);
}

template <typename VAL_T>
inline VAL_T SparseBinIterator<VAL_T>::InnerRawGet(data_size_t idx) {
642
  while (cur_pos_ < idx) {
643
    bin_data_->NextNonzeroFast(&i_delta_, &cur_pos_);
Guolin Ke's avatar
Guolin Ke committed
644
  }
645
  if (cur_pos_ == idx) {
646
647
    return bin_data_->vals_[i_delta_];
  } else {
Guolin Ke's avatar
Guolin Ke committed
648
    return 0;
Guolin Ke's avatar
Guolin Ke committed
649
  }
650
}
Guolin Ke's avatar
Guolin Ke committed
651

652
653
template <typename VAL_T>
inline void SparseBinIterator<VAL_T>::Reset(data_size_t start_idx) {
654
  bin_data_->InitIndex(start_idx, &i_delta_, &cur_pos_);
655
}
Guolin Ke's avatar
Guolin Ke committed
656
657

template <typename VAL_T>
658
659
BinIterator* SparseBin<VAL_T>::GetIterator(uint32_t min_bin, uint32_t max_bin,
                                           uint32_t most_freq_bin) const {
Guolin Ke's avatar
Guolin Ke committed
660
  return new SparseBinIterator<VAL_T>(this, min_bin, max_bin, most_freq_bin);
Guolin Ke's avatar
Guolin Ke committed
661
662
663
}

}  // namespace LightGBM
664
#endif  // LightGBM_IO_SPARSE_BIN_HPP_