sparse_bin.hpp 17.4 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#ifndef LIGHTGBM_IO_SPARSE_BIN_HPP_
#define LIGHTGBM_IO_SPARSE_BIN_HPP_

#include <LightGBM/bin.h>
9
#include <LightGBM/utils/log.h>
10
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
11

12
#include <limits>
13
14
15
16
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <utility>
Guolin Ke's avatar
Guolin Ke committed
17
18
19
20
#include <vector>

namespace LightGBM {

21
template <typename VAL_T> class SparseBin;
22

Guolin Ke's avatar
Guolin Ke committed
23
24
const size_t kNumFastIndex = 64;

25
26
template <typename VAL_T>
class SparseBinIterator: public BinIterator {
27
 public:
Guolin Ke's avatar
Guolin Ke committed
28
  SparseBinIterator(const SparseBin<VAL_T>* bin_data,
Guolin Ke's avatar
Guolin Ke committed
29
    uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin)
Guolin Ke's avatar
Guolin Ke committed
30
31
    : bin_data_(bin_data), min_bin_(static_cast<VAL_T>(min_bin)),
    max_bin_(static_cast<VAL_T>(max_bin)),
Guolin Ke's avatar
Guolin Ke committed
32
33
    most_freq_bin_(static_cast<VAL_T>(most_freq_bin)) {
    if (most_freq_bin_ == 0) {
34
      offset_ = 1;
Guolin Ke's avatar
Guolin Ke committed
35
    } else {
36
      offset_ = 0;
Guolin Ke's avatar
Guolin Ke committed
37
38
39
    }
    Reset(0);
  }
40
41
42
43
44
  SparseBinIterator(const SparseBin<VAL_T>* bin_data, data_size_t start_idx)
    : bin_data_(bin_data) {
    Reset(start_idx);
  }

45
46
  inline uint32_t RawGet(data_size_t idx) override;
  inline VAL_T InnerRawGet(data_size_t idx);
47

48
  inline uint32_t Get(data_size_t idx) override {
Guolin Ke's avatar
Guolin Ke committed
49
    VAL_T ret = InnerRawGet(idx);
Guolin Ke's avatar
Guolin Ke committed
50
    if (ret >= min_bin_ && ret <= max_bin_) {
51
      return ret - min_bin_ + offset_;
Guolin Ke's avatar
Guolin Ke committed
52
    } else {
Guolin Ke's avatar
Guolin Ke committed
53
      return most_freq_bin_;
Guolin Ke's avatar
Guolin Ke committed
54
    }
55
56
  }

Guolin Ke's avatar
Guolin Ke committed
57
  inline void Reset(data_size_t idx) override;
58

59
 private:
60
61
62
  const SparseBin<VAL_T>* bin_data_;
  data_size_t cur_pos_;
  data_size_t i_delta_;
Guolin Ke's avatar
Guolin Ke committed
63
64
  VAL_T min_bin_;
  VAL_T max_bin_;
Guolin Ke's avatar
Guolin Ke committed
65
  VAL_T most_freq_bin_;
66
  uint8_t offset_;
67
68
};

Guolin Ke's avatar
Guolin Ke committed
69
template <typename VAL_T>
70
class SparseBin: public Bin {
71
 public:
Guolin Ke's avatar
Guolin Ke committed
72
73
  friend class SparseBinIterator<VAL_T>;

Guolin Ke's avatar
Guolin Ke committed
74
  explicit SparseBin(data_size_t num_data)
Guolin Ke's avatar
Guolin Ke committed
75
    : num_data_(num_data) {
Guolin Ke's avatar
Guolin Ke committed
76
    int num_threads = 1;
77
78
    #pragma omp parallel
    #pragma omp master
Guolin Ke's avatar
Guolin Ke committed
79
    {
Guolin Ke's avatar
Guolin Ke committed
80
      num_threads = omp_get_num_threads();
Guolin Ke's avatar
Guolin Ke committed
81
    }
Guolin Ke's avatar
Guolin Ke committed
82
    push_buffers_.resize(num_threads);
Guolin Ke's avatar
Guolin Ke committed
83
84
85
  }

  ~SparseBin() {
Guolin Ke's avatar
Guolin Ke committed
86
87
88
89
  }

  void ReSize(data_size_t num_data) override {
    num_data_ = num_data;
Guolin Ke's avatar
Guolin Ke committed
90
91
92
  }

  void Push(int tid, data_size_t idx, uint32_t value) override {
93
    auto cur_bin = static_cast<VAL_T>(value);
Guolin Ke's avatar
Guolin Ke committed
94
    if (cur_bin != 0) {
95
96
      push_buffers_[tid].emplace_back(idx, cur_bin);
    }
Guolin Ke's avatar
Guolin Ke committed
97
98
  }

Guolin Ke's avatar
Guolin Ke committed
99
  BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin) const override;
Guolin Ke's avatar
Guolin Ke committed
100

101
102
103
104
105
106
  #define ACC_GH(hist, i, g, h) \
  const auto ti = static_cast<int>(i) << 1; \
  hist[ti] += g; \
  hist[ti + 1] += h; \

  void ConstructHistogram(const data_size_t* data_indices, data_size_t start, data_size_t end,
107
    const score_t* ordered_gradients, const score_t* ordered_hessians, hist_t* out) const override {
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    data_size_t i_delta, cur_pos;
    InitIndex(data_indices[start], &i_delta, &cur_pos);
    data_size_t i = start;
    for (;;) {
      if (cur_pos < data_indices[i]) {
        cur_pos += deltas_[++i_delta];
        if (i_delta >= num_vals_) { break; }
      } else if (cur_pos > data_indices[i]) {
        if (++i >= end) { break; }
      } else {
        const VAL_T bin = vals_[i_delta];
        ACC_GH(out, bin, ordered_gradients[i], ordered_hessians[i]);
        if (++i >= end) { break; }
        cur_pos += deltas_[++i_delta];
        if (i_delta >= num_vals_) { break; }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
125
126
  }

127
  void ConstructHistogram(data_size_t start, data_size_t end,
128
    const score_t* ordered_gradients, const score_t* ordered_hessians, hist_t* out) const override {
129
130
131
132
133
134
135
136
137
138
    data_size_t i_delta, cur_pos;
    InitIndex(start, &i_delta, &cur_pos);
    while (cur_pos < start && i_delta < num_vals_) {
      cur_pos += deltas_[++i_delta];
    }
    while (cur_pos < end && i_delta < num_vals_) {
      const VAL_T bin = vals_[i_delta];
      ACC_GH(out, bin, ordered_gradients[cur_pos], ordered_hessians[cur_pos]);
      cur_pos += deltas_[++i_delta];
    }
139
140
  }

141
  void ConstructHistogram(const data_size_t* data_indices, data_size_t start, data_size_t end,
142
    const score_t* ordered_gradients, hist_t* out) const override {
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    data_size_t i_delta, cur_pos;
    InitIndex(data_indices[start], &i_delta, &cur_pos);
    data_size_t i = start;
    for (;;) {
      if (cur_pos < data_indices[i]) {
        cur_pos += deltas_[++i_delta];
        if (i_delta >= num_vals_) { break; }
      } else if (cur_pos > data_indices[i]) {
        if (++i >= end) { break; }
      } else {
        const VAL_T bin = vals_[i_delta];
        ACC_GH(out, bin, ordered_gradients[i], 1.0f);
        if (++i >= end) { break; }
        cur_pos += deltas_[++i_delta];
        if (i_delta >= num_vals_) { break; }
      }
    }
160
161
  }

162
  void ConstructHistogram(data_size_t start, data_size_t end,
163
    const score_t* ordered_gradients, hist_t* out) const override {
164
165
166
167
168
169
170
171
172
173
    data_size_t i_delta, cur_pos;
    InitIndex(start, &i_delta, &cur_pos);
    while (cur_pos < start && i_delta < num_vals_) {
      cur_pos += deltas_[++i_delta];
    }
    while (cur_pos < end && i_delta < num_vals_) {
      const VAL_T bin = vals_[i_delta];
      ACC_GH(out, bin, ordered_gradients[cur_pos], 1.0f);
      cur_pos += deltas_[++i_delta];
    }
174
  }
175
  #undef ACC_GH
176

177
  inline void NextNonzeroFast(data_size_t* i_delta, data_size_t* cur_pos) const {
178
179
180
    *cur_pos += deltas_[++(*i_delta)];
    if (*i_delta >= num_vals_) {
      *cur_pos = num_data_;
181
    }
182
183
184
185
186
  }

  inline bool NextNonzero(data_size_t* i_delta,
    data_size_t* cur_pos) const {
    *cur_pos += deltas_[++(*i_delta)];
187
    if (*i_delta < num_vals_) {
188
189
      return true;
    } else {
190
      *cur_pos = num_data_;
191
192
193
194
      return false;
    }
  }

Guolin Ke's avatar
Guolin Ke committed
195

Guolin Ke's avatar
Guolin Ke committed
196
  data_size_t Split(
197
198
    uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, uint32_t most_freq_bin,
    MissingType missing_type, bool default_left,
Guolin Ke's avatar
Guolin Ke committed
199
    uint32_t threshold, data_size_t* data_indices, data_size_t num_data,
200
    data_size_t* lte_indices, data_size_t* gt_indices) const override {
201
    if (num_data <= 0) { return 0; }
Guolin Ke's avatar
Guolin Ke committed
202
    VAL_T th = static_cast<VAL_T>(threshold + min_bin);
203
204
    const VAL_T minb = static_cast<VAL_T>(min_bin);
    const VAL_T maxb = static_cast<VAL_T>(max_bin);
Guolin Ke's avatar
Guolin Ke committed
205
    VAL_T t_default_bin = static_cast<VAL_T>(min_bin + default_bin);
Guolin Ke's avatar
Guolin Ke committed
206
207
    VAL_T t_most_freq_bin = static_cast<VAL_T>(min_bin + most_freq_bin);
    if (most_freq_bin == 0) {
Guolin Ke's avatar
Guolin Ke committed
208
      th -= 1;
Guolin Ke's avatar
Guolin Ke committed
209
      t_default_bin -= 1;
Guolin Ke's avatar
Guolin Ke committed
210
      t_most_freq_bin -= 1;
Guolin Ke's avatar
Guolin Ke committed
211
    }
Guolin Ke's avatar
Guolin Ke committed
212
213
    data_size_t lte_count = 0;
    data_size_t gt_count = 0;
Guolin Ke's avatar
Guolin Ke committed
214
215
    data_size_t* default_indices = gt_indices;
    data_size_t* default_count = &gt_count;
Guolin Ke's avatar
Guolin Ke committed
216
217
218
219
220
221
222
    data_size_t* missing_default_indices = gt_indices;
    data_size_t* missing_default_count = &gt_count;
    SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
    if (most_freq_bin <= threshold) {
      default_indices = lte_indices;
      default_count = &lte_count;
    }
223
224
225
226
    if (missing_type == MissingType::NaN) {
      if (default_left) {
        missing_default_indices = lte_indices;
        missing_default_count = &lte_count;
Guolin Ke's avatar
Guolin Ke committed
227
      }
228
229
230
      for (data_size_t i = 0; i < num_data; ++i) {
        const data_size_t idx = data_indices[i];
        const VAL_T bin = iterator.InnerRawGet(idx);
Guolin Ke's avatar
Guolin Ke committed
231
        if (bin == maxb) {
232
          missing_default_indices[(*missing_default_count)++] = idx;
Guolin Ke's avatar
Guolin Ke committed
233
234
        } else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
          default_indices[(*default_count)++] = idx;
235
236
237
238
        } else if (bin > th) {
          gt_indices[gt_count++] = idx;
        } else {
          lte_indices[lte_count++] = idx;
239
240
241
        }
      }
    } else {
Guolin Ke's avatar
Guolin Ke committed
242
243
244
245
      if ((default_left && missing_type == MissingType::Zero)
        || (default_bin <= threshold && missing_type != MissingType::Zero)) {
        missing_default_indices = lte_indices;
        missing_default_count = &lte_count;
246
      }
Guolin Ke's avatar
Guolin Ke committed
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
      if (default_bin == most_freq_bin) {
        for (data_size_t i = 0; i < num_data; ++i) {
          const data_size_t idx = data_indices[i];
          const VAL_T bin = iterator.InnerRawGet(idx);
          if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
            missing_default_indices[(*missing_default_count)++] = idx;
          } else if (bin > th) {
            gt_indices[gt_count++] = idx;
          } else {
            lte_indices[lte_count++] = idx;
          }
        }
      } else {
        for (data_size_t i = 0; i < num_data; ++i) {
          const data_size_t idx = data_indices[i];
          const VAL_T bin = iterator.InnerRawGet(idx);
          if (bin == t_default_bin) {
            missing_default_indices[(*missing_default_count)++] = idx;
          } else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
            default_indices[(*default_count)++] = idx;
          } else if (bin > th) {
            gt_indices[gt_count++] = idx;
          } else {
            lte_indices[lte_count++] = idx;
          }
272
        }
Guolin Ke's avatar
Guolin Ke committed
273
      }
274
    }
275
276
277
    return lte_count;
  }

Guolin Ke's avatar
Guolin Ke committed
278
  data_size_t SplitCategorical(
Guolin Ke's avatar
Guolin Ke committed
279
    uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin,
280
    const uint32_t* threshold, int num_threahold, data_size_t* data_indices, data_size_t num_data,
281
282
283
284
285
286
287
    data_size_t* lte_indices, data_size_t* gt_indices) const override {
    if (num_data <= 0) { return 0; }
    data_size_t lte_count = 0;
    data_size_t gt_count = 0;
    SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
    data_size_t* default_indices = gt_indices;
    data_size_t* default_count = &gt_count;
Guolin Ke's avatar
Guolin Ke committed
288
    if (Common::FindInBitset(threshold, num_threahold, most_freq_bin)) {
289
290
291
292
293
294
295
296
      default_indices = lte_indices;
      default_count = &lte_count;
    }
    for (data_size_t i = 0; i < num_data; ++i) {
      const data_size_t idx = data_indices[i];
      uint32_t bin = iterator.InnerRawGet(idx);
      if (bin < min_bin || bin > max_bin) {
        default_indices[(*default_count)++] = idx;
297
      } else if (Common::FindInBitset(threshold, num_threahold, bin - min_bin)) {
298
299
300
301
        lte_indices[lte_count++] = idx;
      } else {
        gt_indices[gt_count++] = idx;
      }
Guolin Ke's avatar
Guolin Ke committed
302
303
304
305
306
307
308
309
    }
    return lte_count;
  }

  data_size_t num_data() const override { return num_data_; }

  void FinishLoad() override {
    // get total non zero size
310
    size_t pair_cnt = 0;
311
    for (size_t i = 0; i < push_buffers_.size(); ++i) {
312
      pair_cnt += push_buffers_[i].size();
Guolin Ke's avatar
Guolin Ke committed
313
    }
Guolin Ke's avatar
Guolin Ke committed
314
    std::vector<std::pair<data_size_t, VAL_T>>& idx_val_pairs = push_buffers_[0];
315
    idx_val_pairs.reserve(pair_cnt);
Guolin Ke's avatar
Guolin Ke committed
316
317

    for (size_t i = 1; i < push_buffers_.size(); ++i) {
318
      idx_val_pairs.insert(idx_val_pairs.end(), push_buffers_[i].begin(), push_buffers_[i].end());
Guolin Ke's avatar
Guolin Ke committed
319
320
321
322
      push_buffers_[i].clear();
      push_buffers_[i].shrink_to_fit();
    }
    // sort by data index
323
    std::sort(idx_val_pairs.begin(), idx_val_pairs.end(),
Guolin Ke's avatar
Guolin Ke committed
324
      [](const std::pair<data_size_t, VAL_T>& a, const std::pair<data_size_t, VAL_T>& b) {
325
326
        return a.first < b.first;
      });
zhangyafeikimi's avatar
zhangyafeikimi committed
327
    // load delta array
328
    LoadFromPair(idx_val_pairs);
Guolin Ke's avatar
Guolin Ke committed
329
330
  }

