sparse_bin.hpp 15.7 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#ifndef LIGHTGBM_IO_SPARSE_BIN_HPP_
#define LIGHTGBM_IO_SPARSE_BIN_HPP_

#include <LightGBM/bin.h>
9
#include <LightGBM/utils/log.h>
10
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
11

12
#include <limits>
13
14
15
16
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <utility>
Guolin Ke's avatar
Guolin Ke committed
17
18
19
20
#include <vector>

namespace LightGBM {

21
template <typename VAL_T> class SparseBin;
22

Guolin Ke's avatar
Guolin Ke committed
23
24
const size_t kNumFastIndex = 64;

25
26
template <typename VAL_T>
class SparseBinIterator: public BinIterator {
Nikita Titov's avatar
Nikita Titov committed
27
 public:
Guolin Ke's avatar
Guolin Ke committed
28
  SparseBinIterator(const SparseBin<VAL_T>* bin_data,
Guolin Ke's avatar
Guolin Ke committed
29
    uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin)
Guolin Ke's avatar
Guolin Ke committed
30
31
    : bin_data_(bin_data), min_bin_(static_cast<VAL_T>(min_bin)),
    max_bin_(static_cast<VAL_T>(max_bin)),
Guolin Ke's avatar
Guolin Ke committed
32
33
    most_freq_bin_(static_cast<VAL_T>(most_freq_bin)) {
    if (most_freq_bin_ == 0) {
34
      offset_ = 1;
Guolin Ke's avatar
Guolin Ke committed
35
    } else {
36
      offset_ = 0;
Guolin Ke's avatar
Guolin Ke committed
37
38
39
    }
    Reset(0);
  }
40
41
42
43
44
  SparseBinIterator(const SparseBin<VAL_T>* bin_data, data_size_t start_idx)
    : bin_data_(bin_data) {
    Reset(start_idx);
  }

45
46
  inline uint32_t RawGet(data_size_t idx) override;
  inline VAL_T InnerRawGet(data_size_t idx);
47

48
  inline uint32_t Get(data_size_t idx) override {
Guolin Ke's avatar
Guolin Ke committed
49
    VAL_T ret = InnerRawGet(idx);
Guolin Ke's avatar
Guolin Ke committed
50
    if (ret >= min_bin_ && ret <= max_bin_) {
51
      return ret - min_bin_ + offset_;
Guolin Ke's avatar
Guolin Ke committed
52
    } else {
Guolin Ke's avatar
Guolin Ke committed
53
      return most_freq_bin_;
Guolin Ke's avatar
Guolin Ke committed
54
    }
55
56
  }

Guolin Ke's avatar
Guolin Ke committed
57
  inline void Reset(data_size_t idx) override;
58

Nikita Titov's avatar
Nikita Titov committed
59
 private:
60
61
62
  const SparseBin<VAL_T>* bin_data_;
  data_size_t cur_pos_;
  data_size_t i_delta_;
Guolin Ke's avatar
Guolin Ke committed
63
64
  VAL_T min_bin_;
  VAL_T max_bin_;
Guolin Ke's avatar
Guolin Ke committed
65
  VAL_T most_freq_bin_;
66
  uint8_t offset_;
67
68
69
70
};

template <typename VAL_T>
class OrderedSparseBin;
Guolin Ke's avatar
Guolin Ke committed
71
72

template <typename VAL_T>
73
class SparseBin: public Bin {
Nikita Titov's avatar
Nikita Titov committed
74
 public:
Guolin Ke's avatar
Guolin Ke committed
75
  friend class SparseBinIterator<VAL_T>;
76
  friend class OrderedSparseBin<VAL_T>;
Guolin Ke's avatar
Guolin Ke committed
77

Guolin Ke's avatar
Guolin Ke committed
78
  explicit SparseBin(data_size_t num_data)
Guolin Ke's avatar
Guolin Ke committed
79
    : num_data_(num_data) {
Guolin Ke's avatar
Guolin Ke committed
80
    int num_threads = 1;
81
82
#pragma omp parallel
#pragma omp master
Guolin Ke's avatar
Guolin Ke committed
83
    {
Guolin Ke's avatar
Guolin Ke committed
84
      num_threads = omp_get_num_threads();
Guolin Ke's avatar
Guolin Ke committed
85
    }
Guolin Ke's avatar
Guolin Ke committed
86
    push_buffers_.resize(num_threads);
Guolin Ke's avatar
Guolin Ke committed
87
88
89
  }

  ~SparseBin() {
Guolin Ke's avatar
Guolin Ke committed
90
91
92
93
  }

  void ReSize(data_size_t num_data) override {
    num_data_ = num_data;
Guolin Ke's avatar
Guolin Ke committed
94
95
96
  }

  void Push(int tid, data_size_t idx, uint32_t value) override {
97
    auto cur_bin = static_cast<VAL_T>(value);
Guolin Ke's avatar
Guolin Ke committed
98
    if (cur_bin != 0) {
99
100
      push_buffers_[tid].emplace_back(idx, cur_bin);
    }
Guolin Ke's avatar
Guolin Ke committed
101
102
  }

Guolin Ke's avatar
Guolin Ke committed
103
  BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin) const override;
Guolin Ke's avatar
Guolin Ke committed
104

105
  void ConstructHistogram(const data_size_t*, data_size_t, data_size_t, const score_t*,
106
    const score_t*, HistogramBinEntry*) const override {
Guolin Ke's avatar
Guolin Ke committed
107
    // Will use OrderedSparseBin->ConstructHistogram() instead
Guolin Ke's avatar
Guolin Ke committed
108
    Log::Fatal("Using OrderedSparseBin->ConstructHistogram() instead");
Guolin Ke's avatar
Guolin Ke committed
109
110
  }

111
  void ConstructHistogram(data_size_t, data_size_t, const score_t*,
112
113
114
115
116
                          const score_t*, HistogramBinEntry*) const override {
    // Will use OrderedSparseBin->ConstructHistogram() instead
    Log::Fatal("Using OrderedSparseBin->ConstructHistogram() instead");
  }

117
  void ConstructHistogram(const data_size_t*, data_size_t, data_size_t, const score_t*,
118
119
                          HistogramBinEntry*) const override {
    // Will use OrderedSparseBin->ConstructHistogram() instead
120
121
122
    Log::Fatal("Using OrderedSparseBin->ConstructHistogram() instead");
  }

123
  void ConstructHistogram(data_size_t, data_size_t, const score_t*,
124
125
                          HistogramBinEntry*) const override {
    // Will use OrderedSparseBin->ConstructHistogram() instead
126
127
128
    Log::Fatal("Using OrderedSparseBin->ConstructHistogram() instead");
  }

129
  inline bool NextNonzero(data_size_t* i_delta,
Guolin Ke's avatar
Guolin Ke committed
130
                          data_size_t* cur_pos) const {
131
    ++(*i_delta);
132
133
    data_size_t shift = 0;
    data_size_t delta = deltas_[*i_delta];
Guolin Ke's avatar
Guolin Ke committed
134
    while (*i_delta < num_vals_ && vals_[*i_delta] == 0) {
135
      ++(*i_delta);
136
      shift += 8;
Guolin Ke's avatar
Guolin Ke committed
137
      delta |= static_cast<data_size_t>(deltas_[*i_delta]) << shift;
138
    }
139
140
    *cur_pos += delta;
    if (*i_delta < num_vals_) {
141
142
      return true;
    } else {
143
      *cur_pos = num_data_;
144
145
146
147
      return false;
    }
  }

Guolin Ke's avatar
Guolin Ke committed
148

Guolin Ke's avatar
Guolin Ke committed
149
  data_size_t Split(
Guolin Ke's avatar
Guolin Ke committed
150
    uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, uint32_t most_freq_bin, MissingType missing_type, bool default_left,
Guolin Ke's avatar
Guolin Ke committed
151
    uint32_t threshold, data_size_t* data_indices, data_size_t num_data,
152
    data_size_t* lte_indices, data_size_t* gt_indices) const override {
153
    if (num_data <= 0) { return 0; }
Guolin Ke's avatar
Guolin Ke committed
154
    VAL_T th = static_cast<VAL_T>(threshold + min_bin);
155
156
    const VAL_T minb = static_cast<VAL_T>(min_bin);
    const VAL_T maxb = static_cast<VAL_T>(max_bin);
Guolin Ke's avatar
Guolin Ke committed
157
    VAL_T t_default_bin = static_cast<VAL_T>(min_bin + default_bin);
Guolin Ke's avatar
Guolin Ke committed
158
159
    VAL_T t_most_freq_bin = static_cast<VAL_T>(min_bin + most_freq_bin);
    if (most_freq_bin == 0) {
Guolin Ke's avatar
Guolin Ke committed
160
      th -= 1;
Guolin Ke's avatar
Guolin Ke committed
161
      t_default_bin -= 1;
Guolin Ke's avatar
Guolin Ke committed
162
      t_most_freq_bin -= 1;
Guolin Ke's avatar
Guolin Ke committed
163
    }
Guolin Ke's avatar
Guolin Ke committed
164
165
    data_size_t lte_count = 0;
    data_size_t gt_count = 0;
Guolin Ke's avatar
Guolin Ke committed
166
167
    data_size_t* default_indices = gt_indices;
    data_size_t* default_count = &gt_count;
Guolin Ke's avatar
Guolin Ke committed
168
169
170
171
172
173
174
    data_size_t* missing_default_indices = gt_indices;
    data_size_t* missing_default_count = &gt_count;
    SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
    if (most_freq_bin <= threshold) {
      default_indices = lte_indices;
      default_count = &lte_count;
    }
175
176
177
178
    if (missing_type == MissingType::NaN) {
      if (default_left) {
        missing_default_indices = lte_indices;
        missing_default_count = &lte_count;
Guolin Ke's avatar
Guolin Ke committed
179
      }
180
181
182
      for (data_size_t i = 0; i < num_data; ++i) {
        const data_size_t idx = data_indices[i];
        const VAL_T bin = iterator.InnerRawGet(idx);
Guolin Ke's avatar
Guolin Ke committed
183
        if (bin == maxb) {
184
          missing_default_indices[(*missing_default_count)++] = idx;
Guolin Ke's avatar
Guolin Ke committed
185
186
        } else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
          default_indices[(*default_count)++] = idx;
187
188
189
190
        } else if (bin > th) {
          gt_indices[gt_count++] = idx;
        } else {
          lte_indices[lte_count++] = idx;
191
192
193
        }
      }
    } else {
Guolin Ke's avatar
Guolin Ke committed
194
195
196
197
      if ((default_left && missing_type == MissingType::Zero)
        || (default_bin <= threshold && missing_type != MissingType::Zero)) {
        missing_default_indices = lte_indices;
        missing_default_count = &lte_count;
198
      }
Guolin Ke's avatar
Guolin Ke committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
      if (default_bin == most_freq_bin) {
        for (data_size_t i = 0; i < num_data; ++i) {
          const data_size_t idx = data_indices[i];
          const VAL_T bin = iterator.InnerRawGet(idx);
          if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
            missing_default_indices[(*missing_default_count)++] = idx;
          } else if (bin > th) {
            gt_indices[gt_count++] = idx;
          } else {
            lte_indices[lte_count++] = idx;
          }
        }
      } else {
        for (data_size_t i = 0; i < num_data; ++i) {
          const data_size_t idx = data_indices[i];
          const VAL_T bin = iterator.InnerRawGet(idx);
          if (bin == t_default_bin) {
            missing_default_indices[(*missing_default_count)++] = idx;
          } else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
            default_indices[(*default_count)++] = idx;
          } else if (bin > th) {
            gt_indices[gt_count++] = idx;
          } else {
            lte_indices[lte_count++] = idx;
          }
224
        }
Guolin Ke's avatar
Guolin Ke committed
225
      }
226
    }
227
228
229
    return lte_count;
  }

Guolin Ke's avatar
Guolin Ke committed
230
  data_size_t SplitCategorical(
Guolin Ke's avatar
Guolin Ke committed
231
    uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin,
232
    const uint32_t* threshold, int num_threahold, data_size_t* data_indices, data_size_t num_data,
233
234
235
236
237
238
239
    data_size_t* lte_indices, data_size_t* gt_indices) const override {
    if (num_data <= 0) { return 0; }
    data_size_t lte_count = 0;
    data_size_t gt_count = 0;
    SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
    data_size_t* default_indices = gt_indices;
    data_size_t* default_count = &gt_count;
Guolin Ke's avatar
Guolin Ke committed
240
    if (Common::FindInBitset(threshold, num_threahold, most_freq_bin)) {
241
242
243
244
245
246
247
248
      default_indices = lte_indices;
      default_count = &lte_count;
    }
    for (data_size_t i = 0; i < num_data; ++i) {
      const data_size_t idx = data_indices[i];
      uint32_t bin = iterator.InnerRawGet(idx);
      if (bin < min_bin || bin > max_bin) {
        default_indices[(*default_count)++] = idx;
249
      } else if (Common::FindInBitset(threshold, num_threahold, bin - min_bin)) {
250
251
252
253
        lte_indices[lte_count++] = idx;
      } else {
        gt_indices[gt_count++] = idx;
      }
Guolin Ke's avatar
Guolin Ke committed
254
255
256
257
258
259
    }
    return lte_count;
  }

  data_size_t num_data() const override { return num_data_; }

260
  OrderedBin* CreateOrderedBin() const override;
Guolin Ke's avatar
Guolin Ke committed
261
262
263

  void FinishLoad() override {
    // get total non zero size
264
    size_t pair_cnt = 0;
265
    for (size_t i = 0; i < push_buffers_.size(); ++i) {
266
      pair_cnt += push_buffers_[i].size();
Guolin Ke's avatar
Guolin Ke committed
267
    }
Guolin Ke's avatar
Guolin Ke committed
268
    std::vector<std::pair<data_size_t, VAL_T>>& idx_val_pairs = push_buffers_[0];
269
    idx_val_pairs.reserve(pair_cnt);
Guolin Ke's avatar
Guolin Ke committed
270
271

    for (size_t i = 1; i < push_buffers_.size(); ++i) {
272
      idx_val_pairs.insert(idx_val_pairs.end(), push_buffers_[i].begin(), push_buffers_[i].end());
Guolin Ke's avatar
Guolin Ke committed
273
274
275
276
      push_buffers_[i].clear();
      push_buffers_[i].shrink_to_fit();
    }
    // sort by data index
277
    std::sort(idx_val_pairs.begin(), idx_val_pairs.end(),
Guolin Ke's avatar
Guolin Ke committed
278
279
280
      [](const std::pair<data_size_t, VAL_T>& a, const std::pair<data_size_t, VAL_T>& b) {
      return a.first < b.first;
    });
zhangyafeikimi's avatar
zhangyafeikimi committed
281
    // load delta array
282
    LoadFromPair(idx_val_pairs);
Guolin Ke's avatar
Guolin Ke committed
283
284
  }

285
  void LoadFromPair(const std::vector<std::pair<data_size_t, VAL_T>>& idx_val_pairs) {
286
    deltas_.clear();
Guolin Ke's avatar
Guolin Ke committed
287
288
289
    vals_.clear();
    // transform to delta array
    data_size_t last_idx = 0;
290
291
292
    for (size_t i = 0; i < idx_val_pairs.size(); ++i) {
      const data_size_t cur_idx = idx_val_pairs[i].first;
      const VAL_T bin = idx_val_pairs[i].second;
Guolin Ke's avatar
Guolin Ke committed
293
      data_size_t cur_delta = cur_idx - last_idx;
Guolin Ke's avatar
Guolin Ke committed
294
      if (i > 0 && cur_delta == 0) { continue; }
295
296
      while (cur_delta >= 256) {
        deltas_.push_back(cur_delta & 0xff);
Guolin Ke's avatar
Guolin Ke committed
297
        vals_.push_back(0);
298
        cur_delta >>= 8;
Guolin Ke's avatar
Guolin Ke committed
299
      }
300
      deltas_.push_back(static_cast<uint8_t>(cur_delta));
Guolin Ke's avatar
Guolin Ke committed
301
302
303
304
      vals_.push_back(bin);
      last_idx = cur_idx;
    }
    // avoid out of range
305
    deltas_.push_back(0);
Guolin Ke's avatar
Guolin Ke committed
306
307
308
    num_vals_ = static_cast<data_size_t>(vals_.size());

    // reduce memory cost
309
    deltas_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
    vals_.shrink_to_fit();

    // generate fast index
    GetFastIndex();
  }

  void GetFastIndex() {
    fast_index_.clear();
    // get shift cnt
    data_size_t mod_size = (num_data_ + kNumFastIndex - 1) / kNumFastIndex;
    data_size_t pow2_mod_size = 1;
    fast_index_shift_ = 0;
    while (pow2_mod_size < mod_size) {
      pow2_mod_size <<= 1;
      ++fast_index_shift_;
    }
    // build fast index
327
    data_size_t i_delta = -1;
Guolin Ke's avatar
Guolin Ke committed
328
    data_size_t cur_pos = 0;
329
330
    data_size_t next_threshold = 0;
    while (NextNonzero(&i_delta, &cur_pos)) {
Guolin Ke's avatar
Guolin Ke committed
331
      while (next_threshold <= cur_pos) {
332
333
        fast_index_.emplace_back(i_delta, cur_pos);
        next_threshold += pow2_mod_size;
Guolin Ke's avatar
Guolin Ke committed
334
335
336
      }
    }
    // avoid out of range
337
    while (next_threshold < num_data_) {
Guolin Ke's avatar
Guolin Ke committed
338
      fast_index_.emplace_back(num_vals_ - 1, cur_pos);
339
      next_threshold += pow2_mod_size;
Guolin Ke's avatar
Guolin Ke committed
340
341
342
343
    }
    fast_index_.shrink_to_fit();
  }

