rf.hpp 7.39 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#ifndef LIGHTGBM_BOOSTING_RF_H_
#define LIGHTGBM_BOOSTING_RF_H_

#include <LightGBM/boosting.h>
#include <LightGBM/metric.h>
#include "score_updater.hpp"
#include "gbdt.h"

#include <cstdio>
#include <vector>
#include <string>
#include <fstream>

namespace LightGBM {
/*!
* \brief Rondom Forest implementation
*/
Guolin Ke's avatar
Guolin Ke committed
18
class RF : public GBDT {
Guolin Ke's avatar
Guolin Ke committed
19
20
public:

Guolin Ke's avatar
Guolin Ke committed
21
  RF() : GBDT() {
Guolin Ke's avatar
Guolin Ke committed
22
23
24
25
26
    average_output_ = true;
  }

  ~RF() {}

Guolin Ke's avatar
Guolin Ke committed
27
  void Init(const Config* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
Guolin Ke's avatar
Guolin Ke committed
28
    const std::vector<const Metric*>& training_metrics) override {
Guolin Ke's avatar
Guolin Ke committed
29
    CHECK(config->bagging_freq > 0 && config->bagging_fraction < 1.0f && config->bagging_fraction > 0.0f);
30
    CHECK(config->feature_fraction <= 1.0f && config->feature_fraction > 0.0f);
Guolin Ke's avatar
Guolin Ke committed
31
32
33
34
35
36
37
38
39
    GBDT::Init(config, train_data, objective_function, training_metrics);

    if (num_init_iteration_ > 0) {
      for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
        MultiplyScore(cur_tree_id, 1.0f / num_init_iteration_);
      }
    } else {
      CHECK(train_data->metadata().init_score() == nullptr);
    }
40
    CHECK(num_tree_per_iteration_ == num_class_);
Guolin Ke's avatar
Guolin Ke committed
41
42
43
    // not shrinkage rate for the RF
    shrinkage_rate_ = 1.0f;
    // only boosting one time
Guolin Ke's avatar
Guolin Ke committed
44
    Boosting();
Guolin Ke's avatar
Guolin Ke committed
45
    if (is_use_subset_ && bag_data_cnt_ < num_data_) {
46
47
      tmp_grad_.resize(num_data_);
      tmp_hess_.resize(num_data_);
Guolin Ke's avatar
Guolin Ke committed
48
49
50
    }
  }

Guolin Ke's avatar
Guolin Ke committed
51
  void ResetConfig(const Config* config) override {
Guolin Ke's avatar
Guolin Ke committed
52
    CHECK(config->bagging_freq > 0 && config->bagging_fraction < 1.0f && config->bagging_fraction > 0.0f);
53
    CHECK(config->feature_fraction <= 1.0f && config->feature_fraction > 0.0f);
Guolin Ke's avatar
Guolin Ke committed
54
55
56
57
58
59
    GBDT::ResetConfig(config);
    // not shrinkage rate for the RF
    shrinkage_rate_ = 1.0f;
  }

  void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
Guolin Ke's avatar
Guolin Ke committed
60
    const std::vector<const Metric*>& training_metrics) override {
Guolin Ke's avatar
Guolin Ke committed
61
62
63
64
65
66
    GBDT::ResetTrainingData(train_data, objective_function, training_metrics);
    if (iter_ + num_init_iteration_ > 0) {
      for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
        train_score_updater_->MultiplyScore(1.0f / (iter_ + num_init_iteration_), cur_tree_id);
      }
    }
67
    CHECK(num_tree_per_iteration_ == num_class_);
Guolin Ke's avatar
Guolin Ke committed
68
    // only boosting one time
Guolin Ke's avatar
Guolin Ke committed
69
    Boosting();
Guolin Ke's avatar
Guolin Ke committed
70
    if (is_use_subset_ && bag_data_cnt_ < num_data_) {
71
72
      tmp_grad_.resize(num_data_);
      tmp_hess_.resize(num_data_);
Guolin Ke's avatar
Guolin Ke committed
73
74
75
    }
  }

Guolin Ke's avatar
Guolin Ke committed
76
77
78
79
80
81
82
  void Boosting() override {
    if (objective_function_ == nullptr) {
      Log::Fatal("No object function provided");
    }
    init_scores_.resize(num_tree_per_iteration_, 0.0);
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
      init_scores_[cur_tree_id] = BoostFromAverage(cur_tree_id, false);
83
    }
Guolin Ke's avatar
Guolin Ke committed
84
85
86
87
88
89
90
    size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
    std::vector<double> tmp_scores(total_size, 0.0f);
    #pragma omp parallel for schedule(static)
    for (int j = 0; j < num_tree_per_iteration_; ++j) {
      size_t bias = static_cast<size_t>(j)* num_data_;
      for (data_size_t i = 0; i < num_data_; ++i) {
        tmp_scores[bias + i] = init_scores_[j];
91
      }
Guolin Ke's avatar
Guolin Ke committed
92
    }
Guolin Ke's avatar
Guolin Ke committed
93
94
    objective_function_->
      GetGradients(tmp_scores.data(), gradients_.data(), hessians_.data());
Guolin Ke's avatar
Guolin Ke committed
95
96
  }

Guolin Ke's avatar
Guolin Ke committed
97
  bool TrainOneIter(const score_t* gradients, const score_t* hessians) override {
Guolin Ke's avatar
Guolin Ke committed
98
99
    // bagging logic
    Bagging(iter_);
100
101
    CHECK(gradients == nullptr);
    CHECK(hessians == nullptr);
Guolin Ke's avatar
Guolin Ke committed
102

103
104
    gradients = gradients_.data();
    hessians = hessians_.data();
Guolin Ke's avatar
Guolin Ke committed
105
106
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
      std::unique_ptr<Tree> new_tree(new Tree(2));
107
      size_t bias = static_cast<size_t>(cur_tree_id)* num_data_;
Guolin Ke's avatar
Guolin Ke committed
108
109
110
111
112
113
114
115
116
117
118
119
120
      if (class_need_train_[cur_tree_id]) {

        auto grad = gradients + bias;
        auto hess = hessians + bias;

        // need to copy gradients for bagging subset.
        if (is_use_subset_ && bag_data_cnt_ < num_data_) {
          for (int i = 0; i < bag_data_cnt_; ++i) {
            tmp_grad_[i] = grad[bag_data_indices_[i]];
            tmp_hess_[i] = hess[bag_data_indices_[i]];
          }
          grad = tmp_grad_.data();
          hess = tmp_hess_.data();
Guolin Ke's avatar
Guolin Ke committed
121
        }
Guolin Ke's avatar
Guolin Ke committed
122
123
124

        new_tree.reset(tree_learner_->Train(grad, hess, is_constant_hessian_,
          forced_splits_json_));
Guolin Ke's avatar
Guolin Ke committed
125
      }
Guolin Ke's avatar
Guolin Ke committed
126

Guolin Ke's avatar
Guolin Ke committed
127
      if (new_tree->num_leaves() > 1) {
Guolin Ke's avatar
Guolin Ke committed
128
        tree_learner_->RenewTreeOutput(new_tree.get(), objective_function_, init_scores_[cur_tree_id],
129
          num_data_, bag_data_indices_.data(), bag_data_cnt_);
Guolin Ke's avatar
Guolin Ke committed
130
131
132
        if (std::fabs(init_scores_[cur_tree_id]) > kEpsilon) {
          new_tree->AddBias(init_scores_[cur_tree_id]);
        }
Guolin Ke's avatar
Guolin Ke committed
133
134
135
136
        // update score
        MultiplyScore(cur_tree_id, (iter_ + num_init_iteration_));
        UpdateScore(new_tree.get(), cur_tree_id);
        MultiplyScore(cur_tree_id, 1.0 / (iter_ + num_init_iteration_ + 1));
Guolin Ke's avatar
Guolin Ke committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
      } else {
        // only add default score one-time
        if (models_.size() < static_cast<size_t>(num_tree_per_iteration_)) {
          double output = 0.0;
          if (!class_need_train_[cur_tree_id]) {
            if (objective_function_ != nullptr) {
              output = objective_function_->BoostFromScore(cur_tree_id);
            } else {
              output = init_scores_[cur_tree_id];
            }
          }
          new_tree->AsConstantTree(output);
          MultiplyScore(cur_tree_id, (iter_ + num_init_iteration_));
          UpdateScore(new_tree.get(), cur_tree_id);
          MultiplyScore(cur_tree_id, 1.0 / (iter_ + num_init_iteration_ + 1));
        }
Guolin Ke's avatar
Guolin Ke committed
153
154
155
156
157
      }
      // add model
      models_.push_back(std::move(new_tree));
    }
    ++iter_;
Guolin Ke's avatar
Guolin Ke committed
158
    return false;
Guolin Ke's avatar
Guolin Ke committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
  }

  void RollbackOneIter() override {
    if (iter_ <= 0) { return; }
    int cur_iter = iter_ + num_init_iteration_ - 1;
    // reset score
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
      auto curr_tree = cur_iter * num_tree_per_iteration_ + cur_tree_id;
      models_[curr_tree]->Shrinkage(-1.0);
      MultiplyScore(cur_tree_id, (iter_ + num_init_iteration_));
      train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
      for (auto& score_updater : valid_score_updater_) {
        score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
      }
      MultiplyScore(cur_tree_id, 1.0f / (iter_ + num_init_iteration_ - 1));
    }
    // remove model
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
      models_.pop_back();
    }
    --iter_;
  }

  void MultiplyScore(const int cur_tree_id, double val) {
    train_score_updater_->MultiplyScore(val, cur_tree_id);
    for (auto& score_updater : valid_score_updater_) {
      score_updater->MultiplyScore(val, cur_tree_id);
    }
  }

  void AddValidDataset(const Dataset* valid_data,
Guolin Ke's avatar
Guolin Ke committed
190
    const std::vector<const Metric*>& valid_metrics) override {
Guolin Ke's avatar
Guolin Ke committed
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    GBDT::AddValidDataset(valid_data, valid_metrics);
    if (iter_ + num_init_iteration_ > 0) {
      for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
        valid_score_updater_.back()->MultiplyScore(1.0f / (iter_ + num_init_iteration_), cur_tree_id);
      }
    }
  }

  bool NeedAccuratePrediction() const override {
    // No early stopping for prediction
    return true;
  };

private:

  std::vector<score_t> tmp_grad_;
  std::vector<score_t> tmp_hess_;
Guolin Ke's avatar
Guolin Ke committed
208
  std::vector<double> init_scores_;
Guolin Ke's avatar
Guolin Ke committed
209
210
211
212

};

}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
213
#endif   // LIGHTGBM_BOOSTING_RF_H_