gbdt.cpp 29.6 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
#include "gbdt.h"

3
#include <LightGBM/utils/openmp_wrapper.h>
4

Guolin Ke's avatar
Guolin Ke committed
5
6
7
#include <LightGBM/utils/common.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
cbecker's avatar
cbecker committed
8
#include <LightGBM/prediction_early_stop.h>
Guolin Ke's avatar
Guolin Ke committed
9
#include <LightGBM/network.h>
Guolin Ke's avatar
Guolin Ke committed
10
11
12
13
14
15
16

#include <ctime>

#include <sstream>
#include <chrono>
#include <string>
#include <vector>
17
#include <utility>
Guolin Ke's avatar
Guolin Ke committed
18
19
20

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
21
22
23
#ifdef TIMETAG
std::chrono::duration<double, std::milli> boosting_time;
std::chrono::duration<double, std::milli> train_score_time;
Guolin Ke's avatar
Guolin Ke committed
24
std::chrono::duration<double, std::milli> out_of_bag_score_time;
Guolin Ke's avatar
Guolin Ke committed
25
26
27
28
29
30
std::chrono::duration<double, std::milli> valid_score_time;
std::chrono::duration<double, std::milli> metric_time;
std::chrono::duration<double, std::milli> bagging_time;
std::chrono::duration<double, std::milli> tree_time;
#endif // TIMETAG

Guolin Ke's avatar
Guolin Ke committed
31
GBDT::GBDT() : iter_(0),
Guolin Ke's avatar
Guolin Ke committed
32
33
34
35
36
37
38
39
40
41
train_data_(nullptr),
objective_function_(nullptr),
early_stopping_round_(0),
max_feature_idx_(0),
num_tree_per_iteration_(1),
num_class_(1),
num_iteration_for_pred_(0),
shrinkage_rate_(0.1f),
num_init_iteration_(0),
need_re_bagging_(false) {
Guolin Ke's avatar
Guolin Ke committed
42

Guolin Ke's avatar
Guolin Ke committed
43
44
  #pragma omp parallel
  #pragma omp master
Guolin Ke's avatar
Guolin Ke committed
45
46
47
48
  {
    num_threads_ = omp_get_num_threads();
  }
  average_output_ = false;
Guolin Ke's avatar
Guolin Ke committed
49
  tree_learner_ = nullptr;
Guolin Ke's avatar
Guolin Ke committed
50
51
52
}

GBDT::~GBDT() {
Guolin Ke's avatar
Guolin Ke committed
53
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
54
55
  Log::Info("GBDT::boosting costs %f", boosting_time * 1e-3);
  Log::Info("GBDT::train_score costs %f", train_score_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
56
  Log::Info("GBDT::out_of_bag_score costs %f", out_of_bag_score_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
57
58
59
60
  Log::Info("GBDT::valid_score costs %f", valid_score_time * 1e-3);
  Log::Info("GBDT::metric costs %f", metric_time * 1e-3);
  Log::Info("GBDT::bagging costs %f", bagging_time * 1e-3);
  Log::Info("GBDT::tree costs %f", tree_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
61
  #endif
Guolin Ke's avatar
Guolin Ke committed
62
63
}

Guolin Ke's avatar
Guolin Ke committed
64
void GBDT::Init(const Config* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
65
                const std::vector<const Metric*>& training_metrics) {
Guolin Ke's avatar
Guolin Ke committed
66
  CHECK(train_data != nullptr);
Guolin Ke's avatar
Guolin Ke committed
67
  CHECK(train_data->num_features() > 0);
68
  train_data_ = train_data;
69
  iter_ = 0;
wxchan's avatar
wxchan committed
70
  num_iteration_for_pred_ = 0;
71
  max_feature_idx_ = 0;
wxchan's avatar
wxchan committed
72
  num_class_ = config->num_class;
Guolin Ke's avatar
Guolin Ke committed
73
74
75
  config_ = std::unique_ptr<Config>(new Config(*config));
  early_stopping_round_ = config_->early_stopping_round;
  shrinkage_rate_ = config_->learning_rate;
76

77
78
79
80
81
82
83
84
85
86
  std::string forced_splits_path = config->forcedsplits_filename;
  //load forced_splits file
  if (forced_splits_path != "") {
      std::ifstream forced_splits_file(forced_splits_path.c_str());
      std::stringstream buffer;
      buffer << forced_splits_file.rdbuf();
      std::string err;
      forced_splits_json_ = Json::parse(buffer.str(), err);
  }

87
88
89
90
  objective_function_ = objective_function;
  num_tree_per_iteration_ = num_class_;
  if (objective_function_ != nullptr) {
    is_constant_hessian_ = objective_function_->IsConstantHessian();
Guolin Ke's avatar
Guolin Ke committed
91
    num_tree_per_iteration_ = objective_function_->NumModelPerIteration();
92
93
94
95
  } else {
    is_constant_hessian_ = false;
  }

Guolin Ke's avatar
Guolin Ke committed
96
  tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(config_->tree_learner, config_->device_type, config_.get()));
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

  // init tree learner
  tree_learner_->Init(train_data_, is_constant_hessian_);

  // push training metrics
  training_metrics_.clear();
  for (const auto& metric : training_metrics) {
    training_metrics_.push_back(metric);
  }
  training_metrics_.shrink_to_fit();

  train_score_updater_.reset(new ScoreUpdater(train_data_, num_tree_per_iteration_));

  num_data_ = train_data_->num_data();
  // create buffer for gradients and hessians
  if (objective_function_ != nullptr) {
    size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
    gradients_.resize(total_size);
    hessians_.resize(total_size);
  }
  // get max feature index
  max_feature_idx_ = train_data_->num_total_features() - 1;
  // get label index
  label_idx_ = train_data_->label_idx();
  // get feature names
  feature_names_ = train_data_->feature_names();
  feature_infos_ = train_data_->feature_infos();

  // if need bagging, create buffer
Guolin Ke's avatar
Guolin Ke committed
126
  ResetBaggingConfig(config_.get(), true);
127
128
129
130
131

  // reset config for tree learner
  class_need_train_ = std::vector<bool>(num_tree_per_iteration_, true);
  if (objective_function_ != nullptr && objective_function_->SkipEmptyClass()) {
    CHECK(num_tree_per_iteration_ == num_class_);
Guolin Ke's avatar
Guolin Ke committed
132

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    class_default_output_ = std::vector<double>(num_tree_per_iteration_, 0.0f);
    auto label = train_data_->metadata().label();
    if (num_tree_per_iteration_ > 1) {
      // multi-class
      std::vector<data_size_t> cnt_per_class(num_tree_per_iteration_, 0);
      for (data_size_t i = 0; i < num_data_; ++i) {
        int index = static_cast<int>(label[i]);
        CHECK(index < num_tree_per_iteration_);
        ++cnt_per_class[index];
      }
      for (int i = 0; i < num_tree_per_iteration_; ++i) {
        if (cnt_per_class[i] == num_data_) {
          class_need_train_[i] = false;
          class_default_output_[i] = -std::log(kEpsilon);
        } else if (cnt_per_class[i] == 0) {
          class_need_train_[i] = false;
          class_default_output_[i] = -std::log(1.0f / kEpsilon - 1.0f);
        }
      }
    } else {
      // binary class
      data_size_t cnt_pos = 0;
      for (data_size_t i = 0; i < num_data_; ++i) {
        if (label[i] > 0) {
          ++cnt_pos;
        }
      }
      if (cnt_pos == 0) {
        class_need_train_[0] = false;
        class_default_output_[0] = -std::log(1.0f / kEpsilon - 1.0f);
      } else if (cnt_pos == num_data_) {
        class_need_train_[0] = false;
        class_default_output_[0] = -std::log(kEpsilon);
      }
    }
  }
wxchan's avatar
wxchan committed
169
170
171
}

void GBDT::AddValidDataset(const Dataset* valid_data,
172
                           const std::vector<const Metric*>& valid_metrics) {
wxchan's avatar
wxchan committed
173
  if (!train_data_->CheckAlign(*valid_data)) {
174
    Log::Fatal("Cannot add validation data, since it has different bin mappers with training data");
175
  }
Guolin Ke's avatar
Guolin Ke committed
176
  // for a validation dataset, we need its score and metric
177
  auto new_score_updater = std::unique_ptr<ScoreUpdater>(new ScoreUpdater(valid_data, num_tree_per_iteration_));
wxchan's avatar
wxchan committed
178
179
  // update score
  for (int i = 0; i < iter_; ++i) {
180
181
182
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
      auto curr_tree = (i + num_init_iteration_) * num_tree_per_iteration_ + cur_tree_id;
      new_score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
183
184
    }
  }
Guolin Ke's avatar
Guolin Ke committed
185
  valid_score_updater_.push_back(std::move(new_score_updater));
Guolin Ke's avatar
Guolin Ke committed
186
  valid_metrics_.emplace_back();
187
188
189
  if (early_stopping_round_ > 0) {
    best_iter_.emplace_back();
    best_score_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
190
    best_msg_.emplace_back();
191
  }
Guolin Ke's avatar
Guolin Ke committed
192
193
  for (const auto& metric : valid_metrics) {
    valid_metrics_.back().push_back(metric);
194
195
196
    if (early_stopping_round_ > 0) {
      best_iter_.back().push_back(0);
      best_score_.back().push_back(kMinScore);
Guolin Ke's avatar
Guolin Ke committed
197
      best_msg_.back().emplace_back();
198
    }
Guolin Ke's avatar
Guolin Ke committed
199
  }
Guolin Ke's avatar
Guolin Ke committed
200
  valid_metrics_.back().shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
201
202
}

Guolin Ke's avatar
Guolin Ke committed
203
204
205
206
207
208
209
210
211
212
void GBDT::Boosting() {
  if (objective_function_ == nullptr) {
    Log::Fatal("No object function provided");
  }
  // objective function will calculate gradients and hessians
  int64_t num_score = 0;
  objective_function_->
    GetGradients(GetTrainingScore(&num_score), gradients_.data(), hessians_.data());
}

213
data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer) {
214
215
216
  if (cnt <= 0) {
    return 0;
  }
Guolin Ke's avatar
Guolin Ke committed
217
  data_size_t bag_data_cnt = static_cast<data_size_t>(config_->bagging_fraction * cnt);
218
219
  data_size_t cur_left_cnt = 0;
  data_size_t cur_right_cnt = 0;
Guolin Ke's avatar
Guolin Ke committed
220
  auto right_buffer = buffer + bag_data_cnt;
221
222
  // random bagging, minimal unit is one record
  for (data_size_t i = 0; i < cnt; ++i) {
Guolin Ke's avatar
Guolin Ke committed
223
    float prob = (bag_data_cnt - cur_left_cnt) / static_cast<float>(cnt - i);
Guolin Ke's avatar
Guolin Ke committed
224
    if (cur_rand.NextFloat() < prob) {
225
226
      buffer[cur_left_cnt++] = start + i;
    } else {
Guolin Ke's avatar
Guolin Ke committed
227
      right_buffer[cur_right_cnt++] = start + i;
228
229
230
231
232
    }
  }
  CHECK(cur_left_cnt == bag_data_cnt);
  return cur_left_cnt;
}
Guolin Ke's avatar
Guolin Ke committed
233

234
void GBDT::Bagging(int iter) {
Guolin Ke's avatar
Guolin Ke committed
235
  // if need bagging
Guolin Ke's avatar
Guolin Ke committed
236
  if ((bag_data_cnt_ < num_data_ && iter % config_->bagging_freq == 0)
Guolin Ke's avatar
Guolin Ke committed
237
      || need_re_bagging_) {
Guolin Ke's avatar
Guolin Ke committed
238
    need_re_bagging_ = false;
Guolin Ke's avatar
Guolin Ke committed
239
    const data_size_t min_inner_size = 1000;
240
241
    data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_;
    if (inner_size < min_inner_size) { inner_size = min_inner_size; }
242
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
243
    #pragma omp parallel for schedule(static,1)
244
    for (int i = 0; i < num_threads_; ++i) {
245
      OMP_LOOP_EX_BEGIN();
246
247
248
249
250
251
      left_cnts_buf_[i] = 0;
      right_cnts_buf_[i] = 0;
      data_size_t cur_start = i * inner_size;
      if (cur_start > num_data_) { continue; }
      data_size_t cur_cnt = inner_size;
      if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; }
Guolin Ke's avatar
Guolin Ke committed
252
      Random cur_rand(config_->bagging_seed + iter * num_threads_ + i);
Guolin Ke's avatar
Guolin Ke committed
253
      data_size_t cur_left_count = BaggingHelper(cur_rand, cur_start, cur_cnt, tmp_indices_.data() + cur_start);
254
255
256
      offsets_buf_[i] = cur_start;
      left_cnts_buf_[i] = cur_left_count;
      right_cnts_buf_[i] = cur_cnt - cur_left_count;
257
      OMP_LOOP_EX_END();
258
    }
259
    OMP_THROW_EX();
260
261
262
263
264
265
266
267
268
    data_size_t left_cnt = 0;
    left_write_pos_buf_[0] = 0;
    right_write_pos_buf_[0] = 0;
    for (int i = 1; i < num_threads_; ++i) {
      left_write_pos_buf_[i] = left_write_pos_buf_[i - 1] + left_cnts_buf_[i - 1];
      right_write_pos_buf_[i] = right_write_pos_buf_[i - 1] + right_cnts_buf_[i - 1];
    }
    left_cnt = left_write_pos_buf_[num_threads_ - 1] + left_cnts_buf_[num_threads_ - 1];

Guolin Ke's avatar
Guolin Ke committed
269
    #pragma omp parallel for schedule(static, 1)
270
    for (int i = 0; i < num_threads_; ++i) {
271
      OMP_LOOP_EX_BEGIN();
272
273
      if (left_cnts_buf_[i] > 0) {
        std::memcpy(bag_data_indices_.data() + left_write_pos_buf_[i],
274
                    tmp_indices_.data() + offsets_buf_[i], left_cnts_buf_[i] * sizeof(data_size_t));
Guolin Ke's avatar
Guolin Ke committed
275
      }
276
277
      if (right_cnts_buf_[i] > 0) {
        std::memcpy(bag_data_indices_.data() + left_cnt + right_write_pos_buf_[i],
278
                    tmp_indices_.data() + offsets_buf_[i] + left_cnts_buf_[i], right_cnts_buf_[i] * sizeof(data_size_t));
Guolin Ke's avatar
Guolin Ke committed
279
      }
280
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
281
    }
282
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
283
    bag_data_cnt_ = left_cnt;
Guolin Ke's avatar
Guolin Ke committed
284
    Log::Debug("Re-bagging, using %d data to train", bag_data_cnt_);
Guolin Ke's avatar
Guolin Ke committed
285
    // set bagging data to tree learner
Guolin Ke's avatar
Guolin Ke committed
286
287
288
289
    if (!is_use_subset_) {
      tree_learner_->SetBaggingData(bag_data_indices_.data(), bag_data_cnt_);
    } else {
      // get subset
Guolin Ke's avatar
Guolin Ke committed
290
291
      tmp_subset_->ReSize(bag_data_cnt_);
      tmp_subset_->CopySubset(train_data_, bag_data_indices_.data(), bag_data_cnt_, false);
Guolin Ke's avatar
Guolin Ke committed
292
293
      tree_learner_->ResetTrainingData(tmp_subset_.get());
    }
Guolin Ke's avatar
Guolin Ke committed
294
295
296
  }
}

297
/* If the custom "average" is implemented it will be used inplace of the label average (if enabled)
Guolin Ke's avatar
Guolin Ke committed
298
299
300
301
302
303
304
305
306
*
* An improvement to this is to have options to explicitly choose
* (i) standard average
* (ii) custom average if available
* (iii) any user defined scalar bias (e.g. using a new option "init_score" that overrides (i) and (ii) )
*
* (i) and (ii) could be selected as say "auto_init_score" = 0 or 1 etc..
*
*/
307
double ObtainAutomaticInitialScore(const ObjectiveFunction* fobj) {
308
  double init_score = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
309
  if (fobj != nullptr) {
310
    init_score = fobj->BoostFromScore();
Guolin Ke's avatar
Guolin Ke committed
311
312
  }
  if (Network::num_machines() > 1) {
313
    init_score = Network::GlobalSyncUpByMean(init_score);
Guolin Ke's avatar
Guolin Ke committed
314
  }
315
  return init_score;
Guolin Ke's avatar
Guolin Ke committed
316
317
}

Guolin Ke's avatar
Guolin Ke committed
318
319
320
void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
  bool is_finished = false;
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
321
  for (int iter = 0; iter < config_->num_iterations && !is_finished; ++iter) {
Guolin Ke's avatar
Guolin Ke committed
322
323
324
325
    is_finished = TrainOneIter(nullptr, nullptr);
    if (!is_finished) {
      is_finished = EvalAndCheckEarlyStopping();
    }
Guolin Ke's avatar
Guolin Ke committed
326
327
328
329
330
331
332
333
334
335
336
337
    auto end_time = std::chrono::steady_clock::now();
    // output used time per iteration
    Log::Info("%f seconds elapsed, finished iteration %d", std::chrono::duration<double,
              std::milli>(end_time - start_time) * 1e-3, iter + 1);
    if (snapshot_freq > 0
        && (iter + 1) % snapshot_freq == 0) {
      std::string snapshot_out = model_output_path + ".snapshot_iter_" + std::to_string(iter + 1);
      SaveModelToFile(-1, snapshot_out.c_str());
    }
  }
}

338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
void GBDT::RefitTree(const std::vector<std::vector<int>>& tree_leaf_prediction) {
  CHECK(tree_leaf_prediction.size() > 0);
  CHECK(static_cast<size_t>(num_data_) == tree_leaf_prediction.size());
  CHECK(static_cast<size_t>(models_.size()) == tree_leaf_prediction[0].size());
  int num_iterations = static_cast<int>(models_.size() / num_tree_per_iteration_);
  std::vector<int> leaf_pred(num_data_);
  for (int iter = 0; iter < num_iterations; ++iter) {
    Boosting();
    for (int tree_id = 0; tree_id < num_tree_per_iteration_; ++tree_id) {
      int model_index = iter * num_tree_per_iteration_ + tree_id;
      #pragma omp parallel for schedule(static)
      for (int i = 0; i < num_data_; ++i) {
        leaf_pred[i] = tree_leaf_prediction[i][model_index];
      }
      size_t bias = static_cast<size_t>(tree_id) * num_data_;
      auto grad = gradients_.data() + bias;
      auto hess = hessians_.data() + bias;
      auto new_tree = tree_learner_->FitByExistingTree(models_[model_index].get(), leaf_pred, grad, hess);
      train_score_updater_->AddScore(tree_learner_.get(), new_tree, tree_id);
      models_[model_index].reset(new_tree);
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
362
double GBDT::BoostFromAverage() {
363
  // boosting from average label; or customized "average" if implemented for the current objective
364
  if (models_.empty() && !train_score_updater_->has_init_score()
365
      && num_class_ <= 1
366
      && objective_function_ != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
367
    if (config_->boost_from_average) {
368
369
370
371
372
373
374
375
      double init_score = ObtainAutomaticInitialScore(objective_function_);
      if (std::fabs(init_score) > kEpsilon) {
        train_score_updater_->AddScore(init_score, 0);
        for (auto& score_updater : valid_score_updater_) {
          score_updater->AddScore(init_score, 0);
        }
        Log::Info("Start training from score %lf", init_score);
        return init_score;
Guolin Ke's avatar
Guolin Ke committed
376
      }
377
378
379
    } else if (std::string(objective_function_->GetName()) == std::string("regression_l1")
               || std::string(objective_function_->GetName()) == std::string("quantile")
               || std::string(objective_function_->GetName()) == std::string("mape")) {
380
      Log::Warning("Disabling boost_from_average in %s may cause the slow convergence", objective_function_->GetName());
381
    }
382
  }
Guolin Ke's avatar
Guolin Ke committed
383
384
  return 0.0f;
}
Guolin Ke's avatar
Guolin Ke committed
385

Guolin Ke's avatar
Guolin Ke committed
386
bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
387
  double init_score = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
388
  // boosting first
Guolin Ke's avatar
Guolin Ke committed
389
  if (gradients == nullptr || hessians == nullptr) {
390
    init_score = BoostFromAverage();
Guolin Ke's avatar
Guolin Ke committed
391
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
392
    auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
393
    #endif
Guolin Ke's avatar
Guolin Ke committed
394

Guolin Ke's avatar
Guolin Ke committed
395
    Boosting();
Guolin Ke's avatar
Guolin Ke committed
396
397
398
    gradients = gradients_.data();
    hessians = hessians_.data();

Guolin Ke's avatar
Guolin Ke committed
399
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
400
    boosting_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
401
    #endif
Guolin Ke's avatar
Guolin Ke committed
402
  }
Guolin Ke's avatar
Guolin Ke committed
403

Guolin Ke's avatar
Guolin Ke committed
404
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
405
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
406
  #endif
Guolin Ke's avatar
Guolin Ke committed
407

408
409
  // bagging logic
  Bagging(iter_);
Guolin Ke's avatar
Guolin Ke committed
410

Guolin Ke's avatar
Guolin Ke committed
411
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
412
  bagging_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
413
  #endif
Guolin Ke's avatar
Guolin Ke committed
414

Guolin Ke's avatar
Guolin Ke committed
415
  bool should_continue = false;
416
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
417

Guolin Ke's avatar
Guolin Ke committed
418
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
419
    start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
420
    #endif
421
    const size_t bias = static_cast<size_t>(cur_tree_id) * num_data_;
422
    std::unique_ptr<Tree> new_tree(new Tree(2));
423
    if (class_need_train_[cur_tree_id]) {
Guolin Ke's avatar
Guolin Ke committed
424
425
426
427
428
429
430
431
432
433
434
435
436
      auto grad = gradients + bias;
      auto hess = hessians + bias;

      // need to copy gradients for bagging subset.
      if (is_use_subset_ && bag_data_cnt_ < num_data_) {
        for (int i = 0; i < bag_data_cnt_; ++i) {
          gradients_[bias + i] = grad[bag_data_indices_[i]];
          hessians_[bias + i] = hess[bag_data_indices_[i]];
        }
        grad = gradients_.data() + bias;
        hess = hessians_.data() + bias;
      }

437
      new_tree.reset(tree_learner_->Train(grad, hess, is_constant_hessian_, forced_splits_json_));
438
    }
Guolin Ke's avatar
Guolin Ke committed
439

Guolin Ke's avatar
Guolin Ke committed
440
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
441
    tree_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
442
    #endif
Guolin Ke's avatar
Guolin Ke committed
443
444

    if (new_tree->num_leaves() > 1) {
Guolin Ke's avatar
Guolin Ke committed
445
      should_continue = true;
446
447
      tree_learner_->RenewTreeOutput(new_tree.get(), objective_function_, train_score_updater_->score() + bias,
                                     num_data_, bag_data_indices_.data(), bag_data_cnt_);
Guolin Ke's avatar
Guolin Ke committed
448
449
450
      // shrinkage by learning rate
      new_tree->Shrinkage(shrinkage_rate_);
      // update score
451
      UpdateScore(new_tree.get(), cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
452
453
454
      if (std::fabs(init_score) > kEpsilon) {
        new_tree->AddBias(init_score);
      }
455
456
    } else {
      // only add default score one-time
457
458
      if (!class_need_train_[cur_tree_id] && models_.size() < static_cast<size_t>(num_tree_per_iteration_)) {
        auto output = class_default_output_[cur_tree_id];
459
        new_tree->AsConstantTree(output);
Guolin Ke's avatar
Guolin Ke committed
460
        // updates scores
461
        train_score_updater_->AddScore(output, cur_tree_id);
462
        for (auto& score_updater : valid_score_updater_) {
463
          score_updater->AddScore(output, cur_tree_id);
464
465
466
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
467
468
469
    // add model
    models_.push_back(std::move(new_tree));
  }
Guolin Ke's avatar
Guolin Ke committed
470

Guolin Ke's avatar
Guolin Ke committed
471
  if (!should_continue) {
472
    Log::Warning("Stopped training because there are no more leaves that meet the split requirements");
473
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
474
475
476
477
      models_.pop_back();
    }
    return true;
  }
478

Guolin Ke's avatar
Guolin Ke committed
479
480
  ++iter_;
  return false;
Guolin Ke's avatar
Guolin Ke committed
481
}
482

wxchan's avatar
wxchan committed
483
void GBDT::RollbackOneIter() {
484
  if (iter_ <= 0) { return; }
wxchan's avatar
wxchan committed
485
  // reset score
486
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
487
    auto curr_tree = models_.size() - num_tree_per_iteration_ + cur_tree_id;
wxchan's avatar
wxchan committed
488
    models_[curr_tree]->Shrinkage(-1.0);
489
    train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
490
    for (auto& score_updater : valid_score_updater_) {
491
      score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
492
493
494
    }
  }
  // remove model
495
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
wxchan's avatar
wxchan committed
496
497
498
499
500
    models_.pop_back();
  }
  --iter_;
}

Guolin Ke's avatar
Guolin Ke committed
501
bool GBDT::EvalAndCheckEarlyStopping() {
502
  bool is_met_early_stopping = false;
Guolin Ke's avatar
Guolin Ke committed
503

Guolin Ke's avatar
Guolin Ke committed
504
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
505
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
506
  #endif
Guolin Ke's avatar
Guolin Ke committed
507

508
  // print message for metric
Guolin Ke's avatar
Guolin Ke committed
509
  auto best_msg = OutputMetric(iter_);
Guolin Ke's avatar
Guolin Ke committed
510

Guolin Ke's avatar
Guolin Ke committed
511
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
512
  metric_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
513
  #endif
Guolin Ke's avatar
Guolin Ke committed
514

Guolin Ke's avatar
Guolin Ke committed
515
  is_met_early_stopping = !best_msg.empty();
516
517
  if (is_met_early_stopping) {
    Log::Info("Early stopping at iteration %d, the best iteration round is %d",
518
              iter_, iter_ - early_stopping_round_);
Guolin Ke's avatar
Guolin Ke committed
519
    Log::Info("Output of best iteration round:\n%s", best_msg.c_str());
520
    // pop last early_stopping_round_ models
521
    for (int i = 0; i < early_stopping_round_ * num_tree_per_iteration_; ++i) {
522
523
524
525
      models_.pop_back();
    }
  }
  return is_met_early_stopping;
Guolin Ke's avatar
Guolin Ke committed
526
527
}

528
void GBDT::UpdateScore(const Tree* tree, const int cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
529

Guolin Ke's avatar
Guolin Ke committed
530
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
531
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
532
  #endif
Guolin Ke's avatar
Guolin Ke committed
533

Guolin Ke's avatar
Guolin Ke committed
534
  // update training score
Guolin Ke's avatar
Guolin Ke committed
535
  if (!is_use_subset_) {
536
    train_score_updater_->AddScore(tree_learner_.get(), tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554

    #ifdef TIMETAG
    train_score_time += std::chrono::steady_clock::now() - start_time;
    #endif

    #ifdef TIMETAG
    start_time = std::chrono::steady_clock::now();
    #endif

    // we need to predict out-of-bag scores of data for boosting
    if (num_data_ - bag_data_cnt_ > 0) {
      train_score_updater_->AddScore(tree, bag_data_indices_.data() + bag_data_cnt_, num_data_ - bag_data_cnt_, cur_tree_id);
    }

    #ifdef TIMETAG
    out_of_bag_score_time += std::chrono::steady_clock::now() - start_time;
    #endif

Guolin Ke's avatar
Guolin Ke committed
555
  } else {
556
    train_score_updater_->AddScore(tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
557
558
559
560

    #ifdef TIMETAG
    train_score_time += std::chrono::steady_clock::now() - start_time;
    #endif
Guolin Ke's avatar
Guolin Ke committed
561
  }
Guolin Ke's avatar
Guolin Ke committed
562
563


Guolin Ke's avatar
Guolin Ke committed
564
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
565
  start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
566
  #endif
Guolin Ke's avatar
Guolin Ke committed
567

Guolin Ke's avatar
Guolin Ke committed
568
  // update validation score
Guolin Ke's avatar
Guolin Ke committed
569
  for (auto& score_updater : valid_score_updater_) {
570
    score_updater->AddScore(tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
571
  }
Guolin Ke's avatar
Guolin Ke committed
572

Guolin Ke's avatar
Guolin Ke committed
573
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
574
  valid_score_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
575
  #endif
Guolin Ke's avatar
Guolin Ke committed
576
577
}

Guolin Ke's avatar
Guolin Ke committed
578
579
580
581
std::vector<double> GBDT::EvalOneMetric(const Metric* metric, const double* score) const {
  return metric->Eval(score, objective_function_);
}

Guolin Ke's avatar
Guolin Ke committed
582
std::string GBDT::OutputMetric(int iter) {
Guolin Ke's avatar
Guolin Ke committed
583
  bool need_output = (iter % config_->metric_freq) == 0;
Guolin Ke's avatar
Guolin Ke committed
584
585
  std::string ret = "";
  std::stringstream msg_buf;
586
  std::vector<std::pair<size_t, size_t>> meet_early_stopping_pairs;
Guolin Ke's avatar
Guolin Ke committed
587
  // print training metric
Guolin Ke's avatar
Guolin Ke committed
588
  if (need_output) {
589
590
    for (auto& sub_metric : training_metrics_) {
      auto name = sub_metric->GetName();
Guolin Ke's avatar
Guolin Ke committed
591
      auto scores = EvalOneMetric(sub_metric, train_score_updater_->score());
Guolin Ke's avatar
Guolin Ke committed
592
      for (size_t k = 0; k < name.size(); ++k) {
Guolin Ke's avatar
Guolin Ke committed
593
594
595
596
597
598
        std::stringstream tmp_buf;
        tmp_buf << "Iteration:" << iter
          << ", training " << name[k]
          << " : " << scores[k];
        Log::Info(tmp_buf.str().c_str());
        if (early_stopping_round_ > 0) {
599
          msg_buf << tmp_buf.str() << '\n';
Guolin Ke's avatar
Guolin Ke committed
600
        }
601
      }
602
    }
Guolin Ke's avatar
Guolin Ke committed
603
604
  }
  // print validation metric
Guolin Ke's avatar
Guolin Ke committed
605
  if (need_output || early_stopping_round_ > 0) {
606
607
    for (size_t i = 0; i < valid_metrics_.size(); ++i) {
      for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
608
        auto test_scores = EvalOneMetric(valid_metrics_[i][j], valid_score_updater_[i]->score());
Guolin Ke's avatar
Guolin Ke committed
609
610
611
612
613
614
615
616
617
618
        auto name = valid_metrics_[i][j]->GetName();
        for (size_t k = 0; k < name.size(); ++k) {
          std::stringstream tmp_buf;
          tmp_buf << "Iteration:" << iter
            << ", valid_" << i + 1 << " " << name[k]
            << " : " << test_scores[k];
          if (need_output) {
            Log::Info(tmp_buf.str().c_str());
          }
          if (early_stopping_round_ > 0) {
619
            msg_buf << tmp_buf.str() << '\n';
620
          }
wxchan's avatar
wxchan committed
621
        }
Guolin Ke's avatar
Guolin Ke committed
622
        if (ret.empty() && early_stopping_round_ > 0) {
623
624
625
          auto cur_score = valid_metrics_[i][j]->factor_to_bigger_better() * test_scores.back();
          if (cur_score > best_score_[i][j]) {
            best_score_[i][j] = cur_score;
626
            best_iter_[i][j] = iter;
Guolin Ke's avatar
Guolin Ke committed
627
            meet_early_stopping_pairs.emplace_back(i, j);
628
          } else {
Guolin Ke's avatar
Guolin Ke committed
629
            if (iter - best_iter_[i][j] >= early_stopping_round_) { ret = best_msg_[i][j]; }
630
          }
wxchan's avatar
wxchan committed
631
632
        }
      }
Guolin Ke's avatar
Guolin Ke committed
633
634
    }
  }
Guolin Ke's avatar
Guolin Ke committed
635
636
637
  for (auto& pair : meet_early_stopping_pairs) {
    best_msg_[pair.first][pair.second] = msg_buf.str();
  }
wxchan's avatar
wxchan committed
638
  return ret;
Guolin Ke's avatar
Guolin Ke committed
639
640
}

641
/*! \brief Get eval result */
642
std::vector<double> GBDT::GetEvalAt(int data_idx) const {
Guolin Ke's avatar
Guolin Ke committed
643
  CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
644
645
  std::vector<double> ret;
  if (data_idx == 0) {
646
    for (auto& sub_metric : training_metrics_) {
Guolin Ke's avatar
Guolin Ke committed
647
      auto scores = EvalOneMetric(sub_metric, train_score_updater_->score());
648
649
650
      for (auto score : scores) {
        ret.push_back(score);
      }
651
    }
652
  } else {
653
654
    auto used_idx = data_idx - 1;
    for (size_t j = 0; j < valid_metrics_[used_idx].size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
655
      auto test_scores = EvalOneMetric(valid_metrics_[used_idx][j], valid_score_updater_[used_idx]->score());
656
657
658
      for (auto score : test_scores) {
        ret.push_back(score);
      }
659
660
661
662
663
    }
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
664
/*! \brief Get training scores result */
665
const double* GBDT::GetTrainingScore(int64_t* out_len) {
666
  *out_len = static_cast<int64_t>(train_score_updater_->num_data()) * num_class_;
Guolin Ke's avatar
Guolin Ke committed
667
  return train_score_updater_->score();
668
669
}

670
671
672
void GBDT::PredictContrib(const double* features, double* output, const PredictionEarlyStopInstance* early_stop) const {
  int early_stop_round_counter = 0;
  // set zero
Guolin Ke's avatar
Guolin Ke committed
673
674
  const int num_features = max_feature_idx_ + 1;
  std::memset(output, 0, sizeof(double) * num_tree_per_iteration_ * (num_features + 1));
675
676
677
  for (int i = 0; i < num_iteration_for_pred_; ++i) {
    // predict all the trees for one iteration
    for (int k = 0; k < num_tree_per_iteration_; ++k) {
Guolin Ke's avatar
Guolin Ke committed
678
      models_[i * num_tree_per_iteration_ + k]->PredictContrib(features, num_features, output + k*(num_features + 1));
679
680
681
682
683
684
685
686
687
688
689
690
    }
    // check early stopping
    ++early_stop_round_counter;
    if (early_stop->round_period == early_stop_round_counter) {
      if (early_stop->callback_function(output, num_tree_per_iteration_)) {
        return;
      }
      early_stop_round_counter = 0;
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
691
692
void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
  CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
Guolin Ke's avatar
Guolin Ke committed
693

694
  const double* raw_scores = nullptr;
Guolin Ke's avatar
Guolin Ke committed
695
696
  data_size_t num_data = 0;
  if (data_idx == 0) {
wxchan's avatar
wxchan committed
697
    raw_scores = GetTrainingScore(out_len);
Guolin Ke's avatar
Guolin Ke committed
698
699
700
701
702
    num_data = train_score_updater_->num_data();
  } else {
    auto used_idx = data_idx - 1;
    raw_scores = valid_score_updater_[used_idx]->score();
    num_data = valid_score_updater_[used_idx]->num_data();
703
    *out_len = static_cast<int64_t>(num_data) * num_class_;
Guolin Ke's avatar
Guolin Ke committed
704
  }
Guolin Ke's avatar
Guolin Ke committed
705
  if (objective_function_ != nullptr && !average_output_) {
Guolin Ke's avatar
Guolin Ke committed
706
707
    #pragma omp parallel for schedule(static)
    for (data_size_t i = 0; i < num_data; ++i) {
Guolin Ke's avatar
Guolin Ke committed
708
      std::vector<double> tree_pred(num_tree_per_iteration_);
709
      for (int j = 0; j < num_tree_per_iteration_; ++j) {
Guolin Ke's avatar
Guolin Ke committed
710
        tree_pred[j] = raw_scores[j * num_data + i];
711
      }
Guolin Ke's avatar
Guolin Ke committed
712
713
      std::vector<double> tmp_result(num_class_);
      objective_function_->ConvertOutput(tree_pred.data(), tmp_result.data());
Guolin Ke's avatar
Guolin Ke committed
714
      for (int j = 0; j < num_class_; ++j) {
715
        out_result[j * num_data + i] = static_cast<double>(tmp_result[j]);
Guolin Ke's avatar
Guolin Ke committed
716
717
      }
    }
718
  } else {
Guolin Ke's avatar
Guolin Ke committed
719
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
720
    for (data_size_t i = 0; i < num_data; ++i) {
Guolin Ke's avatar
Guolin Ke committed
721
      std::vector<double> tmp_result(num_tree_per_iteration_);
722
      for (int j = 0; j < num_tree_per_iteration_; ++j) {
Guolin Ke's avatar
Guolin Ke committed
723
        out_result[j * num_data + i] = static_cast<double>(raw_scores[j * num_data + i]);
Guolin Ke's avatar
Guolin Ke committed
724
725
726
727
728
      }
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
729
730
void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
                             const std::vector<const Metric*>& training_metrics) {
Guolin Ke's avatar
Guolin Ke committed
731

Guolin Ke's avatar
Guolin Ke committed
732
  if (train_data != train_data_ && !train_data_->CheckAlign(*train_data)) {
733
    Log::Fatal("Cannot reset training data, since new training data has different bin mappers");
wxchan's avatar
wxchan committed
734
735
  }

Guolin Ke's avatar
Guolin Ke committed
736
737
738
739
740
741
  objective_function_ = objective_function;
  if (objective_function_ != nullptr) {
    is_constant_hessian_ = objective_function_->IsConstantHessian();
    CHECK(num_tree_per_iteration_ == objective_function_->NumModelPerIteration());
  } else {
    is_constant_hessian_ = false;
742
743
  }

Guolin Ke's avatar
Guolin Ke committed
744
745
746
747
  // push training metrics
  training_metrics_.clear();
  for (const auto& metric : training_metrics) {
    training_metrics_.push_back(metric);
748
  }
Guolin Ke's avatar
Guolin Ke committed
749
  training_metrics_.shrink_to_fit();
750

Guolin Ke's avatar
Guolin Ke committed
751
752
753
754
755
  if (train_data != train_data_) {
    train_data_ = train_data;
    // not same training data, need reset score and others
    // create score tracker
    train_score_updater_.reset(new ScoreUpdater(train_data_, num_tree_per_iteration_));
756

Guolin Ke's avatar
Guolin Ke committed
757
758
759
760
761
762
    // update score
    for (int i = 0; i < iter_; ++i) {
      for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
        auto curr_tree = (i + num_init_iteration_) * num_tree_per_iteration_ + cur_tree_id;
        train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
      }
763
764
    }

Guolin Ke's avatar
Guolin Ke committed
765
    num_data_ = train_data_->num_data();
766

Guolin Ke's avatar
Guolin Ke committed
767
768
769
770
771
772
    // create buffer for gradients and hessians
    if (objective_function_ != nullptr) {
      size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
      gradients_.resize(total_size);
      hessians_.resize(total_size);
    }
773

Guolin Ke's avatar
Guolin Ke committed
774
775
776
777
    max_feature_idx_ = train_data_->num_total_features() - 1;
    label_idx_ = train_data_->label_idx();
    feature_names_ = train_data_->feature_names();
    feature_infos_ = train_data_->feature_infos();
778

Guolin Ke's avatar
Guolin Ke committed
779
    tree_learner_->ResetTrainingData(train_data);
Guolin Ke's avatar
Guolin Ke committed
780
    ResetBaggingConfig(config_.get(), true);
781
  }
782
783
}

Guolin Ke's avatar
Guolin Ke committed
784
785
void GBDT::ResetConfig(const Config* config) {
  auto new_config = std::unique_ptr<Config>(new Config(*config));
Guolin Ke's avatar
Guolin Ke committed
786
787
788
  early_stopping_round_ = new_config->early_stopping_round;
  shrinkage_rate_ = new_config->learning_rate;
  if (tree_learner_ != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
789
    tree_learner_->ResetConfig(new_config.get());
790
  }
Guolin Ke's avatar
Guolin Ke committed
791
792
  if (train_data_ != nullptr) {
    ResetBaggingConfig(new_config.get(), false);
793
  }
Guolin Ke's avatar
Guolin Ke committed
794
  config_.reset(new_config.release());
Guolin Ke's avatar
Guolin Ke committed
795
796
}

Guolin Ke's avatar
Guolin Ke committed
797
void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) {
Guolin Ke's avatar
Guolin Ke committed
798
799
800
801
802
803
  // if need bagging, create buffer
  if (config->bagging_fraction < 1.0 && config->bagging_freq > 0) {
    bag_data_cnt_ =
      static_cast<data_size_t>(config->bagging_fraction * num_data_);
    bag_data_indices_.resize(num_data_);
    tmp_indices_.resize(num_data_);
804

Guolin Ke's avatar
Guolin Ke committed
805
806
807
808
809
    offsets_buf_.resize(num_threads_);
    left_cnts_buf_.resize(num_threads_);
    right_cnts_buf_.resize(num_threads_);
    left_write_pos_buf_.resize(num_threads_);
    right_write_pos_buf_.resize(num_threads_);
810

Guolin Ke's avatar
Guolin Ke committed
811
812
813
814
815
816
    double average_bag_rate = config->bagging_fraction / config->bagging_freq;
    int sparse_group = 0;
    for (int i = 0; i < train_data_->num_feature_groups(); ++i) {
      if (train_data_->FeatureGroupIsSparse(i)) {
        ++sparse_group;
      }
Guolin Ke's avatar
Guolin Ke committed
817
    }
Guolin Ke's avatar
Guolin Ke committed
818
819
820
821
822
823
824
825
826
827
    is_use_subset_ = false;
    const int group_threshold_usesubset = 100;
    const int sparse_group_threshold_usesubset = train_data_->num_feature_groups() / 4;
    if (average_bag_rate <= 0.5
        && (train_data_->num_feature_groups() < group_threshold_usesubset || sparse_group < sparse_group_threshold_usesubset)) {
      if (tmp_subset_ == nullptr || is_change_dataset) {
        tmp_subset_.reset(new Dataset(bag_data_cnt_));
        tmp_subset_->CopyFeatureMapperFrom(train_data_);
      }
      is_use_subset_ = true;
828
      Log::Debug("Use subset for bagging");
Guolin Ke's avatar
Guolin Ke committed
829
830
    }

Guolin Ke's avatar
Guolin Ke committed
831
832
    if (is_change_dataset) {
      need_re_bagging_ = true;
Guolin Ke's avatar
Guolin Ke committed
833
    }
834

Guolin Ke's avatar
Guolin Ke committed
835
836
837
838
839
    if (is_use_subset_ && bag_data_cnt_ < num_data_) {
      if (objective_function_ == nullptr) {
        size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
        gradients_.resize(total_size);
        hessians_.resize(total_size);
840
      }
841
    }
842
  } else {
Guolin Ke's avatar
Guolin Ke committed
843
844
845
846
    bag_data_cnt_ = num_data_;
    bag_data_indices_.clear();
    tmp_indices_.clear();
    is_use_subset_ = false;
847
  }
wxchan's avatar
wxchan committed
848
849
}

Guolin Ke's avatar
Guolin Ke committed
850
}  // namespace LightGBM