"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "03469ae59bb91a64b2c8ff7e2de7377b23c80c53"
Unverified Commit cba82447 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

Fix bugs in RF (#1906)

* fix RF's bugs

* fix tests

* rollback num_iterations

* fix a bug and reduce memory costs

* reduce memory cost
parent 0c5f390a
...@@ -42,6 +42,13 @@ public: ...@@ -42,6 +42,13 @@ public:
const data_size_t*, const data_size_t*,
data_size_t) const { return ori_output; } data_size_t) const { return ori_output; }
virtual double RenewTreeOutput(double ori_output, double,
const data_size_t*,
const data_size_t*,
data_size_t) const {
return ori_output;
}
virtual double BoostFromScore(int /*class_id*/) const { return 0.0; } virtual double BoostFromScore(int /*class_id*/) const { return 0.0; }
virtual bool ClassNeedTrain(int /*class_id*/) const { return true; } virtual bool ClassNeedTrain(int /*class_id*/) const { return true; }
......
...@@ -75,6 +75,9 @@ public: ...@@ -75,6 +75,9 @@ public:
virtual void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, const double* prediction, virtual void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, const double* prediction,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const = 0; data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const = 0;
virtual void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, double prediction,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const = 0;
TreeLearner() = default; TreeLearner() = default;
/*! \brief Disable copy */ /*! \brief Disable copy */
TreeLearner& operator=(const TreeLearner&) = delete; TreeLearner& operator=(const TreeLearner&) = delete;
......
...@@ -308,15 +308,17 @@ double ObtainAutomaticInitialScore(const ObjectiveFunction* fobj, int class_id) ...@@ -308,15 +308,17 @@ double ObtainAutomaticInitialScore(const ObjectiveFunction* fobj, int class_id)
return init_score; return init_score;
} }
double GBDT::BoostFromAverage(int class_id) { double GBDT::BoostFromAverage(int class_id, bool update_scorer) {
// boosting from average label; or customized "average" if implemented for the current objective // boosting from average label; or customized "average" if implemented for the current objective
if (models_.empty() && !train_score_updater_->has_init_score() && objective_function_ != nullptr) { if (models_.empty() && !train_score_updater_->has_init_score() && objective_function_ != nullptr) {
if (config_->boost_from_average || (train_data_ != nullptr && train_data_->num_features() == 0)) { if (config_->boost_from_average || (train_data_ != nullptr && train_data_->num_features() == 0)) {
double init_score = ObtainAutomaticInitialScore(objective_function_, class_id); double init_score = ObtainAutomaticInitialScore(objective_function_, class_id);
if (std::fabs(init_score) > kEpsilon) { if (std::fabs(init_score) > kEpsilon) {
train_score_updater_->AddScore(init_score, class_id); if (update_scorer) {
for (auto& score_updater : valid_score_updater_) { train_score_updater_->AddScore(init_score, class_id);
score_updater->AddScore(init_score, class_id); for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(init_score, class_id);
}
} }
Log::Info("Start training from score %lf", init_score); Log::Info("Start training from score %lf", init_score);
return init_score; return init_score;
...@@ -335,7 +337,7 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) { ...@@ -335,7 +337,7 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
// boosting first // boosting first
if (gradients == nullptr || hessians == nullptr) { if (gradients == nullptr || hessians == nullptr) {
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
init_scores[cur_tree_id] = BoostFromAverage(cur_tree_id); init_scores[cur_tree_id] = BoostFromAverage(cur_tree_id, true);
} }
Boosting(); Boosting();
gradients = gradients_.data(); gradients = gradients_.data();
...@@ -597,7 +599,7 @@ void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) { ...@@ -597,7 +599,7 @@ void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
num_data = valid_score_updater_[used_idx]->num_data(); num_data = valid_score_updater_[used_idx]->num_data();
*out_len = static_cast<int64_t>(num_data) * num_class_; *out_len = static_cast<int64_t>(num_data) * num_class_;
} }
if (objective_function_ != nullptr && !average_output_) { if (objective_function_ != nullptr) {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
std::vector<double> tree_pred(num_tree_per_iteration_); std::vector<double> tree_pred(num_tree_per_iteration_);
......
...@@ -407,7 +407,7 @@ protected: ...@@ -407,7 +407,7 @@ protected:
*/ */
std::string OutputMetric(int iter); std::string OutputMetric(int iter);
double BoostFromAverage(int class_id); double BoostFromAverage(int class_id, bool update_scorer);
/*! \brief current iteration */ /*! \brief current iteration */
int iter_; int iter_;
......
...@@ -152,7 +152,7 @@ std::string GBDT::ModelToIfElse(int num_iteration) const { ...@@ -152,7 +152,7 @@ std::string GBDT::ModelToIfElse(int num_iteration) const {
str_buf << "\t\t\t" << "output[k] /= num_iteration_for_pred_;" << '\n'; str_buf << "\t\t\t" << "output[k] /= num_iteration_for_pred_;" << '\n';
str_buf << "\t\t" << "}" << '\n'; str_buf << "\t\t" << "}" << '\n';
str_buf << "\t" << "}" << '\n'; str_buf << "\t" << "}" << '\n';
str_buf << "\t" << "else if (objective_function_ != nullptr) {" << '\n'; str_buf << "\t" << "if (objective_function_ != nullptr) {" << '\n';
str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << '\n'; str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << '\n';
str_buf << "\t" << "}" << '\n'; str_buf << "\t" << "}" << '\n';
str_buf << "}" << '\n'; str_buf << "}" << '\n';
...@@ -166,7 +166,7 @@ std::string GBDT::ModelToIfElse(int num_iteration) const { ...@@ -166,7 +166,7 @@ std::string GBDT::ModelToIfElse(int num_iteration) const {
str_buf << "\t\t\t" << "output[k] /= num_iteration_for_pred_;" << '\n'; str_buf << "\t\t\t" << "output[k] /= num_iteration_for_pred_;" << '\n';
str_buf << "\t\t" << "}" << '\n'; str_buf << "\t\t" << "}" << '\n';
str_buf << "\t" << "}" << '\n'; str_buf << "\t" << "}" << '\n';
str_buf << "\t" << "else if (objective_function_ != nullptr) {" << '\n'; str_buf << "\t" << "if (objective_function_ != nullptr) {" << '\n';
str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << '\n'; str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << '\n';
str_buf << "\t" << "}" << '\n'; str_buf << "\t" << "}" << '\n';
str_buf << "}" << '\n'; str_buf << "}" << '\n';
......
...@@ -52,7 +52,8 @@ void GBDT::Predict(const double* features, double* output, const PredictionEarly ...@@ -52,7 +52,8 @@ void GBDT::Predict(const double* features, double* output, const PredictionEarly
for (int k = 0; k < num_tree_per_iteration_; ++k) { for (int k = 0; k < num_tree_per_iteration_; ++k) {
output[k] /= num_iteration_for_pred_; output[k] /= num_iteration_for_pred_;
} }
} else if (objective_function_ != nullptr) { }
if (objective_function_ != nullptr) {
objective_function_->ConvertOutput(output, output); objective_function_->ConvertOutput(output, output);
} }
} }
...@@ -63,7 +64,8 @@ void GBDT::PredictByMap(const std::unordered_map<int, double>& features, double* ...@@ -63,7 +64,8 @@ void GBDT::PredictByMap(const std::unordered_map<int, double>& features, double*
for (int k = 0; k < num_tree_per_iteration_; ++k) { for (int k = 0; k < num_tree_per_iteration_; ++k) {
output[k] /= num_iteration_for_pred_; output[k] /= num_iteration_for_pred_;
} }
} else if (objective_function_ != nullptr) { }
if (objective_function_ != nullptr) {
objective_function_->ConvertOutput(output, output); objective_function_->ConvertOutput(output, output);
} }
} }
......
...@@ -15,17 +15,17 @@ namespace LightGBM { ...@@ -15,17 +15,17 @@ namespace LightGBM {
/*! /*!
* \brief Rondom Forest implementation * \brief Rondom Forest implementation
*/ */
class RF: public GBDT { class RF : public GBDT {
public: public:
RF() : GBDT() { RF() : GBDT() {
average_output_ = true; average_output_ = true;
} }
~RF() {} ~RF() {}
void Init(const Config* config, const Dataset* train_data, const ObjectiveFunction* objective_function, void Init(const Config* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override { const std::vector<const Metric*>& training_metrics) override {
CHECK(config->bagging_freq > 0 && config->bagging_fraction < 1.0f && config->bagging_fraction > 0.0f); CHECK(config->bagging_freq > 0 && config->bagging_fraction < 1.0f && config->bagging_fraction > 0.0f);
CHECK(config->feature_fraction <= 1.0f && config->feature_fraction > 0.0f); CHECK(config->feature_fraction <= 1.0f && config->feature_fraction > 0.0f);
GBDT::Init(config, train_data, objective_function, training_metrics); GBDT::Init(config, train_data, objective_function, training_metrics);
...@@ -37,17 +37,15 @@ public: ...@@ -37,17 +37,15 @@ public:
} else { } else {
CHECK(train_data->metadata().init_score() == nullptr); CHECK(train_data->metadata().init_score() == nullptr);
} }
// cannot use RF for multi-class.
CHECK(num_tree_per_iteration_ == num_class_); CHECK(num_tree_per_iteration_ == num_class_);
// not shrinkage rate for the RF // not shrinkage rate for the RF
shrinkage_rate_ = 1.0f; shrinkage_rate_ = 1.0f;
// only boosting one time // only boosting one time
GetRFTargets(train_data); Boosting();
if (is_use_subset_ && bag_data_cnt_ < num_data_) { if (is_use_subset_ && bag_data_cnt_ < num_data_) {
tmp_grad_.resize(num_data_); tmp_grad_.resize(num_data_);
tmp_hess_.resize(num_data_); tmp_hess_.resize(num_data_);
} }
tmp_score_.resize(num_data_, 0.0);
} }
void ResetConfig(const Config* config) override { void ResetConfig(const Config* config) override {
...@@ -59,54 +57,41 @@ public: ...@@ -59,54 +57,41 @@ public:
} }
void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function, void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override { const std::vector<const Metric*>& training_metrics) override {
GBDT::ResetTrainingData(train_data, objective_function, training_metrics); GBDT::ResetTrainingData(train_data, objective_function, training_metrics);
if (iter_ + num_init_iteration_ > 0) { if (iter_ + num_init_iteration_ > 0) {
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
train_score_updater_->MultiplyScore(1.0f / (iter_ + num_init_iteration_), cur_tree_id); train_score_updater_->MultiplyScore(1.0f / (iter_ + num_init_iteration_), cur_tree_id);
} }
} }
// cannot use RF for multi-class.
CHECK(num_tree_per_iteration_ == num_class_); CHECK(num_tree_per_iteration_ == num_class_);
// only boosting one time // only boosting one time
GetRFTargets(train_data); Boosting();
if (is_use_subset_ && bag_data_cnt_ < num_data_) { if (is_use_subset_ && bag_data_cnt_ < num_data_) {
tmp_grad_.resize(num_data_); tmp_grad_.resize(num_data_);
tmp_hess_.resize(num_data_); tmp_hess_.resize(num_data_);
} }
tmp_score_.resize(num_data_, 0.0);
} }
void GetRFTargets(const Dataset* train_data) { void Boosting() override {
auto label_ptr = train_data->metadata().label(); if (objective_function_ == nullptr) {
std::fill(hessians_.begin(), hessians_.end(), 1.0f); Log::Fatal("No object function provided");
if (num_tree_per_iteration_ == 1) { }
OMP_INIT_EX(); init_scores_.resize(num_tree_per_iteration_, 0.0);
#pragma omp parallel for schedule(static,1) for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
for (data_size_t i = 0; i < train_data->num_data(); ++i) { init_scores_[cur_tree_id] = BoostFromAverage(cur_tree_id, false);
OMP_LOOP_EX_BEGIN();
score_t label = label_ptr[i];
gradients_[i] = static_cast<score_t>(-label);
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
} }
else { size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
std::fill(gradients_.begin(), gradients_.end(), 0.0f); std::vector<double> tmp_scores(total_size, 0.0f);
OMP_INIT_EX(); #pragma omp parallel for schedule(static)
#pragma omp parallel for schedule(static,1) for (int j = 0; j < num_tree_per_iteration_; ++j) {
for (data_size_t i = 0; i < train_data->num_data(); ++i) { size_t bias = static_cast<size_t>(j)* num_data_;
OMP_LOOP_EX_BEGIN(); for (data_size_t i = 0; i < num_data_; ++i) {
score_t label = label_ptr[i]; tmp_scores[bias + i] = init_scores_[j];
gradients_[i + static_cast<int>(label) * num_data_] = -1.0f;
OMP_LOOP_EX_END();
} }
OMP_THROW_EX();
} }
} objective_function_->
GetGradients(tmp_scores.data(), gradients_.data(), hessians_.data());
void Boosting() override {
} }
bool TrainOneIter(const score_t* gradients, const score_t* hessians) override { bool TrainOneIter(const score_t* gradients, const score_t* hessians) override {
...@@ -120,27 +105,51 @@ public: ...@@ -120,27 +105,51 @@ public:
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
std::unique_ptr<Tree> new_tree(new Tree(2)); std::unique_ptr<Tree> new_tree(new Tree(2));
size_t bias = static_cast<size_t>(cur_tree_id)* num_data_; size_t bias = static_cast<size_t>(cur_tree_id)* num_data_;
auto grad = gradients + bias; if (class_need_train_[cur_tree_id]) {
auto hess = hessians + bias;
auto grad = gradients + bias;
// need to copy gradients for bagging subset. auto hess = hessians + bias;
if (is_use_subset_ && bag_data_cnt_ < num_data_) {
for (int i = 0; i < bag_data_cnt_; ++i) { // need to copy gradients for bagging subset.
tmp_grad_[i] = grad[bag_data_indices_[i]]; if (is_use_subset_ && bag_data_cnt_ < num_data_) {
tmp_hess_[i] = hess[bag_data_indices_[i]]; for (int i = 0; i < bag_data_cnt_; ++i) {
tmp_grad_[i] = grad[bag_data_indices_[i]];
tmp_hess_[i] = hess[bag_data_indices_[i]];
}
grad = tmp_grad_.data();
hess = tmp_hess_.data();
} }
grad = tmp_grad_.data();
hess = tmp_hess_.data(); new_tree.reset(tree_learner_->Train(grad, hess, is_constant_hessian_,
forced_splits_json_));
} }
new_tree.reset(tree_learner_->Train(grad, hess, is_constant_hessian_,
forced_splits_json_));
if (new_tree->num_leaves() > 1) { if (new_tree->num_leaves() > 1) {
tree_learner_->RenewTreeOutput(new_tree.get(), objective_function_, tmp_score_.data(), tree_learner_->RenewTreeOutput(new_tree.get(), objective_function_, init_scores_[cur_tree_id],
num_data_, bag_data_indices_.data(), bag_data_cnt_); num_data_, bag_data_indices_.data(), bag_data_cnt_);
if (std::fabs(init_scores_[cur_tree_id]) > kEpsilon) {
new_tree->AddBias(init_scores_[cur_tree_id]);
}
// update score // update score
MultiplyScore(cur_tree_id, (iter_ + num_init_iteration_)); MultiplyScore(cur_tree_id, (iter_ + num_init_iteration_));
UpdateScore(new_tree.get(), cur_tree_id); UpdateScore(new_tree.get(), cur_tree_id);
MultiplyScore(cur_tree_id, 1.0 / (iter_ + num_init_iteration_ + 1)); MultiplyScore(cur_tree_id, 1.0 / (iter_ + num_init_iteration_ + 1));
} else {
// only add default score one-time
if (models_.size() < static_cast<size_t>(num_tree_per_iteration_)) {
double output = 0.0;
if (!class_need_train_[cur_tree_id]) {
if (objective_function_ != nullptr) {
output = objective_function_->BoostFromScore(cur_tree_id);
} else {
output = init_scores_[cur_tree_id];
}
}
new_tree->AsConstantTree(output);
MultiplyScore(cur_tree_id, (iter_ + num_init_iteration_));
UpdateScore(new_tree.get(), cur_tree_id);
MultiplyScore(cur_tree_id, 1.0 / (iter_ + num_init_iteration_ + 1));
}
} }
// add model // add model
models_.push_back(std::move(new_tree)); models_.push_back(std::move(new_tree));
...@@ -178,7 +187,7 @@ public: ...@@ -178,7 +187,7 @@ public:
} }
void AddValidDataset(const Dataset* valid_data, void AddValidDataset(const Dataset* valid_data,
const std::vector<const Metric*>& valid_metrics) override { const std::vector<const Metric*>& valid_metrics) override {
GBDT::AddValidDataset(valid_data, valid_metrics); GBDT::AddValidDataset(valid_data, valid_metrics);
if (iter_ + num_init_iteration_ > 0) { if (iter_ + num_init_iteration_ > 0) {
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
...@@ -192,17 +201,13 @@ public: ...@@ -192,17 +201,13 @@ public:
return true; return true;
}; };
std::vector<double> EvalOneMetric(const Metric* metric, const double* score) const override {
return metric->Eval(score, nullptr);
}
private: private:
std::vector<score_t> tmp_grad_; std::vector<score_t> tmp_grad_;
std::vector<score_t> tmp_hess_; std::vector<score_t> tmp_hess_;
std::vector<double> tmp_score_; std::vector<double> init_scores_;
}; };
} // namespace LightGBM } // namespace LightGBM
#endif // LIGHTGBM_BOOSTING_RF_H_ #endif // LIGHTGBM_BOOSTING_RF_H_
\ No newline at end of file
...@@ -250,6 +250,38 @@ public: ...@@ -250,6 +250,38 @@ public:
} }
} }
double RenewTreeOutput(double, double pred,
const data_size_t* index_mapper,
const data_size_t* bagging_mapper,
data_size_t num_data_in_leaf) const override {
const double alpha = 0.5;
if (weights_ == nullptr) {
if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred)
PercentileFun(double, data_reader, num_data_in_leaf, alpha);
#undef data_reader
} else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred)
PercentileFun(double, data_reader, num_data_in_leaf, alpha);
#undef data_reader
}
} else {
if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred)
#define weight_reader(i) (weights_[index_mapper[i]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha);
#undef data_reader
#undef weight_reader
} else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred)
#define weight_reader(i) (weights_[bagging_mapper[index_mapper[i]]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha);
#undef data_reader
#undef weight_reader
}
}
}
const char* GetName() const override { const char* GetName() const override {
return "regression_l1"; return "regression_l1";
} }
...@@ -540,6 +572,37 @@ public: ...@@ -540,6 +572,37 @@ public:
} }
} }
double RenewTreeOutput(double, double pred,
const data_size_t* index_mapper,
const data_size_t* bagging_mapper,
data_size_t num_data_in_leaf) const override {
if (weights_ == nullptr) {
if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred)
PercentileFun(double, data_reader, num_data_in_leaf, alpha_);
#undef data_reader
} else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred)
PercentileFun(double, data_reader, num_data_in_leaf, alpha_);
#undef data_reader
}
} else {
if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred)
#define weight_reader(i) (weights_[index_mapper[i]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha_);
#undef data_reader
#undef weight_reader
} else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred)
#define weight_reader(i) (weights_[bagging_mapper[index_mapper[i]]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha_);
#undef data_reader
#undef weight_reader
}
}
}
private: private:
score_t alpha_; score_t alpha_;
}; };
...@@ -631,6 +694,26 @@ public: ...@@ -631,6 +694,26 @@ public:
} }
} }
double RenewTreeOutput(double, double pred,
const data_size_t* index_mapper,
const data_size_t* bagging_mapper,
data_size_t num_data_in_leaf) const override {
const double alpha = 0.5;
if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred)
#define weight_reader(i) (label_weight_[index_mapper[i]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha);
#undef data_reader
#undef weight_reader
} else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred)
#define weight_reader(i) (label_weight_[bagging_mapper[index_mapper[i]]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha);
#undef data_reader
#undef weight_reader
}
}
const char* GetName() const override { const char* GetName() const override {
return "mape"; return "mape";
} }
......
...@@ -819,4 +819,44 @@ void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj ...@@ -819,4 +819,44 @@ void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj
} }
} }
void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, double prediction,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const {
if (obj != nullptr && obj->IsRenewTreeOutput()) {
CHECK(tree->num_leaves() <= data_partition_->num_leaves());
const data_size_t* bag_mapper = nullptr;
if (total_num_data != num_data_) {
CHECK(bag_cnt == num_data_);
bag_mapper = bag_indices;
}
std::vector<int> n_nozeroworker_perleaf(tree->num_leaves(), 1);
int num_machines = Network::num_machines();
#pragma omp parallel for schedule(static)
for (int i = 0; i < tree->num_leaves(); ++i) {
const double output = static_cast<double>(tree->LeafOutput(i));
data_size_t cnt_leaf_data = 0;
auto index_mapper = data_partition_->GetIndexOnLeaf(i, &cnt_leaf_data);
if (cnt_leaf_data > 0) {
// bag_mapper[index_mapper[i]]
const double new_output = obj->RenewTreeOutput(output, prediction, index_mapper, bag_mapper, cnt_leaf_data);
tree->SetLeafOutput(i, new_output);
} else {
CHECK(num_machines > 1);
tree->SetLeafOutput(i, 0.0);
n_nozeroworker_perleaf[i] = 0;
}
}
if (num_machines > 1) {
std::vector<double> outputs(tree->num_leaves());
for (int i = 0; i < tree->num_leaves(); ++i) {
outputs[i] = static_cast<double>(tree->LeafOutput(i));
}
Network::GlobalSum(outputs);
Network::GlobalSum(n_nozeroworker_perleaf);
for (int i = 0; i < tree->num_leaves(); ++i) {
tree->SetLeafOutput(i, outputs[i] / n_nozeroworker_perleaf[i]);
}
}
}
}
} // namespace LightGBM } // namespace LightGBM
...@@ -72,6 +72,9 @@ public: ...@@ -72,6 +72,9 @@ public:
void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, const double* prediction, void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, const double* prediction,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const override; data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const override;
void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, double prediction,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const override;
protected: protected:
/*! /*!
* \brief Some initial works before training * \brief Some initial works before training
......
...@@ -732,7 +732,7 @@ class TestEngine(unittest.TestCase): ...@@ -732,7 +732,7 @@ class TestEngine(unittest.TestCase):
'bagging_freq': 1, 'bagging_freq': 1,
'bagging_fraction': 0.8, 'bagging_fraction': 0.8,
'feature_fraction': 0.8, 'feature_fraction': 0.8,
'boost_from_average': False 'boost_from_average': True
} }
lgb_train = lgb.Dataset(X, y) lgb_train = lgb.Dataset(X, y)
gbm = lgb.train(params, lgb_train, num_boost_round=20) gbm = lgb.train(params, lgb_train, num_boost_round=20)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment