gbdt.cpp 19.4 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include "gbdt.h"

#include <LightGBM/utils/common.h>

#include <LightGBM/feature.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>

#include <ctime>

#include <sstream>
#include <chrono>
#include <string>
#include <vector>
15
#include <utility>
Guolin Ke's avatar
Guolin Ke committed
16
17
18

namespace LightGBM {

19
20
21
22
GBDT::GBDT() 
  :saved_model_size_(-1), 
  num_iteration_for_pred_(0), 
  num_init_iteration_(0) {
Guolin Ke's avatar
Guolin Ke committed
23

Guolin Ke's avatar
Guolin Ke committed
24
25
26
}

GBDT::~GBDT() {
Guolin Ke's avatar
Guolin Ke committed
27

Guolin Ke's avatar
Guolin Ke committed
28
29
}

30
31
32
void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
     const std::vector<const Metric*>& training_metrics) {
  iter_ = 0;
33
  saved_model_size_ = -1;
Guolin Ke's avatar
Guolin Ke committed
34
  num_iteration_for_pred_ = 0;
35
  max_feature_idx_ = 0;
36
  num_class_ = config->num_class;
37
  train_data_ = nullptr;
38
  ResetTrainingData(config, train_data, object_function, training_metrics);
39
40
}

41
42
void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
  const std::vector<const Metric*>& training_metrics) {
43
44
45
  if (train_data_ != nullptr && !train_data_->CheckAlign(*train_data)) {
    Log::Fatal("cannot reset training data, since new training data has different bin mappers");
  }
46
47
48
  gbdt_config_ = config;
  early_stopping_round_ = gbdt_config_->early_stopping_round;
  shrinkage_rate_ = gbdt_config_->learning_rate;
49
  train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
50
  // create tree learner
51
  tree_learner_.clear();
Guolin Ke's avatar
Guolin Ke committed
52
53
54
55
56
  for (int i = 0; i < num_class_; ++i) {
    auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config));
    new_tree_learner->Init(train_data_);
    // init tree learner
    tree_learner_.push_back(std::move(new_tree_learner));
57
  }
Guolin Ke's avatar
Guolin Ke committed
58
  tree_learner_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
59
60
  object_function_ = object_function;
  // push training metrics
61
  training_metrics_.clear();
Guolin Ke's avatar
Guolin Ke committed
62
63
64
  for (const auto& metric : training_metrics) {
    training_metrics_.push_back(metric);
  }
Guolin Ke's avatar
Guolin Ke committed
65
  training_metrics_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
66
  // create score tracker
Guolin Ke's avatar
Guolin Ke committed
67
  train_score_updater_.reset(new ScoreUpdater(train_data_, num_class_));
Guolin Ke's avatar
Guolin Ke committed
68
69
  num_data_ = train_data_->num_data();
  // create buffer for gradients and hessians
70
  if (object_function_ != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
71
72
73
74
    gradients_ = std::vector<score_t>(num_data_ * num_class_);
    hessians_ = std::vector<score_t>(num_data_ * num_class_);
  }
  sigmoid_ = -1.0f;
75
  if (object_function_ != nullptr
Guolin Ke's avatar
Guolin Ke committed
76
77
78
    && std::string(object_function_->GetName()) == std::string("binary")) {
    // only binary classification need sigmoid transform
    sigmoid_ = gbdt_config_->sigmoid;
79
  }
Guolin Ke's avatar
Guolin Ke committed
80
  // get max feature index
81
  max_feature_idx_ = train_data_->num_total_features() - 1;
Guolin Ke's avatar
Guolin Ke committed
82
83
  // get label index
  label_idx_ = train_data_->label_idx();
Guolin Ke's avatar
Guolin Ke committed
84
85
  // if need bagging, create buffer
  if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) {
Guolin Ke's avatar
Guolin Ke committed
86
87
    out_of_bag_data_indices_ = std::vector<data_size_t>(num_data_);
    bag_data_indices_ = std::vector<data_size_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
88
89
  } else {
    out_of_bag_data_cnt_ = 0;
Guolin Ke's avatar
Guolin Ke committed
90
    out_of_bag_data_indices_.clear();
Guolin Ke's avatar
Guolin Ke committed
91
    bag_data_cnt_ = num_data_;
Guolin Ke's avatar
Guolin Ke committed
92
    bag_data_indices_.clear();
Guolin Ke's avatar
Guolin Ke committed
93
  }
94
  random_ = Random(gbdt_config_->bagging_seed);
95
96
97
98
99
100
  // update score
  for (int i = 0; i < iter_; ++i) {
    for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
      auto curr_tree = i * num_class_ + curr_class;
      train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
    }
101
102
  }
}
103
104

void GBDT::AddValidDataset(const Dataset* valid_data,
Guolin Ke's avatar
Guolin Ke committed
105
  const std::vector<const Metric*>& valid_metrics) {
106
107
  if (!train_data_->CheckAlign(*valid_data)) {
    Log::Fatal("cannot add validation data, since it has different bin mappers with training data");
108
  }
Guolin Ke's avatar
Guolin Ke committed
109
  // for a validation dataset, we need its score and metric
Guolin Ke's avatar
Guolin Ke committed
110
  auto new_score_updater = std::unique_ptr<ScoreUpdater>(new ScoreUpdater(valid_data, num_class_));
111
112
113
114
115
116
117
  // update score
  for (int i = 0; i < iter_; ++i) {
    for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
      auto curr_tree = i * num_class_ + curr_class;
      new_score_updater->AddScore(models_[curr_tree].get(), curr_class);
    }
  }
Guolin Ke's avatar
Guolin Ke committed
118
  valid_score_updater_.push_back(std::move(new_score_updater));
Guolin Ke's avatar
Guolin Ke committed
119
  valid_metrics_.emplace_back();
120
121
122
123
  if (early_stopping_round_ > 0) {
    best_iter_.emplace_back();
    best_score_.emplace_back();
  }
Guolin Ke's avatar
Guolin Ke committed
124
125
  for (const auto& metric : valid_metrics) {
    valid_metrics_.back().push_back(metric);
126
127
128
129
    if (early_stopping_round_ > 0) {
      best_iter_.back().push_back(0);
      best_score_.back().push_back(kMinScore);
    }
Guolin Ke's avatar
Guolin Ke committed
130
  }
Guolin Ke's avatar
Guolin Ke committed
131
  valid_metrics_.back().shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
132
133
134
}


135
void GBDT::Bagging(int iter, const int curr_class) {
Guolin Ke's avatar
Guolin Ke committed
136
  // if need bagging
Guolin Ke's avatar
Guolin Ke committed
137
  if (out_of_bag_data_indices_.size() > 0 && iter % gbdt_config_->bagging_freq == 0) {
Guolin Ke's avatar
Guolin Ke committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    // if doesn't have query data
    if (train_data_->metadata().query_boundaries() == nullptr) {
      bag_data_cnt_ =
        static_cast<data_size_t>(gbdt_config_->bagging_fraction * num_data_);
      out_of_bag_data_cnt_ = num_data_ - bag_data_cnt_;
      data_size_t cur_left_cnt = 0;
      data_size_t cur_right_cnt = 0;
      // random bagging, minimal unit is one record
      for (data_size_t i = 0; i < num_data_; ++i) {
        double prob =
          (bag_data_cnt_ - cur_left_cnt) / static_cast<double>(num_data_ - i);
        if (random_.NextDouble() < prob) {
          bag_data_indices_[cur_left_cnt++] = i;
        } else {
          out_of_bag_data_indices_[cur_right_cnt++] = i;
        }
      }
    } else {
      // if have query data
      const data_size_t* query_boundaries = train_data_->metadata().query_boundaries();
      data_size_t num_query = train_data_->metadata().num_queries();
      data_size_t bag_query_cnt =
          static_cast<data_size_t>(num_query * gbdt_config_->bagging_fraction);
      data_size_t cur_left_query_cnt = 0;
      data_size_t cur_left_cnt = 0;
      data_size_t cur_right_cnt = 0;
      // random bagging, minimal unit is one query
      for (data_size_t i = 0; i < num_query; ++i) {
        double prob =
            (bag_query_cnt - cur_left_query_cnt) / static_cast<double>(num_query - i);
        if (random_.NextDouble() < prob) {
          for (data_size_t j = query_boundaries[i]; j < query_boundaries[i + 1]; ++j) {
            bag_data_indices_[cur_left_cnt++] = j;
          }
          cur_left_query_cnt++;
        } else {
          for (data_size_t j = query_boundaries[i]; j < query_boundaries[i + 1]; ++j) {
            out_of_bag_data_indices_[cur_right_cnt++] = j;
          }
        }
      }
      bag_data_cnt_ = cur_left_cnt;
      out_of_bag_data_cnt_ = num_data_ - bag_data_cnt_;
    }
Guolin Ke's avatar
Guolin Ke committed
182
    Log::Debug("Re-bagging, using %d data to train", bag_data_cnt_);
Guolin Ke's avatar
Guolin Ke committed
183
    // set bagging data to tree learner
Guolin Ke's avatar
Guolin Ke committed
184
    tree_learner_[curr_class]->SetBaggingData(bag_data_indices_.data(), bag_data_cnt_);
Guolin Ke's avatar
Guolin Ke committed
185
186
187
  }
}

188
void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int curr_class) {
Hui Xue's avatar
Hui Xue committed
189
  // we need to predict out-of-bag socres of data for boosting
Guolin Ke's avatar
Guolin Ke committed
190
191
  if (out_of_bag_data_indices_.size() > 0) {
    train_score_updater_->AddScore(tree, out_of_bag_data_indices_.data(), out_of_bag_data_cnt_, curr_class);
Guolin Ke's avatar
Guolin Ke committed
192
193
194
  }
}

195
bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) {
Guolin Ke's avatar
Guolin Ke committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
  // boosting first
  if (gradient == nullptr || hessian == nullptr) {
    Boosting();
    gradient = gradients_.data();
    hessian = hessians_.data();
  }

  for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
    // bagging logic
    Bagging(iter_, curr_class);

    // train a new tree
    std::unique_ptr<Tree> new_tree(tree_learner_[curr_class]->Train(gradient + curr_class * num_data_, hessian + curr_class * num_data_));
    // if cannot learn a new tree, then stop
    if (new_tree->num_leaves() <= 1) {
      Log::Info("Stopped training because there are no more leafs that meet the split requirements.");
      return true;
213
    }
214

Guolin Ke's avatar
Guolin Ke committed
215
216
217
218
219
    // shrinkage by learning rate
    new_tree->Shrinkage(shrinkage_rate_);
    // update score
    UpdateScore(new_tree.get(), curr_class);
    UpdateScoreOutOfBag(new_tree.get(), curr_class);
220

Guolin Ke's avatar
Guolin Ke committed
221
222
223
224
225
226
227
228
229
    // add model
    models_.push_back(std::move(new_tree));
  }
  ++iter_;
  if (is_eval) {
    return EvalAndCheckEarlyStopping();
  } else {
    return false;
  }
230

Guolin Ke's avatar
Guolin Ke committed
231
}
232

233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
void GBDT::RollbackOneIter() {
  if (iter_ == 0) { return; }
  int cur_iter = iter_ - 1;
  // reset score
  for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
    auto curr_tree = cur_iter * num_class_ + curr_class;
    models_[curr_tree]->Shrinkage(-1.0);
    train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
    for (auto& score_updater : valid_score_updater_) {
      score_updater->AddScore(models_[curr_tree].get(), curr_class);
    }
  }
  // remove model
  for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
    models_.pop_back();
  }
  --iter_;
}

Guolin Ke's avatar
Guolin Ke committed
252
bool GBDT::EvalAndCheckEarlyStopping() {
253
254
  bool is_met_early_stopping = false;
  // print message for metric
Guolin Ke's avatar
Guolin Ke committed
255
  is_met_early_stopping = OutputMetric(iter_);
256
257
258
259
  if (is_met_early_stopping) {
    Log::Info("Early stopping at iteration %d, the best iteration round is %d",
      iter_, iter_ - early_stopping_round_);
    // pop last early_stopping_round_ models
260
    for (int i = 0; i < early_stopping_round_ * num_class_; ++i) {
261
262
263
264
      models_.pop_back();
    }
  }
  return is_met_early_stopping;
Guolin Ke's avatar
Guolin Ke committed
265
266
}

267
void GBDT::UpdateScore(const Tree* tree, const int curr_class) {
Guolin Ke's avatar
Guolin Ke committed
268
  // update training score
Guolin Ke's avatar
Guolin Ke committed
269
  train_score_updater_->AddScore(tree_learner_[curr_class].get(), curr_class);
Guolin Ke's avatar
Guolin Ke committed
270
  // update validation score
Guolin Ke's avatar
Guolin Ke committed
271
272
  for (auto& score_updater : valid_score_updater_) {
    score_updater->AddScore(tree, curr_class);
Guolin Ke's avatar
Guolin Ke committed
273
274
275
  }
}

wxchan's avatar
wxchan committed
276
277
bool GBDT::OutputMetric(int iter) {
  bool ret = false;
Guolin Ke's avatar
Guolin Ke committed
278
  // print training metric
279
280
281
282
  if ((iter % gbdt_config_->output_freq) == 0) {
    for (auto& sub_metric : training_metrics_) {
      auto name = sub_metric->GetName();
      auto scores = sub_metric->Eval(train_score_updater_->score());
Guolin Ke's avatar
Guolin Ke committed
283
      for (size_t k = 0; k < name.size(); ++k) {
Guolin Ke's avatar
Guolin Ke committed
284
        Log::Info("Iteration:%d, training %s : %f", iter, name[k].c_str(), scores[k]);
285
      }
286
    }
Guolin Ke's avatar
Guolin Ke committed
287
288
  }
  // print validation metric
289
290
291
292
293
294
  if ((iter % gbdt_config_->output_freq) == 0 || early_stopping_round_ > 0) {
    for (size_t i = 0; i < valid_metrics_.size(); ++i) {
      for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
        auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score());
        if ((iter % gbdt_config_->output_freq) == 0) {
          auto name = valid_metrics_[i][j]->GetName();
Guolin Ke's avatar
Guolin Ke committed
295
          for (size_t k = 0; k < name.size(); ++k) {
Guolin Ke's avatar
Guolin Ke committed
296
            Log::Info("Iteration:%d, valid_%d %s : %f", iter, i + 1, name[k].c_str(), test_scores[k]);
297
          }
wxchan's avatar
wxchan committed
298
        }
299
        if (!ret && early_stopping_round_ > 0) {
300
301
302
          auto cur_score = valid_metrics_[i][j]->factor_to_bigger_better() * test_scores.back();
          if (cur_score > best_score_[i][j]) {
            best_score_[i][j] = cur_score;
303
304
            best_iter_[i][j] = iter;
          } else {
305
            if (iter - best_iter_[i][j] >= early_stopping_round_) { ret = true; }
306
          }
wxchan's avatar
wxchan committed
307
308
        }
      }
Guolin Ke's avatar
Guolin Ke committed
309
310
    }
  }
wxchan's avatar
wxchan committed
311
  return ret;
Guolin Ke's avatar
Guolin Ke committed
312
313
}

314
/*! \brief Get eval result */
315
316
317
318
std::vector<double> GBDT::GetEvalAt(int data_idx) const {
  CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_metrics_.size()));
  std::vector<double> ret;
  if (data_idx == 0) {
319
320
    for (auto& sub_metric : training_metrics_) {
      auto scores = sub_metric->Eval(train_score_updater_->score());
321
322
323
      for (auto score : scores) {
        ret.push_back(score);
      }
324
325
    }
  }
326
327
328
329
330
331
332
  else {
    auto used_idx = data_idx - 1;
    for (size_t j = 0; j < valid_metrics_[used_idx].size(); ++j) {
      auto test_scores = valid_metrics_[used_idx][j]->Eval(valid_score_updater_[used_idx]->score());
      for (auto score : test_scores) {
        ret.push_back(score);
      }
333
334
335
336
337
    }
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
338
/*! \brief Get training scores result */
339
const score_t* GBDT::GetTrainingScore(data_size_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
340
341
  *out_len = train_score_updater_->num_data() * num_class_;
  return train_score_updater_->score();
342
343
}

Guolin Ke's avatar
Guolin Ke committed
344
void GBDT::GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
345
346
347
348
349
350
  CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_metrics_.size()));
  std::vector<double> ret;

  const score_t* raw_scores = nullptr;
  data_size_t num_data = 0;
  if (data_idx == 0) {
Guolin Ke's avatar
Guolin Ke committed
351
    raw_scores = GetTrainingScore(out_len);
Guolin Ke's avatar
Guolin Ke committed
352
353
354
355
356
    num_data = train_score_updater_->num_data();
  } else {
    auto used_idx = data_idx - 1;
    raw_scores = valid_score_updater_[used_idx]->score();
    num_data = valid_score_updater_[used_idx]->num_data();
Guolin Ke's avatar
Guolin Ke committed
357
    *out_len = num_data * num_class_;
Guolin Ke's avatar
Guolin Ke committed
358
359
  }
  if (num_class_ > 1) {
Guolin Ke's avatar
Guolin Ke committed
360
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
361
362
363
364
365
366
367
368
369
370
    for (data_size_t i = 0; i < num_data; ++i) {
      std::vector<double> tmp_result;
      for (int j = 0; j < num_class_; ++j) {
        tmp_result.push_back(raw_scores[j * num_data + i]);
      }
      Common::Softmax(&tmp_result);
      for (int j = 0; j < num_class_; ++j) {
        out_result[j * num_data + i] = static_cast<score_t>(tmp_result[i]);
      }
    }
Guolin Ke's avatar
Guolin Ke committed
371
  } else if(sigmoid_ > 0.0f){
Guolin Ke's avatar
Guolin Ke committed
372
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
373
374
375
376
    for (data_size_t i = 0; i < num_data; ++i) {
      out_result[i] = static_cast<score_t>(1.0f / (1.0f + std::exp(-2.0f * sigmoid_ * raw_scores[i])));
    }
  } else {
Guolin Ke's avatar
Guolin Ke committed
377
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
378
379
380
381
382
383
384
    for (data_size_t i = 0; i < num_data; ++i) {
      out_result[i] = raw_scores[i];
    }
  }

}

Guolin Ke's avatar
Guolin Ke committed
385
void GBDT::Boosting() {
386
387
388
  if (object_function_ == nullptr) {
    Log::Fatal("No object function provided");
  }
Hui Xue's avatar
Hui Xue committed
389
  // objective function will calculate gradients and hessians
Guolin Ke's avatar
Guolin Ke committed
390
  int num_score = 0;
Guolin Ke's avatar
Guolin Ke committed
391
  object_function_->
Guolin Ke's avatar
Guolin Ke committed
392
    GetGradients(GetTrainingScore(&num_score), gradients_.data(), hessians_.data());
Guolin Ke's avatar
Guolin Ke committed
393
394
}

Guolin Ke's avatar
Guolin Ke committed
395
void GBDT::SaveModelToFile(int num_iteration, bool is_finish, const char* filename) {
396
  // first time to this function, open file
Guolin Ke's avatar
Guolin Ke committed
397
  if (saved_model_size_ < 0) {
398
399
    model_output_file_.open(filename);
    // output model type
400
    model_output_file_ << Name() << std::endl;
401
402
    // output number of class
    model_output_file_ << "num_class=" << num_class_ << std::endl;
403
404
405
406
    // output label index
    model_output_file_ << "label_index=" << label_idx_ << std::endl;
    // output max_feature_idx
    model_output_file_ << "max_feature_idx=" << max_feature_idx_ << std::endl;
Guolin Ke's avatar
Guolin Ke committed
407
408
409
410
    // output objective name
    if (object_function_ != nullptr) {
      model_output_file_ << "objective=" << object_function_->GetName() << std::endl;
    }
411
    // output sigmoid parameter
Guolin Ke's avatar
Guolin Ke committed
412
    model_output_file_ << "sigmoid=" << sigmoid_ << std::endl;
413
414
415
416
417
418
419
    model_output_file_ << std::endl;
    saved_model_size_ = 0;
  }
  // already saved
  if (!model_output_file_.is_open()) {
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
420
421
  int num_used_model = 0;
  if (num_iteration == NO_LIMIT) {
Guolin Ke's avatar
Guolin Ke committed
422
423
    num_used_model = static_cast<int>(models_.size());
  } else {
Guolin Ke's avatar
Guolin Ke committed
424
    num_used_model = num_iteration * num_class_;
Guolin Ke's avatar
Guolin Ke committed
425
426
  }
  int rest = num_used_model - early_stopping_round_ * num_class_;
427
428
429
430
431
  // output tree models
  for (int i = saved_model_size_; i < rest; ++i) {
    model_output_file_ << "Tree=" << i << std::endl;
    model_output_file_ << models_[i]->ToString() << std::endl;
  }
432

Guolin Ke's avatar
Guolin Ke committed
433
  saved_model_size_ = std::max(saved_model_size_, rest);
434

435
436
437
  model_output_file_.flush();
  // training finished, can close file
  if (is_finish) {
Guolin Ke's avatar
Guolin Ke committed
438
    for (int i = saved_model_size_; i < num_used_model; ++i) {
439
440
441
442
443
      model_output_file_ << "Tree=" << i << std::endl;
      model_output_file_ << models_[i]->ToString() << std::endl;
    }
    model_output_file_ << std::endl << FeatureImportance() << std::endl;
    model_output_file_.close();
Guolin Ke's avatar
Guolin Ke committed
444
445
446
  }
}

Guolin Ke's avatar
Guolin Ke committed
447
void GBDT::LoadModelFromString(const std::string& model_str) {
Guolin Ke's avatar
Guolin Ke committed
448
449
450
  // use serialized string to restore this object
  models_.clear();
  std::vector<std::string> lines = Common::Split(model_str.c_str(), '\n');
451
452

  // get number of classes
453
454
455
456
  auto line = Common::FindFromLines(lines, "num_class=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &num_class_);
  } else {
457
    Log::Fatal("Model file doesn't specify the number of classes");
458
459
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
460
  // get index of label
461
462
463
464
  line = Common::FindFromLines(lines, "label_index=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &label_idx_);
  } else {
465
    Log::Fatal("Model file doesn't specify the label index");
Guolin Ke's avatar
Guolin Ke committed
466
467
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
468
  // get max_feature_idx first
469
470
471
472
  line = Common::FindFromLines(lines, "max_feature_idx=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &max_feature_idx_);
  } else {
473
    Log::Fatal("Model file doesn't specify max_feature_idx");
Guolin Ke's avatar
Guolin Ke committed
474
475
476
    return;
  }
  // get sigmoid parameter
477
478
479
480
  line = Common::FindFromLines(lines, "sigmoid=");
  if (line.size() > 0) {
    Common::Atof(Common::Split(line.c_str(), '=')[1].c_str(), &sigmoid_);
  } else {
481
    sigmoid_ = -1.0f;
Guolin Ke's avatar
Guolin Ke committed
482
483
  }
  // get tree models
484
  size_t i = 0;
Guolin Ke's avatar
Guolin Ke committed
485
486
487
488
489
490
491
  while (i < lines.size()) {
    size_t find_pos = lines[i].find("Tree=");
    if (find_pos != std::string::npos) {
      ++i;
      int start = static_cast<int>(i);
      while (i < lines.size() && lines[i].find("Tree=") == std::string::npos) { ++i; }
      int end = static_cast<int>(i);
492
      std::string tree_str = Common::Join<std::string>(lines, start, end, '\n');
Guolin Ke's avatar
Guolin Ke committed
493
494
      auto new_tree = std::unique_ptr<Tree>(new Tree(tree_str));
      models_.push_back(std::move(new_tree));
Guolin Ke's avatar
Guolin Ke committed
495
496
497
498
    } else {
      ++i;
    }
  }
499
  Log::Info("Finished loading %d models", models_.size());
Guolin Ke's avatar
Guolin Ke committed
500
  num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_class_;
501
  num_init_iteration_ = num_iteration_for_pred_;
Guolin Ke's avatar
Guolin Ke committed
502
503
}

504
std::string GBDT::FeatureImportance() const {
505
  std::vector<size_t> feature_importances(max_feature_idx_ + 1, 0);
506
    for (size_t iter = 0; iter < models_.size(); ++iter) {
507
508
        for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
            ++feature_importances[models_[iter]->split_feature_real(split_idx)];
wxchan's avatar
wxchan committed
509
510
        }
    }
511
512
513
    // store the importance first
    std::vector<std::pair<size_t, std::string>> pairs;
    for (size_t i = 0; i < feature_importances.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
514
515
516
      if (feature_importances[i] > 0) {
        pairs.emplace_back(feature_importances[i], train_data_->feature_names()[i]);
      }
517
518
519
520
521
    }
    // sort the importance
    std::sort(pairs.begin(), pairs.end(),
      [](const std::pair<size_t, std::string>& lhs,
        const std::pair<size_t, std::string>& rhs) {
522
      return lhs.first > rhs.first;
523
    });
524
    std::stringstream str_buf;
525
    // write to model file
526
    str_buf << std::endl << "feature importances:" << std::endl;
527
    for (size_t i = 0; i < pairs.size(); ++i) {
528
      str_buf << pairs[i].second << "=" << std::to_string(pairs[i].first) << std::endl;
529
    }
530
    return str_buf.str();
wxchan's avatar
wxchan committed
531
532
}

533
534
std::vector<double> GBDT::PredictRaw(const double* value) const {
  std::vector<double> ret(num_class_, 0.0f);
Guolin Ke's avatar
Guolin Ke committed
535
  for (int i = 0; i < num_iteration_for_pred_; ++i) {
536
537
538
    for (int j = 0; j < num_class_; ++j) {
      ret[j] += models_[i * num_class_ + j]->Predict(value);
    }
Guolin Ke's avatar
Guolin Ke committed
539
540
541
542
  }
  return ret;
}

543
std::vector<double> GBDT::Predict(const double* value) const {
544
  std::vector<double> ret(num_class_, 0.0f);
Guolin Ke's avatar
Guolin Ke committed
545
  for (int i = 0; i < num_iteration_for_pred_; ++i) {
546
547
    for (int j = 0; j < num_class_; ++j) {
      ret[j] += models_[i * num_class_ + j]->Predict(value);
548
549
    }
  }
550
551
552
553
554
555
  // if need sigmoid transform
  if (sigmoid_ > 0 && num_class_ == 1) {
    ret[0] = 1.0f / (1.0f + std::exp(- 2.0f * sigmoid_ * ret[0]));
  } else if (num_class_ > 1) {
    Common::Softmax(&ret);
  }
556
557
558
  return ret;
}

559
std::vector<int> GBDT::PredictLeafIndex(const double* value) const {
wxchan's avatar
wxchan committed
560
  std::vector<int> ret;
Guolin Ke's avatar
Guolin Ke committed
561
  for (int i = 0; i < num_iteration_for_pred_; ++i) {
562
563
564
    for (int j = 0; j < num_class_; ++j) {
      ret.push_back(models_[i * num_class_ + j]->PredictLeafIndex(value));
    }
wxchan's avatar
wxchan committed
565
566
567
568
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
569
}  // namespace LightGBM