"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "bba800df8291aeacf3ff651c703ae0721edc6793"
Commit 9390dc21 authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

add DART: Dropouts meet Multiple Additive Regression Trees (#67)

add DART: Dropouts meet Multiple Additive Regression Trees (#67)
parent 5a483a8d
...@@ -75,7 +75,7 @@ public: ...@@ -75,7 +75,7 @@ public:
/*! \brief Types of boosting */ /*! \brief Types of boosting */
enum BoostingType { enum BoostingType {
kGBDT, kUnknow kGBDT, kDART, kUnknow
}; };
...@@ -191,6 +191,8 @@ public: ...@@ -191,6 +191,8 @@ public:
int bagging_freq = 0; int bagging_freq = 0;
int early_stopping_round = 0; int early_stopping_round = 0;
int num_class = 1; int num_class = 1;
double drop_rate = 0.01;
int dropping_seed = 4;
void Set(const std::unordered_map<std::string, std::string>& params) override; void Set(const std::unordered_map<std::string, std::string>& params) override;
}; };
......
...@@ -200,7 +200,7 @@ void Application::InitTrain() { ...@@ -200,7 +200,7 @@ void Application::InitTrain() {
Network::Init(config_.network_config); Network::Init(config_.network_config);
Log::Info("Finished initializing network"); Log::Info("Finished initializing network");
// sync global random seed for feature patition // sync global random seed for feature patition
if (config_.boosting_type == BoostingType::kGBDT) { if (config_.boosting_type == BoostingType::kGBDT || config_.boosting_type == BoostingType::kDART) {
GBDTConfig* gbdt_config = GBDTConfig* gbdt_config =
dynamic_cast<GBDTConfig*>(config_.boosting_config); dynamic_cast<GBDTConfig*>(config_.boosting_config);
gbdt_config->tree_config.feature_fraction_seed = gbdt_config->tree_config.feature_fraction_seed =
......
#include <LightGBM/boosting.h> #include <LightGBM/boosting.h>
#include "gbdt.h" #include "gbdt.h"
#include "dart.h"
namespace LightGBM { namespace LightGBM {
...@@ -8,6 +9,8 @@ BoostingType GetBoostingTypeFromModelFile(const char* filename) { ...@@ -8,6 +9,8 @@ BoostingType GetBoostingTypeFromModelFile(const char* filename) {
std::string type = model_reader.first_line(); std::string type = model_reader.first_line();
if (type == std::string("gbdt")) { if (type == std::string("gbdt")) {
return BoostingType::kGBDT; return BoostingType::kGBDT;
} else if (type == std::string("dart")) {
return BoostingType::kDART;
} }
return BoostingType::kUnknow; return BoostingType::kUnknow;
} }
...@@ -28,6 +31,8 @@ Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) { ...@@ -28,6 +31,8 @@ Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) {
if (filename == nullptr || filename[0] == '\0') { if (filename == nullptr || filename[0] == '\0') {
if (type == BoostingType::kGBDT) { if (type == BoostingType::kGBDT) {
return new GBDT(); return new GBDT();
} else if (type == BoostingType::kDART) {
return new DART();
} else { } else {
return nullptr; return nullptr;
} }
...@@ -37,6 +42,8 @@ Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) { ...@@ -37,6 +42,8 @@ Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) {
if (type_in_file == type) { if (type_in_file == type) {
if (type == BoostingType::kGBDT) { if (type == BoostingType::kGBDT) {
ret = new GBDT(); ret = new GBDT();
} else if (type == BoostingType::kDART) {
ret = new DART();
} }
LoadFileToBoosting(ret, filename); LoadFileToBoosting(ret, filename);
} else { } else {
...@@ -51,6 +58,8 @@ Boosting* Boosting::CreateBoosting(const char* filename) { ...@@ -51,6 +58,8 @@ Boosting* Boosting::CreateBoosting(const char* filename) {
Boosting* ret = nullptr; Boosting* ret = nullptr;
if (type == BoostingType::kGBDT) { if (type == BoostingType::kGBDT) {
ret = new GBDT(); ret = new GBDT();
} else if (type == BoostingType::kDART) {
ret = new DART();
} }
LoadFileToBoosting(ret, filename); LoadFileToBoosting(ret, filename);
return ret; return ret;
......
#include "gbdt.h"
#include "dart.h"
#include <LightGBM/utils/common.h>
#include <LightGBM/feature.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <ctime>
#include <sstream>
#include <chrono>
#include <string>
#include <vector>
#include <utility>
namespace LightGBM {
DART::DART(){
}
DART::~DART(){
}
void DART::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics) {
GBDT::Init(config, train_data, object_function, training_metrics);
gbdt_config_ = dynamic_cast<const GBDTConfig*>(config);
drop_rate_ = gbdt_config_->drop_rate;
shrinkage_rate_ = 1.0;
random_for_drop_ = Random(gbdt_config_->dropping_seed);
}
bool DART::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) {
// boosting first
if (gradient == nullptr || hessian == nullptr) {
Boosting();
gradient = gradients_;
hessian = hessians_;
}
for (int curr_class = 0; curr_class < num_class_; ++curr_class){
// bagging logic
Bagging(iter_, curr_class);
// train a new tree
Tree * new_tree = tree_learner_[curr_class]->Train(gradient + curr_class * num_data_, hessian+ curr_class * num_data_);
// if cannot learn a new tree, then stop
if (new_tree->num_leaves() <= 1) {
Log::Info("Can't training anymore, there isn't any leaf meets split requirements.");
return true;
}
// shrink new tree
new_tree->Shrinkage(shrinkage_rate_);
// update score
UpdateScore(new_tree, curr_class);
UpdateScoreOutOfBag(new_tree, curr_class);
// add model
models_.push_back(new_tree);
}
// normalize
Normalize();
bool is_met_early_stopping = false;
// print message for metric
if (is_eval) {
is_met_early_stopping = OutputMetric(iter_ + 1);
}
++iter_;
if (is_met_early_stopping) {
Log::Info("Early stopping at iteration %d, the best iteration round is %d",
iter_, iter_ - early_stopping_round_);
// pop last early_stopping_round_ models
for (int i = 0; i < early_stopping_round_ * num_class_; ++i) {
delete models_.back();
models_.pop_back();
}
}
return is_met_early_stopping;
}
/*! \brief Get training scores result */
const score_t* DART::GetTrainingScore(data_size_t* out_len) {
DroppingTrees();
*out_len = train_score_updater_->num_data() * num_class_;
return train_score_updater_->score();
}
void DART::SaveModelToFile(int num_used_model, bool is_finish, const char* filename) {
// only save model once when is_finish = true
if (is_finish && saved_model_size_ < 0) {
GBDT::SaveModelToFile(num_used_model, is_finish, filename);
}
}
void DART::DroppingTrees() {
drop_index_.clear();
// select dropping tree indexes based on drop_rate
// if drop rate is too small, skip this step, drop one tree randomly
if (drop_rate_ > kEpsilon) {
for (size_t i = 0; i < static_cast<size_t>(iter_); ++i){
if (random_for_drop_.NextDouble() < drop_rate_) {
drop_index_.push_back(i);
}
}
}
// binomial-plus-one, at least one tree will be dropped
if (drop_index_.empty()){
drop_index_ = random_for_drop_.Sample(iter_, 1);
}
// drop trees
for (int i: drop_index_) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
int curr_tree = i * num_class_ + curr_class;
models_[curr_tree]->Shrinkage(-1.0);
train_score_updater_->AddScore(models_[curr_tree], curr_class);
}
}
shrinkage_rate_ = 1.0 / (1.0 + drop_index_.size());
}
void DART::Normalize() {
double k = static_cast<double>(drop_index_.size());
for (int i: drop_index_) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
int curr_tree = i * num_class_ + curr_class;
// update validation score
models_[curr_tree]->Shrinkage(shrinkage_rate_);
for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(models_[curr_tree], curr_class);
}
// update training score
models_[curr_tree]->Shrinkage(-k);
train_score_updater_->AddScore(models_[curr_tree], curr_class);
}
}
}
} // namespace LightGBM
#ifndef LIGHTGBM_BOOSTING_DART_H_
#define LIGHTGBM_BOOSTING_DART_H_
#include <LightGBM/boosting.h>
#include "score_updater.hpp"
#include "gbdt.h"
#include <cstdio>
#include <vector>
#include <string>
#include <fstream>
namespace LightGBM {
/*!
* \brief DART algorithm implementation. including Training, prediction, bagging.
*/
class DART: public GBDT {
public:
/*!
* \brief Constructor
*/
DART();
/*!
* \brief Destructor
*/
~DART();
/*!
* \brief Initialization logic
* \param config Config for boosting
* \param train_data Training data
* \param object_function Training objective function
* \param training_metrics Training metrics
* \param output_model_filename Filename of output model
*/
void Init(const BoostingConfig* gbdt_config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics)
override;
/*!
* \brief one training iteration
*/
bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override;
/*!
* \brief Get current training score
* \param out_len lenght of returned score
* \return training score
*/
const score_t* GetTrainingScore(data_size_t* out_len) override;
/*!
* \brief Serialize models by string
* \return String output of tranined model
*/
void SaveModelToFile(int num_used_model, bool is_finish, const char* filename) override;
/*!
* \brief Get Type name of this boosting object
*/
const char* Name() const override { return "dart"; }
private:
/*!
* \brief drop trees based on drop_rate
*/
void DroppingTrees();
/*!
* \brief normalize dropped trees
*/
void Normalize();
/*! \brief The indexes of dropping trees */
std::vector<size_t> drop_index_;
/*! \brief Dropping rate */
double drop_rate_;
/*! \brief Shrinkage rate for one iteration */
double shrinkage_rate_;
/*! \brief Random generator, used to select dropping trees */
Random random_for_drop_;
};
} // namespace LightGBM
#endif // LightGBM_BOOSTING_DART_H_
...@@ -357,7 +357,7 @@ void GBDT::SaveModelToFile(int num_used_model, bool is_finish, const char* filen ...@@ -357,7 +357,7 @@ void GBDT::SaveModelToFile(int num_used_model, bool is_finish, const char* filen
if (saved_model_size_ < 0) { if (saved_model_size_ < 0) {
model_output_file_.open(filename); model_output_file_.open(filename);
// output model type // output model type
model_output_file_ << "gbdt" << std::endl; model_output_file_ << Name() << std::endl;
// output number of class // output number of class
model_output_file_ << "num_class=" << num_class_ << std::endl; model_output_file_ << "num_class=" << num_class_ << std::endl;
// output label index // output label index
......
...@@ -48,7 +48,7 @@ public: ...@@ -48,7 +48,7 @@ public:
* \param is_eval true if need evaluation or early stop * \param is_eval true if need evaluation or early stop
* \return True if meet early stopping or cannot boosting * \return True if meet early stopping or cannot boosting
*/ */
bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override; virtual bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override;
/*! /*!
* \brief Get evaluation result at data_idx data * \brief Get evaluation result at data_idx data
...@@ -62,7 +62,7 @@ public: ...@@ -62,7 +62,7 @@ public:
* \param out_len lenght of returned score * \param out_len lenght of returned score
* \return training score * \return training score
*/ */
const score_t* GetTrainingScore(data_size_t* out_len) override; virtual const score_t* GetTrainingScore(data_size_t* out_len) override;
/*! /*!
* \brief Get prediction result at data_idx data * \brief Get prediction result at data_idx data
...@@ -77,14 +77,14 @@ public: ...@@ -77,14 +77,14 @@ public:
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \return Prediction result for this record * \return Prediction result for this record
*/ */
std::vector<double> PredictRaw(const double* feature_values) const override; std::vector<double> PredictRaw(const double* feature_values) const override;
/*! /*!
* \brief Predtion for one record with sigmoid transformation if enabled * \brief Predtion for one record with sigmoid transformation if enabled
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \return Prediction result for this record * \return Prediction result for this record
*/ */
std::vector<double> Predict(const double* feature_values) const override; std::vector<double> Predict(const double* feature_values) const override;
/*! /*!
* \brief Predtion for one record with leaf index * \brief Predtion for one record with leaf index
...@@ -97,7 +97,7 @@ public: ...@@ -97,7 +97,7 @@ public:
* \brief Serialize models by string * \brief Serialize models by string
* \return String output of tranined model * \return String output of tranined model
*/ */
void SaveModelToFile(int num_used_model, bool is_finish, const char* filename) override; virtual void SaveModelToFile(int num_used_model, bool is_finish, const char* filename) override;
/*! /*!
* \brief Restore from a serialized string * \brief Restore from a serialized string
*/ */
...@@ -134,14 +134,13 @@ public: ...@@ -134,14 +134,13 @@ public:
num_used_model_ = static_cast<int>(num_used_model / num_class_); num_used_model_ = static_cast<int>(num_used_model / num_class_);
} }
} }
/*! /*!
* \brief Get Type name of this boosting object * \brief Get Type name of this boosting object
*/ */
const char* Name() const override { return "gbdt"; } virtual const char* Name() const override { return "gbdt"; }
private: protected:
/*! /*!
* \brief Implement bagging logic * \brief Implement bagging logic
* \param iter Current interation * \param iter Current interation
...@@ -164,7 +163,7 @@ private: ...@@ -164,7 +163,7 @@ private:
* \param tree Trained tree of this iteration * \param tree Trained tree of this iteration
* \param curr_class Current class for multiclass training * \param curr_class Current class for multiclass training
*/ */
void UpdateScore(const Tree* tree, const int curr_class); virtual void UpdateScore(const Tree* tree, const int curr_class);
/*! /*!
* \brief Print metric result of current iteration * \brief Print metric result of current iteration
* \param iter Current interation * \param iter Current interation
......
...@@ -40,6 +40,7 @@ public: ...@@ -40,6 +40,7 @@ public:
* \brief Using tree model to get prediction number, then adding to scores for all data * \brief Using tree model to get prediction number, then adding to scores for all data
* Note: this function generally will be used on validation data too. * Note: this function generally will be used on validation data too.
* \param tree Trained tree model * \param tree Trained tree model
* \param curr_class Current class for multiclass training
*/ */
inline void AddScore(const Tree* tree, int curr_class) { inline void AddScore(const Tree* tree, int curr_class) {
tree->AddPredictionToScore(data_, num_data_, score_ + curr_class * num_data_); tree->AddPredictionToScore(data_, num_data_, score_ + curr_class * num_data_);
...@@ -49,6 +50,7 @@ public: ...@@ -49,6 +50,7 @@ public:
* The training data is partitioned into tree leaves after training * The training data is partitioned into tree leaves after training
* Based on which We can get prediction quckily. * Based on which We can get prediction quckily.
* \param tree_learner * \param tree_learner
* \param curr_class Current class for multiclass training
*/ */
inline void AddScore(const TreeLearner* tree_learner, int curr_class) { inline void AddScore(const TreeLearner* tree_learner, int curr_class) {
tree_learner->AddPredictionToScore(score_ + curr_class * num_data_); tree_learner->AddPredictionToScore(score_ + curr_class * num_data_);
...@@ -59,6 +61,7 @@ public: ...@@ -59,6 +61,7 @@ public:
* \param tree Trained tree model * \param tree Trained tree model
* \param data_indices Indices of data that will be proccessed * \param data_indices Indices of data that will be proccessed
* \param data_cnt Number of data that will be proccessed * \param data_cnt Number of data that will be proccessed
* \param curr_class Current class for multiclass training
*/ */
inline void AddScore(const Tree* tree, const data_size_t* data_indices, inline void AddScore(const Tree* tree, const data_size_t* data_indices,
data_size_t data_cnt, int curr_class) { data_size_t data_cnt, int curr_class) {
......
...@@ -39,7 +39,7 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para ...@@ -39,7 +39,7 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para
GetMetricType(params); GetMetricType(params);
// construct boosting configs // construct boosting configs
if (boosting_type == BoostingType::kGBDT) { if (boosting_type == BoostingType::kGBDT || boosting_type == BoostingType::kDART) {
boosting_config = new GBDTConfig(); boosting_config = new GBDTConfig();
} }
...@@ -73,6 +73,8 @@ void OverallConfig::GetBoostingType(const std::unordered_map<std::string, std::s ...@@ -73,6 +73,8 @@ void OverallConfig::GetBoostingType(const std::unordered_map<std::string, std::s
std::transform(value.begin(), value.end(), value.begin(), ::tolower); std::transform(value.begin(), value.end(), value.begin(), ::tolower);
if (value == std::string("gbdt") || value == std::string("gbrt")) { if (value == std::string("gbdt") || value == std::string("gbrt")) {
boosting_type = BoostingType::kGBDT; boosting_type = BoostingType::kGBDT;
} else if (value == std::string("dart")) {
boosting_type = BoostingType::kDART;
} else { } else {
Log::Fatal("Unknown boosting type %s", value.c_str()); Log::Fatal("Unknown boosting type %s", value.c_str());
} }
...@@ -296,6 +298,9 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par ...@@ -296,6 +298,9 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
CHECK(output_freq >= 0); CHECK(output_freq >= 0);
GetBool(params, "is_training_metric", &is_provide_training_metric); GetBool(params, "is_training_metric", &is_provide_training_metric);
GetInt(params, "num_class", &num_class); GetInt(params, "num_class", &num_class);
GetInt(params, "dropping_seed", &dropping_seed);
GetDouble(params, "drop_rate", &drop_rate);
CHECK(drop_rate <= 1.0 && drop_rate >= 0.0);
} }
void GBDTConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) { void GBDTConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) {
......
...@@ -211,6 +211,7 @@ ...@@ -211,6 +211,7 @@
<ClInclude Include="..\include\LightGBM\utils\threading.h" /> <ClInclude Include="..\include\LightGBM\utils\threading.h" />
<ClInclude Include="..\src\application\predictor.hpp" /> <ClInclude Include="..\src\application\predictor.hpp" />
<ClInclude Include="..\src\boosting\gbdt.h" /> <ClInclude Include="..\src\boosting\gbdt.h" />
<ClInclude Include="..\src\boosting\dart.h" />
<ClInclude Include="..\src\boosting\score_updater.hpp" /> <ClInclude Include="..\src\boosting\score_updater.hpp" />
<ClInclude Include="..\src\io\dense_bin.hpp" /> <ClInclude Include="..\src\io\dense_bin.hpp" />
<ClInclude Include="..\src\io\ordered_sparse_bin.hpp" /> <ClInclude Include="..\src\io\ordered_sparse_bin.hpp" />
...@@ -237,6 +238,7 @@ ...@@ -237,6 +238,7 @@
<ClCompile Include="..\src\application\application.cpp" /> <ClCompile Include="..\src\application\application.cpp" />
<ClCompile Include="..\src\boosting\boosting.cpp" /> <ClCompile Include="..\src\boosting\boosting.cpp" />
<ClCompile Include="..\src\boosting\gbdt.cpp" /> <ClCompile Include="..\src\boosting\gbdt.cpp" />
<ClCompile Include="..\src\boosting\dart.cpp" />
<ClCompile Include="..\src\c_api.cpp" /> <ClCompile Include="..\src\c_api.cpp" />
<ClCompile Include="..\src\io\bin.cpp" /> <ClCompile Include="..\src\io\bin.cpp" />
<ClCompile Include="..\src\io\config.cpp" /> <ClCompile Include="..\src\io\config.cpp" />
......
...@@ -39,6 +39,9 @@ ...@@ -39,6 +39,9 @@
<ClInclude Include="..\src\boosting\gbdt.h"> <ClInclude Include="..\src\boosting\gbdt.h">
<Filter>src\boosting</Filter> <Filter>src\boosting</Filter>
</ClInclude> </ClInclude>
<ClInclude Include="..\src\boosting\dart.h">
<Filter>src\boosting</Filter>
</ClInclude>
<ClInclude Include="..\src\network\linkers.h"> <ClInclude Include="..\src\network\linkers.h">
<Filter>src\network</Filter> <Filter>src\network</Filter>
</ClInclude> </ClInclude>
...@@ -191,6 +194,9 @@ ...@@ -191,6 +194,9 @@
<ClCompile Include="..\src\Boosting\gbdt.cpp"> <ClCompile Include="..\src\Boosting\gbdt.cpp">
<Filter>src\boosting</Filter> <Filter>src\boosting</Filter>
</ClCompile> </ClCompile>
<ClCompile Include="..\src\Boosting\dart.cpp">
<Filter>src\boosting</Filter>
</ClCompile>
<ClCompile Include="..\src\io\dataset.cpp"> <ClCompile Include="..\src\io\dataset.cpp">
<Filter>src\io</Filter> <Filter>src\io</Filter>
</ClCompile> </ClCompile>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment