bagging.hpp 11.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
/*!
 * Copyright (c) 2021 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */

#ifndef LIGHTGBM_BOOSTING_BAGGING_HPP_
#define LIGHTGBM_BOOSTING_BAGGING_HPP_

#include <string>
10
#include <vector>
11
12
13
14
15
16
17
18
19
20

namespace LightGBM {

class BaggingSampleStrategy : public SampleStrategy {
 public:
  BaggingSampleStrategy(const Config* config, const Dataset* train_data, const ObjectiveFunction* objective_function, int num_tree_per_iteration)
    : need_re_bagging_(false) {
    config_ = config;
    train_data_ = train_data;
    num_data_ = train_data->num_data();
21
22
    num_queries_ = train_data->metadata().num_queries();
    query_boundaries_ = train_data->metadata().query_boundaries();
23
24
    objective_function_ = objective_function;
    num_tree_per_iteration_ = num_tree_per_iteration;
25
    num_threads_ = OMP_NUM_THREADS();
26
27
28
29
30
31
32
33
  }

  ~BaggingSampleStrategy() {}

  void Bagging(int iter, TreeLearner* tree_learner, score_t* /*gradients*/, score_t* /*hessians*/) override {
    Common::FunctionTimer fun_timer("GBDT::Bagging", global_timer);
    // if need bagging
    if ((bag_data_cnt_ < num_data_ && iter % config_->bagging_freq == 0) ||
34
      need_re_bagging_) {
35
      need_re_bagging_ = false;
36
37
      if (!config_->bagging_by_query) {
        auto left_cnt = bagging_runner_.Run<true>(
38
39
40
41
42
43
44
45
46
47
48
49
50
          num_data_,
          [=](int, data_size_t cur_start, data_size_t cur_cnt, data_size_t* left,
              data_size_t*) {
            data_size_t cur_left_count = 0;
            if (balanced_bagging_) {
              cur_left_count =
                  BalancedBaggingHelper(cur_start, cur_cnt, left);
            } else {
              cur_left_count = BaggingHelper(cur_start, cur_cnt, left);
            }
            return cur_left_count;
          },
          bag_data_indices_.data());
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        bag_data_cnt_ = left_cnt;
      } else {
        num_sampled_queries_ = bagging_runner_.Run<true>(
          num_queries_,
          [=](int, data_size_t cur_start, data_size_t cur_cnt, data_size_t* left,
              data_size_t*) {
            data_size_t cur_left_count = 0;
            cur_left_count = BaggingHelper(cur_start, cur_cnt, left);
            return cur_left_count;
          }, bag_query_indices_.data());

        sampled_query_boundaries_[0] = 0;
        OMP_INIT_EX();
        #pragma omp parallel for schedule(static) num_threads(num_threads_)
        for (data_size_t i = 0; i < num_sampled_queries_; ++i) {
          OMP_LOOP_EX_BEGIN();
          sampled_query_boundaries_[i + 1] = query_boundaries_[bag_query_indices_[i] + 1] - query_boundaries_[bag_query_indices_[i]];
          OMP_LOOP_EX_END();
        }
        OMP_THROW_EX();

        const int num_blocks = Threading::For<data_size_t>(0, num_sampled_queries_ + 1, 128, [this](int thread_index, data_size_t start_index, data_size_t end_index) {
          for (data_size_t i = start_index + 1; i < end_index; ++i) {
            sampled_query_boundaries_[i] += sampled_query_boundaries_[i - 1];
          }
          sampled_query_boundaires_thread_buffer_[thread_index] = sampled_query_boundaries_[end_index - 1];
         });

        for (int thread_index = 1; thread_index < num_blocks; ++thread_index) {
          sampled_query_boundaires_thread_buffer_[thread_index] += sampled_query_boundaires_thread_buffer_[thread_index - 1];
        }

        Threading::For<data_size_t>(0, num_sampled_queries_ + 1, 128, [this](int thread_index, data_size_t start_index, data_size_t end_index) {
          if (thread_index > 0) {
            for (data_size_t i = start_index; i < end_index; ++i) {
              sampled_query_boundaries_[i] += sampled_query_boundaires_thread_buffer_[thread_index - 1];
            }
          }
        });

        bag_data_cnt_ = sampled_query_boundaries_[num_sampled_queries_];

        Threading::For<data_size_t>(0, num_sampled_queries_, 1, [this](int /*thread_index*/, data_size_t start_index, data_size_t end_index) {
          for (data_size_t sampled_query_id = start_index; sampled_query_id < end_index; ++sampled_query_id) {
            const data_size_t query_index = bag_query_indices_[sampled_query_id];
            const data_size_t data_index_start = query_boundaries_[query_index];
            const data_size_t data_index_end = query_boundaries_[query_index + 1];
            const data_size_t sampled_query_start = sampled_query_boundaries_[sampled_query_id];
            for (data_size_t i = data_index_start; i < data_index_end; ++i) {
              bag_data_indices_[sampled_query_start + i - data_index_start] = i;
            }
          }
        });
      }
105
106
107
      Log::Debug("Re-bagging, using %d data to train", bag_data_cnt_);
      // set bagging data to tree learner
      if (!is_use_subset_) {
108
109
        #ifdef USE_CUDA
        if (config_->device_type == std::string("cuda")) {
110
111
112
          CopyFromHostToCUDADevice<data_size_t>(cuda_bag_data_indices_.RawData(), bag_data_indices_.data(), static_cast<size_t>(num_data_), __FILE__, __LINE__);
          tree_learner->SetBaggingData(nullptr, cuda_bag_data_indices_.RawData(), bag_data_cnt_);
        } else {
113
        #endif  // USE_CUDA
114
          tree_learner->SetBaggingData(nullptr, bag_data_indices_.data(), bag_data_cnt_);
115
        #ifdef USE_CUDA
116
        }
117
        #endif  // USE_CUDA
118
119
120
121
122
      } else {
        // get subset
        tmp_subset_->ReSize(bag_data_cnt_);
        tmp_subset_->CopySubrow(train_data_, bag_data_indices_.data(),
                                bag_data_cnt_, false);
123
124
        #ifdef USE_CUDA
        if (config_->device_type == std::string("cuda")) {
125
126
127
128
          CopyFromHostToCUDADevice<data_size_t>(cuda_bag_data_indices_.RawData(), bag_data_indices_.data(), static_cast<size_t>(num_data_), __FILE__, __LINE__);
          tree_learner->SetBaggingData(tmp_subset_.get(), cuda_bag_data_indices_.RawData(),
                                       bag_data_cnt_);
        } else {
129
        #endif  // USE_CUDA
130
131
          tree_learner->SetBaggingData(tmp_subset_.get(), bag_data_indices_.data(),
                                       bag_data_cnt_);
132
        #ifdef USE_CUDA
133
        }
134
        #endif  // USE_CUDA
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
      }
    }
  }

  void ResetSampleConfig(const Config* config, bool is_change_dataset) override {
    need_resize_gradients_ = false;
    // if need bagging, create buffer
    data_size_t num_pos_data = 0;
    if (objective_function_ != nullptr) {
      num_pos_data = objective_function_->NumPositiveData();
    }
    bool balance_bagging_cond = (config->pos_bagging_fraction < 1.0 || config->neg_bagging_fraction < 1.0) && (num_pos_data > 0);
    if ((config->bagging_fraction < 1.0 || balance_bagging_cond) && config->bagging_freq > 0) {
      need_re_bagging_ = false;
      if (!is_change_dataset &&
        config_ != nullptr && config_->bagging_fraction == config->bagging_fraction && config_->bagging_freq == config->bagging_freq
        && config_->pos_bagging_fraction == config->pos_bagging_fraction && config_->neg_bagging_fraction == config->neg_bagging_fraction) {
        config_ = config;
        return;
      }
      config_ = config;
      if (balance_bagging_cond) {
        balanced_bagging_ = true;
        bag_data_cnt_ = static_cast<data_size_t>(num_pos_data * config_->pos_bagging_fraction)
                        + static_cast<data_size_t>((num_data_ - num_pos_data) * config_->neg_bagging_fraction);
      } else {
        bag_data_cnt_ = static_cast<data_size_t>(config_->bagging_fraction * num_data_);
      }
      bag_data_indices_.resize(num_data_);
164
165
      #ifdef USE_CUDA
      if (config_->device_type == std::string("cuda")) {
166
167
        cuda_bag_data_indices_.Resize(num_data_);
      }
168
      #endif  // USE_CUDA
169
170
171
172
173
174
175
176
      if (!config_->bagging_by_query) {
        bagging_runner_.ReSize(num_data_);
      } else {
        bagging_runner_.ReSize(num_queries_);
        sampled_query_boundaries_.resize(num_queries_ + 1, 0);
        sampled_query_boundaires_thread_buffer_.resize(num_threads_, 0);
        bag_query_indices_.resize(num_data_);
      }
177
178
179
180
181
182
183
184
185
      bagging_rands_.clear();
      for (int i = 0;
          i < (num_data_ + bagging_rand_block_ - 1) / bagging_rand_block_; ++i) {
        bagging_rands_.emplace_back(config_->bagging_seed + i);
      }

      double average_bag_rate =
          (static_cast<double>(bag_data_cnt_) / num_data_) / config_->bagging_freq;
      is_use_subset_ = false;
186
      if (config_->device_type != std::string("cuda")) {
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        const int group_threshold_usesubset = 100;
        const double average_bag_rate_threshold = 0.5;
        if (average_bag_rate <= average_bag_rate_threshold
            && (train_data_->num_feature_groups() < group_threshold_usesubset)) {
          if (tmp_subset_ == nullptr || is_change_dataset) {
            tmp_subset_.reset(new Dataset(bag_data_cnt_));
            tmp_subset_->CopyFeatureMapperFrom(train_data_);
          }
          is_use_subset_ = true;
          Log::Debug("Use subset for bagging");
        }
      }

      need_re_bagging_ = true;

      if (is_use_subset_ && bag_data_cnt_ < num_data_) {
        // resize gradient vectors to copy the customized gradients for using subset data
        need_resize_gradients_ = true;
      }
    } else {
      bag_data_cnt_ = num_data_;
      bag_data_indices_.clear();
209
      #ifdef USE_CUDA
210
      cuda_bag_data_indices_.Clear();
211
      #endif  // USE_CUDA
212
213
214
215
216
217
218
219
220
      bagging_runner_.ReSize(0);
      is_use_subset_ = false;
    }
  }

  bool IsHessianChange() const override {
    return false;
  }

221
222
223
224
225
226
227
228
  data_size_t num_sampled_queries() const override {
    return num_sampled_queries_;
  }

  const data_size_t* sampled_query_indices() const override {
    return bag_query_indices_.data();
  }

229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
 private:
  data_size_t BaggingHelper(data_size_t start, data_size_t cnt, data_size_t* buffer) {
    if (cnt <= 0) {
      return 0;
    }
    data_size_t cur_left_cnt = 0;
    data_size_t cur_right_pos = cnt;
    // random bagging, minimal unit is one record
    for (data_size_t i = 0; i < cnt; ++i) {
      auto cur_idx = start + i;
      if (bagging_rands_[cur_idx / bagging_rand_block_].NextFloat() < config_->bagging_fraction) {
        buffer[cur_left_cnt++] = cur_idx;
      } else {
        buffer[--cur_right_pos] = cur_idx;
      }
    }
    return cur_left_cnt;
  }

  data_size_t BalancedBaggingHelper(data_size_t start, data_size_t cnt, data_size_t* buffer) {
    if (cnt <= 0) {
      return 0;
    }
    auto label_ptr = train_data_->metadata().label();
    data_size_t cur_left_cnt = 0;
    data_size_t cur_right_pos = cnt;
    // random bagging, minimal unit is one record
    for (data_size_t i = 0; i < cnt; ++i) {
      auto cur_idx = start + i;
      bool is_pos = label_ptr[start + i] > 0;
      bool is_in_bag = false;
      if (is_pos) {
        is_in_bag = bagging_rands_[cur_idx / bagging_rand_block_].NextFloat() <
                    config_->pos_bagging_fraction;
      } else {
        is_in_bag = bagging_rands_[cur_idx / bagging_rand_block_].NextFloat() <
                    config_->neg_bagging_fraction;
      }
      if (is_in_bag) {
        buffer[cur_left_cnt++] = cur_idx;
      } else {
        buffer[--cur_right_pos] = cur_idx;
      }
    }
    return cur_left_cnt;
  }

  /*! \brief whether need restart bagging in continued training */
  bool need_re_bagging_;
278
279
280
281
282
283
284
285
286
287
288
289
290
291
  /*! \brief number of threads */
  int num_threads_;
  /*! \brief query boundaries of the in-bag queries */
  std::vector<data_size_t> sampled_query_boundaries_;
  /*! \brief buffer for calculating sampled_query_boundaries_ */
  std::vector<data_size_t> sampled_query_boundaires_thread_buffer_;
  /*! \brief in-bag query indices */
  std::vector<data_size_t, Common::AlignmentAllocator<data_size_t, kAlignedSize>> bag_query_indices_;
  /*! \brief number of queries in the training dataset */
  data_size_t num_queries_;
  /*! \brief number of in-bag queries */
  data_size_t num_sampled_queries_;
  /*! \brief query boundaries of the whole training dataset */
  const data_size_t* query_boundaries_;
292
293
294
295
296
};

}  // namespace LightGBM

#endif  // LIGHTGBM_BOOSTING_BAGGING_HPP_