"python-package/vscode:/vscode.git/clone" did not exist on "dae7551629d3443a70b3b163b4f2304b9e65b059"
gbdt.cpp 37.2 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
#include "gbdt.h"

3
#include <LightGBM/utils/openmp_wrapper.h>
4

Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#include <LightGBM/utils/common.h>

#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
cbecker's avatar
cbecker committed
9
#include <LightGBM/prediction_early_stop.h>
Guolin Ke's avatar
Guolin Ke committed
10
#include <LightGBM/network.h>
Guolin Ke's avatar
Guolin Ke committed
11
12
13
14
15
16
17

#include <ctime>

#include <sstream>
#include <chrono>
#include <string>
#include <vector>
18
#include <utility>
Guolin Ke's avatar
Guolin Ke committed
19
20
21

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
22
23
24
#ifdef TIMETAG
std::chrono::duration<double, std::milli> boosting_time;
std::chrono::duration<double, std::milli> train_score_time;
Guolin Ke's avatar
Guolin Ke committed
25
std::chrono::duration<double, std::milli> out_of_bag_score_time;
Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
30
31
32
std::chrono::duration<double, std::milli> valid_score_time;
std::chrono::duration<double, std::milli> metric_time;
std::chrono::duration<double, std::milli> bagging_time;
std::chrono::duration<double, std::milli> sub_gradient_time;
std::chrono::duration<double, std::milli> tree_time;
#endif // TIMETAG

33
GBDT::GBDT()
34
  :iter_(0),
35
  train_data_(nullptr),
36
  objective_function_(nullptr),
37
38
  early_stopping_round_(0),
  max_feature_idx_(0),
39
  num_tree_per_iteration_(1),
40
  num_class_(1),
41
  num_iteration_for_pred_(0),
42
  shrinkage_rate_(0.1f),
43
44
  num_init_iteration_(0),
  boost_from_average_(false) {
Guolin Ke's avatar
Guolin Ke committed
45
46
  #pragma omp parallel
  #pragma omp master
47
48
49
  {
    num_threads_ = omp_get_num_threads();
  }
Guolin Ke's avatar
Guolin Ke committed
50
51
52
}

GBDT::~GBDT() {
Guolin Ke's avatar
Guolin Ke committed
53
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
54
55
  Log::Info("GBDT::boosting costs %f", boosting_time * 1e-3);
  Log::Info("GBDT::train_score costs %f", train_score_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
56
  Log::Info("GBDT::out_of_bag_score costs %f", out_of_bag_score_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
57
58
59
60
61
  Log::Info("GBDT::valid_score costs %f", valid_score_time * 1e-3);
  Log::Info("GBDT::metric costs %f", metric_time * 1e-3);
  Log::Info("GBDT::bagging costs %f", bagging_time * 1e-3);
  Log::Info("GBDT::sub_gradient costs %f", sub_gradient_time * 1e-3);
  Log::Info("GBDT::tree costs %f", tree_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
62
  #endif
Guolin Ke's avatar
Guolin Ke committed
63
64
}

65
void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
66
                const std::vector<const Metric*>& training_metrics) {
67
  iter_ = 0;
wxchan's avatar
wxchan committed
68
  num_iteration_for_pred_ = 0;
69
  max_feature_idx_ = 0;
wxchan's avatar
wxchan committed
70
71
  num_class_ = config->num_class;
  train_data_ = nullptr;
Guolin Ke's avatar
Guolin Ke committed
72
  gbdt_config_ = nullptr;
73
  tree_learner_ = nullptr;
74
  ResetTrainingData(config, train_data, objective_function, training_metrics);
wxchan's avatar
wxchan committed
75
76
}

77
void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
78
                             const std::vector<const Metric*>& training_metrics) {
Guolin Ke's avatar
Guolin Ke committed
79
  auto new_config = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
wxchan's avatar
wxchan committed
80
81
82
  if (train_data_ != nullptr && !train_data_->CheckAlign(*train_data)) {
    Log::Fatal("cannot reset training data, since new training data has different bin mappers");
  }
Guolin Ke's avatar
Guolin Ke committed
83
84
85
  early_stopping_round_ = new_config->early_stopping_round;
  shrinkage_rate_ = new_config->learning_rate;

86
87
88
89
  objective_function_ = objective_function;
  num_tree_per_iteration_ = num_class_;
  if (objective_function_ != nullptr) {
    is_constant_hessian_ = objective_function_->IsConstantHessian();
Guolin Ke's avatar
Guolin Ke committed
90
    num_tree_per_iteration_ = objective_function_->NumTreePerIteration();
91
92
93
  } else {
    is_constant_hessian_ = false;
  }
Guolin Ke's avatar
Guolin Ke committed
94

Guolin Ke's avatar
Guolin Ke committed
95
  if (train_data_ != train_data && train_data != nullptr) {
96
    if (tree_learner_ == nullptr) {
97
      tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(new_config->tree_learner_type, new_config->device_type, &new_config->tree_config));
Guolin Ke's avatar
Guolin Ke committed
98
99
    }
    // init tree learner
100
    tree_learner_->Init(train_data, is_constant_hessian_);
Guolin Ke's avatar
Guolin Ke committed
101

Guolin Ke's avatar
Guolin Ke committed
102
103
104
105
106
107
    // push training metrics
    training_metrics_.clear();
    for (const auto& metric : training_metrics) {
      training_metrics_.push_back(metric);
    }
    training_metrics_.shrink_to_fit();
wxchan's avatar
wxchan committed
108
109
    // not same training data, need reset score and others
    // create score tracker
110
    train_score_updater_.reset(new ScoreUpdater(train_data, num_tree_per_iteration_));
wxchan's avatar
wxchan committed
111
112
    // update score
    for (int i = 0; i < iter_; ++i) {
113
114
115
      for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
        auto curr_tree = (i + num_init_iteration_) * num_tree_per_iteration_ + cur_tree_id;
        train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
116
117
118
119
      }
    }
    num_data_ = train_data->num_data();
    // create buffer for gradients and hessians
120
121
122
123
124
    if (objective_function_ != nullptr) {
      size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
      gradients_.resize(total_size);
      hessians_.resize(total_size);
    }
wxchan's avatar
wxchan committed
125
126
127
128
    // get max feature index
    max_feature_idx_ = train_data->num_total_features() - 1;
    // get label index
    label_idx_ = train_data->label_idx();
129
130
    // get feature names
    feature_names_ = train_data->feature_names();
Guolin Ke's avatar
Guolin Ke committed
131
132

    feature_infos_ = train_data->feature_infos();
Guolin Ke's avatar
Guolin Ke committed
133
134
  }

Guolin Ke's avatar
Guolin Ke committed
135
  if ((train_data_ != train_data && train_data != nullptr)
136
      || (gbdt_config_ != nullptr && gbdt_config_->bagging_fraction != new_config->bagging_fraction)) {
wxchan's avatar
wxchan committed
137
    // if need bagging, create buffer
Guolin Ke's avatar
Guolin Ke committed
138
    if (new_config->bagging_fraction < 1.0 && new_config->bagging_freq > 0) {
139
140
      bag_data_cnt_ =
        static_cast<data_size_t>(new_config->bagging_fraction * num_data_);
141
      bag_data_indices_.resize(num_data_);
142
143
144
145
146
147
      tmp_indices_.resize(num_data_);
      offsets_buf_.resize(num_threads_);
      left_cnts_buf_.resize(num_threads_);
      right_cnts_buf_.resize(num_threads_);
      left_write_pos_buf_.resize(num_threads_);
      right_write_pos_buf_.resize(num_threads_);
Guolin Ke's avatar
Guolin Ke committed
148
      double average_bag_rate = new_config->bagging_fraction / new_config->bagging_freq;
149
150
151
152
153
154
      int sparse_group = 0;
      for (int i = 0; i < train_data->num_feature_groups(); ++i) {
        if (train_data->FeatureGroupIsSparse(i)) {
          ++sparse_group;
        }
      }
Guolin Ke's avatar
Guolin Ke committed
155
      is_use_subset_ = false;
156
157
158
159
      const int group_threshold_usesubset = 100;
      const int sparse_group_threshold_usesubset = train_data->num_feature_groups() / 4;
      if (average_bag_rate <= 0.5 
          && (train_data->num_feature_groups() < group_threshold_usesubset || sparse_group < sparse_group_threshold_usesubset)) {
Guolin Ke's avatar
Guolin Ke committed
160
        tmp_subset_.reset(new Dataset(bag_data_cnt_));
161
        tmp_subset_->CopyFeatureMapperFrom(train_data);
Guolin Ke's avatar
Guolin Ke committed
162
163
164
        is_use_subset_ = true;
        Log::Debug("use subset for bagging");
      }
wxchan's avatar
wxchan committed
165
166
167
    } else {
      bag_data_cnt_ = num_data_;
      bag_data_indices_.clear();
168
      tmp_indices_.clear();
Guolin Ke's avatar
Guolin Ke committed
169
      is_use_subset_ = false;
wxchan's avatar
wxchan committed
170
    }
Guolin Ke's avatar
Guolin Ke committed
171
  }
wxchan's avatar
wxchan committed
172
  train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
