gbdt.cpp 18.2 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include "gbdt.h"

#include <LightGBM/utils/common.h>

#include <LightGBM/feature.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>

#include <ctime>

#include <sstream>
#include <chrono>
#include <string>
#include <vector>
15
#include <utility>
Guolin Ke's avatar
Guolin Ke committed
16
17
18

namespace LightGBM {

19
GBDT::GBDT()
20
  : train_score_updater_(nullptr),
Guolin Ke's avatar
Guolin Ke committed
21
  gradients_(nullptr), hessians_(nullptr),
22
23
  out_of_bag_data_indices_(nullptr), bag_data_indices_(nullptr),
  saved_model_size_(-1), num_used_model_(0) {
Guolin Ke's avatar
Guolin Ke committed
24
25
26
}

GBDT::~GBDT() {
27
  for (auto& tree_learner: tree_learner_){
28
    if (tree_learner != nullptr) { delete tree_learner; }
29
  }
Guolin Ke's avatar
Guolin Ke committed
30
31
32
33
34
35
36
37
38
39
40
41
42
  if (gradients_ != nullptr) { delete[] gradients_; }
  if (hessians_ != nullptr) { delete[] hessians_; }
  if (out_of_bag_data_indices_ != nullptr) { delete[] out_of_bag_data_indices_; }
  if (bag_data_indices_ != nullptr) { delete[] bag_data_indices_; }
  for (auto& tree : models_) {
    if (tree != nullptr) { delete tree; }
  }
  if (train_score_updater_ != nullptr) { delete train_score_updater_; }
  for (auto& score_tracker : valid_score_updater_) {
    if (score_tracker != nullptr) { delete score_tracker; }
  }
}

43
44
45
46
void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
     const std::vector<const Metric*>& training_metrics) {
  gbdt_config_ = dynamic_cast<const GBDTConfig*>(config);
  iter_ = 0;
47
  saved_model_size_ = -1;
48
49
  max_feature_idx_ = 0;
  early_stopping_round_ = gbdt_config_->early_stopping_round;
Guolin Ke's avatar
Guolin Ke committed
50
  train_data_ = train_data;
51
52
  num_class_ = config->num_class;
  tree_learner_ = std::vector<TreeLearner*>(num_class_, nullptr);
Guolin Ke's avatar
Guolin Ke committed
53
  // create tree learner
54
55
56
57
58
59
  for (int i = 0; i < num_class_; ++i){
      tree_learner_[i] =
        TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config);
      // init tree learner
      tree_learner_[i]->Init(train_data_);
  }
Guolin Ke's avatar
Guolin Ke committed
60
61
62
63
64
65
  object_function_ = object_function;
  // push training metrics
  for (const auto& metric : training_metrics) {
    training_metrics_.push_back(metric);
  }
  // create score tracker
66
  train_score_updater_ = new ScoreUpdater(train_data_, num_class_);
Guolin Ke's avatar
Guolin Ke committed
67
68
  num_data_ = train_data_->num_data();
  // create buffer for gradients and hessians
69
  if (object_function_ != nullptr) {
70
71
    gradients_ = new score_t[num_data_ * num_class_];
    hessians_ = new score_t[num_data_ * num_class_];
72
  }
Guolin Ke's avatar
Guolin Ke committed
73
74

  // get max feature index
75
  max_feature_idx_ = train_data_->num_total_features() - 1;
Guolin Ke's avatar
Guolin Ke committed
76
77
  // get label index
  label_idx_ = train_data_->label_idx();
Guolin Ke's avatar
Guolin Ke committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
  // if need bagging, create buffer
  if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) {
    out_of_bag_data_indices_ = new data_size_t[num_data_];
    bag_data_indices_ = new data_size_t[num_data_];
  } else {
    out_of_bag_data_cnt_ = 0;
    out_of_bag_data_indices_ = nullptr;
    bag_data_cnt_ = num_data_;
    bag_data_indices_ = nullptr;
  }
  // initialize random generator
  random_ = Random(gbdt_config_->bagging_seed);

}

void GBDT::AddDataset(const Dataset* valid_data,
         const std::vector<const Metric*>& valid_metrics) {
95
96
97
  if (iter_ > 0) {
    Log::Fatal("Cannot add validation data after training started");
  }
Guolin Ke's avatar
Guolin Ke committed
98
  // for a validation dataset, we need its score and metric
99
  valid_score_updater_.push_back(new ScoreUpdater(valid_data, num_class_));
Guolin Ke's avatar
Guolin Ke committed
100
  valid_metrics_.emplace_back();
101
102
103
104
  if (early_stopping_round_ > 0) {
    best_iter_.emplace_back();
    best_score_.emplace_back();
  }
Guolin Ke's avatar
Guolin Ke committed
105
106
  for (const auto& metric : valid_metrics) {
    valid_metrics_.back().push_back(metric);
107
108
109
110
    if (early_stopping_round_ > 0) {
      best_iter_.back().push_back(0);
      best_score_.back().push_back(kMinScore);
    }
Guolin Ke's avatar
Guolin Ke committed
111
112
113
114
  }
}


115
void GBDT::Bagging(int iter, const int curr_class) {
Guolin Ke's avatar
Guolin Ke committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
  // if need bagging
  if (out_of_bag_data_indices_ != nullptr && iter % gbdt_config_->bagging_freq == 0) {
    // if doesn't have query data
    if (train_data_->metadata().query_boundaries() == nullptr) {
      bag_data_cnt_ =
        static_cast<data_size_t>(gbdt_config_->bagging_fraction * num_data_);
      out_of_bag_data_cnt_ = num_data_ - bag_data_cnt_;
      data_size_t cur_left_cnt = 0;
      data_size_t cur_right_cnt = 0;
      // random bagging, minimal unit is one record
      for (data_size_t i = 0; i < num_data_; ++i) {
        double prob =
          (bag_data_cnt_ - cur_left_cnt) / static_cast<double>(num_data_ - i);
        if (random_.NextDouble() < prob) {
          bag_data_indices_[cur_left_cnt++] = i;
        } else {
          out_of_bag_data_indices_[cur_right_cnt++] = i;
        }
      }
    } else {
      // if have query data
      const data_size_t* query_boundaries = train_data_->metadata().query_boundaries();
      data_size_t num_query = train_data_->metadata().num_queries();
      data_size_t bag_query_cnt =
          static_cast<data_size_t>(num_query * gbdt_config_->bagging_fraction);
      data_size_t cur_left_query_cnt = 0;
      data_size_t cur_left_cnt = 0;
      data_size_t cur_right_cnt = 0;
      // random bagging, minimal unit is one query
      for (data_size_t i = 0; i < num_query; ++i) {
        double prob =
            (bag_query_cnt - cur_left_query_cnt) / static_cast<double>(num_query - i);
        if (random_.NextDouble() < prob) {
          for (data_size_t j = query_boundaries[i]; j < query_boundaries[i + 1]; ++j) {
            bag_data_indices_[cur_left_cnt++] = j;
          }
          cur_left_query_cnt++;
        } else {
          for (data_size_t j = query_boundaries[i]; j < query_boundaries[i + 1]; ++j) {
            out_of_bag_data_indices_[cur_right_cnt++] = j;
          }
        }
      }
      bag_data_cnt_ = cur_left_cnt;
      out_of_bag_data_cnt_ = num_data_ - bag_data_cnt_;
    }
162
    Log::Info("Re-bagging, using %d data to train", bag_data_cnt_);
Guolin Ke's avatar
Guolin Ke committed
163
    // set bagging data to tree learner
164
    tree_learner_[curr_class]->SetBaggingData(bag_data_indices_, bag_data_cnt_);
Guolin Ke's avatar
Guolin Ke committed
165
166
167
  }
}

168
void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int curr_class) {
Hui Xue's avatar
Hui Xue committed
169
  // we need to predict out-of-bag socres of data for boosting
Guolin Ke's avatar
Guolin Ke committed
170
171
  if (out_of_bag_data_indices_ != nullptr) {
    train_score_updater_->
172
      AddScore(tree, out_of_bag_data_indices_, out_of_bag_data_cnt_, curr_class);
Guolin Ke's avatar
Guolin Ke committed
173
174
175
  }
}

176
bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) {
177
178
179
180
181
182
    // boosting first
    if (gradient == nullptr || hessian == nullptr) {
      Boosting();
      gradient = gradients_;
      hessian = hessians_;
    }
183

184
185
186
    for (int curr_class = 0; curr_class < num_class_; ++curr_class){
      // bagging logic
      Bagging(iter_, curr_class);
187

188
189
190
191
      // train a new tree
      Tree * new_tree = tree_learner_[curr_class]->Train(gradient + curr_class * num_data_, hessian+ curr_class * num_data_);
      // if cannot learn a new tree, then stop
      if (new_tree->num_leaves() <= 1) {
192
        Log::Info("Stopped training because there are no more leafs that meet the split requirements.");
193
194
        return true;
      }
195

196
197
198
199
200
      // shrinkage by learning rate
      new_tree->Shrinkage(gbdt_config_->learning_rate);
      // update score
      UpdateScore(new_tree, curr_class);
      UpdateScoreOutOfBag(new_tree, curr_class);
201

202
203
204
      // add model
      models_.push_back(new_tree);
    }
205

206
207
208
209
210
211
212
213
214
215
  bool is_met_early_stopping = false;
  // print message for metric
  if (is_eval) {
    is_met_early_stopping = OutputMetric(iter_ + 1);
  }
  ++iter_;
  if (is_met_early_stopping) {
    Log::Info("Early stopping at iteration %d, the best iteration round is %d",
      iter_, iter_ - early_stopping_round_);
    // pop last early_stopping_round_ models
216
    for (int i = 0; i < early_stopping_round_ * num_class_; ++i) {
217
218
219
220
221
      delete models_.back();
      models_.pop_back();
    }
  }
  return is_met_early_stopping;
222

Guolin Ke's avatar
Guolin Ke committed
223
224
}

225
void GBDT::UpdateScore(const Tree* tree, const int curr_class) {
Guolin Ke's avatar
Guolin Ke committed
226
  // update training score
227
  train_score_updater_->AddScore(tree_learner_[curr_class], curr_class);
Guolin Ke's avatar
Guolin Ke committed
228
  // update validation score
Guolin Ke's avatar
Guolin Ke committed
229
230
  for (auto& score_updater : valid_score_updater_) {
    score_updater->AddScore(tree, curr_class);
Guolin Ke's avatar
Guolin Ke committed
231
232
233
  }
}

wxchan's avatar
wxchan committed
234
235
bool GBDT::OutputMetric(int iter) {
  bool ret = false;
Guolin Ke's avatar
Guolin Ke committed
236
  // print training metric
237
238
239
240
  if ((iter % gbdt_config_->output_freq) == 0) {
    for (auto& sub_metric : training_metrics_) {
      auto name = sub_metric->GetName();
      auto scores = sub_metric->Eval(train_score_updater_->score());
Guolin Ke's avatar
Guolin Ke committed
241
      for (size_t k = 0; k < name.size(); ++k) {
242
243
        Log::Info("Iteration: %d, %s : %f", iter, name[k].c_str(), scores[k]);
      }
244
    }
Guolin Ke's avatar
Guolin Ke committed
245
246
  }
  // print validation metric
247
248
249
250
251
252
  if ((iter % gbdt_config_->output_freq) == 0 || early_stopping_round_ > 0) {
    for (size_t i = 0; i < valid_metrics_.size(); ++i) {
      for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
        auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score());
        if ((iter % gbdt_config_->output_freq) == 0) {
          auto name = valid_metrics_[i][j]->GetName();
Guolin Ke's avatar
Guolin Ke committed
253
          for (size_t k = 0; k < name.size(); ++k) {
254
255
            Log::Info("Iteration: %d, %s : %f", iter, name[k].c_str(), test_scores[k]);
          }
wxchan's avatar
wxchan committed
256
        }
257
        if (!ret && early_stopping_round_ > 0) {
258
259
260
          auto cur_score = valid_metrics_[i][j]->factor_to_bigger_better() * test_scores.back();
          if (cur_score > best_score_[i][j]) {
            best_score_[i][j] = cur_score;
261
262
            best_iter_[i][j] = iter;
          } else {
263
            if (iter - best_iter_[i][j] >= early_stopping_round_) { ret = true; }
264
          }
wxchan's avatar
wxchan committed
265
266
        }
      }
Guolin Ke's avatar
Guolin Ke committed
267
268
    }
  }
wxchan's avatar
wxchan committed
269
  return ret;
Guolin Ke's avatar
Guolin Ke committed
270
271
}

272
/*! \brief Get eval result */
273
274
275
276
std::vector<double> GBDT::GetEvalAt(int data_idx) const {
  CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_metrics_.size()));
  std::vector<double> ret;
  if (data_idx == 0) {
277
278
    for (auto& sub_metric : training_metrics_) {
      auto scores = sub_metric->Eval(train_score_updater_->score());
279
280
281
      for (auto score : scores) {
        ret.push_back(score);
      }
282
283
    }
  }
284
285
286
287
288
289
290
  else {
    auto used_idx = data_idx - 1;
    for (size_t j = 0; j < valid_metrics_[used_idx].size(); ++j) {
      auto test_scores = valid_metrics_[used_idx][j]->Eval(valid_score_updater_[used_idx]->score());
      for (auto score : test_scores) {
        ret.push_back(score);
      }
291
292
293
294
295
    }
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
296
/*! \brief Get training scores result */
Guolin Ke's avatar
Guolin Ke committed
297
298
299
const score_t* GBDT::GetTrainingScore(data_size_t* out_len) const {
  *out_len = train_score_updater_->num_data() * num_class_;
  return train_score_updater_->score();
300
301
}

Guolin Ke's avatar
Guolin Ke committed
302
void GBDT::GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) const {
Guolin Ke's avatar
Guolin Ke committed
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
  CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_metrics_.size()));
  std::vector<double> ret;

  const score_t* raw_scores = nullptr;
  data_size_t num_data = 0;
  if (data_idx == 0) {
    raw_scores = train_score_updater_->score();
    num_data = train_score_updater_->num_data();
  } else {
    auto used_idx = data_idx - 1;
    raw_scores = valid_score_updater_[used_idx]->score();
    num_data = valid_score_updater_[used_idx]->num_data();
  }
  *out_len = num_data * num_class_;

  if (num_class_ > 1) {
#pragma omp parallel for schedule(guided)
    for (data_size_t i = 0; i < num_data; ++i) {
      std::vector<double> tmp_result;
      for (int j = 0; j < num_class_; ++j) {
        tmp_result.push_back(raw_scores[j * num_data + i]);
      }
      Common::Softmax(&tmp_result);
      for (int j = 0; j < num_class_; ++j) {
        out_result[j * num_data + i] = static_cast<score_t>(tmp_result[i]);
      }
    }
  } else if(sigmoid_ > 0){
#pragma omp parallel for schedule(guided)
    for (data_size_t i = 0; i < num_data; ++i) {
      out_result[i] = static_cast<score_t>(1.0f / (1.0f + std::exp(-2.0f * sigmoid_ * raw_scores[i])));
    }
  } else {
#pragma omp parallel for schedule(guided)
    for (data_size_t i = 0; i < num_data; ++i) {
      out_result[i] = raw_scores[i];
    }
  }

}

Guolin Ke's avatar
Guolin Ke committed
344
void GBDT::Boosting() {
345
346
347
  if (object_function_ == nullptr) {
    Log::Fatal("No object function provided");
  }
Hui Xue's avatar
Hui Xue committed
348
  // objective function will calculate gradients and hessians
Guolin Ke's avatar
Guolin Ke committed
349
  int num_score = 0;
Guolin Ke's avatar
Guolin Ke committed
350
  object_function_->
Guolin Ke's avatar
Guolin Ke committed
351
    GetGradients(GetTrainingScore(&num_score), gradients_, hessians_);
Guolin Ke's avatar
Guolin Ke committed
352
353
}

Guolin Ke's avatar
Guolin Ke committed
354
void GBDT::SaveModelToFile(int num_used_model, bool is_finish, const char* filename) {
355
356

  // first time to this function, open file
Guolin Ke's avatar
Guolin Ke committed
357
  if (saved_model_size_ < 0) {
358
359
360
    model_output_file_.open(filename);
    // output model type
    model_output_file_ << "gbdt" << std::endl;
361
362
    // output number of class
    model_output_file_ << "num_class=" << num_class_ << std::endl;
363
364
365
366
367
368
369
370
371
372
373
374
375
    // output label index
    model_output_file_ << "label_index=" << label_idx_ << std::endl;
    // output max_feature_idx
    model_output_file_ << "max_feature_idx=" << max_feature_idx_ << std::endl;
    // output sigmoid parameter
    model_output_file_ << "sigmoid=" << object_function_->GetSigmoid() << std::endl;
    model_output_file_ << std::endl;
    saved_model_size_ = 0;
  }
  // already saved
  if (!model_output_file_.is_open()) {
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
376
  if (num_used_model == NO_LIMIT) {
Guolin Ke's avatar
Guolin Ke committed
377
378
379
380
381
    num_used_model = static_cast<int>(models_.size());
  } else {
    num_used_model = num_used_model * num_class_;
  }
  int rest = num_used_model - early_stopping_round_ * num_class_;
382
383
384
385
386
  // output tree models
  for (int i = saved_model_size_; i < rest; ++i) {
    model_output_file_ << "Tree=" << i << std::endl;
    model_output_file_ << models_[i]->ToString() << std::endl;
  }
387

Guolin Ke's avatar
Guolin Ke committed
388
  saved_model_size_ = Common::Max(saved_model_size_, rest);
389

390
391
392
  model_output_file_.flush();
  // training finished, can close file
  if (is_finish) {
Guolin Ke's avatar
Guolin Ke committed
393
    for (int i = saved_model_size_; i < num_used_model; ++i) {
394
395
396
397
398
      model_output_file_ << "Tree=" << i << std::endl;
      model_output_file_ << models_[i]->ToString() << std::endl;
    }
    model_output_file_ << std::endl << FeatureImportance() << std::endl;
    model_output_file_.close();
Guolin Ke's avatar
Guolin Ke committed
399
400
401
  }
}

Guolin Ke's avatar
Guolin Ke committed
402
void GBDT::LoadModelFromString(const std::string& model_str) {
Guolin Ke's avatar
Guolin Ke committed
403
404
405
406
  // use serialized string to restore this object
  models_.clear();
  std::vector<std::string> lines = Common::Split(model_str.c_str(), '\n');
  size_t i = 0;
407
408

  // get number of classes
409
410
411
412
413
414
415
416
417
418
419
420
  while (i < lines.size()) {
    size_t find_pos = lines[i].find("num_class=");
    if (find_pos != std::string::npos) {
      std::vector<std::string> strs = Common::Split(lines[i].c_str(), '=');
      Common::Atoi(strs[1].c_str(), &num_class_);
      ++i;
      break;
    } else {
      ++i;
    }
  }
  if (i == lines.size()) {
421
    Log::Fatal("Model file doesn't specify the number of classes");
422
423
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
424
425

  // get index of label
426
  i = 0;
Guolin Ke's avatar
Guolin Ke committed
427
428
429
430
431
432
433
434
435
436
437
438
  while (i < lines.size()) {
    size_t find_pos = lines[i].find("label_index=");
    if (find_pos != std::string::npos) {
      std::vector<std::string> strs = Common::Split(lines[i].c_str(), '=');
      Common::Atoi(strs[1].c_str(), &label_idx_);
      ++i;
      break;
    } else {
      ++i;
    }
  }
  if (i == lines.size()) {
439
    Log::Fatal("Model file doesn't specify the label index");
Guolin Ke's avatar
Guolin Ke committed
440
441
442
    return;
  }

Guolin Ke's avatar
Guolin Ke committed
443
  // get max_feature_idx first
Guolin Ke's avatar
Guolin Ke committed
444
  i = 0;
Guolin Ke's avatar
Guolin Ke committed
445
446
447
448
449
450
451
452
453
454
455
456
  while (i < lines.size()) {
    size_t find_pos = lines[i].find("max_feature_idx=");
    if (find_pos != std::string::npos) {
      std::vector<std::string> strs = Common::Split(lines[i].c_str(), '=');
      Common::Atoi(strs[1].c_str(), &max_feature_idx_);
      ++i;
      break;
    } else {
      ++i;
    }
  }
  if (i == lines.size()) {
457
    Log::Fatal("Model file doesn't specify max_feature_idx");
Guolin Ke's avatar
Guolin Ke committed
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
    return;
  }
  // get sigmoid parameter
  i = 0;
  while (i < lines.size()) {
    size_t find_pos = lines[i].find("sigmoid=");
    if (find_pos != std::string::npos) {
      std::vector<std::string> strs = Common::Split(lines[i].c_str(), '=');
      Common::Atof(strs[1].c_str(), &sigmoid_);
      ++i;
      break;
    } else {
      ++i;
    }
  }
  // if sigmoid doesn't exists
  if (i == lines.size()) {
475
    sigmoid_ = -1.0f;
Guolin Ke's avatar
Guolin Ke committed
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
  }
  // get tree models
  i = 0;
  while (i < lines.size()) {
    size_t find_pos = lines[i].find("Tree=");
    if (find_pos != std::string::npos) {
      ++i;
      int start = static_cast<int>(i);
      while (i < lines.size() && lines[i].find("Tree=") == std::string::npos) { ++i; }
      int end = static_cast<int>(i);
      std::string tree_str = Common::Join(lines, start, end, '\n');
      models_.push_back(new Tree(tree_str));
    } else {
      ++i;
    }
  }
492
  Log::Info("Finished loading %d models", models_.size());
493
  num_used_model_ = static_cast<int>(models_.size()) / num_class_;
Guolin Ke's avatar
Guolin Ke committed
494
495
}

496
std::string GBDT::FeatureImportance() const {
497
  std::vector<size_t> feature_importances(max_feature_idx_ + 1, 0);
498
    for (size_t iter = 0; iter < models_.size(); ++iter) {
499
500
        for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
            ++feature_importances[models_[iter]->split_feature_real(split_idx)];
wxchan's avatar
wxchan committed
501
502
        }
    }
503
504
505
506
507
508
509
510
511
    // store the importance first
    std::vector<std::pair<size_t, std::string>> pairs;
    for (size_t i = 0; i < feature_importances.size(); ++i) {
      pairs.emplace_back(feature_importances[i], train_data_->feature_names()[i]);
    }
    // sort the importance
    std::sort(pairs.begin(), pairs.end(),
      [](const std::pair<size_t, std::string>& lhs,
        const std::pair<size_t, std::string>& rhs) {
512
      return lhs.first > rhs.first;
513
    });
514
    std::stringstream str_buf;
515
    // write to model file
516
    str_buf << std::endl << "feature importances:" << std::endl;
517
    for (size_t i = 0; i < pairs.size(); ++i) {
518
      str_buf << pairs[i].second << "=" << std::to_string(pairs[i].first) << std::endl;
519
    }
520
    return str_buf.str();
wxchan's avatar
wxchan committed
521
522
}

523
524
std::vector<double> GBDT::PredictRaw(const double* value) const {
  std::vector<double> ret(num_class_, 0.0f);
525
  for (int i = 0; i < num_used_model_; ++i) {
526
527
528
    for (int j = 0; j < num_class_; ++j) {
      ret[j] += models_[i * num_class_ + j]->Predict(value);
    }
Guolin Ke's avatar
Guolin Ke committed
529
530
531
532
  }
  return ret;
}

533
std::vector<double> GBDT::Predict(const double* value) const {
534
  std::vector<double> ret(num_class_, 0.0f);
535
  for (int i = 0; i < num_used_model_; ++i) {
536
537
    for (int j = 0; j < num_class_; ++j) {
      ret[j] += models_[i * num_class_ + j]->Predict(value);
538
539
    }
  }
540
541
542
543
544
545
  // if need sigmoid transform
  if (sigmoid_ > 0 && num_class_ == 1) {
    ret[0] = 1.0f / (1.0f + std::exp(- 2.0f * sigmoid_ * ret[0]));
  } else if (num_class_ > 1) {
    Common::Softmax(&ret);
  }
546
547
548
  return ret;
}

549
std::vector<int> GBDT::PredictLeafIndex(const double* value) const {
wxchan's avatar
wxchan committed
550
  std::vector<int> ret;
551
  for (int i = 0; i < num_used_model_; ++i) {
552
553
554
    for (int j = 0; j < num_class_; ++j) {
      ret.push_back(models_[i * num_class_ + j]->PredictLeafIndex(value));
    }
wxchan's avatar
wxchan committed
555
556
557
558
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
559
}  // namespace LightGBM