gbdt.cpp 38.3 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
#include "gbdt.h"

3
#include <LightGBM/utils/openmp_wrapper.h>
4

Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#include <LightGBM/utils/common.h>

#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
cbecker's avatar
cbecker committed
9
#include <LightGBM/prediction_early_stop.h>
Guolin Ke's avatar
Guolin Ke committed
10
#include <LightGBM/network.h>
Guolin Ke's avatar
Guolin Ke committed
11
12
13
14
15
16
17

#include <ctime>

#include <sstream>
#include <chrono>
#include <string>
#include <vector>
18
#include <utility>
Guolin Ke's avatar
Guolin Ke committed
19
20
21

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
22
23
24
#ifdef TIMETAG
std::chrono::duration<double, std::milli> boosting_time;
std::chrono::duration<double, std::milli> train_score_time;
Guolin Ke's avatar
Guolin Ke committed
25
std::chrono::duration<double, std::milli> out_of_bag_score_time;
Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
30
31
32
std::chrono::duration<double, std::milli> valid_score_time;
std::chrono::duration<double, std::milli> metric_time;
std::chrono::duration<double, std::milli> bagging_time;
std::chrono::duration<double, std::milli> sub_gradient_time;
std::chrono::duration<double, std::milli> tree_time;
#endif // TIMETAG

33
GBDT::GBDT()
34
  :iter_(0),
35
  train_data_(nullptr),
36
  objective_function_(nullptr),
37
38
  early_stopping_round_(0),
  max_feature_idx_(0),
39
  num_tree_per_iteration_(1),
40
  num_class_(1),
41
  num_iteration_for_pred_(0),
42
  shrinkage_rate_(0.1f),
43
44
  num_init_iteration_(0),
  boost_from_average_(false) {
Guolin Ke's avatar
Guolin Ke committed
45
46
  #pragma omp parallel
  #pragma omp master
47
48
49
    {
      num_threads_ = omp_get_num_threads();
    }
Guolin Ke's avatar
Guolin Ke committed
50
51
52
}

GBDT::~GBDT() {
Guolin Ke's avatar
Guolin Ke committed
53
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
54
55
  Log::Info("GBDT::boosting costs %f", boosting_time * 1e-3);
  Log::Info("GBDT::train_score costs %f", train_score_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
56
  Log::Info("GBDT::out_of_bag_score costs %f", out_of_bag_score_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
57
58
59
60
61
  Log::Info("GBDT::valid_score costs %f", valid_score_time * 1e-3);
  Log::Info("GBDT::metric costs %f", metric_time * 1e-3);
  Log::Info("GBDT::bagging costs %f", bagging_time * 1e-3);
  Log::Info("GBDT::sub_gradient costs %f", sub_gradient_time * 1e-3);
  Log::Info("GBDT::tree costs %f", tree_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
62
  #endif
Guolin Ke's avatar
Guolin Ke committed
63
64
}

65
void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
66
                const std::vector<const Metric*>& training_metrics) {
67
  train_data_ = train_data;
68
  iter_ = 0;
wxchan's avatar
wxchan committed
69
  num_iteration_for_pred_ = 0;
70
  max_feature_idx_ = 0;
wxchan's avatar
wxchan committed
71
  num_class_ = config->num_class;
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
  gbdt_config_ = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
  early_stopping_round_ = gbdt_config_->early_stopping_round;
  shrinkage_rate_ = gbdt_config_->learning_rate;

  objective_function_ = objective_function;
  num_tree_per_iteration_ = num_class_;
  if (objective_function_ != nullptr) {
    is_constant_hessian_ = objective_function_->IsConstantHessian();
    num_tree_per_iteration_ = objective_function_->NumTreePerIteration();
  } else {
    is_constant_hessian_ = false;
  }

  tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->device_type, &gbdt_config_->tree_config));

  // init tree learner
  tree_learner_->Init(train_data_, is_constant_hessian_);

  // push training metrics
  training_metrics_.clear();
  for (const auto& metric : training_metrics) {
    training_metrics_.push_back(metric);
  }
  training_metrics_.shrink_to_fit();

  train_score_updater_.reset(new ScoreUpdater(train_data_, num_tree_per_iteration_));

  num_data_ = train_data_->num_data();
  // create buffer for gradients and hessians
  if (objective_function_ != nullptr) {
    size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
    gradients_.resize(total_size);
    hessians_.resize(total_size);
  }
  // get max feature index
  max_feature_idx_ = train_data_->num_total_features() - 1;
  // get label index
  label_idx_ = train_data_->label_idx();
  // get feature names
  feature_names_ = train_data_->feature_names();
  feature_infos_ = train_data_->feature_infos();

  // if need bagging, create buffer
  ResetBaggingConfig(gbdt_config_.get());

  // reset config for tree learner
  class_need_train_ = std::vector<bool>(num_tree_per_iteration_, true);
  if (objective_function_ != nullptr && objective_function_->SkipEmptyClass()) {
    CHECK(num_tree_per_iteration_ == num_class_);
    // + 1 here for the binary classification
    class_default_output_ = std::vector<double>(num_tree_per_iteration_, 0.0f);
    auto label = train_data_->metadata().label();
    if (num_tree_per_iteration_ > 1) {
      // multi-class
      std::vector<data_size_t> cnt_per_class(num_tree_per_iteration_, 0);
      for (data_size_t i = 0; i < num_data_; ++i) {
        int index = static_cast<int>(label[i]);
        CHECK(index < num_tree_per_iteration_);
        ++cnt_per_class[index];
      }
      for (int i = 0; i < num_tree_per_iteration_; ++i) {
        if (cnt_per_class[i] == num_data_) {
          class_need_train_[i] = false;
          class_default_output_[i] = -std::log(kEpsilon);
        } else if (cnt_per_class[i] == 0) {
          class_need_train_[i] = false;
          class_default_output_[i] = -std::log(1.0f / kEpsilon - 1.0f);
        }
      }
    } else {
      // binary class
      data_size_t cnt_pos = 0;
      for (data_size_t i = 0; i < num_data_; ++i) {
        if (label[i] > 0) {
          ++cnt_pos;
        }
      }
      if (cnt_pos == 0) {
        class_need_train_[0] = false;
        class_default_output_[0] = -std::log(1.0f / kEpsilon - 1.0f);
      } else if (cnt_pos == num_data_) {
        class_need_train_[0] = false;
        class_default_output_[0] = -std::log(kEpsilon);
      }
    }
  }
wxchan's avatar
wxchan committed
158
159
}

160
void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
161
                             const std::vector<const Metric*>& training_metrics) {
162
  if (train_data != train_data_ && !train_data_->CheckAlign(*train_data)) {
wxchan's avatar
wxchan committed
163
164
    Log::Fatal("cannot reset training data, since new training data has different bin mappers");
  }
Guolin Ke's avatar
Guolin Ke committed
165

166
167
168
169
  objective_function_ = objective_function;
  num_tree_per_iteration_ = num_class_;
  if (objective_function_ != nullptr) {
    is_constant_hessian_ = objective_function_->IsConstantHessian();
Guolin Ke's avatar
Guolin Ke committed
170
    num_tree_per_iteration_ = objective_function_->NumTreePerIteration();
171
172
173
  } else {
    is_constant_hessian_ = false;
  }
Guolin Ke's avatar
Guolin Ke committed
174

175
176
177
178
179
180
  // push training metrics
  training_metrics_.clear();
  for (const auto& metric : training_metrics) {
    training_metrics_.push_back(metric);
  }
  training_metrics_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
181

182
183
  if (train_data != train_data_) {
    train_data_ = train_data;
wxchan's avatar
wxchan committed
184
185
    // not same training data, need reset score and others
    // create score tracker
186
    train_score_updater_.reset(new ScoreUpdater(train_data_, num_tree_per_iteration_));
wxchan's avatar
wxchan committed
187
188
    // update score
    for (int i = 0; i < iter_; ++i) {
189
190
191
      for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
        auto curr_tree = (i + num_init_iteration_) * num_tree_per_iteration_ + cur_tree_id;
        train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
192
193
      }
    }
194
195
    num_data_ = train_data_->num_data();

wxchan's avatar
wxchan committed
196
    // create buffer for gradients and hessians
197
198
199
200
201
    if (objective_function_ != nullptr) {
      size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
      gradients_.resize(total_size);
      hessians_.resize(total_size);
    }
202

wxchan's avatar
wxchan committed
203
    // get max feature index
204
    max_feature_idx_ = train_data_->num_total_features() - 1;
wxchan's avatar
wxchan committed
205
    // get label index
206
    label_idx_ = train_data_->label_idx();
207
    // get feature names
208
209
210
211
212
    feature_names_ = train_data_->feature_names();

    feature_infos_ = train_data_->feature_infos();

    ResetBaggingConfig(gbdt_config_.get());
Guolin Ke's avatar
Guolin Ke committed
213

214
    tree_learner_->ResetTrainingData(train_data);
Guolin Ke's avatar
Guolin Ke committed
215
  }
216
}
Guolin Ke's avatar
Guolin Ke committed
217

218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
void GBDT::ResetConfig(const BoostingConfig* config) {
  auto new_config = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));

  early_stopping_round_ = new_config->early_stopping_round;
  shrinkage_rate_ = new_config->learning_rate;

  ResetBaggingConfig(new_config.get());

  tree_learner_->ResetConfig(&new_config->tree_config);
  gbdt_config_.reset(new_config.release());
}

void GBDT::ResetBaggingConfig(const BoostingConfig* config) {
  // if need bagging, create buffer
  if (config->bagging_fraction < 1.0 && config->bagging_freq > 0) {
    bag_data_cnt_ =
      static_cast<data_size_t>(config->bagging_fraction * num_data_);
    bag_data_indices_.resize(num_data_);
    tmp_indices_.resize(num_data_);
    offsets_buf_.resize(num_threads_);
    left_cnts_buf_.resize(num_threads_);
    right_cnts_buf_.resize(num_threads_);
    left_write_pos_buf_.resize(num_threads_);
    right_write_pos_buf_.resize(num_threads_);
    double average_bag_rate = config->bagging_fraction / config->bagging_freq;
    int sparse_group = 0;
    for (int i = 0; i < train_data_->num_feature_groups(); ++i) {
      if (train_data_->FeatureGroupIsSparse(i)) {
        ++sparse_group;
Guolin Ke's avatar
Guolin Ke committed
247
      }
wxchan's avatar
wxchan committed
248
    }
249
250
251
252
253
254
255
256
257
    is_use_subset_ = false;
    const int group_threshold_usesubset = 100;
    const int sparse_group_threshold_usesubset = train_data_->num_feature_groups() / 4;
    if (average_bag_rate <= 0.5
        && (train_data_->num_feature_groups() < group_threshold_usesubset || sparse_group < sparse_group_threshold_usesubset)) {
      tmp_subset_.reset(new Dataset(bag_data_cnt_));
      tmp_subset_->CopyFeatureMapperFrom(train_data_);
      is_use_subset_ = true;
      Log::Debug("use subset for bagging");
258
    }
259
260
261
262
263
  } else {
    bag_data_cnt_ = num_data_;
    bag_data_indices_.clear();
    tmp_indices_.clear();
    is_use_subset_ = false;
Guolin Ke's avatar
Guolin Ke committed
264
  }
Guolin Ke's avatar
Guolin Ke committed
265
266
}

wxchan's avatar
wxchan committed
267
void GBDT::AddValidDataset(const Dataset* valid_data,
268
                           const std::vector<const Metric*>& valid_metrics) {
wxchan's avatar
wxchan committed
269
270
  if (!train_data_->CheckAlign(*valid_data)) {
    Log::Fatal("cannot add validation data, since it has different bin mappers with training data");
271
  }
Guolin Ke's avatar
Guolin Ke committed
272
  // for a validation dataset, we need its score and metric
273
  auto new_score_updater = std::unique_ptr<ScoreUpdater>(new ScoreUpdater(valid_data, num_tree_per_iteration_));
wxchan's avatar
wxchan committed
274
275
  // update score
  for (int i = 0; i < iter_; ++i) {
276
277
278
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
      auto curr_tree = (i + num_init_iteration_) * num_tree_per_iteration_ + cur_tree_id;
      new_score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
279
280
    }
  }
Guolin Ke's avatar
Guolin Ke committed
281
  valid_score_updater_.push_back(std::move(new_score_updater));
Guolin Ke's avatar
Guolin Ke committed
282
  valid_metrics_.emplace_back();
283
284
285
  if (early_stopping_round_ > 0) {
    best_iter_.emplace_back();
    best_score_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
286
    best_msg_.emplace_back();
287
  }
Guolin Ke's avatar
Guolin Ke committed
288
289
  for (const auto& metric : valid_metrics) {
    valid_metrics_.back().push_back(metric);
290
291
292
    if (early_stopping_round_ > 0) {
      best_iter_.back().push_back(0);
      best_score_.back().push_back(kMinScore);
Guolin Ke's avatar
Guolin Ke committed
293
      best_msg_.back().emplace_back();
294
    }
Guolin Ke's avatar
Guolin Ke committed
295
  }
Guolin Ke's avatar
Guolin Ke committed
296
  valid_metrics_.back().shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
297
298
}

299
data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer) {
300
301
302
  if (cnt <= 0) {
    return 0;
  }
303
304
305
306
  data_size_t bag_data_cnt =
    static_cast<data_size_t>(gbdt_config_->bagging_fraction * cnt);
  data_size_t cur_left_cnt = 0;
  data_size_t cur_right_cnt = 0;
Guolin Ke's avatar
Guolin Ke committed
307
  auto right_buffer = buffer + bag_data_cnt;
308
309
  // random bagging, minimal unit is one record
  for (data_size_t i = 0; i < cnt; ++i) {
Guolin Ke's avatar
Guolin Ke committed
310
311
312
    float prob =
      (bag_data_cnt - cur_left_cnt) / static_cast<float>(cnt - i);
    if (cur_rand.NextFloat() < prob) {
313
314
      buffer[cur_left_cnt++] = start + i;
    } else {
Guolin Ke's avatar
Guolin Ke committed
315
      right_buffer[cur_right_cnt++] = start + i;
316
317
318
319
320
    }
  }
  CHECK(cur_left_cnt == bag_data_cnt);
  return cur_left_cnt;
}
Guolin Ke's avatar
Guolin Ke committed
321

322
void GBDT::Bagging(int iter) {
Guolin Ke's avatar
Guolin Ke committed
323
  // if need bagging
324
  if (bag_data_cnt_ < num_data_ && iter % gbdt_config_->bagging_freq == 0) {
Guolin Ke's avatar
Guolin Ke committed
325
    const data_size_t min_inner_size = 1000;
326
327
    data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_;
    if (inner_size < min_inner_size) { inner_size = min_inner_size; }
328
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
329
    #pragma omp parallel for schedule(static,1)
330
    for (int i = 0; i < num_threads_; ++i) {
331
      OMP_LOOP_EX_BEGIN();
332
333
334
335
336
337
      left_cnts_buf_[i] = 0;
      right_cnts_buf_[i] = 0;
      data_size_t cur_start = i * inner_size;
      if (cur_start > num_data_) { continue; }
      data_size_t cur_cnt = inner_size;
      if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; }
Guolin Ke's avatar
Guolin Ke committed
338
339
      Random cur_rand(gbdt_config_->bagging_seed + iter * num_threads_ + i);
      data_size_t cur_left_count = BaggingHelper(cur_rand, cur_start, cur_cnt, tmp_indices_.data() + cur_start);
340
341
342
      offsets_buf_[i] = cur_start;
      left_cnts_buf_[i] = cur_left_count;
      right_cnts_buf_[i] = cur_cnt - cur_left_count;
343
      OMP_LOOP_EX_END();
344
    }
345
    OMP_THROW_EX();
346
347
348
349
350
351
352
353
354
    data_size_t left_cnt = 0;
    left_write_pos_buf_[0] = 0;
    right_write_pos_buf_[0] = 0;
    for (int i = 1; i < num_threads_; ++i) {
      left_write_pos_buf_[i] = left_write_pos_buf_[i - 1] + left_cnts_buf_[i - 1];
      right_write_pos_buf_[i] = right_write_pos_buf_[i - 1] + right_cnts_buf_[i - 1];
    }
    left_cnt = left_write_pos_buf_[num_threads_ - 1] + left_cnts_buf_[num_threads_ - 1];

Guolin Ke's avatar
Guolin Ke committed
355
    #pragma omp parallel for schedule(static, 1)
356
    for (int i = 0; i < num_threads_; ++i) {
357
      OMP_LOOP_EX_BEGIN();
358
359
      if (left_cnts_buf_[i] > 0) {
        std::memcpy(bag_data_indices_.data() + left_write_pos_buf_[i],
360
                    tmp_indices_.data() + offsets_buf_[i], left_cnts_buf_[i] * sizeof(data_size_t));
Guolin Ke's avatar
Guolin Ke committed
361
      }
362
363
      if (right_cnts_buf_[i] > 0) {
        std::memcpy(bag_data_indices_.data() + left_cnt + right_write_pos_buf_[i],
364
                    tmp_indices_.data() + offsets_buf_[i] + left_cnts_buf_[i], right_cnts_buf_[i] * sizeof(data_size_t));
Guolin Ke's avatar
Guolin Ke committed
365
      }
366
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
367
    }
368
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
369
    bag_data_cnt_ = left_cnt;
Guolin Ke's avatar
Guolin Ke committed
370
    Log::Debug("Re-bagging, using %d data to train", bag_data_cnt_);
Guolin Ke's avatar
Guolin Ke committed
371
    // set bagging data to tree learner
Guolin Ke's avatar
Guolin Ke committed
372
373
374
375
    if (!is_use_subset_) {
      tree_learner_->SetBaggingData(bag_data_indices_.data(), bag_data_cnt_);
    } else {
      // get subset
Guolin Ke's avatar
Guolin Ke committed
376
377
      tmp_subset_->ReSize(bag_data_cnt_);
      tmp_subset_->CopySubset(train_data_, bag_data_indices_.data(), bag_data_cnt_, false);
Guolin Ke's avatar
Guolin Ke committed
378
379
      tree_learner_->ResetTrainingData(tmp_subset_.get());
    }
Guolin Ke's avatar
Guolin Ke committed
380
381
382
  }
}

383
void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
384
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
385
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
386
  #endif
387
  // we need to predict out-of-bag scores of data for boosting
Guolin Ke's avatar
Guolin Ke committed
388
  if (num_data_ - bag_data_cnt_ > 0 && !is_use_subset_) {
389
    train_score_updater_->AddScore(tree, bag_data_indices_.data() + bag_data_cnt_, num_data_ - bag_data_cnt_, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
390
  }
Guolin Ke's avatar
Guolin Ke committed
391
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
392
  out_of_bag_score_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
393
  #endif
Guolin Ke's avatar
Guolin Ke committed
394
395
}

Guolin Ke's avatar
Guolin Ke committed
396
397
398
399
400
401
402
403
404
405
406
407
double LabelAverage(const float* label, data_size_t num_data) {
  double sum_label = 0.0f;
  #pragma omp parallel for schedule(static) reduction(+:sum_label)
  for (data_size_t i = 0; i < num_data; ++i) {
    sum_label += label[i];
  }
  double init_score = sum_label / num_data;
  if (Network::num_machines() > 1) {
    double global_init_score = 0.0f;
    Network::Allreduce(reinterpret_cast<char*>(&init_score),
                       sizeof(init_score), sizeof(init_score),
                       reinterpret_cast<char*>(&global_init_score),
408
                       [] (const char* src, char* dst, int len) {
Guolin Ke's avatar
Guolin Ke committed
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
      int used_size = 0;
      const int type_size = sizeof(double);
      const double *p1;
      double *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const double *>(src);
        p2 = reinterpret_cast<double *>(dst);
        *p2 += *p1;
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    return global_init_score / Network::num_machines();
  } else {
    return init_score;
  }
}

428
bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) {
429
  // boosting from average prediction. It doesn't work well for classification, remove it for now.
Guolin Ke's avatar
Guolin Ke committed
430
431
  if (models_.empty()
      && gbdt_config_->boost_from_average
432
      && !train_score_updater_->has_init_score()
433
434
435
      && num_class_ <= 1
      && objective_function_ != nullptr
      && objective_function_->BoostFromAverage()) {
436
    auto label = train_data_->metadata().label();
Guolin Ke's avatar
Guolin Ke committed
437
    double init_score = LabelAverage(label, num_data_);
438
    std::unique_ptr<Tree> new_tree(new Tree(2));
Guolin Ke's avatar
Guolin Ke committed
439
    new_tree->Split(0, 0, BinType::NumericalBin, 0, 0, 0, init_score, init_score, 0, 0, -1, 0, 0, 0);
440
441
442
443
444
    train_score_updater_->AddScore(init_score, 0);
    for (auto& score_updater : valid_score_updater_) {
      score_updater->AddScore(init_score, 0);
    }
    models_.push_back(std::move(new_tree));
445
446
    boost_from_average_ = true;
  }
Guolin Ke's avatar
Guolin Ke committed
447
448
  // boosting first
  if (gradient == nullptr || hessian == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
449
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
450
    auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
451
    #endif
Guolin Ke's avatar
Guolin Ke committed
452
    Boosting();
453
454
    gradient = gradients_.data();
    hessian = hessians_.data();
Guolin Ke's avatar
Guolin Ke committed
455
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
456
    boosting_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
457
    #endif
Guolin Ke's avatar
Guolin Ke committed
458
  }
Guolin Ke's avatar
Guolin Ke committed
459
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
460
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
461
  #endif
462
463
  // bagging logic
  Bagging(iter_);
Guolin Ke's avatar
Guolin Ke committed
464
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
465
  bagging_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
466
  #endif
Guolin Ke's avatar
Guolin Ke committed
467
  if (is_use_subset_ && bag_data_cnt_ < num_data_) {
Guolin Ke's avatar
Guolin Ke committed
468
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
469
    start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
470
    #endif
471
472
473
474
475
    if (gradients_.empty()) {
      size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
      gradients_.resize(total_size);
      hessians_.resize(total_size);
    }
Guolin Ke's avatar
Guolin Ke committed
476
    // get sub gradients
477
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
478
      size_t bias = static_cast<size_t>(cur_tree_id)* num_data_;
479
      // cannot multi-threading here.
Guolin Ke's avatar
Guolin Ke committed
480
      for (int i = 0; i < bag_data_cnt_; ++i) {
481
482
        gradients_[bias + i] = gradient[bias + bag_data_indices_[i]];
        hessians_[bias + i] = hessian[bias + bag_data_indices_[i]];
Guolin Ke's avatar
Guolin Ke committed
483
484
      }
    }
485
486
    gradient = gradients_.data();
    hessian = hessians_.data();
Guolin Ke's avatar
Guolin Ke committed
487
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
488
    sub_gradient_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
489
    #endif
Guolin Ke's avatar
Guolin Ke committed
490
  }
Guolin Ke's avatar
Guolin Ke committed
491
  bool should_continue = false;
492
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
493
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
494
    start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
495
    #endif
496
    std::unique_ptr<Tree> new_tree(new Tree(2));
497
    if (class_need_train_[cur_tree_id]) {
498
      size_t bias = static_cast<size_t>(cur_tree_id)* num_data_;
499
      new_tree.reset(
500
        tree_learner_->Train(gradient + bias, hessian + bias, is_constant_hessian_));
501
    }
Guolin Ke's avatar
Guolin Ke committed
502
    #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
503
    tree_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
504
    #endif
Guolin Ke's avatar
Guolin Ke committed
505
506

    if (new_tree->num_leaves() > 1) {
Guolin Ke's avatar
Guolin Ke committed
507
508
509
510
      should_continue = true;
      // shrinkage by learning rate
      new_tree->Shrinkage(shrinkage_rate_);
      // update score
511
512
      UpdateScore(new_tree.get(), cur_tree_id);
      UpdateScoreOutOfBag(new_tree.get(), cur_tree_id);
513
514
    } else {
      // only add default score one-time
515
516
      if (!class_need_train_[cur_tree_id] && models_.size() < static_cast<size_t>(num_tree_per_iteration_)) {
        auto output = class_default_output_[cur_tree_id];
Guolin Ke's avatar
Guolin Ke committed
517
        new_tree->Split(0, 0, BinType::NumericalBin, 0, 0, 0,
Guolin Ke's avatar
Guolin Ke committed
518
                        output, output, 0, 0, -1, 0, 0, 0);
519
        train_score_updater_->AddScore(output, cur_tree_id);
520
        for (auto& score_updater : valid_score_updater_) {
521
          score_updater->AddScore(output, cur_tree_id);
522
523
524
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
525
526
527
    // add model
    models_.push_back(std::move(new_tree));
  }
Guolin Ke's avatar
Guolin Ke committed
528
  if (!should_continue) {
Guolin Ke's avatar
Guolin Ke committed
529
    Log::Warning("Stopped training because there are no more leaves that meet the split requirements.");
530
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
531
532
533
534
      models_.pop_back();
    }
    return true;
  }
Guolin Ke's avatar
Guolin Ke committed
535
536
537
538
539
540
  ++iter_;
  if (is_eval) {
    return EvalAndCheckEarlyStopping();
  } else {
    return false;
  }
541

Guolin Ke's avatar
Guolin Ke committed
542
}
543

wxchan's avatar
wxchan committed
544
void GBDT::RollbackOneIter() {
545
  if (iter_ <= 0) { return; }
wxchan's avatar
wxchan committed
546
547
  int cur_iter = iter_ + num_init_iteration_ - 1;
  // reset score
548
549
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
    auto curr_tree = cur_iter * num_tree_per_iteration_ + cur_tree_id;
wxchan's avatar
wxchan committed
550
    models_[curr_tree]->Shrinkage(-1.0);
551
    train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
552
    for (auto& score_updater : valid_score_updater_) {
553
      score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
wxchan's avatar
wxchan committed
554
555
556
    }
  }
  // remove model
557
  for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
wxchan's avatar
wxchan committed
558
559
560
561
562
    models_.pop_back();
  }
  --iter_;
}

Guolin Ke's avatar
Guolin Ke committed
563
bool GBDT::EvalAndCheckEarlyStopping() {
564
  bool is_met_early_stopping = false;
Guolin Ke's avatar
Guolin Ke committed
565
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
566
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
567
  #endif
568
  // print message for metric
Guolin Ke's avatar
Guolin Ke committed
569
  auto best_msg = OutputMetric(iter_);
Guolin Ke's avatar
Guolin Ke committed
570
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
571
  metric_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
572
  #endif
Guolin Ke's avatar
Guolin Ke committed
573
  is_met_early_stopping = !best_msg.empty();
574
575
  if (is_met_early_stopping) {
    Log::Info("Early stopping at iteration %d, the best iteration round is %d",
576
              iter_, iter_ - early_stopping_round_);
Guolin Ke's avatar
Guolin Ke committed
577
    Log::Info("Output of best iteration round:\n%s", best_msg.c_str());
578
    // pop last early_stopping_round_ models
579
    for (int i = 0; i < early_stopping_round_ * num_tree_per_iteration_; ++i) {
580
581
582
583
      models_.pop_back();
    }
  }
  return is_met_early_stopping;
Guolin Ke's avatar
Guolin Ke committed
584
585
}

586
void GBDT::UpdateScore(const Tree* tree, const int cur_tree_id) {
Guolin Ke's avatar
Guolin Ke committed
587
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
588
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
589
  #endif
Guolin Ke's avatar
Guolin Ke committed
590
  // update training score
Guolin Ke's avatar
Guolin Ke committed
591
  if (!is_use_subset_) {
592
    train_score_updater_->AddScore(tree_learner_.get(), tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
593
  } else {
594
    train_score_updater_->AddScore(tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
595
  }
Guolin Ke's avatar
Guolin Ke committed
596
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
597
  train_score_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
598
599
  #endif
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
600
  start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
601
  #endif
Guolin Ke's avatar
Guolin Ke committed
602
  // update validation score
Guolin Ke's avatar
Guolin Ke committed
603
  for (auto& score_updater : valid_score_updater_) {
604
    score_updater->AddScore(tree, cur_tree_id);
Guolin Ke's avatar
Guolin Ke committed
605
  }
Guolin Ke's avatar
Guolin Ke committed
606
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
607
  valid_score_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
608
  #endif
Guolin Ke's avatar
Guolin Ke committed
609
610
}

Guolin Ke's avatar
Guolin Ke committed
611
612
613
614
std::string GBDT::OutputMetric(int iter) {
  bool need_output = (iter % gbdt_config_->output_freq) == 0;
  std::string ret = "";
  std::stringstream msg_buf;
615
  std::vector<std::pair<size_t, size_t>> meet_early_stopping_pairs;
Guolin Ke's avatar
Guolin Ke committed
616
  // print training metric
Guolin Ke's avatar
Guolin Ke committed
617
  if (need_output) {
618
619
    for (auto& sub_metric : training_metrics_) {
      auto name = sub_metric->GetName();
Guolin Ke's avatar
Guolin Ke committed
620
      auto scores = sub_metric->Eval(train_score_updater_->score(), objective_function_);
Guolin Ke's avatar
Guolin Ke committed
621
      for (size_t k = 0; k < name.size(); ++k) {
Guolin Ke's avatar
Guolin Ke committed
622
623
624
625
626
627
628
629
        std::stringstream tmp_buf;
        tmp_buf << "Iteration:" << iter
          << ", training " << name[k]
          << " : " << scores[k];
        Log::Info(tmp_buf.str().c_str());
        if (early_stopping_round_ > 0) {
          msg_buf << tmp_buf.str() << std::endl;
        }
630
      }
631
    }
Guolin Ke's avatar
Guolin Ke committed
632
633
  }
  // print validation metric
Guolin Ke's avatar
Guolin Ke committed
634
  if (need_output || early_stopping_round_ > 0) {
635
636
    for (size_t i = 0; i < valid_metrics_.size(); ++i) {
      for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
637
        auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score(),
Guolin Ke's avatar
Guolin Ke committed
638
                                                      objective_function_);
Guolin Ke's avatar
Guolin Ke committed
639
640
641
642
643
644
645
646
647
648
649
        auto name = valid_metrics_[i][j]->GetName();
        for (size_t k = 0; k < name.size(); ++k) {
          std::stringstream tmp_buf;
          tmp_buf << "Iteration:" << iter
            << ", valid_" << i + 1 << " " << name[k]
            << " : " << test_scores[k];
          if (need_output) {
            Log::Info(tmp_buf.str().c_str());
          }
          if (early_stopping_round_ > 0) {
            msg_buf << tmp_buf.str() << std::endl;
650
          }
wxchan's avatar
wxchan committed
651
        }
Guolin Ke's avatar
Guolin Ke committed
652
        if (ret.empty() && early_stopping_round_ > 0) {
653
654
655
          auto cur_score = valid_metrics_[i][j]->factor_to_bigger_better() * test_scores.back();
          if (cur_score > best_score_[i][j]) {
            best_score_[i][j] = cur_score;
656
            best_iter_[i][j] = iter;
Guolin Ke's avatar
Guolin Ke committed
657
            meet_early_stopping_pairs.emplace_back(i, j);
658
          } else {
Guolin Ke's avatar
Guolin Ke committed
659
            if (iter - best_iter_[i][j] >= early_stopping_round_) { ret = best_msg_[i][j]; }
660
          }
wxchan's avatar
wxchan committed
661
662
        }
      }
Guolin Ke's avatar
Guolin Ke committed
663
664
    }
  }
Guolin Ke's avatar
Guolin Ke committed
665
666
667
  for (auto& pair : meet_early_stopping_pairs) {
    best_msg_[pair.first][pair.second] = msg_buf.str();
  }
wxchan's avatar
wxchan committed
668
  return ret;
Guolin Ke's avatar
Guolin Ke committed
669
670
}

671
/*! \brief Get eval result */
672
std::vector<double> GBDT::GetEvalAt(int data_idx) const {
Guolin Ke's avatar
Guolin Ke committed
673
  CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
674
675
  std::vector<double> ret;
  if (data_idx == 0) {
676
    for (auto& sub_metric : training_metrics_) {
Guolin Ke's avatar
Guolin Ke committed
677
      auto scores = sub_metric->Eval(train_score_updater_->score(), objective_function_);
678
679
680
      for (auto score : scores) {
        ret.push_back(score);
      }
681
    }
682
  } else {
683
684
    auto used_idx = data_idx - 1;
    for (size_t j = 0; j < valid_metrics_[used_idx].size(); ++j) {
685
      auto test_scores = valid_metrics_[used_idx][j]->Eval(valid_score_updater_[used_idx]->score(),
Guolin Ke's avatar
Guolin Ke committed
686
                                                           objective_function_);
687
688
689
      for (auto score : test_scores) {
        ret.push_back(score);
      }
690
691
692
693
694
    }
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
695
/*! \brief Get training scores result */
696
const double* GBDT::GetTrainingScore(int64_t* out_len) {
697
  *out_len = static_cast<int64_t>(train_score_updater_->num_data()) * num_class_;
Guolin Ke's avatar
Guolin Ke committed
698
  return train_score_updater_->score();
699
700
}

Guolin Ke's avatar
Guolin Ke committed
701
702
void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
  CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
Guolin Ke's avatar
Guolin Ke committed
703

704
  const double* raw_scores = nullptr;
Guolin Ke's avatar
Guolin Ke committed
705
706
  data_size_t num_data = 0;
  if (data_idx == 0) {
wxchan's avatar
wxchan committed
707
    raw_scores = GetTrainingScore(out_len);
Guolin Ke's avatar
Guolin Ke committed
708
709
710
711
712
    num_data = train_score_updater_->num_data();
  } else {
    auto used_idx = data_idx - 1;
    raw_scores = valid_score_updater_[used_idx]->score();
    num_data = valid_score_updater_[used_idx]->num_data();
713
    *out_len = static_cast<int64_t>(num_data) * num_class_;
Guolin Ke's avatar
Guolin Ke committed
714
  }
715
  if (objective_function_ != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
716
717
    #pragma omp parallel for schedule(static)
    for (data_size_t i = 0; i < num_data; ++i) {
Guolin Ke's avatar
Guolin Ke committed
718
      std::vector<double> tree_pred(num_tree_per_iteration_);
719
      for (int j = 0; j < num_tree_per_iteration_; ++j) {
Guolin Ke's avatar
Guolin Ke committed
720
        tree_pred[j] = raw_scores[j * num_data + i];
721
      }
Guolin Ke's avatar
Guolin Ke committed
722
723
      std::vector<double> tmp_result(num_class_);
      objective_function_->ConvertOutput(tree_pred.data(), tmp_result.data());
Guolin Ke's avatar
Guolin Ke committed
724
      for (int j = 0; j < num_class_; ++j) {
725
        out_result[j * num_data + i] = static_cast<double>(tmp_result[j]);
Guolin Ke's avatar
Guolin Ke committed
726
727
      }
    }
728
  } else {
Guolin Ke's avatar
Guolin Ke committed
729
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
730
    for (data_size_t i = 0; i < num_data; ++i) {
Guolin Ke's avatar
Guolin Ke committed
731
      std::vector<double> tmp_result(num_tree_per_iteration_);
732
      for (int j = 0; j < num_tree_per_iteration_; ++j) {
Guolin Ke's avatar
Guolin Ke committed
733
        out_result[j * num_data + i] = static_cast<double>(raw_scores[j * num_data + i]);
Guolin Ke's avatar
Guolin Ke committed
734
735
736
737
738
      }
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
739
void GBDT::Boosting() {
740
  if (objective_function_ == nullptr) {
741
742
    Log::Fatal("No object function provided");
  }
Hui Xue's avatar
Hui Xue committed
743
  // objective function will calculate gradients and hessians
744
  int64_t num_score = 0;
745
  objective_function_->
Guolin Ke's avatar
Guolin Ke committed
746
    GetGradients(GetTrainingScore(&num_score), gradients_.data(), hessians_.data());
Guolin Ke's avatar
Guolin Ke committed
747
748
}

749
std::string GBDT::DumpModel(int num_iteration) const {
Guolin Ke's avatar
Guolin Ke committed
750
  std::stringstream str_buf;
wxchan's avatar
wxchan committed
751

Guolin Ke's avatar
Guolin Ke committed
752
  str_buf << "{";
Guolin Ke's avatar
Guolin Ke committed
753
  str_buf << "\"name\":\"" << SubModelName() << "\"," << std::endl;
Guolin Ke's avatar
Guolin Ke committed
754
  str_buf << "\"num_class\":" << num_class_ << "," << std::endl;
755
  str_buf << "\"num_tree_per_iteration\":" << num_tree_per_iteration_ << "," << std::endl;
Guolin Ke's avatar
Guolin Ke committed
756
757
  str_buf << "\"label_index\":" << label_idx_ << "," << std::endl;
  str_buf << "\"max_feature_idx\":" << max_feature_idx_ << "," << std::endl;
wxchan's avatar
wxchan committed
758

759
760
761
  str_buf << "\"feature_names\":[\""
    << Common::Join(feature_names_, "\",\"") << "\"],"
    << std::endl;
Guolin Ke's avatar
Guolin Ke committed
762

Guolin Ke's avatar
Guolin Ke committed
763
  str_buf << "\"tree_info\":[";
764
765
  int num_used_model = static_cast<int>(models_.size());
  if (num_iteration > 0) {
Guolin Ke's avatar
Guolin Ke committed
766
    num_iteration += boost_from_average_ ? 1 : 0;
767
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
768
  }
769
  for (int i = 0; i < num_used_model; ++i) {
wxchan's avatar
wxchan committed
770
    if (i > 0) {
Guolin Ke's avatar
Guolin Ke committed
771
      str_buf << ",";
wxchan's avatar
wxchan committed
772
    }
Guolin Ke's avatar
Guolin Ke committed
773
774
775
776
    str_buf << "{";
    str_buf << "\"tree_index\":" << i << ",";
    str_buf << models_[i]->ToJSON();
    str_buf << "}";
wxchan's avatar
wxchan committed
777
  }
Guolin Ke's avatar
Guolin Ke committed
778
  str_buf << "]" << std::endl;
wxchan's avatar
wxchan committed
779

Guolin Ke's avatar
Guolin Ke committed
780
  str_buf << "}" << std::endl;
wxchan's avatar
wxchan committed
781

Guolin Ke's avatar
Guolin Ke committed
782
  return str_buf.str();
wxchan's avatar
wxchan committed
783
784
}

785
786
787
std::string GBDT::ModelToIfElse(int num_iteration) const {
  std::stringstream str_buf;

788
789
790
791
  str_buf << "#include \"gbdt.h\"" << std::endl;
  str_buf << "#include <LightGBM/utils/common.h>" << std::endl;
  str_buf << "#include <LightGBM/objective_function.h>" << std::endl;
  str_buf << "#include <LightGBM/metric.h>" << std::endl;
cbecker's avatar
cbecker committed
792
  str_buf << "#include <LightGBM/prediction_early_stop.h>" << std::endl;
793
794
795
796
797
798
799
800
  str_buf << "#include <ctime>" << std::endl;
  str_buf << "#include <sstream>" << std::endl;
  str_buf << "#include <chrono>" << std::endl;
  str_buf << "#include <string>" << std::endl;
  str_buf << "#include <vector>" << std::endl;
  str_buf << "#include <utility>" << std::endl;
  str_buf << "namespace LightGBM {" << std::endl;

801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
  int num_used_model = static_cast<int>(models_.size());
  if (num_iteration > 0) {
    num_iteration += boost_from_average_ ? 1 : 0;
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
  }

  // PredictRaw
  for (int i = 0; i < num_used_model; ++i) {
    str_buf << models_[i]->ToIfElse(i, false) << std::endl;
  }

  str_buf << "double (*PredictTreePtr[])(const double*) = { ";
  for (int i = 0; i < num_used_model; ++i) {
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i;
  }
  str_buf << " };" << std::endl << std::endl;

  std::stringstream pred_str_buf;

823
  pred_str_buf << "\t" << "int early_stop_round_counter = 0;" << std::endl;
824
  pred_str_buf << "\t" << "std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);" << std::endl;
cbecker's avatar
cbecker committed
825
  pred_str_buf << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl;
826
  pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
cbecker's avatar
cbecker committed
827
  pred_str_buf << "\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
828
  pred_str_buf << "\t\t" << "}" << std::endl;
829
830
831
  pred_str_buf << "\t\t" << "++early_stop_round_counter;" << std::endl;
  pred_str_buf << "\t\t" << "if (early_stop->round_period == early_stop_round_counter) {" << std::endl;
  pred_str_buf << "\t\t\t" << "if (early_stop->callback_function(output, num_tree_per_iteration_))" << std::endl;
cbecker's avatar
cbecker committed
832
  pred_str_buf << "\t\t\t\t" << "return;" << std::endl;
833
  pred_str_buf << "\t\t\t" << "early_stop_round_counter = 0;" << std::endl;
834
835
836
  pred_str_buf << "\t\t" << "}" << std::endl;
  pred_str_buf << "\t" << "}" << std::endl;

837
  str_buf << "void GBDT::PredictRaw(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << std::endl;
838
839
840
841
842
  str_buf << pred_str_buf.str();
  str_buf << "}" << std::endl;
  str_buf << std::endl;

  // Predict
843
844
  str_buf << "void GBDT::Predict(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << std::endl;
  str_buf << "\t" << "PredictRaw(features, output, early_stop);" << std::endl;
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
  str_buf << "\t" << "if (objective_function_ != nullptr) {" << std::endl;
  str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << std::endl;
  str_buf << "\t" << "}" << std::endl;
  str_buf << "}" << std::endl;
  str_buf << std::endl;

  // PredictLeafIndex
  for (int i = 0; i < num_used_model; ++i) {
    str_buf << models_[i]->ToIfElse(i, true) << std::endl;
  }

  str_buf << "double (*PredictTreeLeafPtr[])(const double*) = { ";
  for (int i = 0; i < num_used_model; ++i) {
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i << "Leaf";
  }
  str_buf << " };" << std::endl << std::endl;

  str_buf << "void GBDT::PredictLeafIndex(const double* features, double *output) const {" << std::endl;
  str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << std::endl;
  str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << std::endl;
  str_buf << "\t\t" << "output[i] = (*PredictTreeLeafPtr[i])(features);" << std::endl;
  str_buf << "\t" << "}" << std::endl;
  str_buf << "}" << std::endl;
871
872
873

  str_buf << "}  // namespace LightGBM" << std::endl;

874
875
876
877
878
879
  return str_buf.str();
}

bool GBDT::SaveModelToIfElse(int num_iteration, const char* filename) const {
  /*! \brief File to write models */
  std::ofstream output_file;
880
881
882
  std::ifstream ifs(filename);
  if (ifs.good()) {
    std::string origin((std::istreambuf_iterator<char>(ifs)),
883
      (std::istreambuf_iterator<char>()));
884
885
886
887
888
889
890
891
892
893
894
    output_file.open(filename);
    output_file << "#define USE_HARD_CODE 0" << std::endl;
    output_file << "#ifndef USE_HARD_CODE" << std::endl;
    output_file << origin << std::endl;
    output_file << "#else" << std::endl;
    output_file << ModelToIfElse(num_iteration);
    output_file << "#endif" << std::endl;
  } else {
    output_file.open(filename);
    output_file << ModelToIfElse(num_iteration);
  }
895

896
  ifs.close();
897
898
899
900
901
  output_file.close();

  return (bool)output_file;
}

Guolin Ke's avatar
Guolin Ke committed
902
std::string GBDT::SaveModelToString(int num_iteration) const {
903
  std::stringstream ss;
904

905
906
907
908
  // output model type
  ss << SubModelName() << std::endl;
  // output number of class
  ss << "num_class=" << num_class_ << std::endl;
909
  ss << "num_tree_per_iteration=" << num_tree_per_iteration_ << std::endl;
910
911
912
913
  // output label index
  ss << "label_index=" << label_idx_ << std::endl;
  // output max_feature_idx
  ss << "max_feature_idx=" << max_feature_idx_ << std::endl;
914
915
916
  // output objective
  if (objective_function_ != nullptr) {
    ss << "objective=" << objective_function_->ToString() << std::endl;
917
  }
918

919
920
921
  if (boost_from_average_) {
    ss << "boost_from_average" << std::endl;
  }
Guolin Ke's avatar
Guolin Ke committed
922

923
  ss << "feature_names=" << Common::Join(feature_names_, " ") << std::endl;
924

925
  ss << "feature_infos=" << Common::Join(feature_infos_, " ") << std::endl;
926

927
928
  ss << std::endl;
  int num_used_model = static_cast<int>(models_.size());
Guolin Ke's avatar
Guolin Ke committed
929
930
  if (num_iteration > 0) {
    num_iteration += boost_from_average_ ? 1 : 0;
931
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
932
933
934
935
936
937
938
939
940
941
942
943
944
945
  }
  // output tree models
  for (int i = 0; i < num_used_model; ++i) {
    ss << "Tree=" << i << std::endl;
    ss << models_[i]->ToString() << std::endl;
  }

  std::vector<std::pair<size_t, std::string>> pairs = FeatureImportance();
  ss << std::endl << "feature importances:" << std::endl;
  for (size_t i = 0; i < pairs.size(); ++i) {
    ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << std::endl;
  }

  return ss.str();
946
947
}

948
bool GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
wxchan's avatar
wxchan committed
949
950
951
  /*! \brief File to write models */
  std::ofstream output_file;
  output_file.open(filename);
952

953
  output_file << SaveModelToString(num_iteration);
954

wxchan's avatar
wxchan committed
955
  output_file.close();
956
957

  return (bool)output_file;
Guolin Ke's avatar
Guolin Ke committed
958
959
}

960
bool GBDT::LoadModelFromString(const std::string& model_str) {
Guolin Ke's avatar
Guolin Ke committed
961
962
  // use serialized string to restore this object
  models_.clear();
Guolin Ke's avatar
Guolin Ke committed
963
  std::vector<std::string> lines = Common::SplitLines(model_str.c_str());
964
965

  // get number of classes
966
967
968
969
  auto line = Common::FindFromLines(lines, "num_class=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &num_class_);
  } else {
970
    Log::Fatal("Model file doesn't specify the number of classes");
971
    return false;
972
  }
973
974
975
976
977
978
979
980

  line = Common::FindFromLines(lines, "num_tree_per_iteration=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &num_tree_per_iteration_);
  } else {
    num_tree_per_iteration_ = num_class_;
  }

Guolin Ke's avatar
Guolin Ke committed
981
  // get index of label
982
983
984
985
  line = Common::FindFromLines(lines, "label_index=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &label_idx_);
  } else {
986
    Log::Fatal("Model file doesn't specify the label index");
987
    return false;
Guolin Ke's avatar
Guolin Ke committed
988
  }
Guolin Ke's avatar
Guolin Ke committed
989
  // get max_feature_idx first
990
991
992
993
  line = Common::FindFromLines(lines, "max_feature_idx=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &max_feature_idx_);
  } else {
994
    Log::Fatal("Model file doesn't specify max_feature_idx");
995
    return false;
Guolin Ke's avatar
Guolin Ke committed
996
  }
997
998
999
1000
1001
  // get boost_from_average_
  line = Common::FindFromLines(lines, "boost_from_average");
  if (line.size() > 0) {
    boost_from_average_ = true;
  }
Guolin Ke's avatar
Guolin Ke committed
1002
1003
1004
  // get feature names
  line = Common::FindFromLines(lines, "feature_names=");
  if (line.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
1005
    feature_names_ = Common::Split(line.substr(std::strlen("feature_names=")).c_str(), ' ');
Guolin Ke's avatar
Guolin Ke committed
1006
1007
    if (feature_names_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of feature_names");
1008
      return false;
Guolin Ke's avatar
Guolin Ke committed
1009
    }
1010
  } else {
Guolin Ke's avatar
Guolin Ke committed
1011
    Log::Fatal("Model file doesn't contain feature names");
1012
    return false;
Guolin Ke's avatar
Guolin Ke committed
1013
1014
  }

Guolin Ke's avatar
Guolin Ke committed
1015
1016
  line = Common::FindFromLines(lines, "feature_infos=");
  if (line.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
1017
    feature_infos_ = Common::Split(line.substr(std::strlen("feature_infos=")).c_str(), ' ');
Guolin Ke's avatar
Guolin Ke committed
1018
1019
1020
1021
1022
1023
1024
1025
1026
    if (feature_infos_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of feature_infos");
      return false;
    }
  } else {
    Log::Fatal("Model file doesn't contain feature infos");
    return false;
  }

1027
1028
1029
1030
1031
1032
1033
1034
  line = Common::FindFromLines(lines, "objective=");

  if (line.size() > 0) {
    auto str = Common::Split(line.c_str(), '=')[1];
    loaded_objective_.reset(ObjectiveFunction::CreateObjectiveFunction(str));
    objective_function_ = loaded_objective_.get();
  }

Guolin Ke's avatar
Guolin Ke committed
1035
  // get tree models
1036
  size_t i = 0;
Guolin Ke's avatar
Guolin Ke committed
1037
1038
1039
1040
1041
1042
1043
  while (i < lines.size()) {
    size_t find_pos = lines[i].find("Tree=");
    if (find_pos != std::string::npos) {
      ++i;
      int start = static_cast<int>(i);
      while (i < lines.size() && lines[i].find("Tree=") == std::string::npos) { ++i; }
      int end = static_cast<int>(i);
Guolin Ke's avatar
Guolin Ke committed
1044
      std::string tree_str = Common::Join<std::string>(lines, start, end, "\n");
1045
      models_.emplace_back(new Tree(tree_str));
Guolin Ke's avatar
Guolin Ke committed
1046
1047
1048
1049
    } else {
      ++i;
    }
  }
1050
  Log::Info("Finished loading %d models", models_.size());
1051
  num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
wxchan's avatar
wxchan committed
1052
  num_init_iteration_ = num_iteration_for_pred_;
1053
  iter_ = 0;
1054
1055

  return true;
Guolin Ke's avatar
Guolin Ke committed
1056
1057
}

wxchan's avatar
wxchan committed
1058
std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const {
1059

1060
  std::vector<size_t> feature_importances(max_feature_idx_ + 1, 0);
1061
1062
  for (size_t iter = 0; iter < models_.size(); ++iter) {
    for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
Guolin Ke's avatar
Guolin Ke committed
1063
1064
1065
      if (models_[iter]->split_gain(split_idx) > 0) {
        ++feature_importances[models_[iter]->split_feature(split_idx)];
      }
wxchan's avatar
wxchan committed
1066
    }
1067
1068
1069
1070
1071
1072
  }
  // store the importance first
  std::vector<std::pair<size_t, std::string>> pairs;
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    if (feature_importances[i] > 0) {
      pairs.emplace_back(feature_importances[i], feature_names_[i]);
1073
    }
1074
1075
1076
  }
  // sort the importance
  std::sort(pairs.begin(), pairs.end(),
1077
1078
            [] (const std::pair<size_t, std::string>& lhs,
                const std::pair<size_t, std::string>& rhs) {
1079
1080
1081
    return lhs.first > rhs.first;
  });
  return pairs;
wxchan's avatar
wxchan committed
1082
1083
}

Guolin Ke's avatar
Guolin Ke committed
1084
}  // namespace LightGBM