173
174
  if (train_data_ != nullptr) {
    // reset config for tree learner
175
    tree_learner_->ResetConfig(&new_config->tree_config);
176
177
178
    class_need_train_ = std::vector<bool>(num_tree_per_iteration_, true);
    if (objective_function_ != nullptr && objective_function_->SkipEmptyClass()) {
      CHECK(num_tree_per_iteration_ == num_class_);
179
      // + 1 here for the binary classification
Guolin Ke's avatar
Guolin Ke committed
180
      class_default_output_ = std::vector<double>(num_tree_per_iteration_, 0.0f);
181
      auto label = train_data_->metadata().label();
182
      if (num_tree_per_iteration_ > 1) {
Guolin Ke's avatar
Guolin Ke committed
183
184
185
        // multi-class
        std::vector<data_size_t> cnt_per_class(num_tree_per_iteration_, 0);
        for (data_size_t i = 0; i < num_data_; ++i) {
186
187
188
          int index = static_cast<int>(label[i]);
          CHECK(index < num_tree_per_iteration_);
          ++cnt_per_class[index];
Guolin Ke's avatar
Guolin Ke committed
189
        }
190
        for (int i = 0; i < num_tree_per_iteration_; ++i) {
191
192
193
194
195
          if (cnt_per_class[i] == num_data_) {
            class_need_train_[i] = false;
            class_default_output_[i] = -std::log(kEpsilon);
          } else if (cnt_per_class[i] == 0) {
            class_need_train_[i] = false;
196
            class_default_output_[i] = -std::log(1.0f / kEpsilon - 1.0f);
197
198
199
          }
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
200
201
202
203
204
205
206
207
        // binary class
        data_size_t cnt_pos = 0;
        for (data_size_t i = 0; i < num_data_; ++i) {
          if (label[i] > 0) {
            ++cnt_pos;
          }
        }
        if (cnt_pos == 0) {
208
209
          class_need_train_[0] = false;
          class_default_output_[0] = -std::log(1.0f / kEpsilon - 1.0f);
Guolin Ke's avatar
Guolin Ke committed
210
        } else if (cnt_pos == num_data_) {
211
212
213
214
215
          class_need_train_[0] = false;
          class_default_output_[0] = -std::log(kEpsilon);
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
216
  }
Guolin Ke's avatar
Guolin Ke committed
217
  gbdt_config_.reset(new_config.release());
Guolin Ke's avatar
Guolin Ke committed
218
219
}

wxchan's avatar
wxchan committed
220
void GBDT::AddValidDataset(const Dataset* valid_data,
221
                           const std::vector<const Metric*>& valid_metrics) {
wxchan's avatar
wxchan committed
222
223
  if (!train_data_->CheckAlign(*valid_data)) {
    Log::Fatal("cannot add validation data, since it has different bin mappers with training data");
224
  }
Guolin Ke's avatar
Guolin Ke committed
225
  // for a validation dataset, we need its score and metric
226
  auto new_score_updater = std::unique_ptr<ScoreUpdater>(new ScoreUpdater(valid_data, num_tree_per_iteration_));
wxchan's avatar
wxchan committed
227
228
  // update score
  for (int i = 0; i < iter_; ++i) {
229
230
231
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
      auto curr_tree = (i + num_init_iteration_) * num_tree_per_iteration_ + cur_tree_id;
      new_score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
232
233
    }
  }
Guolin Ke's avatar
Guolin Ke committed
234
  valid_score_updater_.push_back(std::move(new_score_updater));
Guolin Ke's avatar
Guolin Ke committed
235
  valid_metrics_.emplace_back();
236
237
238
  if (early_stopping_round_ > 0) {
    best_iter_.emplace_back();
    best_score_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
239
    best_msg_.emplace_back();
240
  }
Guolin Ke's avatar
Guolin Ke committed
241
242
  for (const auto& metric : valid_metrics) {
    valid_metrics_.back().push_back(metric);
243
244
245
    if (early_stopping_round_ > 0) {
      best_iter_.back().push_back(0);
      best_score_.back().push_back(kMinScore);
Guolin Ke's avatar
Guolin Ke committed
246
      best_msg_.back().emplace_back();
247
    }
Guolin Ke's avatar
Guolin Ke committed
248
  }
Guolin Ke's avatar
Guolin Ke committed
249
  valid_metrics_.back().shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
250
251
}

252
data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer) {
253
254
255
  if (cnt <= 0) {
    return 0;
  }
256
257
258
259
  data_size_t bag_data_cnt =
    static_cast<data_size_t>(gbdt_config_->bagging_fraction * cnt);
  data_size_t cur_left_cnt = 0;
  data_size_t cur_right_cnt = 0;
Guolin Ke's avatar
Guolin Ke committed
260
  auto right_buffer = buffer + bag_data_cnt;
261
262
  // random bagging, minimal unit is one record
  for (data_size_t i = 0; i < cnt; ++i) {
Guolin Ke's avatar
Guolin Ke committed
263
264
265
    float prob =
      (bag_data_cnt - cur_left_cnt) / static_cast<float>(cnt - i);
    if (cur_rand.NextFloat() < prob) {
266
267
      buffer[cur_left_cnt++] = start + i;
    } else {
Guolin Ke's avatar
Guolin Ke committed
268
      right_buffer[cur_right_cnt++] = start + i;
269
270
271
272
273
    }
  }
  CHECK(cur_left_cnt == bag_data_cnt);
  return cur_left_cnt;
}
Guolin Ke's avatar
Guolin Ke committed
274

275
void GBDT::Bagging(int iter) {
Guolin Ke's avatar
Guolin Ke committed
276
  // if need bagging
277
  if (bag_data_cnt_ < num_data_ && iter % gbdt_config_->bagging_freq == 0) {
Guolin Ke's avatar
Guolin Ke committed
278
    const data_size_t min_inner_size = 1000;
279
280
    data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_;
    if (inner_size < min_inner_size) { inner_size = min_inner_size; }
281
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
282
    #pragma omp parallel for schedule(static,1)
283
    for (int i = 0; i < num_threads_; ++i) {
284
      OMP_LOOP_EX_BEGIN();
285
286
287
288
289
290
      left_cnts_buf_[i] = 0;
      right_cnts_buf_[i] = 0;
      data_size_t cur_start = i * inner_size;
      if (cur_start > num_data_) { continue; }
      data_size_t cur_cnt = inner_size;
      if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; }
Guolin Ke's avatar
Guolin Ke committed
291
292
      Random cur_rand(gbdt_config_->bagging_seed + iter * num_threads_ + i);
      data_size_t cur_left_count = BaggingHelper(cur_rand, cur_start, cur_cnt, tmp_indices_.data() + cur_start);
293
294
295
      offsets_buf_[i] = cur_start;
      left_cnts_buf_[i] = cur_left_count;
      right_cnts_buf_[i] = cur_cnt - cur_left_count;
296
      OMP_LOOP_EX_END();
297
    }
298
    OMP_THROW_EX();
299
300
301
302
303
304
305
306
307
    data_size_t left_cnt = 0;
    left_write_pos_buf_[0] = 0;
    right_write_pos_buf_[0] = 0;
    for (int i = 1; i < num_threads_; ++i) {
      left_write_pos_buf_[i] = left_write_pos_buf_[i - 1] + left_cnts_buf_[i - 1];
      right_write_pos_buf_[i] = right_write_pos_buf_[i - 1] + right_cnts_buf_[i - 1];
    }
    left_cnt = left_write_pos_buf_[num_threads_ - 1] + left_cnts_buf_[num_threads_ - 1];

Guolin Ke's avatar
Guolin Ke committed
308
    #pragma omp parallel for schedule(static, 1)
309
    for (int i = 0; i < num_threads_; ++i) {
310
      OMP_LOOP_EX_BEGIN();
311
312
      if (left_cnts_buf_[i] > 0) {
        std::memcpy(bag_data_indices_.data() + left_write_pos_buf_[i],
313
                    tmp_indices_.data() + offsets_buf_[i], left_cnts_buf_[i] * sizeof(data_size_t));
Guolin Ke's avatar
Guolin Ke committed
314
      }
315
316
      if (right_cnts_buf_[i] > 0) {
        std::memcpy(bag_data_indices_.data() + left_cnt + right_write_pos_buf_[i],
317
                    tmp_indices_.data() + offsets_buf_[i] + left_cnts_buf_[i], right_cnts_buf_[i] * sizeof(data_size_t));
Guolin Ke's avatar
Guolin Ke committed
318
      }
319
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
320
    }
321
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
322
    bag_data_cnt_ = left_cnt;
Guolin Ke's avatar
Guolin Ke committed
323
    Log::Debug("Re-bagging, using %d data to train", bag_data_cnt_);
Guolin Ke's avatar
Guolin Ke committed
324
    // set bagging data to tree learner
Guolin Ke's avatar
Guolin Ke committed
325
326
327
328
    if (!is_use_subset_) {
      tree_learner_->SetBaggingData(bag_data_indices_.data(), bag_data_cnt_);
    } else {
      // get subset
Guolin Ke's avatar
Guolin Ke committed
329
330
      tmp_subset_->ReSize(bag_data_cnt_);
      tmp_subset_->CopySubset(train_data_, bag_data_indices_.data(), bag_data_cnt_, false);
Guolin Ke's avatar
Guolin Ke committed
331
332
      tree_learner_->ResetTrainingData(tmp_subset_.get());
    }
Guolin Ke's avatar
Guolin Ke committed
333
334
335
  }
}

336
void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
337
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
338
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
339
  #endif
340
  // we need to predict out-of-bag scores of data for boosting
Guolin Ke's avatar
Guolin Ke committed
341
  if (num_data_ - bag_data_cnt_ > 0 && !is_use_subset_) {
342
    train_score_updater_->AddScore(tree, bag_data_indices_.data() + bag_data_cnt_, num_data_ - bag_data_cnt_, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
343
  }
Guolin Ke's avatar
Guolin Ke committed
344
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
345
  out_of_bag_score_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
346
  #endif
Guolin Ke's avatar
Guolin Ke committed
347
348
}

Guolin Ke's avatar
Guolin Ke committed
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
double LabelAverage(const float* label, data_size_t num_data) {
  double sum_label = 0.0f;
  #pragma omp parallel for schedule(static) reduction(+:sum_label)
  for (data_size_t i = 0; i < num_data; ++i) {
    sum_label += label[i];
  }
  double init_score = sum_label / num_data;
  if (Network::num_machines() > 1) {
    double global_init_score = 0.0f;
    Network::Allreduce(reinterpret_cast<char*>(&init_score),
                       sizeof(init_score), sizeof(init_score),
                       reinterpret_cast<char*>(&global_init_score),
                       [](const char* src, char* dst, int len) {
      int used_size = 0;
      const int type_size = sizeof(double);
      const double *p1;
      double *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const double *>(src);
        p2 = reinterpret_cast<double *>(dst);
        *p2 += *p1;
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    return global_init_score / Network::num_machines();
  } else {
    return init_score;
  }
}

381
bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) {
382
  // boosting from average prediction. It doesn't work well for classification, remove it for now.
Guolin Ke's avatar
Guolin Ke committed
383
384
  if (models_.empty()
      && gbdt_config_->boost_from_average
385
      && !train_score_updater_->has_init_score()
386
387
388
      && num_class_ <= 1
      && objective_function_ != nullptr
      && objective_function_->BoostFromAverage()) {
389
    auto label = train_data_->metadata().label();
Guolin Ke's avatar
Guolin Ke committed
390
    double init_score = LabelAverage(label, num_data_);
391
    std::unique_ptr<Tree> new_tree(new Tree(2));
Guolin Ke's avatar
Guolin Ke committed
392
    new_tree->Split(0, 0, BinType::NumericalBin, 0, 0, 0, init_score, init_score, 0, 0, -1, 0, 0, 0);
393
394
395
396
397
    train_score_updater_->AddScore(init_score, 0);
    for (auto& score_updater : valid_score_updater_) {
      score_updater->AddScore(init_score, 0);
    }
    models_.push_back(std::move(new_tree));
398
399
    boost_from_average_ = true;
  }
Guolin Ke's avatar
Guolin Ke committed
400
401
  // boosting first
  if (gradient == nullptr || hessian == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
402
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
403
    auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
404
    #endif
Guolin Ke's avatar
Guolin Ke committed
405
    Boosting();
406
407
    gradient = gradients_.data();
    hessian = hessians_.data();
Guolin Ke's avatar
Guolin Ke committed
408
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
409
    boosting_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
410
    #endif
Guolin Ke's avatar
Guolin Ke committed
411
  }
Guolin Ke's avatar
Guolin Ke committed
412
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
413
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
414
  #endif
415
416
  // bagging logic
  Bagging(iter_);
Guolin Ke's avatar
Guolin Ke committed
417
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
418
  bagging_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
419
  #endif
Guolin Ke's avatar
Guolin Ke committed
420
  if (is_use_subset_ && bag_data_cnt_ < num_data_) {
Guolin Ke's avatar
Guolin Ke committed
421
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
422
    start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
423
    #endif
424
425
426
427
428
    if (gradients_.empty()) {
      size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
      gradients_.resize(total_size);
      hessians_.resize(total_size);
    }
Guolin Ke's avatar
Guolin Ke committed
429
    // get sub gradients
430
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
431
      size_t bias = static_cast<size_t>(cur_tree_id)* num_data_;
432
      // cannot multi-threading here.
Guolin Ke's avatar
Guolin Ke committed
433
      for (int i = 0; i < bag_data_cnt_; ++i) {
434
435
        gradients_[bias + i] = gradient[bias + bag_data_indices_[i]];
        hessians_[bias + i] = hessian[bias + bag_data_indices_[i]];
Guolin Ke's avatar
Guolin Ke committed
436
437
      }
    }
438
439
    gradient = gradients_.data();
    hessian = hessians_.data();
Guolin Ke's avatar
Guolin Ke committed
440
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
441
    sub_gradient_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
442
    #endif
Guolin Ke's avatar
Guolin Ke committed
443
  }
Guolin Ke's avatar
Guolin Ke committed
444
  bool should_continue = false;
445
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
446
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
447
    start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
448
    #endif
449
    std::unique_ptr<Tree> new_tree(new Tree(2));
450
    if (class_need_train_[cur_tree_id]) {
451
      size_t bias = static_cast<size_t>(cur_tree_id)* num_data_;
452
      new_tree.reset(
453
        tree_learner_->Train(gradient + bias, hessian + bias, is_constant_hessian_));
454
    }
Guolin Ke's avatar
Guolin Ke committed
455
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
456
    tree_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
457
    #endif
Guolin Ke's avatar
Guolin Ke committed
458
459

    if (new_tree->num_leaves() > 1) {
Guolin Ke's avatar
Guolin Ke committed
460
461
462
463
      should_continue = true;
      // shrinkage by learning rate
      new_tree->Shrinkage(shrinkage_rate_);
      // update score
464
465
      UpdateScore(new_tree.get(), cur_tree_id);
      UpdateScoreOutOfBag(new_tree.get(), cur_tree_id);
466
467
    } else {
      // only add default score one-time
468
469
      if (!class_need_train_[cur_tree_id] && models_.size() < static_cast<size_t>(num_tree_per_iteration_)) {
        auto output = class_default_output_[cur_tree_id];
Guolin Ke's avatar
Guolin Ke committed
470
        new_tree->Split(0, 0, BinType::NumericalBin, 0, 0, 0,
Guolin Ke's avatar
Guolin Ke committed
471
                        output, output, 0, 0, -1, 0, 0, 0);
472
        train_score_updater_->AddScore(output, cur_tree_id);
473
        for (auto& score_updater : valid_score_updater_) {
474
          score_updater->AddScore(output, cur_tree_id);
475
476
477
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
478
479
480
    // add model
    models_.push_back(std::move(new_tree));
  }
Guolin Ke's avatar
Guolin Ke committed
481
  if (!should_continue) {
Guolin Ke's avatar
Guolin Ke committed
482
    Log::Warning("Stopped training because there are no more leaves that meet the split requirements.");
483
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
484
485
486
487
      models_.pop_back();
    }
    return true;
  }
Guolin Ke's avatar
Guolin Ke committed
488
489
490
491
492
493
  ++iter_;
  if (is_eval) {
    return EvalAndCheckEarlyStopping();
  } else {
    return false;
  }
494

Guolin Ke's avatar
Guolin Ke committed
495
}
496

wxchan's avatar
wxchan committed
497
void GBDT::RollbackOneIter() {
498
  if (iter_ <= 0) { return; }
wxchan's avatar
wxchan committed
499
500
  int cur_iter = iter_ + num_init_iteration_ - 1;
  // reset score
501
502
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
    auto curr_tree = cur_iter * num_tree_per_iteration_ + cur_tree_id;
wxchan's avatar
wxchan committed
503
    models_[curr_tree]->Shrinkage(-1.0);
504
    train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
505
    for (auto& score_updater : valid_score_updater_) {
506
      score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
507
508
509
    }
  }
  // remove model
510
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
wxchan's avatar
wxchan committed
511
512
513
514
515
    models_.pop_back();
  }
  --iter_;
}

Guolin Ke's avatar
Guolin Ke committed
516
bool GBDT::EvalAndCheckEarlyStopping() {
517
  bool is_met_early_stopping = false;
Guolin Ke's avatar
Guolin Ke committed
518
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
519
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
520
  #endif
521
  // print message for metric
Guolin Ke's avatar
Guolin Ke committed
522
  auto best_msg = OutputMetric(iter_);
Guolin Ke's avatar
Guolin Ke committed
523
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
524
  metric_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
525
  #endif
Guolin Ke's avatar
Guolin Ke committed
526
  is_met_early_stopping = !best_msg.empty();
527
528
  if (is_met_early_stopping) {
    Log::Info("Early stopping at iteration %d, the best iteration round is %d",
529
              iter_, iter_ - early_stopping_round_);
Guolin Ke's avatar
Guolin Ke committed
530
    Log::Info("Output of best iteration round:\n%s", best_msg.c_str());
531
    // pop last early_stopping_round_ models
532
    for (int i = 0; i < early_stopping_round_ * num_tree_per_iteration_; ++i) {
533
534
535
536
      models_.pop_back();
    }
  }
  return is_met_early_stopping;
Guolin Ke's avatar
Guolin Ke committed
537
538
}

539
void GBDT::UpdateScore(const Tree* tree, const int cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
540
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
541
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
542
  #endif
Guolin Ke's avatar
Guolin Ke committed
543
  // update training score
Guolin Ke's avatar
Guolin Ke committed
544
  if (!is_use_subset_) {
545
    train_score_updater_->AddScore(tree_learner_.get(), tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
546
  } else {
547
    train_score_updater_->AddScore(tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
548
  }
Guolin Ke's avatar
Guolin Ke committed
549
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
550
  train_score_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
551
552
  #endif
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
553
  start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
554
  #endif
Guolin Ke's avatar
Guolin Ke committed
555
  // update validation score
Guolin Ke's avatar
Guolin Ke committed
556
  for (auto& score_updater : valid_score_updater_) {
557
    score_updater->AddScore(tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
558
  }
Guolin Ke's avatar
Guolin Ke committed
559
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
560
  valid_score_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
561
  #endif
Guolin Ke's avatar
Guolin Ke committed
562
563
}

Guolin Ke's avatar
Guolin Ke committed
564
565
566
567
std::string GBDT::OutputMetric(int iter) {
  bool need_output = (iter % gbdt_config_->output_freq) == 0;
  std::string ret = "";
  std::stringstream msg_buf;
568
  std::vector<std::pair<size_t, size_t>> meet_early_stopping_pairs;
Guolin Ke's avatar
Guolin Ke committed
569
  // print training metric
Guolin Ke's avatar
Guolin Ke committed
570
  if (need_output) {
571
572
    for (auto& sub_metric : training_metrics_) {
      auto name = sub_metric->GetName();
Guolin Ke's avatar
Guolin Ke committed
573
      auto scores = sub_metric->Eval(train_score_updater_->score(), objective_function_);
Guolin Ke's avatar
Guolin Ke committed
574
      for (size_t k = 0; k < name.size(); ++k) {
Guolin Ke's avatar
Guolin Ke committed
575
576
577
578
579
580
581
582
        std::stringstream tmp_buf;
        tmp_buf << "Iteration:" << iter
          << ", training " << name[k]
          << " : " << scores[k];
        Log::Info(tmp_buf.str().c_str());
        if (early_stopping_round_ > 0) {
          msg_buf << tmp_buf.str() << std::endl;
        }
583
      }
584
    }
Guolin Ke's avatar
Guolin Ke committed
585
586
  }
  // print validation metric
Guolin Ke's avatar
Guolin Ke committed
587
  if (need_output || early_stopping_round_ > 0) {
588
589
    for (size_t i = 0; i < valid_metrics_.size(); ++i) {
      for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
590
        auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score(),
Guolin Ke's avatar
Guolin Ke committed
591
                                                      objective_function_);
Guolin Ke's avatar
Guolin Ke committed
592
593
594
595
596
597
598
599
600
601
602
        auto name = valid_metrics_[i][j]->GetName();
        for (size_t k = 0; k < name.size(); ++k) {
          std::stringstream tmp_buf;
          tmp_buf << "Iteration:" << iter
            << ", valid_" << i + 1 << " " << name[k]
            << " : " << test_scores[k];
          if (need_output) {
            Log::Info(tmp_buf.str().c_str());
          }
          if (early_stopping_round_ > 0) {
            msg_buf << tmp_buf.str() << std::endl;
603
          }
wxchan's avatar
wxchan committed
604
        }
Guolin Ke's avatar
Guolin Ke committed
605
        if (ret.empty() && early_stopping_round_ > 0) {
606
607
608
          auto cur_score = valid_metrics_[i][j]->factor_to_bigger_better() * test_scores.back();
          if (cur_score > best_score_[i][j]) {
            best_score_[i][j] = cur_score;
609
            best_iter_[i][j] = iter;
Guolin Ke's avatar
Guolin Ke committed
610
            meet_early_stopping_pairs.emplace_back(i, j);
611
          } else {
Guolin Ke's avatar
Guolin Ke committed
612
            if (iter - best_iter_[i][j] >= early_stopping_round_) { ret = best_msg_[i][j]; }
613
          }
wxchan's avatar
wxchan committed
614
615
        }
      }
Guolin Ke's avatar
Guolin Ke committed
616
617
    }
  }
Guolin Ke's avatar
Guolin Ke committed
618
619
620
  for (auto& pair : meet_early_stopping_pairs) {
    best_msg_[pair.first][pair.second] = msg_buf.str();
  }
wxchan's avatar
wxchan committed
621
  return ret;
Guolin Ke's avatar
Guolin Ke committed
622
623
}

624
/*! \brief Get eval result */
625
std::vector<double> GBDT::GetEvalAt(int data_idx) const {
Guolin Ke's avatar
Guolin Ke committed
626
  CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
627
628
  std::vector<double> ret;
  if (data_idx == 0) {
629
    for (auto& sub_metric : training_metrics_) {
Guolin Ke's avatar
Guolin Ke committed
630
      auto scores = sub_metric->Eval(train_score_updater_->score(), objective_function_);
631
632
633
      for (auto score : scores) {
        ret.push_back(score);
      }
634
    }
635
  } else {
636
637
    auto used_idx = data_idx - 1;
    for (size_t j = 0; j < valid_metrics_[used_idx].size(); ++j) {
638
      auto test_scores = valid_metrics_[used_idx][j]->Eval(valid_score_updater_[used_idx]->score(),
Guolin Ke's avatar
Guolin Ke committed
639
                                                           objective_function_);
640
641
642
      for (auto score : test_scores) {
        ret.push_back(score);
      }
643
644
645
646
647
    }
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
648
/*! \brief Get training scores result */
649
const double* GBDT::GetTrainingScore(int64_t* out_len) {
650
  *out_len = static_cast<int64_t>(train_score_updater_->num_data()) * num_class_;
Guolin Ke's avatar
Guolin Ke committed
651
  return train_score_updater_->score();
652
653
}

Guolin Ke's avatar
Guolin Ke committed
654
655
void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
  CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
Guolin Ke's avatar
Guolin Ke committed
656

657
  const double* raw_scores = nullptr;
Guolin Ke's avatar
Guolin Ke committed
658
659
  data_size_t num_data = 0;
  if (data_idx == 0) {
wxchan's avatar
wxchan committed
660
    raw_scores = GetTrainingScore(out_len);
Guolin Ke's avatar
Guolin Ke committed
661
662
663
664
665
    num_data = train_score_updater_->num_data();
  } else {
    auto used_idx = data_idx - 1;
    raw_scores = valid_score_updater_[used_idx]->score();
    num_data = valid_score_updater_[used_idx]->num_data();
666
    *out_len = static_cast<int64_t>(num_data) * num_class_;
Guolin Ke's avatar
Guolin Ke committed
667
  }
668
  if (objective_function_ != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
669
670
    #pragma omp parallel for schedule(static)
    for (data_size_t i = 0; i < num_data; ++i) {
Guolin Ke's avatar
Guolin Ke committed
671
      std::vector<double> tree_pred(num_tree_per_iteration_);
672
      for (int j = 0; j < num_tree_per_iteration_; ++j) {
Guolin Ke's avatar
Guolin Ke committed
673
        tree_pred[j] = raw_scores[j * num_data + i];
674
      }
Guolin Ke's avatar
Guolin Ke committed
675
676
      std::vector<double> tmp_result(num_class_);
      objective_function_->ConvertOutput(tree_pred.data(), tmp_result.data());
Guolin Ke's avatar
Guolin Ke committed
677
      for (int j = 0; j < num_class_; ++j) {
678
        out_result[j * num_data + i] = static_cast<double>(tmp_result[j]);
Guolin Ke's avatar
Guolin Ke committed
679
680
      }
    }
681
  } else {
Guolin Ke's avatar
Guolin Ke committed
682
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
683
    for (data_size_t i = 0; i < num_data; ++i) {
Guolin Ke's avatar
Guolin Ke committed
684
      std::vector<double> tmp_result(num_tree_per_iteration_);
685
      for (int j = 0; j < num_tree_per_iteration_; ++j) {
Guolin Ke's avatar
Guolin Ke committed
686
        out_result[j * num_data + i] = static_cast<double>(raw_scores[j * num_data + i]);
Guolin Ke's avatar
Guolin Ke committed
687
688
689
690
691
      }
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
692
void GBDT::Boosting() {
693
  if (objective_function_ == nullptr) {
694
695
    Log::Fatal("No object function provided");
  }
Hui Xue's avatar
Hui Xue committed
696
  // objective function will calculate gradients and hessians
697
  int64_t num_score = 0;
698
  objective_function_->
Guolin Ke's avatar
Guolin Ke committed
699
    GetGradients(GetTrainingScore(&num_score), gradients_.data(), hessians_.data());
Guolin Ke's avatar
Guolin Ke committed
700
701
}

702
std::string GBDT::DumpModel(int num_iteration) const {
Guolin Ke's avatar
Guolin Ke committed
703
  std::stringstream str_buf;
wxchan's avatar
wxchan committed
704

Guolin Ke's avatar
Guolin Ke committed
705
  str_buf << "{";
Guolin Ke's avatar
Guolin Ke committed
706
  str_buf << "\"name\":\"" << SubModelName() << "\"," << std::endl;
Guolin Ke's avatar
Guolin Ke committed
707
  str_buf << "\"num_class\":" << num_class_ << "," << std::endl;
708
  str_buf << "\"num_tree_per_iteration\":" << num_tree_per_iteration_ << "," << std::endl;
Guolin Ke's avatar
Guolin Ke committed
709
710
  str_buf << "\"label_index\":" << label_idx_ << "," << std::endl;
  str_buf << "\"max_feature_idx\":" << max_feature_idx_ << "," << std::endl;
wxchan's avatar
wxchan committed
711

712
713
714
  str_buf << "\"feature_names\":[\""
    << Common::Join(feature_names_, "\",\"") << "\"],"
    << std::endl;
Guolin Ke's avatar
Guolin Ke committed
715

Guolin Ke's avatar
Guolin Ke committed
716
  str_buf << "\"tree_info\":[";
717
718
  int num_used_model = static_cast<int>(models_.size());
  if (num_iteration > 0) {
Guolin Ke's avatar
Guolin Ke committed
719
    num_iteration += boost_from_average_ ? 1 : 0;
720
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
721
  }
722
  for (int i = 0; i < num_used_model; ++i) {
wxchan's avatar
wxchan committed
723
    if (i > 0) {
Guolin Ke's avatar
Guolin Ke committed
724
      str_buf << ",";
wxchan's avatar
wxchan committed
725
    }
Guolin Ke's avatar
Guolin Ke committed
726
727
728
729
    str_buf << "{";
    str_buf << "\"tree_index\":" << i << ",";
    str_buf << models_[i]->ToJSON();
    str_buf << "}";
wxchan's avatar
wxchan committed
730
  }
Guolin Ke's avatar
Guolin Ke committed
731
  str_buf << "]" << std::endl;
wxchan's avatar
wxchan committed
732

Guolin Ke's avatar
Guolin Ke committed
733
  str_buf << "}" << std::endl;
wxchan's avatar
wxchan committed
734

Guolin Ke's avatar
Guolin Ke committed
735
  return str_buf.str();
wxchan's avatar
wxchan committed
736
737
}

738
739
740
std::string GBDT::ModelToIfElse(int num_iteration) const {
  std::stringstream str_buf;

741
742
743
744
  str_buf << "#include \"gbdt.h\"" << std::endl;
  str_buf << "#include <LightGBM/utils/common.h>" << std::endl;
  str_buf << "#include <LightGBM/objective_function.h>" << std::endl;
  str_buf << "#include <LightGBM/metric.h>" << std::endl;
cbecker's avatar
cbecker committed
745
  str_buf << "#include <LightGBM/prediction_early_stop.h>" << std::endl;
746
747
748
749
750
751
752
753
  str_buf << "#include <ctime>" << std::endl;
  str_buf << "#include <sstream>" << std::endl;
  str_buf << "#include <chrono>" << std::endl;
  str_buf << "#include <string>" << std::endl;
  str_buf << "#include <vector>" << std::endl;
  str_buf << "#include <utility>" << std::endl;
  str_buf << "namespace LightGBM {" << std::endl;

754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
  int num_used_model = static_cast<int>(models_.size());
  if (num_iteration > 0) {
    num_iteration += boost_from_average_ ? 1 : 0;
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
  }

  // PredictRaw
  for (int i = 0; i < num_used_model; ++i) {
    str_buf << models_[i]->ToIfElse(i, false) << std::endl;
  }

  str_buf << "double (*PredictTreePtr[])(const double*) = { ";
  for (int i = 0; i < num_used_model; ++i) {
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i;
  }
  str_buf << " };" << std::endl << std::endl;

  std::stringstream pred_str_buf;

776
  pred_str_buf << "\t" << "int early_stop_round_counter = 0;" << std::endl;
cbecker's avatar
cbecker committed
777
  pred_str_buf << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl;
778
  pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
cbecker's avatar
cbecker committed
779
  pred_str_buf << "\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
780
  pred_str_buf << "\t\t" << "}" << std::endl;
781
782
783
  pred_str_buf << "\t\t" << "++early_stop_round_counter;" << std::endl;
  pred_str_buf << "\t\t" << "if (early_stop->round_period == early_stop_round_counter) {" << std::endl;
  pred_str_buf << "\t\t\t" << "if (early_stop->callback_function(output, num_tree_per_iteration_))" << std::endl;
cbecker's avatar
cbecker committed
784
  pred_str_buf << "\t\t\t\t" << "return;" << std::endl;
785
  pred_str_buf << "\t\t\t" << "early_stop_round_counter = 0;" << std::endl;
786
787
788
  pred_str_buf << "\t\t" << "}" << std::endl;
  pred_str_buf << "\t" << "}" << std::endl;

789
  str_buf << "void GBDT::PredictRaw(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << std::endl;
790
791
792
793
794
  str_buf << pred_str_buf.str();
  str_buf << "}" << std::endl;
  str_buf << std::endl;

  // Predict
795
796
  str_buf << "void GBDT::Predict(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << std::endl;
  str_buf << "\t" << "PredictRaw(features, output, early_stop);" << std::endl;
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
  str_buf << "\t" << "if (objective_function_ != nullptr) {" << std::endl;
  str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << std::endl;
  str_buf << "\t" << "}" << std::endl;
  str_buf << "}" << std::endl;
  str_buf << std::endl;

  // PredictLeafIndex
  for (int i = 0; i < num_used_model; ++i) {
    str_buf << models_[i]->ToIfElse(i, true) << std::endl;
  }

  str_buf << "double (*PredictTreeLeafPtr[])(const double*) = { ";
  for (int i = 0; i < num_used_model; ++i) {
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i << "Leaf";
  }
  str_buf << " };" << std::endl << std::endl;

  str_buf << "void GBDT::PredictLeafIndex(const double* features, double *output) const {" << std::endl;
  str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << std::endl;
  str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << std::endl;
  str_buf << "\t\t" << "output[i] = (*PredictTreeLeafPtr[i])(features);" << std::endl;
  str_buf << "\t" << "}" << std::endl;
  str_buf << "}" << std::endl;
823
824
825

  str_buf << "}  // namespace LightGBM" << std::endl;

826
827
828
829
830
831
  return str_buf.str();
}

bool GBDT::SaveModelToIfElse(int num_iteration, const char* filename) const {
  /*! \brief File to write models */
  std::ofstream output_file;
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
  std::ifstream ifs(filename);
  if (ifs.good()) {
    std::string origin((std::istreambuf_iterator<char>(ifs)),
                       (std::istreambuf_iterator<char>()));
    output_file.open(filename);
    output_file << "#define USE_HARD_CODE 0" << std::endl;
    output_file << "#ifndef USE_HARD_CODE" << std::endl;
    output_file << origin << std::endl;
    output_file << "#else" << std::endl;
    output_file << ModelToIfElse(num_iteration);
    output_file << "#endif" << std::endl;
  } else {
    output_file.open(filename);
    output_file << ModelToIfElse(num_iteration);
  }
847

848
  ifs.close();
849
850
851
852
853
  output_file.close();

  return (bool)output_file;
}

Guolin Ke's avatar
Guolin Ke committed
854
std::string GBDT::SaveModelToString(int num_iteration) const {
855
  std::stringstream ss;
856

857
858
859
860
  // output model type
  ss << SubModelName() << std::endl;
  // output number of class
  ss << "num_class=" << num_class_ << std::endl;
861
  ss << "num_tree_per_iteration=" << num_tree_per_iteration_ << std::endl;
862
863
864
865
  // output label index
  ss << "label_index=" << label_idx_ << std::endl;
  // output max_feature_idx
  ss << "max_feature_idx=" << max_feature_idx_ << std::endl;
866
867
868
  // output objective
  if (objective_function_ != nullptr) {
    ss << "objective=" << objective_function_->ToString() << std::endl;
869
  }
870

871
872
873
  if (boost_from_average_) {
    ss << "boost_from_average" << std::endl;
  }
Guolin Ke's avatar
Guolin Ke committed
874

875
  ss << "feature_names=" << Common::Join(feature_names_, " ") << std::endl;
876

877
  ss << "feature_infos=" << Common::Join(feature_infos_, " ") << std::endl;
878

879
880
  ss << std::endl;
  int num_used_model = static_cast<int>(models_.size());
Guolin Ke's avatar
Guolin Ke committed
881
882
  if (num_iteration > 0) {
    num_iteration += boost_from_average_ ? 1 : 0;
883
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
884
885
886
887
888
889
890
891
892
893
894
895
896
897
  }
  // output tree models
  for (int i = 0; i < num_used_model; ++i) {
    ss << "Tree=" << i << std::endl;
    ss << models_[i]->ToString() << std::endl;
  }

  std::vector<std::pair<size_t, std::string>> pairs = FeatureImportance();
  ss << std::endl << "feature importances:" << std::endl;
  for (size_t i = 0; i < pairs.size(); ++i) {
    ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << std::endl;
  }

  return ss.str();
898
899
}

900
bool GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
wxchan's avatar
wxchan committed
901
902
903
  /*! \brief File to write models */
  std::ofstream output_file;
  output_file.open(filename);
904

905
  output_file << SaveModelToString(num_iteration);
906

wxchan's avatar
wxchan committed
907
  output_file.close();
908
909

  return (bool)output_file;
Guolin Ke's avatar
Guolin Ke committed
910
911
}

912
bool GBDT::LoadModelFromString(const std::string& model_str) {
Guolin Ke's avatar
Guolin Ke committed
913
914
  // use serialized string to restore this object
  models_.clear();
Guolin Ke's avatar
Guolin Ke committed
915
  std::vector<std::string> lines = Common::SplitLines(model_str.c_str());
916
917

  // get number of classes
918
919
920
921
  auto line = Common::FindFromLines(lines, "num_class=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &num_class_);
  } else {
922
    Log::Fatal("Model file doesn't specify the number of classes");
923
    return false;
924
  }
925
926
927
928
929
930
931
932

  line = Common::FindFromLines(lines, "num_tree_per_iteration=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &num_tree_per_iteration_);
  } else {
    num_tree_per_iteration_ = num_class_;
  }

Guolin Ke's avatar
Guolin Ke committed
933
  // get index of label
934
935
936
937
  line = Common::FindFromLines(lines, "label_index=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &label_idx_);
  } else {
938
    Log::Fatal("Model file doesn't specify the label index");
939
    return false;
Guolin Ke's avatar
Guolin Ke committed
940
  }
Guolin Ke's avatar
Guolin Ke committed
941
  // get max_feature_idx first
942
943
944
945
  line = Common::FindFromLines(lines, "max_feature_idx=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &max_feature_idx_);
  } else {
946
    Log::Fatal("Model file doesn't specify max_feature_idx");
947
    return false;
Guolin Ke's avatar
Guolin Ke committed
948
  }
949
950
951
952
953
  // get boost_from_average_
  line = Common::FindFromLines(lines, "boost_from_average");
  if (line.size() > 0) {
    boost_from_average_ = true;
  }
Guolin Ke's avatar
Guolin Ke committed
954
955
956
  // get feature names
  line = Common::FindFromLines(lines, "feature_names=");
  if (line.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
957
    feature_names_ = Common::Split(line.substr(std::strlen("feature_names=")).c_str(), ' ');
Guolin Ke's avatar
Guolin Ke committed
958
959
    if (feature_names_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of feature_names");
960
      return false;
Guolin Ke's avatar
Guolin Ke committed
961
    }
962
  } else {
Guolin Ke's avatar
Guolin Ke committed
963
    Log::Fatal("Model file doesn't contain feature names");
964
    return false;
Guolin Ke's avatar
Guolin Ke committed
965
966
  }

Guolin Ke's avatar
Guolin Ke committed
967
968
  line = Common::FindFromLines(lines, "feature_infos=");
  if (line.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
969
    feature_infos_ = Common::Split(line.substr(std::strlen("feature_infos=")).c_str(), ' ');
Guolin Ke's avatar
Guolin Ke committed
970
971
972
973
974
975
976
977
978
    if (feature_infos_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of feature_infos");
      return false;
    }
  } else {
    Log::Fatal("Model file doesn't contain feature infos");
    return false;
  }

979
980
981
982
983
984
985
986
  line = Common::FindFromLines(lines, "objective=");

  if (line.size() > 0) {
    auto str = Common::Split(line.c_str(), '=')[1];
    loaded_objective_.reset(ObjectiveFunction::CreateObjectiveFunction(str));
    objective_function_ = loaded_objective_.get();
  }

Guolin Ke's avatar
Guolin Ke committed
987
  // get tree models
988
  size_t i = 0;
Guolin Ke's avatar
Guolin Ke committed
989
990
991
992
993
994
995
  while (i < lines.size()) {
    size_t find_pos = lines[i].find("Tree=");
    if (find_pos != std::string::npos) {
      ++i;
      int start = static_cast<int>(i);
      while (i < lines.size() && lines[i].find("Tree=") == std::string::npos) { ++i; }
      int end = static_cast<int>(i);
Guolin Ke's avatar
Guolin Ke committed
996
      std::string tree_str = Common::Join<std::string>(lines, start, end, "\n");
997
      models_.emplace_back(new Tree(tree_str));
Guolin Ke's avatar
Guolin Ke committed
998
999
1000
1001
    } else {
      ++i;
    }
  }
1002
  Log::Info("Finished loading %d models", models_.size());
1003
  num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
wxchan's avatar
wxchan committed
1004
  num_init_iteration_ = num_iteration_for_pred_;
1005
  iter_ = 0;
1006
1007

  return true;
Guolin Ke's avatar
Guolin Ke committed
1008
1009
}

wxchan's avatar
wxchan committed
1010
std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const {
1011

1012
  std::vector<size_t> feature_importances(max_feature_idx_ + 1, 0);
1013
1014
  for (size_t iter = 0; iter < models_.size(); ++iter) {
    for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
Guolin Ke's avatar
Guolin Ke committed
1015
1016
1017
      if (models_[iter]->split_gain(split_idx) > 0) {
        ++feature_importances[models_[iter]->split_feature(split_idx)];
      }
wxchan's avatar
wxchan committed
1018
    }
1019
1020
1021
1022
1023
1024
  }
  // store the importance first
  std::vector<std::pair<size_t, std::string>> pairs;
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    if (feature_importances[i] > 0) {
      pairs.emplace_back(feature_importances[i], feature_names_[i]);
1025
    }
1026
1027
1028
  }
  // sort the importance
  std::sort(pairs.begin(), pairs.end(),
Guolin Ke's avatar
Guolin Ke committed
1029
1030
            [](const std::pair<size_t, std::string>& lhs,
               const std::pair<size_t, std::string>& rhs) {
1031
1032
1033
    return lhs.first > rhs.first;
  });
  return pairs;
wxchan's avatar
wxchan committed
1034
1035
}

Guolin Ke's avatar
Guolin Ke committed
1036
}  // namespace LightGBM