sparse_bin.hpp 11.5 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
#ifndef LIGHTGBM_IO_SPARSE_BIN_HPP_
#define LIGHTGBM_IO_SPARSE_BIN_HPP_

#include <LightGBM/utils/log.h>

#include <LightGBM/bin.h>

8
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
9
10
11

#include <cstring>
#include <cstdint>
12
#include <limits>
Guolin Ke's avatar
Guolin Ke committed
13
14
15
16
#include <vector>

namespace LightGBM {

17
template <typename VAL_T> class SparseBin;
18

Guolin Ke's avatar
Guolin Ke committed
19
20
const size_t kNumFastIndex = 64;

21
22
23
template <typename VAL_T>
class SparseBinIterator: public BinIterator {
public:
Guolin Ke's avatar
Guolin Ke committed
24
25
26
27
  SparseBinIterator(const SparseBin<VAL_T>* bin_data,
    uint32_t min_bin, uint32_t max_bin, uint32_t default_bin)
    : bin_data_(bin_data), min_bin_(static_cast<VAL_T>(min_bin)),
    max_bin_(static_cast<VAL_T>(max_bin)),
zhangyafeikimi's avatar
zhangyafeikimi committed
28
    default_bin_(static_cast<VAL_T>(default_bin)) {
Guolin Ke's avatar
Guolin Ke committed
29
30
31
32
33
34
35
    if (default_bin_ == 0) {
      bias_ = 1;
    } else {
      bias_ = 0;
    }
    Reset(0);
  }
36
37
38
39
40
  SparseBinIterator(const SparseBin<VAL_T>* bin_data, data_size_t start_idx)
    : bin_data_(bin_data) {
    Reset(start_idx);
  }

Guolin Ke's avatar
Guolin Ke committed
41
  inline VAL_T RawGet(data_size_t idx);
42

Guolin Ke's avatar
Guolin Ke committed
43
44
45
46
47
48
49
  inline uint32_t Get( data_size_t idx) override {
    VAL_T ret = RawGet(idx);
    if (ret >= min_bin_ && ret <= max_bin_) {
      return ret - min_bin_ + bias_;
    } else {
      return default_bin_;
    }
50
51
  }

Guolin Ke's avatar
Guolin Ke committed
52
  inline void Reset(data_size_t idx) override;
53
54
55
56
private:
  const SparseBin<VAL_T>* bin_data_;
  data_size_t cur_pos_;
  data_size_t i_delta_;
Guolin Ke's avatar
Guolin Ke committed
57
58
59
60
  VAL_T min_bin_;
  VAL_T max_bin_;
  VAL_T default_bin_;
  uint8_t bias_;
61
62
63
64
};

template <typename VAL_T>
class OrderedSparseBin;
Guolin Ke's avatar
Guolin Ke committed
65
66

template <typename VAL_T>
67
class SparseBin: public Bin {
Guolin Ke's avatar
Guolin Ke committed
68
69
public:
  friend class SparseBinIterator<VAL_T>;
70
  friend class OrderedSparseBin<VAL_T>;
Guolin Ke's avatar
Guolin Ke committed
71

Guolin Ke's avatar
Guolin Ke committed
72
  SparseBin(data_size_t num_data)
Guolin Ke's avatar
Guolin Ke committed
73
    : num_data_(num_data) {
Guolin Ke's avatar
Guolin Ke committed
74
    int num_threads = 1;
75
76
#pragma omp parallel
#pragma omp master
Guolin Ke's avatar
Guolin Ke committed
77
    {
Guolin Ke's avatar
Guolin Ke committed
78
      num_threads = omp_get_num_threads();
Guolin Ke's avatar
Guolin Ke committed
79
    }
Guolin Ke's avatar
Guolin Ke committed
80
    push_buffers_.resize(num_threads);
Guolin Ke's avatar
Guolin Ke committed
81
82
83
  }

  ~SparseBin() {
Guolin Ke's avatar
Guolin Ke committed
84
85
86
87
88

  }

  void ReSize(data_size_t num_data) override {
    num_data_ = num_data;
Guolin Ke's avatar
Guolin Ke committed
89
90
91
  }

  void Push(int tid, data_size_t idx, uint32_t value) override {
92
    auto cur_bin = static_cast<VAL_T>(value);
Guolin Ke's avatar
Guolin Ke committed
93
    if (cur_bin != 0) {
94
95
      push_buffers_[tid].emplace_back(idx, cur_bin);
    }
Guolin Ke's avatar
Guolin Ke committed
96
97
  }

Guolin Ke's avatar
Guolin Ke committed
98
  BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t default_bin) const override;
Guolin Ke's avatar
Guolin Ke committed
99

100
101
  void ConstructHistogram(const data_size_t*, data_size_t, const score_t*,
    const score_t*, HistogramBinEntry*) const override {
Guolin Ke's avatar
Guolin Ke committed
102
    // Will use OrderedSparseBin->ConstructHistogram() instead
Guolin Ke's avatar
Guolin Ke committed
103
    Log::Fatal("Using OrderedSparseBin->ConstructHistogram() instead");
Guolin Ke's avatar
Guolin Ke committed
104
105
  }

106
107
108
109
110
111
  void ConstructHistogram(const data_size_t*, data_size_t, const score_t*,
                          HistogramBinEntry*) const override {
    // Will use OrderedSparseBin->ConstructHistogram() instead
    Log::Fatal("Using OrderedSparseBin->ConstructHistogram() instead");
  }

112
113
114
  inline bool NextNonzero(data_size_t* i_delta,
    data_size_t* cur_pos) const {
    ++(*i_delta);
115
116
    data_size_t shift = 0;
    data_size_t delta = deltas_[*i_delta];
Guolin Ke's avatar
Guolin Ke committed
117
    while (*i_delta < num_vals_ && vals_[*i_delta] == 0) {
118
      ++(*i_delta);
119
120
      shift += 8;
      delta |=  static_cast<data_size_t>(deltas_[*i_delta]) << shift;
121
    }
122
123
    *cur_pos += delta;
    if (*i_delta < num_vals_) {
124
125
      return true;
    } else {
126
      *cur_pos = num_data_;
127
128
129
130
      return false;
    }
  }

Guolin Ke's avatar
Guolin Ke committed
131
132
133
  virtual data_size_t Split(
    uint32_t min_bin, uint32_t max_bin, uint32_t default_bin,
    uint32_t threshold, data_size_t* data_indices, data_size_t num_data,
134
    data_size_t* lte_indices, data_size_t* gt_indices, BinType bin_type) const override {
135
136
    // not need to split
    if (num_data <= 0) { return 0; }
Guolin Ke's avatar
Guolin Ke committed
137
138
139
140
141
142
    VAL_T th = static_cast<VAL_T>(threshold + min_bin);
    VAL_T minb = static_cast<VAL_T>(min_bin);
    VAL_T maxb = static_cast<VAL_T>(max_bin);
    if (default_bin == 0) {
      th -= 1;
    }
143
    SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
Guolin Ke's avatar
Guolin Ke committed
144
145
    data_size_t lte_count = 0;
    data_size_t gt_count = 0;
Guolin Ke's avatar
Guolin Ke committed
146
147
    data_size_t* default_indices = gt_indices;
    data_size_t* default_count = &gt_count;
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    if (bin_type == BinType::NumericalBin) {
      if (default_bin <= threshold) {
        default_indices = lte_indices;
        default_count = &lte_count;
      }
      for (data_size_t i = 0; i < num_data; ++i) {
        const data_size_t idx = data_indices[i];
        VAL_T bin = iterator.RawGet(idx);
        if (bin > maxb || bin < minb) {
          default_indices[(*default_count)++] = idx;
        } else if (bin > th) {
          gt_indices[gt_count++] = idx;
        } else {
          lte_indices[lte_count++] = idx;
        }
      }
    } else {
      if (default_bin == threshold) {
        default_indices = lte_indices;
        default_count = &lte_count;
      }
      for (data_size_t i = 0; i < num_data; ++i) {
        const data_size_t idx = data_indices[i];
        VAL_T bin = iterator.RawGet(idx);
        if (bin > maxb || bin < minb) {
          default_indices[(*default_count)++] = idx;
        } else if (bin != th) {
          gt_indices[gt_count++] = idx;
        } else {
          lte_indices[lte_count++] = idx;
        }
Guolin Ke's avatar
Guolin Ke committed
179
180
181
182
183
184
185
      }
    }
    return lte_count;
  }

  data_size_t num_data() const override { return num_data_; }

186
  OrderedBin* CreateOrderedBin() const override;
Guolin Ke's avatar
Guolin Ke committed
187
188
189

  void FinishLoad() override {
    // get total non zero size
190
    size_t pair_cnt = 0;
191
    for (size_t i = 0; i < push_buffers_.size(); ++i) {
192
      pair_cnt += push_buffers_[i].size();
Guolin Ke's avatar
Guolin Ke committed
193
    }
Guolin Ke's avatar
Guolin Ke committed
194
    std::vector<std::pair<data_size_t, VAL_T>>& idx_val_pairs = push_buffers_[0];
195
    idx_val_pairs.reserve(pair_cnt);
Guolin Ke's avatar
Guolin Ke committed
196
197

    for (size_t i = 1; i < push_buffers_.size(); ++i) {
198
      idx_val_pairs.insert(idx_val_pairs.end(), push_buffers_[i].begin(), push_buffers_[i].end());
Guolin Ke's avatar
Guolin Ke committed
199
200
201
202
      push_buffers_[i].clear();
      push_buffers_[i].shrink_to_fit();
    }
    // sort by data index
203
    std::sort(idx_val_pairs.begin(), idx_val_pairs.end(),
Guolin Ke's avatar
Guolin Ke committed
204
205
206
      [](const std::pair<data_size_t, VAL_T>& a, const std::pair<data_size_t, VAL_T>& b) {
      return a.first < b.first;
    });
zhangyafeikimi's avatar
zhangyafeikimi committed
207
    // load delta array
208
    LoadFromPair(idx_val_pairs);
Guolin Ke's avatar
Guolin Ke committed
209
210
  }

211
  void LoadFromPair(const std::vector<std::pair<data_size_t, VAL_T>>& idx_val_pairs) {
212
    deltas_.clear();
Guolin Ke's avatar
Guolin Ke committed
213
214
215
    vals_.clear();
    // transform to delta array
    data_size_t last_idx = 0;
216
217
218
    for (size_t i = 0; i < idx_val_pairs.size(); ++i) {
      const data_size_t cur_idx = idx_val_pairs[i].first;
      const VAL_T bin = idx_val_pairs[i].second;
Guolin Ke's avatar
Guolin Ke committed
219
      data_size_t cur_delta = cur_idx - last_idx;
220
221
      while (cur_delta >= 256) {
        deltas_.push_back(cur_delta & 0xff);
Guolin Ke's avatar
Guolin Ke committed
222
        vals_.push_back(0);
223
        cur_delta >>= 8;
Guolin Ke's avatar
Guolin Ke committed
224
      }
225
      deltas_.push_back(static_cast<uint8_t>(cur_delta));
Guolin Ke's avatar
Guolin Ke committed
226
227
228
229
      vals_.push_back(bin);
      last_idx = cur_idx;
    }
    // avoid out of range
230
    deltas_.push_back(0);
Guolin Ke's avatar
Guolin Ke committed
231
232
233
    num_vals_ = static_cast<data_size_t>(vals_.size());

    // reduce memory cost
234
    deltas_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
235
236
237
238
239
240
241
    vals_.shrink_to_fit();

    // generate fast index
    GetFastIndex();
  }

  void GetFastIndex() {
Guolin Ke's avatar
Guolin Ke committed
242

Guolin Ke's avatar
Guolin Ke committed
243
244
245
246
247
248
249
250
251
252
    fast_index_.clear();
    // get shift cnt
    data_size_t mod_size = (num_data_ + kNumFastIndex - 1) / kNumFastIndex;
    data_size_t pow2_mod_size = 1;
    fast_index_shift_ = 0;
    while (pow2_mod_size < mod_size) {
      pow2_mod_size <<= 1;
      ++fast_index_shift_;
    }
    // build fast index
253
    data_size_t i_delta = -1;
Guolin Ke's avatar
Guolin Ke committed
254
    data_size_t cur_pos = 0;
255
256
    data_size_t next_threshold = 0;
    while (NextNonzero(&i_delta, &cur_pos)) {
Guolin Ke's avatar
Guolin Ke committed
257
      while (next_threshold <= cur_pos) {
258
259
        fast_index_.emplace_back(i_delta, cur_pos);
        next_threshold += pow2_mod_size;
Guolin Ke's avatar
Guolin Ke committed
260
261
262
      }
    }
    // avoid out of range
263
    while (next_threshold < num_data_) {
Guolin Ke's avatar
Guolin Ke committed
264
      fast_index_.emplace_back(num_vals_ - 1, cur_pos);
265
      next_threshold += pow2_mod_size;
Guolin Ke's avatar
Guolin Ke committed
266
267
268
269
270
271
    }
    fast_index_.shrink_to_fit();
  }

  void SaveBinaryToFile(FILE* file) const override {
    fwrite(&num_vals_, sizeof(num_vals_), 1, file);
272
    fwrite(deltas_.data(), sizeof(uint8_t), num_vals_ + 1, file);
Guolin Ke's avatar
Guolin Ke committed
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    fwrite(vals_.data(), sizeof(VAL_T), num_vals_, file);
  }

  size_t SizesInByte() const override {
    return sizeof(num_vals_) + sizeof(uint8_t) * (num_vals_ + 1)
      + sizeof(VAL_T) * num_vals_;
  }

  void LoadFromMemory(const void* memory, const std::vector<data_size_t>& local_used_indices) override {
    const char* mem_ptr = reinterpret_cast<const char*>(memory);
    data_size_t tmp_num_vals = *(reinterpret_cast<const data_size_t*>(mem_ptr));
    mem_ptr += sizeof(tmp_num_vals);
    const uint8_t* tmp_delta = reinterpret_cast<const uint8_t*>(mem_ptr);
    mem_ptr += sizeof(uint8_t) * (tmp_num_vals + 1);
    const VAL_T* tmp_vals = reinterpret_cast<const VAL_T*>(mem_ptr);

289
290
291
292
293
294
295
296
297
298
299
    deltas_.clear();
    vals_.clear();
    num_vals_ = tmp_num_vals;
    for (data_size_t i = 0; i < num_vals_; ++i) {
      deltas_.push_back(tmp_delta[i]);
      vals_.push_back(tmp_vals[i]);
    }
    deltas_.push_back(0);
    // reduce memory cost
    deltas_.shrink_to_fit();
    vals_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
300

Guolin Ke's avatar
Guolin Ke committed
301
    if (local_used_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
302
303
304
305
      // generate fast index
      GetFastIndex();
    } else {
      std::vector<std::pair<data_size_t, VAL_T>> tmp_pair;
306
307
      data_size_t cur_pos = 0;
      data_size_t j = -1;
Guolin Ke's avatar
Guolin Ke committed
308
309
      for (data_size_t i = 0; i < static_cast<data_size_t>(local_used_indices.size()); ++i) {
        const data_size_t idx = local_used_indices[i];
310
311
        while (cur_pos < idx && j < num_vals_) {
          NextNonzero(&j, &cur_pos);
Guolin Ke's avatar
Guolin Ke committed
312
        }
313
        if (cur_pos == idx && j < num_vals_) {
Guolin Ke's avatar
Guolin Ke committed
314
          // new row index is i
315
          tmp_pair.emplace_back(i, vals_[j]);
Guolin Ke's avatar
Guolin Ke committed
316
317
318
319
        }
      }
      LoadFromPair(tmp_pair);
    }
320
  }
Guolin Ke's avatar
Guolin Ke committed
321

322
323
324
  void CopySubset(const Bin* full_bin, const data_size_t* used_indices, data_size_t num_used_indices) override {
    auto other_bin = reinterpret_cast<const SparseBin<VAL_T>*>(full_bin);
    SparseBinIterator<VAL_T> iterator(other_bin, used_indices[0]);
325
326
327
328
    deltas_.clear();
    vals_.clear();
    // transform to delta array
    data_size_t last_idx = 0;
329
    for (data_size_t i = 0; i < num_used_indices; ++i) {
Guolin Ke's avatar
Guolin Ke committed
330
331
      VAL_T bin = iterator.RawGet(used_indices[i]);
      if (bin > 0) {
332
333
334
335
336
337
338
339
340
        data_size_t cur_delta = i - last_idx;
        while (cur_delta >= 256) {
          deltas_.push_back(cur_delta & 0xff);
          vals_.push_back(0);
          cur_delta >>= 8;
        }
        deltas_.push_back(static_cast<uint8_t>(cur_delta));
        vals_.push_back(bin);
        last_idx = i;
341
342
      }
    }
343
344
345
346
347
348
349
350
351
352
    // avoid out of range
    deltas_.push_back(0);
    num_vals_ = static_cast<data_size_t>(vals_.size());

    // reduce memory cost
    deltas_.shrink_to_fit();
    vals_.shrink_to_fit();

    // generate fast index
    GetFastIndex();
Guolin Ke's avatar
Guolin Ke committed
353
354
  }

Guolin Ke's avatar
Guolin Ke committed
355
protected:
Guolin Ke's avatar
Guolin Ke committed
356
  data_size_t num_data_;
357
  std::vector<uint8_t> deltas_;
Guolin Ke's avatar
Guolin Ke committed
358
359
360
361
362
363
364
365
  std::vector<VAL_T> vals_;
  data_size_t num_vals_;
  std::vector<std::vector<std::pair<data_size_t, VAL_T>>> push_buffers_;
  std::vector<std::pair<data_size_t, data_size_t>> fast_index_;
  data_size_t fast_index_shift_;
};

template <typename VAL_T>
Guolin Ke's avatar
Guolin Ke committed
366
inline VAL_T SparseBinIterator<VAL_T>::RawGet(data_size_t idx) {
367
  while (cur_pos_ < idx) {
368
    bin_data_->NextNonzero(&i_delta_, &cur_pos_);
Guolin Ke's avatar
Guolin Ke committed
369
  }
370
  if (cur_pos_ == idx) {
371
372
    return bin_data_->vals_[i_delta_];
  } else {
Guolin Ke's avatar
Guolin Ke committed
373
    return 0;
Guolin Ke's avatar
Guolin Ke committed
374
  }
375
}
Guolin Ke's avatar
Guolin Ke committed
376

377
378
379
380
381
382
template <typename VAL_T>
inline void SparseBinIterator<VAL_T>::Reset(data_size_t start_idx) {
  const auto fast_pair = bin_data_->fast_index_[start_idx >> bin_data_->fast_index_shift_];
  i_delta_ = fast_pair.first;
  cur_pos_ = fast_pair.second;
}
Guolin Ke's avatar
Guolin Ke committed
383
384

template <typename VAL_T>
Guolin Ke's avatar
Guolin Ke committed
385
386
BinIterator* SparseBin<VAL_T>::GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t default_bin) const {
  return new SparseBinIterator<VAL_T>(this, min_bin, max_bin, default_bin);
Guolin Ke's avatar
Guolin Ke committed
387
388
389
}

}  // namespace LightGBM
zhangyafeikimi's avatar
zhangyafeikimi committed
390
#endif   // LightGBM_IO_SPARSE_BIN_HPP_