gbdt.cpp 16.8 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include "gbdt.h"

#include <LightGBM/utils/common.h>

#include <LightGBM/feature.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>

#include <ctime>

#include <sstream>
#include <chrono>
#include <string>
#include <vector>
15
#include <utility>
Guolin Ke's avatar
Guolin Ke committed
16
17
18

namespace LightGBM {

19
GBDT::GBDT()
20
  : train_score_updater_(nullptr),
Guolin Ke's avatar
Guolin Ke committed
21
  gradients_(nullptr), hessians_(nullptr),
22
23
  out_of_bag_data_indices_(nullptr), bag_data_indices_(nullptr),
  saved_model_size_(-1), num_used_model_(0) {
Guolin Ke's avatar
Guolin Ke committed
24
25
26
}

GBDT::~GBDT() {
27
  for (auto& tree_learner: tree_learner_){
28
    if (tree_learner != nullptr) { delete tree_learner; }
29
  }
Guolin Ke's avatar
Guolin Ke committed
30
31
32
33
34
35
36
37
38
39
40
41
42
  if (gradients_ != nullptr) { delete[] gradients_; }
  if (hessians_ != nullptr) { delete[] hessians_; }
  if (out_of_bag_data_indices_ != nullptr) { delete[] out_of_bag_data_indices_; }
  if (bag_data_indices_ != nullptr) { delete[] bag_data_indices_; }
  for (auto& tree : models_) {
    if (tree != nullptr) { delete tree; }
  }
  if (train_score_updater_ != nullptr) { delete train_score_updater_; }
  for (auto& score_tracker : valid_score_updater_) {
    if (score_tracker != nullptr) { delete score_tracker; }
  }
}

43
44
45
46
void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
     const std::vector<const Metric*>& training_metrics) {
  gbdt_config_ = dynamic_cast<const GBDTConfig*>(config);
  iter_ = 0;
47
  saved_model_size_ = -1;
48
49
  max_feature_idx_ = 0;
  early_stopping_round_ = gbdt_config_->early_stopping_round;
Guolin Ke's avatar
Guolin Ke committed
50
  train_data_ = train_data;
51
52
  num_class_ = config->num_class;
  tree_learner_ = std::vector<TreeLearner*>(num_class_, nullptr);
Guolin Ke's avatar
Guolin Ke committed
53
  // create tree learner
54
55
56
57
58
59
  for (int i = 0; i < num_class_; ++i){
      tree_learner_[i] =
        TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config);
      // init tree learner
      tree_learner_[i]->Init(train_data_);
  }
Guolin Ke's avatar
Guolin Ke committed
60
61
62
63
64
65
  object_function_ = object_function;
  // push training metrics
  for (const auto& metric : training_metrics) {
    training_metrics_.push_back(metric);
  }
  // create score tracker
66
  train_score_updater_ = new ScoreUpdater(train_data_, num_class_);
Guolin Ke's avatar
Guolin Ke committed
67
68
  num_data_ = train_data_->num_data();
  // create buffer for gradients and hessians
69
  if (object_function_ != nullptr) {
70
71
    gradients_ = new score_t[num_data_ * num_class_];
    hessians_ = new score_t[num_data_ * num_class_];
72
  }
Guolin Ke's avatar
Guolin Ke committed
73
74

  // get max feature index
75
  max_feature_idx_ = train_data_->num_total_features() - 1;
Guolin Ke's avatar
Guolin Ke committed
76
77
  // get label index
  label_idx_ = train_data_->label_idx();
Guolin Ke's avatar
Guolin Ke committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
  // if need bagging, create buffer
  if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) {
    out_of_bag_data_indices_ = new data_size_t[num_data_];
    bag_data_indices_ = new data_size_t[num_data_];
  } else {
    out_of_bag_data_cnt_ = 0;
    out_of_bag_data_indices_ = nullptr;
    bag_data_cnt_ = num_data_;
    bag_data_indices_ = nullptr;
  }
  // initialize random generator
  random_ = Random(gbdt_config_->bagging_seed);

}

void GBDT::AddDataset(const Dataset* valid_data,
         const std::vector<const Metric*>& valid_metrics) {
  // for a validation dataset, we need its score and metric
96
  valid_score_updater_.push_back(new ScoreUpdater(valid_data, num_class_));
Guolin Ke's avatar
Guolin Ke committed
97
  valid_metrics_.emplace_back();
wxchan's avatar
wxchan committed
98
99
  best_iter_.emplace_back();
  best_score_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
100
101
  for (const auto& metric : valid_metrics) {
    valid_metrics_.back().push_back(metric);
wxchan's avatar
wxchan committed
102
103
    best_iter_.back().push_back(0);
    best_score_.back().push_back(-1);
Guolin Ke's avatar
Guolin Ke committed
104
105
106
107
  }
}


108
void GBDT::Bagging(int iter, const int curr_class) {
Guolin Ke's avatar
Guolin Ke committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
  // if need bagging
  if (out_of_bag_data_indices_ != nullptr && iter % gbdt_config_->bagging_freq == 0) {
    // if doesn't have query data
    if (train_data_->metadata().query_boundaries() == nullptr) {
      bag_data_cnt_ =
        static_cast<data_size_t>(gbdt_config_->bagging_fraction * num_data_);
      out_of_bag_data_cnt_ = num_data_ - bag_data_cnt_;
      data_size_t cur_left_cnt = 0;
      data_size_t cur_right_cnt = 0;
      // random bagging, minimal unit is one record
      for (data_size_t i = 0; i < num_data_; ++i) {
        double prob =
          (bag_data_cnt_ - cur_left_cnt) / static_cast<double>(num_data_ - i);
        if (random_.NextDouble() < prob) {
          bag_data_indices_[cur_left_cnt++] = i;
        } else {
          out_of_bag_data_indices_[cur_right_cnt++] = i;
        }
      }
    } else {
      // if have query data
      const data_size_t* query_boundaries = train_data_->metadata().query_boundaries();
      data_size_t num_query = train_data_->metadata().num_queries();
      data_size_t bag_query_cnt =
          static_cast<data_size_t>(num_query * gbdt_config_->bagging_fraction);
      data_size_t cur_left_query_cnt = 0;
      data_size_t cur_left_cnt = 0;
      data_size_t cur_right_cnt = 0;
      // random bagging, minimal unit is one query
      for (data_size_t i = 0; i < num_query; ++i) {
        double prob =
            (bag_query_cnt - cur_left_query_cnt) / static_cast<double>(num_query - i);
        if (random_.NextDouble() < prob) {
          for (data_size_t j = query_boundaries[i]; j < query_boundaries[i + 1]; ++j) {
            bag_data_indices_[cur_left_cnt++] = j;
          }
          cur_left_query_cnt++;
        } else {
          for (data_size_t j = query_boundaries[i]; j < query_boundaries[i + 1]; ++j) {
            out_of_bag_data_indices_[cur_right_cnt++] = j;
          }
        }
      }
      bag_data_cnt_ = cur_left_cnt;
      out_of_bag_data_cnt_ = num_data_ - bag_data_cnt_;
    }
155
    Log::Info("Re-bagging, using %d data to train", bag_data_cnt_);
Guolin Ke's avatar
Guolin Ke committed
156
    // set bagging data to tree learner
157
    tree_learner_[curr_class]->SetBaggingData(bag_data_indices_, bag_data_cnt_);
Guolin Ke's avatar
Guolin Ke committed
158
159
160
  }
}

161
void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int curr_class) {
Hui Xue's avatar
Hui Xue committed
162
  // we need to predict out-of-bag socres of data for boosting
Guolin Ke's avatar
Guolin Ke committed
163
164
  if (out_of_bag_data_indices_ != nullptr) {
    train_score_updater_->
165
      AddScore(tree, out_of_bag_data_indices_, out_of_bag_data_cnt_, curr_class);
Guolin Ke's avatar
Guolin Ke committed
166
167
168
  }
}

169
bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) {
170
171
172
173
174
175
    // boosting first
    if (gradient == nullptr || hessian == nullptr) {
      Boosting();
      gradient = gradients_;
      hessian = hessians_;
    }
176

177
178
179
    for (int curr_class = 0; curr_class < num_class_; ++curr_class){
      // bagging logic
      Bagging(iter_, curr_class);
180

181
182
183
184
      // train a new tree
      Tree * new_tree = tree_learner_[curr_class]->Train(gradient + curr_class * num_data_, hessian+ curr_class * num_data_);
      // if cannot learn a new tree, then stop
      if (new_tree->num_leaves() <= 1) {
185
        Log::Info("Stopped training because there are no more leafs that meet the split requirements.");
186
187
        return true;
      }
188

189
190
191
192
193
      // shrinkage by learning rate
      new_tree->Shrinkage(gbdt_config_->learning_rate);
      // update score
      UpdateScore(new_tree, curr_class);
      UpdateScoreOutOfBag(new_tree, curr_class);
194

195
196
197
      // add model
      models_.push_back(new_tree);
    }
198

199
200
201
202
203
204
205
206
207
208
  bool is_met_early_stopping = false;
  // print message for metric
  if (is_eval) {
    is_met_early_stopping = OutputMetric(iter_ + 1);
  }
  ++iter_;
  if (is_met_early_stopping) {
    Log::Info("Early stopping at iteration %d, the best iteration round is %d",
      iter_, iter_ - early_stopping_round_);
    // pop last early_stopping_round_ models
209
    for (int i = 0; i < early_stopping_round_ * num_class_; ++i) {
210
211
212
213
214
      delete models_.back();
      models_.pop_back();
    }
  }
  return is_met_early_stopping;
215

Guolin Ke's avatar
Guolin Ke committed
216
217
}

218
void GBDT::UpdateScore(const Tree* tree, const int curr_class) {
Guolin Ke's avatar
Guolin Ke committed
219
  // update training score
220
  train_score_updater_->AddScore(tree_learner_[curr_class], curr_class);
Guolin Ke's avatar
Guolin Ke committed
221
  // update validation score
Guolin Ke's avatar
Guolin Ke committed
222
223
  for (auto& score_updater : valid_score_updater_) {
    score_updater->AddScore(tree, curr_class);
Guolin Ke's avatar
Guolin Ke committed
224
225
226
  }
}

wxchan's avatar
wxchan committed
227
228
bool GBDT::OutputMetric(int iter) {
  bool ret = false;
Guolin Ke's avatar
Guolin Ke committed
229
  // print training metric
230
231
232
233
  if ((iter % gbdt_config_->output_freq) == 0) {
    for (auto& sub_metric : training_metrics_) {
      auto name = sub_metric->GetName();
      auto scores = sub_metric->Eval(train_score_updater_->score());
234
      Log::Info("Iteration: %d, %s: %s", iter, name, Common::ArrayToString<double>(scores, ' ').c_str());
235
    }
Guolin Ke's avatar
Guolin Ke committed
236
237
  }
  // print validation metric
238
239
240
241
242
243
  if ((iter % gbdt_config_->output_freq) == 0 || early_stopping_round_ > 0) {
    for (size_t i = 0; i < valid_metrics_.size(); ++i) {
      for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
        auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score());
        if ((iter % gbdt_config_->output_freq) == 0) {
          auto name = valid_metrics_[i][j]->GetName();
244
          Log::Info("Iteration: %d, %s: %s", iter, name, Common::ArrayToString<double>(test_scores, ' ').c_str());
wxchan's avatar
wxchan committed
245
        }
246
247
248
249
250
251
252
253
254
255
        if (!ret && early_stopping_round_ > 0) {
          bool the_bigger_the_better = valid_metrics_[i][j]->is_bigger_better();
          if (best_score_[i][j] < 0
            || (!the_bigger_the_better && test_scores.back() < best_score_[i][j])
            || (the_bigger_the_better && test_scores.back() > best_score_[i][j])) {
            best_score_[i][j] = test_scores.back();
            best_iter_[i][j] = iter;
          } else {
            if (iter - best_iter_[i][j] >= early_stopping_round_) ret = true;
          }
wxchan's avatar
wxchan committed
256
257
        }
      }
Guolin Ke's avatar
Guolin Ke committed
258
259
    }
  }
wxchan's avatar
wxchan committed
260
  return ret;
Guolin Ke's avatar
Guolin Ke committed
261
262
}

263
264
265
266
267
268
269
270
/*! \brief Get eval result */
std::vector<std::string> GBDT::EvalCurrent(bool is_eval_train) const {
  std::vector<std::string> ret;
  if (is_eval_train) {
    for (auto& sub_metric : training_metrics_) {
      auto name = sub_metric->GetName();
      auto scores = sub_metric->Eval(train_score_updater_->score());
      std::stringstream str_buf;
271
      str_buf << name << " : " << Common::ArrayToString<double>(scores, ' ');
272
273
274
275
276
277
278
279
280
      ret.emplace_back(str_buf.str());
    }
  }

  for (size_t i = 0; i < valid_metrics_.size(); ++i) {
    for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
      auto name = valid_metrics_[i][j]->GetName();
      auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score());
      std::stringstream str_buf;
281
      str_buf << name << " : " << Common::ArrayToString<double>(test_scores, ' ');
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
      ret.emplace_back(str_buf.str());
    }
  }
  return ret;
}

/*! \brief Get prediction result */
const std::vector<const score_t*> GBDT::PredictCurrent(bool is_predict_train) const {
  std::vector<const score_t*> ret;
  if (is_predict_train) {
    ret.push_back(train_score_updater_->score());
  }
  for (size_t i = 0; i < valid_metrics_.size(); ++i) {
    ret.push_back(valid_score_updater_[i]->score());
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
300
void GBDT::Boosting() {
301
302
303
  if (object_function_ == nullptr) {
    Log::Fatal("No object function provided");
  }
Hui Xue's avatar
Hui Xue committed
304
  // objective function will calculate gradients and hessians
Guolin Ke's avatar
Guolin Ke committed
305
306
307
308
  object_function_->
    GetGradients(train_score_updater_->score(), gradients_, hessians_);
}

309
310
311
312
313
314
315
void GBDT::SaveModelToFile(bool is_finish, const char* filename) {

  // first time to this function, open file
  if (saved_model_size_ == -1) {
    model_output_file_.open(filename);
    // output model type
    model_output_file_ << "gbdt" << std::endl;
316
317
    // output number of class
    model_output_file_ << "num_class=" << num_class_ << std::endl;
318
319
320
321
322
323
324
325
326
327
328
329
330
    // output label index
    model_output_file_ << "label_index=" << label_idx_ << std::endl;
    // output max_feature_idx
    model_output_file_ << "max_feature_idx=" << max_feature_idx_ << std::endl;
    // output sigmoid parameter
    model_output_file_ << "sigmoid=" << object_function_->GetSigmoid() << std::endl;
    model_output_file_ << std::endl;
    saved_model_size_ = 0;
  }
  // already saved
  if (!model_output_file_.is_open()) {
    return;
  }
331
  int rest = static_cast<int>(models_.size()) - early_stopping_round_ * num_class_;
332
333
334
335
336
  // output tree models
  for (int i = saved_model_size_; i < rest; ++i) {
    model_output_file_ << "Tree=" << i << std::endl;
    model_output_file_ << models_[i]->ToString() << std::endl;
  }
337

Guolin Ke's avatar
Guolin Ke committed
338
  saved_model_size_ = Common::Max(saved_model_size_, rest);
339

340
341
342
343
344
345
346
347
348
  model_output_file_.flush();
  // training finished, can close file
  if (is_finish) {
    for (int i = saved_model_size_; i < static_cast<int>(models_.size()); ++i) {
      model_output_file_ << "Tree=" << i << std::endl;
      model_output_file_ << models_[i]->ToString() << std::endl;
    }
    model_output_file_ << std::endl << FeatureImportance() << std::endl;
    model_output_file_.close();
Guolin Ke's avatar
Guolin Ke committed
349
350
351
  }
}

352
void GBDT::ModelsFromString(const std::string& model_str) {
Guolin Ke's avatar
Guolin Ke committed
353
354
355
356
  // use serialized string to restore this object
  models_.clear();
  std::vector<std::string> lines = Common::Split(model_str.c_str(), '\n');
  size_t i = 0;
357
358

  // get number of classes
359
360
361
362
363
364
365
366
367
368
369
370
  while (i < lines.size()) {
    size_t find_pos = lines[i].find("num_class=");
    if (find_pos != std::string::npos) {
      std::vector<std::string> strs = Common::Split(lines[i].c_str(), '=');
      Common::Atoi(strs[1].c_str(), &num_class_);
      ++i;
      break;
    } else {
      ++i;
    }
  }
  if (i == lines.size()) {
371
    Log::Fatal("Model file doesn't specify the number of classes");
372
373
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
374
375

  // get index of label
376
  i = 0;
Guolin Ke's avatar
Guolin Ke committed
377
378
379
380
381
382
383
384
385
386
387
388
  while (i < lines.size()) {
    size_t find_pos = lines[i].find("label_index=");
    if (find_pos != std::string::npos) {
      std::vector<std::string> strs = Common::Split(lines[i].c_str(), '=');
      Common::Atoi(strs[1].c_str(), &label_idx_);
      ++i;
      break;
    } else {
      ++i;
    }
  }
  if (i == lines.size()) {
389
    Log::Fatal("Model file doesn't specify the label index");
Guolin Ke's avatar
Guolin Ke committed
390
391
392
    return;
  }

Guolin Ke's avatar
Guolin Ke committed
393
  // get max_feature_idx first
Guolin Ke's avatar
Guolin Ke committed
394
  i = 0;
Guolin Ke's avatar
Guolin Ke committed
395
396
397
398
399
400
401
402
403
404
405
406
  while (i < lines.size()) {
    size_t find_pos = lines[i].find("max_feature_idx=");
    if (find_pos != std::string::npos) {
      std::vector<std::string> strs = Common::Split(lines[i].c_str(), '=');
      Common::Atoi(strs[1].c_str(), &max_feature_idx_);
      ++i;
      break;
    } else {
      ++i;
    }
  }
  if (i == lines.size()) {
407
    Log::Fatal("Model file doesn't specify max_feature_idx");
Guolin Ke's avatar
Guolin Ke committed
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
    return;
  }
  // get sigmoid parameter
  i = 0;
  while (i < lines.size()) {
    size_t find_pos = lines[i].find("sigmoid=");
    if (find_pos != std::string::npos) {
      std::vector<std::string> strs = Common::Split(lines[i].c_str(), '=');
      Common::Atof(strs[1].c_str(), &sigmoid_);
      ++i;
      break;
    } else {
      ++i;
    }
  }
  // if sigmoid doesn't exists
  if (i == lines.size()) {
425
    sigmoid_ = -1.0f;
Guolin Ke's avatar
Guolin Ke committed
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
  }
  // get tree models
  i = 0;
  while (i < lines.size()) {
    size_t find_pos = lines[i].find("Tree=");
    if (find_pos != std::string::npos) {
      ++i;
      int start = static_cast<int>(i);
      while (i < lines.size() && lines[i].find("Tree=") == std::string::npos) { ++i; }
      int end = static_cast<int>(i);
      std::string tree_str = Common::Join(lines, start, end, '\n');
      models_.push_back(new Tree(tree_str));
    } else {
      ++i;
    }
  }
442
  Log::Info("Finished loading %d models", models_.size());
443
  num_used_model_ = static_cast<int>(models_.size()) / num_class_;
Guolin Ke's avatar
Guolin Ke committed
444
445
}

446
std::string GBDT::FeatureImportance() const {
447
  std::vector<size_t> feature_importances(max_feature_idx_ + 1, 0);
448
    for (size_t iter = 0; iter < models_.size(); ++iter) {
449
450
        for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
            ++feature_importances[models_[iter]->split_feature_real(split_idx)];
wxchan's avatar
wxchan committed
451
452
        }
    }
453
454
455
456
457
458
459
460
461
    // store the importance first
    std::vector<std::pair<size_t, std::string>> pairs;
    for (size_t i = 0; i < feature_importances.size(); ++i) {
      pairs.emplace_back(feature_importances[i], train_data_->feature_names()[i]);
    }
    // sort the importance
    std::sort(pairs.begin(), pairs.end(),
      [](const std::pair<size_t, std::string>& lhs,
        const std::pair<size_t, std::string>& rhs) {
462
      return lhs.first > rhs.first;
463
    });
464
    std::stringstream str_buf;
465
    // write to model file
466
    str_buf << std::endl << "feature importances:" << std::endl;
467
    for (size_t i = 0; i < pairs.size(); ++i) {
468
      str_buf << pairs[i].second << "=" << std::to_string(pairs[i].first) << std::endl;
469
    }
470
    return str_buf.str();
wxchan's avatar
wxchan committed
471
472
}

473
double GBDT::PredictRaw(const double* value) const {
474
  double ret = 0.0f;
475
  for (int i = 0; i < num_used_model_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
476
477
478
479
480
    ret += models_[i]->Predict(value);
  }
  return ret;
}

481
double GBDT::Predict(const double* value) const {
482
  double ret = 0.0f;
483
  for (int i = 0; i < num_used_model_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
484
485
486
487
    ret += models_[i]->Predict(value);
  }
  // if need sigmoid transform
  if (sigmoid_ > 0) {
488
    ret = 1.0f / (1.0f + std::exp(- 2.0f * sigmoid_ * ret));
Guolin Ke's avatar
Guolin Ke committed
489
490
491
492
  }
  return ret;
}

493
std::vector<double> GBDT::PredictMulticlass(const double* value) const {
494
  std::vector<double> ret(num_class_, 0.0f);
495
  for (int i = 0; i < num_used_model_; ++i) {
496
497
498
499
500
501
502
    for (int j = 0; j < num_class_; ++j){
        ret[j] += models_[i * num_class_ + j] -> Predict(value);
    }
  }
  return ret;
}

503
std::vector<int> GBDT::PredictLeafIndex(const double* value) const {
wxchan's avatar
wxchan committed
504
  std::vector<int> ret;
505
  for (int i = 0; i < num_used_model_; ++i) {
wxchan's avatar
wxchan committed
506
507
508
509
510
    ret.push_back(models_[i]->PredictLeafIndex(value));
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
511
}  // namespace LightGBM