Commit 1e7ccbbb authored by Guolin Ke's avatar Guolin Ke
Browse files

clean code for Boosting::ResetTrainingData.

parent a98b23d2
......@@ -43,14 +43,10 @@ public:
*/
virtual void MergeFrom(const Boosting* other) = 0;
/*!
* \brief Reset training data for current boosting
* \param config Configs for boosting
* \param train_data Training data
* \param objective_function Training objective function
* \param training_metrics Training metric
*/
virtual void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function, const std::vector<const Metric*>& training_metrics) = 0;
virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) = 0;
virtual void ResetConfig(const BoostingConfig* config) = 0;
/*!
* \brief Add a validation data
......
......@@ -91,7 +91,7 @@ public:
int data_random_seed = 1;
std::string data_filename = "";
std::vector<std::string> valid_data_filenames;
int snapshot_freq = 100;
int snapshot_freq = -1;
std::string output_model = "LightGBM_model.txt";
std::string output_result = "LightGBM_predict_result.txt";
std::string convert_model = "gbdt_prediction.cpp";
......
......@@ -39,10 +39,6 @@ public:
sum_weight_ = 0.0f;
}
void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override {
GBDT::ResetTrainingData(config, train_data, objective_function, training_metrics);
}
/*!
* \brief one training iteration
*/
......
......@@ -64,24 +64,14 @@ GBDT::~GBDT() {
void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) {
train_data_ = train_data;
iter_ = 0;
num_iteration_for_pred_ = 0;
max_feature_idx_ = 0;
num_class_ = config->num_class;
train_data_ = nullptr;
gbdt_config_ = nullptr;
tree_learner_ = nullptr;
ResetTrainingData(config, train_data, objective_function, training_metrics);
}
void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) {
auto new_config = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
if (train_data_ != nullptr && !train_data_->CheckAlign(*train_data)) {
Log::Fatal("cannot reset training data, since new training data has different bin mappers");
}
early_stopping_round_ = new_config->early_stopping_round;
shrinkage_rate_ = new_config->learning_rate;
gbdt_config_ = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
early_stopping_round_ = gbdt_config_->early_stopping_round;
shrinkage_rate_ = gbdt_config_->learning_rate;
objective_function_ = objective_function;
num_tree_per_iteration_ = num_class_;
......@@ -92,12 +82,10 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
is_constant_hessian_ = false;
}
if (train_data_ != train_data && train_data != nullptr) {
if (tree_learner_ == nullptr) {
tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(new_config->tree_learner_type, new_config->device_type, &new_config->tree_config));
}
tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->device_type, &gbdt_config_->tree_config));
// init tree learner
tree_learner_->Init(train_data, is_constant_hessian_);
tree_learner_->Init(train_data_, is_constant_hessian_);
// push training metrics
training_metrics_.clear();
......@@ -105,17 +93,10 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
training_metrics_.push_back(metric);
}
training_metrics_.shrink_to_fit();
// not same training data, need reset score and others
// create score tracker
train_score_updater_.reset(new ScoreUpdater(train_data, num_tree_per_iteration_));
// update score
for (int i = 0; i < iter_; ++i) {
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
auto curr_tree = (i + num_init_iteration_) * num_tree_per_iteration_ + cur_tree_id;
train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
}
}
num_data_ = train_data->num_data();
train_score_updater_.reset(new ScoreUpdater(train_data_, num_tree_per_iteration_));
num_data_ = train_data_->num_data();
// create buffer for gradients and hessians
if (objective_function_ != nullptr) {
size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
......@@ -123,56 +104,17 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
hessians_.resize(total_size);
}
// get max feature index
max_feature_idx_ = train_data->num_total_features() - 1;
max_feature_idx_ = train_data_->num_total_features() - 1;
// get label index
label_idx_ = train_data->label_idx();
label_idx_ = train_data_->label_idx();
// get feature names
feature_names_ = train_data->feature_names();
feature_names_ = train_data_->feature_names();
feature_infos_ = train_data_->feature_infos();
feature_infos_ = train_data->feature_infos();
}
if ((train_data_ != train_data && train_data != nullptr)
|| (gbdt_config_ != nullptr && gbdt_config_->bagging_fraction != new_config->bagging_fraction)) {
// if need bagging, create buffer
if (new_config->bagging_fraction < 1.0 && new_config->bagging_freq > 0) {
bag_data_cnt_ =
static_cast<data_size_t>(new_config->bagging_fraction * num_data_);
bag_data_indices_.resize(num_data_);
tmp_indices_.resize(num_data_);
offsets_buf_.resize(num_threads_);
left_cnts_buf_.resize(num_threads_);
right_cnts_buf_.resize(num_threads_);
left_write_pos_buf_.resize(num_threads_);
right_write_pos_buf_.resize(num_threads_);
double average_bag_rate = new_config->bagging_fraction / new_config->bagging_freq;
int sparse_group = 0;
for (int i = 0; i < train_data->num_feature_groups(); ++i) {
if (train_data->FeatureGroupIsSparse(i)) {
++sparse_group;
}
}
is_use_subset_ = false;
const int group_threshold_usesubset = 100;
const int sparse_group_threshold_usesubset = train_data->num_feature_groups() / 4;
if (average_bag_rate <= 0.5
&& (train_data->num_feature_groups() < group_threshold_usesubset || sparse_group < sparse_group_threshold_usesubset)) {
tmp_subset_.reset(new Dataset(bag_data_cnt_));
tmp_subset_->CopyFeatureMapperFrom(train_data);
is_use_subset_ = true;
Log::Debug("use subset for bagging");
}
} else {
bag_data_cnt_ = num_data_;
bag_data_indices_.clear();
tmp_indices_.clear();
is_use_subset_ = false;
}
}
train_data_ = train_data;
if (train_data_ != nullptr) {
ResetBaggingConfig(gbdt_config_.get());
// reset config for tree learner
tree_learner_->ResetConfig(&new_config->tree_config);
class_need_train_ = std::vector<bool>(num_tree_per_iteration_, true);
if (objective_function_ != nullptr && objective_function_->SkipEmptyClass()) {
CHECK(num_tree_per_iteration_ == num_class_);
......@@ -213,10 +155,115 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
}
}
}
}
void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) {
if (train_data != train_data_ && !train_data_->CheckAlign(*train_data)) {
Log::Fatal("cannot reset training data, since new training data has different bin mappers");
}
objective_function_ = objective_function;
num_tree_per_iteration_ = num_class_;
if (objective_function_ != nullptr) {
is_constant_hessian_ = objective_function_->IsConstantHessian();
num_tree_per_iteration_ = objective_function_->NumTreePerIteration();
} else {
is_constant_hessian_ = false;
}
// push training metrics
training_metrics_.clear();
for (const auto& metric : training_metrics) {
training_metrics_.push_back(metric);
}
training_metrics_.shrink_to_fit();
if (train_data != train_data_) {
train_data_ = train_data;
// not same training data, need reset score and others
// create score tracker
train_score_updater_.reset(new ScoreUpdater(train_data_, num_tree_per_iteration_));
// update score
for (int i = 0; i < iter_; ++i) {
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
auto curr_tree = (i + num_init_iteration_) * num_tree_per_iteration_ + cur_tree_id;
train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
}
}
num_data_ = train_data_->num_data();
// create buffer for gradients and hessians
if (objective_function_ != nullptr) {
size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
gradients_.resize(total_size);
hessians_.resize(total_size);
}
// get max feature index
max_feature_idx_ = train_data_->num_total_features() - 1;
// get label index
label_idx_ = train_data_->label_idx();
// get feature names
feature_names_ = train_data_->feature_names();
feature_infos_ = train_data_->feature_infos();
ResetBaggingConfig(gbdt_config_.get());
tree_learner_->ResetTrainingData(train_data);
}
}
void GBDT::ResetConfig(const BoostingConfig* config) {
auto new_config = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
early_stopping_round_ = new_config->early_stopping_round;
shrinkage_rate_ = new_config->learning_rate;
ResetBaggingConfig(new_config.get());
tree_learner_->ResetConfig(&new_config->tree_config);
gbdt_config_.reset(new_config.release());
}
void GBDT::ResetBaggingConfig(const BoostingConfig* config) {
// if need bagging, create buffer
if (config->bagging_fraction < 1.0 && config->bagging_freq > 0) {
bag_data_cnt_ =
static_cast<data_size_t>(config->bagging_fraction * num_data_);
bag_data_indices_.resize(num_data_);
tmp_indices_.resize(num_data_);
offsets_buf_.resize(num_threads_);
left_cnts_buf_.resize(num_threads_);
right_cnts_buf_.resize(num_threads_);
left_write_pos_buf_.resize(num_threads_);
right_write_pos_buf_.resize(num_threads_);
double average_bag_rate = config->bagging_fraction / config->bagging_freq;
int sparse_group = 0;
for (int i = 0; i < train_data_->num_feature_groups(); ++i) {
if (train_data_->FeatureGroupIsSparse(i)) {
++sparse_group;
}
}
is_use_subset_ = false;
const int group_threshold_usesubset = 100;
const int sparse_group_threshold_usesubset = train_data_->num_feature_groups() / 4;
if (average_bag_rate <= 0.5
&& (train_data_->num_feature_groups() < group_threshold_usesubset || sparse_group < sparse_group_threshold_usesubset)) {
tmp_subset_.reset(new Dataset(bag_data_cnt_));
tmp_subset_->CopyFeatureMapperFrom(train_data_);
is_use_subset_ = true;
Log::Debug("use subset for bagging");
}
} else {
bag_data_cnt_ = num_data_;
bag_data_indices_.clear();
tmp_indices_.clear();
is_use_subset_ = false;
}
}
void GBDT::AddValidDataset(const Dataset* valid_data,
const std::vector<const Metric*>& valid_metrics) {
if (!train_data_->CheckAlign(*valid_data)) {
......@@ -358,7 +405,7 @@ double LabelAverage(const float* label, data_size_t num_data) {
Network::Allreduce(reinterpret_cast<char*>(&init_score),
sizeof(init_score), sizeof(init_score),
reinterpret_cast<char*>(&global_init_score),
[](const char* src, char* dst, int len) {
[] (const char* src, char* dst, int len) {
int used_size = 0;
const int type_size = sizeof(double);
const double *p1;
......@@ -1027,7 +1074,7 @@ std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const {
}
// sort the importance
std::sort(pairs.begin(), pairs.end(),
[](const std::pair<size_t, std::string>& lhs,
[] (const std::pair<size_t, std::string>& lhs,
const std::pair<size_t, std::string>& rhs) {
return lhs.first > rhs.first;
});
......
......@@ -63,14 +63,10 @@ public:
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
}
/*!
* \brief Reset training data for current boosting
* \param train_data Training data
* \param objective_function Training objective function
* \param training_metrics Training metric
*/
void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function, const std::vector<const Metric*>& training_metrics) override;
void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override;
void ResetConfig(const BoostingConfig* config) override;
/*!
* \brief Adding a validation dataset
* \param valid_data Validation dataset
......@@ -258,6 +254,7 @@ public:
virtual const char* SubModelName() const override { return "tree"; }
protected:
void ResetBaggingConfig(const BoostingConfig* config);
/*!
* \brief Implement bagging logic
* \param iter Current interation
......
......@@ -41,21 +41,28 @@ public:
void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override {
GBDT::Init(config, train_data, objective_function, training_metrics);
ResetGoss();
}
void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override {
GBDT::ResetTrainingData(train_data, objective_function, training_metrics);
ResetGoss();
}
void ResetConfig(const BoostingConfig* config) override {
GBDT::ResetConfig(config);
ResetGoss();
}
void ResetGoss() {
CHECK(gbdt_config_->top_rate + gbdt_config_->other_rate <= 1.0f);
CHECK(gbdt_config_->top_rate > 0.0f && gbdt_config_->other_rate > 0.0f);
if (gbdt_config_->bagging_freq > 0 && gbdt_config_->bagging_fraction != 1.0f) {
Log::Fatal("cannot use bagging in GOSS");
}
Log::Info("using GOSS");
}
void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override {
if (config->bagging_freq > 0 && config->bagging_fraction != 1.0f) {
Log::Fatal("cannot use bagging in GOSS");
}
GBDT::ResetTrainingData(config, train_data, objective_function, training_metrics);
if (train_data_ == nullptr) { return; }
bag_data_indices_.resize(num_data_);
tmp_indices_.resize(num_data_);
tmp_indice_right_.resize(num_data_);
......@@ -66,8 +73,8 @@ public:
right_write_pos_buf_.resize(num_threads_);
is_use_subset_ = false;
if (config->top_rate + config->other_rate <= 0.5) {
auto bag_data_cnt = static_cast<data_size_t>((config->top_rate + config->other_rate) * num_data_);
if (gbdt_config_->top_rate + gbdt_config_->other_rate <= 0.5) {
auto bag_data_cnt = static_cast<data_size_t>((gbdt_config_->top_rate + gbdt_config_->other_rate) * num_data_);
tmp_subset_.reset(new Dataset(bag_data_cnt));
tmp_subset_->CopyFeatureMapperFrom(train_data_);
is_use_subset_ = true;
......
......@@ -51,11 +51,12 @@ public:
boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
train_data_ = train_data;
CreateObjectiveAndMetrics();
// initialize the boosting
boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
boosting_->Init(&config_.boosting_config, train_data_, objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
ResetTrainingData(train_data);
}
void MergeFrom(const Booster* other) {
......@@ -67,9 +68,7 @@ public:
}
void ResetTrainingData(const Dataset* train_data) {
std::lock_guard<std::mutex> lock(mutex_);
train_data_ = train_data;
void CreateObjectiveAndMetrics() {
// create objective function
objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
config_.objective_config));
......@@ -91,10 +90,18 @@ public:
train_metric_.push_back(std::move(metric));
}
train_metric_.shrink_to_fit();
}
void ResetTrainingData(const Dataset* train_data) {
if (train_data != train_data_) {
std::lock_guard<std::mutex> lock(mutex_);
train_data_ = train_data;
CreateObjectiveAndMetrics();
// reset the boosting
boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
boosting_->ResetTrainingData(train_data_,
objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
}
}
void ResetConfig(const char* parameters) {
std::lock_guard<std::mutex> lock(mutex_);
......@@ -125,10 +132,11 @@ public:
if (objective_fun_ != nullptr) {
objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
}
boosting_->ResetTrainingData(train_data_,
objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
}
boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
boosting_->ResetConfig(&config_.boosting_config);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment