Commit 5e1a5135 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix #865

parent 603bffcf
......@@ -115,7 +115,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
feature_infos_ = train_data_->feature_infos();
// if need bagging, create buffer
ResetBaggingConfig(gbdt_config_.get());
ResetBaggingConfig(gbdt_config_.get(), true);
// reset config for tree learner
class_need_train_ = std::vector<bool>(num_tree_per_iteration_, true);
......@@ -211,7 +211,7 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*
feature_infos_ = train_data_->feature_infos();
ResetBaggingConfig(gbdt_config_.get());
ResetBaggingConfig(gbdt_config_.get(), true);
tree_learner_->ResetTrainingData(train_data);
}
......@@ -222,13 +222,13 @@ void GBDT::ResetConfig(const BoostingConfig* config) {
early_stopping_round_ = new_config->early_stopping_round;
shrinkage_rate_ = new_config->learning_rate;
if (tree_learner_ != nullptr) {
ResetBaggingConfig(new_config.get());
ResetBaggingConfig(new_config.get(), false);
tree_learner_->ResetConfig(&new_config->tree_config);
}
gbdt_config_.reset(new_config.release());
}
void GBDT::ResetBaggingConfig(const BoostingConfig* config) {
void GBDT::ResetBaggingConfig(const BoostingConfig* config, bool is_change_dataset) {
// if need bagging, create buffer
if (config->bagging_fraction < 1.0 && config->bagging_freq > 0) {
bag_data_cnt_ =
......@@ -252,8 +252,10 @@ void GBDT::ResetBaggingConfig(const BoostingConfig* config) {
const int sparse_group_threshold_usesubset = train_data_->num_feature_groups() / 4;
if (average_bag_rate <= 0.5
&& (train_data_->num_feature_groups() < group_threshold_usesubset || sparse_group < sparse_group_threshold_usesubset)) {
tmp_subset_.reset(new Dataset(bag_data_cnt_));
tmp_subset_->CopyFeatureMapperFrom(train_data_);
if (tmp_subset_ == nullptr || is_change_dataset) {
tmp_subset_.reset(new Dataset(bag_data_cnt_));
tmp_subset_->CopyFeatureMapperFrom(train_data_);
}
is_use_subset_ = true;
Log::Debug("use subset for bagging");
}
......@@ -263,6 +265,10 @@ void GBDT::ResetBaggingConfig(const BoostingConfig* config) {
tmp_indices_.clear();
is_use_subset_ = false;
}
if (is_change_dataset) {
need_re_bagging_ = true;
}
}
void GBDT::AddValidDataset(const Dataset* valid_data,
......@@ -322,7 +328,8 @@ data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t
void GBDT::Bagging(int iter) {
// if need bagging
if (bag_data_cnt_ < num_data_ && iter % gbdt_config_->bagging_freq == 0) {
if ( (bag_data_cnt_ < num_data_ && iter % gbdt_config_->bagging_freq == 0) || need_re_bagging_) {
need_re_bagging_ = false;
const data_size_t min_inner_size = 1000;
data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_;
if (inner_size < min_inner_size) { inner_size = min_inner_size; }
......
......@@ -270,7 +270,7 @@ public:
virtual const char* SubModelName() const override { return "tree"; }
protected:
void ResetBaggingConfig(const BoostingConfig* config);
void ResetBaggingConfig(const BoostingConfig* config, bool is_change_dataset);
/*!
* \brief Implement bagging logic
* \param iter Current interation
......@@ -388,6 +388,7 @@ protected:
bool is_constant_hessian_;
std::unique_ptr<ObjectiveFunction> loaded_objective_;
bool average_output_;
bool need_re_bagging_ = false;
};
} // namespace LightGBM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment