rf.hpp 8.79 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
#ifndef LIGHTGBM_BOOSTING_RF_H_
#define LIGHTGBM_BOOSTING_RF_H_

8
9
10
#include <LightGBM/boosting.h>
#include <LightGBM/metric.h>

Guolin Ke's avatar
Guolin Ke committed
11
#include <string>
12
#include <cstdio>
Guolin Ke's avatar
Guolin Ke committed
13
#include <fstream>
14
15
16
17
18
19
#include <memory>
#include <utility>
#include <vector>

#include "gbdt.h"
#include "score_updater.hpp"
Guolin Ke's avatar
Guolin Ke committed
20
21
22

namespace LightGBM {
/*!
guanqun's avatar
guanqun committed
23
* \brief Random Forest implementation
Guolin Ke's avatar
Guolin Ke committed
24
*/
Guolin Ke's avatar
Guolin Ke committed
25
class RF : public GBDT {
Nikita Titov's avatar
Nikita Titov committed
26
 public:
Guolin Ke's avatar
Guolin Ke committed
27
  RF() : GBDT() {
Guolin Ke's avatar
Guolin Ke committed
28
29
30
31
32
    average_output_ = true;
  }

  ~RF() {}

Guolin Ke's avatar
Guolin Ke committed
33
  void Init(const Config* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
Guolin Ke's avatar
Guolin Ke committed
34
    const std::vector<const Metric*>& training_metrics) override {
35
36
37
38
39
40
    if (config->data_sample_strategy == std::string("bagging")) {
      CHECK((config->bagging_freq > 0 && config->bagging_fraction < 1.0f && config->bagging_fraction > 0.0f) ||
            (config->feature_fraction < 1.0f && config->feature_fraction > 0.0f));
    } else {
      CHECK_EQ(config->data_sample_strategy, std::string("goss"));
    }
Guolin Ke's avatar
Guolin Ke committed
41
42
43
44
45
46
47
    GBDT::Init(config, train_data, objective_function, training_metrics);

    if (num_init_iteration_ > 0) {
      for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
        MultiplyScore(cur_tree_id, 1.0f / num_init_iteration_);
      }
    } else {
Nikita Titov's avatar
Nikita Titov committed
48
      CHECK_EQ(train_data->metadata().init_score(), nullptr);
Guolin Ke's avatar
Guolin Ke committed
49
    }
Nikita Titov's avatar
Nikita Titov committed
50
    CHECK_EQ(num_tree_per_iteration_, num_class_);
Guolin Ke's avatar
Guolin Ke committed
51
52
53
    // not shrinkage rate for the RF
    shrinkage_rate_ = 1.0f;
    // only boosting one time
Guolin Ke's avatar
Guolin Ke committed
54
    Boosting();
55
    if (data_sample_strategy_->is_use_subset() && data_sample_strategy_->bag_data_cnt() < num_data_) {
56
57
      tmp_grad_.resize(num_data_);
      tmp_hess_.resize(num_data_);
Guolin Ke's avatar
Guolin Ke committed
58
59
60
    }
  }

Guolin Ke's avatar
Guolin Ke committed
61
  void ResetConfig(const Config* config) override {
62
63
64
65
66
67
    if (config->data_sample_strategy == std::string("bagging")) {
      CHECK((config->bagging_freq > 0 && config->bagging_fraction < 1.0f && config->bagging_fraction > 0.0f) ||
            (config->feature_fraction < 1.0f && config->feature_fraction > 0.0f));
    } else {
      CHECK_EQ(config->data_sample_strategy, std::string("goss"));
    }
Guolin Ke's avatar
Guolin Ke committed
68
69
70
71
72
73
    GBDT::ResetConfig(config);
    // not shrinkage rate for the RF
    shrinkage_rate_ = 1.0f;
  }

  void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
Guolin Ke's avatar
Guolin Ke committed
74
    const std::vector<const Metric*>& training_metrics) override {
Guolin Ke's avatar
Guolin Ke committed
75
76
77
78
79
80
    GBDT::ResetTrainingData(train_data, objective_function, training_metrics);
    if (iter_ + num_init_iteration_ > 0) {
      for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
        train_score_updater_->MultiplyScore(1.0f / (iter_ + num_init_iteration_), cur_tree_id);
      }
    }
Nikita Titov's avatar
Nikita Titov committed
81
    CHECK_EQ(num_tree_per_iteration_, num_class_);
Guolin Ke's avatar
Guolin Ke committed
82
    // only boosting one time
Guolin Ke's avatar
Guolin Ke committed
83
    Boosting();
84
    if (data_sample_strategy_->is_use_subset() && data_sample_strategy_->bag_data_cnt() < num_data_) {
85
86
      tmp_grad_.resize(num_data_);
      tmp_hess_.resize(num_data_);
Guolin Ke's avatar
Guolin Ke committed
87
88
89
    }
  }

Guolin Ke's avatar
Guolin Ke committed
90
91
  void Boosting() override {
    if (objective_function_ == nullptr) {
92
      Log::Fatal("RF mode do not support custom objective function, please use built-in objectives.");
Guolin Ke's avatar
Guolin Ke committed
93
94
95
96
    }
    init_scores_.resize(num_tree_per_iteration_, 0.0);
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
      init_scores_[cur_tree_id] = BoostFromAverage(cur_tree_id, false);
97
    }
Guolin Ke's avatar
Guolin Ke committed
98
99
    size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
    std::vector<double> tmp_scores(total_size, 0.0f);
100
    #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
Guolin Ke's avatar
Guolin Ke committed
101
    for (int j = 0; j < num_tree_per_iteration_; ++j) {
102
      size_t offset = static_cast<size_t>(j)* num_data_;
Guolin Ke's avatar
Guolin Ke committed
103
      for (data_size_t i = 0; i < num_data_; ++i) {
104
        tmp_scores[offset + i] = init_scores_[j];
105
      }
Guolin Ke's avatar
Guolin Ke committed
106
    }
Guolin Ke's avatar
Guolin Ke committed
107
108
    objective_function_->
      GetGradients(tmp_scores.data(), gradients_.data(), hessians_.data());
Guolin Ke's avatar
Guolin Ke committed
109
110
  }

Guolin Ke's avatar
Guolin Ke committed
111
  bool TrainOneIter(const score_t* gradients, const score_t* hessians) override {
Guolin Ke's avatar
Guolin Ke committed
112
    // bagging logic
113
114
115
116
117
    data_sample_strategy_ ->Bagging(iter_, tree_learner_.get(), gradients_.data(), hessians_.data());
    const bool is_use_subset = data_sample_strategy_->is_use_subset();
    const data_size_t bag_data_cnt = data_sample_strategy_->bag_data_cnt();
    const std::vector<data_size_t, Common::AlignmentAllocator<data_size_t, kAlignedSize>>& bag_data_indices = data_sample_strategy_->bag_data_indices();

118
119
120
121
122
123
    // GOSSStrategy->Bagging may modify value of bag_data_cnt_
    if (is_use_subset && bag_data_cnt < num_data_) {
      tmp_grad_.resize(num_data_);
      tmp_hess_.resize(num_data_);
    }

Nikita Titov's avatar
Nikita Titov committed
124
125
    CHECK_EQ(gradients, nullptr);
    CHECK_EQ(hessians, nullptr);
Guolin Ke's avatar
Guolin Ke committed
126

127
128
    gradients = gradients_.data();
    hessians = hessians_.data();
Guolin Ke's avatar
Guolin Ke committed
129
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
130
      std::unique_ptr<Tree> new_tree(new Tree(2, false, false));
131
      size_t offset = static_cast<size_t>(cur_tree_id)* num_data_;
Guolin Ke's avatar
Guolin Ke committed
132
      if (class_need_train_[cur_tree_id]) {
133
134
        auto grad = gradients + offset;
        auto hess = hessians + offset;
Guolin Ke's avatar
Guolin Ke committed
135

136
137
138
139
        if (is_use_subset && bag_data_cnt < num_data_ && !boosting_on_gpu_) {
          for (int i = 0; i < bag_data_cnt; ++i) {
            tmp_grad_[i] = grad[bag_data_indices[i]];
            tmp_hess_[i] = hess[bag_data_indices[i]];
Guolin Ke's avatar
Guolin Ke committed
140
141
142
          }
          grad = tmp_grad_.data();
          hess = tmp_hess_.data();
Guolin Ke's avatar
Guolin Ke committed
143
        }
Guolin Ke's avatar
Guolin Ke committed
144

145
        new_tree.reset(tree_learner_->Train(grad, hess, false));
Guolin Ke's avatar
Guolin Ke committed
146
      }
Guolin Ke's avatar
Guolin Ke committed
147

Guolin Ke's avatar
Guolin Ke committed
148
      if (new_tree->num_leaves() > 1) {
149
150
151
        double pred = init_scores_[cur_tree_id];
        auto residual_getter = [pred](const label_t* label, int i) {return static_cast<double>(label[i]) - pred; };
        tree_learner_->RenewTreeOutput(new_tree.get(), objective_function_, residual_getter,
152
          num_data_, bag_data_indices.data(), bag_data_cnt, train_score_updater_->score());
Guolin Ke's avatar
Guolin Ke committed
153
154
155
        if (std::fabs(init_scores_[cur_tree_id]) > kEpsilon) {
          new_tree->AddBias(init_scores_[cur_tree_id]);
        }
Guolin Ke's avatar
Guolin Ke committed
156
157
158
159
        // update score
        MultiplyScore(cur_tree_id, (iter_ + num_init_iteration_));
        UpdateScore(new_tree.get(), cur_tree_id);
        MultiplyScore(cur_tree_id, 1.0 / (iter_ + num_init_iteration_ + 1));
Guolin Ke's avatar
Guolin Ke committed
160
161
162
163
164
165
166
167
168
169
170
      } else {
        // only add default score one-time
        if (models_.size() < static_cast<size_t>(num_tree_per_iteration_)) {
          double output = 0.0;
          if (!class_need_train_[cur_tree_id]) {
            if (objective_function_ != nullptr) {
              output = objective_function_->BoostFromScore(cur_tree_id);
            } else {
              output = init_scores_[cur_tree_id];
            }
          }
171
          new_tree->AsConstantTree(output, num_data_);
Guolin Ke's avatar
Guolin Ke committed
172
173
174
175
          MultiplyScore(cur_tree_id, (iter_ + num_init_iteration_));
          UpdateScore(new_tree.get(), cur_tree_id);
          MultiplyScore(cur_tree_id, 1.0 / (iter_ + num_init_iteration_ + 1));
        }
Guolin Ke's avatar
Guolin Ke committed
176
177
178
179
180
      }
      // add model
      models_.push_back(std::move(new_tree));
    }
    ++iter_;
Guolin Ke's avatar
Guolin Ke committed
181
    return false;
Guolin Ke's avatar
Guolin Ke committed
182
183
184
  }

  void RollbackOneIter() override {
185
186
187
    if (iter_ <= 0) {
      return;
    }
Guolin Ke's avatar
Guolin Ke committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    int cur_iter = iter_ + num_init_iteration_ - 1;
    // reset score
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
      auto curr_tree = cur_iter * num_tree_per_iteration_ + cur_tree_id;
      models_[curr_tree]->Shrinkage(-1.0);
      MultiplyScore(cur_tree_id, (iter_ + num_init_iteration_));
      train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
      for (auto& score_updater : valid_score_updater_) {
        score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
      }
      MultiplyScore(cur_tree_id, 1.0f / (iter_ + num_init_iteration_ - 1));
    }
    // remove model
    for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
      models_.pop_back();
    }
    --iter_;
  }

  void MultiplyScore(const int cur_tree_id, double val) {
    train_score_updater_->MultiplyScore(val, cur_tree_id);
    for (auto& score_updater : valid_score_updater_) {
      score_updater->MultiplyScore(val, cur_tree_id);
    }
  }

  void AddValidDataset(const Dataset* valid_data,
Guolin Ke's avatar
Guolin Ke committed
215
    const std::vector<const Metric*>& valid_metrics) override {
Guolin Ke's avatar
Guolin Ke committed
216
217
218
219
220
221
222
223
224
225
226
227
228
    GBDT::AddValidDataset(valid_data, valid_metrics);
    if (iter_ + num_init_iteration_ > 0) {
      for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
        valid_score_updater_.back()->MultiplyScore(1.0f / (iter_ + num_init_iteration_), cur_tree_id);
      }
    }
  }

  bool NeedAccuratePrediction() const override {
    // No early stopping for prediction
    return true;
  };

Nikita Titov's avatar
Nikita Titov committed
229
 private:
Guolin Ke's avatar
Guolin Ke committed
230
231
  std::vector<score_t> tmp_grad_;
  std::vector<score_t> tmp_hess_;
Guolin Ke's avatar
Guolin Ke committed
232
  std::vector<double> init_scores_;
Guolin Ke's avatar
Guolin Ke committed
233
234
235
};

}  // namespace LightGBM
236
#endif  // LIGHTGBM_BOOSTING_RF_H_