344
345
346
347
  void SaveBinaryToFile(const VirtualFileWriter* writer) const override {
    writer->Write(&num_vals_, sizeof(num_vals_));
    writer->Write(deltas_.data(), sizeof(uint8_t) * (num_vals_ + 1));
    writer->Write(vals_.data(), sizeof(VAL_T) * num_vals_);
Guolin Ke's avatar
Guolin Ke committed
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
  }

  size_t SizesInByte() const override {
    return sizeof(num_vals_) + sizeof(uint8_t) * (num_vals_ + 1)
      + sizeof(VAL_T) * num_vals_;
  }

  void LoadFromMemory(const void* memory, const std::vector<data_size_t>& local_used_indices) override {
    const char* mem_ptr = reinterpret_cast<const char*>(memory);
    data_size_t tmp_num_vals = *(reinterpret_cast<const data_size_t*>(mem_ptr));
    mem_ptr += sizeof(tmp_num_vals);
    const uint8_t* tmp_delta = reinterpret_cast<const uint8_t*>(mem_ptr);
    mem_ptr += sizeof(uint8_t) * (tmp_num_vals + 1);
    const VAL_T* tmp_vals = reinterpret_cast<const VAL_T*>(mem_ptr);

363
364
365
366
367
368
369
370
371
372
373
    deltas_.clear();
    vals_.clear();
    num_vals_ = tmp_num_vals;
    for (data_size_t i = 0; i < num_vals_; ++i) {
      deltas_.push_back(tmp_delta[i]);
      vals_.push_back(tmp_vals[i]);
    }
    deltas_.push_back(0);
    // reduce memory cost
    deltas_.shrink_to_fit();
    vals_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
374

Guolin Ke's avatar
Guolin Ke committed
375
    if (local_used_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
376
377
378
379
      // generate fast index
      GetFastIndex();
    } else {
      std::vector<std::pair<data_size_t, VAL_T>> tmp_pair;
380
381
      data_size_t cur_pos = 0;
      data_size_t j = -1;
Guolin Ke's avatar
Guolin Ke committed
382
383
      for (data_size_t i = 0; i < static_cast<data_size_t>(local_used_indices.size()); ++i) {
        const data_size_t idx = local_used_indices[i];
384
385
        while (cur_pos < idx && j < num_vals_) {
          NextNonzero(&j, &cur_pos);
Guolin Ke's avatar
Guolin Ke committed
386
        }
387
        if (cur_pos == idx && j < num_vals_) {
Guolin Ke's avatar
Guolin Ke committed
388
          // new row index is i
389
          tmp_pair.emplace_back(i, vals_[j]);
Guolin Ke's avatar
Guolin Ke committed
390
391
392
393
        }
      }
      LoadFromPair(tmp_pair);
    }
394
  }
Guolin Ke's avatar
Guolin Ke committed
395

396
  void CopySubset(const Bin* full_bin, const data_size_t* used_indices, data_size_t num_used_indices) override {
Guolin Ke's avatar
Guolin Ke committed
397
    auto other_bin = dynamic_cast<const SparseBin<VAL_T>*>(full_bin);
Guolin Ke's avatar
Guolin Ke committed
398
399
    deltas_.clear();
    vals_.clear();
Guolin Ke's avatar
Guolin Ke committed
400
401
402
403
404
    data_size_t start = 0;
    if (num_used_indices > 0) {
      start = used_indices[0];
    }
    SparseBinIterator<VAL_T> iterator(other_bin, start);
405
406
    // transform to delta array
    data_size_t last_idx = 0;
407
    for (data_size_t i = 0; i < num_used_indices; ++i) {
408
      VAL_T bin = iterator.InnerRawGet(used_indices[i]);
Guolin Ke's avatar
Guolin Ke committed
409
      if (bin > 0) {
410
411
412
413
414
415
416
417
418
        data_size_t cur_delta = i - last_idx;
        while (cur_delta >= 256) {
          deltas_.push_back(cur_delta & 0xff);
          vals_.push_back(0);
          cur_delta >>= 8;
        }
        deltas_.push_back(static_cast<uint8_t>(cur_delta));
        vals_.push_back(bin);
        last_idx = i;
419
420
      }
    }
421
422
423
424
425
426
427
428
429
430
    // avoid out of range
    deltas_.push_back(0);
    num_vals_ = static_cast<data_size_t>(vals_.size());

    // reduce memory cost
    deltas_.shrink_to_fit();
    vals_.shrink_to_fit();

    // generate fast index
    GetFastIndex();
Guolin Ke's avatar
Guolin Ke committed
431
432
  }

433
434
  SparseBin<VAL_T>* Clone() override;

Nikita Titov's avatar
Nikita Titov committed
435
 protected:
436
437
438
  SparseBin<VAL_T>(const SparseBin<VAL_T>& other)
    : num_data_(other.num_data_), deltas_(other.deltas_), vals_(other.vals_),
      num_vals_(other.num_vals_), push_buffers_(other.push_buffers_),
439
      fast_index_(other.fast_index_), fast_index_shift_(other.fast_index_shift_) {}
440

Guolin Ke's avatar
Guolin Ke committed
441
  data_size_t num_data_;
442
  std::vector<uint8_t> deltas_;
Guolin Ke's avatar
Guolin Ke committed
443
444
445
446
447
448
449
  std::vector<VAL_T> vals_;
  data_size_t num_vals_;
  std::vector<std::vector<std::pair<data_size_t, VAL_T>>> push_buffers_;
  std::vector<std::pair<data_size_t, data_size_t>> fast_index_;
  data_size_t fast_index_shift_;
};

450
template<typename VAL_T>
451
SparseBin<VAL_T>* SparseBin<VAL_T>::Clone() {
452
453
454
  return new SparseBin(*this);
}

Guolin Ke's avatar
Guolin Ke committed
455
template <typename VAL_T>
456
457
458
459
460
461
inline uint32_t SparseBinIterator<VAL_T>::RawGet(data_size_t idx) {
  return InnerRawGet(idx);
}

template <typename VAL_T>
inline VAL_T SparseBinIterator<VAL_T>::InnerRawGet(data_size_t idx) {
462
  while (cur_pos_ < idx) {
463
    bin_data_->NextNonzero(&i_delta_, &cur_pos_);
Guolin Ke's avatar
Guolin Ke committed
464
  }
465
  if (cur_pos_ == idx) {
466
467
    return bin_data_->vals_[i_delta_];
  } else {
Guolin Ke's avatar
Guolin Ke committed
468
    return 0;
Guolin Ke's avatar
Guolin Ke committed
469
  }
470
}
Guolin Ke's avatar
Guolin Ke committed
471

472
473
template <typename VAL_T>
inline void SparseBinIterator<VAL_T>::Reset(data_size_t start_idx) {
Guolin Ke's avatar
Guolin Ke committed
474
475
476
477
478
479
480
481
482
  auto idx = start_idx >> bin_data_->fast_index_shift_;
  if (static_cast<size_t>(idx) < bin_data_->fast_index_.size()) {
    const auto fast_pair = bin_data_->fast_index_[start_idx >> bin_data_->fast_index_shift_];
    i_delta_ = fast_pair.first;
    cur_pos_ = fast_pair.second;
  } else {
    i_delta_ = -1;
    cur_pos_ = 0;
  }
483
}
Guolin Ke's avatar
Guolin Ke committed
484
485

template <typename VAL_T>
Guolin Ke's avatar
Guolin Ke committed
486
487
BinIterator* SparseBin<VAL_T>::GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin) const {
  return new SparseBinIterator<VAL_T>(this, min_bin, max_bin, most_freq_bin);
Guolin Ke's avatar
Guolin Ke committed
488
489
490
}

}  // namespace LightGBM
zhangyafeikimi's avatar
zhangyafeikimi committed
491
#endif   // LightGBM_IO_SPARSE_BIN_HPP_