Unverified Commit 46d21476 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix a bug when bagging with reset_config (#2149)

* fix a bug when bagging with reset_config

* clean code
parent 2c41d15e
...@@ -43,18 +43,11 @@ class ObjectiveFunction { ...@@ -43,18 +43,11 @@ class ObjectiveFunction {
virtual bool IsRenewTreeOutput() const { return false; } virtual bool IsRenewTreeOutput() const { return false; }
virtual double RenewTreeOutput(double ori_output, const double*, virtual double RenewTreeOutput(double ori_output, std::function<double(const label_t*, int)> residual_getter,
const data_size_t*, const data_size_t*,
const data_size_t*, const data_size_t*,
data_size_t) const { return ori_output; } data_size_t) const { return ori_output; }
virtual double RenewTreeOutput(double ori_output, double,
const data_size_t*,
const data_size_t*,
data_size_t) const {
return ori_output;
}
virtual double BoostFromScore(int /*class_id*/) const { return 0.0; } virtual double BoostFromScore(int /*class_id*/) const { return 0.0; }
virtual bool ClassNeedTrain(int /*class_id*/) const { return true; } virtual bool ClassNeedTrain(int /*class_id*/) const { return true; }
......
...@@ -77,10 +77,7 @@ class TreeLearner { ...@@ -77,10 +77,7 @@ class TreeLearner {
*/ */
virtual void AddPredictionToScore(const Tree* tree, double* out_score) const = 0; virtual void AddPredictionToScore(const Tree* tree, double* out_score) const = 0;
virtual void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, const double* prediction, virtual void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, std::function<double(const label_t*, int)> residual_getter,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const = 0;
virtual void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, double prediction,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const = 0; data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const = 0;
TreeLearner() = default; TreeLearner() = default;
......
...@@ -364,7 +364,9 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) { ...@@ -364,7 +364,9 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
if (new_tree->num_leaves() > 1) { if (new_tree->num_leaves() > 1) {
should_continue = true; should_continue = true;
tree_learner_->RenewTreeOutput(new_tree.get(), objective_function_, train_score_updater_->score() + bias, auto score_ptr = train_score_updater_->score() + bias;
auto residual_getter = [score_ptr](const label_t* label, int i) {return static_cast<double>(label[i]) - score_ptr[i]; };
tree_learner_->RenewTreeOutput(new_tree.get(), objective_function_, residual_getter,
num_data_, bag_data_indices_.data(), bag_data_cnt_); num_data_, bag_data_indices_.data(), bag_data_cnt_);
// shrinkage by learning rate // shrinkage by learning rate
new_tree->Shrinkage(shrinkage_rate_); new_tree->Shrinkage(shrinkage_rate_);
...@@ -688,6 +690,11 @@ void GBDT::ResetConfig(const Config* config) { ...@@ -688,6 +690,11 @@ void GBDT::ResetConfig(const Config* config) {
void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) { void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) {
// if need bagging, create buffer // if need bagging, create buffer
if (config->bagging_fraction < 1.0 && config->bagging_freq > 0) { if (config->bagging_fraction < 1.0 && config->bagging_freq > 0) {
need_re_bagging_ = false;
if (!is_change_dataset &&
config_.get() != nullptr && config_->bagging_fraction == config->bagging_fraction && config_->bagging_freq == config->bagging_freq) {
return;
}
bag_data_cnt_ = bag_data_cnt_ =
static_cast<data_size_t>(config->bagging_fraction * num_data_); static_cast<data_size_t>(config->bagging_fraction * num_data_);
bag_data_indices_.resize(num_data_); bag_data_indices_.resize(num_data_);
...@@ -719,9 +726,7 @@ void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) { ...@@ -719,9 +726,7 @@ void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) {
Log::Debug("Use subset for bagging"); Log::Debug("Use subset for bagging");
} }
if (is_change_dataset) {
need_re_bagging_ = true; need_re_bagging_ = true;
}
if (is_use_subset_ && bag_data_cnt_ < num_data_) { if (is_use_subset_ && bag_data_cnt_ < num_data_) {
if (objective_function_ == nullptr) { if (objective_function_ == nullptr) {
......
...@@ -130,7 +130,9 @@ class RF : public GBDT { ...@@ -130,7 +130,9 @@ class RF : public GBDT {
} }
if (new_tree->num_leaves() > 1) { if (new_tree->num_leaves() > 1) {
tree_learner_->RenewTreeOutput(new_tree.get(), objective_function_, init_scores_[cur_tree_id], double pred = init_scores_[cur_tree_id];
auto residual_getter = [pred](const label_t* label, int i) {return static_cast<double>(label[i]) - pred; };
tree_learner_->RenewTreeOutput(new_tree.get(), objective_function_, residual_getter,
num_data_, bag_data_indices_.data(), bag_data_cnt_); num_data_, bag_data_indices_.data(), bag_data_cnt_);
if (std::fabs(init_scores_[cur_tree_id]) > kEpsilon) { if (std::fabs(init_scores_[cur_tree_id]) > kEpsilon) {
new_tree->AddBias(init_scores_[cur_tree_id]); new_tree->AddBias(init_scores_[cur_tree_id]);
......
...@@ -232,62 +232,30 @@ class RegressionL1loss: public RegressionL2loss { ...@@ -232,62 +232,30 @@ class RegressionL1loss: public RegressionL2loss {
bool IsRenewTreeOutput() const override { return true; } bool IsRenewTreeOutput() const override { return true; }
double RenewTreeOutput(double, const double* pred, double RenewTreeOutput(double, std::function<double(const label_t*, int)> residual_getter,
const data_size_t* index_mapper, const data_size_t* index_mapper,
const data_size_t* bagging_mapper, const data_size_t* bagging_mapper,
data_size_t num_data_in_leaf) const override { data_size_t num_data_in_leaf) const override {
const double alpha = 0.5; const double alpha = 0.5;
if (weights_ == nullptr) { if (weights_ == nullptr) {
if (bagging_mapper == nullptr) { if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred[index_mapper[i]]) #define data_reader(i) (residual_getter(label_,index_mapper[i]))
PercentileFun(double, data_reader, num_data_in_leaf, alpha); PercentileFun(double, data_reader, num_data_in_leaf, alpha);
#undef data_reader #undef data_reader
} else { } else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred[bagging_mapper[index_mapper[i]]]) #define data_reader(i) (residual_getter(label_,bagging_mapper[index_mapper[i]]))
PercentileFun(double, data_reader, num_data_in_leaf, alpha); PercentileFun(double, data_reader, num_data_in_leaf, alpha);
#undef data_reader #undef data_reader
} }
} else { } else {
if (bagging_mapper == nullptr) { if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred[index_mapper[i]]) #define data_reader(i) (residual_getter(label_,index_mapper[i]))
#define weight_reader(i) (weights_[index_mapper[i]]) #define weight_reader(i) (weights_[index_mapper[i]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha); WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha);
#undef data_reader #undef data_reader
#undef weight_reader #undef weight_reader
} else { } else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred[bagging_mapper[index_mapper[i]]]) #define data_reader(i) (residual_getter(label_,bagging_mapper[index_mapper[i]]))
#define weight_reader(i) (weights_[bagging_mapper[index_mapper[i]]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha);
#undef data_reader
#undef weight_reader
}
}
}
double RenewTreeOutput(double, double pred,
const data_size_t* index_mapper,
const data_size_t* bagging_mapper,
data_size_t num_data_in_leaf) const override {
const double alpha = 0.5;
if (weights_ == nullptr) {
if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred)
PercentileFun(double, data_reader, num_data_in_leaf, alpha);
#undef data_reader
} else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred)
PercentileFun(double, data_reader, num_data_in_leaf, alpha);
#undef data_reader
}
} else {
if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred)
#define weight_reader(i) (weights_[index_mapper[i]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha);
#undef data_reader
#undef weight_reader
} else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred)
#define weight_reader(i) (weights_[bagging_mapper[index_mapper[i]]]) #define weight_reader(i) (weights_[bagging_mapper[index_mapper[i]]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha); WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha);
#undef data_reader #undef data_reader
...@@ -552,60 +520,29 @@ class RegressionQuantileloss : public RegressionL2loss { ...@@ -552,60 +520,29 @@ class RegressionQuantileloss : public RegressionL2loss {
bool IsRenewTreeOutput() const override { return true; } bool IsRenewTreeOutput() const override { return true; }
double RenewTreeOutput(double, const double* pred, double RenewTreeOutput(double, std::function<double(const label_t*, int)> residual_getter,
const data_size_t* index_mapper,
const data_size_t* bagging_mapper,
data_size_t num_data_in_leaf) const override {
if (weights_ == nullptr) {
if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred[index_mapper[i]])
PercentileFun(double, data_reader, num_data_in_leaf, alpha_);
#undef data_reader
} else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred[bagging_mapper[index_mapper[i]]])
PercentileFun(double, data_reader, num_data_in_leaf, alpha_);
#undef data_reader
}
} else {
if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred[index_mapper[i]])
#define weight_reader(i) (weights_[index_mapper[i]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha_);
#undef data_reader
#undef weight_reader
} else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred[bagging_mapper[index_mapper[i]]])
#define weight_reader(i) (weights_[bagging_mapper[index_mapper[i]]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha_);
#undef data_reader
#undef weight_reader
}
}
}
double RenewTreeOutput(double, double pred,
const data_size_t* index_mapper, const data_size_t* index_mapper,
const data_size_t* bagging_mapper, const data_size_t* bagging_mapper,
data_size_t num_data_in_leaf) const override { data_size_t num_data_in_leaf) const override {
if (weights_ == nullptr) { if (weights_ == nullptr) {
if (bagging_mapper == nullptr) { if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred) #define data_reader(i) (residual_getter(label_,index_mapper[i]))
PercentileFun(double, data_reader, num_data_in_leaf, alpha_); PercentileFun(double, data_reader, num_data_in_leaf, alpha_);
#undef data_reader #undef data_reader
} else { } else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred) #define data_reader(i) (residual_getter(label_,bagging_mapper[index_mapper[i]]))
PercentileFun(double, data_reader, num_data_in_leaf, alpha_); PercentileFun(double, data_reader, num_data_in_leaf, alpha_);
#undef data_reader #undef data_reader
} }
} else { } else {
if (bagging_mapper == nullptr) { if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred) #define data_reader(i) (residual_getter(label_,index_mapper[i]))
#define weight_reader(i) (weights_[index_mapper[i]]) #define weight_reader(i) (weights_[index_mapper[i]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha_); WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha_);
#undef data_reader #undef data_reader
#undef weight_reader #undef weight_reader
} else { } else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred) #define data_reader(i) (residual_getter(label_,bagging_mapper[index_mapper[i]]))
#define weight_reader(i) (weights_[bagging_mapper[index_mapper[i]]]) #define weight_reader(i) (weights_[bagging_mapper[index_mapper[i]]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha_); WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha_);
#undef data_reader #undef data_reader
...@@ -684,39 +621,19 @@ class RegressionMAPELOSS : public RegressionL1loss { ...@@ -684,39 +621,19 @@ class RegressionMAPELOSS : public RegressionL1loss {
bool IsRenewTreeOutput() const override { return true; } bool IsRenewTreeOutput() const override { return true; }
double RenewTreeOutput(double, const double* pred, double RenewTreeOutput(double, std::function<double(const label_t*, int)> residual_getter,
const data_size_t* index_mapper,
const data_size_t* bagging_mapper,
data_size_t num_data_in_leaf) const override {
const double alpha = 0.5;
if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred[index_mapper[i]])
#define weight_reader(i) (label_weight_[index_mapper[i]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha);
#undef data_reader
#undef weight_reader
} else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred[bagging_mapper[index_mapper[i]]])
#define weight_reader(i) (label_weight_[bagging_mapper[index_mapper[i]]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha);
#undef data_reader
#undef weight_reader
}
}
double RenewTreeOutput(double, double pred,
const data_size_t* index_mapper, const data_size_t* index_mapper,
const data_size_t* bagging_mapper, const data_size_t* bagging_mapper,
data_size_t num_data_in_leaf) const override { data_size_t num_data_in_leaf) const override {
const double alpha = 0.5; const double alpha = 0.5;
if (bagging_mapper == nullptr) { if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred) #define data_reader(i) (residual_getter(label_,index_mapper[i]))
#define weight_reader(i) (label_weight_[index_mapper[i]]) #define weight_reader(i) (label_weight_[index_mapper[i]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha); WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha);
#undef data_reader #undef data_reader
#undef weight_reader #undef weight_reader
} else { } else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred) #define data_reader(i) (residual_getter(label_,bagging_mapper[index_mapper[i]]))
#define weight_reader(i) (label_weight_[bagging_mapper[index_mapper[i]]]) #define weight_reader(i) (label_weight_[bagging_mapper[index_mapper[i]]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha); WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha);
#undef data_reader #undef data_reader
......
...@@ -851,7 +851,7 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri ...@@ -851,7 +851,7 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
} }
void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, const double* prediction, void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, std::function<double(const label_t*, int)> residual_getter,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const { data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const {
if (obj != nullptr && obj->IsRenewTreeOutput()) { if (obj != nullptr && obj->IsRenewTreeOutput()) {
CHECK(tree->num_leaves() <= data_partition_->num_leaves()); CHECK(tree->num_leaves() <= data_partition_->num_leaves());
...@@ -869,47 +869,7 @@ void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj ...@@ -869,47 +869,7 @@ void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj
auto index_mapper = data_partition_->GetIndexOnLeaf(i, &cnt_leaf_data); auto index_mapper = data_partition_->GetIndexOnLeaf(i, &cnt_leaf_data);
if (cnt_leaf_data > 0) { if (cnt_leaf_data > 0) {
// bag_mapper[index_mapper[i]] // bag_mapper[index_mapper[i]]
const double new_output = obj->RenewTreeOutput(output, prediction, index_mapper, bag_mapper, cnt_leaf_data); const double new_output = obj->RenewTreeOutput(output, residual_getter, index_mapper, bag_mapper, cnt_leaf_data);
tree->SetLeafOutput(i, new_output);
} else {
CHECK(num_machines > 1);
tree->SetLeafOutput(i, 0.0);
n_nozeroworker_perleaf[i] = 0;
}
}
if (num_machines > 1) {
std::vector<double> outputs(tree->num_leaves());
for (int i = 0; i < tree->num_leaves(); ++i) {
outputs[i] = static_cast<double>(tree->LeafOutput(i));
}
Network::GlobalSum(outputs);
Network::GlobalSum(n_nozeroworker_perleaf);
for (int i = 0; i < tree->num_leaves(); ++i) {
tree->SetLeafOutput(i, outputs[i] / n_nozeroworker_perleaf[i]);
}
}
}
}
void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, double prediction,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const {
if (obj != nullptr && obj->IsRenewTreeOutput()) {
CHECK(tree->num_leaves() <= data_partition_->num_leaves());
const data_size_t* bag_mapper = nullptr;
if (total_num_data != num_data_) {
CHECK(bag_cnt == num_data_);
bag_mapper = bag_indices;
}
std::vector<int> n_nozeroworker_perleaf(tree->num_leaves(), 1);
int num_machines = Network::num_machines();
#pragma omp parallel for schedule(static)
for (int i = 0; i < tree->num_leaves(); ++i) {
const double output = static_cast<double>(tree->LeafOutput(i));
data_size_t cnt_leaf_data = 0;
auto index_mapper = data_partition_->GetIndexOnLeaf(i, &cnt_leaf_data);
if (cnt_leaf_data > 0) {
// bag_mapper[index_mapper[i]]
const double new_output = obj->RenewTreeOutput(output, prediction, index_mapper, bag_mapper, cnt_leaf_data);
tree->SetLeafOutput(i, new_output); tree->SetLeafOutput(i, new_output);
} else { } else {
CHECK(num_machines > 1); CHECK(num_machines > 1);
......
...@@ -74,10 +74,7 @@ class SerialTreeLearner: public TreeLearner { ...@@ -74,10 +74,7 @@ class SerialTreeLearner: public TreeLearner {
} }
} }
void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, const double* prediction, void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, std::function<double(const label_t*, int)> residual_getter,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const override;
void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, double prediction,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const override; data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const override;
protected: protected:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment