Commit 1e7ccbbb authored by Guolin Ke's avatar Guolin Ke
Browse files

clean code for Boosting::ResetTrainingData.

parent a98b23d2
......@@ -43,14 +43,10 @@ public:
*/
virtual void MergeFrom(const Boosting* other) = 0;
/*!
* \brief Reset training data for current boosting
* \param config Configs for boosting
* \param train_data Training data
* \param objective_function Training objective function
* \param training_metrics Training metric
*/
virtual void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function, const std::vector<const Metric*>& training_metrics) = 0;
virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) = 0;
virtual void ResetConfig(const BoostingConfig* config) = 0;
/*!
* \brief Add a validation data
......
......@@ -91,7 +91,7 @@ public:
int data_random_seed = 1;
std::string data_filename = "";
std::vector<std::string> valid_data_filenames;
int snapshot_freq = 100;
int snapshot_freq = -1;
std::string output_model = "LightGBM_model.txt";
std::string output_result = "LightGBM_predict_result.txt";
std::string convert_model = "gbdt_prediction.cpp";
......
......@@ -39,10 +39,6 @@ public:
sum_weight_ = 0.0f;
}
void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override {
GBDT::ResetTrainingData(config, train_data, objective_function, training_metrics);
}
/*!
* \brief one training iteration
*/
......
......@@ -44,9 +44,9 @@ GBDT::GBDT()
boost_from_average_(false) {
#pragma omp parallel
#pragma omp master
{
num_threads_ = omp_get_num_threads();
}
{
num_threads_ = omp_get_num_threads();
}
}
GBDT::~GBDT() {
......@@ -64,24 +64,104 @@ GBDT::~GBDT() {
void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) {
train_data_ = train_data;
iter_ = 0;
num_iteration_for_pred_ = 0;
max_feature_idx_ = 0;
num_class_ = config->num_class;
train_data_ = nullptr;
gbdt_config_ = nullptr;
tree_learner_ = nullptr;
ResetTrainingData(config, train_data, objective_function, training_metrics);
gbdt_config_ = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
early_stopping_round_ = gbdt_config_->early_stopping_round;
shrinkage_rate_ = gbdt_config_->learning_rate;
objective_function_ = objective_function;
num_tree_per_iteration_ = num_class_;
if (objective_function_ != nullptr) {
is_constant_hessian_ = objective_function_->IsConstantHessian();
num_tree_per_iteration_ = objective_function_->NumTreePerIteration();
} else {
is_constant_hessian_ = false;
}
tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->device_type, &gbdt_config_->tree_config));
// init tree learner
tree_learner_->Init(train_data_, is_constant_hessian_);
// push training metrics
training_metrics_.clear();
for (const auto& metric : training_metrics) {
training_metrics_.push_back(metric);
}
training_metrics_.shrink_to_fit();
train_score_updater_.reset(new ScoreUpdater(train_data_, num_tree_per_iteration_));
num_data_ = train_data_->num_data();
// create buffer for gradients and hessians
if (objective_function_ != nullptr) {
size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
gradients_.resize(total_size);
hessians_.resize(total_size);
}
// get max feature index
max_feature_idx_ = train_data_->num_total_features() - 1;
// get label index
label_idx_ = train_data_->label_idx();
// get feature names
feature_names_ = train_data_->feature_names();
feature_infos_ = train_data_->feature_infos();
// if need bagging, create buffer
ResetBaggingConfig(gbdt_config_.get());
// reset config for tree learner
class_need_train_ = std::vector<bool>(num_tree_per_iteration_, true);
if (objective_function_ != nullptr && objective_function_->SkipEmptyClass()) {
CHECK(num_tree_per_iteration_ == num_class_);
// + 1 here for the binary classification
class_default_output_ = std::vector<double>(num_tree_per_iteration_, 0.0f);
auto label = train_data_->metadata().label();
if (num_tree_per_iteration_ > 1) {
// multi-class
std::vector<data_size_t> cnt_per_class(num_tree_per_iteration_, 0);
for (data_size_t i = 0; i < num_data_; ++i) {
int index = static_cast<int>(label[i]);
CHECK(index < num_tree_per_iteration_);
++cnt_per_class[index];
}
for (int i = 0; i < num_tree_per_iteration_; ++i) {
if (cnt_per_class[i] == num_data_) {
class_need_train_[i] = false;
class_default_output_[i] = -std::log(kEpsilon);
} else if (cnt_per_class[i] == 0) {
class_need_train_[i] = false;
class_default_output_[i] = -std::log(1.0f / kEpsilon - 1.0f);
}
}
} else {
// binary class
data_size_t cnt_pos = 0;
for (data_size_t i = 0; i < num_data_; ++i) {
if (label[i] > 0) {
++cnt_pos;
}
}
if (cnt_pos == 0) {
class_need_train_[0] = false;
class_default_output_[0] = -std::log(1.0f / kEpsilon - 1.0f);
} else if (cnt_pos == num_data_) {
class_need_train_[0] = false;
class_default_output_[0] = -std::log(kEpsilon);
}
}
}
}
void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) {
auto new_config = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
if (train_data_ != nullptr && !train_data_->CheckAlign(*train_data)) {
if (train_data != train_data_ && !train_data_->CheckAlign(*train_data)) {
Log::Fatal("cannot reset training data, since new training data has different bin mappers");
}
early_stopping_round_ = new_config->early_stopping_round;
shrinkage_rate_ = new_config->learning_rate;
objective_function_ = objective_function;
num_tree_per_iteration_ = num_class_;
......@@ -92,22 +172,18 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
is_constant_hessian_ = false;
}
if (train_data_ != train_data && train_data != nullptr) {
if (tree_learner_ == nullptr) {
tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(new_config->tree_learner_type, new_config->device_type, &new_config->tree_config));
}
// init tree learner
tree_learner_->Init(train_data, is_constant_hessian_);
// push training metrics
training_metrics_.clear();
for (const auto& metric : training_metrics) {
training_metrics_.push_back(metric);
}
training_metrics_.shrink_to_fit();
// push training metrics
training_metrics_.clear();
for (const auto& metric : training_metrics) {
training_metrics_.push_back(metric);
}
training_metrics_.shrink_to_fit();
if (train_data != train_data_) {
train_data_ = train_data;
// not same training data, need reset score and others
// create score tracker
train_score_updater_.reset(new ScoreUpdater(train_data, num_tree_per_iteration_));
train_score_updater_.reset(new ScoreUpdater(train_data_, num_tree_per_iteration_));
// update score
for (int i = 0; i < iter_; ++i) {
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
......@@ -115,106 +191,77 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
}
}
num_data_ = train_data->num_data();
num_data_ = train_data_->num_data();
// create buffer for gradients and hessians
if (objective_function_ != nullptr) {
size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
gradients_.resize(total_size);
hessians_.resize(total_size);
}
// get max feature index
max_feature_idx_ = train_data->num_total_features() - 1;
max_feature_idx_ = train_data_->num_total_features() - 1;
// get label index
label_idx_ = train_data->label_idx();
label_idx_ = train_data_->label_idx();
// get feature names
feature_names_ = train_data->feature_names();
feature_names_ = train_data_->feature_names();
feature_infos_ = train_data_->feature_infos();
ResetBaggingConfig(gbdt_config_.get());
feature_infos_ = train_data->feature_infos();
tree_learner_->ResetTrainingData(train_data);
}
}
if ((train_data_ != train_data && train_data != nullptr)
|| (gbdt_config_ != nullptr && gbdt_config_->bagging_fraction != new_config->bagging_fraction)) {
// if need bagging, create buffer
if (new_config->bagging_fraction < 1.0 && new_config->bagging_freq > 0) {
bag_data_cnt_ =
static_cast<data_size_t>(new_config->bagging_fraction * num_data_);
bag_data_indices_.resize(num_data_);
tmp_indices_.resize(num_data_);
offsets_buf_.resize(num_threads_);
left_cnts_buf_.resize(num_threads_);
right_cnts_buf_.resize(num_threads_);
left_write_pos_buf_.resize(num_threads_);
right_write_pos_buf_.resize(num_threads_);
double average_bag_rate = new_config->bagging_fraction / new_config->bagging_freq;
int sparse_group = 0;
for (int i = 0; i < train_data->num_feature_groups(); ++i) {
if (train_data->FeatureGroupIsSparse(i)) {
++sparse_group;
}
}
is_use_subset_ = false;
const int group_threshold_usesubset = 100;
const int sparse_group_threshold_usesubset = train_data->num_feature_groups() / 4;
if (average_bag_rate <= 0.5
&& (train_data->num_feature_groups() < group_threshold_usesubset || sparse_group < sparse_group_threshold_usesubset)) {
tmp_subset_.reset(new Dataset(bag_data_cnt_));
tmp_subset_->CopyFeatureMapperFrom(train_data);
is_use_subset_ = true;
Log::Debug("use subset for bagging");
void GBDT::ResetConfig(const BoostingConfig* config) {
auto new_config = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
early_stopping_round_ = new_config->early_stopping_round;
shrinkage_rate_ = new_config->learning_rate;
ResetBaggingConfig(new_config.get());
tree_learner_->ResetConfig(&new_config->tree_config);
gbdt_config_.reset(new_config.release());
}
void GBDT::ResetBaggingConfig(const BoostingConfig* config) {
// if need bagging, create buffer
if (config->bagging_fraction < 1.0 && config->bagging_freq > 0) {
bag_data_cnt_ =
static_cast<data_size_t>(config->bagging_fraction * num_data_);
bag_data_indices_.resize(num_data_);
tmp_indices_.resize(num_data_);
offsets_buf_.resize(num_threads_);
left_cnts_buf_.resize(num_threads_);
right_cnts_buf_.resize(num_threads_);
left_write_pos_buf_.resize(num_threads_);
right_write_pos_buf_.resize(num_threads_);
double average_bag_rate = config->bagging_fraction / config->bagging_freq;
int sparse_group = 0;
for (int i = 0; i < train_data_->num_feature_groups(); ++i) {
if (train_data_->FeatureGroupIsSparse(i)) {
++sparse_group;
}
} else {
bag_data_cnt_ = num_data_;
bag_data_indices_.clear();
tmp_indices_.clear();
is_use_subset_ = false;
}
}
train_data_ = train_data;
if (train_data_ != nullptr) {
// reset config for tree learner
tree_learner_->ResetConfig(&new_config->tree_config);
class_need_train_ = std::vector<bool>(num_tree_per_iteration_, true);
if (objective_function_ != nullptr && objective_function_->SkipEmptyClass()) {
CHECK(num_tree_per_iteration_ == num_class_);
// + 1 here for the binary classification
class_default_output_ = std::vector<double>(num_tree_per_iteration_, 0.0f);
auto label = train_data_->metadata().label();
if (num_tree_per_iteration_ > 1) {
// multi-class
std::vector<data_size_t> cnt_per_class(num_tree_per_iteration_, 0);
for (data_size_t i = 0; i < num_data_; ++i) {
int index = static_cast<int>(label[i]);
CHECK(index < num_tree_per_iteration_);
++cnt_per_class[index];
}
for (int i = 0; i < num_tree_per_iteration_; ++i) {
if (cnt_per_class[i] == num_data_) {
class_need_train_[i] = false;
class_default_output_[i] = -std::log(kEpsilon);
} else if (cnt_per_class[i] == 0) {
class_need_train_[i] = false;
class_default_output_[i] = -std::log(1.0f / kEpsilon - 1.0f);
}
}
} else {
// binary class
data_size_t cnt_pos = 0;
for (data_size_t i = 0; i < num_data_; ++i) {
if (label[i] > 0) {
++cnt_pos;
}
}
if (cnt_pos == 0) {
class_need_train_[0] = false;
class_default_output_[0] = -std::log(1.0f / kEpsilon - 1.0f);
} else if (cnt_pos == num_data_) {
class_need_train_[0] = false;
class_default_output_[0] = -std::log(kEpsilon);
}
}
is_use_subset_ = false;
const int group_threshold_usesubset = 100;
const int sparse_group_threshold_usesubset = train_data_->num_feature_groups() / 4;
if (average_bag_rate <= 0.5
&& (train_data_->num_feature_groups() < group_threshold_usesubset || sparse_group < sparse_group_threshold_usesubset)) {
tmp_subset_.reset(new Dataset(bag_data_cnt_));
tmp_subset_->CopyFeatureMapperFrom(train_data_);
is_use_subset_ = true;
Log::Debug("use subset for bagging");
}
} else {
bag_data_cnt_ = num_data_;
bag_data_indices_.clear();
tmp_indices_.clear();
is_use_subset_ = false;
}
gbdt_config_.reset(new_config.release());
}
void GBDT::AddValidDataset(const Dataset* valid_data,
......@@ -358,7 +405,7 @@ double LabelAverage(const float* label, data_size_t num_data) {
Network::Allreduce(reinterpret_cast<char*>(&init_score),
sizeof(init_score), sizeof(init_score),
reinterpret_cast<char*>(&global_init_score),
[](const char* src, char* dst, int len) {
[] (const char* src, char* dst, int len) {
int used_size = 0;
const int type_size = sizeof(double);
const double *p1;
......@@ -833,7 +880,7 @@ bool GBDT::SaveModelToIfElse(int num_iteration, const char* filename) const {
std::ifstream ifs(filename);
if (ifs.good()) {
std::string origin((std::istreambuf_iterator<char>(ifs)),
(std::istreambuf_iterator<char>()));
(std::istreambuf_iterator<char>()));
output_file.open(filename);
output_file << "#define USE_HARD_CODE 0" << std::endl;
output_file << "#ifndef USE_HARD_CODE" << std::endl;
......@@ -1027,8 +1074,8 @@ std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const {
}
// sort the importance
std::sort(pairs.begin(), pairs.end(),
[](const std::pair<size_t, std::string>& lhs,
const std::pair<size_t, std::string>& rhs) {
[] (const std::pair<size_t, std::string>& lhs,
const std::pair<size_t, std::string>& rhs) {
return lhs.first > rhs.first;
});
return pairs;
......
......@@ -63,14 +63,10 @@ public:
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
}
/*!
* \brief Reset training data for current boosting
* \param train_data Training data
* \param objective_function Training objective function
* \param training_metrics Training metric
*/
void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function, const std::vector<const Metric*>& training_metrics) override;
void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override;
void ResetConfig(const BoostingConfig* config) override;
/*!
* \brief Adding a validation dataset
* \param valid_data Validation dataset
......@@ -258,6 +254,7 @@ public:
virtual const char* SubModelName() const override { return "tree"; }
protected:
void ResetBaggingConfig(const BoostingConfig* config);
/*!
* \brief Implement bagging logic
* \param iter Current interation
......
......@@ -41,21 +41,28 @@ public:
void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override {
GBDT::Init(config, train_data, objective_function, training_metrics);
ResetGoss();
}
void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override {
GBDT::ResetTrainingData(train_data, objective_function, training_metrics);
ResetGoss();
}
void ResetConfig(const BoostingConfig* config) override {
GBDT::ResetConfig(config);
ResetGoss();
}
void ResetGoss() {
CHECK(gbdt_config_->top_rate + gbdt_config_->other_rate <= 1.0f);
CHECK(gbdt_config_->top_rate > 0.0f && gbdt_config_->other_rate > 0.0f);
if (gbdt_config_->bagging_freq > 0 && gbdt_config_->bagging_fraction != 1.0f) {
Log::Fatal("cannot use bagging in GOSS");
}
Log::Info("using GOSS");
}
void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override {
if (config->bagging_freq > 0 && config->bagging_fraction != 1.0f) {
Log::Fatal("cannot use bagging in GOSS");
}
GBDT::ResetTrainingData(config, train_data, objective_function, training_metrics);
if (train_data_ == nullptr) { return; }
bag_data_indices_.resize(num_data_);
tmp_indices_.resize(num_data_);
tmp_indice_right_.resize(num_data_);
......@@ -66,8 +73,8 @@ public:
right_write_pos_buf_.resize(num_threads_);
is_use_subset_ = false;
if (config->top_rate + config->other_rate <= 0.5) {
auto bag_data_cnt = static_cast<data_size_t>((config->top_rate + config->other_rate) * num_data_);
if (gbdt_config_->top_rate + gbdt_config_->other_rate <= 0.5) {
auto bag_data_cnt = static_cast<data_size_t>((gbdt_config_->top_rate + gbdt_config_->other_rate) * num_data_);
tmp_subset_.reset(new Dataset(bag_data_cnt));
tmp_subset_->CopyFeatureMapperFrom(train_data_);
is_use_subset_ = true;
......
......@@ -51,11 +51,12 @@ public:
boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
train_data_ = train_data;
CreateObjectiveAndMetrics();
// initialize the boosting
boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
boosting_->Init(&config_.boosting_config, train_data_, objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
ResetTrainingData(train_data);
}
void MergeFrom(const Booster* other) {
......@@ -67,9 +68,7 @@ public:
}
void ResetTrainingData(const Dataset* train_data) {
std::lock_guard<std::mutex> lock(mutex_);
train_data_ = train_data;
void CreateObjectiveAndMetrics() {
// create objective function
objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
config_.objective_config));
......@@ -91,9 +90,17 @@ public:
train_metric_.push_back(std::move(metric));
}
train_metric_.shrink_to_fit();
// reset the boosting
boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
}
void ResetTrainingData(const Dataset* train_data) {
if (train_data != train_data_) {
std::lock_guard<std::mutex> lock(mutex_);
train_data_ = train_data;
CreateObjectiveAndMetrics();
// reset the boosting
boosting_->ResetTrainingData(train_data_,
objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
}
}
void ResetConfig(const char* parameters) {
......@@ -125,10 +132,11 @@ public:
if (objective_fun_ != nullptr) {
objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
}
boosting_->ResetTrainingData(train_data_,
objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
}
boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
boosting_->ResetConfig(&config_.boosting_config);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment