sparse_bin.hpp 17.9 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#ifndef LIGHTGBM_IO_SPARSE_BIN_HPP_
#define LIGHTGBM_IO_SPARSE_BIN_HPP_

#include <LightGBM/bin.h>
9
#include <LightGBM/utils/log.h>
10
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
11

12
#include <limits>
13
14
15
16
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <utility>
Guolin Ke's avatar
Guolin Ke committed
17
18
19
20
#include <vector>

namespace LightGBM {

21
template <typename VAL_T> class SparseBin;
22

Guolin Ke's avatar
Guolin Ke committed
23
24
const size_t kNumFastIndex = 64;

25
26
template <typename VAL_T>
class SparseBinIterator: public BinIterator {
27
 public:
Guolin Ke's avatar
Guolin Ke committed
28
  SparseBinIterator(const SparseBin<VAL_T>* bin_data,
Guolin Ke's avatar
Guolin Ke committed
29
    uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin)
Guolin Ke's avatar
Guolin Ke committed
30
31
    : bin_data_(bin_data), min_bin_(static_cast<VAL_T>(min_bin)),
    max_bin_(static_cast<VAL_T>(max_bin)),
Guolin Ke's avatar
Guolin Ke committed
32
33
    most_freq_bin_(static_cast<VAL_T>(most_freq_bin)) {
    if (most_freq_bin_ == 0) {
34
      offset_ = 1;
Guolin Ke's avatar
Guolin Ke committed
35
    } else {
36
      offset_ = 0;
Guolin Ke's avatar
Guolin Ke committed
37
38
39
    }
    Reset(0);
  }
40
41
42
43
44
  SparseBinIterator(const SparseBin<VAL_T>* bin_data, data_size_t start_idx)
    : bin_data_(bin_data) {
    Reset(start_idx);
  }

45
46
  inline uint32_t RawGet(data_size_t idx) override;
  inline VAL_T InnerRawGet(data_size_t idx);
47

48
  inline uint32_t Get(data_size_t idx) override {
Guolin Ke's avatar
Guolin Ke committed
49
    VAL_T ret = InnerRawGet(idx);
Guolin Ke's avatar
Guolin Ke committed
50
    if (ret >= min_bin_ && ret <= max_bin_) {
51
      return ret - min_bin_ + offset_;
Guolin Ke's avatar
Guolin Ke committed
52
    } else {
Guolin Ke's avatar
Guolin Ke committed
53
      return most_freq_bin_;
Guolin Ke's avatar
Guolin Ke committed
54
    }
55
56
  }

Guolin Ke's avatar
Guolin Ke committed
57
  inline void Reset(data_size_t idx) override;
58

59
 private:
60
61
62
  const SparseBin<VAL_T>* bin_data_;
  data_size_t cur_pos_;
  data_size_t i_delta_;
Guolin Ke's avatar
Guolin Ke committed
63
64
  VAL_T min_bin_;
  VAL_T max_bin_;
Guolin Ke's avatar
Guolin Ke committed
65
  VAL_T most_freq_bin_;
66
  uint8_t offset_;
67
68
};

Guolin Ke's avatar
Guolin Ke committed
69
template <typename VAL_T>
70
class SparseBin: public Bin {
71
 public:
Guolin Ke's avatar
Guolin Ke committed
72
73
  friend class SparseBinIterator<VAL_T>;

Guolin Ke's avatar
Guolin Ke committed
74
  explicit SparseBin(data_size_t num_data)
Guolin Ke's avatar
Guolin Ke committed
75
    : num_data_(num_data) {
Guolin Ke's avatar
Guolin Ke committed
76
    int num_threads = 1;
77
78
    #pragma omp parallel
    #pragma omp master
Guolin Ke's avatar
Guolin Ke committed
79
    {
Guolin Ke's avatar
Guolin Ke committed
80
      num_threads = omp_get_num_threads();
Guolin Ke's avatar
Guolin Ke committed
81
    }
Guolin Ke's avatar
Guolin Ke committed
82
    push_buffers_.resize(num_threads);
Guolin Ke's avatar
Guolin Ke committed
83
84
85
  }

  ~SparseBin() {
Guolin Ke's avatar
Guolin Ke committed
86
87
88
89
  }

  void ReSize(data_size_t num_data) override {
    num_data_ = num_data;
Guolin Ke's avatar
Guolin Ke committed
90
91
92
  }

  void Push(int tid, data_size_t idx, uint32_t value) override {
93
    auto cur_bin = static_cast<VAL_T>(value);
Guolin Ke's avatar
Guolin Ke committed
94
    if (cur_bin != 0) {
95
96
      push_buffers_[tid].emplace_back(idx, cur_bin);
    }
Guolin Ke's avatar
Guolin Ke committed
97
98
  }

Guolin Ke's avatar
Guolin Ke committed
99
  BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin) const override;
Guolin Ke's avatar
Guolin Ke committed
100

101
102
103
104
105
106
  #define ACC_GH(hist, i, g, h) \
  const auto ti = static_cast<int>(i) << 1; \
  hist[ti] += g; \
  hist[ti + 1] += h; \

  void ConstructHistogram(const data_size_t* data_indices, data_size_t start, data_size_t end,
107
    const score_t* ordered_gradients, const score_t* ordered_hessians, hist_t* out) const override {
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    data_size_t i_delta, cur_pos;
    InitIndex(data_indices[start], &i_delta, &cur_pos);
    data_size_t i = start;
    for (;;) {
      if (cur_pos < data_indices[i]) {
        cur_pos += deltas_[++i_delta];
        if (i_delta >= num_vals_) { break; }
      } else if (cur_pos > data_indices[i]) {
        if (++i >= end) { break; }
      } else {
        const VAL_T bin = vals_[i_delta];
        ACC_GH(out, bin, ordered_gradients[i], ordered_hessians[i]);
        if (++i >= end) { break; }
        cur_pos += deltas_[++i_delta];
        if (i_delta >= num_vals_) { break; }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
125
126
  }

127
  void ConstructHistogram(data_size_t start, data_size_t end,
128
    const score_t* ordered_gradients, const score_t* ordered_hessians, hist_t* out) const override {
129
130
131
132
133
134
135
136
137
138
    data_size_t i_delta, cur_pos;
    InitIndex(start, &i_delta, &cur_pos);
    while (cur_pos < start && i_delta < num_vals_) {
      cur_pos += deltas_[++i_delta];
    }
    while (cur_pos < end && i_delta < num_vals_) {
      const VAL_T bin = vals_[i_delta];
      ACC_GH(out, bin, ordered_gradients[cur_pos], ordered_hessians[cur_pos]);
      cur_pos += deltas_[++i_delta];
    }
139
140
  }

141
  void ConstructHistogram(const data_size_t* data_indices, data_size_t start, data_size_t end,
142
    const score_t* ordered_gradients, hist_t* out) const override {
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    data_size_t i_delta, cur_pos;
    InitIndex(data_indices[start], &i_delta, &cur_pos);
    data_size_t i = start;
    for (;;) {
      if (cur_pos < data_indices[i]) {
        cur_pos += deltas_[++i_delta];
        if (i_delta >= num_vals_) { break; }
      } else if (cur_pos > data_indices[i]) {
        if (++i >= end) { break; }
      } else {
        const VAL_T bin = vals_[i_delta];
        ACC_GH(out, bin, ordered_gradients[i], 1.0f);
        if (++i >= end) { break; }
        cur_pos += deltas_[++i_delta];
        if (i_delta >= num_vals_) { break; }
      }
    }
160
161
  }

162
  void ConstructHistogram(data_size_t start, data_size_t end,
163
    const score_t* ordered_gradients, hist_t* out) const override {
164
165
166
167
168
169
170
171
172
173
    data_size_t i_delta, cur_pos;
    InitIndex(start, &i_delta, &cur_pos);
    while (cur_pos < start && i_delta < num_vals_) {
      cur_pos += deltas_[++i_delta];
    }
    while (cur_pos < end && i_delta < num_vals_) {
      const VAL_T bin = vals_[i_delta];
      ACC_GH(out, bin, ordered_gradients[cur_pos], 1.0f);
      cur_pos += deltas_[++i_delta];
    }
174
  }
175
  #undef ACC_GH
176

177
  inline void NextNonzeroFast(data_size_t* i_delta, data_size_t* cur_pos) const {
178
179
180
    *cur_pos += deltas_[++(*i_delta)];
    if (*i_delta >= num_vals_) {
      *cur_pos = num_data_;
181
    }
182
183
184
185
186
  }

  inline bool NextNonzero(data_size_t* i_delta,
    data_size_t* cur_pos) const {
    *cur_pos += deltas_[++(*i_delta)];
187
    if (*i_delta < num_vals_) {
188
189
      return true;
    } else {
190
      *cur_pos = num_data_;
191
192
193
194
      return false;
    }
  }

Guolin Ke's avatar
Guolin Ke committed
195

Guolin Ke's avatar
Guolin Ke committed
196
  data_size_t Split(
197
198
    uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, uint32_t most_freq_bin,
    MissingType missing_type, bool default_left,
Guolin Ke's avatar
Guolin Ke committed
199
    uint32_t threshold, data_size_t* data_indices, data_size_t num_data,
200
    data_size_t* lte_indices, data_size_t* gt_indices) const override {
201
    if (num_data <= 0) { return 0; }
Guolin Ke's avatar
Guolin Ke committed
202
    VAL_T th = static_cast<VAL_T>(threshold + min_bin);
203
204
    const VAL_T minb = static_cast<VAL_T>(min_bin);
    const VAL_T maxb = static_cast<VAL_T>(max_bin);
205
    VAL_T t_zero_bin = static_cast<VAL_T>(min_bin + default_bin);
Guolin Ke's avatar
Guolin Ke committed
206
207
    VAL_T t_most_freq_bin = static_cast<VAL_T>(min_bin + most_freq_bin);
    if (most_freq_bin == 0) {
Guolin Ke's avatar
Guolin Ke committed
208
      th -= 1;
209
      t_zero_bin -= 1;
Guolin Ke's avatar
Guolin Ke committed
210
      t_most_freq_bin -= 1;
Guolin Ke's avatar
Guolin Ke committed
211
    }
Guolin Ke's avatar
Guolin Ke committed
212
213
    data_size_t lte_count = 0;
    data_size_t gt_count = 0;
Guolin Ke's avatar
Guolin Ke committed
214
215
    data_size_t* default_indices = gt_indices;
    data_size_t* default_count = &gt_count;
Guolin Ke's avatar
Guolin Ke committed
216
217
218
219
220
221
222
    data_size_t* missing_default_indices = gt_indices;
    data_size_t* missing_default_count = &gt_count;
    SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
    if (most_freq_bin <= threshold) {
      default_indices = lte_indices;
      default_count = &lte_count;
    }
223
224
225
226
    if (missing_type == MissingType::NaN) {
      if (default_left) {
        missing_default_indices = lte_indices;
        missing_default_count = &lte_count;
Guolin Ke's avatar
Guolin Ke committed
227
      }
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
      if (t_most_freq_bin == maxb) {
        for (data_size_t i = 0; i < num_data; ++i) {
          const data_size_t idx = data_indices[i];
          const VAL_T bin = iterator.InnerRawGet(idx);
          if (t_most_freq_bin == bin || bin < minb || bin > maxb) {
            missing_default_indices[(*missing_default_count)++] = idx;
          } else if (bin > th) {
            gt_indices[gt_count++] = idx;
          } else {
            lte_indices[lte_count++] = idx;
          }
        }
      } else {
        for (data_size_t i = 0; i < num_data; ++i) {
          const data_size_t idx = data_indices[i];
          const VAL_T bin = iterator.InnerRawGet(idx);
          if (bin == maxb) {
            missing_default_indices[(*missing_default_count)++] = idx;
          } else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
            default_indices[(*default_count)++] = idx;
          } else if (bin > th) {
            gt_indices[gt_count++] = idx;
          } else {
            lte_indices[lte_count++] = idx;
          }
253
254
255
        }
      }
    } else {
Guolin Ke's avatar
Guolin Ke committed
256
257
258
259
      if ((default_left && missing_type == MissingType::Zero)
        || (default_bin <= threshold && missing_type != MissingType::Zero)) {
        missing_default_indices = lte_indices;
        missing_default_count = &lte_count;
260
      }
Guolin Ke's avatar
Guolin Ke committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
      if (default_bin == most_freq_bin) {
        for (data_size_t i = 0; i < num_data; ++i) {
          const data_size_t idx = data_indices[i];
          const VAL_T bin = iterator.InnerRawGet(idx);
          if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
            missing_default_indices[(*missing_default_count)++] = idx;
          } else if (bin > th) {
            gt_indices[gt_count++] = idx;
          } else {
            lte_indices[lte_count++] = idx;
          }
        }
      } else {
        for (data_size_t i = 0; i < num_data; ++i) {
          const data_size_t idx = data_indices[i];
          const VAL_T bin = iterator.InnerRawGet(idx);
277
          if (bin == t_zero_bin) {
Guolin Ke's avatar
Guolin Ke committed
278
279
280
281
282
283
284
285
            missing_default_indices[(*missing_default_count)++] = idx;
          } else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
            default_indices[(*default_count)++] = idx;
          } else if (bin > th) {
            gt_indices[gt_count++] = idx;
          } else {
            lte_indices[lte_count++] = idx;
          }
286
        }
Guolin Ke's avatar
Guolin Ke committed
287
      }
288
    }
289
290
291
    return lte_count;
  }

Guolin Ke's avatar
Guolin Ke committed
292
  data_size_t SplitCategorical(
Guolin Ke's avatar
Guolin Ke committed
293
    uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin,
294
    const uint32_t* threshold, int num_threahold, data_size_t* data_indices, data_size_t num_data,
295
296
297
298
299
300
301
    data_size_t* lte_indices, data_size_t* gt_indices) const override {
    if (num_data <= 0) { return 0; }
    data_size_t lte_count = 0;
    data_size_t gt_count = 0;
    SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
    data_size_t* default_indices = gt_indices;
    data_size_t* default_count = &gt_count;
Guolin Ke's avatar
Guolin Ke committed
302
    if (Common::FindInBitset(threshold, num_threahold, most_freq_bin)) {
303
304
305
306
307
308
309
310
      default_indices = lte_indices;
      default_count = &lte_count;
    }
    for (data_size_t i = 0; i < num_data; ++i) {
      const data_size_t idx = data_indices[i];
      uint32_t bin = iterator.InnerRawGet(idx);
      if (bin < min_bin || bin > max_bin) {
        default_indices[(*default_count)++] = idx;
311
      } else if (Common::FindInBitset(threshold, num_threahold, bin - min_bin)) {
312
313
314
315
        lte_indices[lte_count++] = idx;
      } else {
        gt_indices[gt_count++] = idx;
      }
Guolin Ke's avatar
Guolin Ke committed
316
317
318
319
320
321
322
323
    }
    return lte_count;
  }

  data_size_t num_data() const override { return num_data_; }

  void FinishLoad() override {
    // get total non zero size
324
    size_t pair_cnt = 0;
325
    for (size_t i = 0; i < push_buffers_.size(); ++i) {
326
      pair_cnt += push_buffers_[i].size();
Guolin Ke's avatar
Guolin Ke committed
327
    }
Guolin Ke's avatar
Guolin Ke committed
328
    std::vector<std::pair<data_size_t, VAL_T>>& idx_val_pairs = push_buffers_[0];
329
    idx_val_pairs.reserve(pair_cnt);
Guolin Ke's avatar
Guolin Ke committed
330
331

    for (size_t i = 1; i < push_buffers_.size(); ++i) {
332
      idx_val_pairs.insert(idx_val_pairs.end(), push_buffers_[i].begin(), push_buffers_[i].end());
Guolin Ke's avatar
Guolin Ke committed
333
334
335
336
      push_buffers_[i].clear();
      push_buffers_[i].shrink_to_fit();
    }
    // sort by data index
337
    std::sort(idx_val_pairs.begin(), idx_val_pairs.end(),
Guolin Ke's avatar
Guolin Ke committed
338
      [](const std::pair<data_size_t, VAL_T>& a, const std::pair<data_size_t, VAL_T>& b) {
339
340
        return a.first < b.first;
      });
zhangyafeikimi's avatar
zhangyafeikimi committed
341
    // load delta array
342
    LoadFromPair(idx_val_pairs);
Guolin Ke's avatar
Guolin Ke committed
343
344
  }

345
  void LoadFromPair(const std::vector<std::pair<data_size_t, VAL_T>>& idx_val_pairs) {
346
    deltas_.clear();
Guolin Ke's avatar
Guolin Ke committed
347
    vals_.clear();
Guolin Ke's avatar
Guolin Ke committed
348
349
    deltas_.reserve(idx_val_pairs.size());
    vals_.reserve(idx_val_pairs.size());
Guolin Ke's avatar
Guolin Ke committed
350
351
    // transform to delta array
    data_size_t last_idx = 0;
352
353
354
    for (size_t i = 0; i < idx_val_pairs.size(); ++i) {
      const data_size_t cur_idx = idx_val_pairs[i].first;
      const VAL_T bin = idx_val_pairs[i].second;
Guolin Ke's avatar
Guolin Ke committed
355
      data_size_t cur_delta = cur_idx - last_idx;
356
      // disallow the multi-val in one row
Guolin Ke's avatar
Guolin Ke committed
357
      if (i > 0 && cur_delta == 0) { continue; }
358
      while (cur_delta >= 256) {
359
        deltas_.push_back(255);
Guolin Ke's avatar
Guolin Ke committed
360
        vals_.push_back(0);
361
        cur_delta -= 255;
Guolin Ke's avatar
Guolin Ke committed
362
      }
363
      deltas_.push_back(static_cast<uint8_t>(cur_delta));
Guolin Ke's avatar
Guolin Ke committed
364
365
366
367
      vals_.push_back(bin);
      last_idx = cur_idx;
    }
    // avoid out of range
368
    deltas_.push_back(0);
Guolin Ke's avatar
Guolin Ke committed
369
370
371
    num_vals_ = static_cast<data_size_t>(vals_.size());

    // reduce memory cost
372
    deltas_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
    vals_.shrink_to_fit();

    // generate fast index
    GetFastIndex();
  }

  void GetFastIndex() {
    fast_index_.clear();
    // get shift cnt
    data_size_t mod_size = (num_data_ + kNumFastIndex - 1) / kNumFastIndex;
    data_size_t pow2_mod_size = 1;
    fast_index_shift_ = 0;
    while (pow2_mod_size < mod_size) {
      pow2_mod_size <<= 1;
      ++fast_index_shift_;
    }
    // build fast index
390
    data_size_t i_delta = -1;
Guolin Ke's avatar
Guolin Ke committed
391
    data_size_t cur_pos = 0;
392
393
    data_size_t next_threshold = 0;
    while (NextNonzero(&i_delta, &cur_pos)) {
Guolin Ke's avatar
Guolin Ke committed
394
      while (next_threshold <= cur_pos) {
395
396
        fast_index_.emplace_back(i_delta, cur_pos);
        next_threshold += pow2_mod_size;
Guolin Ke's avatar
Guolin Ke committed
397
398
399
      }
    }
    // avoid out of range
400
    while (next_threshold < num_data_) {
Guolin Ke's avatar
Guolin Ke committed
401
      fast_index_.emplace_back(num_vals_ - 1, cur_pos);
402
      next_threshold += pow2_mod_size;
Guolin Ke's avatar
Guolin Ke committed
403
404
405
406
    }
    fast_index_.shrink_to_fit();
  }

407
408
409
410
  void SaveBinaryToFile(const VirtualFileWriter* writer) const override {
    writer->Write(&num_vals_, sizeof(num_vals_));
    writer->Write(deltas_.data(), sizeof(uint8_t) * (num_vals_ + 1));
    writer->Write(vals_.data(), sizeof(VAL_T) * num_vals_);
Guolin Ke's avatar
Guolin Ke committed
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
  }

  size_t SizesInByte() const override {
    return sizeof(num_vals_) + sizeof(uint8_t) * (num_vals_ + 1)
      + sizeof(VAL_T) * num_vals_;
  }

  void LoadFromMemory(const void* memory, const std::vector<data_size_t>& local_used_indices) override {
    const char* mem_ptr = reinterpret_cast<const char*>(memory);
    data_size_t tmp_num_vals = *(reinterpret_cast<const data_size_t*>(mem_ptr));
    mem_ptr += sizeof(tmp_num_vals);
    const uint8_t* tmp_delta = reinterpret_cast<const uint8_t*>(mem_ptr);
    mem_ptr += sizeof(uint8_t) * (tmp_num_vals + 1);
    const VAL_T* tmp_vals = reinterpret_cast<const VAL_T*>(mem_ptr);

426
427
428
429
430
431
432
433
434
435
436
    deltas_.clear();
    vals_.clear();
    num_vals_ = tmp_num_vals;
    for (data_size_t i = 0; i < num_vals_; ++i) {
      deltas_.push_back(tmp_delta[i]);
      vals_.push_back(tmp_vals[i]);
    }
    deltas_.push_back(0);
    // reduce memory cost
    deltas_.shrink_to_fit();
    vals_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
437

Guolin Ke's avatar
Guolin Ke committed
438
    if (local_used_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
439
440
441
442
      // generate fast index
      GetFastIndex();
    } else {
      std::vector<std::pair<data_size_t, VAL_T>> tmp_pair;
443
444
      data_size_t cur_pos = 0;
      data_size_t j = -1;
Guolin Ke's avatar
Guolin Ke committed
445
446
      for (data_size_t i = 0; i < static_cast<data_size_t>(local_used_indices.size()); ++i) {
        const data_size_t idx = local_used_indices[i];
447
448
        while (cur_pos < idx && j < num_vals_) {
          NextNonzero(&j, &cur_pos);
Guolin Ke's avatar
Guolin Ke committed
449
        }
450
        if (cur_pos == idx && j < num_vals_ && vals_[j] > 0) {
Guolin Ke's avatar
Guolin Ke committed
451
          // new row index is i
452
          tmp_pair.emplace_back(i, vals_[j]);
Guolin Ke's avatar
Guolin Ke committed
453
454
455
456
        }
      }
      LoadFromPair(tmp_pair);
    }
457
  }
Guolin Ke's avatar
Guolin Ke committed
458

459
  void CopySubset(const Bin* full_bin, const data_size_t* used_indices, data_size_t num_used_indices) override {
Guolin Ke's avatar
Guolin Ke committed
460
    auto other_bin = dynamic_cast<const SparseBin<VAL_T>*>(full_bin);
Guolin Ke's avatar
Guolin Ke committed
461
462
    deltas_.clear();
    vals_.clear();
Guolin Ke's avatar
Guolin Ke committed
463
464
465
466
467
    data_size_t start = 0;
    if (num_used_indices > 0) {
      start = used_indices[0];
    }
    SparseBinIterator<VAL_T> iterator(other_bin, start);
468
469
    // transform to delta array
    data_size_t last_idx = 0;
470
    for (data_size_t i = 0; i < num_used_indices; ++i) {
471
      auto bin = iterator.InnerRawGet(used_indices[i]);
Guolin Ke's avatar
Guolin Ke committed
472
      if (bin > 0) {
473
474
        data_size_t cur_delta = i - last_idx;
        while (cur_delta >= 256) {
475
          deltas_.push_back(255);
476
          vals_.push_back(0);
477
          cur_delta -= 255;
478
479
480
481
        }
        deltas_.push_back(static_cast<uint8_t>(cur_delta));
        vals_.push_back(bin);
        last_idx = i;
482
483
      }
    }
484
485
486
487
488
489
490
491
492
493
    // avoid out of range
    deltas_.push_back(0);
    num_vals_ = static_cast<data_size_t>(vals_.size());

    // reduce memory cost
    deltas_.shrink_to_fit();
    vals_.shrink_to_fit();

    // generate fast index
    GetFastIndex();
Guolin Ke's avatar
Guolin Ke committed
494
495
  }

496
497
498
499
  SparseBin<VAL_T>* Clone() override;

  SparseBin<VAL_T>(const SparseBin<VAL_T>& other)
    : num_data_(other.num_data_), deltas_(other.deltas_), vals_(other.vals_),
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
    num_vals_(other.num_vals_), push_buffers_(other.push_buffers_),
    fast_index_(other.fast_index_), fast_index_shift_(other.fast_index_shift_) {
  }

  void InitIndex(data_size_t start_idx, data_size_t * i_delta, data_size_t * cur_pos) const {
    auto idx = start_idx >> fast_index_shift_;
    if (static_cast<size_t>(idx) < fast_index_.size()) {
      const auto fast_pair = fast_index_[start_idx >> fast_index_shift_];
      *i_delta = fast_pair.first;
      *cur_pos = fast_pair.second;
    } else {
      *i_delta = -1;
      *cur_pos = 0;
    }
  }

516
 private:
Guolin Ke's avatar
Guolin Ke committed
517
  data_size_t num_data_;
518
519
  std::vector<uint8_t, Common::AlignmentAllocator<uint8_t, kAlignedSize>> deltas_;
  std::vector<VAL_T, Common::AlignmentAllocator<VAL_T, kAlignedSize>> vals_;
Guolin Ke's avatar
Guolin Ke committed
520
521
522
523
524
525
  data_size_t num_vals_;
  std::vector<std::vector<std::pair<data_size_t, VAL_T>>> push_buffers_;
  std::vector<std::pair<data_size_t, data_size_t>> fast_index_;
  data_size_t fast_index_shift_;
};

526
template<typename VAL_T>
527
SparseBin<VAL_T>* SparseBin<VAL_T>::Clone() {
528
529
530
  return new SparseBin(*this);
}

Guolin Ke's avatar
Guolin Ke committed
531
template <typename VAL_T>
532
533
534
535
536
537
inline uint32_t SparseBinIterator<VAL_T>::RawGet(data_size_t idx) {
  return InnerRawGet(idx);
}

template <typename VAL_T>
inline VAL_T SparseBinIterator<VAL_T>::InnerRawGet(data_size_t idx) {
538
  while (cur_pos_ < idx) {
539
    bin_data_->NextNonzeroFast(&i_delta_, &cur_pos_);
Guolin Ke's avatar
Guolin Ke committed
540
  }
541
  if (cur_pos_ == idx) {
542
543
    return bin_data_->vals_[i_delta_];
  } else {
Guolin Ke's avatar
Guolin Ke committed
544
    return 0;
Guolin Ke's avatar
Guolin Ke committed
545
  }
546
}
Guolin Ke's avatar
Guolin Ke committed
547

548
549
template <typename VAL_T>
inline void SparseBinIterator<VAL_T>::Reset(data_size_t start_idx) {
550
  bin_data_->InitIndex(start_idx, &i_delta_, &cur_pos_);
551
}
Guolin Ke's avatar
Guolin Ke committed
552
553

template <typename VAL_T>
Guolin Ke's avatar
Guolin Ke committed
554
555
BinIterator* SparseBin<VAL_T>::GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin) const {
  return new SparseBinIterator<VAL_T>(this, min_bin, max_bin, most_freq_bin);
Guolin Ke's avatar
Guolin Ke committed
556
557
558
}

}  // namespace LightGBM
zhangyafeikimi's avatar
zhangyafeikimi committed
559
#endif   // LightGBM_IO_SPARSE_BIN_HPP_