Commit f449a45b authored by Guolin Ke's avatar Guolin Ke
Browse files

reduce memory cost for multi classification

parent 381a945d
......@@ -35,7 +35,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
random_ = Random(config->bagging_seed);
train_data_ = nullptr;
gbdt_config_ = nullptr;
tree_learner_.clear();
tree_learner_ = nullptr;
ResetTrainingData(config, train_data, object_function, training_metrics);
}
......@@ -58,17 +58,11 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
}
if (train_data_ != train_data && train_data != nullptr) {
if (tree_learner_.empty()) {
for (int i = 0; i < num_class_; ++i) {
auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(new_config->tree_learner_type, &new_config->tree_config));
tree_learner_.push_back(std::move(new_tree_learner));
}
tree_learner_.shrink_to_fit();
if (tree_learner_ == nullptr) {
tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(new_config->tree_learner_type, &new_config->tree_config));
}
// init tree learner
for (int i = 0; i < num_class_; ++i) {
tree_learner_[i]->Init(train_data);
}
tree_learner_->Init(train_data);
// push training metrics
training_metrics_.clear();
......@@ -114,9 +108,7 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
train_data_ = train_data;
if (train_data_ != nullptr) {
// reset config for tree learner
for (int i = 0; i < num_class_; ++i) {
tree_learner_[i]->ResetConfig(&new_config->tree_config);
}
tree_learner_->ResetConfig(&new_config->tree_config);
}
gbdt_config_.reset(new_config.release());
}
......@@ -154,7 +146,7 @@ void GBDT::AddValidDataset(const Dataset* valid_data,
}
void GBDT::Bagging(int iter, const int curr_class) {
void GBDT::Bagging(int iter) {
// if need bagging
if (!out_of_bag_data_indices_.empty() && iter % gbdt_config_->bagging_freq == 0) {
// if doesn't have query data
......@@ -203,7 +195,7 @@ void GBDT::Bagging(int iter, const int curr_class) {
}
Log::Debug("Re-bagging, using %d data to train", bag_data_cnt_);
// set bagging data to tree learner
tree_learner_[curr_class]->SetBaggingData(bag_data_indices_.data(), bag_data_cnt_);
tree_learner_->SetBaggingData(bag_data_indices_.data(), bag_data_cnt_);
}
}
......@@ -221,13 +213,12 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
gradient = gradients_.data();
hessian = hessians_.data();
}
// bagging logic
Bagging(iter_);
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
// bagging logic
Bagging(iter_, curr_class);
// train a new tree
std::unique_ptr<Tree> new_tree(tree_learner_[curr_class]->Train(gradient + curr_class * num_data_, hessian + curr_class * num_data_));
std::unique_ptr<Tree> new_tree(tree_learner_->Train(gradient + curr_class * num_data_, hessian + curr_class * num_data_));
// if cannot learn a new tree, then stop
if (new_tree->num_leaves() <= 1) {
Log::Info("Stopped training because there are no more leafs that meet the split requirements.");
......@@ -290,7 +281,7 @@ bool GBDT::EvalAndCheckEarlyStopping() {
void GBDT::UpdateScore(const Tree* tree, const int curr_class) {
// update training score
train_score_updater_->AddScore(tree_learner_[curr_class].get(), curr_class);
train_score_updater_->AddScore(tree_learner_.get(), curr_class);
// update validation score
for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(tree, curr_class);
......@@ -301,7 +292,7 @@ std::string GBDT::OutputMetric(int iter) {
bool need_output = (iter % gbdt_config_->output_freq) == 0;
std::string ret = "";
std::stringstream msg_buf;
std::vector<std::pair<int, int>> meet_early_stopping_pairs;
std::vector<std::pair<size_t, size_t>> meet_early_stopping_pairs;
// print training metric
if (need_output) {
for (auto& sub_metric : training_metrics_) {
......
......@@ -214,9 +214,8 @@ protected:
/*!
* \brief Implement bagging logic
* \param iter Current interation
* \param curr_class Current class for multiclass training
*/
void Bagging(int iter, const int curr_class);
void Bagging(int iter);
/*!
* \brief updating score for out-of-bag data.
* Data should be update since we may re-bagging data on training
......@@ -252,7 +251,7 @@ protected:
/*! \brief Config of gbdt */
std::unique_ptr<BoostingConfig> gbdt_config_;
/*! \brief Tree learner, will use this class to learn trees */
std::vector<std::unique_ptr<TreeLearner>> tree_learner_;
std::unique_ptr<TreeLearner> tree_learner_;
/*! \brief Objective function */
const ObjectiveFunction* object_function_;
/*! \brief Store and update training data's score */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment