"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "574d7800f899415511a66122f76381fa8dc22636"
gbdt.cpp 30.7 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
#include "gbdt.h"

3
#include <LightGBM/utils/openmp_wrapper.h>
4

Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
9
10
11
12
13
14
15
#include <LightGBM/utils/common.h>

#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>

#include <ctime>

#include <sstream>
#include <chrono>
#include <string>
#include <vector>
16
#include <utility>
Guolin Ke's avatar
Guolin Ke committed
17
18
19

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
20
21
22
#ifdef TIMETAG
std::chrono::duration<double, std::milli> boosting_time;
std::chrono::duration<double, std::milli> train_score_time;
Guolin Ke's avatar
Guolin Ke committed
23
std::chrono::duration<double, std::milli> out_of_bag_score_time;
Guolin Ke's avatar
Guolin Ke committed
24
25
26
27
28
29
30
std::chrono::duration<double, std::milli> valid_score_time;
std::chrono::duration<double, std::milli> metric_time;
std::chrono::duration<double, std::milli> bagging_time;
std::chrono::duration<double, std::milli> sub_gradient_time;
std::chrono::duration<double, std::milli> tree_time;
#endif // TIMETAG

31
GBDT::GBDT()
32
  :iter_(0),
33
34
35
36
37
  train_data_(nullptr),
  object_function_(nullptr),
  early_stopping_round_(0),
  max_feature_idx_(0),
  num_class_(1),
38
  sigmoid_(-1.0f),
39
  num_iteration_for_pred_(0),
40
  shrinkage_rate_(0.1f),
41
42
  num_init_iteration_(0),
  boost_from_average_(false) {
43
44
45
46
47
#pragma omp parallel
#pragma omp master
    {
      num_threads_ = omp_get_num_threads();
    }
Guolin Ke's avatar
Guolin Ke committed
48
49
50
}

GBDT::~GBDT() {
Guolin Ke's avatar
Guolin Ke committed
51
52
53
#ifdef TIMETAG
  Log::Info("GBDT::boosting costs %f", boosting_time * 1e-3);
  Log::Info("GBDT::train_score costs %f", train_score_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
54
  Log::Info("GBDT::out_of_bag_score costs %f", out_of_bag_score_time * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
55
56
57
58
59
60
  Log::Info("GBDT::valid_score costs %f", valid_score_time * 1e-3);
  Log::Info("GBDT::metric costs %f", metric_time * 1e-3);
  Log::Info("GBDT::bagging costs %f", bagging_time * 1e-3);
  Log::Info("GBDT::sub_gradient costs %f", sub_gradient_time * 1e-3);
  Log::Info("GBDT::tree costs %f", tree_time * 1e-3);
#endif
Guolin Ke's avatar
Guolin Ke committed
61
62
}

63
void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
64
                const std::vector<const Metric*>& training_metrics) {
65
  iter_ = 0;
wxchan's avatar
wxchan committed
66
  num_iteration_for_pred_ = 0;
67
  max_feature_idx_ = 0;
wxchan's avatar
wxchan committed
68
69
  num_class_ = config->num_class;
  train_data_ = nullptr;
Guolin Ke's avatar
Guolin Ke committed
70
  gbdt_config_ = nullptr;
71
  tree_learner_ = nullptr;
wxchan's avatar
wxchan committed
72
73
74
75
  ResetTrainingData(config, train_data, object_function, training_metrics);
}

void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
76
                             const std::vector<const Metric*>& training_metrics) {
Guolin Ke's avatar
Guolin Ke committed
77
  auto new_config = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
wxchan's avatar
wxchan committed
78
79
80
  if (train_data_ != nullptr && !train_data_->CheckAlign(*train_data)) {
    Log::Fatal("cannot reset training data, since new training data has different bin mappers");
  }
Guolin Ke's avatar
Guolin Ke committed
81
82
83
  early_stopping_round_ = new_config->early_stopping_round;
  shrinkage_rate_ = new_config->learning_rate;

Guolin Ke's avatar
Guolin Ke committed
84
  object_function_ = object_function;
Guolin Ke's avatar
Guolin Ke committed
85

Guolin Ke's avatar
Guolin Ke committed
86
  sigmoid_ = -1.0f;
wxchan's avatar
wxchan committed
87
  if (object_function_ != nullptr
88
      && std::string(object_function_->GetName()) == std::string("binary")) {
Guolin Ke's avatar
Guolin Ke committed
89
    // only binary classification need sigmoid transform
Guolin Ke's avatar
Guolin Ke committed
90
    sigmoid_ = new_config->sigmoid;
91
  }
Guolin Ke's avatar
Guolin Ke committed
92

Guolin Ke's avatar
Guolin Ke committed
93
  if (train_data_ != train_data && train_data != nullptr) {
94
95
    if (tree_learner_ == nullptr) {
      tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(new_config->tree_learner_type, &new_config->tree_config));
Guolin Ke's avatar
Guolin Ke committed
96
97
    }
    // init tree learner
98
    tree_learner_->Init(train_data);
Guolin Ke's avatar
Guolin Ke committed
99

Guolin Ke's avatar
Guolin Ke committed
100
101
102
103
104
105
    // push training metrics
    training_metrics_.clear();
    for (const auto& metric : training_metrics) {
      training_metrics_.push_back(metric);
    }
    training_metrics_.shrink_to_fit();
wxchan's avatar
wxchan committed
106
107
108
109
110
111
112
113
114
115
116
117
118
    // not same training data, need reset score and others
    // create score tracker
    train_score_updater_.reset(new ScoreUpdater(train_data, num_class_));
    // update score
    for (int i = 0; i < iter_; ++i) {
      for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
        auto curr_tree = (i + num_init_iteration_) * num_class_ + curr_class;
        train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
      }
    }
    num_data_ = train_data->num_data();
    // create buffer for gradients and hessians
    if (object_function_ != nullptr) {
119
120
121
      size_t total_size = static_cast<size_t>(num_data_) * num_class_;
      gradients_.resize(total_size);
      hessians_.resize(total_size);
wxchan's avatar
wxchan committed
122
123
124
125
126
    }
    // get max feature index
    max_feature_idx_ = train_data->num_total_features() - 1;
    // get label index
    label_idx_ = train_data->label_idx();
127
128
    // get feature names
    feature_names_ = train_data->feature_names();
Guolin Ke's avatar
Guolin Ke committed
129
130

    feature_infos_ = train_data->feature_infos();
Guolin Ke's avatar
Guolin Ke committed
131
132
  }

Guolin Ke's avatar
Guolin Ke committed
133
  if ((train_data_ != train_data && train_data != nullptr)
134
      || (gbdt_config_ != nullptr && gbdt_config_->bagging_fraction != new_config->bagging_fraction)) {
wxchan's avatar
wxchan committed
135
    // if need bagging, create buffer
Guolin Ke's avatar
Guolin Ke committed
136
    if (new_config->bagging_fraction < 1.0 && new_config->bagging_freq > 0) {
137
138
      bag_data_cnt_ =
        static_cast<data_size_t>(new_config->bagging_fraction * num_data_);
139
      bag_data_indices_.resize(num_data_);
140
141
142
143
144
145
      tmp_indices_.resize(num_data_);
      offsets_buf_.resize(num_threads_);
      left_cnts_buf_.resize(num_threads_);
      right_cnts_buf_.resize(num_threads_);
      left_write_pos_buf_.resize(num_threads_);
      right_write_pos_buf_.resize(num_threads_);
Guolin Ke's avatar
Guolin Ke committed
146
147
      double average_bag_rate = new_config->bagging_fraction / new_config->bagging_freq;
      is_use_subset_ = false;
148
      if (average_bag_rate <= 0.5) {
Guolin Ke's avatar
Guolin Ke committed
149
        tmp_subset_.reset(new Dataset(bag_data_cnt_));
150
        tmp_subset_->CopyFeatureMapperFrom(train_data);
Guolin Ke's avatar
Guolin Ke committed
151
152
153
        is_use_subset_ = true;
        Log::Debug("use subset for bagging");
      }
wxchan's avatar
wxchan committed
154
155
156
    } else {
      bag_data_cnt_ = num_data_;
      bag_data_indices_.clear();
157
      tmp_indices_.clear();
Guolin Ke's avatar
Guolin Ke committed
158
      is_use_subset_ = false;
wxchan's avatar
wxchan committed
159
    }
Guolin Ke's avatar
Guolin Ke committed
160
  }
wxchan's avatar
wxchan committed
161
  train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
162
163
  if (train_data_ != nullptr) {
    // reset config for tree learner
164
    tree_learner_->ResetConfig(&new_config->tree_config);
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
    class_need_train_ = std::vector<bool>(num_class_, true);
    if (num_class_ > 1 || sigmoid_ > 0) {
      // + 1 here for the binary classification
      class_default_output_ = std::vector<double>(num_class_ + 1, 0.0f);
      std::vector<data_size_t> cnt_per_class(num_class_, 0);
      auto label = train_data_->metadata().label();
      for (int i = 0; i < num_data_; ++i) {
        ++cnt_per_class[static_cast<int>(label[i])];
      }
      if (num_class_ > 1) {
        for (int i = 0; i < num_class_; ++i) {
          if (cnt_per_class[i] == num_data_) {
            Log::Warning("Only contain one class.");
            class_need_train_[i] = false;
            class_default_output_[i] = -std::log(kEpsilon);
          } else if (cnt_per_class[i] == 0) {
            class_need_train_[i] = false;
            class_default_output_[i] = -std::log(1.0f / kEpsilon - 1.0f);
          }
        }
      } else {
        // binary classification. 
        if (cnt_per_class[1] == 0) {
          class_need_train_[0] = false;
          class_default_output_[0] = -std::log(1.0f / kEpsilon - 1.0f);
        } else if (cnt_per_class[1] == num_data_) {
          class_need_train_[0] = false;
          class_default_output_[0] = -std::log(kEpsilon);
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
196
  }
Guolin Ke's avatar
Guolin Ke committed
197
  gbdt_config_.reset(new_config.release());
Guolin Ke's avatar
Guolin Ke committed
198
199
}

wxchan's avatar
wxchan committed
200
void GBDT::AddValidDataset(const Dataset* valid_data,
201
                           const std::vector<const Metric*>& valid_metrics) {
wxchan's avatar
wxchan committed
202
203
  if (!train_data_->CheckAlign(*valid_data)) {
    Log::Fatal("cannot add validation data, since it has different bin mappers with training data");
204
  }
Guolin Ke's avatar
Guolin Ke committed
205
  // for a validation dataset, we need its score and metric
Guolin Ke's avatar
Guolin Ke committed
206
  auto new_score_updater = std::unique_ptr<ScoreUpdater>(new ScoreUpdater(valid_data, num_class_));
wxchan's avatar
wxchan committed
207
208
209
210
211
212
213
  // update score
  for (int i = 0; i < iter_; ++i) {
    for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
      auto curr_tree = (i + num_init_iteration_) * num_class_ + curr_class;
      new_score_updater->AddScore(models_[curr_tree].get(), curr_class);
    }
  }
Guolin Ke's avatar
Guolin Ke committed
214
  valid_score_updater_.push_back(std::move(new_score_updater));
Guolin Ke's avatar
Guolin Ke committed
215
  valid_metrics_.emplace_back();
216
217
218
  if (early_stopping_round_ > 0) {
    best_iter_.emplace_back();
    best_score_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
219
    best_msg_.emplace_back();
220
  }
Guolin Ke's avatar
Guolin Ke committed
221
222
  for (const auto& metric : valid_metrics) {
    valid_metrics_.back().push_back(metric);
223
224
225
    if (early_stopping_round_ > 0) {
      best_iter_.back().push_back(0);
      best_score_.back().push_back(kMinScore);
Guolin Ke's avatar
Guolin Ke committed
226
      best_msg_.back().emplace_back();
227
    }
Guolin Ke's avatar
Guolin Ke committed
228
  }
Guolin Ke's avatar
Guolin Ke committed
229
  valid_metrics_.back().shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
230
231
}

232
data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer) {
233
234
235
  if (cnt <= 0) {
    return 0;
  }
236
237
238
239
  data_size_t bag_data_cnt =
    static_cast<data_size_t>(gbdt_config_->bagging_fraction * cnt);
  data_size_t cur_left_cnt = 0;
  data_size_t cur_right_cnt = 0;
Guolin Ke's avatar
Guolin Ke committed
240
  auto right_buffer = buffer + bag_data_cnt;
241
242
  // random bagging, minimal unit is one record
  for (data_size_t i = 0; i < cnt; ++i) {
Guolin Ke's avatar
Guolin Ke committed
243
244
245
    float prob =
      (bag_data_cnt - cur_left_cnt) / static_cast<float>(cnt - i);
    if (cur_rand.NextFloat() < prob) {
246
247
      buffer[cur_left_cnt++] = start + i;
    } else {
Guolin Ke's avatar
Guolin Ke committed
248
      right_buffer[cur_right_cnt++] = start + i;
249
250
251
252
253
    }
  }
  CHECK(cur_left_cnt == bag_data_cnt);
  return cur_left_cnt;
}
Guolin Ke's avatar
Guolin Ke committed
254

Guolin Ke's avatar
Guolin Ke committed
255
256


257
void GBDT::Bagging(int iter) {
Guolin Ke's avatar
Guolin Ke committed
258
  // if need bagging
259
  if (bag_data_cnt_ < num_data_ && iter % gbdt_config_->bagging_freq == 0) {
Guolin Ke's avatar
Guolin Ke committed
260
    const data_size_t min_inner_size = 1000;
261
262
    data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_;
    if (inner_size < min_inner_size) { inner_size = min_inner_size; }
263
    OMP_INIT_EX();
264
  #pragma omp parallel for schedule(static,1)
265
    for (int i = 0; i < num_threads_; ++i) {
266
      OMP_LOOP_EX_BEGIN();
267
268
269
270
271
272
      left_cnts_buf_[i] = 0;
      right_cnts_buf_[i] = 0;
      data_size_t cur_start = i * inner_size;
      if (cur_start > num_data_) { continue; }
      data_size_t cur_cnt = inner_size;
      if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; }
Guolin Ke's avatar
Guolin Ke committed
273
274
      Random cur_rand(gbdt_config_->bagging_seed + iter * num_threads_ + i);
      data_size_t cur_left_count = BaggingHelper(cur_rand, cur_start, cur_cnt, tmp_indices_.data() + cur_start);
275
276
277
      offsets_buf_[i] = cur_start;
      left_cnts_buf_[i] = cur_left_count;
      right_cnts_buf_[i] = cur_cnt - cur_left_count;
278
      OMP_LOOP_EX_END();
279
    }
280
    OMP_THROW_EX();
281
282
283
284
285
286
287
288
289
    data_size_t left_cnt = 0;
    left_write_pos_buf_[0] = 0;
    right_write_pos_buf_[0] = 0;
    for (int i = 1; i < num_threads_; ++i) {
      left_write_pos_buf_[i] = left_write_pos_buf_[i - 1] + left_cnts_buf_[i - 1];
      right_write_pos_buf_[i] = right_write_pos_buf_[i - 1] + right_cnts_buf_[i - 1];
    }
    left_cnt = left_write_pos_buf_[num_threads_ - 1] + left_cnts_buf_[num_threads_ - 1];

290
  #pragma omp parallel for schedule(static, 1)
291
    for (int i = 0; i < num_threads_; ++i) {
292
      OMP_LOOP_EX_BEGIN();
293
294
      if (left_cnts_buf_[i] > 0) {
        std::memcpy(bag_data_indices_.data() + left_write_pos_buf_[i],
295
                    tmp_indices_.data() + offsets_buf_[i], left_cnts_buf_[i] * sizeof(data_size_t));
Guolin Ke's avatar
Guolin Ke committed
296
      }
297
298
      if (right_cnts_buf_[i] > 0) {
        std::memcpy(bag_data_indices_.data() + left_cnt + right_write_pos_buf_[i],
299
                    tmp_indices_.data() + offsets_buf_[i] + left_cnts_buf_[i], right_cnts_buf_[i] * sizeof(data_size_t));
Guolin Ke's avatar
Guolin Ke committed
300
      }
301
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
302
    }
303
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
304
305
    bag_data_cnt_ = left_cnt;
    CHECK(bag_data_indices_[bag_data_cnt_ - 1] > bag_data_indices_[bag_data_cnt_]);
Guolin Ke's avatar
Guolin Ke committed
306
    Log::Debug("Re-bagging, using %d data to train", bag_data_cnt_);
Guolin Ke's avatar
Guolin Ke committed
307
    // set bagging data to tree learner
Guolin Ke's avatar
Guolin Ke committed
308
309
310
311
    if (!is_use_subset_) {
      tree_learner_->SetBaggingData(bag_data_indices_.data(), bag_data_cnt_);
    } else {
      // get subset
Guolin Ke's avatar
Guolin Ke committed
312
313
      tmp_subset_->ReSize(bag_data_cnt_);
      tmp_subset_->CopySubset(train_data_, bag_data_indices_.data(), bag_data_cnt_, false);
Guolin Ke's avatar
Guolin Ke committed
314
315
      tree_learner_->ResetTrainingData(tmp_subset_.get());
    }
Guolin Ke's avatar
Guolin Ke committed
316
317
318
  }
}

319
void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int curr_class) {
Guolin Ke's avatar
Guolin Ke committed
320
321
322
#ifdef TIMETAG
  auto start_time = std::chrono::steady_clock::now();
#endif
323
  // we need to predict out-of-bag scores of data for boosting
Guolin Ke's avatar
Guolin Ke committed
324
  if (num_data_ - bag_data_cnt_ > 0 && !is_use_subset_) {
325
    train_score_updater_->AddScore(tree, bag_data_indices_.data() + bag_data_cnt_, num_data_ - bag_data_cnt_, curr_class);
Guolin Ke's avatar
Guolin Ke committed
326
  }
Guolin Ke's avatar
Guolin Ke committed
327
#ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
328
  out_of_bag_score_time += std::chrono::steady_clock::now() - start_time;
Guolin Ke's avatar
Guolin Ke committed
329
#endif
Guolin Ke's avatar
Guolin Ke committed
330
331
}

332
bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) {
333
  // boosting from average prediction. It doesn't work well for classification, remove it for now.
334
335
336
  if (models_.empty() 
      && gbdt_config_->boost_from_average 
      && !train_score_updater_->has_init_score()
337
338
339
      && sigmoid_ < 0.0f
      && num_class_ <= 1) {
    double init_score = 0.0f;
340
    auto label = train_data_->metadata().label();
341
342
343
    #pragma omp parallel for schedule(static) reduction(+:init_score)
    for (data_size_t i = 0; i < num_data_; ++i) {
      init_score += label[i];
344
    }
345
346
347
348
349
350
351
352
    init_score /= num_data_;
    std::unique_ptr<Tree> new_tree(new Tree(2));
    new_tree->Split(0, 0, BinType::NumericalBin, 0, 0, 0, init_score, init_score, 0, num_data_, 1);
    train_score_updater_->AddScore(init_score, 0);
    for (auto& score_updater : valid_score_updater_) {
      score_updater->AddScore(init_score, 0);
    }
    models_.push_back(std::move(new_tree));
353
354
    boost_from_average_ = true;
  }
Guolin Ke's avatar
Guolin Ke committed
355
356
  // boosting first
  if (gradient == nullptr || hessian == nullptr) {
357
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
358
    auto start_time = std::chrono::steady_clock::now();
359
  #endif
Guolin Ke's avatar
Guolin Ke committed
360
361
362
    Boosting();
    gradient = gradients_.data();
    hessian = hessians_.data();
363
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
364
    boosting_time += std::chrono::steady_clock::now() - start_time;
365
  #endif
Guolin Ke's avatar
Guolin Ke committed
366
  }
Guolin Ke's avatar
Guolin Ke committed
367
368
369
#ifdef TIMETAG
  auto start_time = std::chrono::steady_clock::now();
#endif
370
371
  // bagging logic
  Bagging(iter_);
Guolin Ke's avatar
Guolin Ke committed
372
373
374
#ifdef TIMETAG
  bagging_time += std::chrono::steady_clock::now() - start_time;
#endif
Guolin Ke's avatar
Guolin Ke committed
375
  if (is_use_subset_ && bag_data_cnt_ < num_data_) {
376
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
377
    start_time = std::chrono::steady_clock::now();
378
  #endif
Guolin Ke's avatar
Guolin Ke committed
379
380
381
382
    if (gradients_.empty()) {
      size_t total_size = static_cast<size_t>(num_data_) * num_class_;
      gradients_.resize(total_size);
      hessians_.resize(total_size);
383
    }
Guolin Ke's avatar
Guolin Ke committed
384
385
386
    // get sub gradients
    for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
      auto bias = curr_class * num_data_;
387
      // cannot multi-threading here.
Guolin Ke's avatar
Guolin Ke committed
388
389
390
391
392
393
394
      for (int i = 0; i < bag_data_cnt_; ++i) {
        gradients_[bias + i] = gradient[bias + bag_data_indices_[i]];
        hessians_[bias + i] = hessian[bias + bag_data_indices_[i]];
      }
    }
    gradient = gradients_.data();
    hessian = hessians_.data();
395
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
396
    sub_gradient_time += std::chrono::steady_clock::now() - start_time;
397
  #endif
Guolin Ke's avatar
Guolin Ke committed
398
  }
Guolin Ke's avatar
Guolin Ke committed
399
  bool should_continue = false;
Guolin Ke's avatar
Guolin Ke committed
400
  for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
401
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
402
    start_time = std::chrono::steady_clock::now();
403
  #endif
404
405
406
407
408
    std::unique_ptr<Tree> new_tree(new Tree(2));
    if (class_need_train_[curr_class]) {
      new_tree.reset(
        tree_learner_->Train(gradient + curr_class * num_data_, hessian + curr_class * num_data_));
    }
409
  #ifdef TIMETAG
Guolin Ke's avatar
Guolin Ke committed
410
    tree_time += std::chrono::steady_clock::now() - start_time;
411
  #endif
Guolin Ke's avatar
Guolin Ke committed
412
413

    if (new_tree->num_leaves() > 1) {
Guolin Ke's avatar
Guolin Ke committed
414
415
416
417
418
419
      should_continue = true;
      // shrinkage by learning rate
      new_tree->Shrinkage(shrinkage_rate_);
      // update score
      UpdateScore(new_tree.get(), curr_class);
      UpdateScoreOutOfBag(new_tree.get(), curr_class);
420
421
422
423
424
425
426
427
428
429
430
431
    } else {
      // only add default score one-time
      if (!class_need_train_[curr_class] && models_.size() < num_class_) {
        auto output = class_default_output_[curr_class];
        new_tree->Split(0, 0, BinType::NumericalBin, 0, 0, 0, 
                        output, output, 0, num_data_, 1);
        train_score_updater_->AddScore(output, curr_class);
        for (auto& score_updater : valid_score_updater_) {
          score_updater->AddScore(output, curr_class);
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
432
433
434
    // add model
    models_.push_back(std::move(new_tree));
  }
Guolin Ke's avatar
Guolin Ke committed
435
  if (!should_continue) {
Guolin Ke's avatar
Guolin Ke committed
436
437
438
439
440
441
    Log::Warning("Stopped training because there are no more leaves that meet the split requirements.");
    for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
      models_.pop_back();
    }
    return true;
  }
Guolin Ke's avatar
Guolin Ke committed
442
443
444
445
446
447
  ++iter_;
  if (is_eval) {
    return EvalAndCheckEarlyStopping();
  } else {
    return false;
  }
448

Guolin Ke's avatar
Guolin Ke committed
449
}
450

wxchan's avatar
wxchan committed
451
void GBDT::RollbackOneIter() {
452
  if (iter_ <= 0) { return; }
wxchan's avatar
wxchan committed
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
  int cur_iter = iter_ + num_init_iteration_ - 1;
  // reset score
  for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
    auto curr_tree = cur_iter * num_class_ + curr_class;
    models_[curr_tree]->Shrinkage(-1.0);
    train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
    for (auto& score_updater : valid_score_updater_) {
      score_updater->AddScore(models_[curr_tree].get(), curr_class);
    }
  }
  // remove model
  for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
    models_.pop_back();
  }
  --iter_;
}

Guolin Ke's avatar
Guolin Ke committed
470
bool GBDT::EvalAndCheckEarlyStopping() {
471
  bool is_met_early_stopping = false;
Guolin Ke's avatar
Guolin Ke committed
472
473
474
#ifdef TIMETAG
  auto start_time = std::chrono::steady_clock::now();
#endif
475
  // print message for metric
Guolin Ke's avatar
Guolin Ke committed
476
  auto best_msg = OutputMetric(iter_);
Guolin Ke's avatar
Guolin Ke committed
477
478
479
#ifdef TIMETAG
  metric_time += std::chrono::steady_clock::now() - start_time;
#endif
Guolin Ke's avatar
Guolin Ke committed
480
  is_met_early_stopping = !best_msg.empty();
481
482
  if (is_met_early_stopping) {
    Log::Info("Early stopping at iteration %d, the best iteration round is %d",
483
              iter_, iter_ - early_stopping_round_);
Guolin Ke's avatar
Guolin Ke committed
484
    Log::Info("Output of best iteration round:\n%s", best_msg.c_str());
485
    // pop last early_stopping_round_ models
486
    for (int i = 0; i < early_stopping_round_ * num_class_; ++i) {
487
488
489
490
      models_.pop_back();
    }
  }
  return is_met_early_stopping;
Guolin Ke's avatar
Guolin Ke committed
491
492
}

493
void GBDT::UpdateScore(const Tree* tree, const int curr_class) {
Guolin Ke's avatar
Guolin Ke committed
494
495
496
#ifdef TIMETAG
  auto start_time = std::chrono::steady_clock::now();
#endif
Guolin Ke's avatar
Guolin Ke committed
497
  // update training score
Guolin Ke's avatar
Guolin Ke committed
498
  if (!is_use_subset_) {
Guolin Ke's avatar
Guolin Ke committed
499
    train_score_updater_->AddScore(tree_learner_.get(), tree, curr_class);
Guolin Ke's avatar
Guolin Ke committed
500
501
502
  } else {
    train_score_updater_->AddScore(tree, curr_class);
  }
Guolin Ke's avatar
Guolin Ke committed
503
504
505
506
507
508
#ifdef TIMETAG
  train_score_time += std::chrono::steady_clock::now() - start_time;
#endif
#ifdef TIMETAG
  start_time = std::chrono::steady_clock::now();
#endif
Guolin Ke's avatar
Guolin Ke committed
509
  // update validation score
Guolin Ke's avatar
Guolin Ke committed
510
511
  for (auto& score_updater : valid_score_updater_) {
    score_updater->AddScore(tree, curr_class);
Guolin Ke's avatar
Guolin Ke committed
512
  }
Guolin Ke's avatar
Guolin Ke committed
513
514
515
#ifdef TIMETAG
  valid_score_time += std::chrono::steady_clock::now() - start_time;
#endif
Guolin Ke's avatar
Guolin Ke committed
516
517
}

Guolin Ke's avatar
Guolin Ke committed
518
519
520
521
std::string GBDT::OutputMetric(int iter) {
  bool need_output = (iter % gbdt_config_->output_freq) == 0;
  std::string ret = "";
  std::stringstream msg_buf;
522
  std::vector<std::pair<size_t, size_t>> meet_early_stopping_pairs;
Guolin Ke's avatar
Guolin Ke committed
523
  // print training metric
Guolin Ke's avatar
Guolin Ke committed
524
  if (need_output) {
525
526
527
    for (auto& sub_metric : training_metrics_) {
      auto name = sub_metric->GetName();
      auto scores = sub_metric->Eval(train_score_updater_->score());
Guolin Ke's avatar
Guolin Ke committed
528
      for (size_t k = 0; k < name.size(); ++k) {
Guolin Ke's avatar
Guolin Ke committed
529
530
531
532
533
534
535
536
        std::stringstream tmp_buf;
        tmp_buf << "Iteration:" << iter
          << ", training " << name[k]
          << " : " << scores[k];
        Log::Info(tmp_buf.str().c_str());
        if (early_stopping_round_ > 0) {
          msg_buf << tmp_buf.str() << std::endl;
        }
537
      }
538
    }
Guolin Ke's avatar
Guolin Ke committed
539
540
  }
  // print validation metric
Guolin Ke's avatar
Guolin Ke committed
541
  if (need_output || early_stopping_round_ > 0) {
542
543
544
    for (size_t i = 0; i < valid_metrics_.size(); ++i) {
      for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
        auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score());
Guolin Ke's avatar
Guolin Ke committed
545
546
547
548
549
550
551
552
553
554
555
        auto name = valid_metrics_[i][j]->GetName();
        for (size_t k = 0; k < name.size(); ++k) {
          std::stringstream tmp_buf;
          tmp_buf << "Iteration:" << iter
            << ", valid_" << i + 1 << " " << name[k]
            << " : " << test_scores[k];
          if (need_output) {
            Log::Info(tmp_buf.str().c_str());
          }
          if (early_stopping_round_ > 0) {
            msg_buf << tmp_buf.str() << std::endl;
556
          }
wxchan's avatar
wxchan committed
557
        }
Guolin Ke's avatar
Guolin Ke committed
558
        if (ret.empty() && early_stopping_round_ > 0) {
559
560
561
          auto cur_score = valid_metrics_[i][j]->factor_to_bigger_better() * test_scores.back();
          if (cur_score > best_score_[i][j]) {
            best_score_[i][j] = cur_score;
562
            best_iter_[i][j] = iter;
Guolin Ke's avatar
Guolin Ke committed
563
            meet_early_stopping_pairs.emplace_back(i, j);
564
          } else {
Guolin Ke's avatar
Guolin Ke committed
565
            if (iter - best_iter_[i][j] >= early_stopping_round_) { ret = best_msg_[i][j]; }
566
          }
wxchan's avatar
wxchan committed
567
568
        }
      }
Guolin Ke's avatar
Guolin Ke committed
569
570
    }
  }
Guolin Ke's avatar
Guolin Ke committed
571
572
573
  for (auto& pair : meet_early_stopping_pairs) {
    best_msg_[pair.first][pair.second] = msg_buf.str();
  }
wxchan's avatar
wxchan committed
574
  return ret;
Guolin Ke's avatar
Guolin Ke committed
575
576
}

577
/*! \brief Get eval result */
578
std::vector<double> GBDT::GetEvalAt(int data_idx) const {
Guolin Ke's avatar
Guolin Ke committed
579
  CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
580
581
  std::vector<double> ret;
  if (data_idx == 0) {
582
583
    for (auto& sub_metric : training_metrics_) {
      auto scores = sub_metric->Eval(train_score_updater_->score());
584
585
586
      for (auto score : scores) {
        ret.push_back(score);
      }
587
    }
588
  } else {
589
590
591
592
593
594
    auto used_idx = data_idx - 1;
    for (size_t j = 0; j < valid_metrics_[used_idx].size(); ++j) {
      auto test_scores = valid_metrics_[used_idx][j]->Eval(valid_score_updater_[used_idx]->score());
      for (auto score : test_scores) {
        ret.push_back(score);
      }
595
596
597
598
599
    }
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
600
/*! \brief Get training scores result */
601
const double* GBDT::GetTrainingScore(int64_t* out_len) {
602
  *out_len = static_cast<int64_t>(train_score_updater_->num_data()) * num_class_;
Guolin Ke's avatar
Guolin Ke committed
603
  return train_score_updater_->score();
604
605
}

Guolin Ke's avatar
Guolin Ke committed
606
607
void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
  CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
Guolin Ke's avatar
Guolin Ke committed
608

609
  const double* raw_scores = nullptr;
Guolin Ke's avatar
Guolin Ke committed
610
611
  data_size_t num_data = 0;
  if (data_idx == 0) {
wxchan's avatar
wxchan committed
612
    raw_scores = GetTrainingScore(out_len);
Guolin Ke's avatar
Guolin Ke committed
613
614
615
616
617
    num_data = train_score_updater_->num_data();
  } else {
    auto used_idx = data_idx - 1;
    raw_scores = valid_score_updater_[used_idx]->score();
    num_data = valid_score_updater_[used_idx]->num_data();
618
    *out_len = static_cast<int64_t>(num_data) * num_class_;
Guolin Ke's avatar
Guolin Ke committed
619
620
  }
  if (num_class_ > 1) {
621
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
622
    for (data_size_t i = 0; i < num_data; ++i) {
623
      std::vector<double> tmp_result(num_class_);
Guolin Ke's avatar
Guolin Ke committed
624
      for (int j = 0; j < num_class_; ++j) {
625
        tmp_result[j] = raw_scores[j * num_data + i];
Guolin Ke's avatar
Guolin Ke committed
626
627
628
      }
      Common::Softmax(&tmp_result);
      for (int j = 0; j < num_class_; ++j) {
Guolin Ke's avatar
Guolin Ke committed
629
        out_result[j * num_data + i] = static_cast<double>(tmp_result[j]);
Guolin Ke's avatar
Guolin Ke committed
630
631
      }
    }
632
633
  } else if (sigmoid_ > 0.0f) {
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
634
    for (data_size_t i = 0; i < num_data; ++i) {
635
      out_result[i] = static_cast<double>(1.0f / (1.0f + std::exp(-sigmoid_ * raw_scores[i])));
Guolin Ke's avatar
Guolin Ke committed
636
637
    }
  } else {
638
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
639
    for (data_size_t i = 0; i < num_data; ++i) {
Guolin Ke's avatar
Guolin Ke committed
640
      out_result[i] = static_cast<double>(raw_scores[i]);
Guolin Ke's avatar
Guolin Ke committed
641
642
643
644
645
    }
  }

}

Guolin Ke's avatar
Guolin Ke committed
646
void GBDT::Boosting() {
647
648
649
  if (object_function_ == nullptr) {
    Log::Fatal("No object function provided");
  }
Hui Xue's avatar
Hui Xue committed
650
  // objective function will calculate gradients and hessians
651
  int64_t num_score = 0;
Guolin Ke's avatar
Guolin Ke committed
652
  object_function_->
Guolin Ke's avatar
Guolin Ke committed
653
    GetGradients(GetTrainingScore(&num_score), gradients_.data(), hessians_.data());
Guolin Ke's avatar
Guolin Ke committed
654
655
}

656
std::string GBDT::DumpModel(int num_iteration) const {
Guolin Ke's avatar
Guolin Ke committed
657
  std::stringstream str_buf;
wxchan's avatar
wxchan committed
658

Guolin Ke's avatar
Guolin Ke committed
659
  str_buf << "{";
Guolin Ke's avatar
Guolin Ke committed
660
  str_buf << "\"name\":\"" << SubModelName() << "\"," << std::endl;
Guolin Ke's avatar
Guolin Ke committed
661
662
663
664
  str_buf << "\"num_class\":" << num_class_ << "," << std::endl;
  str_buf << "\"label_index\":" << label_idx_ << "," << std::endl;
  str_buf << "\"max_feature_idx\":" << max_feature_idx_ << "," << std::endl;
  str_buf << "\"sigmoid\":" << sigmoid_ << "," << std::endl;
wxchan's avatar
wxchan committed
665

666
667
668
  str_buf << "\"feature_names\":[\""
    << Common::Join(feature_names_, "\",\"") << "\"],"
    << std::endl;
Guolin Ke's avatar
Guolin Ke committed
669

Guolin Ke's avatar
Guolin Ke committed
670
  str_buf << "\"tree_info\":[";
671
672
  int num_used_model = static_cast<int>(models_.size());
  if (num_iteration > 0) {
Guolin Ke's avatar
Guolin Ke committed
673
    num_iteration += boost_from_average_ ? 1 : 0;
674
    num_used_model = std::min(num_iteration * num_class_, num_used_model);
675
  }
676
  for (int i = 0; i < num_used_model; ++i) {
wxchan's avatar
wxchan committed
677
    if (i > 0) {
Guolin Ke's avatar
Guolin Ke committed
678
      str_buf << ",";
wxchan's avatar
wxchan committed
679
    }
Guolin Ke's avatar
Guolin Ke committed
680
681
682
683
    str_buf << "{";
    str_buf << "\"tree_index\":" << i << ",";
    str_buf << models_[i]->ToJSON();
    str_buf << "}";
wxchan's avatar
wxchan committed
684
  }
Guolin Ke's avatar
Guolin Ke committed
685
  str_buf << "]" << std::endl;
wxchan's avatar
wxchan committed
686

Guolin Ke's avatar
Guolin Ke committed
687
  str_buf << "}" << std::endl;
wxchan's avatar
wxchan committed
688

Guolin Ke's avatar
Guolin Ke committed
689
  return str_buf.str();
wxchan's avatar
wxchan committed
690
691
}

Guolin Ke's avatar
Guolin Ke committed
692
std::string GBDT::SaveModelToString(int num_iteration) const {
693
  std::stringstream ss;
694

695
696
697
698
699
700
701
702
703
704
705
706
707
708
  // output model type
  ss << SubModelName() << std::endl;
  // output number of class
  ss << "num_class=" << num_class_ << std::endl;
  // output label index
  ss << "label_index=" << label_idx_ << std::endl;
  // output max_feature_idx
  ss << "max_feature_idx=" << max_feature_idx_ << std::endl;
  // output objective name
  if (object_function_ != nullptr) {
    ss << "objective=" << object_function_->GetName() << std::endl;
  }
  // output sigmoid parameter
  ss << "sigmoid=" << sigmoid_ << std::endl;
709

710
711
712
  if (boost_from_average_) {
    ss << "boost_from_average" << std::endl;
  }
Guolin Ke's avatar
Guolin Ke committed
713

714
  ss << "feature_names=" << Common::Join(feature_names_, " ") << std::endl;
715

716
  ss << "feature_infos=" << Common::Join(feature_infos_, " ") << std::endl;
717

718
719
  ss << std::endl;
  int num_used_model = static_cast<int>(models_.size());
Guolin Ke's avatar
Guolin Ke committed
720
721
722
  if (num_iteration > 0) {
    num_iteration += boost_from_average_ ? 1 : 0;
    num_used_model = std::min(num_iteration * num_class_, num_used_model);
723
724
725
726
727
728
729
730
731
732
733
734
735
736
  }
  // output tree models
  for (int i = 0; i < num_used_model; ++i) {
    ss << "Tree=" << i << std::endl;
    ss << models_[i]->ToString() << std::endl;
  }

  std::vector<std::pair<size_t, std::string>> pairs = FeatureImportance();
  ss << std::endl << "feature importances:" << std::endl;
  for (size_t i = 0; i < pairs.size(); ++i) {
    ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << std::endl;
  }

  return ss.str();
737
738
}

739
bool GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
wxchan's avatar
wxchan committed
740
741
742
  /*! \brief File to write models */
  std::ofstream output_file;
  output_file.open(filename);
743

744
  output_file << SaveModelToString(num_iteration);
745

wxchan's avatar
wxchan committed
746
  output_file.close();
747
748

  return (bool)output_file;
Guolin Ke's avatar
Guolin Ke committed
749
750
}

751
bool GBDT::LoadModelFromString(const std::string& model_str) {
Guolin Ke's avatar
Guolin Ke committed
752
753
754
  // use serialized string to restore this object
  models_.clear();
  std::vector<std::string> lines = Common::Split(model_str.c_str(), '\n');
755
756

  // get number of classes
757
758
759
760
  auto line = Common::FindFromLines(lines, "num_class=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &num_class_);
  } else {
761
    Log::Fatal("Model file doesn't specify the number of classes");
762
    return false;
763
  }
Guolin Ke's avatar
Guolin Ke committed
764
  // get index of label
765
766
767
768
  line = Common::FindFromLines(lines, "label_index=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &label_idx_);
  } else {
769
    Log::Fatal("Model file doesn't specify the label index");
770
    return false;
Guolin Ke's avatar
Guolin Ke committed
771
  }
Guolin Ke's avatar
Guolin Ke committed
772
  // get max_feature_idx first
773
774
775
776
  line = Common::FindFromLines(lines, "max_feature_idx=");
  if (line.size() > 0) {
    Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &max_feature_idx_);
  } else {
777
    Log::Fatal("Model file doesn't specify max_feature_idx");
778
    return false;
Guolin Ke's avatar
Guolin Ke committed
779
780
  }
  // get sigmoid parameter
781
782
783
784
  line = Common::FindFromLines(lines, "sigmoid=");
  if (line.size() > 0) {
    Common::Atof(Common::Split(line.c_str(), '=')[1].c_str(), &sigmoid_);
  } else {
785
    sigmoid_ = -1.0f;
Guolin Ke's avatar
Guolin Ke committed
786
  }
787
788
789
790
791
  // get boost_from_average_
  line = Common::FindFromLines(lines, "boost_from_average");
  if (line.size() > 0) {
    boost_from_average_ = true;
  }
Guolin Ke's avatar
Guolin Ke committed
792
793
794
  // get feature names
  line = Common::FindFromLines(lines, "feature_names=");
  if (line.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
795
    feature_names_ = Common::Split(line.substr(std::strlen("feature_names=")).c_str(), " ");
Guolin Ke's avatar
Guolin Ke committed
796
797
    if (feature_names_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of feature_names");
798
      return false;
Guolin Ke's avatar
Guolin Ke committed
799
    }
800
  } else {
Guolin Ke's avatar
Guolin Ke committed
801
    Log::Fatal("Model file doesn't contain feature names");
802
    return false;
Guolin Ke's avatar
Guolin Ke committed
803
804
  }

Guolin Ke's avatar
Guolin Ke committed
805
806
807
808
809
810
811
812
813
814
815
816
  line = Common::FindFromLines(lines, "feature_infos=");
  if (line.size() > 0) {
    feature_infos_ = Common::Split(line.substr(std::strlen("feature_infos=")).c_str(), " ");
    if (feature_infos_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of feature_infos");
      return false;
    }
  } else {
    Log::Fatal("Model file doesn't contain feature infos");
    return false;
  }

Guolin Ke's avatar
Guolin Ke committed
817
  // get tree models
818
  size_t i = 0;
Guolin Ke's avatar
Guolin Ke committed
819
820
821
822
823
824
825
  while (i < lines.size()) {
    size_t find_pos = lines[i].find("Tree=");
    if (find_pos != std::string::npos) {
      ++i;
      int start = static_cast<int>(i);
      while (i < lines.size() && lines[i].find("Tree=") == std::string::npos) { ++i; }
      int end = static_cast<int>(i);
Guolin Ke's avatar
Guolin Ke committed
826
      std::string tree_str = Common::Join<std::string>(lines, start, end, "\n");
Guolin Ke's avatar
Guolin Ke committed
827
828
      auto new_tree = std::unique_ptr<Tree>(new Tree(tree_str));
      models_.push_back(std::move(new_tree));
Guolin Ke's avatar
Guolin Ke committed
829
830
831
832
    } else {
      ++i;
    }
  }
833
  Log::Info("Finished loading %d models", models_.size());
wxchan's avatar
wxchan committed
834
835
  num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_class_;
  num_init_iteration_ = num_iteration_for_pred_;
836
  iter_ = 0;
837
838

  return true;
Guolin Ke's avatar
Guolin Ke committed
839
840
}

wxchan's avatar
wxchan committed
841
std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const {
842

843
  std::vector<size_t> feature_importances(max_feature_idx_ + 1, 0);
844
845
846
  for (size_t iter = 0; iter < models_.size(); ++iter) {
    for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
      ++feature_importances[models_[iter]->split_feature(split_idx)];
wxchan's avatar
wxchan committed
847
    }
848
849
850
851
852
853
  }
  // store the importance first
  std::vector<std::pair<size_t, std::string>> pairs;
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    if (feature_importances[i] > 0) {
      pairs.emplace_back(feature_importances[i], feature_names_[i]);
854
    }
855
856
857
858
859
860
861
862
  }
  // sort the importance
  std::sort(pairs.begin(), pairs.end(),
            [] (const std::pair<size_t, std::string>& lhs,
                const std::pair<size_t, std::string>& rhs) {
    return lhs.first > rhs.first;
  });
  return pairs;
wxchan's avatar
wxchan committed
863
864
}

865
866
std::vector<double> GBDT::PredictRaw(const double* value) const {
  std::vector<double> ret(num_class_, 0.0f);
wxchan's avatar
wxchan committed
867
  for (int i = 0; i < num_iteration_for_pred_; ++i) {
868
869
870
    for (int j = 0; j < num_class_; ++j) {
      ret[j] += models_[i * num_class_ + j]->Predict(value);
    }
Guolin Ke's avatar
Guolin Ke committed
871
872
873
874
  }
  return ret;
}

875
std::vector<double> GBDT::Predict(const double* value) const {
876
  std::vector<double> ret(num_class_, 0.0f);
wxchan's avatar
wxchan committed
877
  for (int i = 0; i < num_iteration_for_pred_; ++i) {
878
879
    for (int j = 0; j < num_class_; ++j) {
      ret[j] += models_[i * num_class_ + j]->Predict(value);
880
881
    }
  }
882
883
  // if need sigmoid transform
  if (sigmoid_ > 0 && num_class_ == 1) {
884
    ret[0] = 1.0f / (1.0f + std::exp(-sigmoid_ * ret[0]));
885
886
887
  } else if (num_class_ > 1) {
    Common::Softmax(&ret);
  }
888
889
890
  return ret;
}

891
std::vector<int> GBDT::PredictLeafIndex(const double* value) const {
wxchan's avatar
wxchan committed
892
  std::vector<int> ret;
wxchan's avatar
wxchan committed
893
  for (int i = 0; i < num_iteration_for_pred_; ++i) {
894
895
896
    for (int j = 0; j < num_class_; ++j) {
      ret.push_back(models_[i * num_class_ + j]->PredictLeafIndex(value));
    }
wxchan's avatar
wxchan committed
897
898
899
900
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
901
}  // namespace LightGBM