Unverified Commit 7744757a authored by Belinda Trotta's avatar Belinda Trotta Committed by GitHub
Browse files

Improve performance of path smoothing (#3396)

* Make path smoothing faster

* Fix bug

* Fix bug

* Minor style fix
parent 4278f222
......@@ -180,6 +180,8 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const
this->col_sampler_.GetByNode(tree, this->smaller_leaf_splits_->leaf_index());
std::vector<int8_t> larger_node_used_features =
this->col_sampler_.GetByNode(tree, this->larger_leaf_splits_->leaf_index());
double smaller_leaf_parent_output = this->GetParentOutput(tree, this->smaller_leaf_splits_.get());
double larger_leaf_parent_output = this->GetParentOutput(tree, this->larger_leaf_splits_.get());
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
......@@ -200,7 +202,8 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const
smaller_node_used_features[feature_index],
GetGlobalDataCountInLeaf(this->smaller_leaf_splits_->leaf_index()),
this->smaller_leaf_splits_.get(),
&smaller_bests_per_thread[tid]);
&smaller_bests_per_thread[tid],
smaller_leaf_parent_output);
// only root leaf
if (this->larger_leaf_splits_ == nullptr || this->larger_leaf_splits_->leaf_index() < 0) continue;
......@@ -214,7 +217,8 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const
larger_node_used_features[feature_index],
GetGlobalDataCountInLeaf(this->larger_leaf_splits_->leaf_index()),
this->larger_leaf_splits_.get(),
&larger_bests_per_thread[tid]);
&larger_bests_per_thread[tid],
larger_leaf_parent_output);
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
......
......@@ -372,6 +372,11 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(
std::vector<SplitInfo> larger_best(share_state_->num_threads);
std::vector<int8_t> smaller_node_used_features = col_sampler_.GetByNode(tree, smaller_leaf_splits_->leaf_index());
std::vector<int8_t> larger_node_used_features;
double smaller_leaf_parent_output = GetParentOutput(tree, smaller_leaf_splits_.get());
double larger_leaf_parent_output = 0;
if (larger_leaf_splits_ != nullptr && larger_leaf_splits_->leaf_index() >= 0) {
larger_leaf_parent_output = GetParentOutput(tree, larger_leaf_splits_.get());
}
if (larger_leaf_splits_->leaf_index() >= 0) {
larger_node_used_features = col_sampler_.GetByNode(tree, larger_leaf_splits_->leaf_index());
}
......@@ -394,7 +399,8 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(
real_fidx,
smaller_node_used_features[feature_index],
smaller_leaf_splits_->num_data_in_leaf(),
smaller_leaf_splits_.get(), &smaller_best[tid]);
smaller_leaf_splits_.get(), &smaller_best[tid],
smaller_leaf_parent_output);
// only has root leaf
if (larger_leaf_splits_ == nullptr ||
......@@ -416,7 +422,8 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(
real_fidx,
larger_node_used_features[feature_index],
larger_leaf_splits_->num_data_in_leaf(),
larger_leaf_splits_.get(), &larger_best[tid]);
larger_leaf_splits_.get(), &larger_best[tid],
larger_leaf_parent_output);
OMP_LOOP_EX_END();
}
......@@ -710,8 +717,7 @@ void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj
void SerialTreeLearner::ComputeBestSplitForFeature(
FeatureHistogram* histogram_array_, int feature_index, int real_fidx,
bool is_feature_used, int num_data, const LeafSplits* leaf_splits,
SplitInfo* best_split) {
SplitInfo* best_split, double parent_output) {
bool is_feature_numerical = train_data_->FeatureBinMapper(feature_index)
->bin_type() == BinType::NumericalBin;
if (is_feature_numerical & !config_->monotone_constraints.empty()) {
......@@ -719,21 +725,10 @@ void SerialTreeLearner::ComputeBestSplitForFeature(
constraints_.get(), feature_index, ~(leaf_splits->leaf_index()),
train_data_->FeatureNumBin(feature_index));
}
SplitInfo new_split;
double parent_output;
if (leaf_splits->leaf_index() == 0) {
// for root leaf the "parent" output is its own output because we don't apply any smoothing to the root
parent_output = FeatureHistogram::CalculateSplittedLeafOutput<false, true, true, false>(
leaf_splits->sum_gradients(), leaf_splits->sum_hessians(), config_->lambda_l1,
config_->lambda_l2, config_->max_delta_step, BasicConstraint(),
config_->path_smooth, static_cast<data_size_t>(num_data), 0);
} else {
parent_output = leaf_splits->weight();
}
histogram_array_[feature_index].FindBestThreshold(
leaf_splits->sum_gradients(), leaf_splits->sum_hessians(), num_data,
constraints_->GetFeatureConstraint(leaf_splits->leaf_index(), feature_index), parent_output, &new_split);
constraints_->GetFeatureConstraint(leaf_splits->leaf_index(), feature_index), parent_output, &new_split);
new_split.feature = real_fidx;
if (cegb_ != nullptr) {
new_split.gain -=
......@@ -752,6 +747,20 @@ void SerialTreeLearner::ComputeBestSplitForFeature(
}
}
double SerialTreeLearner::GetParentOutput(const Tree* tree, const LeafSplits* leaf_splits) const {
double parent_output;
if (tree->num_leaves() == 1) {
// for root leaf the "parent" output is its own output because we don't apply any smoothing to the root
parent_output = FeatureHistogram::CalculateSplittedLeafOutput<true, true, true, false>(
leaf_splits->sum_gradients(), leaf_splits->sum_hessians(), config_->lambda_l1,
config_->lambda_l2, config_->max_delta_step, BasicConstraint(),
config_->path_smooth, static_cast<data_size_t>(leaf_splits->num_data_in_leaf()), 0);
} else {
parent_output = leaf_splits->weight();
}
return parent_output;
}
void SerialTreeLearner::RecomputeBestSplitForLeaf(int leaf, SplitInfo* split) {
FeatureHistogram* histogram_array_;
if (!histogram_pool_.Get(leaf, &histogram_array_)) {
......@@ -769,6 +778,14 @@ void SerialTreeLearner::RecomputeBestSplitForLeaf(int leaf, SplitInfo* split) {
LeafSplits leaf_splits(num_data);
leaf_splits.Init(leaf, sum_gradients, sum_hessians);
// can't use GetParentOutput because leaf_splits doesn't have weight property set
double parent_output = 0;
if (config_->path_smooth > kEpsilon) {
parent_output = FeatureHistogram::CalculateSplittedLeafOutput<true, true, true, false>(
sum_gradients, sum_hessians, config_->lambda_l1, config_->lambda_l2, config_->max_delta_step,
BasicConstraint(), config_->path_smooth, static_cast<data_size_t>(num_data), 0);
}
OMP_INIT_EX();
// find splits
#pragma omp parallel for schedule(static) num_threads(share_state_->num_threads)
......@@ -780,10 +797,8 @@ void SerialTreeLearner::RecomputeBestSplitForLeaf(int leaf, SplitInfo* split) {
}
const int tid = omp_get_thread_num();
int real_fidx = train_data_->RealFeatureIndex(feature_index);
ComputeBestSplitForFeature(
histogram_array_, feature_index, real_fidx,
true,
num_data, &leaf_splits, &bests[tid]);
ComputeBestSplitForFeature(histogram_array_, feature_index, real_fidx, true,
num_data, &leaf_splits, &bests[tid], parent_output);
OMP_LOOP_EX_END();
}
......
......@@ -114,12 +114,16 @@ class SerialTreeLearner: public TreeLearner {
void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, std::function<double(const label_t*, int)> residual_getter,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const override;
/*! \brief Get output of parent node, used for path smoothing */
double GetParentOutput(const Tree* tree, const LeafSplits* leaf_splits) const;
protected:
void ComputeBestSplitForFeature(FeatureHistogram* histogram_array_,
int feature_index, int real_fidx,
bool is_feature_used, int num_data,
const LeafSplits* leaf_splits,
SplitInfo* best_split);
SplitInfo* best_split, double parent_output);
void GetShareStates(const Dataset* dataset, bool is_constant_hessian, bool is_first_time);
......
......@@ -262,7 +262,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits(const Tree* tree)
std::vector<SplitInfo> smaller_bestsplit_per_features(this->num_features_);
std::vector<SplitInfo> larger_bestsplit_per_features(this->num_features_);
double smaller_leaf_parent_output = this->GetParentOutput(tree, this->smaller_leaf_splits_.get());
double larger_leaf_parent_output = this->GetParentOutput(tree, this->larger_leaf_splits_.get());
OMP_INIT_EX();
// find splits
#pragma omp parallel for schedule(static)
......@@ -278,7 +279,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits(const Tree* tree)
this->smaller_leaf_histogram_array_, feature_index, real_feature_index,
true, this->smaller_leaf_splits_->num_data_in_leaf(),
this->smaller_leaf_splits_.get(),
&smaller_bestsplit_per_features[feature_index]);
&smaller_bestsplit_per_features[feature_index],
smaller_leaf_parent_output);
// only has root leaf
if (this->larger_leaf_splits_ == nullptr || this->larger_leaf_splits_->leaf_index() < 0) { continue; }
......@@ -292,7 +294,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits(const Tree* tree)
this->larger_leaf_histogram_array_, feature_index, real_feature_index,
true, this->larger_leaf_splits_->num_data_in_leaf(),
this->larger_leaf_splits_.get(),
&larger_bestsplit_per_features[feature_index]);
&larger_bestsplit_per_features[feature_index],
larger_leaf_parent_output);
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
......@@ -354,8 +357,9 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons
this->col_sampler_.GetByNode(tree, this->smaller_leaf_splits_->leaf_index());
std::vector<int8_t> larger_node_used_features =
this->col_sampler_.GetByNode(tree, this->larger_leaf_splits_->leaf_index());
double smaller_leaf_parent_output = this->GetParentOutput(tree, this->smaller_leaf_splits_global_.get());
double larger_leaf_parent_output = this->GetParentOutput(tree, this->larger_leaf_splits_global_.get());
// find best split from local aggregated histograms
OMP_INIT_EX();
#pragma omp parallel for schedule(static) num_threads(this->share_state_->num_threads)
for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
......@@ -375,7 +379,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons
smaller_leaf_histogram_array_global_.get(), feature_index,
real_feature_index, smaller_node_used_features[feature_index],
GetGlobalDataCountInLeaf(smaller_leaf_splits_global_->leaf_index()),
smaller_leaf_splits_global_.get(), &smaller_bests_per_thread[tid]);
smaller_leaf_splits_global_.get(), &smaller_bests_per_thread[tid],
smaller_leaf_parent_output);
}
if (larger_is_feature_aggregated_[feature_index]) {
......@@ -391,7 +396,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons
real_feature_index,
larger_node_used_features[feature_index],
GetGlobalDataCountInLeaf(larger_leaf_splits_global_->leaf_index()),
larger_leaf_splits_global_.get(), &larger_bests_per_thread[tid]);
larger_leaf_splits_global_.get(), &larger_bests_per_thread[tid],
larger_leaf_parent_output);
}
OMP_LOOP_EX_END();
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment