"vscode:/vscode.git/clone" did not exist on "a97c444b4cf9d2755bd888911ce65ace1fe13e4b"
gbdt.cpp 35.8 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
#include "gbdt.h"

3
#include <LightGBM/utils/openmp_wrapper.h>
4

Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#include <LightGBM/utils/common.h>

#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
cbecker's avatar
cbecker committed
9
#include <LightGBM/prediction_early_stop.h>
Guolin Ke's avatar
Guolin Ke committed
10
11
12
13
14
15
16

#include <ctime>

#include <sstream>
#include <chrono>
#include <string>
#include <vector>
17
#include <utility>
Guolin Ke's avatar
Guolin Ke committed
18
19
20

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
21
22
23
#ifdef TIMETAG
std::chrono::duration<double, std::milli> boosting_time;
std::chrono::duration<double, std::milli> train_score_time;
Guolin Ke's avatar
Guolin Ke committed
24
std::chrono::duration<double, std::milli> out_of_bag_score_time;
Guolin Ke's avatar
Guolin Ke committed
25
26
27
28
29
30
31
std::chrono::duration<double, std::milli> valid_score_time;
std::chrono::duration<double, std::milli> metric_time;
std::chrono::duration<double, std::milli> bagging_time;
std::chrono::duration<double, std::milli> sub_gradient_time;
std::chrono::duration<double, std::milli> tree_time;
#endif // TIMETAG

32
GBDT::GBDT()
33
  :iter_(0),
34
  train_data_(nullptr),
35
  objective_function_(nullptr),
36
37
  early_stopping_round_(0),
  max_feature_idx_(0),
38
  num_tree_per_iteration_(1),
39
  num_class_(1),
40
  num_iteration_for_pred_(0),
41
  shrinkage_rate_(0.1f),
42
43
  num_init_iteration_(0),
  boost_from_average_(false) {
Guolin Ke's avatar
Guolin Ke committed
44
45
  #pragma omp parallel
  #pragma omp master
46
47
48
  {
    num_threads_ = omp_get_num_threads();
  }
Guolin Ke's avatar
Guolin Ke committed
49
50
51
}

GBDT::~GBDT() {
Guolin Ke's avatar
Guolin Ke committed
52
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
53
54
  Log::Info("GBDT::boosting costs %f", boosting_time * 1e-3);
  Log::Info("GBDT::train_score costs %f", train_score_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
55
  Log::Info("GBDT::out_of_bag_score costs %f", out_of_bag_score_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
56
57
58
59
60
  Log::Info("GBDT::valid_score costs %f", valid_score_time * 1e-3);
  Log::Info("GBDT::metric costs %f", metric_time * 1e-3);
  Log::Info("GBDT::bagging costs %f", bagging_time * 1e-3);
  Log::Info("GBDT::sub_gradient costs %f", sub_gradient_time * 1e-3);
  Log::Info("GBDT::tree costs %f", tree_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
61
  #endif
Guolin Ke's avatar
Guolin Ke committed
62
63
}

64
void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
65
                const std::vector<const Metric*>& training_metrics) {
66
  iter_ = 0;
wxchan's avatar
wxchan committed
67
  num_iteration_for_pred_ = 0;
68
  max_feature_idx_ = 0;
wxchan's avatar
wxchan committed
69
70
  num_class_ = config->num_class;
  train_data_ = nullptr;
Guolin Ke's avatar
Guolin Ke committed
71
  gbdt_config_ = nullptr;
72
  tree_learner_ = nullptr;
73
  ResetTrainingData(config, train_data, objective_function, training_metrics);
wxchan's avatar
wxchan committed
74
75
}

76
void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
77
                             const std::vector<const Metric*>& training_metrics) {
Guolin Ke's avatar
Guolin Ke committed
78
  auto new_config = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
wxchan's avatar
wxchan committed
79
80
81
  if (train_data_ != nullptr && !train_data_->CheckAlign(*train_data)) {
    Log::Fatal("cannot reset training data, since new training data has different bin mappers");
  }
Guolin Ke's avatar
Guolin Ke committed
82
83
84
  early_stopping_round_ = new_config->early_stopping_round;
  shrinkage_rate_ = new_config->learning_rate;

85
86
87
88
  objective_function_ = objective_function;
  num_tree_per_iteration_ = num_class_;
  if (objective_function_ != nullptr) {
    is_constant_hessian_ = objective_function_->IsConstantHessian();
Guolin Ke's avatar
Guolin Ke committed
89
    num_tree_per_iteration_ = objective_function_->NumTreePerIteration();
90
91
92
  } else {
    is_constant_hessian_ = false;
  }
Guolin Ke's avatar
Guolin Ke committed
93

Guolin Ke's avatar
Guolin Ke committed
94
  if (train_data_ != train_data && train_data != nullptr) {
95
    if (tree_learner_ == nullptr) {
96
      tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(new_config->tree_learner_type, new_config->device_type, &new_config->tree_config));
Guolin Ke's avatar
Guolin Ke committed
97
98
    }
    // init tree learner
99
    tree_learner_->Init(train_data, is_constant_hessian_);
Guolin Ke's avatar
Guolin Ke committed
100

Guolin Ke's avatar
Guolin Ke committed
101
102
103
104
105
106
    // push training metrics
    training_metrics_.clear();
    for (const auto& metric : training_metrics) {
      training_metrics_.push_back(metric);
    }
    training_metrics_.shrink_to_fit();
wxchan's avatar
wxchan committed
107
108
    // not same training data, need reset score and others
    // create score tracker
109
    train_score_updater_.reset(new ScoreUpdater(train_data, num_tree_per_iteration_));
wxchan's avatar
wxchan committed
110
111
    // update score
    for (int i = 0; i < iter_; ++i) {
112
113
114
      for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
        auto curr_tree = (i + num_init_iteration_) * num_tree_per_iteration_ + cur_tree_id;
        train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
115
116
117
118
      }
    }
    num_data_ = train_data->num_data();
    // create buffer for gradients and hessians
119
120
121
122
123
    if (objective_function_ != nullptr) {
      size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
      gradients_.resize(total_size);
      hessians_.resize(total_size);
    }
wxchan's avatar
wxchan committed
124
125
126
127
    // get max feature index
    max_feature_idx_ = train_data->num_total_features() - 1;
    // get label index
    label_idx_ = train_data->label_idx();
128
129
    // get feature names
    feature_names_ = train_data->feature_names();
Guolin Ke's avatar
Guolin Ke committed
130
131

    feature_infos_ = train_data->feature_infos();
Guolin Ke's avatar
Guolin Ke committed
132
133
  }

Guolin Ke's avatar
Guolin Ke committed
134
  if ((train_data_ != train_data && train_data != nullptr)
135
      || (gbdt_config_ != nullptr && gbdt_config_->bagging_fraction != new_config->bagging_fraction)) {
wxchan's avatar
wxchan committed
136
    // if need bagging, create buffer
Guolin Ke's avatar
Guolin Ke committed
137
    if (new_config->bagging_fraction < 1.0 && new_config->bagging_freq > 0) {
138
139
      bag_data_cnt_ =
        static_cast<data_size_t>(new_config->bagging_fraction * num_data_);
140
      bag_data_indices_.resize(num_data_);
141
142
143
144
145
146
      tmp_indices_.resize(num_data_);
      offsets_buf_.resize(num_threads_);
      left_cnts_buf_.resize(num_threads_);
      right_cnts_buf_.resize(num_threads_);
      left_write_pos_buf_.resize(num_threads_);
      right_write_pos_buf_.resize(num_threads_);
Guolin Ke's avatar
Guolin Ke committed
147
148
      double average_bag_rate = new_config->bagging_fraction / new_config->bagging_freq;
      is_use_subset_ = false;
149
      if (average_bag_rate <= 0.5) {
Guolin Ke's avatar
Guolin Ke committed
150
        tmp_subset_.reset(new Dataset(bag_data_cnt_));
151
        tmp_subset_->CopyFeatureMapperFrom(train_data);
Guolin Ke's avatar
Guolin Ke committed
152
153
154
        is_use_subset_ = true;
        Log::Debug("use subset for bagging");
      }
wxchan's avatar
wxchan committed
155
156
157
    } else {
      bag_data_cnt_ = num_data_;
      bag_data_indices_.clear();
158
      tmp_indices_.clear();
Guolin Ke's avatar
Guolin Ke committed
159
      is_use_subset_ = false;
wxchan's avatar
wxchan committed
160
    }
Guolin Ke's avatar
Guolin Ke committed
161
  }
wxchan's avatar
wxchan committed
162
  train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
163
164
  if (train_data_ != nullptr) {
    // reset config for tree learner
165
    tree_learner_->ResetConfig(&new_config->tree_config);
166
167
168
    class_need_train_ = std::vector<bool>(num_tree_per_iteration_, true);
    if (objective_function_ != nullptr && objective_function_->SkipEmptyClass()) {
      CHECK(num_tree_per_iteration_ == num_class_);
169
      // + 1 here for the binary classification
Guolin Ke's avatar
Guolin Ke committed
170
      class_default_output_ = std::vector<double>(num_tree_per_iteration_, 0.0f);
171
      auto label = train_data_->metadata().label();
172
      if (num_tree_per_iteration_ > 1) {
Guolin Ke's avatar
Guolin Ke committed
173
174
175
        // multi-class
        std::vector<data_size_t> cnt_per_class(num_tree_per_iteration_, 0);
        for (data_size_t i = 0; i < num_data_; ++i) {
176
177
178
          int index = static_cast<int>(label[i]);
          CHECK(index < num_tree_per_iteration_);
          ++cnt_per_class[index];
Guolin Ke's avatar
Guolin Ke committed
179
        }
180
        for (int i = 0; i < num_tree_per_iteration_; ++i) {
181
182
183
184
185
          if (cnt_per_class[i] == num_data_) {
            class_need_train_[i] = false;
            class_default_output_[i] = -std::log(kEpsilon);
          } else if (cnt_per_class[i] == 0) {
            class_need_train_[i] = false;
186
            class_default_output_[i] = -std::log(1.0f / kEpsilon - 1.0f);
187
188
189
          }
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
190
191
192
193
194
195
196
197
        // binary class
        data_size_t cnt_pos = 0;
        for (data_size_t i = 0; i < num_data_; ++i) {
          if (label[i] > 0) {
            ++cnt_pos;
          }
        }
        if (cnt_pos == 0) {
198
199
          class_need_train_[0] = false;
          class_default_output_[0] = -std::log(1.0f / kEpsilon - 1.0f);
Guolin Ke's avatar
Guolin Ke committed
200
        } else if (cnt_pos == num_data_) {
201
202
203
204
205
          class_need_train_[0] = false;
          class_default_output_[0] = -std::log(kEpsilon);
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
206
  }
Guolin Ke's avatar
Guolin Ke committed
207
  gbdt_config_.reset(new_config.release());
Guolin Ke's avatar
Guolin Ke committed
208
209
}

wxchan's avatar
wxchan committed
210
void GBDT::AddValidDataset(const Dataset* valid_data,
211
                           const std::vector<const Metric*>& valid_metrics) {
wxchan's avatar
wxchan committed
212
213
  if (!train_data_->CheckAlign(*valid_data)) {
    Log::Fatal("cannot add validation data, since it has different bin mappers with training data");
214
  }
Guolin Ke's avatar
Guolin Ke committed
215
  // for a validation dataset, we need its score and metric
216
  auto new_score_updater = std::unique_ptr<ScoreUpdater>(new ScoreUpdater(valid_data, num_tree_per_iteration_));
wxchan's avatar
wxchan committed
217
218
  // update score
  for (int i = 0; i < iter_; ++i) {
219
220
221
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
      auto curr_tree = (i + num_init_iteration_) * num_tree_per_iteration_ + cur_tree_id;
      new_score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
222
223
    }
  }
Guolin Ke's avatar
Guolin Ke committed
224
  valid_score_updater_.push_back(std::move(new_score_updater));
Guolin Ke's avatar
Guolin Ke committed
225
  valid_metrics_.emplace_back();
226
227
228
  if (early_stopping_round_ > 0) {
    best_iter_.emplace_back();
    best_score_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
229
    best_msg_.emplace_back();
230
  }
Guolin Ke's avatar
Guolin Ke committed
231
232
  for (const auto& metric : valid_metrics) {
    valid_metrics_.back().push_back(metric);
233
234
235
    if (early_stopping_round_ > 0) {
      best_iter_.back().push_back(0);
      best_score_.back().push_back(kMinScore);
Guolin Ke's avatar
Guolin Ke committed
236
      best_msg_.back().emplace_back();
237
    }
Guolin Ke's avatar
Guolin Ke committed
238
  }
Guolin Ke's avatar
Guolin Ke committed
239
  valid_metrics_.back().shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
240
241
}

242
data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer) {
243
244
245
  if (cnt <= 0) {
    return 0;
  }
246
247
248
249
  data_size_t bag_data_cnt =
    static_cast<data_size_t>(gbdt_config_->bagging_fraction * cnt);
  data_size_t cur_left_cnt = 0;
  data_size_t cur_right_cnt = 0;
Guolin Ke's avatar
Guolin Ke committed
250
  auto right_buffer = buffer + bag_data_cnt;
251
252
  // random bagging, minimal unit is one record
  for (data_size_t i = 0; i < cnt; ++i) {
Guolin Ke's avatar
Guolin Ke committed
253
254
255
    float prob =
      (bag_data_cnt - cur_left_cnt) / static_cast<float>(cnt - i);
    if (cur_rand.NextFloat() < prob) {
256
257
      buffer[cur_left_cnt++] = start + i;
    } else {
Guolin Ke's avatar
Guolin Ke committed
258
      right_buffer[cur_right_cnt++] = start + i;
259
260
261
262
263
    }
  }
  CHECK(cur_left_cnt == bag_data_cnt);
  return cur_left_cnt;
}
Guolin Ke's avatar
Guolin Ke committed
264

265
void GBDT::Bagging(int iter) {
Guolin Ke's avatar
Guolin Ke committed
266
  // if need bagging
267
  if (bag_data_cnt_ < num_data_ && iter % gbdt_config_->bagging_freq == 0) {
Guolin Ke's avatar
Guolin Ke committed
268
    const data_size_t min_inner_size = 1000;
269
270
    data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_;
    if (inner_size < min_inner_size) { inner_size = min_inner_size; }
271
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
272
    #pragma omp parallel for schedule(static,1)
273
    for (int i = 0; i < num_threads_; ++i) {
274
      OMP_LOOP_EX_BEGIN();
275
276
277
278
279
280
      left_cnts_buf_[i] = 0;
      right_cnts_buf_[i] = 0;
      data_size_t cur_start = i * inner_size;
      if (cur_start > num_data_) { continue; }
      data_size_t cur_cnt = inner_size;
      if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; }
Guolin Ke's avatar
Guolin Ke committed
281
282
      Random cur_rand(gbdt_config_->bagging_seed + iter * num_threads_ + i);
      data_size_t cur_left_count = BaggingHelper(cur_rand, cur_start, cur_cnt, tmp_indices_.data() + cur_start);
283
284
285
      offsets_buf_[i] = cur_start;
      left_cnts_buf_[i] = cur_left_count;
      right_cnts_buf_[i] = cur_cnt - cur_left_count;
286
      OMP_LOOP_EX_END();
287
    }
288
    OMP_THROW_EX();
289
290
291
292
293
294
295
296
297
    data_size_t left_cnt = 0;
    left_write_pos_buf_[0] = 0;
    right_write_pos_buf_[0] = 0;
    for (int i = 1; i < num_threads_; ++i) {
      left_write_pos_buf_[i] = left_write_pos_buf_[i - 1] + left_cnts_buf_[i - 1];
      right_write_pos_buf_[i] = right_write_pos_buf_[i - 1] + right_cnts_buf_[i - 1];
    }
    left_cnt = left_write_pos_buf_[num_threads_ - 1] + left_cnts_buf_[num_threads_ - 1];

Guolin Ke's avatar
Guolin Ke committed
298
    #pragma omp parallel for schedule(static, 1)
299
    for (int i = 0; i < num_threads_; ++i) {
300
      OMP_LOOP_EX_BEGIN();
301
302
      if (left_cnts_buf_[i] > 0) {
        std::memcpy(bag_data_indices_.data() + left_write_pos_buf_[i],
303
                    tmp_indices_.data() + offsets_buf_[i], left_cnts_buf_[i] * sizeof(data_size_t));
Guolin Ke's avatar
Guolin Ke committed
304
      }
305
306
      if (right_cnts_buf_[i] > 0) {
        std::memcpy(bag_data_indices_.data() + left_cnt + right_write_pos_buf_[i],
307
                    tmp_indices_.data() + offsets_buf_[i] + left_cnts_buf_[i], right_cnts_buf_[i] * sizeof(data_size_t));
Guolin Ke's avatar
Guolin Ke committed
308
      }
309
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
310
    }
311
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
312
    bag_data_cnt_ = left_cnt;
Guolin Ke's avatar
Guolin Ke committed
313
    Log::Debug("Re-bagging, using %d data to train", bag_data_cnt_);
Guolin Ke's avatar
Guolin Ke committed
314
    // set bagging data to tree learner
Guolin Ke's avatar
Guolin Ke committed
315
316
317
318
    if (!is_use_subset_) {
      tree_learner_->SetBaggingData(bag_data_indices_.data(), bag_data_cnt_);
    } else {
      // get subset
Guolin Ke's avatar
Guolin Ke committed
319
320
      tmp_subset_->ReSize(bag_data_cnt_);
      tmp_subset_->CopySubset(train_data_, bag_data_indices_.data(), bag_data_cnt_, false);
Guolin Ke's avatar
Guolin Ke committed
321
322
      tree_learner_->ResetTrainingData(tmp_subset_.get());
    }
Guolin Ke's avatar
Guolin Ke committed
323
324
325
  }
}

326
void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
327
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
328
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
329
  #endif
330
  // we need to predict out-of-bag scores of data for boosting
Guolin Ke's avatar
Guolin Ke committed
331
  if (num_data_ - bag_data_cnt_ > 0 && !is_use_subset_) {
332
    train_score_updater_->AddScore(tree, bag_data_indices_.data() + bag_data_cnt_, num_data_ - bag_data_cnt_, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
333
  }
Guolin Ke's avatar
Guolin Ke committed
334
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
335
  out_of_bag_score_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
336
  #endif
Guolin Ke's avatar
Guolin Ke committed
337
338
}

339
bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) {
340
  // boosting from average prediction. It doesn't work well for classification, remove it for now.
Guolin Ke's avatar
Guolin Ke committed
341
342
  if (models_.empty()
      && gbdt_config_->boost_from_average
343
      && !train_score_updater_->has_init_score()
344
345
346
      && num_class_ <= 1
      && objective_function_ != nullptr
      && objective_function_->BoostFromAverage()) {
347
    double init_score = 0.0f;
348
    auto label = train_data_->metadata().label();
349
350
351
    #pragma omp parallel for schedule(static) reduction(+:init_score)
    for (data_size_t i = 0; i < num_data_; ++i) {
      init_score += label[i];
352
    }
353
354
    init_score /= num_data_;
    std::unique_ptr<Tree> new_tree(new Tree(2));
Guolin Ke's avatar
Guolin Ke committed
355
    new_tree->Split(0, 0, BinType::NumericalBin, 0, 0, 0, init_score, init_score, 0, num_data_, -1, 0, 0, 0);
356
357
358
359
360
    train_score_updater_->AddScore(init_score, 0);
    for (auto& score_updater : valid_score_updater_) {
      score_updater->AddScore(init_score, 0);
    }
    models_.push_back(std::move(new_tree));
361
362
    boost_from_average_ = true;
  }
Guolin Ke's avatar
Guolin Ke committed
363
364
  // boosting first
  if (gradient == nullptr || hessian == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
365
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
366
    auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
367
    #endif
Guolin Ke's avatar
Guolin Ke committed
368
    Boosting();
369
370
    gradient = gradients_.data();
    hessian = hessians_.data();
Guolin Ke's avatar
Guolin Ke committed
371
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
372
    boosting_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
373
    #endif
Guolin Ke's avatar
Guolin Ke committed
374
  }
Guolin Ke's avatar
Guolin Ke committed
375
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
376
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
377
  #endif
378
379
  // bagging logic
  Bagging(iter_);
Guolin Ke's avatar
Guolin Ke committed
380
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
381
  bagging_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
382
  #endif
Guolin Ke's avatar
Guolin Ke committed
383
  if (is_use_subset_ && bag_data_cnt_ < num_data_) {
Guolin Ke's avatar
Guolin Ke committed
384
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
385
    start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
386
    #endif
387
388
389
390
391
    if (gradients_.empty()) {
      size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
      gradients_.resize(total_size);
      hessians_.resize(total_size);
    }
Guolin Ke's avatar
Guolin Ke committed
392
    // get sub gradients
393
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
394
      size_t bias = static_cast<size_t>(cur_tree_id)* num_data_;
395
      // cannot multi-threading here.
Guolin Ke's avatar
Guolin Ke committed
396
      for (int i = 0; i < bag_data_cnt_; ++i) {
397
398
        gradients_[bias + i] = gradient[bias + bag_data_indices_[i]];
        hessians_[bias + i] = hessian[bias + bag_data_indices_[i]];
Guolin Ke's avatar
Guolin Ke committed
399
400
      }
    }
401
402
    gradient = gradients_.data();
    hessian = hessians_.data();
Guolin Ke's avatar
Guolin Ke committed
403
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
404
    sub_gradient_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
405
    #endif
Guolin Ke's avatar
Guolin Ke committed
406
  }
Guolin Ke's avatar
Guolin Ke committed
407
  bool should_continue = false;
408
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
409
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
410
    start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
411
    #endif
412
    std::unique_ptr<Tree> new_tree(new Tree(2));
413
    if (class_need_train_[cur_tree_id]) {
414
      size_t bias = static_cast<size_t>(cur_tree_id)* num_data_;
415
      new_tree.reset(
416
        tree_learner_->Train(gradient + bias, hessian + bias, is_constant_hessian_));
417
    }
Guolin Ke's avatar
Guolin Ke committed
418
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
419
    tree_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
420
    #endif
Guolin Ke's avatar
Guolin Ke committed
421
422

    if (new_tree->num_leaves() > 1) {
Guolin Ke's avatar
Guolin Ke committed
423
424
425
426
      should_continue = true;
      // shrinkage by learning rate
      new_tree->Shrinkage(shrinkage_rate_);
      // update score
427
428
      UpdateScore(new_tree.get(), cur_tree_id);
      UpdateScoreOutOfBag(new_tree.get(), cur_tree_id);
429
430
    } else {
      // only add default score one-time
431
432
      if (!class_need_train_[cur_tree_id] && models_.size() < static_cast<size_t>(num_tree_per_iteration_)) {
        auto output = class_default_output_[cur_tree_id];
Guolin Ke's avatar
Guolin Ke committed
433
        new_tree->Split(0, 0, BinType::NumericalBin, 0, 0, 0,
Guolin Ke's avatar
Guolin Ke committed
434
                        output, output, 0, num_data_, -1, 0, 0, 0);
435
        train_score_updater_->AddScore(output, cur_tree_id);
436
        for (auto& score_updater : valid_score_updater_) {
437
          score_updater->AddScore(output, cur_tree_id);
438
439
440
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
441
442
443
    // add model
    models_.push_back(std::move(new_tree));
  }
Guolin Ke's avatar
Guolin Ke committed
444
  if (!should_continue) {
Guolin Ke's avatar
Guolin Ke committed
445
    Log::Warning("Stopped training because there are no more leaves that meet the split requirements.");
446
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
447
448
449
450
      models_.pop_back();
    }
    return true;
  }
Guolin Ke's avatar
Guolin Ke committed
451
452
453
454
455
456
  ++iter_;
  if (is_eval) {
    return EvalAndCheckEarlyStopping();
  } else {
    return false;
  }
457

Guolin Ke's avatar
Guolin Ke committed
458
}
459

wxchan's avatar
wxchan committed
460
void GBDT::RollbackOneIter() {
461
  if (iter_ <= 0) { return; }
wxchan's avatar
wxchan committed
462
463
  int cur_iter = iter_ + num_init_iteration_ - 1;
  // reset score
464
465
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
    auto curr_tree = cur_iter * num_tree_per_iteration_ + cur_tree_id;
wxchan's avatar
wxchan committed
466
    models_[curr_tree]->Shrinkage(-1.0);
467
    train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
468
    for (auto& score_updater : valid_score_updater_) {
469
      score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
470
471
472
    }
  }
  // remove model
473
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
wxchan's avatar
wxchan committed
474
475
476
477
478
    models_.pop_back();
  }
  --iter_;
}

Guolin Ke's avatar
Guolin Ke committed
479
bool GBDT::EvalAndCheckEarlyStopping() {
480
  bool is_met_early_stopping = false;
Guolin Ke's avatar
Guolin Ke committed
481
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
482
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
483
  #endif
484
  // print message for metric
Guolin Ke's avatar
Guolin Ke committed
485
  auto best_msg = OutputMetric(iter_);
Guolin Ke's avatar
Guolin Ke committed
486
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
487
  metric_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
488
  #endif
Guolin Ke's avatar
Guolin Ke committed
489
  is_met_early_stopping = !best_msg.empty();
490
491
  if (is_met_early_stopping) {
    Log::Info("Early stopping at iteration %d, the best iteration round is %d",
492
              iter_, iter_ - early_stopping_round_);
Guolin Ke's avatar
Guolin Ke committed
493
    Log::Info("Output of best iteration round:\n%s", best_msg.c_str());
494
    // pop last early_stopping_round_ models
495
    for (int i = 0; i < early_stopping_round_ * num_tree_per_iteration_; ++i) {
496
497
498
499
      models_.pop_back();
    }
  }
  return is_met_early_stopping;
Guolin Ke's avatar
Guolin Ke committed
500
501
}

502
void GBDT::UpdateScore(const Tree* tree, const int cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
503
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
504
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
505
  #endif
Guolin Ke's avatar
Guolin Ke committed
506
  // update training score
Guolin Ke's avatar
Guolin Ke committed
507
  if (!is_use_subset_) {
508
    train_score_updater_->AddScore(tree_learner_.get(), tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
509
  } else {
510
    train_score_updater_->AddScore(tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
511
  }
Guolin Ke's avatar
Guolin Ke committed
512
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
513
  train_score_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
514
515
  #endif
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
516
  start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
517
  #endif
Guolin Ke's avatar
Guolin Ke committed
518
  // update validation score
Guolin Ke's avatar
Guolin Ke committed
519
  for (auto& score_updater : valid_score_updater_) {
520
    score_updater->AddScore(tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
521
  }
Guolin Ke's avatar
Guolin Ke committed
522
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
523
  valid_score_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
524
  #endif
Guolin Ke's avatar
Guolin Ke committed
525
526
}

Guolin Ke's avatar
Guolin Ke committed
527
528
529
530
std::string GBDT::OutputMetric(int iter) {
  bool need_output = (iter % gbdt_config_->output_freq) == 0;
  std::string ret = "";
  std::stringstream msg_buf;
531
  std::vector<std::pair<size_t, size_t>> meet_early_stopping_pairs;
Guolin Ke's avatar
Guolin Ke committed
532
  // print training metric
Guolin Ke's avatar
Guolin Ke committed
533
  if (need_output) {
534
535
    for (auto& sub_metric : training_metrics_) {
      auto name = sub_metric->GetName();
Guolin Ke's avatar
Guolin Ke committed
536
      auto scores = sub_metric->Eval(train_score_updater_->score(), objective_function_);
Guolin Ke's avatar
Guolin Ke committed
537
      for (size_t k = 0; k < name.size(); ++k) {
Guolin Ke's avatar
Guolin Ke committed
538
539
540
541
542
543
544
545
        std::stringstream tmp_buf;
        tmp_buf << "Iteration:" << iter
          << ", training " << name[k]
          << " : " << scores[k];
        Log::Info(tmp_buf.str().c_str());
        if (early_stopping_round_ > 0) {
          msg_buf << tmp_buf.str() << std::endl;
        }
546
      }
547
    }
Guolin Ke's avatar
Guolin Ke committed
548
549
  }
  // print validation metric
Guolin Ke's avatar
Guolin Ke committed
550
  if (need_output || early_stopping_round_ > 0) {
551
552
    for (size_t i = 0; i < valid_metrics_.size(); ++i) {
      for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
553
        auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score(),
Guolin Ke's avatar
Guolin Ke committed
554
                                                      objective_function_);
Guolin Ke's avatar
Guolin Ke committed
555
556
557
558
559
560
561
562
563
564
565
        auto name = valid_metrics_[i][j]->GetName();
        for (size_t k = 0; k < name.size(); ++k) {
          std::stringstream tmp_buf;
          tmp_buf << "Iteration:" << iter
            << ", valid_" << i + 1 << " " << name[k]
            << " : " << test_scores[k];
          if (need_output) {
            Log::Info(tmp_buf.str().c_str());
          }
          if (early_stopping_round_ > 0) {
            msg_buf << tmp_buf.str() << std::endl;
566
          }
wxchan's avatar
wxchan committed
567
        }
Guolin Ke's avatar
Guolin Ke committed
568
        if (ret.empty() && early_stopping_round_ > 0) {
569
570
571
          auto cur_score = valid_metrics_[i][j]->factor_to_bigger_better() * test_scores.back();
          if (cur_score > best_score_[i][j]) {
            best_score_[i][j] = cur_score;
572
            best_iter_[i][j] = iter;
Guolin Ke's avatar
Guolin Ke committed
573
            meet_early_stopping_pairs.emplace_back(i, j);
574
          } else {
Guolin Ke's avatar
Guolin Ke committed
575
            if (iter - best_iter_[i][j] >= early_stopping_round_) { ret = best_msg_[i][j]; }
576
          }
wxchan's avatar
wxchan committed
577
578
        }
      }
Guolin Ke's avatar
Guolin Ke committed
579
580
    }
  }
Guolin Ke's avatar
Guolin Ke committed
581
582
583
  for (auto& pair : meet_early_stopping_pairs) {
    best_msg_[pair.first][pair.second] = msg_buf.str();
  }
wxchan's avatar
wxchan committed
584
  return ret;
Guolin Ke's avatar
Guolin Ke committed
585
586
}

587
/*! \brief Get eval result */
588
std::vector<double> GBDT::GetEvalAt(int data_idx) const {
Guolin Ke's avatar
Guolin Ke committed
589
  CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
590
591
  std::vector<double> ret;
  if (data_idx == 0) {
592
    for (auto& sub_metric : training_metrics_) {
Guolin Ke's avatar
Guolin Ke committed
593
      auto scores = sub_metric->Eval(train_score_updater_->score(), objective_function_);
594
595
596
      for (auto score : scores) {
        ret.push_back(score);
      }
597
    }
598
  } else {
599
600
    auto used_idx = data_idx - 1;
    for (size_t j = 0; j < valid_metrics_[used_idx].size(); ++j) {
601
      auto test_scores = valid_metrics_[used_idx][j]->Eval(valid_score_updater_[used_idx]->score(),
Guolin Ke's avatar
Guolin Ke committed
602
                                                           objective_function_);
603
604
605
      for (auto score : test_scores) {
        ret.push_back(score);
      }
606
607
608
609
610
    }
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
611
/*! \brief Get training scores result */
612
const double* GBDT::GetTrainingScore(int64_t* out_len) {
613
  *out_len = static_cast<int64_t>(train_score_updater_->num_data()) * num_class_;
Guolin Ke's avatar
Guolin Ke committed
614
  return train_score_updater_->score();
615
616
}

Guolin Ke's avatar
Guolin Ke committed
617
618
void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
  CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
Guolin Ke's avatar
Guolin Ke committed
619

620
  const double* raw_scores = nullptr;
Guolin Ke's avatar
Guolin Ke committed
621
622
  data_size_t num_data = 0;
  if (data_idx == 0) {
wxchan's avatar
wxchan committed
623
    raw_scores = GetTrainingScore(out_len);
Guolin Ke's avatar
Guolin Ke committed
624
625
626
627
628
    num_data = train_score_updater_->num_data();
  } else {
    auto used_idx = data_idx - 1;
    raw_scores = valid_score_updater_[used_idx]->score();
    num_data = valid_score_updater_[used_idx]->num_data();
629
    *out_len = static_cast<int64_t>(num_data) * num_class_;
Guolin Ke's avatar
Guolin Ke committed
630
  }
631
  if (objective_function_ != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
632
633
    #pragma omp parallel for schedule(static)
    for (data_size_t i = 0; i < num_data; ++i) {
Guolin Ke's avatar
Guolin Ke committed
634
      std::vector<double> tree_pred(num_tree_per_iteration_);
635
      for (int j = 0; j < num_tree_per_iteration_; ++j) {
Guolin Ke's avatar
Guolin Ke committed
636
        tree_pred[j] = raw_scores[j * num_data + i];
637
      }
Guolin Ke's avatar
Guolin Ke committed
638
639
      std::vector<double> tmp_result(num_class_);
      objective_function_->ConvertOutput(tree_pred.data(), tmp_result.data());
Guolin Ke's avatar
Guolin Ke committed
640
      for (int j = 0; j < num_class_; ++j) {
641
        out_result[j * num_data + i] = static_cast<double>(tmp_result[j]);
Guolin Ke's avatar
Guolin Ke committed
642
643
      }
    }
644
  } else {
Guolin Ke's avatar
Guolin Ke committed
645
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
646
    for (data_size_t i = 0; i < num_data; ++i) {
Guolin Ke's avatar
Guolin Ke committed
647
      std::vector<double> tmp_result(num_tree_per_iteration_);
648
      for (int j = 0; j < num_tree_per_iteration_; ++j) {
Guolin Ke's avatar
Guolin Ke committed
649
        out_result[j * num_data + i] = static_cast<double>(raw_scores[j * num_data + i]);
Guolin Ke's avatar
Guolin Ke committed
650
651
652
653
654
      }
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
655
void GBDT::Boosting() {
656
  if (objective_function_ == nullptr) {
657
658
    Log::Fatal("No object function provided");
  }
Hui Xue's avatar
Hui Xue committed
659
  // objective function will calculate gradients and hessians
660
  int64_t num_score = 0;
661
  objective_function_->
Guolin Ke's avatar
Guolin Ke committed
662
    GetGradients(GetTrainingScore(&num_score), gradients_.data(), hessians_.data());
Guolin Ke's avatar
Guolin Ke committed
663
664
}

665
std::string GBDT::DumpModel(int num_iteration) const {
Guolin Ke's avatar
Guolin Ke committed
666
  std::stringstream str_buf;
wxchan's avatar
wxchan committed
667

Guolin Ke's avatar
Guolin Ke committed
668
  str_buf << "{";
Guolin Ke's avatar
Guolin Ke committed
669
  str_buf << "\"name\":\"" << SubModelName() << "\"," << std::endl;
Guolin Ke's avatar
Guolin Ke committed
670
  str_buf << "\"num_class\":" << num_class_ << "," << std::endl;
671
  str_buf << "\"num_tree_per_iteration\":" << num_tree_per_iteration_ << "," << std::endl;
Guolin Ke's avatar
Guolin Ke committed
672
673
  str_buf << "\"label_index\":" << label_idx_ << "," << std::endl;
  str_buf << "\"max_feature_idx\":" << max_feature_idx_ << "," << std::endl;
wxchan's avatar
wxchan committed
674

675
676
677
  str_buf << "\"feature_names\":[\""
    << Common::Join(feature_names_, "\",\"") << "\"],"
    << std::endl;
Guolin Ke's avatar
Guolin Ke committed
678

Guolin Ke's avatar
Guolin Ke committed
679
  str_buf << "\"tree_info\":[";
680
681
  int num_used_model = static_cast<int>(models_.size());
  if (num_iteration > 0) {
Guolin Ke's avatar
Guolin Ke committed
682
    num_iteration += boost_from_average_ ? 1 : 0;
683
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
684
  }
685
  for (int i = 0; i < num_used_model; ++i) {
wxchan's avatar
wxchan committed
686
    if (i > 0) {
Guolin Ke's avatar
Guolin Ke committed
687
      str_buf << ",";
wxchan's avatar
wxchan committed
688
    }
Guolin Ke's avatar
Guolin Ke committed
689
690
691
692
    str_buf << "{";
    str_buf << "\"tree_index\":" << i << ",";
    str_buf << models_[i]->ToJSON();
    str_buf << "}";
wxchan's avatar
wxchan committed
693
  }
Guolin Ke's avatar
Guolin Ke committed
694
  str_buf << "]" << std::endl;
wxchan's avatar
wxchan committed
695

Guolin Ke's avatar
Guolin Ke committed
696
  str_buf << "}" << std::endl;
wxchan's avatar
wxchan committed
697

Guolin Ke's avatar
Guolin Ke committed
698
  return str_buf.str();
wxchan's avatar
wxchan committed
699
700
}

701
702
703
std::string GBDT::ModelToIfElse(int num_iteration) const {
  std::stringstream str_buf;

704
705
706
707
  str_buf << "#include \"gbdt.h\"" << std::endl;
  str_buf << "#include <LightGBM/utils/common.h>" << std::endl;
  str_buf << "#include <LightGBM/objective_function.h>" << std::endl;
  str_buf << "#include <LightGBM/metric.h>" << std::endl;
cbecker's avatar
cbecker committed
708
  str_buf << "#include <LightGBM/prediction_early_stop.h>" << std::endl;
709
710
711
712
713
714
715
716
  str_buf << "#include <ctime>" << std::endl;
  str_buf << "#include <sstream>" << std::endl;
  str_buf << "#include <chrono>" << std::endl;
  str_buf << "#include <string>" << std::endl;
  str_buf << "#include <vector>" << std::endl;
  str_buf << "#include <utility>" << std::endl;
  str_buf << "namespace LightGBM {" << std::endl;

717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
  int num_used_model = static_cast<int>(models_.size());
  if (num_iteration > 0) {
    num_iteration += boost_from_average_ ? 1 : 0;
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
  }

  // PredictRaw
  for (int i = 0; i < num_used_model; ++i) {
    str_buf << models_[i]->ToIfElse(i, false) << std::endl;
  }

  str_buf << "double (*PredictTreePtr[])(const double*) = { ";
  for (int i = 0; i < num_used_model; ++i) {
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i;
  }
  str_buf << " };" << std::endl << std::endl;

  std::stringstream pred_str_buf;

739
  pred_str_buf << "\t" << "int early_stop_round_counter = 0;" << std::endl;
cbecker's avatar
cbecker committed
740
  pred_str_buf << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl;
741
  pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
cbecker's avatar
cbecker committed
742
  pred_str_buf << "\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
743
  pred_str_buf << "\t\t" << "}" << std::endl;
744
745
746
  pred_str_buf << "\t\t" << "++early_stop_round_counter;" << std::endl;
  pred_str_buf << "\t\t" << "if (early_stop->round_period == early_stop_round_counter) {" << std::endl;
  pred_str_buf << "\t\t\t" << "if (early_stop->callback_function(output, num_tree_per_iteration_))" << std::endl;
cbecker's avatar
cbecker committed
747
  pred_str_buf << "\t\t\t\t" << "return;" << std::endl;
748
  pred_str_buf << "\t\t\t" << "early_stop_round_counter = 0;" << std::endl;
749
750
751
  pred_str_buf << "\t\t" << "}" << std::endl;
  pred_str_buf << "\t" << "}" << std::endl;

752
  str_buf << "void GBDT::PredictRaw(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << std::endl;
753
754
755
756
757
  str_buf << pred_str_buf.str();
  str_buf << "}" << std::endl;
  str_buf << std::endl;

  // Predict
758
759
  str_buf << "void GBDT::Predict(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << std::endl;
  str_buf << "\t" << "PredictRaw(features, output, early_stop);" << std::endl;
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
  str_buf << "\t" << "if (objective_function_ != nullptr) {" << std::endl;
  str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << std::endl;
  str_buf << "\t" << "}" << std::endl;
  str_buf << "}" << std::endl;
  str_buf << std::endl;

  // PredictLeafIndex
  for (int i = 0; i < num_used_model; ++i) {
    str_buf << models_[i]->ToIfElse(i, true) << std::endl;
  }

  str_buf << "double (*PredictTreeLeafPtr[])(const double*) = { ";
  for (int i = 0; i < num_used_model; ++i) {
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i << "Leaf";
  }
  str_buf << " };" << std::endl << std::endl;

  str_buf << "void GBDT::PredictLeafIndex(const double* features, double *output) const {" << std::endl;
  str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << std::endl;
  str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << std::endl;
  str_buf << "\t\t" << "output[i] = (*PredictTreeLeafPtr[i])(features);" << std::endl;
  str_buf << "\t" << "}" << std::endl;
  str_buf << "}" << std::endl;
786
787
788

  str_buf << "}  // namespace LightGBM" << std::endl;

789
790
791
792
793
794
  return str_buf.str();
}

bool GBDT::SaveModelToIfElse(int num_iteration, const char* filename) const {
  /*! \brief File to write models */
  std::ofstream output_file;
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
  std::ifstream ifs(filename);
  if (ifs.good()) {
    std::string origin((std::istreambuf_iterator<char>(ifs)),
                       (std::istreambuf_iterator<char>()));
    output_file.open(filename);
    output_file << "#define USE_HARD_CODE 0" << std::endl;
    output_file << "#ifndef USE_HARD_CODE" << std::endl;
    output_file << origin << std::endl;
    output_file << "#else" << std::endl;
    output_file << ModelToIfElse(num_iteration);
    output_file << "#endif" << std::endl;
  } else {
    output_file.open(filename);
    output_file << ModelToIfElse(num_iteration);
  }
810

811
  ifs.close();
812
813
814
815
816
  output_file.close();

  return (bool)output_file;
}

Guolin Ke's avatar
Guolin Ke committed
817
std::string GBDT::SaveModelToString(int num_iteration) const {
818
  std::stringstream ss;
819

820
821
822
823
  // output model type
  ss << SubModelName() << std::endl;
  // output number of class
  ss << "num_class=" << num_class_ << std::endl;
824
  ss << "num_tree_per_iteration=" << num_tree_per_iteration_ << std::endl;
825
826
827
828
  // output label index
  ss << "label_index=" << label_idx_ << std::endl;
  // output max_feature_idx
  ss << "max_feature_idx=" << max_feature_idx_ << std::endl;
829
830
831
  // output objective
  if (objective_function_ != nullptr) {
    ss << "objective=" << objective_function_->ToString() << std::endl;
832
  }
833

834
835
836
  if (boost_from_average_) {
    ss << "boost_from_average" << std::endl;
  }
Guolin Ke's avatar
Guolin Ke committed
837

838
  ss << "feature_names=" << Common::Join(feature_names_, " ") << std::endl;
839

840
  ss << "feature_infos=" << Common::Join(feature_infos_, " ") << std::endl;
841

842
843
  ss << std::endl;
  int num_used_model = static_cast<int>(models_.size());
Guolin Ke's avatar
Guolin Ke committed
844
845
  if (num_iteration > 0) {
    num_iteration += boost_from_average_ ? 1 : 0;
846
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
847
848
849
850
851
852
853
854
855
856
857
858
859
860
  }
  // output tree models
  for (int i = 0; i < num_used_model; ++i) {
    ss << "Tree=" << i << std::endl;
    ss << models_[i]->ToString() << std::endl;
  }

  std::vector<std::pair<size_t, std::string>> pairs = FeatureImportance();
  ss << std::endl << "feature importances:" << std::endl;
  for (size_t i = 0; i < pairs.size(); ++i) {
    ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << std::endl;
  }

  return ss.str();
861
862
}

863
bool GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
wxchan's avatar
wxchan committed
864
865
866
  /*! \brief File to write models */
  std::ofstream output_file;
  output_file.open(filename);
867

868
  output_file << SaveModelToString(num_iteration);
869

wxchan's avatar
wxchan committed
870
  output_file.close();
871
872

  return (bool)output_file;
Guolin Ke's avatar
Guolin Ke committed
873
874
}

875
bool GBDT::LoadModelFromString(const std::string& model_str) {
Guolin Ke's avatar
Guolin Ke committed
876
877
  // use serialized string to restore this object
  models_.clear();
Guolin Ke's avatar
Guolin Ke committed
878
  std::vector<std::string> lines = Common::SplitLines(model_str.c_str());
879
880

  // get number of classes
881
882
883
884
  auto line = Common::FindFromLines(lines, "num_class=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &num_class_);
  } else {
885
    Log::Fatal("Model file doesn't specify the number of classes");
886
    return false;
887
  }
888
889
890
891
892
893
894
895

  line = Common::FindFromLines(lines, "num_tree_per_iteration=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &num_tree_per_iteration_);
  } else {
    num_tree_per_iteration_ = num_class_;
  }

Guolin Ke's avatar
Guolin Ke committed
896
  // get index of label
897
898
899
900
  line = Common::FindFromLines(lines, "label_index=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &label_idx_);
  } else {
901
    Log::Fatal("Model file doesn't specify the label index");
902
    return false;
Guolin Ke's avatar
Guolin Ke committed
903
  }
Guolin Ke's avatar
Guolin Ke committed
904
  // get max_feature_idx first
905
906
907
908
  line = Common::FindFromLines(lines, "max_feature_idx=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &max_feature_idx_);
  } else {
909
    Log::Fatal("Model file doesn't specify max_feature_idx");
910
    return false;
Guolin Ke's avatar
Guolin Ke committed
911
  }
912
913
914
915
916
  // get boost_from_average_
  line = Common::FindFromLines(lines, "boost_from_average");
  if (line.size() > 0) {
    boost_from_average_ = true;
  }
Guolin Ke's avatar
Guolin Ke committed
917
918
919
  // get feature names
  line = Common::FindFromLines(lines, "feature_names=");
  if (line.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
920
    feature_names_ = Common::Split(line.substr(std::strlen("feature_names=")).c_str(), ' ');
Guolin Ke's avatar
Guolin Ke committed
921
922
    if (feature_names_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of feature_names");
923
      return false;
Guolin Ke's avatar
Guolin Ke committed
924
    }
925
  } else {
Guolin Ke's avatar
Guolin Ke committed
926
    Log::Fatal("Model file doesn't contain feature names");
927
    return false;
Guolin Ke's avatar
Guolin Ke committed
928
929
  }

Guolin Ke's avatar
Guolin Ke committed
930
931
  line = Common::FindFromLines(lines, "feature_infos=");
  if (line.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
932
    feature_infos_ = Common::Split(line.substr(std::strlen("feature_infos=")).c_str(), ' ');
Guolin Ke's avatar
Guolin Ke committed
933
934
935
936
937
938
939
940
941
    if (feature_infos_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of feature_infos");
      return false;
    }
  } else {
    Log::Fatal("Model file doesn't contain feature infos");
    return false;
  }

942
943
944
945
946
947
948
949
  line = Common::FindFromLines(lines, "objective=");

  if (line.size() > 0) {
    auto str = Common::Split(line.c_str(), '=')[1];
    loaded_objective_.reset(ObjectiveFunction::CreateObjectiveFunction(str));
    objective_function_ = loaded_objective_.get();
  }

Guolin Ke's avatar
Guolin Ke committed
950
  // get tree models
951
  size_t i = 0;
Guolin Ke's avatar
Guolin Ke committed
952
953
954
955
956
957
958
  while (i < lines.size()) {
    size_t find_pos = lines[i].find("Tree=");
    if (find_pos != std::string::npos) {
      ++i;
      int start = static_cast<int>(i);
      while (i < lines.size() && lines[i].find("Tree=") == std::string::npos) { ++i; }
      int end = static_cast<int>(i);
Guolin Ke's avatar
Guolin Ke committed
959
      std::string tree_str = Common::Join<std::string>(lines, start, end, "\n");
960
      models_.emplace_back(new Tree(tree_str));
Guolin Ke's avatar
Guolin Ke committed
961
962
963
964
    } else {
      ++i;
    }
  }
965
  Log::Info("Finished loading %d models", models_.size());
966
  num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
wxchan's avatar
wxchan committed
967
  num_init_iteration_ = num_iteration_for_pred_;
968
  iter_ = 0;
969
970

  return true;
Guolin Ke's avatar
Guolin Ke committed
971
972
}

wxchan's avatar
wxchan committed
973
std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const {
974

975
  std::vector<size_t> feature_importances(max_feature_idx_ + 1, 0);
976
977
  for (size_t iter = 0; iter < models_.size(); ++iter) {
    for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
Guolin Ke's avatar
Guolin Ke committed
978
979
980
      if (models_[iter]->split_gain(split_idx) > 0) {
        ++feature_importances[models_[iter]->split_feature(split_idx)];
      }
wxchan's avatar
wxchan committed
981
    }
982
983
984
985
986
987
  }
  // store the importance first
  std::vector<std::pair<size_t, std::string>> pairs;
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    if (feature_importances[i] > 0) {
      pairs.emplace_back(feature_importances[i], feature_names_[i]);
988
    }
989
990
991
  }
  // sort the importance
  std::sort(pairs.begin(), pairs.end(),
Guolin Ke's avatar
Guolin Ke committed
992
993
            [](const std::pair<size_t, std::string>& lhs,
               const std::pair<size_t, std::string>& rhs) {
994
995
996
    return lhs.first > rhs.first;
  });
  return pairs;
wxchan's avatar
wxchan committed
997
998
}

Guolin Ke's avatar
Guolin Ke committed
999
}  // namespace LightGBM