"...include/git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "c2d90bddc6a2a562ee7750c14351e9ca16a6a37a"
Commit cdba7147 authored by Guolin Ke's avatar Guolin Ke Committed by Qiwei Ye
Browse files

balanced bagging (#2214)

* add balanced bagging

* refine code

* fix format

* clarify usage only for binary application
parent 7da11ffe
...@@ -210,6 +210,38 @@ Learning Control Parameters ...@@ -210,6 +210,38 @@ Learning Control Parameters
- **Note**: to enable bagging, ``bagging_freq`` should be set to a non zero value as well - **Note**: to enable bagging, ``bagging_freq`` should be set to a non zero value as well
- ``pos_bagging_fraction`` :raw-html:`<a id="pos_bagging_fraction" title="Permalink to this parameter" href="#pos_bagging_fraction">&#x1F517;&#xFE0E;</a>`, default = ``1.0``, type = double, aliases: ``pos_sub_row``, ``pos_subsample``, ``pos_bagging``, constraints: ``0.0 < pos_bagging_fraction <= 1.0``
- used only in ``binary`` application
- used for imbalanced binary classification problem, will randomly sample ``#pos_samples * pos_bagging_fraction`` positive samples in bagging
- should be used together with ``neg_bagging_fraction``
- set this to ``1.0`` to disable
- **Note**: to enable this, you need to set ``bagging_freq`` and ``neg_bagging_fraction`` as well
- **Note**: if both ``pos_bagging_fraction`` and ``neg_bagging_fraction`` are set to ``1.0``, balanced bagging is disabled
- **Note**: if balanced bagging is enabled, ``bagging_fraction`` will be ignored
- ``neg_bagging_fraction`` :raw-html:`<a id="neg_bagging_fraction" title="Permalink to this parameter" href="#neg_bagging_fraction">&#x1F517;&#xFE0E;</a>`, default = ``1.0``, type = double, aliases: ``neg_sub_row``, ``neg_subsample``, ``neg_bagging``, constraints: ``0.0 < neg_bagging_fraction <= 1.0``
- used only in ``binary`` application
- used for imbalanced binary classification problem, will randomly sample ``#neg_samples * neg_bagging_fraction`` negative samples in bagging
- should be used together with ``pos_bagging_fraction``
- set this to ``1.0`` to disable
- **Note**: to enable this, you need to set ``bagging_freq`` and ``pos_bagging_fraction`` as well
- **Note**: if both ``pos_bagging_fraction`` and ``neg_bagging_fraction`` are set to ``1.0``, balanced bagging is disabled
- **Note**: if balanced bagging is enabled, ``bagging_fraction`` will be ignored
- ``bagging_freq`` :raw-html:`<a id="bagging_freq" title="Permalink to this parameter" href="#bagging_freq">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int, aliases: ``subsample_freq`` - ``bagging_freq`` :raw-html:`<a id="bagging_freq" title="Permalink to this parameter" href="#bagging_freq">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int, aliases: ``subsample_freq``
- frequency for bagging - frequency for bagging
......
...@@ -234,6 +234,30 @@ struct Config { ...@@ -234,6 +234,30 @@ struct Config {
// desc = **Note**: to enable bagging, ``bagging_freq`` should be set to a non zero value as well // desc = **Note**: to enable bagging, ``bagging_freq`` should be set to a non zero value as well
double bagging_fraction = 1.0; double bagging_fraction = 1.0;
// alias = pos_sub_row, pos_subsample, pos_bagging
// check = >0.0
// check = <=1.0
// desc = used only in ``binary`` application
// desc = used for imbalanced binary classification problem, will randomly sample ``#pos_samples * pos_bagging_fraction`` positive samples in bagging
// desc = should be used together with ``neg_bagging_fraction``
// desc = set this to ``1.0`` to disable
// desc = **Note**: to enable this, you need to set ``bagging_freq`` and ``neg_bagging_fraction`` as well
// desc = **Note**: if both ``pos_bagging_fraction`` and ``neg_bagging_fraction`` are set to ``1.0``, balanced bagging is disabled
// desc = **Note**: if balanced bagging is enabled, ``bagging_fraction`` will be ignored
double pos_bagging_fraction = 1.0;
// alias = neg_sub_row, neg_subsample, neg_bagging
// check = >0.0
// check = <=1.0
// desc = used only in ``binary`` application
// desc = used for imbalanced binary classification problem, will randomly sample ``#neg_samples * neg_bagging_fraction`` negative samples in bagging
// desc = should be used together with ``pos_bagging_fraction``
// desc = set this to ``1.0`` to disable
// desc = **Note**: to enable this, you need to set ``bagging_freq`` and ``pos_bagging_fraction`` as well
// desc = **Note**: if both ``pos_bagging_fraction`` and ``neg_bagging_fraction`` are set to ``1.0``, balanced bagging is disabled
// desc = **Note**: if balanced bagging is enabled, ``bagging_fraction`` will be ignored
double neg_bagging_fraction = 1.0;
// alias = subsample_freq // alias = subsample_freq
// desc = frequency for bagging // desc = frequency for bagging
// desc = ``0`` means disable bagging; ``k`` means perform bagging at every ``k`` iteration // desc = ``0`` means disable bagging; ``k`` means perform bagging at every ``k`` iteration
......
...@@ -61,6 +61,9 @@ class ObjectiveFunction { ...@@ -61,6 +61,9 @@ class ObjectiveFunction {
/*! \brief The prediction should be accurate or not. True will disable early stopping for prediction. */ /*! \brief The prediction should be accurate or not. True will disable early stopping for prediction. */
virtual bool NeedAccuratePrediction() const { return true; } virtual bool NeedAccuratePrediction() const { return true; }
/*! \brief Return the number of positive samples. Return 0 if no binary classification tasks.*/
virtual data_size_t NumPositiveData() const { return 0; }
virtual void ConvertOutput(const double* input, double* output) const { virtual void ConvertOutput(const double* input, double* output) const {
output[0] = input[0]; output[0] = input[0];
} }
......
...@@ -29,7 +29,8 @@ num_class_(1), ...@@ -29,7 +29,8 @@ num_class_(1),
num_iteration_for_pred_(0), num_iteration_for_pred_(0),
shrinkage_rate_(0.1f), shrinkage_rate_(0.1f),
num_init_iteration_(0), num_init_iteration_(0),
need_re_bagging_(false) { need_re_bagging_(false),
balanced_bagging_(false) {
#pragma omp parallel #pragma omp parallel
#pragma omp master #pragma omp master
{ {
...@@ -176,6 +177,35 @@ data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t ...@@ -176,6 +177,35 @@ data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t
return cur_left_cnt; return cur_left_cnt;
} }
data_size_t GBDT::BalancedBaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer) {
if (cnt <= 0) {
return 0;
}
auto label_ptr = train_data_->metadata().label();
data_size_t cur_left_cnt = 0;
data_size_t cur_right_pos = cnt - 1;
// from right to left
auto right_buffer = buffer;
// random bagging, minimal unit is one record
for (data_size_t i = 0; i < cnt; ++i) {
bool is_pos = label_ptr[start + i] > 0;
bool is_in_bag = false;
if (is_pos) {
is_in_bag = cur_rand.NextFloat() < config_->pos_bagging_fraction;
} else {
is_in_bag = cur_rand.NextFloat() < config_->neg_bagging_fraction;
}
if (is_in_bag) {
buffer[cur_left_cnt++] = start + i;
} else {
right_buffer[cur_right_pos--] = start + i;
}
}
// reverse right buffer
std::reverse(buffer + cur_left_cnt, buffer + cnt);
return cur_left_cnt;
}
void GBDT::Bagging(int iter) { void GBDT::Bagging(int iter) {
// if need bagging // if need bagging
if ((bag_data_cnt_ < num_data_ && iter % config_->bagging_freq == 0) if ((bag_data_cnt_ < num_data_ && iter % config_->bagging_freq == 0)
...@@ -195,7 +225,12 @@ void GBDT::Bagging(int iter) { ...@@ -195,7 +225,12 @@ void GBDT::Bagging(int iter) {
data_size_t cur_cnt = inner_size; data_size_t cur_cnt = inner_size;
if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; } if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; }
Random cur_rand(config_->bagging_seed + iter * num_threads_ + i); Random cur_rand(config_->bagging_seed + iter * num_threads_ + i);
data_size_t cur_left_count = BaggingHelper(cur_rand, cur_start, cur_cnt, tmp_indices_.data() + cur_start); data_size_t cur_left_count = 0;
if (balanced_bagging_) {
cur_left_count = BalancedBaggingHelper(cur_rand, cur_start, cur_cnt, tmp_indices_.data() + cur_start);
} else {
cur_left_count = BaggingHelper(cur_rand, cur_start, cur_cnt, tmp_indices_.data() + cur_start);
}
offsets_buf_[i] = cur_start; offsets_buf_[i] = cur_start;
left_cnts_buf_[i] = cur_left_count; left_cnts_buf_[i] = cur_left_count;
right_cnts_buf_[i] = cur_cnt - cur_left_count; right_cnts_buf_[i] = cur_cnt - cur_left_count;
...@@ -690,14 +725,25 @@ void GBDT::ResetConfig(const Config* config) { ...@@ -690,14 +725,25 @@ void GBDT::ResetConfig(const Config* config) {
void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) { void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) {
// if need bagging, create buffer // if need bagging, create buffer
if (config->bagging_fraction < 1.0 && config->bagging_freq > 0) { data_size_t num_pos_data = 0;
if (objective_function_ != nullptr) {
num_pos_data = objective_function_->NumPositiveData();
}
bool balance_bagging_cond = (config->pos_bagging_fraction < 1.0 || config->neg_bagging_fraction < 1.0) && (num_pos_data > 0);
if ((config->bagging_fraction < 1.0 || balance_bagging_cond) && config->bagging_freq > 0) {
need_re_bagging_ = false; need_re_bagging_ = false;
if (!is_change_dataset && if (!is_change_dataset &&
config_.get() != nullptr && config_->bagging_fraction == config->bagging_fraction && config_->bagging_freq == config->bagging_freq) { config_.get() != nullptr && config_->bagging_fraction == config->bagging_fraction && config_->bagging_freq == config->bagging_freq
&& config_->pos_bagging_fraction == config->pos_bagging_fraction && config_->neg_bagging_fraction == config->neg_bagging_fraction) {
return; return;
} }
bag_data_cnt_ = if (balance_bagging_cond) {
static_cast<data_size_t>(config->bagging_fraction * num_data_); balanced_bagging_ = true;
bag_data_cnt_ = static_cast<data_size_t>(num_pos_data * config->pos_bagging_fraction)
+ static_cast<data_size_t>((num_data_ - num_pos_data) * config->neg_bagging_fraction);
} else {
bag_data_cnt_ = static_cast<data_size_t>(config->bagging_fraction * num_data_);
}
bag_data_indices_.resize(num_data_); bag_data_indices_.resize(num_data_);
tmp_indices_.resize(num_data_); tmp_indices_.resize(num_data_);
...@@ -707,7 +753,7 @@ void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) { ...@@ -707,7 +753,7 @@ void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) {
left_write_pos_buf_.resize(num_threads_); left_write_pos_buf_.resize(num_threads_);
right_write_pos_buf_.resize(num_threads_); right_write_pos_buf_.resize(num_threads_);
double average_bag_rate = config->bagging_fraction / config->bagging_freq; double average_bag_rate = (bag_data_cnt_ / num_data_) / config->bagging_freq;
int sparse_group = 0; int sparse_group = 0;
for (int i = 0; i < train_data_->num_feature_groups(); ++i) { for (int i = 0; i < train_data_->num_feature_groups(); ++i) {
if (train_data_->FeatureGroupIsSparse(i)) { if (train_data_->FeatureGroupIsSparse(i)) {
......
...@@ -387,6 +387,16 @@ class GBDT : public GBDTBase { ...@@ -387,6 +387,16 @@ class GBDT : public GBDTBase {
*/ */
data_size_t BaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer); data_size_t BaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer);
/*!
* \brief Helper function for bagging, used for multi-threading optimization, balanced sampling
* \param start start indice of bagging
* \param cnt count
* \param buffer output buffer
* \return count of left size
*/
data_size_t BalancedBaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer);
/*! /*!
* \brief calculate the object function * \brief calculate the object function
*/ */
...@@ -492,6 +502,7 @@ class GBDT : public GBDTBase { ...@@ -492,6 +502,7 @@ class GBDT : public GBDTBase {
std::unique_ptr<ObjectiveFunction> loaded_objective_; std::unique_ptr<ObjectiveFunction> loaded_objective_;
bool average_output_; bool average_output_;
bool need_re_bagging_; bool need_re_bagging_;
bool balanced_bagging_;
std::string loaded_parameter_; std::string loaded_parameter_;
Json forced_splits_json_; Json forced_splits_json_;
......
...@@ -58,6 +58,12 @@ std::unordered_map<std::string, std::string> Config::alias_table({ ...@@ -58,6 +58,12 @@ std::unordered_map<std::string, std::string> Config::alias_table({
{"sub_row", "bagging_fraction"}, {"sub_row", "bagging_fraction"},
{"subsample", "bagging_fraction"}, {"subsample", "bagging_fraction"},
{"bagging", "bagging_fraction"}, {"bagging", "bagging_fraction"},
{"pos_sub_row", "pos_bagging_fraction"},
{"pos_subsample", "pos_bagging_fraction"},
{"pos_bagging", "pos_bagging_fraction"},
{"neg_sub_row", "neg_bagging_fraction"},
{"neg_subsample", "neg_bagging_fraction"},
{"neg_bagging", "neg_bagging_fraction"},
{"subsample_freq", "bagging_freq"}, {"subsample_freq", "bagging_freq"},
{"bagging_fraction_seed", "bagging_seed"}, {"bagging_fraction_seed", "bagging_seed"},
{"sub_feature", "feature_fraction"}, {"sub_feature", "feature_fraction"},
...@@ -176,6 +182,8 @@ std::unordered_set<std::string> Config::parameter_set({ ...@@ -176,6 +182,8 @@ std::unordered_set<std::string> Config::parameter_set({
"min_data_in_leaf", "min_data_in_leaf",
"min_sum_hessian_in_leaf", "min_sum_hessian_in_leaf",
"bagging_fraction", "bagging_fraction",
"pos_bagging_fraction",
"neg_bagging_fraction",
"bagging_freq", "bagging_freq",
"bagging_seed", "bagging_seed",
"feature_fraction", "feature_fraction",
...@@ -302,6 +310,14 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str ...@@ -302,6 +310,14 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
CHECK(bagging_fraction >0.0); CHECK(bagging_fraction >0.0);
CHECK(bagging_fraction <=1.0); CHECK(bagging_fraction <=1.0);
GetDouble(params, "pos_bagging_fraction", &pos_bagging_fraction);
CHECK(pos_bagging_fraction >0.0);
CHECK(pos_bagging_fraction <=1.0);
GetDouble(params, "neg_bagging_fraction", &neg_bagging_fraction);
CHECK(neg_bagging_fraction >0.0);
CHECK(neg_bagging_fraction <=1.0);
GetInt(params, "bagging_freq", &bagging_freq); GetInt(params, "bagging_freq", &bagging_freq);
GetInt(params, "bagging_seed", &bagging_seed); GetInt(params, "bagging_seed", &bagging_seed);
...@@ -558,6 +574,8 @@ std::string Config::SaveMembersToString() const { ...@@ -558,6 +574,8 @@ std::string Config::SaveMembersToString() const {
str_buf << "[min_data_in_leaf: " << min_data_in_leaf << "]\n"; str_buf << "[min_data_in_leaf: " << min_data_in_leaf << "]\n";
str_buf << "[min_sum_hessian_in_leaf: " << min_sum_hessian_in_leaf << "]\n"; str_buf << "[min_sum_hessian_in_leaf: " << min_sum_hessian_in_leaf << "]\n";
str_buf << "[bagging_fraction: " << bagging_fraction << "]\n"; str_buf << "[bagging_fraction: " << bagging_fraction << "]\n";
str_buf << "[pos_bagging_fraction: " << pos_bagging_fraction << "]\n";
str_buf << "[neg_bagging_fraction: " << neg_bagging_fraction << "]\n";
str_buf << "[bagging_freq: " << bagging_freq << "]\n"; str_buf << "[bagging_freq: " << bagging_freq << "]\n";
str_buf << "[bagging_seed: " << bagging_seed << "]\n"; str_buf << "[bagging_seed: " << bagging_seed << "]\n";
str_buf << "[feature_fraction: " << feature_fraction << "]\n"; str_buf << "[feature_fraction: " << feature_fraction << "]\n";
......
...@@ -96,6 +96,7 @@ class BinaryLogloss: public ObjectiveFunction { ...@@ -96,6 +96,7 @@ class BinaryLogloss: public ObjectiveFunction {
} }
} }
label_weights_[1] *= scale_pos_weight_; label_weights_[1] *= scale_pos_weight_;
num_pos_data_ = cnt_positive;
} }
void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override { void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
...@@ -179,9 +180,13 @@ class BinaryLogloss: public ObjectiveFunction { ...@@ -179,9 +180,13 @@ class BinaryLogloss: public ObjectiveFunction {
bool NeedAccuratePrediction() const override { return false; } bool NeedAccuratePrediction() const override { return false; }
data_size_t NumPositiveData() const override { return num_pos_data_; }
private: private:
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Number of positive samples */
data_size_t num_pos_data_;
/*! \brief Pointer of label */ /*! \brief Pointer of label */
const label_t* label_; const label_t* label_;
/*! \brief True if using unbalance training */ /*! \brief True if using unbalance training */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment