gbdt.cpp 29.8 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
#include "gbdt.h"

3
#include <LightGBM/utils/openmp_wrapper.h>
4

Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#include <LightGBM/utils/common.h>

#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
cbecker's avatar
cbecker committed
9
#include <LightGBM/prediction_early_stop.h>
Guolin Ke's avatar
Guolin Ke committed
10
#include <LightGBM/network.h>
Guolin Ke's avatar
Guolin Ke committed
11
12
13
14
15
16
17

#include <ctime>

#include <sstream>
#include <chrono>
#include <string>
#include <vector>
18
#include <utility>
Guolin Ke's avatar
Guolin Ke committed
19
20
21

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
22
23
24
#ifdef TIMETAG
std::chrono::duration<double, std::milli> boosting_time;
std::chrono::duration<double, std::milli> train_score_time;
Guolin Ke's avatar
Guolin Ke committed
25
std::chrono::duration<double, std::milli> out_of_bag_score_time;
Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
30
31
std::chrono::duration<double, std::milli> valid_score_time;
std::chrono::duration<double, std::milli> metric_time;
std::chrono::duration<double, std::milli> bagging_time;
std::chrono::duration<double, std::milli> tree_time;
#endif // TIMETAG

Guolin Ke's avatar
Guolin Ke committed
32
GBDT::GBDT() : iter_(0),
Guolin Ke's avatar
Guolin Ke committed
33
34
35
36
37
38
39
40
41
42
train_data_(nullptr),
objective_function_(nullptr),
early_stopping_round_(0),
max_feature_idx_(0),
num_tree_per_iteration_(1),
num_class_(1),
num_iteration_for_pred_(0),
shrinkage_rate_(0.1f),
num_init_iteration_(0),
need_re_bagging_(false) {
Guolin Ke's avatar
Guolin Ke committed
43

Guolin Ke's avatar
Guolin Ke committed
44
45
  #pragma omp parallel
  #pragma omp master
Guolin Ke's avatar
Guolin Ke committed
46
47
48
49
  {
    num_threads_ = omp_get_num_threads();
  }
  average_output_ = false;
Guolin Ke's avatar
Guolin Ke committed
50
  tree_learner_ = nullptr;
Guolin Ke's avatar
Guolin Ke committed
51
52
53
}

GBDT::~GBDT() {
Guolin Ke's avatar
Guolin Ke committed
54
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
55
56
  Log::Info("GBDT::boosting costs %f", boosting_time * 1e-3);
  Log::Info("GBDT::train_score costs %f", train_score_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
57
  Log::Info("GBDT::out_of_bag_score costs %f", out_of_bag_score_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
58
59
60
61
  Log::Info("GBDT::valid_score costs %f", valid_score_time * 1e-3);
  Log::Info("GBDT::metric costs %f", metric_time * 1e-3);
  Log::Info("GBDT::bagging costs %f", bagging_time * 1e-3);
  Log::Info("GBDT::tree costs %f", tree_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
62
  #endif
Guolin Ke's avatar
Guolin Ke committed
63
64
}

65
void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
66
                const std::vector<const Metric*>& training_metrics) {
Guolin Ke's avatar
Guolin Ke committed
67
  CHECK(train_data != nullptr);
Guolin Ke's avatar
Guolin Ke committed
68
  CHECK(train_data->num_features() > 0);
69
  train_data_ = train_data;
70
  iter_ = 0;
wxchan's avatar
wxchan committed
71
  num_iteration_for_pred_ = 0;
72
  max_feature_idx_ = 0;
wxchan's avatar
wxchan committed
73
  num_class_ = config->num_class;
74
75
76
77
78
79
80
81
  gbdt_config_ = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
  early_stopping_round_ = gbdt_config_->early_stopping_round;
  shrinkage_rate_ = gbdt_config_->learning_rate;

  objective_function_ = objective_function;
  num_tree_per_iteration_ = num_class_;
  if (objective_function_ != nullptr) {
    is_constant_hessian_ = objective_function_->IsConstantHessian();
Guolin Ke's avatar
Guolin Ke committed
82
    num_tree_per_iteration_ = objective_function_->NumModelPerIteration();
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
  } else {
    is_constant_hessian_ = false;
  }

  tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->device_type, &gbdt_config_->tree_config));

  // init tree learner
  tree_learner_->Init(train_data_, is_constant_hessian_);

  // push training metrics
  training_metrics_.clear();
  for (const auto& metric : training_metrics) {
    training_metrics_.push_back(metric);
  }
  training_metrics_.shrink_to_fit();

  train_score_updater_.reset(new ScoreUpdater(train_data_, num_tree_per_iteration_));

  num_data_ = train_data_->num_data();
  // create buffer for gradients and hessians
  if (objective_function_ != nullptr) {
    size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
    gradients_.resize(total_size);
    hessians_.resize(total_size);
  }
  // get max feature index
  max_feature_idx_ = train_data_->num_total_features() - 1;
  // get label index
  label_idx_ = train_data_->label_idx();
  // get feature names
  feature_names_ = train_data_->feature_names();
  feature_infos_ = train_data_->feature_infos();

  // if need bagging, create buffer
Guolin Ke's avatar
Guolin Ke committed
117
  ResetBaggingConfig(gbdt_config_.get(), true);
118
119
120
121
122

  // reset config for tree learner
  class_need_train_ = std::vector<bool>(num_tree_per_iteration_, true);
  if (objective_function_ != nullptr && objective_function_->SkipEmptyClass()) {
    CHECK(num_tree_per_iteration_ == num_class_);
Guolin Ke's avatar
Guolin Ke committed
123

124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    class_default_output_ = std::vector<double>(num_tree_per_iteration_, 0.0f);
    auto label = train_data_->metadata().label();
    if (num_tree_per_iteration_ > 1) {
      // multi-class
      std::vector<data_size_t> cnt_per_class(num_tree_per_iteration_, 0);
      for (data_size_t i = 0; i < num_data_; ++i) {
        int index = static_cast<int>(label[i]);
        CHECK(index < num_tree_per_iteration_);
        ++cnt_per_class[index];
      }
      for (int i = 0; i < num_tree_per_iteration_; ++i) {
        if (cnt_per_class[i] == num_data_) {
          class_need_train_[i] = false;
          class_default_output_[i] = -std::log(kEpsilon);
        } else if (cnt_per_class[i] == 0) {
          class_need_train_[i] = false;
          class_default_output_[i] = -std::log(1.0f / kEpsilon - 1.0f);
        }
      }
    } else {
      // binary class
      data_size_t cnt_pos = 0;
      for (data_size_t i = 0; i < num_data_; ++i) {
        if (label[i] > 0) {
          ++cnt_pos;
        }
      }
      if (cnt_pos == 0) {
        class_need_train_[0] = false;
        class_default_output_[0] = -std::log(1.0f / kEpsilon - 1.0f);
      } else if (cnt_pos == num_data_) {
        class_need_train_[0] = false;
        class_default_output_[0] = -std::log(kEpsilon);
      }
    }
  }
wxchan's avatar
wxchan committed
160
161
162
}

void GBDT::AddValidDataset(const Dataset* valid_data,
163
                           const std::vector<const Metric*>& valid_metrics) {
wxchan's avatar
wxchan committed
164
165
  if (!train_data_->CheckAlign(*valid_data)) {
    Log::Fatal("cannot add validation data, since it has different bin mappers with training data");
166
  }
Guolin Ke's avatar
Guolin Ke committed
167
  // for a validation dataset, we need its score and metric
168
  auto new_score_updater = std::unique_ptr<ScoreUpdater>(new ScoreUpdater(valid_data, num_tree_per_iteration_));
wxchan's avatar
wxchan committed
169
170
  // update score
  for (int i = 0; i < iter_; ++i) {
171
172
173
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
      auto curr_tree = (i + num_init_iteration_) * num_tree_per_iteration_ + cur_tree_id;
      new_score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
174
175
    }
  }
Guolin Ke's avatar
Guolin Ke committed
176
  valid_score_updater_.push_back(std::move(new_score_updater));
Guolin Ke's avatar
Guolin Ke committed
177
  valid_metrics_.emplace_back();
178
179
180
  if (early_stopping_round_ > 0) {
    best_iter_.emplace_back();
    best_score_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
181
    best_msg_.emplace_back();
182
  }
Guolin Ke's avatar
Guolin Ke committed
183
184
  for (const auto& metric : valid_metrics) {
    valid_metrics_.back().push_back(metric);
185
186
187
    if (early_stopping_round_ > 0) {
      best_iter_.back().push_back(0);
      best_score_.back().push_back(kMinScore);
Guolin Ke's avatar
Guolin Ke committed
188
      best_msg_.back().emplace_back();
189
    }
Guolin Ke's avatar
Guolin Ke committed
190
  }
Guolin Ke's avatar
Guolin Ke committed
191
  valid_metrics_.back().shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
192
193
}

Guolin Ke's avatar
Guolin Ke committed
194
195
196
197
198
199
200
201
202
203
void GBDT::Boosting() {
  if (objective_function_ == nullptr) {
    Log::Fatal("No object function provided");
  }
  // objective function will calculate gradients and hessians
  int64_t num_score = 0;
  objective_function_->
    GetGradients(GetTrainingScore(&num_score), gradients_.data(), hessians_.data());
}

204
data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer) {
205
206
207
  if (cnt <= 0) {
    return 0;
  }
Guolin Ke's avatar
Guolin Ke committed
208
  data_size_t bag_data_cnt = static_cast<data_size_t>(gbdt_config_->bagging_fraction * cnt);
209
210
  data_size_t cur_left_cnt = 0;
  data_size_t cur_right_cnt = 0;
Guolin Ke's avatar
Guolin Ke committed
211
  auto right_buffer = buffer + bag_data_cnt;
212
213
  // random bagging, minimal unit is one record
  for (data_size_t i = 0; i < cnt; ++i) {
Guolin Ke's avatar
Guolin Ke committed
214
    float prob = (bag_data_cnt - cur_left_cnt) / static_cast<float>(cnt - i);
Guolin Ke's avatar
Guolin Ke committed
215
    if (cur_rand.NextFloat() < prob) {
216
217
      buffer[cur_left_cnt++] = start + i;
    } else {
Guolin Ke's avatar
Guolin Ke committed
218
      right_buffer[cur_right_cnt++] = start + i;
219
220
221
222
223
    }
  }
  CHECK(cur_left_cnt == bag_data_cnt);
  return cur_left_cnt;
}
Guolin Ke's avatar
Guolin Ke committed
224

225
void GBDT::Bagging(int iter) {
Guolin Ke's avatar
Guolin Ke committed
226
  // if need bagging
Guolin Ke's avatar
Guolin Ke committed
227
  if ((bag_data_cnt_ < num_data_ && iter % gbdt_config_->bagging_freq == 0)
Guolin Ke's avatar
Guolin Ke committed
228
      || need_re_bagging_) {
Guolin Ke's avatar
Guolin Ke committed
229
    need_re_bagging_ = false;
Guolin Ke's avatar
Guolin Ke committed
230
    const data_size_t min_inner_size = 1000;
231
232
    data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_;
    if (inner_size < min_inner_size) { inner_size = min_inner_size; }
233
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
234
    #pragma omp parallel for schedule(static,1)
235
    for (int i = 0; i < num_threads_; ++i) {
236
      OMP_LOOP_EX_BEGIN();
237
238
239
240
241
242
      left_cnts_buf_[i] = 0;
      right_cnts_buf_[i] = 0;
      data_size_t cur_start = i * inner_size;
      if (cur_start > num_data_) { continue; }
      data_size_t cur_cnt = inner_size;
      if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; }
Guolin Ke's avatar
Guolin Ke committed
243
244
      Random cur_rand(gbdt_config_->bagging_seed + iter * num_threads_ + i);
      data_size_t cur_left_count = BaggingHelper(cur_rand, cur_start, cur_cnt, tmp_indices_.data() + cur_start);
245
246
247
      offsets_buf_[i] = cur_start;
      left_cnts_buf_[i] = cur_left_count;
      right_cnts_buf_[i] = cur_cnt - cur_left_count;
248
      OMP_LOOP_EX_END();
249
    }
250
    OMP_THROW_EX();
251
252
253
254
255
256
257
258
259
    data_size_t left_cnt = 0;
    left_write_pos_buf_[0] = 0;
    right_write_pos_buf_[0] = 0;
    for (int i = 1; i < num_threads_; ++i) {
      left_write_pos_buf_[i] = left_write_pos_buf_[i - 1] + left_cnts_buf_[i - 1];
      right_write_pos_buf_[i] = right_write_pos_buf_[i - 1] + right_cnts_buf_[i - 1];
    }
    left_cnt = left_write_pos_buf_[num_threads_ - 1] + left_cnts_buf_[num_threads_ - 1];

Guolin Ke's avatar
Guolin Ke committed
260
    #pragma omp parallel for schedule(static, 1)
261
    for (int i = 0; i < num_threads_; ++i) {
262
      OMP_LOOP_EX_BEGIN();
263
264
      if (left_cnts_buf_[i] > 0) {
        std::memcpy(bag_data_indices_.data() + left_write_pos_buf_[i],
265
                    tmp_indices_.data() + offsets_buf_[i], left_cnts_buf_[i] * sizeof(data_size_t));
Guolin Ke's avatar
Guolin Ke committed
266
      }
267
268
      if (right_cnts_buf_[i] > 0) {
        std::memcpy(bag_data_indices_.data() + left_cnt + right_write_pos_buf_[i],
269
                    tmp_indices_.data() + offsets_buf_[i] + left_cnts_buf_[i], right_cnts_buf_[i] * sizeof(data_size_t));
Guolin Ke's avatar
Guolin Ke committed
270
      }
271
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
272
    }
273
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
274
    bag_data_cnt_ = left_cnt;
Guolin Ke's avatar
Guolin Ke committed
275
    Log::Debug("Re-bagging, using %d data to train", bag_data_cnt_);
Guolin Ke's avatar
Guolin Ke committed
276
    // set bagging data to tree learner
Guolin Ke's avatar
Guolin Ke committed
277
278
279
280
    if (!is_use_subset_) {
      tree_learner_->SetBaggingData(bag_data_indices_.data(), bag_data_cnt_);
    } else {
      // get subset
Guolin Ke's avatar
Guolin Ke committed
281
282
      tmp_subset_->ReSize(bag_data_cnt_);
      tmp_subset_->CopySubset(train_data_, bag_data_indices_.data(), bag_data_cnt_, false);
Guolin Ke's avatar
Guolin Ke committed
283
284
      tree_learner_->ResetTrainingData(tmp_subset_.get());
    }
Guolin Ke's avatar
Guolin Ke committed
285
286
287
  }
}

288
/* If the custom "average" is implemented it will be used inplace of the label average (if enabled)
Guolin Ke's avatar
Guolin Ke committed
289
290
291
292
293
294
295
296
297
*
* An improvement to this is to have options to explicitly choose
* (i) standard average
* (ii) custom average if available
* (iii) any user defined scalar bias (e.g. using a new option "init_score" that overrides (i) and (ii) )
*
* (i) and (ii) could be selected as say "auto_init_score" = 0 or 1 etc..
*
*/
298
double ObtainAutomaticInitialScore(const ObjectiveFunction* fobj, const label_t* label, data_size_t num_data) {
299
300
  double init_score = 0.0f;
  bool got_custom = false;
Guolin Ke's avatar
Guolin Ke committed
301
302
  if (fobj != nullptr) {
    got_custom = fobj->GetCustomAverage(&init_score);
303
304
305
306
307
308
309
310
  }
  if (!got_custom) {
    double sum_label = 0.0f;
    #pragma omp parallel for schedule(static) reduction(+:sum_label)
    for (data_size_t i = 0; i < num_data; ++i) {
      sum_label += label[i];
    }
    init_score = sum_label / num_data;
Guolin Ke's avatar
Guolin Ke committed
311
312
313
314
315
316
  }
  if (Network::num_machines() > 1) {
    double global_init_score = 0.0f;
    Network::Allreduce(reinterpret_cast<char*>(&init_score),
                       sizeof(init_score), sizeof(init_score),
                       reinterpret_cast<char*>(&global_init_score),
Guolin Ke's avatar
Guolin Ke committed
317
318
                       [](const char* src, char* dst, int type_size, comm_size_t len) {
      comm_size_t used_size = 0;
Guolin Ke's avatar
Guolin Ke committed
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
      const double *p1;
      double *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const double *>(src);
        p2 = reinterpret_cast<double *>(dst);
        *p2 += *p1;
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    return global_init_score / Network::num_machines();
  } else {
    return init_score;
  }
}

Guolin Ke's avatar
Guolin Ke committed
336
337
338
339
void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
  bool is_finished = false;
  auto start_time = std::chrono::steady_clock::now();
  for (int iter = 0; iter < gbdt_config_->num_iterations && !is_finished; ++iter) {
Guolin Ke's avatar
Guolin Ke committed
340
341
342
343
    is_finished = TrainOneIter(nullptr, nullptr);
    if (!is_finished) {
      is_finished = EvalAndCheckEarlyStopping();
    }
Guolin Ke's avatar
Guolin Ke committed
344
345
346
347
348
349
350
351
352
353
354
355
    auto end_time = std::chrono::steady_clock::now();
    // output used time per iteration
    Log::Info("%f seconds elapsed, finished iteration %d", std::chrono::duration<double,
              std::milli>(end_time - start_time) * 1e-3, iter + 1);
    if (snapshot_freq > 0
        && (iter + 1) % snapshot_freq == 0) {
      std::string snapshot_out = model_output_path + ".snapshot_iter_" + std::to_string(iter + 1);
      SaveModelToFile(-1, snapshot_out.c_str());
    }
  }
}

356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
void GBDT::RefitTree(const std::vector<std::vector<int>>& tree_leaf_prediction) {
  CHECK(tree_leaf_prediction.size() > 0);
  CHECK(static_cast<size_t>(num_data_) == tree_leaf_prediction.size());
  CHECK(static_cast<size_t>(models_.size()) == tree_leaf_prediction[0].size());
  int num_iterations = static_cast<int>(models_.size() / num_tree_per_iteration_);
  std::vector<int> leaf_pred(num_data_);
  for (int iter = 0; iter < num_iterations; ++iter) {
    Boosting();
    for (int tree_id = 0; tree_id < num_tree_per_iteration_; ++tree_id) {
      int model_index = iter * num_tree_per_iteration_ + tree_id;
      #pragma omp parallel for schedule(static)
      for (int i = 0; i < num_data_; ++i) {
        leaf_pred[i] = tree_leaf_prediction[i][model_index];
      }
      size_t bias = static_cast<size_t>(tree_id) * num_data_;
      auto grad = gradients_.data() + bias;
      auto hess = hessians_.data() + bias;
      auto new_tree = tree_learner_->FitByExistingTree(models_[model_index].get(), leaf_pred, grad, hess);
      train_score_updater_->AddScore(tree_learner_.get(), new_tree, tree_id);
      models_[model_index].reset(new_tree);
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
380
double GBDT::BoostFromAverage() {
381
  // boosting from average label; or customized "average" if implemented for the current objective
Guolin Ke's avatar
Guolin Ke committed
382
383
  if (models_.empty()
      && gbdt_config_->boost_from_average
384
      && !train_score_updater_->has_init_score()
385
386
387
      && num_class_ <= 1
      && objective_function_ != nullptr
      && objective_function_->BoostFromAverage()) {
Guolin Ke's avatar
Guolin Ke committed
388

389
390
    auto label = train_data_->metadata().label();
    double init_score = ObtainAutomaticInitialScore(objective_function_, label, num_data_);
Guolin Ke's avatar
Guolin Ke committed
391
392
393
394
395
396
    if (std::fabs(init_score) > kEpsilon) {
      train_score_updater_->AddScore(init_score, 0);
      for (auto& score_updater : valid_score_updater_) {
        score_updater->AddScore(init_score, 0);
      }
      return init_score;
397
    }
398
  }
Guolin Ke's avatar
Guolin Ke committed
399
400
  return 0.0f;
}
Guolin Ke's avatar
Guolin Ke committed
401

Guolin Ke's avatar
Guolin Ke committed
402
403
bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
  auto init_score = BoostFromAverage();
Guolin Ke's avatar
Guolin Ke committed
404
  // boosting first
Guolin Ke's avatar
Guolin Ke committed
405
406
  if (gradients == nullptr || hessians == nullptr) {

Guolin Ke's avatar
Guolin Ke committed
407
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
408
    auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
409
    #endif
Guolin Ke's avatar
Guolin Ke committed
410

Guolin Ke's avatar
Guolin Ke committed
411
    Boosting();
Guolin Ke's avatar
Guolin Ke committed
412
413
414
    gradients = gradients_.data();
    hessians = hessians_.data();

Guolin Ke's avatar
Guolin Ke committed
415
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
416
    boosting_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
417
    #endif
Guolin Ke's avatar
Guolin Ke committed
418
  }
Guolin Ke's avatar
Guolin Ke committed
419

Guolin Ke's avatar
Guolin Ke committed
420
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
421
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
422
  #endif
Guolin Ke's avatar
Guolin Ke committed
423

424
425
  // bagging logic
  Bagging(iter_);
Guolin Ke's avatar
Guolin Ke committed
426

Guolin Ke's avatar
Guolin Ke committed
427
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
428
  bagging_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
429
  #endif
Guolin Ke's avatar
Guolin Ke committed
430

Guolin Ke's avatar
Guolin Ke committed
431
  bool should_continue = false;
432
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
433

Guolin Ke's avatar
Guolin Ke committed
434
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
435
    start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
436
    #endif
Guolin Ke's avatar
Guolin Ke committed
437

438
    std::unique_ptr<Tree> new_tree(new Tree(2));
439
    if (class_need_train_[cur_tree_id]) {
440
      size_t bias = static_cast<size_t>(cur_tree_id)* num_data_;
Guolin Ke's avatar
Guolin Ke committed
441
442
443
444
445
446
447
448
449
450
451
452
453
454
      auto grad = gradients + bias;
      auto hess = hessians + bias;

      // need to copy gradients for bagging subset.
      if (is_use_subset_ && bag_data_cnt_ < num_data_) {
        for (int i = 0; i < bag_data_cnt_; ++i) {
          gradients_[bias + i] = grad[bag_data_indices_[i]];
          hessians_[bias + i] = hess[bag_data_indices_[i]];
        }
        grad = gradients_.data() + bias;
        hess = hessians_.data() + bias;
      }

      new_tree.reset(tree_learner_->Train(grad, hess, is_constant_hessian_));
455
    }
Guolin Ke's avatar
Guolin Ke committed
456

Guolin Ke's avatar
Guolin Ke committed
457
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
458
    tree_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
459
    #endif
Guolin Ke's avatar
Guolin Ke committed
460
461

    if (new_tree->num_leaves() > 1) {
Guolin Ke's avatar
Guolin Ke committed
462
463
464
465
      should_continue = true;
      // shrinkage by learning rate
      new_tree->Shrinkage(shrinkage_rate_);
      // update score
466
      UpdateScore(new_tree.get(), cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
467
468
469
      if (std::fabs(init_score) > kEpsilon) {
        new_tree->AddBias(init_score);
      }
470
471
    } else {
      // only add default score one-time
472
473
      if (!class_need_train_[cur_tree_id] && models_.size() < static_cast<size_t>(num_tree_per_iteration_)) {
        auto output = class_default_output_[cur_tree_id];
474
        new_tree->AsConstantTree(output);
Guolin Ke's avatar
Guolin Ke committed
475
        // updates scores
476
        train_score_updater_->AddScore(output, cur_tree_id);
477
        for (auto& score_updater : valid_score_updater_) {
478
          score_updater->AddScore(output, cur_tree_id);
479
480
481
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
482
483
484
    // add model
    models_.push_back(std::move(new_tree));
  }
Guolin Ke's avatar
Guolin Ke committed
485

Guolin Ke's avatar
Guolin Ke committed
486
  if (!should_continue) {
Guolin Ke's avatar
Guolin Ke committed
487
    Log::Warning("Stopped training because there are no more leaves that meet the split requirements.");
488
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
489
490
491
492
      models_.pop_back();
    }
    return true;
  }
493

Guolin Ke's avatar
Guolin Ke committed
494
495
  ++iter_;
  return false;
Guolin Ke's avatar
Guolin Ke committed
496
}
497

wxchan's avatar
wxchan committed
498
void GBDT::RollbackOneIter() {
499
  if (iter_ <= 0) { return; }
wxchan's avatar
wxchan committed
500
  // reset score
501
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
502
    auto curr_tree = models_.size() - num_tree_per_iteration_ + cur_tree_id;
wxchan's avatar
wxchan committed
503
    models_[curr_tree]->Shrinkage(-1.0);
504
    train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
505
    for (auto& score_updater : valid_score_updater_) {
506
      score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
507
508
509
    }
  }
  // remove model
510
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
wxchan's avatar
wxchan committed
511
512
513
514
515
    models_.pop_back();
  }
  --iter_;
}

Guolin Ke's avatar
Guolin Ke committed
516
bool GBDT::EvalAndCheckEarlyStopping() {
517
  bool is_met_early_stopping = false;
Guolin Ke's avatar
Guolin Ke committed
518

Guolin Ke's avatar
Guolin Ke committed
519
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
520
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
521
  #endif
Guolin Ke's avatar
Guolin Ke committed
522

523
  // print message for metric
Guolin Ke's avatar
Guolin Ke committed
524
  auto best_msg = OutputMetric(iter_);
Guolin Ke's avatar
Guolin Ke committed
525

Guolin Ke's avatar
Guolin Ke committed
526
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
527
  metric_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
528
  #endif
Guolin Ke's avatar
Guolin Ke committed
529

Guolin Ke's avatar
Guolin Ke committed
530
  is_met_early_stopping = !best_msg.empty();
531
532
  if (is_met_early_stopping) {
    Log::Info("Early stopping at iteration %d, the best iteration round is %d",
533
              iter_, iter_ - early_stopping_round_);
Guolin Ke's avatar
Guolin Ke committed
534
    Log::Info("Output of best iteration round:\n%s", best_msg.c_str());
535
    // pop last early_stopping_round_ models
536
    for (int i = 0; i < early_stopping_round_ * num_tree_per_iteration_; ++i) {
537
538
539
540
      models_.pop_back();
    }
  }
  return is_met_early_stopping;
Guolin Ke's avatar
Guolin Ke committed
541
542
}

543
void GBDT::UpdateScore(const Tree* tree, const int cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
544

Guolin Ke's avatar
Guolin Ke committed
545
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
546
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
547
  #endif
Guolin Ke's avatar
Guolin Ke committed
548

Guolin Ke's avatar
Guolin Ke committed
549
  // update training score
Guolin Ke's avatar
Guolin Ke committed
550
  if (!is_use_subset_) {
551
    train_score_updater_->AddScore(tree_learner_.get(), tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569

    #ifdef TIMETAG
    train_score_time += std::chrono::steady_clock::now() - start_time;
    #endif

    #ifdef TIMETAG
    start_time = std::chrono::steady_clock::now();
    #endif

    // we need to predict out-of-bag scores of data for boosting
    if (num_data_ - bag_data_cnt_ > 0) {
      train_score_updater_->AddScore(tree, bag_data_indices_.data() + bag_data_cnt_, num_data_ - bag_data_cnt_, cur_tree_id);
    }

    #ifdef TIMETAG
    out_of_bag_score_time += std::chrono::steady_clock::now() - start_time;
    #endif

Guolin Ke's avatar
Guolin Ke committed
570
  } else {
571
    train_score_updater_->AddScore(tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
572
573
574
575

    #ifdef TIMETAG
    train_score_time += std::chrono::steady_clock::now() - start_time;
    #endif
Guolin Ke's avatar
Guolin Ke committed
576
  }
Guolin Ke's avatar
Guolin Ke committed
577
578


Guolin Ke's avatar
Guolin Ke committed
579
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
580
  start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
581
  #endif
Guolin Ke's avatar
Guolin Ke committed
582

Guolin Ke's avatar
Guolin Ke committed
583
  // update validation score
Guolin Ke's avatar
Guolin Ke committed
584
  for (auto& score_updater : valid_score_updater_) {
585
    score_updater->AddScore(tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
586
  }
Guolin Ke's avatar
Guolin Ke committed
587

Guolin Ke's avatar
Guolin Ke committed
588
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
589
  valid_score_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
590
  #endif
Guolin Ke's avatar
Guolin Ke committed
591
592
}

Guolin Ke's avatar
Guolin Ke committed
593
594
595
596
std::vector<double> GBDT::EvalOneMetric(const Metric* metric, const double* score) const {
  return metric->Eval(score, objective_function_);
}

Guolin Ke's avatar
Guolin Ke committed
597
598
599
600
std::string GBDT::OutputMetric(int iter) {
  bool need_output = (iter % gbdt_config_->output_freq) == 0;
  std::string ret = "";
  std::stringstream msg_buf;
601
  std::vector<std::pair<size_t, size_t>> meet_early_stopping_pairs;
Guolin Ke's avatar
Guolin Ke committed
602
  // print training metric
Guolin Ke's avatar
Guolin Ke committed
603
  if (need_output) {
604
605
    for (auto& sub_metric : training_metrics_) {
      auto name = sub_metric->GetName();
Guolin Ke's avatar
Guolin Ke committed
606
      auto scores = EvalOneMetric(sub_metric, train_score_updater_->score());
Guolin Ke's avatar
Guolin Ke committed
607
      for (size_t k = 0; k < name.size(); ++k) {
Guolin Ke's avatar
Guolin Ke committed
608
609
610
611
612
613
        std::stringstream tmp_buf;
        tmp_buf << "Iteration:" << iter
          << ", training " << name[k]
          << " : " << scores[k];
        Log::Info(tmp_buf.str().c_str());
        if (early_stopping_round_ > 0) {
614
          msg_buf << tmp_buf.str() << '\n';
Guolin Ke's avatar
Guolin Ke committed
615
        }
616
      }
617
    }
Guolin Ke's avatar
Guolin Ke committed
618
619
  }
  // print validation metric
Guolin Ke's avatar
Guolin Ke committed
620
  if (need_output || early_stopping_round_ > 0) {
621
622
    for (size_t i = 0; i < valid_metrics_.size(); ++i) {
      for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
623
        auto test_scores = EvalOneMetric(valid_metrics_[i][j], valid_score_updater_[i]->score());
Guolin Ke's avatar
Guolin Ke committed
624
625
626
627
628
629
630
631
632
633
        auto name = valid_metrics_[i][j]->GetName();
        for (size_t k = 0; k < name.size(); ++k) {
          std::stringstream tmp_buf;
          tmp_buf << "Iteration:" << iter
            << ", valid_" << i + 1 << " " << name[k]
            << " : " << test_scores[k];
          if (need_output) {
            Log::Info(tmp_buf.str().c_str());
          }
          if (early_stopping_round_ > 0) {
634
            msg_buf << tmp_buf.str() << '\n';
635
          }
wxchan's avatar
wxchan committed
636
        }
Guolin Ke's avatar
Guolin Ke committed
637
        if (ret.empty() && early_stopping_round_ > 0) {
638
639
640
          auto cur_score = valid_metrics_[i][j]->factor_to_bigger_better() * test_scores.back();
          if (cur_score > best_score_[i][j]) {
            best_score_[i][j] = cur_score;
641
            best_iter_[i][j] = iter;
Guolin Ke's avatar
Guolin Ke committed
642
            meet_early_stopping_pairs.emplace_back(i, j);
643
          } else {
Guolin Ke's avatar
Guolin Ke committed
644
            if (iter - best_iter_[i][j] >= early_stopping_round_) { ret = best_msg_[i][j]; }
645
          }
wxchan's avatar
wxchan committed
646
647
        }
      }
Guolin Ke's avatar
Guolin Ke committed
648
649
    }
  }
Guolin Ke's avatar
Guolin Ke committed
650
651
652
  for (auto& pair : meet_early_stopping_pairs) {
    best_msg_[pair.first][pair.second] = msg_buf.str();
  }
wxchan's avatar
wxchan committed
653
  return ret;
Guolin Ke's avatar
Guolin Ke committed
654
655
}

656
/*! \brief Get eval result */
657
std::vector<double> GBDT::GetEvalAt(int data_idx) const {
Guolin Ke's avatar
Guolin Ke committed
658
  CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
659
660
  std::vector<double> ret;
  if (data_idx == 0) {
661
    for (auto& sub_metric : training_metrics_) {
Guolin Ke's avatar
Guolin Ke committed
662
      auto scores = EvalOneMetric(sub_metric, train_score_updater_->score());
663
664
665
      for (auto score : scores) {
        ret.push_back(score);
      }
666
    }
667
  } else {
668
669
    auto used_idx = data_idx - 1;
    for (size_t j = 0; j < valid_metrics_[used_idx].size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
670
      auto test_scores = EvalOneMetric(valid_metrics_[used_idx][j], valid_score_updater_[used_idx]->score());
671
672
673
      for (auto score : test_scores) {
        ret.push_back(score);
      }
674
675
676
677
678
    }
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
679
/*! \brief Get training scores result */
680
const double* GBDT::GetTrainingScore(int64_t* out_len) {
681
  *out_len = static_cast<int64_t>(train_score_updater_->num_data()) * num_class_;
Guolin Ke's avatar
Guolin Ke committed
682
  return train_score_updater_->score();
683
684
}

685
686
687
void GBDT::PredictContrib(const double* features, double* output, const PredictionEarlyStopInstance* early_stop) const {
  int early_stop_round_counter = 0;
  // set zero
Guolin Ke's avatar
Guolin Ke committed
688
689
  const int num_features = max_feature_idx_ + 1;
  std::memset(output, 0, sizeof(double) * num_tree_per_iteration_ * (num_features + 1));
690
691
692
  for (int i = 0; i < num_iteration_for_pred_; ++i) {
    // predict all the trees for one iteration
    for (int k = 0; k < num_tree_per_iteration_; ++k) {
Guolin Ke's avatar
Guolin Ke committed
693
      models_[i * num_tree_per_iteration_ + k]->PredictContrib(features, num_features, output + k*(num_features + 1));
694
695
696
697
698
699
700
701
702
703
704
705
    }
    // check early stopping
    ++early_stop_round_counter;
    if (early_stop->round_period == early_stop_round_counter) {
      if (early_stop->callback_function(output, num_tree_per_iteration_)) {
        return;
      }
      early_stop_round_counter = 0;
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
706
707
void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
  CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
Guolin Ke's avatar
Guolin Ke committed
708

709
  const double* raw_scores = nullptr;
Guolin Ke's avatar
Guolin Ke committed
710
711
  data_size_t num_data = 0;
  if (data_idx == 0) {
wxchan's avatar
wxchan committed
712
    raw_scores = GetTrainingScore(out_len);
Guolin Ke's avatar
Guolin Ke committed
713
714
715
716
717
    num_data = train_score_updater_->num_data();
  } else {
    auto used_idx = data_idx - 1;
    raw_scores = valid_score_updater_[used_idx]->score();
    num_data = valid_score_updater_[used_idx]->num_data();
718
    *out_len = static_cast<int64_t>(num_data) * num_class_;
Guolin Ke's avatar
Guolin Ke committed
719
  }
Guolin Ke's avatar
Guolin Ke committed
720
  if (objective_function_ != nullptr && !average_output_) {
Guolin Ke's avatar
Guolin Ke committed
721
722
    #pragma omp parallel for schedule(static)
    for (data_size_t i = 0; i < num_data; ++i) {
Guolin Ke's avatar
Guolin Ke committed
723
      std::vector<double> tree_pred(num_tree_per_iteration_);
724
      for (int j = 0; j < num_tree_per_iteration_; ++j) {
Guolin Ke's avatar
Guolin Ke committed
725
        tree_pred[j] = raw_scores[j * num_data + i];
726
      }
Guolin Ke's avatar
Guolin Ke committed
727
728
      std::vector<double> tmp_result(num_class_);
      objective_function_->ConvertOutput(tree_pred.data(), tmp_result.data());
Guolin Ke's avatar
Guolin Ke committed
729
      for (int j = 0; j < num_class_; ++j) {
730
        out_result[j * num_data + i] = static_cast<double>(tmp_result[j]);
Guolin Ke's avatar
Guolin Ke committed
731
732
      }
    }
733
  } else {
Guolin Ke's avatar
Guolin Ke committed
734
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
735
    for (data_size_t i = 0; i < num_data; ++i) {
Guolin Ke's avatar
Guolin Ke committed
736
      std::vector<double> tmp_result(num_tree_per_iteration_);
737
      for (int j = 0; j < num_tree_per_iteration_; ++j) {
Guolin Ke's avatar
Guolin Ke committed
738
        out_result[j * num_data + i] = static_cast<double>(raw_scores[j * num_data + i]);
Guolin Ke's avatar
Guolin Ke committed
739
740
741
742
743
      }
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
744
745
void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
                             const std::vector<const Metric*>& training_metrics) {
Guolin Ke's avatar
Guolin Ke committed
746

Guolin Ke's avatar
Guolin Ke committed
747
748
  if (train_data != train_data_ && !train_data_->CheckAlign(*train_data)) {
    Log::Fatal("cannot reset training data, since new training data has different bin mappers");
wxchan's avatar
wxchan committed
749
750
  }

Guolin Ke's avatar
Guolin Ke committed
751
752
753
754
755
756
  objective_function_ = objective_function;
  if (objective_function_ != nullptr) {
    is_constant_hessian_ = objective_function_->IsConstantHessian();
    CHECK(num_tree_per_iteration_ == objective_function_->NumModelPerIteration());
  } else {
    is_constant_hessian_ = false;
757
758
  }

Guolin Ke's avatar
Guolin Ke committed
759
760
761
762
  // push training metrics
  training_metrics_.clear();
  for (const auto& metric : training_metrics) {
    training_metrics_.push_back(metric);
763
  }
Guolin Ke's avatar
Guolin Ke committed
764
  training_metrics_.shrink_to_fit();
765

Guolin Ke's avatar
Guolin Ke committed
766
767
768
769
770
  if (train_data != train_data_) {
    train_data_ = train_data;
    // not same training data, need reset score and others
    // create score tracker
    train_score_updater_.reset(new ScoreUpdater(train_data_, num_tree_per_iteration_));
771

Guolin Ke's avatar
Guolin Ke committed
772
773
774
775
776
777
    // update score
    for (int i = 0; i < iter_; ++i) {
      for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
        auto curr_tree = (i + num_init_iteration_) * num_tree_per_iteration_ + cur_tree_id;
        train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
      }
778
779
    }

Guolin Ke's avatar
Guolin Ke committed
780
    num_data_ = train_data_->num_data();
781

Guolin Ke's avatar
Guolin Ke committed
782
783
784
785
786
787
    // create buffer for gradients and hessians
    if (objective_function_ != nullptr) {
      size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
      gradients_.resize(total_size);
      hessians_.resize(total_size);
    }
788

Guolin Ke's avatar
Guolin Ke committed
789
790
791
792
    max_feature_idx_ = train_data_->num_total_features() - 1;
    label_idx_ = train_data_->label_idx();
    feature_names_ = train_data_->feature_names();
    feature_infos_ = train_data_->feature_infos();
793

Guolin Ke's avatar
Guolin Ke committed
794
795
    tree_learner_->ResetTrainingData(train_data);
    ResetBaggingConfig(gbdt_config_.get(), true);
796
  }
797
798
}

Guolin Ke's avatar
Guolin Ke committed
799
800
801
802
803
804
void GBDT::ResetConfig(const BoostingConfig* config) {
  auto new_config = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
  early_stopping_round_ = new_config->early_stopping_round;
  shrinkage_rate_ = new_config->learning_rate;
  if (tree_learner_ != nullptr) {
    tree_learner_->ResetConfig(&new_config->tree_config);
805
  }
Guolin Ke's avatar
Guolin Ke committed
806
807
  if (train_data_ != nullptr) {
    ResetBaggingConfig(new_config.get(), false);
808
  }
Guolin Ke's avatar
Guolin Ke committed
809
  gbdt_config_.reset(new_config.release());
Guolin Ke's avatar
Guolin Ke committed
810
811
}

Guolin Ke's avatar
Guolin Ke committed
812
813
814
815
816
817
818
void GBDT::ResetBaggingConfig(const BoostingConfig* config, bool is_change_dataset) {
  // if need bagging, create buffer
  if (config->bagging_fraction < 1.0 && config->bagging_freq > 0) {
    bag_data_cnt_ =
      static_cast<data_size_t>(config->bagging_fraction * num_data_);
    bag_data_indices_.resize(num_data_);
    tmp_indices_.resize(num_data_);
819

Guolin Ke's avatar
Guolin Ke committed
820
821
822
823
824
    offsets_buf_.resize(num_threads_);
    left_cnts_buf_.resize(num_threads_);
    right_cnts_buf_.resize(num_threads_);
    left_write_pos_buf_.resize(num_threads_);
    right_write_pos_buf_.resize(num_threads_);
825

Guolin Ke's avatar
Guolin Ke committed
826
827
828
829
830
831
    double average_bag_rate = config->bagging_fraction / config->bagging_freq;
    int sparse_group = 0;
    for (int i = 0; i < train_data_->num_feature_groups(); ++i) {
      if (train_data_->FeatureGroupIsSparse(i)) {
        ++sparse_group;
      }
Guolin Ke's avatar
Guolin Ke committed
832
    }
Guolin Ke's avatar
Guolin Ke committed
833
834
835
836
837
838
839
840
841
842
843
    is_use_subset_ = false;
    const int group_threshold_usesubset = 100;
    const int sparse_group_threshold_usesubset = train_data_->num_feature_groups() / 4;
    if (average_bag_rate <= 0.5
        && (train_data_->num_feature_groups() < group_threshold_usesubset || sparse_group < sparse_group_threshold_usesubset)) {
      if (tmp_subset_ == nullptr || is_change_dataset) {
        tmp_subset_.reset(new Dataset(bag_data_cnt_));
        tmp_subset_->CopyFeatureMapperFrom(train_data_);
      }
      is_use_subset_ = true;
      Log::Debug("use subset for bagging");
Guolin Ke's avatar
Guolin Ke committed
844
845
    }

Guolin Ke's avatar
Guolin Ke committed
846
847
    if (is_change_dataset) {
      need_re_bagging_ = true;
Guolin Ke's avatar
Guolin Ke committed
848
    }
849

Guolin Ke's avatar
Guolin Ke committed
850
851
852
853
854
    if (is_use_subset_ && bag_data_cnt_ < num_data_) {
      if (objective_function_ == nullptr) {
        size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
        gradients_.resize(total_size);
        hessians_.resize(total_size);
855
      }
856
    }
857
  } else {
Guolin Ke's avatar
Guolin Ke committed
858
859
860
861
    bag_data_cnt_ = num_data_;
    bag_data_indices_.clear();
    tmp_indices_.clear();
    is_use_subset_ = false;
862
  }
wxchan's avatar
wxchan committed
863
864
}

Guolin Ke's avatar
Guolin Ke committed
865
}  // namespace LightGBM