Commit f449a45b authored by Guolin Ke's avatar Guolin Ke
Browse files

reduce memory cost for multi classification

parent 381a945d
...@@ -35,7 +35,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O ...@@ -35,7 +35,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
random_ = Random(config->bagging_seed); random_ = Random(config->bagging_seed);
train_data_ = nullptr; train_data_ = nullptr;
gbdt_config_ = nullptr; gbdt_config_ = nullptr;
tree_learner_.clear(); tree_learner_ = nullptr;
ResetTrainingData(config, train_data, object_function, training_metrics); ResetTrainingData(config, train_data, object_function, training_metrics);
} }
...@@ -58,17 +58,11 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -58,17 +58,11 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
} }
if (train_data_ != train_data && train_data != nullptr) { if (train_data_ != train_data && train_data != nullptr) {
if (tree_learner_.empty()) { if (tree_learner_ == nullptr) {
for (int i = 0; i < num_class_; ++i) { tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(new_config->tree_learner_type, &new_config->tree_config));
auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(new_config->tree_learner_type, &new_config->tree_config));
tree_learner_.push_back(std::move(new_tree_learner));
}
tree_learner_.shrink_to_fit();
} }
// init tree learner // init tree learner
for (int i = 0; i < num_class_; ++i) { tree_learner_->Init(train_data);
tree_learner_[i]->Init(train_data);
}
// push training metrics // push training metrics
training_metrics_.clear(); training_metrics_.clear();
...@@ -114,9 +108,7 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -114,9 +108,7 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
train_data_ = train_data; train_data_ = train_data;
if (train_data_ != nullptr) { if (train_data_ != nullptr) {
// reset config for tree learner // reset config for tree learner
for (int i = 0; i < num_class_; ++i) { tree_learner_->ResetConfig(&new_config->tree_config);
tree_learner_[i]->ResetConfig(&new_config->tree_config);
}
} }
gbdt_config_.reset(new_config.release()); gbdt_config_.reset(new_config.release());
} }
...@@ -154,7 +146,7 @@ void GBDT::AddValidDataset(const Dataset* valid_data, ...@@ -154,7 +146,7 @@ void GBDT::AddValidDataset(const Dataset* valid_data,
} }
void GBDT::Bagging(int iter, const int curr_class) { void GBDT::Bagging(int iter) {
// if need bagging // if need bagging
if (!out_of_bag_data_indices_.empty() && iter % gbdt_config_->bagging_freq == 0) { if (!out_of_bag_data_indices_.empty() && iter % gbdt_config_->bagging_freq == 0) {
// if doesn't have query data // if doesn't have query data
...@@ -203,7 +195,7 @@ void GBDT::Bagging(int iter, const int curr_class) { ...@@ -203,7 +195,7 @@ void GBDT::Bagging(int iter, const int curr_class) {
} }
Log::Debug("Re-bagging, using %d data to train", bag_data_cnt_); Log::Debug("Re-bagging, using %d data to train", bag_data_cnt_);
// set bagging data to tree learner // set bagging data to tree learner
tree_learner_[curr_class]->SetBaggingData(bag_data_indices_.data(), bag_data_cnt_); tree_learner_->SetBaggingData(bag_data_indices_.data(), bag_data_cnt_);
} }
} }
...@@ -221,13 +213,12 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is ...@@ -221,13 +213,12 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
gradient = gradients_.data(); gradient = gradients_.data();
hessian = hessians_.data(); hessian = hessians_.data();
} }
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
// bagging logic // bagging logic
Bagging(iter_, curr_class); Bagging(iter_);
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
// train a new tree // train a new tree
std::unique_ptr<Tree> new_tree(tree_learner_[curr_class]->Train(gradient + curr_class * num_data_, hessian + curr_class * num_data_)); std::unique_ptr<Tree> new_tree(tree_learner_->Train(gradient + curr_class * num_data_, hessian + curr_class * num_data_));
// if cannot learn a new tree, then stop // if cannot learn a new tree, then stop
if (new_tree->num_leaves() <= 1) { if (new_tree->num_leaves() <= 1) {
Log::Info("Stopped training because there are no more leafs that meet the split requirements."); Log::Info("Stopped training because there are no more leafs that meet the split requirements.");
...@@ -290,7 +281,7 @@ bool GBDT::EvalAndCheckEarlyStopping() { ...@@ -290,7 +281,7 @@ bool GBDT::EvalAndCheckEarlyStopping() {
void GBDT::UpdateScore(const Tree* tree, const int curr_class) { void GBDT::UpdateScore(const Tree* tree, const int curr_class) {
// update training score // update training score
train_score_updater_->AddScore(tree_learner_[curr_class].get(), curr_class); train_score_updater_->AddScore(tree_learner_.get(), curr_class);
// update validation score // update validation score
for (auto& score_updater : valid_score_updater_) { for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(tree, curr_class); score_updater->AddScore(tree, curr_class);
...@@ -301,7 +292,7 @@ std::string GBDT::OutputMetric(int iter) { ...@@ -301,7 +292,7 @@ std::string GBDT::OutputMetric(int iter) {
bool need_output = (iter % gbdt_config_->output_freq) == 0; bool need_output = (iter % gbdt_config_->output_freq) == 0;
std::string ret = ""; std::string ret = "";
std::stringstream msg_buf; std::stringstream msg_buf;
std::vector<std::pair<int, int>> meet_early_stopping_pairs; std::vector<std::pair<size_t, size_t>> meet_early_stopping_pairs;
// print training metric // print training metric
if (need_output) { if (need_output) {
for (auto& sub_metric : training_metrics_) { for (auto& sub_metric : training_metrics_) {
......
...@@ -214,9 +214,8 @@ protected: ...@@ -214,9 +214,8 @@ protected:
/*! /*!
* \brief Implement bagging logic * \brief Implement bagging logic
* \param iter Current interation * \param iter Current interation
* \param curr_class Current class for multiclass training
*/ */
void Bagging(int iter, const int curr_class); void Bagging(int iter);
/*! /*!
* \brief updating score for out-of-bag data. * \brief updating score for out-of-bag data.
* Data should be update since we may re-bagging data on training * Data should be update since we may re-bagging data on training
...@@ -252,7 +251,7 @@ protected: ...@@ -252,7 +251,7 @@ protected:
/*! \brief Config of gbdt */ /*! \brief Config of gbdt */
std::unique_ptr<BoostingConfig> gbdt_config_; std::unique_ptr<BoostingConfig> gbdt_config_;
/*! \brief Tree learner, will use this class to learn trees */ /*! \brief Tree learner, will use this class to learn trees */
std::vector<std::unique_ptr<TreeLearner>> tree_learner_; std::unique_ptr<TreeLearner> tree_learner_;
/*! \brief Objective function */ /*! \brief Objective function */
const ObjectiveFunction* object_function_; const ObjectiveFunction* object_function_;
/*! \brief Store and update training data's score */ /*! \brief Store and update training data's score */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment