gbdt.cpp 41.4 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
#include "gbdt.h"

3
#include <LightGBM/utils/openmp_wrapper.h>
4

Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#include <LightGBM/utils/common.h>

#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
cbecker's avatar
cbecker committed
9
#include <LightGBM/prediction_early_stop.h>
Guolin Ke's avatar
Guolin Ke committed
10
#include <LightGBM/network.h>
Guolin Ke's avatar
Guolin Ke committed
11
12
13
14
15
16
17

#include <ctime>

#include <sstream>
#include <chrono>
#include <string>
#include <vector>
18
#include <utility>
Guolin Ke's avatar
Guolin Ke committed
19
20
21

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
22
23
24
#ifdef TIMETAG
std::chrono::duration<double, std::milli> boosting_time;
std::chrono::duration<double, std::milli> train_score_time;
Guolin Ke's avatar
Guolin Ke committed
25
std::chrono::duration<double, std::milli> out_of_bag_score_time;
Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
30
31
32
std::chrono::duration<double, std::milli> valid_score_time;
std::chrono::duration<double, std::milli> metric_time;
std::chrono::duration<double, std::milli> bagging_time;
std::chrono::duration<double, std::milli> sub_gradient_time;
std::chrono::duration<double, std::milli> tree_time;
#endif // TIMETAG

33
GBDT::GBDT()
34
  :iter_(0),
35
  train_data_(nullptr),
36
  objective_function_(nullptr),
37
38
  early_stopping_round_(0),
  max_feature_idx_(0),
39
  num_tree_per_iteration_(1),
40
  num_class_(1),
41
  num_iteration_for_pred_(0),
42
  shrinkage_rate_(0.1f),
43
44
  num_init_iteration_(0),
  boost_from_average_(false) {
Guolin Ke's avatar
Guolin Ke committed
45
46
  #pragma omp parallel
  #pragma omp master
Guolin Ke's avatar
Guolin Ke committed
47
48
49
50
  {
    num_threads_ = omp_get_num_threads();
  }
  average_output_ = false;
Guolin Ke's avatar
Guolin Ke committed
51
  tree_learner_ = nullptr;
Guolin Ke's avatar
Guolin Ke committed
52
53
54
}

GBDT::~GBDT() {
Guolin Ke's avatar
Guolin Ke committed
55
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
56
57
  Log::Info("GBDT::boosting costs %f", boosting_time * 1e-3);
  Log::Info("GBDT::train_score costs %f", train_score_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
58
  Log::Info("GBDT::out_of_bag_score costs %f", out_of_bag_score_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
59
60
61
62
63
  Log::Info("GBDT::valid_score costs %f", valid_score_time * 1e-3);
  Log::Info("GBDT::metric costs %f", metric_time * 1e-3);
  Log::Info("GBDT::bagging costs %f", bagging_time * 1e-3);
  Log::Info("GBDT::sub_gradient costs %f", sub_gradient_time * 1e-3);
  Log::Info("GBDT::tree costs %f", tree_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
64
  #endif
Guolin Ke's avatar
Guolin Ke committed
65
66
}

67
void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
68
                const std::vector<const Metric*>& training_metrics) {
Guolin Ke's avatar
Guolin Ke committed
69
  CHECK(train_data->num_features() > 0);
70
  train_data_ = train_data;
71
  iter_ = 0;
wxchan's avatar
wxchan committed
72
  num_iteration_for_pred_ = 0;
73
  max_feature_idx_ = 0;
wxchan's avatar
wxchan committed
74
  num_class_ = config->num_class;
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
  gbdt_config_ = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
  early_stopping_round_ = gbdt_config_->early_stopping_round;
  shrinkage_rate_ = gbdt_config_->learning_rate;

  objective_function_ = objective_function;
  num_tree_per_iteration_ = num_class_;
  if (objective_function_ != nullptr) {
    is_constant_hessian_ = objective_function_->IsConstantHessian();
    num_tree_per_iteration_ = objective_function_->NumTreePerIteration();
  } else {
    is_constant_hessian_ = false;
  }

  tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->device_type, &gbdt_config_->tree_config));

  // init tree learner
  tree_learner_->Init(train_data_, is_constant_hessian_);

  // push training metrics
  training_metrics_.clear();
  for (const auto& metric : training_metrics) {
    training_metrics_.push_back(metric);
  }
  training_metrics_.shrink_to_fit();

  train_score_updater_.reset(new ScoreUpdater(train_data_, num_tree_per_iteration_));

  num_data_ = train_data_->num_data();
  // create buffer for gradients and hessians
  if (objective_function_ != nullptr) {
    size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
    gradients_.resize(total_size);
    hessians_.resize(total_size);
  }
  // get max feature index
  max_feature_idx_ = train_data_->num_total_features() - 1;
  // get label index
  label_idx_ = train_data_->label_idx();
  // get feature names
  feature_names_ = train_data_->feature_names();
  feature_infos_ = train_data_->feature_infos();

  // if need bagging, create buffer
  ResetBaggingConfig(gbdt_config_.get());

  // reset config for tree learner
  class_need_train_ = std::vector<bool>(num_tree_per_iteration_, true);
  if (objective_function_ != nullptr && objective_function_->SkipEmptyClass()) {
    CHECK(num_tree_per_iteration_ == num_class_);
    // + 1 here for the binary classification
    class_default_output_ = std::vector<double>(num_tree_per_iteration_, 0.0f);
    auto label = train_data_->metadata().label();
    if (num_tree_per_iteration_ > 1) {
      // multi-class
      std::vector<data_size_t> cnt_per_class(num_tree_per_iteration_, 0);
      for (data_size_t i = 0; i < num_data_; ++i) {
        int index = static_cast<int>(label[i]);
        CHECK(index < num_tree_per_iteration_);
        ++cnt_per_class[index];
      }
      for (int i = 0; i < num_tree_per_iteration_; ++i) {
        if (cnt_per_class[i] == num_data_) {
          class_need_train_[i] = false;
          class_default_output_[i] = -std::log(kEpsilon);
        } else if (cnt_per_class[i] == 0) {
          class_need_train_[i] = false;
          class_default_output_[i] = -std::log(1.0f / kEpsilon - 1.0f);
        }
      }
    } else {
      // binary class
      data_size_t cnt_pos = 0;
      for (data_size_t i = 0; i < num_data_; ++i) {
        if (label[i] > 0) {
          ++cnt_pos;
        }
      }
      if (cnt_pos == 0) {
        class_need_train_[0] = false;
        class_default_output_[0] = -std::log(1.0f / kEpsilon - 1.0f);
      } else if (cnt_pos == num_data_) {
        class_need_train_[0] = false;
        class_default_output_[0] = -std::log(kEpsilon);
      }
    }
  }
wxchan's avatar
wxchan committed
161
162
}

163
void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
164
                             const std::vector<const Metric*>& training_metrics) {
165
  if (train_data != train_data_ && !train_data_->CheckAlign(*train_data)) {
wxchan's avatar
wxchan committed
166
167
    Log::Fatal("cannot reset training data, since new training data has different bin mappers");
  }
Guolin Ke's avatar
Guolin Ke committed
168
  CHECK(train_data->num_features() > 0);
169
170
171
  objective_function_ = objective_function;
  if (objective_function_ != nullptr) {
    is_constant_hessian_ = objective_function_->IsConstantHessian();
Guolin Ke's avatar
Guolin Ke committed
172
    CHECK(num_tree_per_iteration_ == objective_function_->NumTreePerIteration());
173
174
175
  } else {
    is_constant_hessian_ = false;
  }
Guolin Ke's avatar
Guolin Ke committed
176

177
178
179
180
181
182
  // push training metrics
  training_metrics_.clear();
  for (const auto& metric : training_metrics) {
    training_metrics_.push_back(metric);
  }
  training_metrics_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
183

184
185
  if (train_data != train_data_) {
    train_data_ = train_data;
wxchan's avatar
wxchan committed
186
187
    // not same training data, need reset score and others
    // create score tracker
188
    train_score_updater_.reset(new ScoreUpdater(train_data_, num_tree_per_iteration_));
wxchan's avatar
wxchan committed
189
190
    // update score
    for (int i = 0; i < iter_; ++i) {
191
192
193
      for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
        auto curr_tree = (i + num_init_iteration_) * num_tree_per_iteration_ + cur_tree_id;
        train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
194
195
      }
    }
196
197
    num_data_ = train_data_->num_data();

wxchan's avatar
wxchan committed
198
    // create buffer for gradients and hessians
199
200
201
202
203
    if (objective_function_ != nullptr) {
      size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
      gradients_.resize(total_size);
      hessians_.resize(total_size);
    }
204

wxchan's avatar
wxchan committed
205
    // get max feature index
206
    max_feature_idx_ = train_data_->num_total_features() - 1;
wxchan's avatar
wxchan committed
207
    // get label index
208
    label_idx_ = train_data_->label_idx();
209
    // get feature names
210
211
212
213
214
    feature_names_ = train_data_->feature_names();

    feature_infos_ = train_data_->feature_infos();

    ResetBaggingConfig(gbdt_config_.get());
Guolin Ke's avatar
Guolin Ke committed
215

216
    tree_learner_->ResetTrainingData(train_data);
Guolin Ke's avatar
Guolin Ke committed
217
  }
218
}
Guolin Ke's avatar
Guolin Ke committed
219

220
221
222
223
void GBDT::ResetConfig(const BoostingConfig* config) {
  auto new_config = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
  early_stopping_round_ = new_config->early_stopping_round;
  shrinkage_rate_ = new_config->learning_rate;
Guolin Ke's avatar
Guolin Ke committed
224
225
226
227
  if (tree_learner_ != nullptr) {
    ResetBaggingConfig(new_config.get());
    tree_learner_->ResetConfig(&new_config->tree_config);
  }
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
  gbdt_config_.reset(new_config.release());
}

void GBDT::ResetBaggingConfig(const BoostingConfig* config) {
  // if need bagging, create buffer
  if (config->bagging_fraction < 1.0 && config->bagging_freq > 0) {
    bag_data_cnt_ =
      static_cast<data_size_t>(config->bagging_fraction * num_data_);
    bag_data_indices_.resize(num_data_);
    tmp_indices_.resize(num_data_);
    offsets_buf_.resize(num_threads_);
    left_cnts_buf_.resize(num_threads_);
    right_cnts_buf_.resize(num_threads_);
    left_write_pos_buf_.resize(num_threads_);
    right_write_pos_buf_.resize(num_threads_);
    double average_bag_rate = config->bagging_fraction / config->bagging_freq;
    int sparse_group = 0;
    for (int i = 0; i < train_data_->num_feature_groups(); ++i) {
      if (train_data_->FeatureGroupIsSparse(i)) {
        ++sparse_group;
Guolin Ke's avatar
Guolin Ke committed
248
      }
wxchan's avatar
wxchan committed
249
    }
250
251
252
253
254
255
256
257
258
    is_use_subset_ = false;
    const int group_threshold_usesubset = 100;
    const int sparse_group_threshold_usesubset = train_data_->num_feature_groups() / 4;
    if (average_bag_rate <= 0.5
        && (train_data_->num_feature_groups() < group_threshold_usesubset || sparse_group < sparse_group_threshold_usesubset)) {
      tmp_subset_.reset(new Dataset(bag_data_cnt_));
      tmp_subset_->CopyFeatureMapperFrom(train_data_);
      is_use_subset_ = true;
      Log::Debug("use subset for bagging");
259
    }
260
261
262
263
264
  } else {
    bag_data_cnt_ = num_data_;
    bag_data_indices_.clear();
    tmp_indices_.clear();
    is_use_subset_ = false;
Guolin Ke's avatar
Guolin Ke committed
265
  }
Guolin Ke's avatar
Guolin Ke committed
266
267
}

wxchan's avatar
wxchan committed
268
void GBDT::AddValidDataset(const Dataset* valid_data,
269
                           const std::vector<const Metric*>& valid_metrics) {
wxchan's avatar
wxchan committed
270
271
  if (!train_data_->CheckAlign(*valid_data)) {
    Log::Fatal("cannot add validation data, since it has different bin mappers with training data");
272
  }
Guolin Ke's avatar
Guolin Ke committed
273
  // for a validation dataset, we need its score and metric
274
  auto new_score_updater = std::unique_ptr<ScoreUpdater>(new ScoreUpdater(valid_data, num_tree_per_iteration_));
wxchan's avatar
wxchan committed
275
276
  // update score
  for (int i = 0; i < iter_; ++i) {
277
278
279
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
      auto curr_tree = (i + num_init_iteration_) * num_tree_per_iteration_ + cur_tree_id;
      new_score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
280
281
    }
  }
Guolin Ke's avatar
Guolin Ke committed
282
  valid_score_updater_.push_back(std::move(new_score_updater));
Guolin Ke's avatar
Guolin Ke committed
283
  valid_metrics_.emplace_back();
284
285
286
  if (early_stopping_round_ > 0) {
    best_iter_.emplace_back();
    best_score_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
287
    best_msg_.emplace_back();
288
  }
Guolin Ke's avatar
Guolin Ke committed
289
290
  for (const auto& metric : valid_metrics) {
    valid_metrics_.back().push_back(metric);
291
292
293
    if (early_stopping_round_ > 0) {
      best_iter_.back().push_back(0);
      best_score_.back().push_back(kMinScore);
Guolin Ke's avatar
Guolin Ke committed
294
      best_msg_.back().emplace_back();
295
    }
Guolin Ke's avatar
Guolin Ke committed
296
  }
Guolin Ke's avatar
Guolin Ke committed
297
  valid_metrics_.back().shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
298
299
}

300
data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer) {
301
302
303
  if (cnt <= 0) {
    return 0;
  }
304
305
306
307
  data_size_t bag_data_cnt =
    static_cast<data_size_t>(gbdt_config_->bagging_fraction * cnt);
  data_size_t cur_left_cnt = 0;
  data_size_t cur_right_cnt = 0;
Guolin Ke's avatar
Guolin Ke committed
308
  auto right_buffer = buffer + bag_data_cnt;
309
310
  // random bagging, minimal unit is one record
  for (data_size_t i = 0; i < cnt; ++i) {
Guolin Ke's avatar
Guolin Ke committed
311
312
313
    float prob =
      (bag_data_cnt - cur_left_cnt) / static_cast<float>(cnt - i);
    if (cur_rand.NextFloat() < prob) {
314
315
      buffer[cur_left_cnt++] = start + i;
    } else {
Guolin Ke's avatar
Guolin Ke committed
316
      right_buffer[cur_right_cnt++] = start + i;
317
318
319
320
321
    }
  }
  CHECK(cur_left_cnt == bag_data_cnt);
  return cur_left_cnt;
}
Guolin Ke's avatar
Guolin Ke committed
322

323
void GBDT::Bagging(int iter) {
Guolin Ke's avatar
Guolin Ke committed
324
  // if need bagging
325
  if (bag_data_cnt_ < num_data_ && iter % gbdt_config_->bagging_freq == 0) {
Guolin Ke's avatar
Guolin Ke committed
326
    const data_size_t min_inner_size = 1000;
327
328
    data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_;
    if (inner_size < min_inner_size) { inner_size = min_inner_size; }
329
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
330
    #pragma omp parallel for schedule(static,1)
331
    for (int i = 0; i < num_threads_; ++i) {
332
      OMP_LOOP_EX_BEGIN();
333
334
335
336
337
338
      left_cnts_buf_[i] = 0;
      right_cnts_buf_[i] = 0;
      data_size_t cur_start = i * inner_size;
      if (cur_start > num_data_) { continue; }
      data_size_t cur_cnt = inner_size;
      if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; }
Guolin Ke's avatar
Guolin Ke committed
339
340
      Random cur_rand(gbdt_config_->bagging_seed + iter * num_threads_ + i);
      data_size_t cur_left_count = BaggingHelper(cur_rand, cur_start, cur_cnt, tmp_indices_.data() + cur_start);
341
342
343
      offsets_buf_[i] = cur_start;
      left_cnts_buf_[i] = cur_left_count;
      right_cnts_buf_[i] = cur_cnt - cur_left_count;
344
      OMP_LOOP_EX_END();
345
    }
346
    OMP_THROW_EX();
347
348
349
350
351
352
353
354
355
    data_size_t left_cnt = 0;
    left_write_pos_buf_[0] = 0;
    right_write_pos_buf_[0] = 0;
    for (int i = 1; i < num_threads_; ++i) {
      left_write_pos_buf_[i] = left_write_pos_buf_[i - 1] + left_cnts_buf_[i - 1];
      right_write_pos_buf_[i] = right_write_pos_buf_[i - 1] + right_cnts_buf_[i - 1];
    }
    left_cnt = left_write_pos_buf_[num_threads_ - 1] + left_cnts_buf_[num_threads_ - 1];

Guolin Ke's avatar
Guolin Ke committed
356
    #pragma omp parallel for schedule(static, 1)
357
    for (int i = 0; i < num_threads_; ++i) {
358
      OMP_LOOP_EX_BEGIN();
359
360
      if (left_cnts_buf_[i] > 0) {
        std::memcpy(bag_data_indices_.data() + left_write_pos_buf_[i],
361
                    tmp_indices_.data() + offsets_buf_[i], left_cnts_buf_[i] * sizeof(data_size_t));
Guolin Ke's avatar
Guolin Ke committed
362
      }
363
364
      if (right_cnts_buf_[i] > 0) {
        std::memcpy(bag_data_indices_.data() + left_cnt + right_write_pos_buf_[i],
365
                    tmp_indices_.data() + offsets_buf_[i] + left_cnts_buf_[i], right_cnts_buf_[i] * sizeof(data_size_t));
Guolin Ke's avatar
Guolin Ke committed
366
      }
367
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
368
    }
369
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
370
    bag_data_cnt_ = left_cnt;
Guolin Ke's avatar
Guolin Ke committed
371
    Log::Debug("Re-bagging, using %d data to train", bag_data_cnt_);
Guolin Ke's avatar
Guolin Ke committed
372
    // set bagging data to tree learner
Guolin Ke's avatar
Guolin Ke committed
373
374
375
376
    if (!is_use_subset_) {
      tree_learner_->SetBaggingData(bag_data_indices_.data(), bag_data_cnt_);
    } else {
      // get subset
Guolin Ke's avatar
Guolin Ke committed
377
378
      tmp_subset_->ReSize(bag_data_cnt_);
      tmp_subset_->CopySubset(train_data_, bag_data_indices_.data(), bag_data_cnt_, false);
Guolin Ke's avatar
Guolin Ke committed
379
380
      tree_learner_->ResetTrainingData(tmp_subset_.get());
    }
Guolin Ke's avatar
Guolin Ke committed
381
382
383
  }
}

384
void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
385
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
386
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
387
  #endif
388
  // we need to predict out-of-bag scores of data for boosting
Guolin Ke's avatar
Guolin Ke committed
389
  if (num_data_ - bag_data_cnt_ > 0 && !is_use_subset_) {
390
    train_score_updater_->AddScore(tree, bag_data_indices_.data() + bag_data_cnt_, num_data_ - bag_data_cnt_, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
391
  }
Guolin Ke's avatar
Guolin Ke committed
392
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
393
  out_of_bag_score_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
394
  #endif
Guolin Ke's avatar
Guolin Ke committed
395
396
}

397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
/* If the custom "average" is implemented it will be used inplace of the label average (if enabled)
 *
 * An improvement to this is to have options to explicitly choose
 * (i) standard average
 * (ii) custom average if available
 * (iii) any user defined scalar bias (e.g. using a new option "init_score" that overrides (i) and (ii) )
 *
 * (i) and (ii) could be selected as say "auto_init_score" = 0 or 1 etc..
 *
 */
double ObtainAutomaticInitialScore(const ObjectiveFunction* objf, const float* label, data_size_t num_data) {
  double init_score = 0.0f;
  bool got_custom = false;
  if (objf != nullptr) {
    got_custom = objf->GetCustomAverage(&init_score);
  }
  if (!got_custom) {
    double sum_label = 0.0f;
    #pragma omp parallel for schedule(static) reduction(+:sum_label)
    for (data_size_t i = 0; i < num_data; ++i) {
      sum_label += label[i];
    }
    init_score = sum_label / num_data;
Guolin Ke's avatar
Guolin Ke committed
420
421
422
423
424
425
  }
  if (Network::num_machines() > 1) {
    double global_init_score = 0.0f;
    Network::Allreduce(reinterpret_cast<char*>(&init_score),
                       sizeof(init_score), sizeof(init_score),
                       reinterpret_cast<char*>(&global_init_score),
426
                       [] (const char* src, char* dst, int len) {
Guolin Ke's avatar
Guolin Ke committed
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
      int used_size = 0;
      const int type_size = sizeof(double);
      const double *p1;
      double *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const double *>(src);
        p2 = reinterpret_cast<double *>(dst);
        *p2 += *p1;
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    return global_init_score / Network::num_machines();
  } else {
    return init_score;
  }
}

Guolin Ke's avatar
Guolin Ke committed
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
  bool is_finished = false;
  bool need_eval = true;
  auto start_time = std::chrono::steady_clock::now();
  for (int iter = 0; iter < gbdt_config_->num_iterations && !is_finished; ++iter) {
    is_finished = TrainOneIter(nullptr, nullptr, need_eval);
    auto end_time = std::chrono::steady_clock::now();
    // output used time per iteration
    Log::Info("%f seconds elapsed, finished iteration %d", std::chrono::duration<double,
              std::milli>(end_time - start_time) * 1e-3, iter + 1);
    if (snapshot_freq > 0
        && (iter + 1) % snapshot_freq == 0) {
      std::string snapshot_out = model_output_path + ".snapshot_iter_" + std::to_string(iter + 1);
      SaveModelToFile(-1, snapshot_out.c_str());
    }
  }
  SaveModelToFile(-1, model_output_path.c_str());
}

465
bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) {
466
  // boosting from average label; or customized "average" if implemented for the current objective
Guolin Ke's avatar
Guolin Ke committed
467
468
  if (models_.empty()
      && gbdt_config_->boost_from_average
469
      && !train_score_updater_->has_init_score()
470
471
472
      && num_class_ <= 1
      && objective_function_ != nullptr
      && objective_function_->BoostFromAverage()) {
473
474
    auto label = train_data_->metadata().label();
    double init_score = ObtainAutomaticInitialScore(objective_function_, label, num_data_);
475
    std::unique_ptr<Tree> new_tree(new Tree(2));
Guolin Ke's avatar
Guolin Ke committed
476
    new_tree->Split(0, 0, BinType::NumericalBin, 0, 0, 0, init_score, init_score, 0, 0, -1, MissingType::None, true);
477
478
479
480
481
    train_score_updater_->AddScore(init_score, 0);
    for (auto& score_updater : valid_score_updater_) {
      score_updater->AddScore(init_score, 0);
    }
    models_.push_back(std::move(new_tree));
482
483
    boost_from_average_ = true;
  }
Guolin Ke's avatar
Guolin Ke committed
484

Guolin Ke's avatar
Guolin Ke committed
485
486
  // boosting first
  if (gradient == nullptr || hessian == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
487
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
488
    auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
489
    #endif
Guolin Ke's avatar
Guolin Ke committed
490
    Boosting();
491
492
    gradient = gradients_.data();
    hessian = hessians_.data();
Guolin Ke's avatar
Guolin Ke committed
493
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
494
    boosting_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
495
    #endif
Guolin Ke's avatar
Guolin Ke committed
496
  }
Guolin Ke's avatar
Guolin Ke committed
497
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
498
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
499
  #endif
500
501
  // bagging logic
  Bagging(iter_);
Guolin Ke's avatar
Guolin Ke committed
502
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
503
  bagging_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
504
  #endif
Guolin Ke's avatar
Guolin Ke committed
505
  // need to use subset gradient and hessian
Guolin Ke's avatar
Guolin Ke committed
506
  if (is_use_subset_ && bag_data_cnt_ < num_data_) {
Guolin Ke's avatar
Guolin Ke committed
507
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
508
    start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
509
    #endif
510
511
512
513
514
    if (gradients_.empty()) {
      size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
      gradients_.resize(total_size);
      hessians_.resize(total_size);
    }
Guolin Ke's avatar
Guolin Ke committed
515
    // get sub gradients
516
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
517
      size_t bias = static_cast<size_t>(cur_tree_id)* num_data_;
518
      // cannot multi-threading here.
Guolin Ke's avatar
Guolin Ke committed
519
      for (int i = 0; i < bag_data_cnt_; ++i) {
520
521
        gradients_[bias + i] = gradient[bias + bag_data_indices_[i]];
        hessians_[bias + i] = hessian[bias + bag_data_indices_[i]];
Guolin Ke's avatar
Guolin Ke committed
522
523
      }
    }
524
525
    gradient = gradients_.data();
    hessian = hessians_.data();
Guolin Ke's avatar
Guolin Ke committed
526
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
527
    sub_gradient_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
528
    #endif
Guolin Ke's avatar
Guolin Ke committed
529
  }
Guolin Ke's avatar
Guolin Ke committed
530
  bool should_continue = false;
531
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
532
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
533
    start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
534
    #endif
535
    std::unique_ptr<Tree> new_tree(new Tree(2));
536
    if (class_need_train_[cur_tree_id]) {
537
      size_t bias = static_cast<size_t>(cur_tree_id)* num_data_;
538
      new_tree.reset(
539
        tree_learner_->Train(gradient + bias, hessian + bias, is_constant_hessian_));
540
    }
Guolin Ke's avatar
Guolin Ke committed
541
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
542
    tree_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
543
    #endif
Guolin Ke's avatar
Guolin Ke committed
544
545

    if (new_tree->num_leaves() > 1) {
Guolin Ke's avatar
Guolin Ke committed
546
547
548
549
      should_continue = true;
      // shrinkage by learning rate
      new_tree->Shrinkage(shrinkage_rate_);
      // update score
550
551
      UpdateScore(new_tree.get(), cur_tree_id);
      UpdateScoreOutOfBag(new_tree.get(), cur_tree_id);
552
553
    } else {
      // only add default score one-time
554
555
      if (!class_need_train_[cur_tree_id] && models_.size() < static_cast<size_t>(num_tree_per_iteration_)) {
        auto output = class_default_output_[cur_tree_id];
Guolin Ke's avatar
Guolin Ke committed
556
        new_tree->Split(0, 0, BinType::NumericalBin, 0, 0, 0,
Guolin Ke's avatar
Guolin Ke committed
557
                        output, output, 0, 0, -1, MissingType::None, true);
558
        train_score_updater_->AddScore(output, cur_tree_id);
559
        for (auto& score_updater : valid_score_updater_) {
560
          score_updater->AddScore(output, cur_tree_id);
561
562
563
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
564
565
566
    // add model
    models_.push_back(std::move(new_tree));
  }
Guolin Ke's avatar
Guolin Ke committed
567
  if (!should_continue) {
Guolin Ke's avatar
Guolin Ke committed
568
    Log::Warning("Stopped training because there are no more leaves that meet the split requirements.");
569
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
570
571
572
573
      models_.pop_back();
    }
    return true;
  }
Guolin Ke's avatar
Guolin Ke committed
574
575
576
577
578
579
  ++iter_;
  if (is_eval) {
    return EvalAndCheckEarlyStopping();
  } else {
    return false;
  }
580

Guolin Ke's avatar
Guolin Ke committed
581
}
582

wxchan's avatar
wxchan committed
583
void GBDT::RollbackOneIter() {
584
  if (iter_ <= 0) { return; }
wxchan's avatar
wxchan committed
585
586
  int cur_iter = iter_ + num_init_iteration_ - 1;
  // reset score
587
588
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
    auto curr_tree = cur_iter * num_tree_per_iteration_ + cur_tree_id;
wxchan's avatar
wxchan committed
589
    models_[curr_tree]->Shrinkage(-1.0);
590
    train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
591
    for (auto& score_updater : valid_score_updater_) {
592
      score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
593
594
595
    }
  }
  // remove model
596
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
wxchan's avatar
wxchan committed
597
598
599
600
601
    models_.pop_back();
  }
  --iter_;
}

Guolin Ke's avatar
Guolin Ke committed
602
bool GBDT::EvalAndCheckEarlyStopping() {
603
  bool is_met_early_stopping = false;
Guolin Ke's avatar
Guolin Ke committed
604
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
605
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
606
  #endif
607
  // print message for metric
Guolin Ke's avatar
Guolin Ke committed
608
  auto best_msg = OutputMetric(iter_);
Guolin Ke's avatar
Guolin Ke committed
609
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
610
  metric_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
611
  #endif
Guolin Ke's avatar
Guolin Ke committed
612
  is_met_early_stopping = !best_msg.empty();
613
614
  if (is_met_early_stopping) {
    Log::Info("Early stopping at iteration %d, the best iteration round is %d",
615
              iter_, iter_ - early_stopping_round_);
Guolin Ke's avatar
Guolin Ke committed
616
    Log::Info("Output of best iteration round:\n%s", best_msg.c_str());
617
    // pop last early_stopping_round_ models
618
    for (int i = 0; i < early_stopping_round_ * num_tree_per_iteration_; ++i) {
619
620
621
622
      models_.pop_back();
    }
  }
  return is_met_early_stopping;
Guolin Ke's avatar
Guolin Ke committed
623
624
}

625
void GBDT::UpdateScore(const Tree* tree, const int cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
626
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
627
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
628
  #endif
Guolin Ke's avatar
Guolin Ke committed
629
  // update training score
Guolin Ke's avatar
Guolin Ke committed
630
  if (!is_use_subset_) {
631
    train_score_updater_->AddScore(tree_learner_.get(), tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
632
  } else {
633
    train_score_updater_->AddScore(tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
634
  }
Guolin Ke's avatar
Guolin Ke committed
635
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
636
  train_score_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
637
638
  #endif
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
639
  start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
640
  #endif
Guolin Ke's avatar
Guolin Ke committed
641
  // update validation score
Guolin Ke's avatar
Guolin Ke committed
642
  for (auto& score_updater : valid_score_updater_) {
643
    score_updater->AddScore(tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
644
  }
Guolin Ke's avatar
Guolin Ke committed
645
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
646
  valid_score_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
647
  #endif
Guolin Ke's avatar
Guolin Ke committed
648
649
}

Guolin Ke's avatar
Guolin Ke committed
650
651
652
653
std::vector<double> GBDT::EvalOneMetric(const Metric* metric, const double* score) const {
  return metric->Eval(score, objective_function_);
}

Guolin Ke's avatar
Guolin Ke committed
654
655
656
657
std::string GBDT::OutputMetric(int iter) {
  bool need_output = (iter % gbdt_config_->output_freq) == 0;
  std::string ret = "";
  std::stringstream msg_buf;
658
  std::vector<std::pair<size_t, size_t>> meet_early_stopping_pairs;
Guolin Ke's avatar
Guolin Ke committed
659
  // print training metric
Guolin Ke's avatar
Guolin Ke committed
660
  if (need_output) {
661
662
    for (auto& sub_metric : training_metrics_) {
      auto name = sub_metric->GetName();
Guolin Ke's avatar
Guolin Ke committed
663
      auto scores = EvalOneMetric(sub_metric, train_score_updater_->score());
Guolin Ke's avatar
Guolin Ke committed
664
      for (size_t k = 0; k < name.size(); ++k) {
Guolin Ke's avatar
Guolin Ke committed
665
666
667
668
669
670
671
672
        std::stringstream tmp_buf;
        tmp_buf << "Iteration:" << iter
          << ", training " << name[k]
          << " : " << scores[k];
        Log::Info(tmp_buf.str().c_str());
        if (early_stopping_round_ > 0) {
          msg_buf << tmp_buf.str() << std::endl;
        }
673
      }
674
    }
Guolin Ke's avatar
Guolin Ke committed
675
676
  }
  // print validation metric
Guolin Ke's avatar
Guolin Ke committed
677
  if (need_output || early_stopping_round_ > 0) {
678
679
    for (size_t i = 0; i < valid_metrics_.size(); ++i) {
      for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
680
        auto test_scores = EvalOneMetric(valid_metrics_[i][j], valid_score_updater_[i]->score());
Guolin Ke's avatar
Guolin Ke committed
681
682
683
684
685
686
687
688
689
690
691
        auto name = valid_metrics_[i][j]->GetName();
        for (size_t k = 0; k < name.size(); ++k) {
          std::stringstream tmp_buf;
          tmp_buf << "Iteration:" << iter
            << ", valid_" << i + 1 << " " << name[k]
            << " : " << test_scores[k];
          if (need_output) {
            Log::Info(tmp_buf.str().c_str());
          }
          if (early_stopping_round_ > 0) {
            msg_buf << tmp_buf.str() << std::endl;
692
          }
wxchan's avatar
wxchan committed
693
        }
Guolin Ke's avatar
Guolin Ke committed
694
        if (ret.empty() && early_stopping_round_ > 0) {
695
696
697
          auto cur_score = valid_metrics_[i][j]->factor_to_bigger_better() * test_scores.back();
          if (cur_score > best_score_[i][j]) {
            best_score_[i][j] = cur_score;
698
            best_iter_[i][j] = iter;
Guolin Ke's avatar
Guolin Ke committed
699
            meet_early_stopping_pairs.emplace_back(i, j);
700
          } else {
Guolin Ke's avatar
Guolin Ke committed
701
            if (iter - best_iter_[i][j] >= early_stopping_round_) { ret = best_msg_[i][j]; }
702
          }
wxchan's avatar
wxchan committed
703
704
        }
      }
Guolin Ke's avatar
Guolin Ke committed
705
706
    }
  }
Guolin Ke's avatar
Guolin Ke committed
707
708
709
  for (auto& pair : meet_early_stopping_pairs) {
    best_msg_[pair.first][pair.second] = msg_buf.str();
  }
wxchan's avatar
wxchan committed
710
  return ret;
Guolin Ke's avatar
Guolin Ke committed
711
712
}

713
/*! \brief Get eval result */
714
std::vector<double> GBDT::GetEvalAt(int data_idx) const {
Guolin Ke's avatar
Guolin Ke committed
715
  CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
716
717
  std::vector<double> ret;
  if (data_idx == 0) {
718
    for (auto& sub_metric : training_metrics_) {
Guolin Ke's avatar
Guolin Ke committed
719
      auto scores = EvalOneMetric(sub_metric, train_score_updater_->score());
720
721
722
      for (auto score : scores) {
        ret.push_back(score);
      }
723
    }
724
  } else {
725
726
    auto used_idx = data_idx - 1;
    for (size_t j = 0; j < valid_metrics_[used_idx].size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
727
      auto test_scores = EvalOneMetric(valid_metrics_[used_idx][j], valid_score_updater_[used_idx]->score());
728
729
730
      for (auto score : test_scores) {
        ret.push_back(score);
      }
731
732
733
734
735
    }
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
736
/*! \brief Get training scores result */
737
const double* GBDT::GetTrainingScore(int64_t* out_len) {
738
  *out_len = static_cast<int64_t>(train_score_updater_->num_data()) * num_class_;
Guolin Ke's avatar
Guolin Ke committed
739
  return train_score_updater_->score();
740
741
}

742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
void GBDT::PredictContrib(const double* features, double* output, const PredictionEarlyStopInstance* early_stop) const {
  int early_stop_round_counter = 0;
  // set zero
  const int num_features = max_feature_idx_+1;
  std::memset(output, 0, sizeof(double) * num_tree_per_iteration_ * (num_features+1));
  for (int i = 0; i < num_iteration_for_pred_; ++i) {
    // predict all the trees for one iteration
    for (int k = 0; k < num_tree_per_iteration_; ++k) {
      models_[i * num_tree_per_iteration_ + k]->PredictContrib(features, num_features, output + k*(num_features+1));
    }
    // check early stopping
    ++early_stop_round_counter;
    if (early_stop->round_period == early_stop_round_counter) {
      if (early_stop->callback_function(output, num_tree_per_iteration_)) {
        return;
      }
      early_stop_round_counter = 0;
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
763
764
void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
  CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
Guolin Ke's avatar
Guolin Ke committed
765

766
  const double* raw_scores = nullptr;
Guolin Ke's avatar
Guolin Ke committed
767
768
  data_size_t num_data = 0;
  if (data_idx == 0) {
wxchan's avatar
wxchan committed
769
    raw_scores = GetTrainingScore(out_len);
Guolin Ke's avatar
Guolin Ke committed
770
771
772
773
774
    num_data = train_score_updater_->num_data();
  } else {
    auto used_idx = data_idx - 1;
    raw_scores = valid_score_updater_[used_idx]->score();
    num_data = valid_score_updater_[used_idx]->num_data();
775
    *out_len = static_cast<int64_t>(num_data) * num_class_;
Guolin Ke's avatar
Guolin Ke committed
776
  }
Guolin Ke's avatar
Guolin Ke committed
777
  if (objective_function_ != nullptr && !average_output_) {
Guolin Ke's avatar
Guolin Ke committed
778
779
    #pragma omp parallel for schedule(static)
    for (data_size_t i = 0; i < num_data; ++i) {
Guolin Ke's avatar
Guolin Ke committed
780
      std::vector<double> tree_pred(num_tree_per_iteration_);
781
      for (int j = 0; j < num_tree_per_iteration_; ++j) {
Guolin Ke's avatar
Guolin Ke committed
782
        tree_pred[j] = raw_scores[j * num_data + i];
783
      }
Guolin Ke's avatar
Guolin Ke committed
784
785
      std::vector<double> tmp_result(num_class_);
      objective_function_->ConvertOutput(tree_pred.data(), tmp_result.data());
Guolin Ke's avatar
Guolin Ke committed
786
      for (int j = 0; j < num_class_; ++j) {
787
        out_result[j * num_data + i] = static_cast<double>(tmp_result[j]);
Guolin Ke's avatar
Guolin Ke committed
788
789
      }
    }
790
  } else {
Guolin Ke's avatar
Guolin Ke committed
791
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
792
    for (data_size_t i = 0; i < num_data; ++i) {
Guolin Ke's avatar
Guolin Ke committed
793
      std::vector<double> tmp_result(num_tree_per_iteration_);
794
      for (int j = 0; j < num_tree_per_iteration_; ++j) {
Guolin Ke's avatar
Guolin Ke committed
795
        out_result[j * num_data + i] = static_cast<double>(raw_scores[j * num_data + i]);
Guolin Ke's avatar
Guolin Ke committed
796
797
798
799
800
      }
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
801
void GBDT::Boosting() {
802
  if (objective_function_ == nullptr) {
803
804
    Log::Fatal("No object function provided");
  }
Hui Xue's avatar
Hui Xue committed
805
  // objective function will calculate gradients and hessians
806
  int64_t num_score = 0;
807
  objective_function_->
Guolin Ke's avatar
Guolin Ke committed
808
    GetGradients(GetTrainingScore(&num_score), gradients_.data(), hessians_.data());
Guolin Ke's avatar
Guolin Ke committed
809
810
}

811
std::string GBDT::DumpModel(int num_iteration) const {
Guolin Ke's avatar
Guolin Ke committed
812
  std::stringstream str_buf;
wxchan's avatar
wxchan committed
813

Guolin Ke's avatar
Guolin Ke committed
814
  str_buf << "{";
Guolin Ke's avatar
Guolin Ke committed
815
  str_buf << "\"name\":\"" << SubModelName() << "\"," << std::endl;
Guolin Ke's avatar
Guolin Ke committed
816
  str_buf << "\"num_class\":" << num_class_ << "," << std::endl;
817
  str_buf << "\"num_tree_per_iteration\":" << num_tree_per_iteration_ << "," << std::endl;
Guolin Ke's avatar
Guolin Ke committed
818
819
  str_buf << "\"label_index\":" << label_idx_ << "," << std::endl;
  str_buf << "\"max_feature_idx\":" << max_feature_idx_ << "," << std::endl;
wxchan's avatar
wxchan committed
820

821
822
823
  str_buf << "\"feature_names\":[\""
    << Common::Join(feature_names_, "\",\"") << "\"],"
    << std::endl;
Guolin Ke's avatar
Guolin Ke committed
824

Guolin Ke's avatar
Guolin Ke committed
825
  str_buf << "\"tree_info\":[";
826
827
  int num_used_model = static_cast<int>(models_.size());
  if (num_iteration > 0) {
Guolin Ke's avatar
Guolin Ke committed
828
    num_iteration += boost_from_average_ ? 1 : 0;
829
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
830
  }
831
  for (int i = 0; i < num_used_model; ++i) {
wxchan's avatar
wxchan committed
832
    if (i > 0) {
Guolin Ke's avatar
Guolin Ke committed
833
      str_buf << ",";
wxchan's avatar
wxchan committed
834
    }
Guolin Ke's avatar
Guolin Ke committed
835
836
837
838
    str_buf << "{";
    str_buf << "\"tree_index\":" << i << ",";
    str_buf << models_[i]->ToJSON();
    str_buf << "}";
wxchan's avatar
wxchan committed
839
  }
Guolin Ke's avatar
Guolin Ke committed
840
  str_buf << "]" << std::endl;
wxchan's avatar
wxchan committed
841

Guolin Ke's avatar
Guolin Ke committed
842
  str_buf << "}" << std::endl;
wxchan's avatar
wxchan committed
843

Guolin Ke's avatar
Guolin Ke committed
844
  return str_buf.str();
wxchan's avatar
wxchan committed
845
846
}

847
848
849
std::string GBDT::ModelToIfElse(int num_iteration) const {
  std::stringstream str_buf;

850
851
852
853
  str_buf << "#include \"gbdt.h\"" << std::endl;
  str_buf << "#include <LightGBM/utils/common.h>" << std::endl;
  str_buf << "#include <LightGBM/objective_function.h>" << std::endl;
  str_buf << "#include <LightGBM/metric.h>" << std::endl;
cbecker's avatar
cbecker committed
854
  str_buf << "#include <LightGBM/prediction_early_stop.h>" << std::endl;
855
856
857
858
859
860
861
862
  str_buf << "#include <ctime>" << std::endl;
  str_buf << "#include <sstream>" << std::endl;
  str_buf << "#include <chrono>" << std::endl;
  str_buf << "#include <string>" << std::endl;
  str_buf << "#include <vector>" << std::endl;
  str_buf << "#include <utility>" << std::endl;
  str_buf << "namespace LightGBM {" << std::endl;

863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
  int num_used_model = static_cast<int>(models_.size());
  if (num_iteration > 0) {
    num_iteration += boost_from_average_ ? 1 : 0;
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
  }

  // PredictRaw
  for (int i = 0; i < num_used_model; ++i) {
    str_buf << models_[i]->ToIfElse(i, false) << std::endl;
  }

  str_buf << "double (*PredictTreePtr[])(const double*) = { ";
  for (int i = 0; i < num_used_model; ++i) {
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i;
  }
  str_buf << " };" << std::endl << std::endl;

  std::stringstream pred_str_buf;

885
  pred_str_buf << "\t" << "int early_stop_round_counter = 0;" << std::endl;
886
  pred_str_buf << "\t" << "std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);" << std::endl;
cbecker's avatar
cbecker committed
887
  pred_str_buf << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl;
888
  pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
cbecker's avatar
cbecker committed
889
  pred_str_buf << "\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
890
  pred_str_buf << "\t\t" << "}" << std::endl;
891
892
893
  pred_str_buf << "\t\t" << "++early_stop_round_counter;" << std::endl;
  pred_str_buf << "\t\t" << "if (early_stop->round_period == early_stop_round_counter) {" << std::endl;
  pred_str_buf << "\t\t\t" << "if (early_stop->callback_function(output, num_tree_per_iteration_))" << std::endl;
cbecker's avatar
cbecker committed
894
  pred_str_buf << "\t\t\t\t" << "return;" << std::endl;
895
  pred_str_buf << "\t\t\t" << "early_stop_round_counter = 0;" << std::endl;
896
897
898
  pred_str_buf << "\t\t" << "}" << std::endl;
  pred_str_buf << "\t" << "}" << std::endl;

899
  str_buf << "void GBDT::PredictRaw(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << std::endl;
900
901
902
903
904
  str_buf << pred_str_buf.str();
  str_buf << "}" << std::endl;
  str_buf << std::endl;

  // Predict
905
906
  str_buf << "void GBDT::Predict(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << std::endl;
  str_buf << "\t" << "PredictRaw(features, output, early_stop);" << std::endl;
Guolin Ke's avatar
Guolin Ke committed
907
908
909
910
911
912
  str_buf << "\t" << "if (average_output_) {" << std::endl;
  str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
  str_buf << "\t\t\t" << "output[k] /= num_iteration_for_pred_;" << std::endl;
  str_buf << "\t\t" << "}" << std::endl;
  str_buf << "\t" << "}" << std::endl;
  str_buf << "\t" << "else if (objective_function_ != nullptr) {" << std::endl;
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
  str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << std::endl;
  str_buf << "\t" << "}" << std::endl;
  str_buf << "}" << std::endl;
  str_buf << std::endl;

  // PredictLeafIndex
  for (int i = 0; i < num_used_model; ++i) {
    str_buf << models_[i]->ToIfElse(i, true) << std::endl;
  }

  str_buf << "double (*PredictTreeLeafPtr[])(const double*) = { ";
  for (int i = 0; i < num_used_model; ++i) {
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i << "Leaf";
  }
  str_buf << " };" << std::endl << std::endl;

  str_buf << "void GBDT::PredictLeafIndex(const double* features, double *output) const {" << std::endl;
  str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << std::endl;
  str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << std::endl;
  str_buf << "\t\t" << "output[i] = (*PredictTreeLeafPtr[i])(features);" << std::endl;
  str_buf << "\t" << "}" << std::endl;
  str_buf << "}" << std::endl;
938
939
940

  str_buf << "}  // namespace LightGBM" << std::endl;

941
942
943
944
945
946
  return str_buf.str();
}

bool GBDT::SaveModelToIfElse(int num_iteration, const char* filename) const {
  /*! \brief File to write models */
  std::ofstream output_file;
947
948
949
  std::ifstream ifs(filename);
  if (ifs.good()) {
    std::string origin((std::istreambuf_iterator<char>(ifs)),
950
      (std::istreambuf_iterator<char>()));
951
952
953
954
955
956
957
958
959
960
961
    output_file.open(filename);
    output_file << "#define USE_HARD_CODE 0" << std::endl;
    output_file << "#ifndef USE_HARD_CODE" << std::endl;
    output_file << origin << std::endl;
    output_file << "#else" << std::endl;
    output_file << ModelToIfElse(num_iteration);
    output_file << "#endif" << std::endl;
  } else {
    output_file.open(filename);
    output_file << ModelToIfElse(num_iteration);
  }
962

963
  ifs.close();
964
965
966
967
968
  output_file.close();

  return (bool)output_file;
}

Guolin Ke's avatar
Guolin Ke committed
969
std::string GBDT::SaveModelToString(int num_iteration) const {
970
  std::stringstream ss;
971

972
973
974
975
  // output model type
  ss << SubModelName() << std::endl;
  // output number of class
  ss << "num_class=" << num_class_ << std::endl;
976
  ss << "num_tree_per_iteration=" << num_tree_per_iteration_ << std::endl;
977
978
979
980
  // output label index
  ss << "label_index=" << label_idx_ << std::endl;
  // output max_feature_idx
  ss << "max_feature_idx=" << max_feature_idx_ << std::endl;
981
982
983
  // output objective
  if (objective_function_ != nullptr) {
    ss << "objective=" << objective_function_->ToString() << std::endl;
984
  }
985

986
987
988
  if (boost_from_average_) {
    ss << "boost_from_average" << std::endl;
  }
Guolin Ke's avatar
Guolin Ke committed
989

Guolin Ke's avatar
Guolin Ke committed
990
991
992
993
  if (average_output_) {
    ss << "average_output" << std::endl;
  }

994
  ss << "feature_names=" << Common::Join(feature_names_, " ") << std::endl;
995

996
  ss << "feature_infos=" << Common::Join(feature_infos_, " ") << std::endl;
997

998
999
  ss << std::endl;
  int num_used_model = static_cast<int>(models_.size());
Guolin Ke's avatar
Guolin Ke committed
1000
1001
  if (num_iteration > 0) {
    num_iteration += boost_from_average_ ? 1 : 0;
1002
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
1003
1004
1005
1006
1007
1008
1009
  }
  // output tree models
  for (int i = 0; i < num_used_model; ++i) {
    ss << "Tree=" << i << std::endl;
    ss << models_[i]->ToString() << std::endl;
  }

wxchan's avatar
wxchan committed
1010
  std::vector<std::pair<size_t, std::string>> pairs = FeatureImportance(num_used_model);
1011
1012
1013
1014
1015
1016
  ss << std::endl << "feature importances:" << std::endl;
  for (size_t i = 0; i < pairs.size(); ++i) {
    ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << std::endl;
  }

  return ss.str();
1017
1018
}

1019
bool GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
wxchan's avatar
wxchan committed
1020
1021
1022
  /*! \brief File to write models */
  std::ofstream output_file;
  output_file.open(filename);
1023

1024
  output_file << SaveModelToString(num_iteration);
1025

wxchan's avatar
wxchan committed
1026
  output_file.close();
1027
1028

  return (bool)output_file;
Guolin Ke's avatar
Guolin Ke committed
1029
1030
}

1031
bool GBDT::LoadModelFromString(const std::string& model_str) {
Guolin Ke's avatar
Guolin Ke committed
1032
1033
  // use serialized string to restore this object
  models_.clear();
Guolin Ke's avatar
Guolin Ke committed
1034
  std::vector<std::string> lines = Common::SplitLines(model_str.c_str());
1035
1036

  // get number of classes
1037
1038
1039
1040
  auto line = Common::FindFromLines(lines, "num_class=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &num_class_);
  } else {
1041
    Log::Fatal("Model file doesn't specify the number of classes");
1042
    return false;
1043
  }
1044
1045
1046
1047
1048
1049
1050
1051

  line = Common::FindFromLines(lines, "num_tree_per_iteration=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &num_tree_per_iteration_);
  } else {
    num_tree_per_iteration_ = num_class_;
  }

Guolin Ke's avatar
Guolin Ke committed
1052
  // get index of label
1053
1054
1055
1056
  line = Common::FindFromLines(lines, "label_index=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &label_idx_);
  } else {
1057
    Log::Fatal("Model file doesn't specify the label index");
1058
    return false;
Guolin Ke's avatar
Guolin Ke committed
1059
  }
Guolin Ke's avatar
Guolin Ke committed
1060
  // get max_feature_idx first
1061
1062
1063
1064
  line = Common::FindFromLines(lines, "max_feature_idx=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &max_feature_idx_);
  } else {
1065
    Log::Fatal("Model file doesn't specify max_feature_idx");
1066
    return false;
Guolin Ke's avatar
Guolin Ke committed
1067
  }
1068
1069
1070
1071
1072
  // get boost_from_average_
  line = Common::FindFromLines(lines, "boost_from_average");
  if (line.size() > 0) {
    boost_from_average_ = true;
  }
Guolin Ke's avatar
Guolin Ke committed
1073
1074
1075
1076
1077
  // get average_output
  line = Common::FindFromLines(lines, "average_output");
  if (line.size() > 0) {
    average_output_ = true;
  }
Guolin Ke's avatar
Guolin Ke committed
1078
1079
1080
  // get feature names
  line = Common::FindFromLines(lines, "feature_names=");
  if (line.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
1081
    feature_names_ = Common::Split(line.substr(std::strlen("feature_names=")).c_str(), ' ');
Guolin Ke's avatar
Guolin Ke committed
1082
1083
    if (feature_names_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of feature_names");
1084
      return false;
Guolin Ke's avatar
Guolin Ke committed
1085
    }
1086
  } else {
Guolin Ke's avatar
Guolin Ke committed
1087
    Log::Fatal("Model file doesn't contain feature names");
1088
    return false;
Guolin Ke's avatar
Guolin Ke committed
1089
1090
  }

Guolin Ke's avatar
Guolin Ke committed
1091
1092
  line = Common::FindFromLines(lines, "feature_infos=");
  if (line.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
1093
    feature_infos_ = Common::Split(line.substr(std::strlen("feature_infos=")).c_str(), ' ');
Guolin Ke's avatar
Guolin Ke committed
1094
1095
1096
1097
1098
1099
1100
1101
1102
    if (feature_infos_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of feature_infos");
      return false;
    }
  } else {
    Log::Fatal("Model file doesn't contain feature infos");
    return false;
  }

1103
1104
1105
1106
1107
1108
1109
1110
  line = Common::FindFromLines(lines, "objective=");

  if (line.size() > 0) {
    auto str = Common::Split(line.c_str(), '=')[1];
    loaded_objective_.reset(ObjectiveFunction::CreateObjectiveFunction(str));
    objective_function_ = loaded_objective_.get();
  }

Guolin Ke's avatar
Guolin Ke committed
1111
  // get tree models
1112
  size_t i = 0;
Guolin Ke's avatar
Guolin Ke committed
1113
1114
1115
1116
1117
1118
1119
  while (i < lines.size()) {
    size_t find_pos = lines[i].find("Tree=");
    if (find_pos != std::string::npos) {
      ++i;
      int start = static_cast<int>(i);
      while (i < lines.size() && lines[i].find("Tree=") == std::string::npos) { ++i; }
      int end = static_cast<int>(i);
Guolin Ke's avatar
Guolin Ke committed
1120
      std::string tree_str = Common::Join<std::string>(lines, start, end, "\n");
1121
      models_.emplace_back(new Tree(tree_str));
Guolin Ke's avatar
Guolin Ke committed
1122
1123
1124
1125
    } else {
      ++i;
    }
  }
1126
  Log::Info("Finished loading %d models", models_.size());
1127
  num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
wxchan's avatar
wxchan committed
1128
  num_init_iteration_ = num_iteration_for_pred_;
1129
  iter_ = 0;
1130
1131

  return true;
Guolin Ke's avatar
Guolin Ke committed
1132
1133
}

wxchan's avatar
wxchan committed
1134
std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance(int num_used_model) const {
1135

1136
  std::vector<size_t> feature_importances(max_feature_idx_ + 1, 0);
wxchan's avatar
wxchan committed
1137
  for (int iter = 0; iter < num_used_model; ++iter) {
1138
    for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
Guolin Ke's avatar
Guolin Ke committed
1139
1140
1141
      if (models_[iter]->split_gain(split_idx) > 0) {
        ++feature_importances[models_[iter]->split_feature(split_idx)];
      }
wxchan's avatar
wxchan committed
1142
    }
1143
1144
1145
1146
1147
1148
  }
  // store the importance first
  std::vector<std::pair<size_t, std::string>> pairs;
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    if (feature_importances[i] > 0) {
      pairs.emplace_back(feature_importances[i], feature_names_[i]);
1149
    }
1150
1151
1152
  }
  // sort the importance
  std::sort(pairs.begin(), pairs.end(),
1153
1154
            [] (const std::pair<size_t, std::string>& lhs,
                const std::pair<size_t, std::string>& rhs) {
1155
1156
1157
    return lhs.first > rhs.first;
  });
  return pairs;
wxchan's avatar
wxchan committed
1158
1159
}

Guolin Ke's avatar
Guolin Ke committed
1160
}  // namespace LightGBM