".github/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "60c92a969a0075cd271d55e40b611aaf52b061ed"
gbdt.cpp 36.1 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
#include "gbdt.h"

3
#include <LightGBM/utils/openmp_wrapper.h>
4

Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
9
10
11
12
13
14
15
#include <LightGBM/utils/common.h>

#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>

#include <ctime>

#include <sstream>
#include <chrono>
#include <string>
#include <vector>
16
#include <utility>
Guolin Ke's avatar
Guolin Ke committed
17
18
19

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
20
21
22
#ifdef TIMETAG
std::chrono::duration<double, std::milli> boosting_time;
std::chrono::duration<double, std::milli> train_score_time;
Guolin Ke's avatar
Guolin Ke committed
23
std::chrono::duration<double, std::milli> out_of_bag_score_time;
Guolin Ke's avatar
Guolin Ke committed
24
25
26
27
28
29
30
std::chrono::duration<double, std::milli> valid_score_time;
std::chrono::duration<double, std::milli> metric_time;
std::chrono::duration<double, std::milli> bagging_time;
std::chrono::duration<double, std::milli> sub_gradient_time;
std::chrono::duration<double, std::milli> tree_time;
#endif // TIMETAG

31
GBDT::GBDT()
32
  :iter_(0),
33
  train_data_(nullptr),
34
  objective_function_(nullptr),
35
36
  early_stopping_round_(0),
  max_feature_idx_(0),
37
  num_tree_per_iteration_(1),
38
  num_class_(1),
39
  num_iteration_for_pred_(0),
40
  shrinkage_rate_(0.1f),
41
42
  num_init_iteration_(0),
  boost_from_average_(false) {
Guolin Ke's avatar
Guolin Ke committed
43
44
  #pragma omp parallel
  #pragma omp master
45
46
47
    {
      num_threads_ = omp_get_num_threads();
    }
Guolin Ke's avatar
Guolin Ke committed
48
49
50
}

GBDT::~GBDT() {
Guolin Ke's avatar
Guolin Ke committed
51
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
52
53
  Log::Info("GBDT::boosting costs %f", boosting_time * 1e-3);
  Log::Info("GBDT::train_score costs %f", train_score_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
54
  Log::Info("GBDT::out_of_bag_score costs %f", out_of_bag_score_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
55
56
57
58
59
  Log::Info("GBDT::valid_score costs %f", valid_score_time * 1e-3);
  Log::Info("GBDT::metric costs %f", metric_time * 1e-3);
  Log::Info("GBDT::bagging costs %f", bagging_time * 1e-3);
  Log::Info("GBDT::sub_gradient costs %f", sub_gradient_time * 1e-3);
  Log::Info("GBDT::tree costs %f", tree_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
60
  #endif
Guolin Ke's avatar
Guolin Ke committed
61
62
}

63
void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
64
                const std::vector<const Metric*>& training_metrics) {
65
  iter_ = 0;
wxchan's avatar
wxchan committed
66
  num_iteration_for_pred_ = 0;
67
  max_feature_idx_ = 0;
wxchan's avatar
wxchan committed
68
69
  num_class_ = config->num_class;
  train_data_ = nullptr;
Guolin Ke's avatar
Guolin Ke committed
70
  gbdt_config_ = nullptr;
71
  tree_learner_ = nullptr;
72
  ResetTrainingData(config, train_data, objective_function, training_metrics);
wxchan's avatar
wxchan committed
73
74
}

75
void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
76
                             const std::vector<const Metric*>& training_metrics) {
Guolin Ke's avatar
Guolin Ke committed
77
  auto new_config = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
wxchan's avatar
wxchan committed
78
79
80
  if (train_data_ != nullptr && !train_data_->CheckAlign(*train_data)) {
    Log::Fatal("cannot reset training data, since new training data has different bin mappers");
  }
Guolin Ke's avatar
Guolin Ke committed
81
82
83
  early_stopping_round_ = new_config->early_stopping_round;
  shrinkage_rate_ = new_config->learning_rate;

84
85
86
87
  objective_function_ = objective_function;
  num_tree_per_iteration_ = num_class_;
  if (objective_function_ != nullptr) {
    is_constant_hessian_ = objective_function_->IsConstantHessian();
Guolin Ke's avatar
Guolin Ke committed
88
    num_tree_per_iteration_ = objective_function_->NumTreePerIteration();
89
90
91
  } else {
    is_constant_hessian_ = false;
  }
Guolin Ke's avatar
Guolin Ke committed
92

Guolin Ke's avatar
Guolin Ke committed
93
  if (train_data_ != train_data && train_data != nullptr) {
94
    if (tree_learner_ == nullptr) {
95
      tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(new_config->tree_learner_type, new_config->device_type, &new_config->tree_config));
Guolin Ke's avatar
Guolin Ke committed
96
97
    }
    // init tree learner
98
    tree_learner_->Init(train_data, is_constant_hessian_);
Guolin Ke's avatar
Guolin Ke committed
99

Guolin Ke's avatar
Guolin Ke committed
100
101
102
103
104
105
    // push training metrics
    training_metrics_.clear();
    for (const auto& metric : training_metrics) {
      training_metrics_.push_back(metric);
    }
    training_metrics_.shrink_to_fit();
wxchan's avatar
wxchan committed
106
107
    // not same training data, need reset score and others
    // create score tracker
108
    train_score_updater_.reset(new ScoreUpdater(train_data, num_tree_per_iteration_));
wxchan's avatar
wxchan committed
109
110
    // update score
    for (int i = 0; i < iter_; ++i) {
111
112
113
      for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
        auto curr_tree = (i + num_init_iteration_) * num_tree_per_iteration_ + cur_tree_id;
        train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
114
115
116
117
      }
    }
    num_data_ = train_data->num_data();
    // create buffer for gradients and hessians
118
119
120
121
122
    if (objective_function_ != nullptr) {
      size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
      gradients_.resize(total_size);
      hessians_.resize(total_size);
    }
wxchan's avatar
wxchan committed
123
124
125
126
    // get max feature index
    max_feature_idx_ = train_data->num_total_features() - 1;
    // get label index
    label_idx_ = train_data->label_idx();
127
128
    // get feature names
    feature_names_ = train_data->feature_names();
Guolin Ke's avatar
Guolin Ke committed
129
130

    feature_infos_ = train_data->feature_infos();
Guolin Ke's avatar
Guolin Ke committed
131
132
  }

Guolin Ke's avatar
Guolin Ke committed
133
  if ((train_data_ != train_data && train_data != nullptr)
134
      || (gbdt_config_ != nullptr && gbdt_config_->bagging_fraction != new_config->bagging_fraction)) {
wxchan's avatar
wxchan committed
135
    // if need bagging, create buffer
Guolin Ke's avatar
Guolin Ke committed
136
    if (new_config->bagging_fraction < 1.0 && new_config->bagging_freq > 0) {
137
138
      bag_data_cnt_ =
        static_cast<data_size_t>(new_config->bagging_fraction * num_data_);
139
      bag_data_indices_.resize(num_data_);
140
141
142
143
144
145
      tmp_indices_.resize(num_data_);
      offsets_buf_.resize(num_threads_);
      left_cnts_buf_.resize(num_threads_);
      right_cnts_buf_.resize(num_threads_);
      left_write_pos_buf_.resize(num_threads_);
      right_write_pos_buf_.resize(num_threads_);
Guolin Ke's avatar
Guolin Ke committed
146
147
      double average_bag_rate = new_config->bagging_fraction / new_config->bagging_freq;
      is_use_subset_ = false;
148
      if (average_bag_rate <= 0.5) {
Guolin Ke's avatar
Guolin Ke committed
149
        tmp_subset_.reset(new Dataset(bag_data_cnt_));
150
        tmp_subset_->CopyFeatureMapperFrom(train_data);
Guolin Ke's avatar
Guolin Ke committed
151
152
153
        is_use_subset_ = true;
        Log::Debug("use subset for bagging");
      }
wxchan's avatar
wxchan committed
154
155
156
    } else {
      bag_data_cnt_ = num_data_;
      bag_data_indices_.clear();
157
      tmp_indices_.clear();
Guolin Ke's avatar
Guolin Ke committed
158
      is_use_subset_ = false;
wxchan's avatar
wxchan committed
159
    }
Guolin Ke's avatar
Guolin Ke committed
160
  }
wxchan's avatar
wxchan committed
161
  train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
162
163
  if (train_data_ != nullptr) {
    // reset config for tree learner
164
    tree_learner_->ResetConfig(&new_config->tree_config);
165
166
167
    class_need_train_ = std::vector<bool>(num_tree_per_iteration_, true);
    if (objective_function_ != nullptr && objective_function_->SkipEmptyClass()) {
      CHECK(num_tree_per_iteration_ == num_class_);
168
      // + 1 here for the binary classification
Guolin Ke's avatar
Guolin Ke committed
169
      class_default_output_ = std::vector<double>(num_tree_per_iteration_, 0.0f);
170
      auto label = train_data_->metadata().label();
171
      if (num_tree_per_iteration_ > 1) {
Guolin Ke's avatar
Guolin Ke committed
172
173
174
        // multi-class
        std::vector<data_size_t> cnt_per_class(num_tree_per_iteration_, 0);
        for (data_size_t i = 0; i < num_data_; ++i) {
175
176
177
          int index = static_cast<int>(label[i]);
          CHECK(index < num_tree_per_iteration_);
          ++cnt_per_class[index];
Guolin Ke's avatar
Guolin Ke committed
178
        }
179
        for (int i = 0; i < num_tree_per_iteration_; ++i) {
180
181
182
183
184
          if (cnt_per_class[i] == num_data_) {
            class_need_train_[i] = false;
            class_default_output_[i] = -std::log(kEpsilon);
          } else if (cnt_per_class[i] == 0) {
            class_need_train_[i] = false;
185
            class_default_output_[i] = -std::log(1.0f / kEpsilon - 1.0f);
186
187
188
          }
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
189
190
191
192
193
194
195
196
        // binary class
        data_size_t cnt_pos = 0;
        for (data_size_t i = 0; i < num_data_; ++i) {
          if (label[i] > 0) {
            ++cnt_pos;
          }
        }
        if (cnt_pos == 0) {
197
198
          class_need_train_[0] = false;
          class_default_output_[0] = -std::log(1.0f / kEpsilon - 1.0f);
Guolin Ke's avatar
Guolin Ke committed
199
        } else if (cnt_pos == num_data_) {
200
201
202
203
204
          class_need_train_[0] = false;
          class_default_output_[0] = -std::log(kEpsilon);
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
205
  }
Guolin Ke's avatar
Guolin Ke committed
206
  gbdt_config_.reset(new_config.release());
Guolin Ke's avatar
Guolin Ke committed
207
208
}

wxchan's avatar
wxchan committed
209
void GBDT::AddValidDataset(const Dataset* valid_data,
210
                           const std::vector<const Metric*>& valid_metrics) {
wxchan's avatar
wxchan committed
211
212
  if (!train_data_->CheckAlign(*valid_data)) {
    Log::Fatal("cannot add validation data, since it has different bin mappers with training data");
213
  }
Guolin Ke's avatar
Guolin Ke committed
214
  // for a validation dataset, we need its score and metric
215
  auto new_score_updater = std::unique_ptr<ScoreUpdater>(new ScoreUpdater(valid_data, num_tree_per_iteration_));
wxchan's avatar
wxchan committed
216
217
  // update score
  for (int i = 0; i < iter_; ++i) {
218
219
220
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
      auto curr_tree = (i + num_init_iteration_) * num_tree_per_iteration_ + cur_tree_id;
      new_score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
221
222
    }
  }
Guolin Ke's avatar
Guolin Ke committed
223
  valid_score_updater_.push_back(std::move(new_score_updater));
Guolin Ke's avatar
Guolin Ke committed
224
  valid_metrics_.emplace_back();
225
226
227
  if (early_stopping_round_ > 0) {
    best_iter_.emplace_back();
    best_score_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
228
    best_msg_.emplace_back();
229
  }
Guolin Ke's avatar
Guolin Ke committed
230
231
  for (const auto& metric : valid_metrics) {
    valid_metrics_.back().push_back(metric);
232
233
234
    if (early_stopping_round_ > 0) {
      best_iter_.back().push_back(0);
      best_score_.back().push_back(kMinScore);
Guolin Ke's avatar
Guolin Ke committed
235
      best_msg_.back().emplace_back();
236
    }
Guolin Ke's avatar
Guolin Ke committed
237
  }
Guolin Ke's avatar
Guolin Ke committed
238
  valid_metrics_.back().shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
239
240
}

241
data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer) {
242
243
244
  if (cnt <= 0) {
    return 0;
  }
245
246
247
248
  data_size_t bag_data_cnt =
    static_cast<data_size_t>(gbdt_config_->bagging_fraction * cnt);
  data_size_t cur_left_cnt = 0;
  data_size_t cur_right_cnt = 0;
Guolin Ke's avatar
Guolin Ke committed
249
  auto right_buffer = buffer + bag_data_cnt;
250
251
  // random bagging, minimal unit is one record
  for (data_size_t i = 0; i < cnt; ++i) {
Guolin Ke's avatar
Guolin Ke committed
252
253
254
    float prob =
      (bag_data_cnt - cur_left_cnt) / static_cast<float>(cnt - i);
    if (cur_rand.NextFloat() < prob) {
255
256
      buffer[cur_left_cnt++] = start + i;
    } else {
Guolin Ke's avatar
Guolin Ke committed
257
      right_buffer[cur_right_cnt++] = start + i;
258
259
260
261
262
    }
  }
  CHECK(cur_left_cnt == bag_data_cnt);
  return cur_left_cnt;
}
Guolin Ke's avatar
Guolin Ke committed
263

Guolin Ke's avatar
Guolin Ke committed
264
265


266
void GBDT::Bagging(int iter) {
Guolin Ke's avatar
Guolin Ke committed
267
  // if need bagging
268
  if (bag_data_cnt_ < num_data_ && iter % gbdt_config_->bagging_freq == 0) {
Guolin Ke's avatar
Guolin Ke committed
269
    const data_size_t min_inner_size = 1000;
270
271
    data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_;
    if (inner_size < min_inner_size) { inner_size = min_inner_size; }
272
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
273
    #pragma omp parallel for schedule(static,1)
274
    for (int i = 0; i < num_threads_; ++i) {
275
      OMP_LOOP_EX_BEGIN();
276
277
278
279
280
281
      left_cnts_buf_[i] = 0;
      right_cnts_buf_[i] = 0;
      data_size_t cur_start = i * inner_size;
      if (cur_start > num_data_) { continue; }
      data_size_t cur_cnt = inner_size;
      if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; }
Guolin Ke's avatar
Guolin Ke committed
282
283
      Random cur_rand(gbdt_config_->bagging_seed + iter * num_threads_ + i);
      data_size_t cur_left_count = BaggingHelper(cur_rand, cur_start, cur_cnt, tmp_indices_.data() + cur_start);
284
285
286
      offsets_buf_[i] = cur_start;
      left_cnts_buf_[i] = cur_left_count;
      right_cnts_buf_[i] = cur_cnt - cur_left_count;
287
      OMP_LOOP_EX_END();
288
    }
289
    OMP_THROW_EX();
290
291
292
293
294
295
296
297
298
    data_size_t left_cnt = 0;
    left_write_pos_buf_[0] = 0;
    right_write_pos_buf_[0] = 0;
    for (int i = 1; i < num_threads_; ++i) {
      left_write_pos_buf_[i] = left_write_pos_buf_[i - 1] + left_cnts_buf_[i - 1];
      right_write_pos_buf_[i] = right_write_pos_buf_[i - 1] + right_cnts_buf_[i - 1];
    }
    left_cnt = left_write_pos_buf_[num_threads_ - 1] + left_cnts_buf_[num_threads_ - 1];

Guolin Ke's avatar
Guolin Ke committed
299
    #pragma omp parallel for schedule(static, 1)
300
    for (int i = 0; i < num_threads_; ++i) {
301
      OMP_LOOP_EX_BEGIN();
302
303
      if (left_cnts_buf_[i] > 0) {
        std::memcpy(bag_data_indices_.data() + left_write_pos_buf_[i],
304
                    tmp_indices_.data() + offsets_buf_[i], left_cnts_buf_[i] * sizeof(data_size_t));
Guolin Ke's avatar
Guolin Ke committed
305
      }
306
307
      if (right_cnts_buf_[i] > 0) {
        std::memcpy(bag_data_indices_.data() + left_cnt + right_write_pos_buf_[i],
308
                    tmp_indices_.data() + offsets_buf_[i] + left_cnts_buf_[i], right_cnts_buf_[i] * sizeof(data_size_t));
Guolin Ke's avatar
Guolin Ke committed
309
      }
310
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
311
    }
312
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
313
    bag_data_cnt_ = left_cnt;
Guolin Ke's avatar
Guolin Ke committed
314
    Log::Debug("Re-bagging, using %d data to train", bag_data_cnt_);
Guolin Ke's avatar
Guolin Ke committed
315
    // set bagging data to tree learner
Guolin Ke's avatar
Guolin Ke committed
316
317
318
319
    if (!is_use_subset_) {
      tree_learner_->SetBaggingData(bag_data_indices_.data(), bag_data_cnt_);
    } else {
      // get subset
Guolin Ke's avatar
Guolin Ke committed
320
321
      tmp_subset_->ReSize(bag_data_cnt_);
      tmp_subset_->CopySubset(train_data_, bag_data_indices_.data(), bag_data_cnt_, false);
Guolin Ke's avatar
Guolin Ke committed
322
323
      tree_learner_->ResetTrainingData(tmp_subset_.get());
    }
Guolin Ke's avatar
Guolin Ke committed
324
325
326
  }
}

327
void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
328
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
329
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
330
  #endif
331
  // we need to predict out-of-bag scores of data for boosting
Guolin Ke's avatar
Guolin Ke committed
332
  if (num_data_ - bag_data_cnt_ > 0 && !is_use_subset_) {
333
    train_score_updater_->AddScore(tree, bag_data_indices_.data() + bag_data_cnt_, num_data_ - bag_data_cnt_, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
334
  }
Guolin Ke's avatar
Guolin Ke committed
335
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
336
  out_of_bag_score_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
337
  #endif
Guolin Ke's avatar
Guolin Ke committed
338
339
}

340
bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) {
341
  // boosting from average prediction. It doesn't work well for classification, remove it for now.
Guolin Ke's avatar
Guolin Ke committed
342
343
  if (models_.empty()
      && gbdt_config_->boost_from_average
344
      && !train_score_updater_->has_init_score()
345
346
347
      && num_class_ <= 1
      && objective_function_ != nullptr
      && objective_function_->BoostFromAverage()) {
348
    double init_score = 0.0f;
349
    auto label = train_data_->metadata().label();
350
351
352
    #pragma omp parallel for schedule(static) reduction(+:init_score)
    for (data_size_t i = 0; i < num_data_; ++i) {
      init_score += label[i];
353
    }
354
355
    init_score /= num_data_;
    std::unique_ptr<Tree> new_tree(new Tree(2));
Guolin Ke's avatar
Guolin Ke committed
356
    new_tree->Split(0, 0, BinType::NumericalBin, 0, 0, 0, init_score, init_score, 0, num_data_, -1);
357
358
359
360
361
    train_score_updater_->AddScore(init_score, 0);
    for (auto& score_updater : valid_score_updater_) {
      score_updater->AddScore(init_score, 0);
    }
    models_.push_back(std::move(new_tree));
362
363
    boost_from_average_ = true;
  }
Guolin Ke's avatar
Guolin Ke committed
364
365
  // boosting first
  if (gradient == nullptr || hessian == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
366
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
367
    auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
368
    #endif
Guolin Ke's avatar
Guolin Ke committed
369
    Boosting();
370
371
    gradient = gradients_.data();
    hessian = hessians_.data();
Guolin Ke's avatar
Guolin Ke committed
372
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
373
    boosting_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
374
    #endif
Guolin Ke's avatar
Guolin Ke committed
375
  }
Guolin Ke's avatar
Guolin Ke committed
376
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
377
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
378
  #endif
379
380
  // bagging logic
  Bagging(iter_);
Guolin Ke's avatar
Guolin Ke committed
381
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
382
  bagging_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
383
  #endif
Guolin Ke's avatar
Guolin Ke committed
384
  if (is_use_subset_ && bag_data_cnt_ < num_data_) {
Guolin Ke's avatar
Guolin Ke committed
385
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
386
    start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
387
    #endif
388
389
390
391
392
    if (gradients_.empty()) {
      size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
      gradients_.resize(total_size);
      hessians_.resize(total_size);
    }
Guolin Ke's avatar
Guolin Ke committed
393
    // get sub gradients
394
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
395
      size_t bias = static_cast<size_t>(cur_tree_id)* num_data_;
396
      // cannot multi-threading here.
Guolin Ke's avatar
Guolin Ke committed
397
      for (int i = 0; i < bag_data_cnt_; ++i) {
398
399
        gradients_[bias + i] = gradient[bias + bag_data_indices_[i]];
        hessians_[bias + i] = hessian[bias + bag_data_indices_[i]];
Guolin Ke's avatar
Guolin Ke committed
400
401
      }
    }
402
403
    gradient = gradients_.data();
    hessian = hessians_.data();
Guolin Ke's avatar
Guolin Ke committed
404
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
405
    sub_gradient_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
406
    #endif
Guolin Ke's avatar
Guolin Ke committed
407
  }
Guolin Ke's avatar
Guolin Ke committed
408
  bool should_continue = false;
409
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
410
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
411
    start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
412
    #endif
413
    std::unique_ptr<Tree> new_tree(new Tree(2));
414
    if (class_need_train_[cur_tree_id]) {
415
      size_t bias = static_cast<size_t>(cur_tree_id)* num_data_;
416
      new_tree.reset(
417
        tree_learner_->Train(gradient + bias, hessian + bias, is_constant_hessian_));
418
    }
Guolin Ke's avatar
Guolin Ke committed
419
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
420
    tree_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
421
    #endif
Guolin Ke's avatar
Guolin Ke committed
422
423

    if (new_tree->num_leaves() > 1) {
Guolin Ke's avatar
Guolin Ke committed
424
425
426
427
      should_continue = true;
      // shrinkage by learning rate
      new_tree->Shrinkage(shrinkage_rate_);
      // update score
428
429
      UpdateScore(new_tree.get(), cur_tree_id);
      UpdateScoreOutOfBag(new_tree.get(), cur_tree_id);
430
431
    } else {
      // only add default score one-time
432
433
      if (!class_need_train_[cur_tree_id] && models_.size() < static_cast<size_t>(num_tree_per_iteration_)) {
        auto output = class_default_output_[cur_tree_id];
Guolin Ke's avatar
Guolin Ke committed
434
        new_tree->Split(0, 0, BinType::NumericalBin, 0, 0, 0,
Guolin Ke's avatar
Guolin Ke committed
435
                        output, output, 0, num_data_, -1);
436
        train_score_updater_->AddScore(output, cur_tree_id);
437
        for (auto& score_updater : valid_score_updater_) {
438
          score_updater->AddScore(output, cur_tree_id);
439
440
441
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
442
443
444
    // add model
    models_.push_back(std::move(new_tree));
  }
Guolin Ke's avatar
Guolin Ke committed
445
  if (!should_continue) {
Guolin Ke's avatar
Guolin Ke committed
446
    Log::Warning("Stopped training because there are no more leaves that meet the split requirements.");
447
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
448
449
450
451
      models_.pop_back();
    }
    return true;
  }
Guolin Ke's avatar
Guolin Ke committed
452
453
454
455
456
457
  ++iter_;
  if (is_eval) {
    return EvalAndCheckEarlyStopping();
  } else {
    return false;
  }
458

Guolin Ke's avatar
Guolin Ke committed
459
}
460

wxchan's avatar
wxchan committed
461
void GBDT::RollbackOneIter() {
462
  if (iter_ <= 0) { return; }
wxchan's avatar
wxchan committed
463
464
  int cur_iter = iter_ + num_init_iteration_ - 1;
  // reset score
465
466
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
    auto curr_tree = cur_iter * num_tree_per_iteration_ + cur_tree_id;
wxchan's avatar
wxchan committed
467
    models_[curr_tree]->Shrinkage(-1.0);
468
    train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
469
    for (auto& score_updater : valid_score_updater_) {
470
      score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
471
472
473
    }
  }
  // remove model
474
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
wxchan's avatar
wxchan committed
475
476
477
478
479
    models_.pop_back();
  }
  --iter_;
}

Guolin Ke's avatar
Guolin Ke committed
480
bool GBDT::EvalAndCheckEarlyStopping() {
481
  bool is_met_early_stopping = false;
Guolin Ke's avatar
Guolin Ke committed
482
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
483
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
484
  #endif
485
  // print message for metric
Guolin Ke's avatar
Guolin Ke committed
486
  auto best_msg = OutputMetric(iter_);
Guolin Ke's avatar
Guolin Ke committed
487
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
488
  metric_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
489
  #endif
Guolin Ke's avatar
Guolin Ke committed
490
  is_met_early_stopping = !best_msg.empty();
491
492
  if (is_met_early_stopping) {
    Log::Info("Early stopping at iteration %d, the best iteration round is %d",
493
              iter_, iter_ - early_stopping_round_);
Guolin Ke's avatar
Guolin Ke committed
494
    Log::Info("Output of best iteration round:\n%s", best_msg.c_str());
495
    // pop last early_stopping_round_ models
496
    for (int i = 0; i < early_stopping_round_ * num_tree_per_iteration_; ++i) {
497
498
499
500
      models_.pop_back();
    }
  }
  return is_met_early_stopping;
Guolin Ke's avatar
Guolin Ke committed
501
502
}

503
void GBDT::UpdateScore(const Tree* tree, const int cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
504
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
505
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
506
  #endif
Guolin Ke's avatar
Guolin Ke committed
507
  // update training score
Guolin Ke's avatar
Guolin Ke committed
508
  if (!is_use_subset_) {
509
    train_score_updater_->AddScore(tree_learner_.get(), tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
510
  } else {
511
    train_score_updater_->AddScore(tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
512
  }
Guolin Ke's avatar
Guolin Ke committed
513
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
514
  train_score_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
515
516
  #endif
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
517
  start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
518
  #endif
Guolin Ke's avatar
Guolin Ke committed
519
  // update validation score
Guolin Ke's avatar
Guolin Ke committed
520
  for (auto& score_updater : valid_score_updater_) {
521
    score_updater->AddScore(tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
522
  }
Guolin Ke's avatar
Guolin Ke committed
523
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
524
  valid_score_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
525
  #endif
Guolin Ke's avatar
Guolin Ke committed
526
527
}

Guolin Ke's avatar
Guolin Ke committed
528
529
530
531
std::string GBDT::OutputMetric(int iter) {
  bool need_output = (iter % gbdt_config_->output_freq) == 0;
  std::string ret = "";
  std::stringstream msg_buf;
532
  std::vector<std::pair<size_t, size_t>> meet_early_stopping_pairs;
Guolin Ke's avatar
Guolin Ke committed
533
  // print training metric
Guolin Ke's avatar
Guolin Ke committed
534
  if (need_output) {
535
536
    for (auto& sub_metric : training_metrics_) {
      auto name = sub_metric->GetName();
Guolin Ke's avatar
Guolin Ke committed
537
      auto scores = sub_metric->Eval(train_score_updater_->score(), objective_function_);
Guolin Ke's avatar
Guolin Ke committed
538
      for (size_t k = 0; k < name.size(); ++k) {
Guolin Ke's avatar
Guolin Ke committed
539
540
541
542
543
544
545
546
        std::stringstream tmp_buf;
        tmp_buf << "Iteration:" << iter
          << ", training " << name[k]
          << " : " << scores[k];
        Log::Info(tmp_buf.str().c_str());
        if (early_stopping_round_ > 0) {
          msg_buf << tmp_buf.str() << std::endl;
        }
547
      }
548
    }
Guolin Ke's avatar
Guolin Ke committed
549
550
  }
  // print validation metric
Guolin Ke's avatar
Guolin Ke committed
551
  if (need_output || early_stopping_round_ > 0) {
552
553
    for (size_t i = 0; i < valid_metrics_.size(); ++i) {
      for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
554
        auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score(),
Guolin Ke's avatar
Guolin Ke committed
555
                                                      objective_function_);
Guolin Ke's avatar
Guolin Ke committed
556
557
558
559
560
561
562
563
564
565
566
        auto name = valid_metrics_[i][j]->GetName();
        for (size_t k = 0; k < name.size(); ++k) {
          std::stringstream tmp_buf;
          tmp_buf << "Iteration:" << iter
            << ", valid_" << i + 1 << " " << name[k]
            << " : " << test_scores[k];
          if (need_output) {
            Log::Info(tmp_buf.str().c_str());
          }
          if (early_stopping_round_ > 0) {
            msg_buf << tmp_buf.str() << std::endl;
567
          }
wxchan's avatar
wxchan committed
568
        }
Guolin Ke's avatar
Guolin Ke committed
569
        if (ret.empty() && early_stopping_round_ > 0) {
570
571
572
          auto cur_score = valid_metrics_[i][j]->factor_to_bigger_better() * test_scores.back();
          if (cur_score > best_score_[i][j]) {
            best_score_[i][j] = cur_score;
573
            best_iter_[i][j] = iter;
Guolin Ke's avatar
Guolin Ke committed
574
            meet_early_stopping_pairs.emplace_back(i, j);
575
          } else {
Guolin Ke's avatar
Guolin Ke committed
576
            if (iter - best_iter_[i][j] >= early_stopping_round_) { ret = best_msg_[i][j]; }
577
          }
wxchan's avatar
wxchan committed
578
579
        }
      }
Guolin Ke's avatar
Guolin Ke committed
580
581
    }
  }
Guolin Ke's avatar
Guolin Ke committed
582
583
584
  for (auto& pair : meet_early_stopping_pairs) {
    best_msg_[pair.first][pair.second] = msg_buf.str();
  }
wxchan's avatar
wxchan committed
585
  return ret;
Guolin Ke's avatar
Guolin Ke committed
586
587
}

588
/*! \brief Get eval result */
589
std::vector<double> GBDT::GetEvalAt(int data_idx) const {
Guolin Ke's avatar
Guolin Ke committed
590
  CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
591
592
  std::vector<double> ret;
  if (data_idx == 0) {
593
    for (auto& sub_metric : training_metrics_) {
Guolin Ke's avatar
Guolin Ke committed
594
      auto scores = sub_metric->Eval(train_score_updater_->score(), objective_function_);
595
596
597
      for (auto score : scores) {
        ret.push_back(score);
      }
598
    }
599
  } else {
600
601
    auto used_idx = data_idx - 1;
    for (size_t j = 0; j < valid_metrics_[used_idx].size(); ++j) {
602
      auto test_scores = valid_metrics_[used_idx][j]->Eval(valid_score_updater_[used_idx]->score(),
Guolin Ke's avatar
Guolin Ke committed
603
                                                           objective_function_);
604
605
606
      for (auto score : test_scores) {
        ret.push_back(score);
      }
607
608
609
610
611
    }
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
612
/*! \brief Get training scores result */
613
const double* GBDT::GetTrainingScore(int64_t* out_len) {
614
  *out_len = static_cast<int64_t>(train_score_updater_->num_data()) * num_class_;
Guolin Ke's avatar
Guolin Ke committed
615
  return train_score_updater_->score();
616
617
}

Guolin Ke's avatar
Guolin Ke committed
618
619
void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
  CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
Guolin Ke's avatar
Guolin Ke committed
620

621
  const double* raw_scores = nullptr;
Guolin Ke's avatar
Guolin Ke committed
622
623
  data_size_t num_data = 0;
  if (data_idx == 0) {
wxchan's avatar
wxchan committed
624
    raw_scores = GetTrainingScore(out_len);
Guolin Ke's avatar
Guolin Ke committed
625
626
627
628
629
    num_data = train_score_updater_->num_data();
  } else {
    auto used_idx = data_idx - 1;
    raw_scores = valid_score_updater_[used_idx]->score();
    num_data = valid_score_updater_[used_idx]->num_data();
630
    *out_len = static_cast<int64_t>(num_data) * num_class_;
Guolin Ke's avatar
Guolin Ke committed
631
  }
632
  if (objective_function_ != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
633
634
    #pragma omp parallel for schedule(static)
    for (data_size_t i = 0; i < num_data; ++i) {
Guolin Ke's avatar
Guolin Ke committed
635
      std::vector<double> tree_pred(num_tree_per_iteration_);
636
      for (int j = 0; j < num_tree_per_iteration_; ++j) {
Guolin Ke's avatar
Guolin Ke committed
637
        tree_pred[j] = raw_scores[j * num_data + i];
638
      }
Guolin Ke's avatar
Guolin Ke committed
639
640
      std::vector<double> tmp_result(num_class_);
      objective_function_->ConvertOutput(tree_pred.data(), tmp_result.data());
Guolin Ke's avatar
Guolin Ke committed
641
      for (int j = 0; j < num_class_; ++j) {
642
        out_result[j * num_data + i] = static_cast<double>(tmp_result[j]);
Guolin Ke's avatar
Guolin Ke committed
643
644
      }
    }
645
  } else {
Guolin Ke's avatar
Guolin Ke committed
646
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
647
    for (data_size_t i = 0; i < num_data; ++i) {
Guolin Ke's avatar
Guolin Ke committed
648
      std::vector<double> tmp_result(num_tree_per_iteration_);
649
      for (int j = 0; j < num_tree_per_iteration_; ++j) {
Guolin Ke's avatar
Guolin Ke committed
650
        out_result[j * num_data + i] = static_cast<double>(raw_scores[j * num_data + i]);
Guolin Ke's avatar
Guolin Ke committed
651
652
653
654
655
      }
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
656
void GBDT::Boosting() {
657
  if (objective_function_ == nullptr) {
658
659
    Log::Fatal("No object function provided");
  }
Hui Xue's avatar
Hui Xue committed
660
  // objective function will calculate gradients and hessians
661
  int64_t num_score = 0;
662
  objective_function_->
Guolin Ke's avatar
Guolin Ke committed
663
    GetGradients(GetTrainingScore(&num_score), gradients_.data(), hessians_.data());
Guolin Ke's avatar
Guolin Ke committed
664
665
}

666
std::string GBDT::DumpModel(int num_iteration) const {
Guolin Ke's avatar
Guolin Ke committed
667
  std::stringstream str_buf;
wxchan's avatar
wxchan committed
668

Guolin Ke's avatar
Guolin Ke committed
669
  str_buf << "{";
Guolin Ke's avatar
Guolin Ke committed
670
  str_buf << "\"name\":\"" << SubModelName() << "\"," << std::endl;
Guolin Ke's avatar
Guolin Ke committed
671
  str_buf << "\"num_class\":" << num_class_ << "," << std::endl;
672
  str_buf << "\"num_tree_per_iteration\":" << num_tree_per_iteration_ << "," << std::endl;
Guolin Ke's avatar
Guolin Ke committed
673
674
  str_buf << "\"label_index\":" << label_idx_ << "," << std::endl;
  str_buf << "\"max_feature_idx\":" << max_feature_idx_ << "," << std::endl;
wxchan's avatar
wxchan committed
675

676
677
678
  str_buf << "\"feature_names\":[\""
    << Common::Join(feature_names_, "\",\"") << "\"],"
    << std::endl;
Guolin Ke's avatar
Guolin Ke committed
679

Guolin Ke's avatar
Guolin Ke committed
680
  str_buf << "\"tree_info\":[";
681
682
  int num_used_model = static_cast<int>(models_.size());
  if (num_iteration > 0) {
Guolin Ke's avatar
Guolin Ke committed
683
    num_iteration += boost_from_average_ ? 1 : 0;
684
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
685
  }
686
  for (int i = 0; i < num_used_model; ++i) {
wxchan's avatar
wxchan committed
687
    if (i > 0) {
Guolin Ke's avatar
Guolin Ke committed
688
      str_buf << ",";
wxchan's avatar
wxchan committed
689
    }
Guolin Ke's avatar
Guolin Ke committed
690
691
692
693
    str_buf << "{";
    str_buf << "\"tree_index\":" << i << ",";
    str_buf << models_[i]->ToJSON();
    str_buf << "}";
wxchan's avatar
wxchan committed
694
  }
Guolin Ke's avatar
Guolin Ke committed
695
  str_buf << "]" << std::endl;
wxchan's avatar
wxchan committed
696

Guolin Ke's avatar
Guolin Ke committed
697
  str_buf << "}" << std::endl;
wxchan's avatar
wxchan committed
698

Guolin Ke's avatar
Guolin Ke committed
699
  return str_buf.str();
wxchan's avatar
wxchan committed
700
701
}

702
703
704
std::string GBDT::ModelToIfElse(int num_iteration) const {
  std::stringstream str_buf;

705
706
707
708
709
710
711
712
713
714
715
716
717
  str_buf << "#include \"gbdt.h\"" << std::endl;
  str_buf << "#include <LightGBM/utils/openmp_wrapper.h>" << std::endl;
  str_buf << "#include <LightGBM/utils/common.h>" << std::endl;
  str_buf << "#include <LightGBM/objective_function.h>" << std::endl;
  str_buf << "#include <LightGBM/metric.h>" << std::endl;
  str_buf << "#include <ctime>" << std::endl;
  str_buf << "#include <sstream>" << std::endl;
  str_buf << "#include <chrono>" << std::endl;
  str_buf << "#include <string>" << std::endl;
  str_buf << "#include <vector>" << std::endl;
  str_buf << "#include <utility>" << std::endl;
  str_buf << "namespace LightGBM {" << std::endl;

718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
  int num_used_model = static_cast<int>(models_.size());
  if (num_iteration > 0) {
    num_iteration += boost_from_average_ ? 1 : 0;
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
  }

  // PredictRaw
  for (int i = 0; i < num_used_model; ++i) {
    str_buf << models_[i]->ToIfElse(i, false) << std::endl;
  }

  str_buf << "double (*PredictTreePtr[])(const double*) = { ";
  for (int i = 0; i < num_used_model; ++i) {
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i;
  }
  str_buf << " };" << std::endl << std::endl;

  std::stringstream pred_str_buf;

  pred_str_buf << "\t" << "if (num_threads_ <= num_tree_per_iteration_) {" << std::endl;
  pred_str_buf << "\t\t" << "#pragma omp parallel for schedule(static)" << std::endl;
  pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
  pred_str_buf << "\t\t\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl;
  pred_str_buf << "\t\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
  pred_str_buf << "\t\t\t" << "}" << std::endl;
  pred_str_buf << "\t\t" << "}" << std::endl;
  pred_str_buf << "\t" << "} else {" << std::endl;
  pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
  pred_str_buf << "\t\t\t" << "double t = 0.0f;" << std::endl;
  pred_str_buf << "\t\t\t" << "#pragma omp parallel for schedule(static) reduction(+:t)" << std::endl;
  pred_str_buf << "\t\t\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl;
  pred_str_buf << "\t\t\t\t" << "t += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
  pred_str_buf << "\t\t\t" << "}" << std::endl;
  pred_str_buf << "\t\t\t" << "output[k] = t;" << std::endl;
  pred_str_buf << "\t\t" << "}" << std::endl;
  pred_str_buf << "\t" << "}" << std::endl;

  str_buf << "void GBDT::PredictRaw(const double* features, double *output) const {" << std::endl;
  str_buf << pred_str_buf.str();
  str_buf << "}" << std::endl;
  str_buf << std::endl;

  // Predict
  str_buf << "void GBDT::Predict(const double* features, double *output) const {" << std::endl;
  str_buf << pred_str_buf.str();
  str_buf << "\t" << "if (objective_function_ != nullptr) {" << std::endl;
  str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << std::endl;
  str_buf << "\t" << "}" << std::endl;
  str_buf << "}" << std::endl;
  str_buf << std::endl;

  // PredictLeafIndex
  for (int i = 0; i < num_used_model; ++i) {
    str_buf << models_[i]->ToIfElse(i, true) << std::endl;
  }

  str_buf << "double (*PredictTreeLeafPtr[])(const double*) = { ";
  for (int i = 0; i < num_used_model; ++i) {
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i << "Leaf";
  }
  str_buf << " };" << std::endl << std::endl;

  str_buf << "void GBDT::PredictLeafIndex(const double* features, double *output) const {" << std::endl;
  str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << std::endl;
  str_buf << "\t" << "#pragma omp parallel for schedule(static)" << std::endl;
  str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << std::endl;
  str_buf << "\t\t" << "output[i] = (*PredictTreeLeafPtr[i])(features);" << std::endl;
  str_buf << "\t" << "}" << std::endl;
  str_buf << "}" << std::endl;
793
794
795

  str_buf << "}  // namespace LightGBM" << std::endl;

796
797
798
799
800
801
  return str_buf.str();
}

bool GBDT::SaveModelToIfElse(int num_iteration, const char* filename) const {
  /*! \brief File to write models */
  std::ofstream output_file;
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
  std::ifstream ifs(filename);
  if (ifs.good()) {
    std::string origin((std::istreambuf_iterator<char>(ifs)),
                       (std::istreambuf_iterator<char>()));
    output_file.open(filename);
    output_file << "#define USE_HARD_CODE 0" << std::endl;
    output_file << "#ifndef USE_HARD_CODE" << std::endl;
    output_file << origin << std::endl;
    output_file << "#else" << std::endl;
    output_file << ModelToIfElse(num_iteration);
    output_file << "#endif" << std::endl;
  } else {
    output_file.open(filename);
    output_file << ModelToIfElse(num_iteration);
  }
817

818
  ifs.close();
819
820
821
822
823
  output_file.close();

  return (bool)output_file;
}

Guolin Ke's avatar
Guolin Ke committed
824
std::string GBDT::SaveModelToString(int num_iteration) const {
825
  std::stringstream ss;
826

827
828
829
830
  // output model type
  ss << SubModelName() << std::endl;
  // output number of class
  ss << "num_class=" << num_class_ << std::endl;
831
  ss << "num_tree_per_iteration=" << num_tree_per_iteration_ << std::endl;
832
833
834
835
  // output label index
  ss << "label_index=" << label_idx_ << std::endl;
  // output max_feature_idx
  ss << "max_feature_idx=" << max_feature_idx_ << std::endl;
836
837
838
  // output objective
  if (objective_function_ != nullptr) {
    ss << "objective=" << objective_function_->ToString() << std::endl;
839
  }
840

841
842
843
  if (boost_from_average_) {
    ss << "boost_from_average" << std::endl;
  }
Guolin Ke's avatar
Guolin Ke committed
844

845
  ss << "feature_names=" << Common::Join(feature_names_, " ") << std::endl;
846

847
  ss << "feature_infos=" << Common::Join(feature_infos_, " ") << std::endl;
848

849
850
  ss << std::endl;
  int num_used_model = static_cast<int>(models_.size());
Guolin Ke's avatar
Guolin Ke committed
851
852
  if (num_iteration > 0) {
    num_iteration += boost_from_average_ ? 1 : 0;
853
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
854
855
856
857
858
859
860
861
862
863
864
865
866
867
  }
  // output tree models
  for (int i = 0; i < num_used_model; ++i) {
    ss << "Tree=" << i << std::endl;
    ss << models_[i]->ToString() << std::endl;
  }

  std::vector<std::pair<size_t, std::string>> pairs = FeatureImportance();
  ss << std::endl << "feature importances:" << std::endl;
  for (size_t i = 0; i < pairs.size(); ++i) {
    ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << std::endl;
  }

  return ss.str();
868
869
}

870
bool GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
wxchan's avatar
wxchan committed
871
872
873
  /*! \brief File to write models */
  std::ofstream output_file;
  output_file.open(filename);
874

875
  output_file << SaveModelToString(num_iteration);
876

wxchan's avatar
wxchan committed
877
  output_file.close();
878
879

  return (bool)output_file;
Guolin Ke's avatar
Guolin Ke committed
880
881
}

882
bool GBDT::LoadModelFromString(const std::string& model_str) {
Guolin Ke's avatar
Guolin Ke committed
883
884
885
  // use serialized string to restore this object
  models_.clear();
  std::vector<std::string> lines = Common::Split(model_str.c_str(), '\n');
886
887

  // get number of classes
888
889
890
891
  auto line = Common::FindFromLines(lines, "num_class=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &num_class_);
  } else {
892
    Log::Fatal("Model file doesn't specify the number of classes");
893
    return false;
894
  }
895
896
897
898
899
900
901
902

  line = Common::FindFromLines(lines, "num_tree_per_iteration=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &num_tree_per_iteration_);
  } else {
    num_tree_per_iteration_ = num_class_;
  }

Guolin Ke's avatar
Guolin Ke committed
903
  // get index of label
904
905
906
907
  line = Common::FindFromLines(lines, "label_index=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &label_idx_);
  } else {
908
    Log::Fatal("Model file doesn't specify the label index");
909
    return false;
Guolin Ke's avatar
Guolin Ke committed
910
  }
Guolin Ke's avatar
Guolin Ke committed
911
  // get max_feature_idx first
912
913
914
915
  line = Common::FindFromLines(lines, "max_feature_idx=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &max_feature_idx_);
  } else {
916
    Log::Fatal("Model file doesn't specify max_feature_idx");
917
    return false;
Guolin Ke's avatar
Guolin Ke committed
918
  }
919
920
921
922
923
  // get boost_from_average_
  line = Common::FindFromLines(lines, "boost_from_average");
  if (line.size() > 0) {
    boost_from_average_ = true;
  }
Guolin Ke's avatar
Guolin Ke committed
924
925
926
  // get feature names
  line = Common::FindFromLines(lines, "feature_names=");
  if (line.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
927
    feature_names_ = Common::Split(line.substr(std::strlen("feature_names=")).c_str(), " ");
Guolin Ke's avatar
Guolin Ke committed
928
929
    if (feature_names_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of feature_names");
930
      return false;
Guolin Ke's avatar
Guolin Ke committed
931
    }
932
  } else {
Guolin Ke's avatar
Guolin Ke committed
933
    Log::Fatal("Model file doesn't contain feature names");
934
    return false;
Guolin Ke's avatar
Guolin Ke committed
935
936
  }

Guolin Ke's avatar
Guolin Ke committed
937
938
939
940
941
942
943
944
945
946
947
948
  line = Common::FindFromLines(lines, "feature_infos=");
  if (line.size() > 0) {
    feature_infos_ = Common::Split(line.substr(std::strlen("feature_infos=")).c_str(), " ");
    if (feature_infos_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of feature_infos");
      return false;
    }
  } else {
    Log::Fatal("Model file doesn't contain feature infos");
    return false;
  }

949
950
951
952
953
954
955
956
  line = Common::FindFromLines(lines, "objective=");

  if (line.size() > 0) {
    auto str = Common::Split(line.c_str(), '=')[1];
    loaded_objective_.reset(ObjectiveFunction::CreateObjectiveFunction(str));
    objective_function_ = loaded_objective_.get();
  }

Guolin Ke's avatar
Guolin Ke committed
957
  // get tree models
958
  size_t i = 0;
Guolin Ke's avatar
Guolin Ke committed
959
960
961
962
963
964
965
  while (i < lines.size()) {
    size_t find_pos = lines[i].find("Tree=");
    if (find_pos != std::string::npos) {
      ++i;
      int start = static_cast<int>(i);
      while (i < lines.size() && lines[i].find("Tree=") == std::string::npos) { ++i; }
      int end = static_cast<int>(i);
Guolin Ke's avatar
Guolin Ke committed
966
      std::string tree_str = Common::Join<std::string>(lines, start, end, "\n");
967
      models_.emplace_back(new Tree(tree_str));
Guolin Ke's avatar
Guolin Ke committed
968
969
970
971
    } else {
      ++i;
    }
  }
972
  Log::Info("Finished loading %d models", models_.size());
973
  num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
wxchan's avatar
wxchan committed
974
  num_init_iteration_ = num_iteration_for_pred_;
975
  iter_ = 0;
976
977

  return true;
Guolin Ke's avatar
Guolin Ke committed
978
979
}

wxchan's avatar
wxchan committed
980
std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const {
981

982
  std::vector<size_t> feature_importances(max_feature_idx_ + 1, 0);
983
984
  for (size_t iter = 0; iter < models_.size(); ++iter) {
    for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
Guolin Ke's avatar
Guolin Ke committed
985
986
987
      if (models_[iter]->split_gain(split_idx) > 0) {
        ++feature_importances[models_[iter]->split_feature(split_idx)];
      }
wxchan's avatar
wxchan committed
988
    }
989
990
991
992
993
994
  }
  // store the importance first
  std::vector<std::pair<size_t, std::string>> pairs;
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    if (feature_importances[i] > 0) {
      pairs.emplace_back(feature_importances[i], feature_names_[i]);
995
    }
996
997
998
  }
  // sort the importance
  std::sort(pairs.begin(), pairs.end(),
Guolin Ke's avatar
Guolin Ke committed
999
1000
            [](const std::pair<size_t, std::string>& lhs,
               const std::pair<size_t, std::string>& rhs) {
1001
1002
1003
    return lhs.first > rhs.first;
  });
  return pairs;
wxchan's avatar
wxchan committed
1004
1005
}

Guolin Ke's avatar
Guolin Ke committed
1006
}  // namespace LightGBM