gbdt.cpp 36.7 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
#include "gbdt.h"

3
#include <LightGBM/utils/openmp_wrapper.h>
4

Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#include <LightGBM/utils/common.h>

#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
cbecker's avatar
cbecker committed
9
#include <LightGBM/prediction_early_stop.h>
Guolin Ke's avatar
Guolin Ke committed
10
#include <LightGBM/network.h>
Guolin Ke's avatar
Guolin Ke committed
11
12
13
14
15
16
17

#include <ctime>

#include <sstream>
#include <chrono>
#include <string>
#include <vector>
18
#include <utility>
Guolin Ke's avatar
Guolin Ke committed
19
20
21

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
22
23
24
#ifdef TIMETAG
std::chrono::duration<double, std::milli> boosting_time;
std::chrono::duration<double, std::milli> train_score_time;
Guolin Ke's avatar
Guolin Ke committed
25
std::chrono::duration<double, std::milli> out_of_bag_score_time;
Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
30
31
32
std::chrono::duration<double, std::milli> valid_score_time;
std::chrono::duration<double, std::milli> metric_time;
std::chrono::duration<double, std::milli> bagging_time;
std::chrono::duration<double, std::milli> sub_gradient_time;
std::chrono::duration<double, std::milli> tree_time;
#endif // TIMETAG

33
GBDT::GBDT()
34
  :iter_(0),
35
  train_data_(nullptr),
36
  objective_function_(nullptr),
37
38
  early_stopping_round_(0),
  max_feature_idx_(0),
39
  num_tree_per_iteration_(1),
40
  num_class_(1),
41
  num_iteration_for_pred_(0),
42
  shrinkage_rate_(0.1f),
43
44
  num_init_iteration_(0),
  boost_from_average_(false) {
Guolin Ke's avatar
Guolin Ke committed
45
46
  #pragma omp parallel
  #pragma omp master
47
48
49
  {
    num_threads_ = omp_get_num_threads();
  }
Guolin Ke's avatar
Guolin Ke committed
50
51
52
}

GBDT::~GBDT() {
Guolin Ke's avatar
Guolin Ke committed
53
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
54
55
  Log::Info("GBDT::boosting costs %f", boosting_time * 1e-3);
  Log::Info("GBDT::train_score costs %f", train_score_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
56
  Log::Info("GBDT::out_of_bag_score costs %f", out_of_bag_score_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
57
58
59
60
61
  Log::Info("GBDT::valid_score costs %f", valid_score_time * 1e-3);
  Log::Info("GBDT::metric costs %f", metric_time * 1e-3);
  Log::Info("GBDT::bagging costs %f", bagging_time * 1e-3);
  Log::Info("GBDT::sub_gradient costs %f", sub_gradient_time * 1e-3);
  Log::Info("GBDT::tree costs %f", tree_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
62
  #endif
Guolin Ke's avatar
Guolin Ke committed
63
64
}

65
void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
66
                const std::vector<const Metric*>& training_metrics) {
67
  iter_ = 0;
wxchan's avatar
wxchan committed
68
  num_iteration_for_pred_ = 0;
69
  max_feature_idx_ = 0;
wxchan's avatar
wxchan committed
70
71
  num_class_ = config->num_class;
  train_data_ = nullptr;
Guolin Ke's avatar
Guolin Ke committed
72
  gbdt_config_ = nullptr;
73
  tree_learner_ = nullptr;
74
  ResetTrainingData(config, train_data, objective_function, training_metrics);
wxchan's avatar
wxchan committed
75
76
}

77
void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
78
                             const std::vector<const Metric*>& training_metrics) {
Guolin Ke's avatar
Guolin Ke committed
79
  auto new_config = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
wxchan's avatar
wxchan committed
80
81
82
  if (train_data_ != nullptr && !train_data_->CheckAlign(*train_data)) {
    Log::Fatal("cannot reset training data, since new training data has different bin mappers");
  }
Guolin Ke's avatar
Guolin Ke committed
83
84
85
  early_stopping_round_ = new_config->early_stopping_round;
  shrinkage_rate_ = new_config->learning_rate;

86
87
88
89
  objective_function_ = objective_function;
  num_tree_per_iteration_ = num_class_;
  if (objective_function_ != nullptr) {
    is_constant_hessian_ = objective_function_->IsConstantHessian();
Guolin Ke's avatar
Guolin Ke committed
90
    num_tree_per_iteration_ = objective_function_->NumTreePerIteration();
91
92
93
  } else {
    is_constant_hessian_ = false;
  }
Guolin Ke's avatar
Guolin Ke committed
94

Guolin Ke's avatar
Guolin Ke committed
95
  if (train_data_ != train_data && train_data != nullptr) {
96
    if (tree_learner_ == nullptr) {
97
      tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(new_config->tree_learner_type, new_config->device_type, &new_config->tree_config));
Guolin Ke's avatar
Guolin Ke committed
98
99
    }
    // init tree learner
100
    tree_learner_->Init(train_data, is_constant_hessian_);
Guolin Ke's avatar
Guolin Ke committed
101

Guolin Ke's avatar
Guolin Ke committed
102
103
104
105
106
107
    // push training metrics
    training_metrics_.clear();
    for (const auto& metric : training_metrics) {
      training_metrics_.push_back(metric);
    }
    training_metrics_.shrink_to_fit();
wxchan's avatar
wxchan committed
108
109
    // not same training data, need reset score and others
    // create score tracker
110
    train_score_updater_.reset(new ScoreUpdater(train_data, num_tree_per_iteration_));
wxchan's avatar
wxchan committed
111
112
    // update score
    for (int i = 0; i < iter_; ++i) {
113
114
115
      for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
        auto curr_tree = (i + num_init_iteration_) * num_tree_per_iteration_ + cur_tree_id;
        train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
116
117
118
119
      }
    }
    num_data_ = train_data->num_data();
    // create buffer for gradients and hessians
120
121
122
123
124
    if (objective_function_ != nullptr) {
      size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
      gradients_.resize(total_size);
      hessians_.resize(total_size);
    }
wxchan's avatar
wxchan committed
125
126
127
128
    // get max feature index
    max_feature_idx_ = train_data->num_total_features() - 1;
    // get label index
    label_idx_ = train_data->label_idx();
129
130
    // get feature names
    feature_names_ = train_data->feature_names();
Guolin Ke's avatar
Guolin Ke committed
131
132

    feature_infos_ = train_data->feature_infos();
Guolin Ke's avatar
Guolin Ke committed
133
134
  }

Guolin Ke's avatar
Guolin Ke committed
135
  if ((train_data_ != train_data && train_data != nullptr)
136
      || (gbdt_config_ != nullptr && gbdt_config_->bagging_fraction != new_config->bagging_fraction)) {
wxchan's avatar
wxchan committed
137
    // if need bagging, create buffer
Guolin Ke's avatar
Guolin Ke committed
138
    if (new_config->bagging_fraction < 1.0 && new_config->bagging_freq > 0) {
139
140
      bag_data_cnt_ =
        static_cast<data_size_t>(new_config->bagging_fraction * num_data_);
141
      bag_data_indices_.resize(num_data_);
142
143
144
145
146
147
      tmp_indices_.resize(num_data_);
      offsets_buf_.resize(num_threads_);
      left_cnts_buf_.resize(num_threads_);
      right_cnts_buf_.resize(num_threads_);
      left_write_pos_buf_.resize(num_threads_);
      right_write_pos_buf_.resize(num_threads_);
Guolin Ke's avatar
Guolin Ke committed
148
149
      double average_bag_rate = new_config->bagging_fraction / new_config->bagging_freq;
      is_use_subset_ = false;
150
      if (average_bag_rate <= 0.5) {
Guolin Ke's avatar
Guolin Ke committed
151
        tmp_subset_.reset(new Dataset(bag_data_cnt_));
152
        tmp_subset_->CopyFeatureMapperFrom(train_data);
Guolin Ke's avatar
Guolin Ke committed
153
154
155
        is_use_subset_ = true;
        Log::Debug("use subset for bagging");
      }
wxchan's avatar
wxchan committed
156
157
158
    } else {
      bag_data_cnt_ = num_data_;
      bag_data_indices_.clear();
159
      tmp_indices_.clear();
Guolin Ke's avatar
Guolin Ke committed
160
      is_use_subset_ = false;
wxchan's avatar
wxchan committed
161
    }
Guolin Ke's avatar
Guolin Ke committed
162
  }
wxchan's avatar
wxchan committed
163
  train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
164
165
  if (train_data_ != nullptr) {
    // reset config for tree learner
166
    tree_learner_->ResetConfig(&new_config->tree_config);
167
168
169
    class_need_train_ = std::vector<bool>(num_tree_per_iteration_, true);
    if (objective_function_ != nullptr && objective_function_->SkipEmptyClass()) {
      CHECK(num_tree_per_iteration_ == num_class_);
170
      // + 1 here for the binary classification
Guolin Ke's avatar
Guolin Ke committed
171
      class_default_output_ = std::vector<double>(num_tree_per_iteration_, 0.0f);
172
      auto label = train_data_->metadata().label();
173
      if (num_tree_per_iteration_ > 1) {
Guolin Ke's avatar
Guolin Ke committed
174
175
176
        // multi-class
        std::vector<data_size_t> cnt_per_class(num_tree_per_iteration_, 0);
        for (data_size_t i = 0; i < num_data_; ++i) {
177
178
179
          int index = static_cast<int>(label[i]);
          CHECK(index < num_tree_per_iteration_);
          ++cnt_per_class[index];
Guolin Ke's avatar
Guolin Ke committed
180
        }
181
        for (int i = 0; i < num_tree_per_iteration_; ++i) {
182
183
184
185
186
          if (cnt_per_class[i] == num_data_) {
            class_need_train_[i] = false;
            class_default_output_[i] = -std::log(kEpsilon);
          } else if (cnt_per_class[i] == 0) {
            class_need_train_[i] = false;
187
            class_default_output_[i] = -std::log(1.0f / kEpsilon - 1.0f);
188
189
190
          }
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
191
192
193
194
195
196
197
198
        // binary class
        data_size_t cnt_pos = 0;
        for (data_size_t i = 0; i < num_data_; ++i) {
          if (label[i] > 0) {
            ++cnt_pos;
          }
        }
        if (cnt_pos == 0) {
199
200
          class_need_train_[0] = false;
          class_default_output_[0] = -std::log(1.0f / kEpsilon - 1.0f);
Guolin Ke's avatar
Guolin Ke committed
201
        } else if (cnt_pos == num_data_) {
202
203
204
205
206
          class_need_train_[0] = false;
          class_default_output_[0] = -std::log(kEpsilon);
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
207
  }
Guolin Ke's avatar
Guolin Ke committed
208
  gbdt_config_.reset(new_config.release());
Guolin Ke's avatar
Guolin Ke committed
209
210
}

wxchan's avatar
wxchan committed
211
void GBDT::AddValidDataset(const Dataset* valid_data,
212
                           const std::vector<const Metric*>& valid_metrics) {
wxchan's avatar
wxchan committed
213
214
  if (!train_data_->CheckAlign(*valid_data)) {
    Log::Fatal("cannot add validation data, since it has different bin mappers with training data");
215
  }
Guolin Ke's avatar
Guolin Ke committed
216
  // for a validation dataset, we need its score and metric
217
  auto new_score_updater = std::unique_ptr<ScoreUpdater>(new ScoreUpdater(valid_data, num_tree_per_iteration_));
wxchan's avatar
wxchan committed
218
219
  // update score
  for (int i = 0; i < iter_; ++i) {
220
221
222
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
      auto curr_tree = (i + num_init_iteration_) * num_tree_per_iteration_ + cur_tree_id;
      new_score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
223
224
    }
  }
Guolin Ke's avatar
Guolin Ke committed
225
  valid_score_updater_.push_back(std::move(new_score_updater));
Guolin Ke's avatar
Guolin Ke committed
226
  valid_metrics_.emplace_back();
227
228
229
  if (early_stopping_round_ > 0) {
    best_iter_.emplace_back();
    best_score_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
230
    best_msg_.emplace_back();
231
  }
Guolin Ke's avatar
Guolin Ke committed
232
233
  for (const auto& metric : valid_metrics) {
    valid_metrics_.back().push_back(metric);
234
235
236
    if (early_stopping_round_ > 0) {
      best_iter_.back().push_back(0);
      best_score_.back().push_back(kMinScore);
Guolin Ke's avatar
Guolin Ke committed
237
      best_msg_.back().emplace_back();
238
    }
Guolin Ke's avatar
Guolin Ke committed
239
  }
Guolin Ke's avatar
Guolin Ke committed
240
  valid_metrics_.back().shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
241
242
}

243
data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer) {
244
245
246
  if (cnt <= 0) {
    return 0;
  }
247
248
249
250
  data_size_t bag_data_cnt =
    static_cast<data_size_t>(gbdt_config_->bagging_fraction * cnt);
  data_size_t cur_left_cnt = 0;
  data_size_t cur_right_cnt = 0;
Guolin Ke's avatar
Guolin Ke committed
251
  auto right_buffer = buffer + bag_data_cnt;
252
253
  // random bagging, minimal unit is one record
  for (data_size_t i = 0; i < cnt; ++i) {
Guolin Ke's avatar
Guolin Ke committed
254
255
256
    float prob =
      (bag_data_cnt - cur_left_cnt) / static_cast<float>(cnt - i);
    if (cur_rand.NextFloat() < prob) {
257
258
      buffer[cur_left_cnt++] = start + i;
    } else {
Guolin Ke's avatar
Guolin Ke committed
259
      right_buffer[cur_right_cnt++] = start + i;
260
261
262
263
264
    }
  }
  CHECK(cur_left_cnt == bag_data_cnt);
  return cur_left_cnt;
}
Guolin Ke's avatar
Guolin Ke committed
265

266
void GBDT::Bagging(int iter) {
Guolin Ke's avatar
Guolin Ke committed
267
  // if need bagging
268
  if (bag_data_cnt_ < num_data_ && iter % gbdt_config_->bagging_freq == 0) {
Guolin Ke's avatar
Guolin Ke committed
269
    const data_size_t min_inner_size = 1000;
270
271
    data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_;
    if (inner_size < min_inner_size) { inner_size = min_inner_size; }
272
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
273
    #pragma omp parallel for schedule(static,1)
274
    for (int i = 0; i < num_threads_; ++i) {
275
      OMP_LOOP_EX_BEGIN();
276
277
278
279
280
281
      left_cnts_buf_[i] = 0;
      right_cnts_buf_[i] = 0;
      data_size_t cur_start = i * inner_size;
      if (cur_start > num_data_) { continue; }
      data_size_t cur_cnt = inner_size;
      if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; }
Guolin Ke's avatar
Guolin Ke committed
282
283
      Random cur_rand(gbdt_config_->bagging_seed + iter * num_threads_ + i);
      data_size_t cur_left_count = BaggingHelper(cur_rand, cur_start, cur_cnt, tmp_indices_.data() + cur_start);
284
285
286
      offsets_buf_[i] = cur_start;
      left_cnts_buf_[i] = cur_left_count;
      right_cnts_buf_[i] = cur_cnt - cur_left_count;
287
      OMP_LOOP_EX_END();
288
    }
289
    OMP_THROW_EX();
290
291
292
293
294
295
296
297
298
    data_size_t left_cnt = 0;
    left_write_pos_buf_[0] = 0;
    right_write_pos_buf_[0] = 0;
    for (int i = 1; i < num_threads_; ++i) {
      left_write_pos_buf_[i] = left_write_pos_buf_[i - 1] + left_cnts_buf_[i - 1];
      right_write_pos_buf_[i] = right_write_pos_buf_[i - 1] + right_cnts_buf_[i - 1];
    }
    left_cnt = left_write_pos_buf_[num_threads_ - 1] + left_cnts_buf_[num_threads_ - 1];

Guolin Ke's avatar
Guolin Ke committed
299
    #pragma omp parallel for schedule(static, 1)
300
    for (int i = 0; i < num_threads_; ++i) {
301
      OMP_LOOP_EX_BEGIN();
302
303
      if (left_cnts_buf_[i] > 0) {
        std::memcpy(bag_data_indices_.data() + left_write_pos_buf_[i],
304
                    tmp_indices_.data() + offsets_buf_[i], left_cnts_buf_[i] * sizeof(data_size_t));
Guolin Ke's avatar
Guolin Ke committed
305
      }
306
307
      if (right_cnts_buf_[i] > 0) {
        std::memcpy(bag_data_indices_.data() + left_cnt + right_write_pos_buf_[i],
308
                    tmp_indices_.data() + offsets_buf_[i] + left_cnts_buf_[i], right_cnts_buf_[i] * sizeof(data_size_t));
Guolin Ke's avatar
Guolin Ke committed
309
      }
310
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
311
    }
312
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
313
    bag_data_cnt_ = left_cnt;
Guolin Ke's avatar
Guolin Ke committed
314
    Log::Debug("Re-bagging, using %d data to train", bag_data_cnt_);
Guolin Ke's avatar
Guolin Ke committed
315
    // set bagging data to tree learner
Guolin Ke's avatar
Guolin Ke committed
316
317
318
319
    if (!is_use_subset_) {
      tree_learner_->SetBaggingData(bag_data_indices_.data(), bag_data_cnt_);
    } else {
      // get subset
Guolin Ke's avatar
Guolin Ke committed
320
321
      tmp_subset_->ReSize(bag_data_cnt_);
      tmp_subset_->CopySubset(train_data_, bag_data_indices_.data(), bag_data_cnt_, false);
Guolin Ke's avatar
Guolin Ke committed
322
323
      tree_learner_->ResetTrainingData(tmp_subset_.get());
    }
Guolin Ke's avatar
Guolin Ke committed
324
325
326
  }
}

327
void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
328
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
329
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
330
  #endif
331
  // we need to predict out-of-bag scores of data for boosting
Guolin Ke's avatar
Guolin Ke committed
332
  if (num_data_ - bag_data_cnt_ > 0 && !is_use_subset_) {
333
    train_score_updater_->AddScore(tree, bag_data_indices_.data() + bag_data_cnt_, num_data_ - bag_data_cnt_, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
334
  }
Guolin Ke's avatar
Guolin Ke committed
335
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
336
  out_of_bag_score_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
337
  #endif
Guolin Ke's avatar
Guolin Ke committed
338
339
}

Guolin Ke's avatar
Guolin Ke committed
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
double LabelAverage(const float* label, data_size_t num_data) {
  double sum_label = 0.0f;
  #pragma omp parallel for schedule(static) reduction(+:sum_label)
  for (data_size_t i = 0; i < num_data; ++i) {
    sum_label += label[i];
  }
  double init_score = sum_label / num_data;
  if (Network::num_machines() > 1) {
    double global_init_score = 0.0f;
    Network::Allreduce(reinterpret_cast<char*>(&init_score),
                       sizeof(init_score), sizeof(init_score),
                       reinterpret_cast<char*>(&global_init_score),
                       [](const char* src, char* dst, int len) {
      int used_size = 0;
      const int type_size = sizeof(double);
      const double *p1;
      double *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const double *>(src);
        p2 = reinterpret_cast<double *>(dst);
        *p2 += *p1;
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    return global_init_score / Network::num_machines();
  } else {
    return init_score;
  }
}

372
bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) {
373
  // boosting from average prediction. It doesn't work well for classification, remove it for now.
Guolin Ke's avatar
Guolin Ke committed
374
375
  if (models_.empty()
      && gbdt_config_->boost_from_average
376
      && !train_score_updater_->has_init_score()
377
378
379
      && num_class_ <= 1
      && objective_function_ != nullptr
      && objective_function_->BoostFromAverage()) {
380
    auto label = train_data_->metadata().label();
Guolin Ke's avatar
Guolin Ke committed
381
    double init_score = LabelAverage(label, num_data_);
382
    std::unique_ptr<Tree> new_tree(new Tree(2));
Guolin Ke's avatar
Guolin Ke committed
383
    new_tree->Split(0, 0, BinType::NumericalBin, 0, 0, 0, init_score, init_score, 0, 0, -1, 0, 0, 0);
384
385
386
387
388
    train_score_updater_->AddScore(init_score, 0);
    for (auto& score_updater : valid_score_updater_) {
      score_updater->AddScore(init_score, 0);
    }
    models_.push_back(std::move(new_tree));
389
390
    boost_from_average_ = true;
  }
Guolin Ke's avatar
Guolin Ke committed
391
392
  // boosting first
  if (gradient == nullptr || hessian == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
393
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
394
    auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
395
    #endif
Guolin Ke's avatar
Guolin Ke committed
396
    Boosting();
397
398
    gradient = gradients_.data();
    hessian = hessians_.data();
Guolin Ke's avatar
Guolin Ke committed
399
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
400
    boosting_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
401
    #endif
Guolin Ke's avatar
Guolin Ke committed
402
  }
Guolin Ke's avatar
Guolin Ke committed
403
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
404
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
405
  #endif
406
407
  // bagging logic
  Bagging(iter_);
Guolin Ke's avatar
Guolin Ke committed
408
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
409
  bagging_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
410
  #endif
Guolin Ke's avatar
Guolin Ke committed
411
  if (is_use_subset_ && bag_data_cnt_ < num_data_) {
Guolin Ke's avatar
Guolin Ke committed
412
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
413
    start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
414
    #endif
415
416
417
418
419
    if (gradients_.empty()) {
      size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
      gradients_.resize(total_size);
      hessians_.resize(total_size);
    }
Guolin Ke's avatar
Guolin Ke committed
420
    // get sub gradients
421
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
422
      size_t bias = static_cast<size_t>(cur_tree_id)* num_data_;
423
      // cannot multi-threading here.
Guolin Ke's avatar
Guolin Ke committed
424
      for (int i = 0; i < bag_data_cnt_; ++i) {
425
426
        gradients_[bias + i] = gradient[bias + bag_data_indices_[i]];
        hessians_[bias + i] = hessian[bias + bag_data_indices_[i]];
Guolin Ke's avatar
Guolin Ke committed
427
428
      }
    }
429
430
    gradient = gradients_.data();
    hessian = hessians_.data();
Guolin Ke's avatar
Guolin Ke committed
431
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
432
    sub_gradient_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
433
    #endif
Guolin Ke's avatar
Guolin Ke committed
434
  }
Guolin Ke's avatar
Guolin Ke committed
435
  bool should_continue = false;
436
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
437
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
438
    start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
439
    #endif
440
    std::unique_ptr<Tree> new_tree(new Tree(2));
441
    if (class_need_train_[cur_tree_id]) {
442
      size_t bias = static_cast<size_t>(cur_tree_id)* num_data_;
443
      new_tree.reset(
444
        tree_learner_->Train(gradient + bias, hessian + bias, is_constant_hessian_));
445
    }
Guolin Ke's avatar
Guolin Ke committed
446
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
447
    tree_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
448
    #endif
Guolin Ke's avatar
Guolin Ke committed
449
450

    if (new_tree->num_leaves() > 1) {
Guolin Ke's avatar
Guolin Ke committed
451
452
453
454
      should_continue = true;
      // shrinkage by learning rate
      new_tree->Shrinkage(shrinkage_rate_);
      // update score
455
456
      UpdateScore(new_tree.get(), cur_tree_id);
      UpdateScoreOutOfBag(new_tree.get(), cur_tree_id);
457
458
    } else {
      // only add default score one-time
459
460
      if (!class_need_train_[cur_tree_id] && models_.size() < static_cast<size_t>(num_tree_per_iteration_)) {
        auto output = class_default_output_[cur_tree_id];
Guolin Ke's avatar
Guolin Ke committed
461
        new_tree->Split(0, 0, BinType::NumericalBin, 0, 0, 0,
Guolin Ke's avatar
Guolin Ke committed
462
                        output, output, 0, 0, -1, 0, 0, 0);
463
        train_score_updater_->AddScore(output, cur_tree_id);
464
        for (auto& score_updater : valid_score_updater_) {
465
          score_updater->AddScore(output, cur_tree_id);
466
467
468
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
469
470
471
    // add model
    models_.push_back(std::move(new_tree));
  }
Guolin Ke's avatar
Guolin Ke committed
472
  if (!should_continue) {
Guolin Ke's avatar
Guolin Ke committed
473
    Log::Warning("Stopped training because there are no more leaves that meet the split requirements.");
474
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
475
476
477
478
      models_.pop_back();
    }
    return true;
  }
Guolin Ke's avatar
Guolin Ke committed
479
480
481
482
483
484
  ++iter_;
  if (is_eval) {
    return EvalAndCheckEarlyStopping();
  } else {
    return false;
  }
485

Guolin Ke's avatar
Guolin Ke committed
486
}
487

wxchan's avatar
wxchan committed
488
void GBDT::RollbackOneIter() {
489
  if (iter_ <= 0) { return; }
wxchan's avatar
wxchan committed
490
491
  int cur_iter = iter_ + num_init_iteration_ - 1;
  // reset score
492
493
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
    auto curr_tree = cur_iter * num_tree_per_iteration_ + cur_tree_id;
wxchan's avatar
wxchan committed
494
    models_[curr_tree]->Shrinkage(-1.0);
495
    train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
496
    for (auto& score_updater : valid_score_updater_) {
497
      score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
498
499
500
    }
  }
  // remove model
501
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
wxchan's avatar
wxchan committed
502
503
504
505
506
    models_.pop_back();
  }
  --iter_;
}

Guolin Ke's avatar
Guolin Ke committed
507
bool GBDT::EvalAndCheckEarlyStopping() {
508
  bool is_met_early_stopping = false;
Guolin Ke's avatar
Guolin Ke committed
509
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
510
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
511
  #endif
512
  // print message for metric
Guolin Ke's avatar
Guolin Ke committed
513
  auto best_msg = OutputMetric(iter_);
Guolin Ke's avatar
Guolin Ke committed
514
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
515
  metric_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
516
  #endif
Guolin Ke's avatar
Guolin Ke committed
517
  is_met_early_stopping = !best_msg.empty();
518
519
  if (is_met_early_stopping) {
    Log::Info("Early stopping at iteration %d, the best iteration round is %d",
520
              iter_, iter_ - early_stopping_round_);
Guolin Ke's avatar
Guolin Ke committed
521
    Log::Info("Output of best iteration round:\n%s", best_msg.c_str());
522
    // pop last early_stopping_round_ models
523
    for (int i = 0; i < early_stopping_round_ * num_tree_per_iteration_; ++i) {
524
525
526
527
      models_.pop_back();
    }
  }
  return is_met_early_stopping;
Guolin Ke's avatar
Guolin Ke committed
528
529
}

530
void GBDT::UpdateScore(const Tree* tree, const int cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
531
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
532
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
533
  #endif
Guolin Ke's avatar
Guolin Ke committed
534
  // update training score
Guolin Ke's avatar
Guolin Ke committed
535
  if (!is_use_subset_) {
536
    train_score_updater_->AddScore(tree_learner_.get(), tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
537
  } else {
538
    train_score_updater_->AddScore(tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
539
  }
Guolin Ke's avatar
Guolin Ke committed
540
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
541
  train_score_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
542
543
  #endif
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
544
  start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
545
  #endif
Guolin Ke's avatar
Guolin Ke committed
546
  // update validation score
Guolin Ke's avatar
Guolin Ke committed
547
  for (auto& score_updater : valid_score_updater_) {
548
    score_updater->AddScore(tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
549
  }
Guolin Ke's avatar
Guolin Ke committed
550
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
551
  valid_score_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
552
  #endif
Guolin Ke's avatar
Guolin Ke committed
553
554
}

Guolin Ke's avatar
Guolin Ke committed
555
556
557
558
std::string GBDT::OutputMetric(int iter) {
  bool need_output = (iter % gbdt_config_->output_freq) == 0;
  std::string ret = "";
  std::stringstream msg_buf;
559
  std::vector<std::pair<size_t, size_t>> meet_early_stopping_pairs;
Guolin Ke's avatar
Guolin Ke committed
560
  // print training metric
Guolin Ke's avatar
Guolin Ke committed
561
  if (need_output) {
562
563
    for (auto& sub_metric : training_metrics_) {
      auto name = sub_metric->GetName();
Guolin Ke's avatar
Guolin Ke committed
564
      auto scores = sub_metric->Eval(train_score_updater_->score(), objective_function_);
Guolin Ke's avatar
Guolin Ke committed
565
      for (size_t k = 0; k < name.size(); ++k) {
Guolin Ke's avatar
Guolin Ke committed
566
567
568
569
570
571
572
573
        std::stringstream tmp_buf;
        tmp_buf << "Iteration:" << iter
          << ", training " << name[k]
          << " : " << scores[k];
        Log::Info(tmp_buf.str().c_str());
        if (early_stopping_round_ > 0) {
          msg_buf << tmp_buf.str() << std::endl;
        }
574
      }
575
    }
Guolin Ke's avatar
Guolin Ke committed
576
577
  }
  // print validation metric
Guolin Ke's avatar
Guolin Ke committed
578
  if (need_output || early_stopping_round_ > 0) {
579
580
    for (size_t i = 0; i < valid_metrics_.size(); ++i) {
      for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
581
        auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score(),
Guolin Ke's avatar
Guolin Ke committed
582
                                                      objective_function_);
Guolin Ke's avatar
Guolin Ke committed
583
584
585
586
587
588
589
590
591
592
593
        auto name = valid_metrics_[i][j]->GetName();
        for (size_t k = 0; k < name.size(); ++k) {
          std::stringstream tmp_buf;
          tmp_buf << "Iteration:" << iter
            << ", valid_" << i + 1 << " " << name[k]
            << " : " << test_scores[k];
          if (need_output) {
            Log::Info(tmp_buf.str().c_str());
          }
          if (early_stopping_round_ > 0) {
            msg_buf << tmp_buf.str() << std::endl;
594
          }
wxchan's avatar
wxchan committed
595
        }
Guolin Ke's avatar
Guolin Ke committed
596
        if (ret.empty() && early_stopping_round_ > 0) {
597
598
599
          auto cur_score = valid_metrics_[i][j]->factor_to_bigger_better() * test_scores.back();
          if (cur_score > best_score_[i][j]) {
            best_score_[i][j] = cur_score;
600
            best_iter_[i][j] = iter;
Guolin Ke's avatar
Guolin Ke committed
601
            meet_early_stopping_pairs.emplace_back(i, j);
602
          } else {
Guolin Ke's avatar
Guolin Ke committed
603
            if (iter - best_iter_[i][j] >= early_stopping_round_) { ret = best_msg_[i][j]; }
604
          }
wxchan's avatar
wxchan committed
605
606
        }
      }
Guolin Ke's avatar
Guolin Ke committed
607
608
    }
  }
Guolin Ke's avatar
Guolin Ke committed
609
610
611
  for (auto& pair : meet_early_stopping_pairs) {
    best_msg_[pair.first][pair.second] = msg_buf.str();
  }
wxchan's avatar
wxchan committed
612
  return ret;
Guolin Ke's avatar
Guolin Ke committed
613
614
}

615
/*! \brief Get eval result */
616
std::vector<double> GBDT::GetEvalAt(int data_idx) const {
Guolin Ke's avatar
Guolin Ke committed
617
  CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
618
619
  std::vector<double> ret;
  if (data_idx == 0) {
620
    for (auto& sub_metric : training_metrics_) {
Guolin Ke's avatar
Guolin Ke committed
621
      auto scores = sub_metric->Eval(train_score_updater_->score(), objective_function_);
622
623
624
      for (auto score : scores) {
        ret.push_back(score);
      }
625
    }
626
  } else {
627
628
    auto used_idx = data_idx - 1;
    for (size_t j = 0; j < valid_metrics_[used_idx].size(); ++j) {
629
      auto test_scores = valid_metrics_[used_idx][j]->Eval(valid_score_updater_[used_idx]->score(),
Guolin Ke's avatar
Guolin Ke committed
630
                                                           objective_function_);
631
632
633
      for (auto score : test_scores) {
        ret.push_back(score);
      }
634
635
636
637
638
    }
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
639
/*! \brief Get training scores result */
640
const double* GBDT::GetTrainingScore(int64_t* out_len) {
641
  *out_len = static_cast<int64_t>(train_score_updater_->num_data()) * num_class_;
Guolin Ke's avatar
Guolin Ke committed
642
  return train_score_updater_->score();
643
644
}

Guolin Ke's avatar
Guolin Ke committed
645
646
void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
  CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
Guolin Ke's avatar
Guolin Ke committed
647

648
  const double* raw_scores = nullptr;
Guolin Ke's avatar
Guolin Ke committed
649
650
  data_size_t num_data = 0;
  if (data_idx == 0) {
wxchan's avatar
wxchan committed
651
    raw_scores = GetTrainingScore(out_len);
Guolin Ke's avatar
Guolin Ke committed
652
653
654
655
656
    num_data = train_score_updater_->num_data();
  } else {
    auto used_idx = data_idx - 1;
    raw_scores = valid_score_updater_[used_idx]->score();
    num_data = valid_score_updater_[used_idx]->num_data();
657
    *out_len = static_cast<int64_t>(num_data) * num_class_;
Guolin Ke's avatar
Guolin Ke committed
658
  }
659
  if (objective_function_ != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
660
661
    #pragma omp parallel for schedule(static)
    for (data_size_t i = 0; i < num_data; ++i) {
Guolin Ke's avatar
Guolin Ke committed
662
      std::vector<double> tree_pred(num_tree_per_iteration_);
663
      for (int j = 0; j < num_tree_per_iteration_; ++j) {
Guolin Ke's avatar
Guolin Ke committed
664
        tree_pred[j] = raw_scores[j * num_data + i];
665
      }
Guolin Ke's avatar
Guolin Ke committed
666
667
      std::vector<double> tmp_result(num_class_);
      objective_function_->ConvertOutput(tree_pred.data(), tmp_result.data());
Guolin Ke's avatar
Guolin Ke committed
668
      for (int j = 0; j < num_class_; ++j) {
669
        out_result[j * num_data + i] = static_cast<double>(tmp_result[j]);
Guolin Ke's avatar
Guolin Ke committed
670
671
      }
    }
672
  } else {
Guolin Ke's avatar
Guolin Ke committed
673
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
674
    for (data_size_t i = 0; i < num_data; ++i) {
Guolin Ke's avatar
Guolin Ke committed
675
      std::vector<double> tmp_result(num_tree_per_iteration_);
676
      for (int j = 0; j < num_tree_per_iteration_; ++j) {
Guolin Ke's avatar
Guolin Ke committed
677
        out_result[j * num_data + i] = static_cast<double>(raw_scores[j * num_data + i]);
Guolin Ke's avatar
Guolin Ke committed
678
679
680
681
682
      }
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
683
void GBDT::Boosting() {
684
  if (objective_function_ == nullptr) {
685
686
    Log::Fatal("No object function provided");
  }
Hui Xue's avatar
Hui Xue committed
687
  // objective function will calculate gradients and hessians
688
  int64_t num_score = 0;
689
  objective_function_->
Guolin Ke's avatar
Guolin Ke committed
690
    GetGradients(GetTrainingScore(&num_score), gradients_.data(), hessians_.data());
Guolin Ke's avatar
Guolin Ke committed
691
692
}

693
std::string GBDT::DumpModel(int num_iteration) const {
Guolin Ke's avatar
Guolin Ke committed
694
  std::stringstream str_buf;
wxchan's avatar
wxchan committed
695

Guolin Ke's avatar
Guolin Ke committed
696
  str_buf << "{";
Guolin Ke's avatar
Guolin Ke committed
697
  str_buf << "\"name\":\"" << SubModelName() << "\"," << std::endl;
Guolin Ke's avatar
Guolin Ke committed
698
  str_buf << "\"num_class\":" << num_class_ << "," << std::endl;
699
  str_buf << "\"num_tree_per_iteration\":" << num_tree_per_iteration_ << "," << std::endl;
Guolin Ke's avatar
Guolin Ke committed
700
701
  str_buf << "\"label_index\":" << label_idx_ << "," << std::endl;
  str_buf << "\"max_feature_idx\":" << max_feature_idx_ << "," << std::endl;
wxchan's avatar
wxchan committed
702

703
704
705
  str_buf << "\"feature_names\":[\""
    << Common::Join(feature_names_, "\",\"") << "\"],"
    << std::endl;
Guolin Ke's avatar
Guolin Ke committed
706

Guolin Ke's avatar
Guolin Ke committed
707
  str_buf << "\"tree_info\":[";
708
709
  int num_used_model = static_cast<int>(models_.size());
  if (num_iteration > 0) {
Guolin Ke's avatar
Guolin Ke committed
710
    num_iteration += boost_from_average_ ? 1 : 0;
711
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
712
  }
713
  for (int i = 0; i < num_used_model; ++i) {
wxchan's avatar
wxchan committed
714
    if (i > 0) {
Guolin Ke's avatar
Guolin Ke committed
715
      str_buf << ",";
wxchan's avatar
wxchan committed
716
    }
Guolin Ke's avatar
Guolin Ke committed
717
718
719
720
    str_buf << "{";
    str_buf << "\"tree_index\":" << i << ",";
    str_buf << models_[i]->ToJSON();
    str_buf << "}";
wxchan's avatar
wxchan committed
721
  }
Guolin Ke's avatar
Guolin Ke committed
722
  str_buf << "]" << std::endl;
wxchan's avatar
wxchan committed
723

Guolin Ke's avatar
Guolin Ke committed
724
  str_buf << "}" << std::endl;
wxchan's avatar
wxchan committed
725

Guolin Ke's avatar
Guolin Ke committed
726
  return str_buf.str();
wxchan's avatar
wxchan committed
727
728
}

729
730
731
std::string GBDT::ModelToIfElse(int num_iteration) const {
  std::stringstream str_buf;

732
733
734
735
  str_buf << "#include \"gbdt.h\"" << std::endl;
  str_buf << "#include <LightGBM/utils/common.h>" << std::endl;
  str_buf << "#include <LightGBM/objective_function.h>" << std::endl;
  str_buf << "#include <LightGBM/metric.h>" << std::endl;
cbecker's avatar
cbecker committed
736
  str_buf << "#include <LightGBM/prediction_early_stop.h>" << std::endl;
737
738
739
740
741
742
743
744
  str_buf << "#include <ctime>" << std::endl;
  str_buf << "#include <sstream>" << std::endl;
  str_buf << "#include <chrono>" << std::endl;
  str_buf << "#include <string>" << std::endl;
  str_buf << "#include <vector>" << std::endl;
  str_buf << "#include <utility>" << std::endl;
  str_buf << "namespace LightGBM {" << std::endl;

745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
  int num_used_model = static_cast<int>(models_.size());
  if (num_iteration > 0) {
    num_iteration += boost_from_average_ ? 1 : 0;
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
  }

  // PredictRaw
  for (int i = 0; i < num_used_model; ++i) {
    str_buf << models_[i]->ToIfElse(i, false) << std::endl;
  }

  str_buf << "double (*PredictTreePtr[])(const double*) = { ";
  for (int i = 0; i < num_used_model; ++i) {
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i;
  }
  str_buf << " };" << std::endl << std::endl;

  std::stringstream pred_str_buf;

767
  pred_str_buf << "\t" << "int early_stop_round_counter = 0;" << std::endl;
cbecker's avatar
cbecker committed
768
  pred_str_buf << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl;
769
  pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
cbecker's avatar
cbecker committed
770
  pred_str_buf << "\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
771
  pred_str_buf << "\t\t" << "}" << std::endl;
772
773
774
  pred_str_buf << "\t\t" << "++early_stop_round_counter;" << std::endl;
  pred_str_buf << "\t\t" << "if (early_stop->round_period == early_stop_round_counter) {" << std::endl;
  pred_str_buf << "\t\t\t" << "if (early_stop->callback_function(output, num_tree_per_iteration_))" << std::endl;
cbecker's avatar
cbecker committed
775
  pred_str_buf << "\t\t\t\t" << "return;" << std::endl;
776
  pred_str_buf << "\t\t\t" << "early_stop_round_counter = 0;" << std::endl;
777
778
779
  pred_str_buf << "\t\t" << "}" << std::endl;
  pred_str_buf << "\t" << "}" << std::endl;

780
  str_buf << "void GBDT::PredictRaw(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << std::endl;
781
782
783
784
785
  str_buf << pred_str_buf.str();
  str_buf << "}" << std::endl;
  str_buf << std::endl;

  // Predict
786
787
  str_buf << "void GBDT::Predict(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << std::endl;
  str_buf << "\t" << "PredictRaw(features, output, early_stop);" << std::endl;
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
  str_buf << "\t" << "if (objective_function_ != nullptr) {" << std::endl;
  str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << std::endl;
  str_buf << "\t" << "}" << std::endl;
  str_buf << "}" << std::endl;
  str_buf << std::endl;

  // PredictLeafIndex
  for (int i = 0; i < num_used_model; ++i) {
    str_buf << models_[i]->ToIfElse(i, true) << std::endl;
  }

  str_buf << "double (*PredictTreeLeafPtr[])(const double*) = { ";
  for (int i = 0; i < num_used_model; ++i) {
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i << "Leaf";
  }
  str_buf << " };" << std::endl << std::endl;

  str_buf << "void GBDT::PredictLeafIndex(const double* features, double *output) const {" << std::endl;
  str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << std::endl;
  str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << std::endl;
  str_buf << "\t\t" << "output[i] = (*PredictTreeLeafPtr[i])(features);" << std::endl;
  str_buf << "\t" << "}" << std::endl;
  str_buf << "}" << std::endl;
814
815
816

  str_buf << "}  // namespace LightGBM" << std::endl;

817
818
819
820
821
822
  return str_buf.str();
}

bool GBDT::SaveModelToIfElse(int num_iteration, const char* filename) const {
  /*! \brief File to write models */
  std::ofstream output_file;
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
  std::ifstream ifs(filename);
  if (ifs.good()) {
    std::string origin((std::istreambuf_iterator<char>(ifs)),
                       (std::istreambuf_iterator<char>()));
    output_file.open(filename);
    output_file << "#define USE_HARD_CODE 0" << std::endl;
    output_file << "#ifndef USE_HARD_CODE" << std::endl;
    output_file << origin << std::endl;
    output_file << "#else" << std::endl;
    output_file << ModelToIfElse(num_iteration);
    output_file << "#endif" << std::endl;
  } else {
    output_file.open(filename);
    output_file << ModelToIfElse(num_iteration);
  }
838

839
  ifs.close();
840
841
842
843
844
  output_file.close();

  return (bool)output_file;
}

Guolin Ke's avatar
Guolin Ke committed
845
std::string GBDT::SaveModelToString(int num_iteration) const {
846
  std::stringstream ss;
847

848
849
850
851
  // output model type
  ss << SubModelName() << std::endl;
  // output number of class
  ss << "num_class=" << num_class_ << std::endl;
852
  ss << "num_tree_per_iteration=" << num_tree_per_iteration_ << std::endl;
853
854
855
856
  // output label index
  ss << "label_index=" << label_idx_ << std::endl;
  // output max_feature_idx
  ss << "max_feature_idx=" << max_feature_idx_ << std::endl;
857
858
859
  // output objective
  if (objective_function_ != nullptr) {
    ss << "objective=" << objective_function_->ToString() << std::endl;
860
  }
861

862
863
864
  if (boost_from_average_) {
    ss << "boost_from_average" << std::endl;
  }
Guolin Ke's avatar
Guolin Ke committed
865

866
  ss << "feature_names=" << Common::Join(feature_names_, " ") << std::endl;
867

868
  ss << "feature_infos=" << Common::Join(feature_infos_, " ") << std::endl;
869

870
871
  ss << std::endl;
  int num_used_model = static_cast<int>(models_.size());
Guolin Ke's avatar
Guolin Ke committed
872
873
  if (num_iteration > 0) {
    num_iteration += boost_from_average_ ? 1 : 0;
874
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
875
876
877
878
879
880
881
882
883
884
885
886
887
888
  }
  // output tree models
  for (int i = 0; i < num_used_model; ++i) {
    ss << "Tree=" << i << std::endl;
    ss << models_[i]->ToString() << std::endl;
  }

  std::vector<std::pair<size_t, std::string>> pairs = FeatureImportance();
  ss << std::endl << "feature importances:" << std::endl;
  for (size_t i = 0; i < pairs.size(); ++i) {
    ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << std::endl;
  }

  return ss.str();
889
890
}

891
bool GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
wxchan's avatar
wxchan committed
892
893
894
  /*! \brief File to write models */
  std::ofstream output_file;
  output_file.open(filename);
895

896
  output_file << SaveModelToString(num_iteration);
897

wxchan's avatar
wxchan committed
898
  output_file.close();
899
900

  return (bool)output_file;
Guolin Ke's avatar
Guolin Ke committed
901
902
}

903
bool GBDT::LoadModelFromString(const std::string& model_str) {
Guolin Ke's avatar
Guolin Ke committed
904
905
  // use serialized string to restore this object
  models_.clear();
Guolin Ke's avatar
Guolin Ke committed
906
  std::vector<std::string> lines = Common::SplitLines(model_str.c_str());
907
908

  // get number of classes
909
910
911
912
  auto line = Common::FindFromLines(lines, "num_class=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &num_class_);
  } else {
913
    Log::Fatal("Model file doesn't specify the number of classes");
914
    return false;
915
  }
916
917
918
919
920
921
922
923

  line = Common::FindFromLines(lines, "num_tree_per_iteration=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &num_tree_per_iteration_);
  } else {
    num_tree_per_iteration_ = num_class_;
  }

Guolin Ke's avatar
Guolin Ke committed
924
  // get index of label
925
926
927
928
  line = Common::FindFromLines(lines, "label_index=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &label_idx_);
  } else {
929
    Log::Fatal("Model file doesn't specify the label index");
930
    return false;
Guolin Ke's avatar
Guolin Ke committed
931
  }
Guolin Ke's avatar
Guolin Ke committed
932
  // get max_feature_idx first
933
934
935
936
  line = Common::FindFromLines(lines, "max_feature_idx=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &max_feature_idx_);
  } else {
937
    Log::Fatal("Model file doesn't specify max_feature_idx");
938
    return false;
Guolin Ke's avatar
Guolin Ke committed
939
  }
940
941
942
943
944
  // get boost_from_average_
  line = Common::FindFromLines(lines, "boost_from_average");
  if (line.size() > 0) {
    boost_from_average_ = true;
  }
Guolin Ke's avatar
Guolin Ke committed
945
946
947
  // get feature names
  line = Common::FindFromLines(lines, "feature_names=");
  if (line.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
948
    feature_names_ = Common::Split(line.substr(std::strlen("feature_names=")).c_str(), ' ');
Guolin Ke's avatar
Guolin Ke committed
949
950
    if (feature_names_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of feature_names");
951
      return false;
Guolin Ke's avatar
Guolin Ke committed
952
    }
953
  } else {
Guolin Ke's avatar
Guolin Ke committed
954
    Log::Fatal("Model file doesn't contain feature names");
955
    return false;
Guolin Ke's avatar
Guolin Ke committed
956
957
  }

Guolin Ke's avatar
Guolin Ke committed
958
959
  line = Common::FindFromLines(lines, "feature_infos=");
  if (line.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
960
    feature_infos_ = Common::Split(line.substr(std::strlen("feature_infos=")).c_str(), ' ');
Guolin Ke's avatar
Guolin Ke committed
961
962
963
964
965
966
967
968
969
    if (feature_infos_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of feature_infos");
      return false;
    }
  } else {
    Log::Fatal("Model file doesn't contain feature infos");
    return false;
  }

970
971
972
973
974
975
976
977
  line = Common::FindFromLines(lines, "objective=");

  if (line.size() > 0) {
    auto str = Common::Split(line.c_str(), '=')[1];
    loaded_objective_.reset(ObjectiveFunction::CreateObjectiveFunction(str));
    objective_function_ = loaded_objective_.get();
  }

Guolin Ke's avatar
Guolin Ke committed
978
  // get tree models
979
  size_t i = 0;
Guolin Ke's avatar
Guolin Ke committed
980
981
982
983
984
985
986
  while (i < lines.size()) {
    size_t find_pos = lines[i].find("Tree=");
    if (find_pos != std::string::npos) {
      ++i;
      int start = static_cast<int>(i);
      while (i < lines.size() && lines[i].find("Tree=") == std::string::npos) { ++i; }
      int end = static_cast<int>(i);
Guolin Ke's avatar
Guolin Ke committed
987
      std::string tree_str = Common::Join<std::string>(lines, start, end, "\n");
988
      models_.emplace_back(new Tree(tree_str));
Guolin Ke's avatar
Guolin Ke committed
989
990
991
992
    } else {
      ++i;
    }
  }
993
  Log::Info("Finished loading %d models", models_.size());
994
  num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
wxchan's avatar
wxchan committed
995
  num_init_iteration_ = num_iteration_for_pred_;
996
  iter_ = 0;
997
998

  return true;
Guolin Ke's avatar
Guolin Ke committed
999
1000
}

wxchan's avatar
wxchan committed
1001
std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const {
1002

1003
  std::vector<size_t> feature_importances(max_feature_idx_ + 1, 0);
1004
1005
  for (size_t iter = 0; iter < models_.size(); ++iter) {
    for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
Guolin Ke's avatar
Guolin Ke committed
1006
1007
1008
      if (models_[iter]->split_gain(split_idx) > 0) {
        ++feature_importances[models_[iter]->split_feature(split_idx)];
      }
wxchan's avatar
wxchan committed
1009
    }
1010
1011
1012
1013
1014
1015
  }
  // store the importance first
  std::vector<std::pair<size_t, std::string>> pairs;
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    if (feature_importances[i] > 0) {
      pairs.emplace_back(feature_importances[i], feature_names_[i]);
1016
    }
1017
1018
1019
  }
  // sort the importance
  std::sort(pairs.begin(), pairs.end(),
Guolin Ke's avatar
Guolin Ke committed
1020
1021
            [](const std::pair<size_t, std::string>& lhs,
               const std::pair<size_t, std::string>& rhs) {
1022
1023
1024
    return lhs.first > rhs.first;
  });
  return pairs;
wxchan's avatar
wxchan committed
1025
1026
}

Guolin Ke's avatar
Guolin Ke committed
1027
}  // namespace LightGBM