331
  void LoadFromPair(const std::vector<std::pair<data_size_t, VAL_T>>& idx_val_pairs) {
332
    deltas_.clear();
Guolin Ke's avatar
Guolin Ke committed
333
    vals_.clear();
Guolin Ke's avatar
Guolin Ke committed
334
335
    deltas_.reserve(idx_val_pairs.size());
    vals_.reserve(idx_val_pairs.size());
Guolin Ke's avatar
Guolin Ke committed
336
337
    // transform to delta array
    data_size_t last_idx = 0;
338
339
340
    for (size_t i = 0; i < idx_val_pairs.size(); ++i) {
      const data_size_t cur_idx = idx_val_pairs[i].first;
      const VAL_T bin = idx_val_pairs[i].second;
Guolin Ke's avatar
Guolin Ke committed
341
      data_size_t cur_delta = cur_idx - last_idx;
342
      // disallow the multi-val in one row
Guolin Ke's avatar
Guolin Ke committed
343
      if (i > 0 && cur_delta == 0) { continue; }
344
      while (cur_delta >= 256) {
345
        deltas_.push_back(255);
Guolin Ke's avatar
Guolin Ke committed
346
        vals_.push_back(0);
347
        cur_delta -= 255;
Guolin Ke's avatar
Guolin Ke committed
348
      }
349
      deltas_.push_back(static_cast<uint8_t>(cur_delta));
Guolin Ke's avatar
Guolin Ke committed
350
351
352
353
      vals_.push_back(bin);
      last_idx = cur_idx;
    }
    // avoid out of range
354
    deltas_.push_back(0);
Guolin Ke's avatar
Guolin Ke committed
355
356
357
    num_vals_ = static_cast<data_size_t>(vals_.size());

    // reduce memory cost
358
    deltas_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    vals_.shrink_to_fit();

    // generate fast index
    GetFastIndex();
  }

  void GetFastIndex() {
    fast_index_.clear();
    // get shift cnt
    data_size_t mod_size = (num_data_ + kNumFastIndex - 1) / kNumFastIndex;
    data_size_t pow2_mod_size = 1;
    fast_index_shift_ = 0;
    while (pow2_mod_size < mod_size) {
      pow2_mod_size <<= 1;
      ++fast_index_shift_;
    }
    // build fast index
376
    data_size_t i_delta = -1;
Guolin Ke's avatar
Guolin Ke committed
377
    data_size_t cur_pos = 0;
378
379
    data_size_t next_threshold = 0;
    while (NextNonzero(&i_delta, &cur_pos)) {
Guolin Ke's avatar
Guolin Ke committed
380
      while (next_threshold <= cur_pos) {
381
382
        fast_index_.emplace_back(i_delta, cur_pos);
        next_threshold += pow2_mod_size;
Guolin Ke's avatar
Guolin Ke committed
383
384
385
      }
    }
    // avoid out of range
386
    while (next_threshold < num_data_) {
Guolin Ke's avatar
Guolin Ke committed
387
      fast_index_.emplace_back(num_vals_ - 1, cur_pos);
388
      next_threshold += pow2_mod_size;
Guolin Ke's avatar
Guolin Ke committed
389
390
391
392
    }
    fast_index_.shrink_to_fit();
  }

393
394
395
396
  void SaveBinaryToFile(const VirtualFileWriter* writer) const override {
    writer->Write(&num_vals_, sizeof(num_vals_));
    writer->Write(deltas_.data(), sizeof(uint8_t) * (num_vals_ + 1));
    writer->Write(vals_.data(), sizeof(VAL_T) * num_vals_);
Guolin Ke's avatar
Guolin Ke committed
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
  }

  size_t SizesInByte() const override {
    return sizeof(num_vals_) + sizeof(uint8_t) * (num_vals_ + 1)
      + sizeof(VAL_T) * num_vals_;
  }

  void LoadFromMemory(const void* memory, const std::vector<data_size_t>& local_used_indices) override {
    const char* mem_ptr = reinterpret_cast<const char*>(memory);
    data_size_t tmp_num_vals = *(reinterpret_cast<const data_size_t*>(mem_ptr));
    mem_ptr += sizeof(tmp_num_vals);
    const uint8_t* tmp_delta = reinterpret_cast<const uint8_t*>(mem_ptr);
    mem_ptr += sizeof(uint8_t) * (tmp_num_vals + 1);
    const VAL_T* tmp_vals = reinterpret_cast<const VAL_T*>(mem_ptr);

412
413
414
415
416
417
418
419
420
421
422
    deltas_.clear();
    vals_.clear();
    num_vals_ = tmp_num_vals;
    for (data_size_t i = 0; i < num_vals_; ++i) {
      deltas_.push_back(tmp_delta[i]);
      vals_.push_back(tmp_vals[i]);
    }
    deltas_.push_back(0);
    // reduce memory cost
    deltas_.shrink_to_fit();
    vals_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
423

Guolin Ke's avatar
Guolin Ke committed
424
    if (local_used_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
425
426
427
428
      // generate fast index
      GetFastIndex();
    } else {
      std::vector<std::pair<data_size_t, VAL_T>> tmp_pair;
429
430
      data_size_t cur_pos = 0;
      data_size_t j = -1;
Guolin Ke's avatar
Guolin Ke committed
431
432
      for (data_size_t i = 0; i < static_cast<data_size_t>(local_used_indices.size()); ++i) {
        const data_size_t idx = local_used_indices[i];
433
434
        while (cur_pos < idx && j < num_vals_) {
          NextNonzero(&j, &cur_pos);
Guolin Ke's avatar
Guolin Ke committed
435
        }
436
        if (cur_pos == idx && j < num_vals_ && vals_[j] > 0) {
Guolin Ke's avatar
Guolin Ke committed
437
          // new row index is i
438
          tmp_pair.emplace_back(i, vals_[j]);
Guolin Ke's avatar
Guolin Ke committed
439
440
441
442
        }
      }
      LoadFromPair(tmp_pair);
    }
443
  }
Guolin Ke's avatar
Guolin Ke committed
444

445
  void CopySubset(const Bin* full_bin, const data_size_t* used_indices, data_size_t num_used_indices) override {
Guolin Ke's avatar
Guolin Ke committed
446
    auto other_bin = dynamic_cast<const SparseBin<VAL_T>*>(full_bin);
Guolin Ke's avatar
Guolin Ke committed
447
448
    deltas_.clear();
    vals_.clear();
Guolin Ke's avatar
Guolin Ke committed
449
450
451
452
453
    data_size_t start = 0;
    if (num_used_indices > 0) {
      start = used_indices[0];
    }
    SparseBinIterator<VAL_T> iterator(other_bin, start);
454
455
    // transform to delta array
    data_size_t last_idx = 0;
456
    for (data_size_t i = 0; i < num_used_indices; ++i) {
457
      auto bin = iterator.InnerRawGet(used_indices[i]);
Guolin Ke's avatar
Guolin Ke committed
458
      if (bin > 0) {
459
460
        data_size_t cur_delta = i - last_idx;
        while (cur_delta >= 256) {
461
          deltas_.push_back(255);
462
          vals_.push_back(0);
463
          cur_delta -= 255;
464
465
466
467
        }
        deltas_.push_back(static_cast<uint8_t>(cur_delta));
        vals_.push_back(bin);
        last_idx = i;
468
469
      }
    }
470
471
472
473
474
475
476
477
478
479
    // avoid out of range
    deltas_.push_back(0);
    num_vals_ = static_cast<data_size_t>(vals_.size());

    // reduce memory cost
    deltas_.shrink_to_fit();
    vals_.shrink_to_fit();

    // generate fast index
    GetFastIndex();
Guolin Ke's avatar
Guolin Ke committed
480
481
  }

482
483
484
485
  SparseBin<VAL_T>* Clone() override;

  SparseBin<VAL_T>(const SparseBin<VAL_T>& other)
    : num_data_(other.num_data_), deltas_(other.deltas_), vals_(other.vals_),
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
    num_vals_(other.num_vals_), push_buffers_(other.push_buffers_),
    fast_index_(other.fast_index_), fast_index_shift_(other.fast_index_shift_) {
  }

  void InitIndex(data_size_t start_idx, data_size_t * i_delta, data_size_t * cur_pos) const {
    auto idx = start_idx >> fast_index_shift_;
    if (static_cast<size_t>(idx) < fast_index_.size()) {
      const auto fast_pair = fast_index_[start_idx >> fast_index_shift_];
      *i_delta = fast_pair.first;
      *cur_pos = fast_pair.second;
    } else {
      *i_delta = -1;
      *cur_pos = 0;
    }
  }

502
 private:
Guolin Ke's avatar
Guolin Ke committed
503
  data_size_t num_data_;
504
505
  std::vector<uint8_t, Common::AlignmentAllocator<uint8_t, kAlignedSize>> deltas_;
  std::vector<VAL_T, Common::AlignmentAllocator<VAL_T, kAlignedSize>> vals_;
Guolin Ke's avatar
Guolin Ke committed
506
507
508
509
510
511
  data_size_t num_vals_;
  std::vector<std::vector<std::pair<data_size_t, VAL_T>>> push_buffers_;
  std::vector<std::pair<data_size_t, data_size_t>> fast_index_;
  data_size_t fast_index_shift_;
};

512
template<typename VAL_T>
513
SparseBin<VAL_T>* SparseBin<VAL_T>::Clone() {
514
515
516
  return new SparseBin(*this);
}

Guolin Ke's avatar
Guolin Ke committed
517
template <typename VAL_T>
518
519
520
521
522
523
inline uint32_t SparseBinIterator<VAL_T>::RawGet(data_size_t idx) {
  return InnerRawGet(idx);
}

template <typename VAL_T>
inline VAL_T SparseBinIterator<VAL_T>::InnerRawGet(data_size_t idx) {
524
  while (cur_pos_ < idx) {
525
    bin_data_->NextNonzeroFast(&i_delta_, &cur_pos_);
Guolin Ke's avatar
Guolin Ke committed
526
  }
527
  if (cur_pos_ == idx) {
528
529
    return bin_data_->vals_[i_delta_];
  } else {
Guolin Ke's avatar
Guolin Ke committed
530
    return 0;
Guolin Ke's avatar
Guolin Ke committed
531
  }
532
}
Guolin Ke's avatar
Guolin Ke committed
533

534
535
template <typename VAL_T>
inline void SparseBinIterator<VAL_T>::Reset(data_size_t start_idx) {
536
  bin_data_->InitIndex(start_idx, &i_delta_, &cur_pos_);
537
}
Guolin Ke's avatar
Guolin Ke committed
538
539

template <typename VAL_T>
Guolin Ke's avatar
Guolin Ke committed
540
541
BinIterator* SparseBin<VAL_T>::GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin) const {
  return new SparseBinIterator<VAL_T>(this, min_bin, max_bin, most_freq_bin);
Guolin Ke's avatar
Guolin Ke committed
542
543
544
}

}  // namespace LightGBM
zhangyafeikimi's avatar
zhangyafeikimi committed
545
#endif   // LightGBM_IO_SPARSE_BIN_HPP_