sparse_bin.hpp 17.3 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#ifndef LIGHTGBM_IO_SPARSE_BIN_HPP_
#define LIGHTGBM_IO_SPARSE_BIN_HPP_

#include <LightGBM/bin.h>
9
#include <LightGBM/utils/log.h>
10
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
11

12
#include <limits>
13
14
15
16
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <utility>
Guolin Ke's avatar
Guolin Ke committed
17
18
19
20
#include <vector>

namespace LightGBM {

21
template <typename VAL_T> class SparseBin;
22

Guolin Ke's avatar
Guolin Ke committed
23
24
const size_t kNumFastIndex = 64;

25
26
template <typename VAL_T>
class SparseBinIterator: public BinIterator {
27
public:
Guolin Ke's avatar
Guolin Ke committed
28
  SparseBinIterator(const SparseBin<VAL_T>* bin_data,
Guolin Ke's avatar
Guolin Ke committed
29
    uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin)
Guolin Ke's avatar
Guolin Ke committed
30
31
    : bin_data_(bin_data), min_bin_(static_cast<VAL_T>(min_bin)),
    max_bin_(static_cast<VAL_T>(max_bin)),
Guolin Ke's avatar
Guolin Ke committed
32
33
    most_freq_bin_(static_cast<VAL_T>(most_freq_bin)) {
    if (most_freq_bin_ == 0) {
34
      offset_ = 1;
Guolin Ke's avatar
Guolin Ke committed
35
    } else {
36
      offset_ = 0;
Guolin Ke's avatar
Guolin Ke committed
37
38
39
    }
    Reset(0);
  }
40
41
42
43
44
  SparseBinIterator(const SparseBin<VAL_T>* bin_data, data_size_t start_idx)
    : bin_data_(bin_data) {
    Reset(start_idx);
  }

45
46
  inline uint32_t RawGet(data_size_t idx) override;
  inline VAL_T InnerRawGet(data_size_t idx);
47

48
  inline uint32_t Get(data_size_t idx) override {
Guolin Ke's avatar
Guolin Ke committed
49
    VAL_T ret = InnerRawGet(idx);
Guolin Ke's avatar
Guolin Ke committed
50
    if (ret >= min_bin_ && ret <= max_bin_) {
51
      return ret - min_bin_ + offset_;
Guolin Ke's avatar
Guolin Ke committed
52
    } else {
Guolin Ke's avatar
Guolin Ke committed
53
      return most_freq_bin_;
Guolin Ke's avatar
Guolin Ke committed
54
    }
55
56
  }

Guolin Ke's avatar
Guolin Ke committed
57
  inline void Reset(data_size_t idx) override;
58

59
private:
60
61
62
  const SparseBin<VAL_T>* bin_data_;
  data_size_t cur_pos_;
  data_size_t i_delta_;
Guolin Ke's avatar
Guolin Ke committed
63
64
  VAL_T min_bin_;
  VAL_T max_bin_;
Guolin Ke's avatar
Guolin Ke committed
65
  VAL_T most_freq_bin_;
66
  uint8_t offset_;
67
68
};

Guolin Ke's avatar
Guolin Ke committed
69
template <typename VAL_T>
70
class SparseBin: public Bin {
71
public:
Guolin Ke's avatar
Guolin Ke committed
72
73
  friend class SparseBinIterator<VAL_T>;

Guolin Ke's avatar
Guolin Ke committed
74
  explicit SparseBin(data_size_t num_data)
Guolin Ke's avatar
Guolin Ke committed
75
    : num_data_(num_data) {
Guolin Ke's avatar
Guolin Ke committed
76
    int num_threads = 1;
77
78
    #pragma omp parallel
    #pragma omp master
Guolin Ke's avatar
Guolin Ke committed
79
    {
Guolin Ke's avatar
Guolin Ke committed
80
      num_threads = omp_get_num_threads();
Guolin Ke's avatar
Guolin Ke committed
81
    }
Guolin Ke's avatar
Guolin Ke committed
82
    push_buffers_.resize(num_threads);
Guolin Ke's avatar
Guolin Ke committed
83
84
85
  }

  ~SparseBin() {
Guolin Ke's avatar
Guolin Ke committed
86
87
88
89
  }

  void ReSize(data_size_t num_data) override {
    num_data_ = num_data;
Guolin Ke's avatar
Guolin Ke committed
90
91
92
  }

  void Push(int tid, data_size_t idx, uint32_t value) override {
93
    auto cur_bin = static_cast<VAL_T>(value);
Guolin Ke's avatar
Guolin Ke committed
94
    if (cur_bin != 0) {
95
96
      push_buffers_[tid].emplace_back(idx, cur_bin);
    }
Guolin Ke's avatar
Guolin Ke committed
97
98
  }

Guolin Ke's avatar
Guolin Ke committed
99
  BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin) const override;
Guolin Ke's avatar
Guolin Ke committed
100

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
  #define ACC_GH(hist, i, g, h) \
  const auto ti = static_cast<int>(i) << 1; \
  hist[ti] += g; \
  hist[ti + 1] += h; \

  void ConstructHistogram(const data_size_t* data_indices, data_size_t start, data_size_t end,
    const score_t* ordered_gradients, const score_t* ordered_hessians,
    hist_t* out) const override {
    data_size_t i_delta, cur_pos;
    InitIndex(data_indices[start], &i_delta, &cur_pos);
    data_size_t i = start;
    for (;;) {
      if (cur_pos < data_indices[i]) {
        cur_pos += deltas_[++i_delta];
        if (i_delta >= num_vals_) { break; }
      } else if (cur_pos > data_indices[i]) {
        if (++i >= end) { break; }
      } else {
        const VAL_T bin = vals_[i_delta];
        ACC_GH(out, bin, ordered_gradients[i], ordered_hessians[i]);
        if (++i >= end) { break; }
        cur_pos += deltas_[++i_delta];
        if (i_delta >= num_vals_) { break; }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
126
127
  }

128
129
130
131
132
133
134
135
136
137
138
139
140
  void ConstructHistogram(data_size_t start, data_size_t end,
    const score_t* ordered_gradients, const score_t* ordered_hessians,
    hist_t* out) const override {
    data_size_t i_delta, cur_pos;
    InitIndex(start, &i_delta, &cur_pos);
    while (cur_pos < start && i_delta < num_vals_) {
      cur_pos += deltas_[++i_delta];
    }
    while (cur_pos < end && i_delta < num_vals_) {
      const VAL_T bin = vals_[i_delta];
      ACC_GH(out, bin, ordered_gradients[cur_pos], ordered_hessians[cur_pos]);
      cur_pos += deltas_[++i_delta];
    }
141
142
  }

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
  void ConstructHistogram(const data_size_t* data_indices, data_size_t start, data_size_t end,
    const score_t* ordered_gradients,
    hist_t* out) const override {
    data_size_t i_delta, cur_pos;
    InitIndex(data_indices[start], &i_delta, &cur_pos);
    data_size_t i = start;
    for (;;) {
      if (cur_pos < data_indices[i]) {
        cur_pos += deltas_[++i_delta];
        if (i_delta >= num_vals_) { break; }
      } else if (cur_pos > data_indices[i]) {
        if (++i >= end) { break; }
      } else {
        const VAL_T bin = vals_[i_delta];
        ACC_GH(out, bin, ordered_gradients[i], 1.0f);
        if (++i >= end) { break; }
        cur_pos += deltas_[++i_delta];
        if (i_delta >= num_vals_) { break; }
      }
    }
163
164
  }

165
166
167
168
169
170
171
172
173
174
175
176
177
  void ConstructHistogram(data_size_t start, data_size_t end,
    const score_t* ordered_gradients,
    hist_t* out) const override {
    data_size_t i_delta, cur_pos;
    InitIndex(start, &i_delta, &cur_pos);
    while (cur_pos < start && i_delta < num_vals_) {
      cur_pos += deltas_[++i_delta];
    }
    while (cur_pos < end && i_delta < num_vals_) {
      const VAL_T bin = vals_[i_delta];
      ACC_GH(out, bin, ordered_gradients[cur_pos], 1.0f);
      cur_pos += deltas_[++i_delta];
    }
178
  }
179
  #undef ACC_GH
180

181
182
183
184
185
  inline void NextNonzeroFast(data_size_t* i_delta,
    data_size_t* cur_pos) const {
    *cur_pos += deltas_[++(*i_delta)];
    if (*i_delta >= num_vals_) {
      *cur_pos = num_data_;
186
    }
187
188
189
190
191
  }

  inline bool NextNonzero(data_size_t* i_delta,
    data_size_t* cur_pos) const {
    *cur_pos += deltas_[++(*i_delta)];
192
    if (*i_delta < num_vals_) {
193
194
      return true;
    } else {
195
      *cur_pos = num_data_;
196
197
198
199
      return false;
    }
  }

Guolin Ke's avatar
Guolin Ke committed
200

Guolin Ke's avatar
Guolin Ke committed
201
  data_size_t Split(
Guolin Ke's avatar
Guolin Ke committed
202
    uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, uint32_t most_freq_bin, MissingType missing_type, bool default_left,
Guolin Ke's avatar
Guolin Ke committed
203
    uint32_t threshold, data_size_t* data_indices, data_size_t num_data,
204
    data_size_t* lte_indices, data_size_t* gt_indices) const override {
205
    if (num_data <= 0) { return 0; }
Guolin Ke's avatar
Guolin Ke committed
206
    VAL_T th = static_cast<VAL_T>(threshold + min_bin);
207
208
    const VAL_T minb = static_cast<VAL_T>(min_bin);
    const VAL_T maxb = static_cast<VAL_T>(max_bin);
Guolin Ke's avatar
Guolin Ke committed
209
    VAL_T t_default_bin = static_cast<VAL_T>(min_bin + default_bin);
Guolin Ke's avatar
Guolin Ke committed
210
211
    VAL_T t_most_freq_bin = static_cast<VAL_T>(min_bin + most_freq_bin);
    if (most_freq_bin == 0) {
Guolin Ke's avatar
Guolin Ke committed
212
      th -= 1;
Guolin Ke's avatar
Guolin Ke committed
213
      t_default_bin -= 1;
Guolin Ke's avatar
Guolin Ke committed
214
      t_most_freq_bin -= 1;
Guolin Ke's avatar
Guolin Ke committed
215
    }
Guolin Ke's avatar
Guolin Ke committed
216
217
    data_size_t lte_count = 0;
    data_size_t gt_count = 0;
Guolin Ke's avatar
Guolin Ke committed
218
219
    data_size_t* default_indices = gt_indices;
    data_size_t* default_count = &gt_count;
Guolin Ke's avatar
Guolin Ke committed
220
221
222
223
224
225
226
    data_size_t* missing_default_indices = gt_indices;
    data_size_t* missing_default_count = &gt_count;
    SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
    if (most_freq_bin <= threshold) {
      default_indices = lte_indices;
      default_count = &lte_count;
    }
227
228
229
230
    if (missing_type == MissingType::NaN) {
      if (default_left) {
        missing_default_indices = lte_indices;
        missing_default_count = &lte_count;
Guolin Ke's avatar
Guolin Ke committed
231
      }
232
233
234
      for (data_size_t i = 0; i < num_data; ++i) {
        const data_size_t idx = data_indices[i];
        const VAL_T bin = iterator.InnerRawGet(idx);
Guolin Ke's avatar
Guolin Ke committed
235
        if (bin == maxb) {
236
          missing_default_indices[(*missing_default_count)++] = idx;
Guolin Ke's avatar
Guolin Ke committed
237
238
        } else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
          default_indices[(*default_count)++] = idx;
239
240
241
242
        } else if (bin > th) {
          gt_indices[gt_count++] = idx;
        } else {
          lte_indices[lte_count++] = idx;
243
244
245
        }
      }
    } else {
Guolin Ke's avatar
Guolin Ke committed
246
247
248
249
      if ((default_left && missing_type == MissingType::Zero)
        || (default_bin <= threshold && missing_type != MissingType::Zero)) {
        missing_default_indices = lte_indices;
        missing_default_count = &lte_count;
250
      }
Guolin Ke's avatar
Guolin Ke committed
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
      if (default_bin == most_freq_bin) {
        for (data_size_t i = 0; i < num_data; ++i) {
          const data_size_t idx = data_indices[i];
          const VAL_T bin = iterator.InnerRawGet(idx);
          if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
            missing_default_indices[(*missing_default_count)++] = idx;
          } else if (bin > th) {
            gt_indices[gt_count++] = idx;
          } else {
            lte_indices[lte_count++] = idx;
          }
        }
      } else {
        for (data_size_t i = 0; i < num_data; ++i) {
          const data_size_t idx = data_indices[i];
          const VAL_T bin = iterator.InnerRawGet(idx);
          if (bin == t_default_bin) {
            missing_default_indices[(*missing_default_count)++] = idx;
          } else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
            default_indices[(*default_count)++] = idx;
          } else if (bin > th) {
            gt_indices[gt_count++] = idx;
          } else {
            lte_indices[lte_count++] = idx;
          }
276
        }
Guolin Ke's avatar
Guolin Ke committed
277
      }
278
    }
279
280
281
    return lte_count;
  }

Guolin Ke's avatar
Guolin Ke committed
282
  data_size_t SplitCategorical(
Guolin Ke's avatar
Guolin Ke committed
283
    uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin,
284
    const uint32_t* threshold, int num_threahold, data_size_t* data_indices, data_size_t num_data,
285
286
287
288
289
290
291
    data_size_t* lte_indices, data_size_t* gt_indices) const override {
    if (num_data <= 0) { return 0; }
    data_size_t lte_count = 0;
    data_size_t gt_count = 0;
    SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
    data_size_t* default_indices = gt_indices;
    data_size_t* default_count = &gt_count;
Guolin Ke's avatar
Guolin Ke committed
292
    if (Common::FindInBitset(threshold, num_threahold, most_freq_bin)) {
293
294
295
296
297
298
299
300
      default_indices = lte_indices;
      default_count = &lte_count;
    }
    for (data_size_t i = 0; i < num_data; ++i) {
      const data_size_t idx = data_indices[i];
      uint32_t bin = iterator.InnerRawGet(idx);
      if (bin < min_bin || bin > max_bin) {
        default_indices[(*default_count)++] = idx;
301
      } else if (Common::FindInBitset(threshold, num_threahold, bin - min_bin)) {
302
303
304
305
        lte_indices[lte_count++] = idx;
      } else {
        gt_indices[gt_count++] = idx;
      }
Guolin Ke's avatar
Guolin Ke committed
306
307
308
309
310
311
312
313
    }
    return lte_count;
  }

  data_size_t num_data() const override { return num_data_; }

  void FinishLoad() override {
    // get total non zero size
314
    size_t pair_cnt = 0;
315
    for (size_t i = 0; i < push_buffers_.size(); ++i) {
316
      pair_cnt += push_buffers_[i].size();
Guolin Ke's avatar
Guolin Ke committed
317
    }
Guolin Ke's avatar
Guolin Ke committed
318
    std::vector<std::pair<data_size_t, VAL_T>>& idx_val_pairs = push_buffers_[0];
319
    idx_val_pairs.reserve(pair_cnt);
Guolin Ke's avatar
Guolin Ke committed
320
321

    for (size_t i = 1; i < push_buffers_.size(); ++i) {
322
      idx_val_pairs.insert(idx_val_pairs.end(), push_buffers_[i].begin(), push_buffers_[i].end());
Guolin Ke's avatar
Guolin Ke committed
323
324
325
326
      push_buffers_[i].clear();
      push_buffers_[i].shrink_to_fit();
    }
    // sort by data index
327
    std::sort(idx_val_pairs.begin(), idx_val_pairs.end(),
Guolin Ke's avatar
Guolin Ke committed
328
      [](const std::pair<data_size_t, VAL_T>& a, const std::pair<data_size_t, VAL_T>& b) {
329
330
        return a.first < b.first;
      });
zhangyafeikimi's avatar
zhangyafeikimi committed
331
    // load delta array
332
    LoadFromPair(idx_val_pairs);
Guolin Ke's avatar
Guolin Ke committed
333
334
  }

335
  void LoadFromPair(const std::vector<std::pair<data_size_t, VAL_T>>& idx_val_pairs) {
336
    deltas_.clear();
Guolin Ke's avatar
Guolin Ke committed
337
338
339
    vals_.clear();
    // transform to delta array
    data_size_t last_idx = 0;
340
341
342
    for (size_t i = 0; i < idx_val_pairs.size(); ++i) {
      const data_size_t cur_idx = idx_val_pairs[i].first;
      const VAL_T bin = idx_val_pairs[i].second;
Guolin Ke's avatar
Guolin Ke committed
343
      data_size_t cur_delta = cur_idx - last_idx;
344
      // disallow the multi-val in one row
Guolin Ke's avatar
Guolin Ke committed
345
      if (i > 0 && cur_delta == 0) { continue; }
346
      while (cur_delta >= 256) {
347
        deltas_.push_back(255);
Guolin Ke's avatar
Guolin Ke committed
348
        vals_.push_back(0);
349
        cur_delta -= 255;
Guolin Ke's avatar
Guolin Ke committed
350
      }
351
      deltas_.push_back(static_cast<uint8_t>(cur_delta));
Guolin Ke's avatar
Guolin Ke committed
352
353
354
355
      vals_.push_back(bin);
      last_idx = cur_idx;
    }
    // avoid out of range
356
    deltas_.push_back(0);
Guolin Ke's avatar
Guolin Ke committed
357
358
359
    num_vals_ = static_cast<data_size_t>(vals_.size());

    // reduce memory cost
360
    deltas_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
    vals_.shrink_to_fit();

    // generate fast index
    GetFastIndex();
  }

  void GetFastIndex() {
    fast_index_.clear();
    // get shift cnt
    data_size_t mod_size = (num_data_ + kNumFastIndex - 1) / kNumFastIndex;
    data_size_t pow2_mod_size = 1;
    fast_index_shift_ = 0;
    while (pow2_mod_size < mod_size) {
      pow2_mod_size <<= 1;
      ++fast_index_shift_;
    }
    // build fast index
378
    data_size_t i_delta = -1;
Guolin Ke's avatar
Guolin Ke committed
379
    data_size_t cur_pos = 0;
380
381
    data_size_t next_threshold = 0;
    while (NextNonzero(&i_delta, &cur_pos)) {
Guolin Ke's avatar
Guolin Ke committed
382
      while (next_threshold <= cur_pos) {
383
384
        fast_index_.emplace_back(i_delta, cur_pos);
        next_threshold += pow2_mod_size;
Guolin Ke's avatar
Guolin Ke committed
385
386
387
      }
    }
    // avoid out of range
388
    while (next_threshold < num_data_) {
Guolin Ke's avatar
Guolin Ke committed
389
      fast_index_.emplace_back(num_vals_ - 1, cur_pos);
390
      next_threshold += pow2_mod_size;
Guolin Ke's avatar
Guolin Ke committed
391
392
393
394
    }
    fast_index_.shrink_to_fit();
  }

395
396
397
398
  void SaveBinaryToFile(const VirtualFileWriter* writer) const override {
    writer->Write(&num_vals_, sizeof(num_vals_));
    writer->Write(deltas_.data(), sizeof(uint8_t) * (num_vals_ + 1));
    writer->Write(vals_.data(), sizeof(VAL_T) * num_vals_);
Guolin Ke's avatar
Guolin Ke committed
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
  }

  size_t SizesInByte() const override {
    return sizeof(num_vals_) + sizeof(uint8_t) * (num_vals_ + 1)
      + sizeof(VAL_T) * num_vals_;
  }

  void LoadFromMemory(const void* memory, const std::vector<data_size_t>& local_used_indices) override {
    const char* mem_ptr = reinterpret_cast<const char*>(memory);
    data_size_t tmp_num_vals = *(reinterpret_cast<const data_size_t*>(mem_ptr));
    mem_ptr += sizeof(tmp_num_vals);
    const uint8_t* tmp_delta = reinterpret_cast<const uint8_t*>(mem_ptr);
    mem_ptr += sizeof(uint8_t) * (tmp_num_vals + 1);
    const VAL_T* tmp_vals = reinterpret_cast<const VAL_T*>(mem_ptr);

414
415
416
417
418
419
420
421
422
423
424
    deltas_.clear();
    vals_.clear();
    num_vals_ = tmp_num_vals;
    for (data_size_t i = 0; i < num_vals_; ++i) {
      deltas_.push_back(tmp_delta[i]);
      vals_.push_back(tmp_vals[i]);
    }
    deltas_.push_back(0);
    // reduce memory cost
    deltas_.shrink_to_fit();
    vals_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
425

Guolin Ke's avatar
Guolin Ke committed
426
    if (local_used_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
427
428
429
430
      // generate fast index
      GetFastIndex();
    } else {
      std::vector<std::pair<data_size_t, VAL_T>> tmp_pair;
431
432
      data_size_t cur_pos = 0;
      data_size_t j = -1;
Guolin Ke's avatar
Guolin Ke committed
433
434
      for (data_size_t i = 0; i < static_cast<data_size_t>(local_used_indices.size()); ++i) {
        const data_size_t idx = local_used_indices[i];
435
436
        while (cur_pos < idx && j < num_vals_) {
          NextNonzero(&j, &cur_pos);
Guolin Ke's avatar
Guolin Ke committed
437
        }
438
        if (cur_pos == idx && j < num_vals_ && vals_[j] > 0) {
Guolin Ke's avatar
Guolin Ke committed
439
          // new row index is i
440
          tmp_pair.emplace_back(i, vals_[j]);
Guolin Ke's avatar
Guolin Ke committed
441
442
443
444
        }
      }
      LoadFromPair(tmp_pair);
    }
445
  }
Guolin Ke's avatar
Guolin Ke committed
446

447
  void CopySubset(const Bin* full_bin, const data_size_t* used_indices, data_size_t num_used_indices) override {
Guolin Ke's avatar
Guolin Ke committed
448
    auto other_bin = dynamic_cast<const SparseBin<VAL_T>*>(full_bin);
Guolin Ke's avatar
Guolin Ke committed
449
450
    deltas_.clear();
    vals_.clear();
Guolin Ke's avatar
Guolin Ke committed
451
452
453
454
455
    data_size_t start = 0;
    if (num_used_indices > 0) {
      start = used_indices[0];
    }
    SparseBinIterator<VAL_T> iterator(other_bin, start);
456
457
    // transform to delta array
    data_size_t last_idx = 0;
458
    for (data_size_t i = 0; i < num_used_indices; ++i) {
459
      auto bin = iterator.InnerRawGet(used_indices[i]);
Guolin Ke's avatar
Guolin Ke committed
460
      if (bin > 0) {
461
462
        data_size_t cur_delta = i - last_idx;
        while (cur_delta >= 256) {
463
          deltas_.push_back(255);
464
          vals_.push_back(0);
465
          cur_delta -= 255;
466
467
468
469
        }
        deltas_.push_back(static_cast<uint8_t>(cur_delta));
        vals_.push_back(bin);
        last_idx = i;
470
471
      }
    }
472
473
474
475
476
477
478
479
480
481
    // avoid out of range
    deltas_.push_back(0);
    num_vals_ = static_cast<data_size_t>(vals_.size());

    // reduce memory cost
    deltas_.shrink_to_fit();
    vals_.shrink_to_fit();

    // generate fast index
    GetFastIndex();
Guolin Ke's avatar
Guolin Ke committed
482
483
  }

484
485
486
487
  SparseBin<VAL_T>* Clone() override;

  SparseBin<VAL_T>(const SparseBin<VAL_T>& other)
    : num_data_(other.num_data_), deltas_(other.deltas_), vals_(other.vals_),
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
    num_vals_(other.num_vals_), push_buffers_(other.push_buffers_),
    fast_index_(other.fast_index_), fast_index_shift_(other.fast_index_shift_) {
  }

  void InitIndex(data_size_t start_idx, data_size_t * i_delta, data_size_t * cur_pos) const {
    auto idx = start_idx >> fast_index_shift_;
    if (static_cast<size_t>(idx) < fast_index_.size()) {
      const auto fast_pair = fast_index_[start_idx >> fast_index_shift_];
      *i_delta = fast_pair.first;
      *cur_pos = fast_pair.second;
    } else {
      *i_delta = -1;
      *cur_pos = 0;
    }
  }

private:
505

Guolin Ke's avatar
Guolin Ke committed
506
  data_size_t num_data_;
507
508
  std::vector<uint8_t, Common::AlignmentAllocator<uint8_t, kAlignedSize>> deltas_;
  std::vector<VAL_T, Common::AlignmentAllocator<VAL_T, kAlignedSize>> vals_;
Guolin Ke's avatar
Guolin Ke committed
509
510
511
512
513
514
  data_size_t num_vals_;
  std::vector<std::vector<std::pair<data_size_t, VAL_T>>> push_buffers_;
  std::vector<std::pair<data_size_t, data_size_t>> fast_index_;
  data_size_t fast_index_shift_;
};

515
template<typename VAL_T>
516
SparseBin<VAL_T>* SparseBin<VAL_T>::Clone() {
517
518
519
  return new SparseBin(*this);
}

Guolin Ke's avatar
Guolin Ke committed
520
template <typename VAL_T>
521
522
523
524
525
526
inline uint32_t SparseBinIterator<VAL_T>::RawGet(data_size_t idx) {
  return InnerRawGet(idx);
}

template <typename VAL_T>
inline VAL_T SparseBinIterator<VAL_T>::InnerRawGet(data_size_t idx) {
527
  while (cur_pos_ < idx) {
528
    bin_data_->NextNonzeroFast(&i_delta_, &cur_pos_);
Guolin Ke's avatar
Guolin Ke committed
529
  }
530
  if (cur_pos_ == idx) {
531
532
    return bin_data_->vals_[i_delta_];
  } else {
Guolin Ke's avatar
Guolin Ke committed
533
    return 0;
Guolin Ke's avatar
Guolin Ke committed
534
  }
535
}
Guolin Ke's avatar
Guolin Ke committed
536

537
538
template <typename VAL_T>
inline void SparseBinIterator<VAL_T>::Reset(data_size_t start_idx) {
539
  bin_data_->InitIndex(start_idx, &i_delta_, &cur_pos_);
540
}
Guolin Ke's avatar
Guolin Ke committed
541
542

template <typename VAL_T>
Guolin Ke's avatar
Guolin Ke committed
543
544
BinIterator* SparseBin<VAL_T>::GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin) const {
  return new SparseBinIterator<VAL_T>(this, min_bin, max_bin, most_freq_bin);
Guolin Ke's avatar
Guolin Ke committed
545
546
547
}

}  // namespace LightGBM
zhangyafeikimi's avatar
zhangyafeikimi committed
548
#endif   // LightGBM_IO_SPARSE_BIN_HPP_