gbdt.cpp 29.6 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
#include "gbdt.h"

3
#include <LightGBM/utils/openmp_wrapper.h>
4

Guolin Ke's avatar
Guolin Ke committed
5
6
7
#include <LightGBM/utils/common.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
cbecker's avatar
cbecker committed
8
#include <LightGBM/prediction_early_stop.h>
Guolin Ke's avatar
Guolin Ke committed
9
#include <LightGBM/network.h>
Guolin Ke's avatar
Guolin Ke committed
10
11
12
13
14
15
16

#include <ctime>

#include <sstream>
#include <chrono>
#include <string>
#include <vector>
17
#include <utility>
Guolin Ke's avatar
Guolin Ke committed
18
19
20

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
21
22
23
#ifdef TIMETAG
std::chrono::duration<double, std::milli> boosting_time;
std::chrono::duration<double, std::milli> train_score_time;
Guolin Ke's avatar
Guolin Ke committed
24
std::chrono::duration<double, std::milli> out_of_bag_score_time;
Guolin Ke's avatar
Guolin Ke committed
25
26
27
28
29
30
std::chrono::duration<double, std::milli> valid_score_time;
std::chrono::duration<double, std::milli> metric_time;
std::chrono::duration<double, std::milli> bagging_time;
std::chrono::duration<double, std::milli> tree_time;
#endif // TIMETAG

Guolin Ke's avatar
Guolin Ke committed
31
GBDT::GBDT() : iter_(0),
Guolin Ke's avatar
Guolin Ke committed
32
33
34
35
36
37
38
39
40
41
train_data_(nullptr),
objective_function_(nullptr),
early_stopping_round_(0),
max_feature_idx_(0),
num_tree_per_iteration_(1),
num_class_(1),
num_iteration_for_pred_(0),
shrinkage_rate_(0.1f),
num_init_iteration_(0),
need_re_bagging_(false) {
Guolin Ke's avatar
Guolin Ke committed
42

Guolin Ke's avatar
Guolin Ke committed
43
44
  #pragma omp parallel
  #pragma omp master
Guolin Ke's avatar
Guolin Ke committed
45
46
47
48
  {
    num_threads_ = omp_get_num_threads();
  }
  average_output_ = false;
Guolin Ke's avatar
Guolin Ke committed
49
  tree_learner_ = nullptr;
Guolin Ke's avatar
Guolin Ke committed
50
51
52
}

GBDT::~GBDT() {
Guolin Ke's avatar
Guolin Ke committed
53
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
54
55
  Log::Info("GBDT::boosting costs %f", boosting_time * 1e-3);
  Log::Info("GBDT::train_score costs %f", train_score_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
56
  Log::Info("GBDT::out_of_bag_score costs %f", out_of_bag_score_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
57
58
59
60
  Log::Info("GBDT::valid_score costs %f", valid_score_time * 1e-3);
  Log::Info("GBDT::metric costs %f", metric_time * 1e-3);
  Log::Info("GBDT::bagging costs %f", bagging_time * 1e-3);
  Log::Info("GBDT::tree costs %f", tree_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
61
  #endif
Guolin Ke's avatar
Guolin Ke committed
62
63
}

Guolin Ke's avatar
Guolin Ke committed
64
void GBDT::Init(const Config* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
65
                const std::vector<const Metric*>& training_metrics) {
Guolin Ke's avatar
Guolin Ke committed
66
  CHECK(train_data != nullptr);
Guolin Ke's avatar
Guolin Ke committed
67
  CHECK(train_data->num_features() > 0);
68
  train_data_ = train_data;
69
  iter_ = 0;
wxchan's avatar
wxchan committed
70
  num_iteration_for_pred_ = 0;
71
  max_feature_idx_ = 0;
wxchan's avatar
wxchan committed
72
  num_class_ = config->num_class;
Guolin Ke's avatar
Guolin Ke committed
73
74
75
  config_ = std::unique_ptr<Config>(new Config(*config));
  early_stopping_round_ = config_->early_stopping_round;
  shrinkage_rate_ = config_->learning_rate;
76

77
78
79
80
81
82
83
84
85
86
  std::string forced_splits_path = config->forcedsplits_filename;
  //load forced_splits file
  if (forced_splits_path != "") {
      std::ifstream forced_splits_file(forced_splits_path.c_str());
      std::stringstream buffer;
      buffer << forced_splits_file.rdbuf();
      std::string err;
      forced_splits_json_ = Json::parse(buffer.str(), err);
  }

87
88
89
90
  objective_function_ = objective_function;
  num_tree_per_iteration_ = num_class_;
  if (objective_function_ != nullptr) {
    is_constant_hessian_ = objective_function_->IsConstantHessian();
Guolin Ke's avatar
Guolin Ke committed
91
    num_tree_per_iteration_ = objective_function_->NumModelPerIteration();
92
93
94
95
  } else {
    is_constant_hessian_ = false;
  }

Guolin Ke's avatar
Guolin Ke committed
96
  tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(config_->tree_learner, config_->device_type, config_.get()));
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

  // init tree learner
  tree_learner_->Init(train_data_, is_constant_hessian_);

  // push training metrics
  training_metrics_.clear();
  for (const auto& metric : training_metrics) {
    training_metrics_.push_back(metric);
  }
  training_metrics_.shrink_to_fit();

  train_score_updater_.reset(new ScoreUpdater(train_data_, num_tree_per_iteration_));

  num_data_ = train_data_->num_data();
  // create buffer for gradients and hessians
  if (objective_function_ != nullptr) {
    size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
    gradients_.resize(total_size);
    hessians_.resize(total_size);
  }
  // get max feature index
  max_feature_idx_ = train_data_->num_total_features() - 1;
  // get label index
  label_idx_ = train_data_->label_idx();
  // get feature names
  feature_names_ = train_data_->feature_names();
  feature_infos_ = train_data_->feature_infos();

  // if need bagging, create buffer
Guolin Ke's avatar
Guolin Ke committed
126
  ResetBaggingConfig(config_.get(), true);
127
128
129
130
131

  // reset config for tree learner
  class_need_train_ = std::vector<bool>(num_tree_per_iteration_, true);
  if (objective_function_ != nullptr && objective_function_->SkipEmptyClass()) {
    CHECK(num_tree_per_iteration_ == num_class_);
Guolin Ke's avatar
Guolin Ke committed
132

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    class_default_output_ = std::vector<double>(num_tree_per_iteration_, 0.0f);
    auto label = train_data_->metadata().label();
    if (num_tree_per_iteration_ > 1) {
      // multi-class
      std::vector<data_size_t> cnt_per_class(num_tree_per_iteration_, 0);
      for (data_size_t i = 0; i < num_data_; ++i) {
        int index = static_cast<int>(label[i]);
        CHECK(index < num_tree_per_iteration_);
        ++cnt_per_class[index];
      }
      for (int i = 0; i < num_tree_per_iteration_; ++i) {
        if (cnt_per_class[i] == num_data_) {
          class_need_train_[i] = false;
          class_default_output_[i] = -std::log(kEpsilon);
        } else if (cnt_per_class[i] == 0) {
          class_need_train_[i] = false;
          class_default_output_[i] = -std::log(1.0f / kEpsilon - 1.0f);
        }
      }
    } else {
      // binary class
      data_size_t cnt_pos = 0;
      for (data_size_t i = 0; i < num_data_; ++i) {
        if (label[i] > 0) {
          ++cnt_pos;
        }
      }
      if (cnt_pos == 0) {
        class_need_train_[0] = false;
        class_default_output_[0] = -std::log(1.0f / kEpsilon - 1.0f);
      } else if (cnt_pos == num_data_) {
        class_need_train_[0] = false;
        class_default_output_[0] = -std::log(kEpsilon);
      }
    }
  }
wxchan's avatar
wxchan committed
169
170
171
}

void GBDT::AddValidDataset(const Dataset* valid_data,
172
                           const std::vector<const Metric*>& valid_metrics) {
wxchan's avatar
wxchan committed
173
  if (!train_data_->CheckAlign(*valid_data)) {
174
    Log::Fatal("Cannot add validation data, since it has different bin mappers with training data");
175
  }
Guolin Ke's avatar
Guolin Ke committed
176
  // for a validation dataset, we need its score and metric
177
  auto new_score_updater = std::unique_ptr<ScoreUpdater>(new ScoreUpdater(valid_data, num_tree_per_iteration_));
wxchan's avatar
wxchan committed
178
179
  // update score
  for (int i = 0; i < iter_; ++i) {
180
181
182
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
      auto curr_tree = (i + num_init_iteration_) * num_tree_per_iteration_ + cur_tree_id;
      new_score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
183
184
    }
  }
Guolin Ke's avatar
Guolin Ke committed
185
  valid_score_updater_.push_back(std::move(new_score_updater));
Guolin Ke's avatar
Guolin Ke committed
186
  valid_metrics_.emplace_back();
187
188
189
  if (early_stopping_round_ > 0) {
    best_iter_.emplace_back();
    best_score_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
190
    best_msg_.emplace_back();
191
  }
Guolin Ke's avatar
Guolin Ke committed
192
193
  for (const auto& metric : valid_metrics) {
    valid_metrics_.back().push_back(metric);
194
195
196
    if (early_stopping_round_ > 0) {
      best_iter_.back().push_back(0);
      best_score_.back().push_back(kMinScore);
Guolin Ke's avatar
Guolin Ke committed
197
      best_msg_.back().emplace_back();
198
    }
Guolin Ke's avatar
Guolin Ke committed
199
  }
Guolin Ke's avatar
Guolin Ke committed
200
  valid_metrics_.back().shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
201
202
}

Guolin Ke's avatar
Guolin Ke committed
203
204
205
206
207
208
209
210
211
212
void GBDT::Boosting() {
  if (objective_function_ == nullptr) {
    Log::Fatal("No object function provided");
  }
  // objective function will calculate gradients and hessians
  int64_t num_score = 0;
  objective_function_->
    GetGradients(GetTrainingScore(&num_score), gradients_.data(), hessians_.data());
}

213
data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer) {
214
215
216
  if (cnt <= 0) {
    return 0;
  }
Guolin Ke's avatar
Guolin Ke committed
217
  data_size_t bag_data_cnt = static_cast<data_size_t>(config_->bagging_fraction * cnt);
218
219
  data_size_t cur_left_cnt = 0;
  data_size_t cur_right_cnt = 0;
Guolin Ke's avatar
Guolin Ke committed
220
  auto right_buffer = buffer + bag_data_cnt;
221
222
  // random bagging, minimal unit is one record
  for (data_size_t i = 0; i < cnt; ++i) {
Guolin Ke's avatar
Guolin Ke committed
223
    float prob = (bag_data_cnt - cur_left_cnt) / static_cast<float>(cnt - i);
Guolin Ke's avatar
Guolin Ke committed
224
    if (cur_rand.NextFloat() < prob) {
225
226
      buffer[cur_left_cnt++] = start + i;
    } else {
Guolin Ke's avatar
Guolin Ke committed
227
      right_buffer[cur_right_cnt++] = start + i;
228
229
230
231
232
    }
  }
  CHECK(cur_left_cnt == bag_data_cnt);
  return cur_left_cnt;
}
Guolin Ke's avatar
Guolin Ke committed
233

234
void GBDT::Bagging(int iter) {
Guolin Ke's avatar
Guolin Ke committed
235
  // if need bagging
Guolin Ke's avatar
Guolin Ke committed
236
  if ((bag_data_cnt_ < num_data_ && iter % config_->bagging_freq == 0)
Guolin Ke's avatar
Guolin Ke committed
237
      || need_re_bagging_) {
Guolin Ke's avatar
Guolin Ke committed
238
    need_re_bagging_ = false;
Guolin Ke's avatar
Guolin Ke committed
239
    const data_size_t min_inner_size = 1000;
240
241
    data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_;
    if (inner_size < min_inner_size) { inner_size = min_inner_size; }
242
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
243
    #pragma omp parallel for schedule(static,1)
244
    for (int i = 0; i < num_threads_; ++i) {
245
      OMP_LOOP_EX_BEGIN();
246
247
248
249
250
251
      left_cnts_buf_[i] = 0;
      right_cnts_buf_[i] = 0;
      data_size_t cur_start = i * inner_size;
      if (cur_start > num_data_) { continue; }
      data_size_t cur_cnt = inner_size;
      if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; }
Guolin Ke's avatar
Guolin Ke committed
252
      Random cur_rand(config_->bagging_seed + iter * num_threads_ + i);
Guolin Ke's avatar
Guolin Ke committed
253
      data_size_t cur_left_count = BaggingHelper(cur_rand, cur_start, cur_cnt, tmp_indices_.data() + cur_start);
254
255
256
      offsets_buf_[i] = cur_start;
      left_cnts_buf_[i] = cur_left_count;
      right_cnts_buf_[i] = cur_cnt - cur_left_count;
257
      OMP_LOOP_EX_END();
258
    }
259
    OMP_THROW_EX();
260
261
262
263
264
265
266
267
268
    data_size_t left_cnt = 0;
    left_write_pos_buf_[0] = 0;
    right_write_pos_buf_[0] = 0;
    for (int i = 1; i < num_threads_; ++i) {
      left_write_pos_buf_[i] = left_write_pos_buf_[i - 1] + left_cnts_buf_[i - 1];
      right_write_pos_buf_[i] = right_write_pos_buf_[i - 1] + right_cnts_buf_[i - 1];
    }
    left_cnt = left_write_pos_buf_[num_threads_ - 1] + left_cnts_buf_[num_threads_ - 1];

Guolin Ke's avatar
Guolin Ke committed
269
    #pragma omp parallel for schedule(static, 1)
270
    for (int i = 0; i < num_threads_; ++i) {
271
      OMP_LOOP_EX_BEGIN();
272
273
      if (left_cnts_buf_[i] > 0) {
        std::memcpy(bag_data_indices_.data() + left_write_pos_buf_[i],
274
                    tmp_indices_.data() + offsets_buf_[i], left_cnts_buf_[i] * sizeof(data_size_t));
Guolin Ke's avatar
Guolin Ke committed
275
      }
276
277
      if (right_cnts_buf_[i] > 0) {
        std::memcpy(bag_data_indices_.data() + left_cnt + right_write_pos_buf_[i],
278
                    tmp_indices_.data() + offsets_buf_[i] + left_cnts_buf_[i], right_cnts_buf_[i] * sizeof(data_size_t));
Guolin Ke's avatar
Guolin Ke committed
279
      }
280
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
281
    }
282
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
283
    bag_data_cnt_ = left_cnt;
Guolin Ke's avatar
Guolin Ke committed
284
    Log::Debug("Re-bagging, using %d data to train", bag_data_cnt_);
Guolin Ke's avatar
Guolin Ke committed
285
    // set bagging data to tree learner
Guolin Ke's avatar
Guolin Ke committed
286
287
288
289
    if (!is_use_subset_) {
      tree_learner_->SetBaggingData(bag_data_indices_.data(), bag_data_cnt_);
    } else {
      // get subset
Guolin Ke's avatar
Guolin Ke committed
290
291
      tmp_subset_->ReSize(bag_data_cnt_);
      tmp_subset_->CopySubset(train_data_, bag_data_indices_.data(), bag_data_cnt_, false);
Guolin Ke's avatar
Guolin Ke committed
292
293
      tree_learner_->ResetTrainingData(tmp_subset_.get());
    }
Guolin Ke's avatar
Guolin Ke committed
294
295
296
  }
}

297
/* If the custom "average" is implemented it will be used inplace of the label average (if enabled)
Guolin Ke's avatar
Guolin Ke committed
298
299
300
301
302
303
304
305
306
*
* An improvement to this is to have options to explicitly choose
* (i) standard average
* (ii) custom average if available
* (iii) any user defined scalar bias (e.g. using a new option "init_score" that overrides (i) and (ii) )
*
* (i) and (ii) could be selected as say "auto_init_score" = 0 or 1 etc..
*
*/
307
double ObtainAutomaticInitialScore(const ObjectiveFunction* fobj) {
308
  double init_score = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
309
  if (fobj != nullptr) {
310
    init_score = fobj->BoostFromScore();
Guolin Ke's avatar
Guolin Ke committed
311
312
  }
  if (Network::num_machines() > 1) {
313
    init_score = Network::GlobalSyncUpByMean(init_score);
Guolin Ke's avatar
Guolin Ke committed
314
  }
315
  return init_score;
Guolin Ke's avatar
Guolin Ke committed
316
317
}

Guolin Ke's avatar
Guolin Ke committed
318
319
320
void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
  bool is_finished = false;
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
321
  for (int iter = 0; iter < config_->num_iterations && !is_finished; ++iter) {
Guolin Ke's avatar
Guolin Ke committed
322
323
324
325
    is_finished = TrainOneIter(nullptr, nullptr);
    if (!is_finished) {
      is_finished = EvalAndCheckEarlyStopping();
    }
Guolin Ke's avatar
Guolin Ke committed
326
327
328
329
330
331
332
    auto end_time = std::chrono::steady_clock::now();
    // output used time per iteration
    Log::Info("%f seconds elapsed, finished iteration %d", std::chrono::duration<double,
              std::milli>(end_time - start_time) * 1e-3, iter + 1);
    if (snapshot_freq > 0
        && (iter + 1) % snapshot_freq == 0) {
      std::string snapshot_out = model_output_path + ".snapshot_iter_" + std::to_string(iter + 1);
333
      SaveModelToFile(0, -1, snapshot_out.c_str());
Guolin Ke's avatar
Guolin Ke committed
334
335
336
337
    }
  }
}

338
339
340
341
342
343
344
345
346
347
348
349
350
void GBDT::RefitTree(const std::vector<std::vector<int>>& tree_leaf_prediction) {
  CHECK(tree_leaf_prediction.size() > 0);
  CHECK(static_cast<size_t>(num_data_) == tree_leaf_prediction.size());
  CHECK(static_cast<size_t>(models_.size()) == tree_leaf_prediction[0].size());
  int num_iterations = static_cast<int>(models_.size() / num_tree_per_iteration_);
  std::vector<int> leaf_pred(num_data_);
  for (int iter = 0; iter < num_iterations; ++iter) {
    Boosting();
    for (int tree_id = 0; tree_id < num_tree_per_iteration_; ++tree_id) {
      int model_index = iter * num_tree_per_iteration_ + tree_id;
      #pragma omp parallel for schedule(static)
      for (int i = 0; i < num_data_; ++i) {
        leaf_pred[i] = tree_leaf_prediction[i][model_index];
Guolin Ke's avatar
Guolin Ke committed
351
        CHECK(leaf_pred[i] < models_[model_index]->num_leaves());
352
353
354
355
356
357
358
359
360
361
362
      }
      size_t bias = static_cast<size_t>(tree_id) * num_data_;
      auto grad = gradients_.data() + bias;
      auto hess = hessians_.data() + bias;
      auto new_tree = tree_learner_->FitByExistingTree(models_[model_index].get(), leaf_pred, grad, hess);
      train_score_updater_->AddScore(tree_learner_.get(), new_tree, tree_id);
      models_[model_index].reset(new_tree);
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
363
double GBDT::BoostFromAverage() {
364
  // boosting from average label; or customized "average" if implemented for the current objective
365
  if (models_.empty() && !train_score_updater_->has_init_score()
366
      && num_class_ <= 1
367
      && objective_function_ != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
368
    if (config_->boost_from_average) {
369
370
371
372
373
374
375
376
      double init_score = ObtainAutomaticInitialScore(objective_function_);
      if (std::fabs(init_score) > kEpsilon) {
        train_score_updater_->AddScore(init_score, 0);
        for (auto& score_updater : valid_score_updater_) {
          score_updater->AddScore(init_score, 0);
        }
        Log::Info("Start training from score %lf", init_score);
        return init_score;
Guolin Ke's avatar
Guolin Ke committed
377
      }
378
379
380
    } else if (std::string(objective_function_->GetName()) == std::string("regression_l1")
               || std::string(objective_function_->GetName()) == std::string("quantile")
               || std::string(objective_function_->GetName()) == std::string("mape")) {
381
      Log::Warning("Disabling boost_from_average in %s may cause the slow convergence", objective_function_->GetName());
382
    }
383
  }
Guolin Ke's avatar
Guolin Ke committed
384
385
  return 0.0f;
}
Guolin Ke's avatar
Guolin Ke committed
386

Guolin Ke's avatar
Guolin Ke committed
387
bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
388
  double init_score = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
389
  // boosting first
Guolin Ke's avatar
Guolin Ke committed
390
  if (gradients == nullptr || hessians == nullptr) {
391
    init_score = BoostFromAverage();
Guolin Ke's avatar
Guolin Ke committed
392
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
393
    auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
394
    #endif
Guolin Ke's avatar
Guolin Ke committed
395

Guolin Ke's avatar
Guolin Ke committed
396
    Boosting();
Guolin Ke's avatar
Guolin Ke committed
397
398
399
    gradients = gradients_.data();
    hessians = hessians_.data();

Guolin Ke's avatar
Guolin Ke committed
400
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
401
    boosting_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
402
    #endif
Guolin Ke's avatar
Guolin Ke committed
403
  }
Guolin Ke's avatar
Guolin Ke committed
404

Guolin Ke's avatar
Guolin Ke committed
405
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
406
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
407
  #endif
Guolin Ke's avatar
Guolin Ke committed
408

409
410
  // bagging logic
  Bagging(iter_);
Guolin Ke's avatar
Guolin Ke committed
411

Guolin Ke's avatar
Guolin Ke committed
412
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
413
  bagging_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
414
  #endif
Guolin Ke's avatar
Guolin Ke committed
415

Guolin Ke's avatar
Guolin Ke committed
416
  bool should_continue = false;
417
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
418

Guolin Ke's avatar
Guolin Ke committed
419
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
420
    start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
421
    #endif
422
    const size_t bias = static_cast<size_t>(cur_tree_id) * num_data_;
423
    std::unique_ptr<Tree> new_tree(new Tree(2));
424
    if (class_need_train_[cur_tree_id]) {
Guolin Ke's avatar
Guolin Ke committed
425
426
427
428
429
430
431
432
433
434
435
436
437
      auto grad = gradients + bias;
      auto hess = hessians + bias;

      // need to copy gradients for bagging subset.
      if (is_use_subset_ && bag_data_cnt_ < num_data_) {
        for (int i = 0; i < bag_data_cnt_; ++i) {
          gradients_[bias + i] = grad[bag_data_indices_[i]];
          hessians_[bias + i] = hess[bag_data_indices_[i]];
        }
        grad = gradients_.data() + bias;
        hess = hessians_.data() + bias;
      }

438
      new_tree.reset(tree_learner_->Train(grad, hess, is_constant_hessian_, forced_splits_json_));
439
    }
Guolin Ke's avatar
Guolin Ke committed
440

Guolin Ke's avatar
Guolin Ke committed
441
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
442
    tree_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
443
    #endif
Guolin Ke's avatar
Guolin Ke committed
444
445

    if (new_tree->num_leaves() > 1) {
Guolin Ke's avatar
Guolin Ke committed
446
      should_continue = true;
447
448
      tree_learner_->RenewTreeOutput(new_tree.get(), objective_function_, train_score_updater_->score() + bias,
                                     num_data_, bag_data_indices_.data(), bag_data_cnt_);
Guolin Ke's avatar
Guolin Ke committed
449
450
451
      // shrinkage by learning rate
      new_tree->Shrinkage(shrinkage_rate_);
      // update score
452
      UpdateScore(new_tree.get(), cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
453
454
455
      if (std::fabs(init_score) > kEpsilon) {
        new_tree->AddBias(init_score);
      }
456
457
    } else {
      // only add default score one-time
458
459
      if (!class_need_train_[cur_tree_id] && models_.size() < static_cast<size_t>(num_tree_per_iteration_)) {
        auto output = class_default_output_[cur_tree_id];
460
        new_tree->AsConstantTree(output);
Guolin Ke's avatar
Guolin Ke committed
461
        // updates scores
462
        train_score_updater_->AddScore(output, cur_tree_id);
463
        for (auto& score_updater : valid_score_updater_) {
464
          score_updater->AddScore(output, cur_tree_id);
465
466
467
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
468
469
470
    // add model
    models_.push_back(std::move(new_tree));
  }
Guolin Ke's avatar
Guolin Ke committed
471

Guolin Ke's avatar
Guolin Ke committed
472
  if (!should_continue) {
473
    Log::Warning("Stopped training because there are no more leaves that meet the split requirements");
474
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
475
476
477
478
      models_.pop_back();
    }
    return true;
  }
479

Guolin Ke's avatar
Guolin Ke committed
480
481
  ++iter_;
  return false;
Guolin Ke's avatar
Guolin Ke committed
482
}
483

wxchan's avatar
wxchan committed
484
void GBDT::RollbackOneIter() {
485
  if (iter_ <= 0) { return; }
wxchan's avatar
wxchan committed
486
  // reset score
487
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
488
    auto curr_tree = models_.size() - num_tree_per_iteration_ + cur_tree_id;
wxchan's avatar
wxchan committed
489
    models_[curr_tree]->Shrinkage(-1.0);
490
    train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
491
    for (auto& score_updater : valid_score_updater_) {
492
      score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
493
494
495
    }
  }
  // remove model
496
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
wxchan's avatar
wxchan committed
497
498
499
500
501
    models_.pop_back();
  }
  --iter_;
}

Guolin Ke's avatar
Guolin Ke committed
502
bool GBDT::EvalAndCheckEarlyStopping() {
503
  bool is_met_early_stopping = false;
Guolin Ke's avatar
Guolin Ke committed
504

Guolin Ke's avatar
Guolin Ke committed
505
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
506
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
507
  #endif
Guolin Ke's avatar
Guolin Ke committed
508

509
  // print message for metric
Guolin Ke's avatar
Guolin Ke committed
510
  auto best_msg = OutputMetric(iter_);
Guolin Ke's avatar
Guolin Ke committed
511

Guolin Ke's avatar
Guolin Ke committed
512
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
513
  metric_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
514
  #endif
Guolin Ke's avatar
Guolin Ke committed
515

Guolin Ke's avatar
Guolin Ke committed
516
  is_met_early_stopping = !best_msg.empty();
517
518
  if (is_met_early_stopping) {
    Log::Info("Early stopping at iteration %d, the best iteration round is %d",
519
              iter_, iter_ - early_stopping_round_);
Guolin Ke's avatar
Guolin Ke committed
520
    Log::Info("Output of best iteration round:\n%s", best_msg.c_str());
521
    // pop last early_stopping_round_ models
522
    for (int i = 0; i < early_stopping_round_ * num_tree_per_iteration_; ++i) {
523
524
525
526
      models_.pop_back();
    }
  }
  return is_met_early_stopping;
Guolin Ke's avatar
Guolin Ke committed
527
528
}

529
void GBDT::UpdateScore(const Tree* tree, const int cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
530

Guolin Ke's avatar
Guolin Ke committed
531
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
532
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
533
  #endif
Guolin Ke's avatar
Guolin Ke committed
534

Guolin Ke's avatar
Guolin Ke committed
535
  // update training score
Guolin Ke's avatar
Guolin Ke committed
536
  if (!is_use_subset_) {
537
    train_score_updater_->AddScore(tree_learner_.get(), tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555

    #ifdef TIMETAG
    train_score_time += std::chrono::steady_clock::now() - start_time;
    #endif

    #ifdef TIMETAG
    start_time = std::chrono::steady_clock::now();
    #endif

    // we need to predict out-of-bag scores of data for boosting
    if (num_data_ - bag_data_cnt_ > 0) {
      train_score_updater_->AddScore(tree, bag_data_indices_.data() + bag_data_cnt_, num_data_ - bag_data_cnt_, cur_tree_id);
    }

    #ifdef TIMETAG
    out_of_bag_score_time += std::chrono::steady_clock::now() - start_time;
    #endif

Guolin Ke's avatar
Guolin Ke committed
556
  } else {
557
    train_score_updater_->AddScore(tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
558
559
560
561

    #ifdef TIMETAG
    train_score_time += std::chrono::steady_clock::now() - start_time;
    #endif
Guolin Ke's avatar
Guolin Ke committed
562
  }
Guolin Ke's avatar
Guolin Ke committed
563
564


Guolin Ke's avatar
Guolin Ke committed
565
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
566
  start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
567
  #endif
Guolin Ke's avatar
Guolin Ke committed
568

Guolin Ke's avatar
Guolin Ke committed
569
  // update validation score
Guolin Ke's avatar
Guolin Ke committed
570
  for (auto& score_updater : valid_score_updater_) {
571
    score_updater->AddScore(tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
572
  }
Guolin Ke's avatar
Guolin Ke committed
573

Guolin Ke's avatar
Guolin Ke committed
574
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
575
  valid_score_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
576
  #endif
Guolin Ke's avatar
Guolin Ke committed
577
578
}

Guolin Ke's avatar
Guolin Ke committed
579
580
581
582
std::vector<double> GBDT::EvalOneMetric(const Metric* metric, const double* score) const {
  return metric->Eval(score, objective_function_);
}

Guolin Ke's avatar
Guolin Ke committed
583
std::string GBDT::OutputMetric(int iter) {
Guolin Ke's avatar
Guolin Ke committed
584
  bool need_output = (iter % config_->metric_freq) == 0;
Guolin Ke's avatar
Guolin Ke committed
585
586
  std::string ret = "";
  std::stringstream msg_buf;
587
  std::vector<std::pair<size_t, size_t>> meet_early_stopping_pairs;
Guolin Ke's avatar
Guolin Ke committed
588
  // print training metric
Guolin Ke's avatar
Guolin Ke committed
589
  if (need_output) {
590
591
    for (auto& sub_metric : training_metrics_) {
      auto name = sub_metric->GetName();
Guolin Ke's avatar
Guolin Ke committed
592
      auto scores = EvalOneMetric(sub_metric, train_score_updater_->score());
Guolin Ke's avatar
Guolin Ke committed
593
      for (size_t k = 0; k < name.size(); ++k) {
Guolin Ke's avatar
Guolin Ke committed
594
595
596
597
598
599
        std::stringstream tmp_buf;
        tmp_buf << "Iteration:" << iter
          << ", training " << name[k]
          << " : " << scores[k];
        Log::Info(tmp_buf.str().c_str());
        if (early_stopping_round_ > 0) {
600
          msg_buf << tmp_buf.str() << '\n';
Guolin Ke's avatar
Guolin Ke committed
601
        }
602
      }
603
    }
Guolin Ke's avatar
Guolin Ke committed
604
605
  }
  // print validation metric
Guolin Ke's avatar
Guolin Ke committed
606
  if (need_output || early_stopping_round_ > 0) {
607
608
    for (size_t i = 0; i < valid_metrics_.size(); ++i) {
      for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
609
        auto test_scores = EvalOneMetric(valid_metrics_[i][j], valid_score_updater_[i]->score());
Guolin Ke's avatar
Guolin Ke committed
610
611
612
613
614
615
616
617
618
619
        auto name = valid_metrics_[i][j]->GetName();
        for (size_t k = 0; k < name.size(); ++k) {
          std::stringstream tmp_buf;
          tmp_buf << "Iteration:" << iter
            << ", valid_" << i + 1 << " " << name[k]
            << " : " << test_scores[k];
          if (need_output) {
            Log::Info(tmp_buf.str().c_str());
          }
          if (early_stopping_round_ > 0) {
620
            msg_buf << tmp_buf.str() << '\n';
621
          }
wxchan's avatar
wxchan committed
622
        }
Guolin Ke's avatar
Guolin Ke committed
623
        if (ret.empty() && early_stopping_round_ > 0) {
624
625
626
          auto cur_score = valid_metrics_[i][j]->factor_to_bigger_better() * test_scores.back();
          if (cur_score > best_score_[i][j]) {
            best_score_[i][j] = cur_score;
627
            best_iter_[i][j] = iter;
Guolin Ke's avatar
Guolin Ke committed
628
            meet_early_stopping_pairs.emplace_back(i, j);
629
          } else {
Guolin Ke's avatar
Guolin Ke committed
630
            if (iter - best_iter_[i][j] >= early_stopping_round_) { ret = best_msg_[i][j]; }
631
          }
wxchan's avatar
wxchan committed
632
633
        }
      }
Guolin Ke's avatar
Guolin Ke committed
634
635
    }
  }
Guolin Ke's avatar
Guolin Ke committed
636
637
638
  for (auto& pair : meet_early_stopping_pairs) {
    best_msg_[pair.first][pair.second] = msg_buf.str();
  }
wxchan's avatar
wxchan committed
639
  return ret;
Guolin Ke's avatar
Guolin Ke committed
640
641
}

642
/*! \brief Get eval result */
643
std::vector<double> GBDT::GetEvalAt(int data_idx) const {
Guolin Ke's avatar
Guolin Ke committed
644
  CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
645
646
  std::vector<double> ret;
  if (data_idx == 0) {
647
    for (auto& sub_metric : training_metrics_) {
Guolin Ke's avatar
Guolin Ke committed
648
      auto scores = EvalOneMetric(sub_metric, train_score_updater_->score());
649
650
651
      for (auto score : scores) {
        ret.push_back(score);
      }
652
    }
653
  } else {
654
655
    auto used_idx = data_idx - 1;
    for (size_t j = 0; j < valid_metrics_[used_idx].size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
656
      auto test_scores = EvalOneMetric(valid_metrics_[used_idx][j], valid_score_updater_[used_idx]->score());
657
658
659
      for (auto score : test_scores) {
        ret.push_back(score);
      }
660
661
662
663
664
    }
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
665
/*! \brief Get training scores result */
666
const double* GBDT::GetTrainingScore(int64_t* out_len) {
667
  *out_len = static_cast<int64_t>(train_score_updater_->num_data()) * num_class_;
Guolin Ke's avatar
Guolin Ke committed
668
  return train_score_updater_->score();
669
670
}

671
672
673
void GBDT::PredictContrib(const double* features, double* output, const PredictionEarlyStopInstance* early_stop) const {
  int early_stop_round_counter = 0;
  // set zero
Guolin Ke's avatar
Guolin Ke committed
674
675
  const int num_features = max_feature_idx_ + 1;
  std::memset(output, 0, sizeof(double) * num_tree_per_iteration_ * (num_features + 1));
676
677
678
  for (int i = 0; i < num_iteration_for_pred_; ++i) {
    // predict all the trees for one iteration
    for (int k = 0; k < num_tree_per_iteration_; ++k) {
Guolin Ke's avatar
Guolin Ke committed
679
      models_[i * num_tree_per_iteration_ + k]->PredictContrib(features, num_features, output + k*(num_features + 1));
680
681
682
683
684
685
686
687
688
689
690
691
    }
    // check early stopping
    ++early_stop_round_counter;
    if (early_stop->round_period == early_stop_round_counter) {
      if (early_stop->callback_function(output, num_tree_per_iteration_)) {
        return;
      }
      early_stop_round_counter = 0;
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
692
693
void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
  CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
Guolin Ke's avatar
Guolin Ke committed
694

695
  const double* raw_scores = nullptr;
Guolin Ke's avatar
Guolin Ke committed
696
697
  data_size_t num_data = 0;
  if (data_idx == 0) {
wxchan's avatar
wxchan committed
698
    raw_scores = GetTrainingScore(out_len);
Guolin Ke's avatar
Guolin Ke committed
699
700
701
702
703
    num_data = train_score_updater_->num_data();
  } else {
    auto used_idx = data_idx - 1;
    raw_scores = valid_score_updater_[used_idx]->score();
    num_data = valid_score_updater_[used_idx]->num_data();
704
    *out_len = static_cast<int64_t>(num_data) * num_class_;
Guolin Ke's avatar
Guolin Ke committed
705
  }
Guolin Ke's avatar
Guolin Ke committed
706
  if (objective_function_ != nullptr && !average_output_) {
Guolin Ke's avatar
Guolin Ke committed
707
708
    #pragma omp parallel for schedule(static)
    for (data_size_t i = 0; i < num_data; ++i) {
Guolin Ke's avatar
Guolin Ke committed
709
      std::vector<double> tree_pred(num_tree_per_iteration_);
710
      for (int j = 0; j < num_tree_per_iteration_; ++j) {
Guolin Ke's avatar
Guolin Ke committed
711
        tree_pred[j] = raw_scores[j * num_data + i];
712
      }
Guolin Ke's avatar
Guolin Ke committed
713
714
      std::vector<double> tmp_result(num_class_);
      objective_function_->ConvertOutput(tree_pred.data(), tmp_result.data());
Guolin Ke's avatar
Guolin Ke committed
715
      for (int j = 0; j < num_class_; ++j) {
716
        out_result[j * num_data + i] = static_cast<double>(tmp_result[j]);
Guolin Ke's avatar
Guolin Ke committed
717
718
      }
    }
719
  } else {
Guolin Ke's avatar
Guolin Ke committed
720
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
721
    for (data_size_t i = 0; i < num_data; ++i) {
722
      for (int j = 0; j < num_tree_per_iteration_; ++j) {
Guolin Ke's avatar
Guolin Ke committed
723
        out_result[j * num_data + i] = static_cast<double>(raw_scores[j * num_data + i]);
Guolin Ke's avatar
Guolin Ke committed
724
725
726
727
728
      }
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
729
730
void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
                             const std::vector<const Metric*>& training_metrics) {
Guolin Ke's avatar
Guolin Ke committed
731

Guolin Ke's avatar
Guolin Ke committed
732
  if (train_data != train_data_ && !train_data_->CheckAlign(*train_data)) {
733
    Log::Fatal("Cannot reset training data, since new training data has different bin mappers");
wxchan's avatar
wxchan committed
734
735
  }

Guolin Ke's avatar
Guolin Ke committed
736
737
738
739
740
741
  objective_function_ = objective_function;
  if (objective_function_ != nullptr) {
    is_constant_hessian_ = objective_function_->IsConstantHessian();
    CHECK(num_tree_per_iteration_ == objective_function_->NumModelPerIteration());
  } else {
    is_constant_hessian_ = false;
742
743
  }

Guolin Ke's avatar
Guolin Ke committed
744
745
746
747
  // push training metrics
  training_metrics_.clear();
  for (const auto& metric : training_metrics) {
    training_metrics_.push_back(metric);
748
  }
Guolin Ke's avatar
Guolin Ke committed
749
  training_metrics_.shrink_to_fit();
750

Guolin Ke's avatar
Guolin Ke committed
751
752
753
754
755
  if (train_data != train_data_) {
    train_data_ = train_data;
    // not same training data, need reset score and others
    // create score tracker
    train_score_updater_.reset(new ScoreUpdater(train_data_, num_tree_per_iteration_));
756

Guolin Ke's avatar
Guolin Ke committed
757
758
759
760
761
762
    // update score
    for (int i = 0; i < iter_; ++i) {
      for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
        auto curr_tree = (i + num_init_iteration_) * num_tree_per_iteration_ + cur_tree_id;
        train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
      }
763
764
    }

Guolin Ke's avatar
Guolin Ke committed
765
    num_data_ = train_data_->num_data();
766

Guolin Ke's avatar
Guolin Ke committed
767
768
769
770
771
772
    // create buffer for gradients and hessians
    if (objective_function_ != nullptr) {
      size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
      gradients_.resize(total_size);
      hessians_.resize(total_size);
    }
773

Guolin Ke's avatar
Guolin Ke committed
774
775
776
777
    max_feature_idx_ = train_data_->num_total_features() - 1;
    label_idx_ = train_data_->label_idx();
    feature_names_ = train_data_->feature_names();
    feature_infos_ = train_data_->feature_infos();
778

Guolin Ke's avatar
Guolin Ke committed
779
    tree_learner_->ResetTrainingData(train_data);
Guolin Ke's avatar
Guolin Ke committed
780
    ResetBaggingConfig(config_.get(), true);
781
  }
782
783
}

Guolin Ke's avatar
Guolin Ke committed
784
785
void GBDT::ResetConfig(const Config* config) {
  auto new_config = std::unique_ptr<Config>(new Config(*config));
Guolin Ke's avatar
Guolin Ke committed
786
787
788
  early_stopping_round_ = new_config->early_stopping_round;
  shrinkage_rate_ = new_config->learning_rate;
  if (tree_learner_ != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
789
    tree_learner_->ResetConfig(new_config.get());
790
  }
Guolin Ke's avatar
Guolin Ke committed
791
792
  if (train_data_ != nullptr) {
    ResetBaggingConfig(new_config.get(), false);
793
  }
Guolin Ke's avatar
Guolin Ke committed
794
  config_.reset(new_config.release());
Guolin Ke's avatar
Guolin Ke committed
795
796
}

Guolin Ke's avatar
Guolin Ke committed
797
void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) {
Guolin Ke's avatar
Guolin Ke committed
798
799
800
801
802
803
  // if need bagging, create buffer
  if (config->bagging_fraction < 1.0 && config->bagging_freq > 0) {
    bag_data_cnt_ =
      static_cast<data_size_t>(config->bagging_fraction * num_data_);
    bag_data_indices_.resize(num_data_);
    tmp_indices_.resize(num_data_);
804

Guolin Ke's avatar
Guolin Ke committed
805
806
807
808
809
    offsets_buf_.resize(num_threads_);
    left_cnts_buf_.resize(num_threads_);
    right_cnts_buf_.resize(num_threads_);
    left_write_pos_buf_.resize(num_threads_);
    right_write_pos_buf_.resize(num_threads_);
810

Guolin Ke's avatar
Guolin Ke committed
811
812
813
814
815
816
    double average_bag_rate = config->bagging_fraction / config->bagging_freq;
    int sparse_group = 0;
    for (int i = 0; i < train_data_->num_feature_groups(); ++i) {
      if (train_data_->FeatureGroupIsSparse(i)) {
        ++sparse_group;
      }
Guolin Ke's avatar
Guolin Ke committed
817
    }
Guolin Ke's avatar
Guolin Ke committed
818
819
820
821
822
823
824
825
826
827
    is_use_subset_ = false;
    const int group_threshold_usesubset = 100;
    const int sparse_group_threshold_usesubset = train_data_->num_feature_groups() / 4;
    if (average_bag_rate <= 0.5
        && (train_data_->num_feature_groups() < group_threshold_usesubset || sparse_group < sparse_group_threshold_usesubset)) {
      if (tmp_subset_ == nullptr || is_change_dataset) {
        tmp_subset_.reset(new Dataset(bag_data_cnt_));
        tmp_subset_->CopyFeatureMapperFrom(train_data_);
      }
      is_use_subset_ = true;
828
      Log::Debug("Use subset for bagging");
Guolin Ke's avatar
Guolin Ke committed
829
830
    }

Guolin Ke's avatar
Guolin Ke committed
831
832
    if (is_change_dataset) {
      need_re_bagging_ = true;
Guolin Ke's avatar
Guolin Ke committed
833
    }
834

Guolin Ke's avatar
Guolin Ke committed
835
836
837
838
839
    if (is_use_subset_ && bag_data_cnt_ < num_data_) {
      if (objective_function_ == nullptr) {
        size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
        gradients_.resize(total_size);
        hessians_.resize(total_size);
840
      }
841
    }
842
  } else {
Guolin Ke's avatar
Guolin Ke committed
843
844
845
846
    bag_data_cnt_ = num_data_;
    bag_data_indices_.clear();
    tmp_indices_.clear();
    is_use_subset_ = false;
847
  }
wxchan's avatar
wxchan committed
848
849
}

Guolin Ke's avatar
Guolin Ke committed
850
}  // namespace LightGBM