Unverified Commit dcf9ad2e authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix-parallel-quantile (#1605)

* fix-parallel-quantile

* Update serial_tree_learner.cpp
parent 8ff1e94b
...@@ -774,6 +774,7 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri ...@@ -774,6 +774,7 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
} }
} }
void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, const double* prediction, void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, const double* prediction,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const { data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const {
if (obj != nullptr && obj->IsRenewTreeOutput()) { if (obj != nullptr && obj->IsRenewTreeOutput()) {
...@@ -783,15 +784,21 @@ void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj ...@@ -783,15 +784,21 @@ void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj
CHECK(bag_cnt == num_data_); CHECK(bag_cnt == num_data_);
bag_mapper = bag_indices; bag_mapper = bag_indices;
} }
std::vector<int> n_nozeroworker_perleaf(tree->num_leaves(), 1);
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int i = 0; i < tree->num_leaves(); ++i) { for (int i = 0; i < tree->num_leaves(); ++i) {
const double output = static_cast<double>(tree->LeafOutput(i)); const double output = static_cast<double>(tree->LeafOutput(i));
data_size_t cnt_leaf_data = 0; data_size_t cnt_leaf_data = 0;
auto index_mapper = data_partition_->GetIndexOnLeaf(i, &cnt_leaf_data); auto index_mapper = data_partition_->GetIndexOnLeaf(i, &cnt_leaf_data);
CHECK(cnt_leaf_data > 0); if (cnt_leaf_data > 0) {
// bag_mapper[index_mapper[i]] // bag_mapper[index_mapper[i]]
const double new_output = obj->RenewTreeOutput(output, prediction, index_mapper, bag_mapper, cnt_leaf_data); const double new_output = obj->RenewTreeOutput(output, prediction, index_mapper, bag_mapper, cnt_leaf_data);
tree->SetLeafOutput(i, new_output); tree->SetLeafOutput(i, new_output);
} else {
CHECK(Network::num_machines() > 1);
tree->SetLeafOutput(i, 0.0);
n_nozeroworker_perleaf[i] = 0;
}
} }
if (Network::num_machines() > 1) { if (Network::num_machines() > 1) {
std::vector<double> outputs(tree->num_leaves()); std::vector<double> outputs(tree->num_leaves());
...@@ -799,8 +806,9 @@ void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj ...@@ -799,8 +806,9 @@ void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj
outputs[i] = static_cast<double>(tree->LeafOutput(i)); outputs[i] = static_cast<double>(tree->LeafOutput(i));
} }
Network::GlobalSum(outputs); Network::GlobalSum(outputs);
Network::GlobalSum(n_nozeroworker_perleaf);
for (int i = 0; i < tree->num_leaves(); ++i) { for (int i = 0; i < tree->num_leaves(); ++i) {
tree->SetLeafOutput(i, outputs[i] / Network::num_machines()); tree->SetLeafOutput(i, outputs[i] / n_nozeroworker_perleaf[i]);
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment