/*! * Copyright (c) 2017 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for license information. */ #ifndef LIGHTGBM_BOOSTING_RF_H_ #define LIGHTGBM_BOOSTING_RF_H_ #include #include #include #include #include #include #include #include #include "gbdt.h" #include "score_updater.hpp" namespace LightGBM { /*! * \brief Rondom Forest implementation */ class RF : public GBDT { public: RF() : GBDT() { average_output_ = true; } ~RF() {} void Init(const Config* config, const Dataset* train_data, const ObjectiveFunction* objective_function, const std::vector& training_metrics) override { CHECK(config->bagging_freq > 0 && config->bagging_fraction < 1.0f && config->bagging_fraction > 0.0f); CHECK(config->feature_fraction <= 1.0f && config->feature_fraction > 0.0f); GBDT::Init(config, train_data, objective_function, training_metrics); if (num_init_iteration_ > 0) { for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { MultiplyScore(cur_tree_id, 1.0f / num_init_iteration_); } } else { CHECK(train_data->metadata().init_score() == nullptr); } CHECK(num_tree_per_iteration_ == num_class_); // not shrinkage rate for the RF shrinkage_rate_ = 1.0f; // only boosting one time Boosting(); if (is_use_subset_ && bag_data_cnt_ < num_data_) { tmp_grad_.resize(num_data_); tmp_hess_.resize(num_data_); } } void ResetConfig(const Config* config) override { CHECK(config->bagging_freq > 0 && config->bagging_fraction < 1.0f && config->bagging_fraction > 0.0f); CHECK(config->feature_fraction <= 1.0f && config->feature_fraction > 0.0f); GBDT::ResetConfig(config); // not shrinkage rate for the RF shrinkage_rate_ = 1.0f; } void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function, const std::vector& training_metrics) override { GBDT::ResetTrainingData(train_data, objective_function, training_metrics); if (iter_ + num_init_iteration_ > 0) { for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { train_score_updater_->MultiplyScore(1.0f / (iter_ + num_init_iteration_), cur_tree_id); } } CHECK(num_tree_per_iteration_ == num_class_); // only boosting one time Boosting(); if (is_use_subset_ && bag_data_cnt_ < num_data_) { tmp_grad_.resize(num_data_); tmp_hess_.resize(num_data_); } } void Boosting() override { if (objective_function_ == nullptr) { Log::Fatal("RF mode do not support custom objective function, please use built-in objectives."); } init_scores_.resize(num_tree_per_iteration_, 0.0); for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { init_scores_[cur_tree_id] = BoostFromAverage(cur_tree_id, false); } size_t total_size = static_cast(num_data_) * num_tree_per_iteration_; std::vector tmp_scores(total_size, 0.0f); #pragma omp parallel for schedule(static) for (int j = 0; j < num_tree_per_iteration_; ++j) { size_t offset = static_cast(j)* num_data_; for (data_size_t i = 0; i < num_data_; ++i) { tmp_scores[offset + i] = init_scores_[j]; } } objective_function_-> GetGradients(tmp_scores.data(), gradients_.data(), hessians_.data()); } bool TrainOneIter(const score_t* gradients, const score_t* hessians) override { // bagging logic Bagging(iter_); CHECK(gradients == nullptr); CHECK(hessians == nullptr); gradients = gradients_.data(); hessians = hessians_.data(); for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { std::unique_ptr new_tree(new Tree(2)); size_t offset = static_cast(cur_tree_id)* num_data_; if (class_need_train_[cur_tree_id]) { auto grad = gradients + offset; auto hess = hessians + offset; // need to copy gradients for bagging subset. if (is_use_subset_ && bag_data_cnt_ < num_data_) { for (int i = 0; i < bag_data_cnt_; ++i) { tmp_grad_[i] = grad[bag_data_indices_[i]]; tmp_hess_[i] = hess[bag_data_indices_[i]]; } grad = tmp_grad_.data(); hess = tmp_hess_.data(); } new_tree.reset(tree_learner_->Train(grad, hess, forced_splits_json_)); } if (new_tree->num_leaves() > 1) { double pred = init_scores_[cur_tree_id]; auto residual_getter = [pred](const label_t* label, int i) {return static_cast(label[i]) - pred; }; tree_learner_->RenewTreeOutput(new_tree.get(), objective_function_, residual_getter, num_data_, bag_data_indices_.data(), bag_data_cnt_); if (std::fabs(init_scores_[cur_tree_id]) > kEpsilon) { new_tree->AddBias(init_scores_[cur_tree_id]); } // update score MultiplyScore(cur_tree_id, (iter_ + num_init_iteration_)); UpdateScore(new_tree.get(), cur_tree_id); MultiplyScore(cur_tree_id, 1.0 / (iter_ + num_init_iteration_ + 1)); } else { // only add default score one-time if (models_.size() < static_cast(num_tree_per_iteration_)) { double output = 0.0; if (!class_need_train_[cur_tree_id]) { if (objective_function_ != nullptr) { output = objective_function_->BoostFromScore(cur_tree_id); } else { output = init_scores_[cur_tree_id]; } } new_tree->AsConstantTree(output); MultiplyScore(cur_tree_id, (iter_ + num_init_iteration_)); UpdateScore(new_tree.get(), cur_tree_id); MultiplyScore(cur_tree_id, 1.0 / (iter_ + num_init_iteration_ + 1)); } } // add model models_.push_back(std::move(new_tree)); } ++iter_; return false; } void RollbackOneIter() override { if (iter_ <= 0) { return; } int cur_iter = iter_ + num_init_iteration_ - 1; // reset score for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { auto curr_tree = cur_iter * num_tree_per_iteration_ + cur_tree_id; models_[curr_tree]->Shrinkage(-1.0); MultiplyScore(cur_tree_id, (iter_ + num_init_iteration_)); train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id); for (auto& score_updater : valid_score_updater_) { score_updater->AddScore(models_[curr_tree].get(), cur_tree_id); } MultiplyScore(cur_tree_id, 1.0f / (iter_ + num_init_iteration_ - 1)); } // remove model for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { models_.pop_back(); } --iter_; } void MultiplyScore(const int cur_tree_id, double val) { train_score_updater_->MultiplyScore(val, cur_tree_id); for (auto& score_updater : valid_score_updater_) { score_updater->MultiplyScore(val, cur_tree_id); } } void AddValidDataset(const Dataset* valid_data, const std::vector& valid_metrics) override { GBDT::AddValidDataset(valid_data, valid_metrics); if (iter_ + num_init_iteration_ > 0) { for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { valid_score_updater_.back()->MultiplyScore(1.0f / (iter_ + num_init_iteration_), cur_tree_id); } } } bool NeedAccuratePrediction() const override { // No early stopping for prediction return true; }; private: std::vector tmp_grad_; std::vector tmp_hess_; std::vector init_scores_; }; } // namespace LightGBM #endif // LIGHTGBM_BOOSTING_RF_H_