Commit 82e273ba authored by Guolin Ke's avatar Guolin Ke
Browse files

clean code for tree learner.

parent ca6018fe
...@@ -145,8 +145,8 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() { ...@@ -145,8 +145,8 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
} }
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
void DataParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() { void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
this->ConstructHistograms(this->is_feature_used_, true); TREELEARNER_T::ConstructHistograms(this->is_feature_used_, true);
// construct local histograms // construct local histograms
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) { for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
...@@ -159,23 +159,29 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() { ...@@ -159,23 +159,29 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() {
// Reduce scatter for histogram // Reduce scatter for histogram
Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, block_start_.data(), Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, block_start_.data(),
block_len_.data(), output_buffer_.data(), &HistogramBinEntry::SumReducer); block_len_.data(), output_buffer_.data(), &HistogramBinEntry::SumReducer);
this->FindBestSplitsFromHistograms(this->is_feature_used_, true);
}
template <typename TREELEARNER_T>
void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>&, bool) {
std::vector<SplitInfo> smaller_bests_per_thread(this->num_threads_, SplitInfo());
std::vector<SplitInfo> larger_bests_per_thread(this->num_threads_, SplitInfo());
std::vector<SplitInfo> smaller_best(this->num_threads_, SplitInfo());
std::vector<SplitInfo> larger_best(this->num_threads_, SplitInfo());
OMP_INIT_EX(); OMP_INIT_EX();
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) { for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
OMP_LOOP_EX_BEGIN(); OMP_LOOP_EX_BEGIN();
if (!is_feature_aggregated_[feature_index]) continue; if (!is_feature_aggregated_[feature_index]) continue;
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
const int real_feature_index = this->train_data_->RealFeatureIndex(feature_index);
// restore global histograms from buffer // restore global histograms from buffer
this->smaller_leaf_histogram_array_[feature_index].FromMemory( this->smaller_leaf_histogram_array_[feature_index].FromMemory(
output_buffer_.data() + buffer_read_start_pos_[feature_index]); output_buffer_.data() + buffer_read_start_pos_[feature_index]);
this->train_data_->FixHistogram(feature_index, this->train_data_->FixHistogram(feature_index,
this->smaller_leaf_splits_->sum_gradients(), this->smaller_leaf_splits_->sum_hessians(), this->smaller_leaf_splits_->sum_gradients(), this->smaller_leaf_splits_->sum_hessians(),
GetGlobalDataCountInLeaf(this->smaller_leaf_splits_->LeafIndex()), GetGlobalDataCountInLeaf(this->smaller_leaf_splits_->LeafIndex()),
this->smaller_leaf_histogram_array_[feature_index].RawData()); this->smaller_leaf_histogram_array_[feature_index].RawData());
SplitInfo smaller_split; SplitInfo smaller_split;
// find best threshold for smaller child // find best threshold for smaller child
this->smaller_leaf_histogram_array_[feature_index].FindBestThreshold( this->smaller_leaf_histogram_array_[feature_index].FindBestThreshold(
...@@ -183,9 +189,9 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() { ...@@ -183,9 +189,9 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() {
this->smaller_leaf_splits_->sum_hessians(), this->smaller_leaf_splits_->sum_hessians(),
GetGlobalDataCountInLeaf(this->smaller_leaf_splits_->LeafIndex()), GetGlobalDataCountInLeaf(this->smaller_leaf_splits_->LeafIndex()),
&smaller_split); &smaller_split);
if (smaller_split.gain > smaller_best[tid].gain) { smaller_split.feature = real_feature_index;
smaller_best[tid] = smaller_split; if (smaller_split > smaller_bests_per_thread[tid]) {
smaller_best[tid].feature = this->train_data_->RealFeatureIndex(feature_index); smaller_bests_per_thread[tid] = smaller_split;
} }
// only root leaf // only root leaf
...@@ -201,49 +207,45 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() { ...@@ -201,49 +207,45 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() {
this->larger_leaf_splits_->sum_hessians(), this->larger_leaf_splits_->sum_hessians(),
GetGlobalDataCountInLeaf(this->larger_leaf_splits_->LeafIndex()), GetGlobalDataCountInLeaf(this->larger_leaf_splits_->LeafIndex()),
&larger_split); &larger_split);
if (larger_split.gain > larger_best[tid].gain) { larger_split.feature = real_feature_index;
larger_best[tid] = larger_split; if (larger_split > larger_bests_per_thread[tid]) {
larger_best[tid].feature = this->train_data_->RealFeatureIndex(feature_index); larger_bests_per_thread[tid] = larger_split;
} }
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
OMP_THROW_EX(); OMP_THROW_EX();
auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_best);
int leaf = this->smaller_leaf_splits_->LeafIndex();
this->best_split_per_leaf_[leaf] = smaller_best[smaller_best_idx];
if (this->larger_leaf_splits_ == nullptr || this->larger_leaf_splits_->LeafIndex() < 0) { return; } auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_bests_per_thread);
int leaf = this->smaller_leaf_splits_->LeafIndex();
leaf = this->larger_leaf_splits_->LeafIndex(); this->best_split_per_leaf_[leaf] = smaller_bests_per_thread[smaller_best_idx];
auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_best);
this->best_split_per_leaf_[leaf] = larger_best[larger_best_idx];
} if (this->larger_leaf_splits_ != nullptr && this->larger_leaf_splits_->LeafIndex() >= 0) {
leaf = this->larger_leaf_splits_->LeafIndex();
auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_bests_per_thread);
this->best_split_per_leaf_[leaf] = larger_bests_per_thread[larger_best_idx];
}
template <typename TREELEARNER_T> SplitInfo smaller_best_split, larger_best_split;
void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsForLeaves() { smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()];
SplitInfo smaller_best, larger_best;
smaller_best = this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()];
// find local best split for larger leaf // find local best split for larger leaf
if (this->larger_leaf_splits_->LeafIndex() >= 0) { if (this->larger_leaf_splits_->LeafIndex() >= 0) {
larger_best = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()]; larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()];
} }
// sync global best info // sync global best info
std::memcpy(input_buffer_.data(), &smaller_best, sizeof(SplitInfo)); std::memcpy(input_buffer_.data(), &smaller_best_split, sizeof(SplitInfo));
std::memcpy(input_buffer_.data() + sizeof(SplitInfo), &larger_best, sizeof(SplitInfo)); std::memcpy(input_buffer_.data() + sizeof(SplitInfo), &larger_best_split, sizeof(SplitInfo));
Network::Allreduce(input_buffer_.data(), sizeof(SplitInfo) * 2, sizeof(SplitInfo), Network::Allreduce(input_buffer_.data(), sizeof(SplitInfo) * 2, sizeof(SplitInfo),
output_buffer_.data(), &SplitInfo::MaxReducer); output_buffer_.data(), &SplitInfo::MaxReducer);
std::memcpy(&smaller_best, output_buffer_.data(), sizeof(SplitInfo)); std::memcpy(&smaller_best_split, output_buffer_.data(), sizeof(SplitInfo));
std::memcpy(&larger_best, output_buffer_.data() + sizeof(SplitInfo), sizeof(SplitInfo)); std::memcpy(&larger_best_split, output_buffer_.data() + sizeof(SplitInfo), sizeof(SplitInfo));
// set best split // set best split
this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best; this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split;
if (this->larger_leaf_splits_->LeafIndex() >= 0) { if (this->larger_leaf_splits_->LeafIndex() >= 0) {
this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()] = larger_best; this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()] = larger_best_split;
} }
} }
......
...@@ -50,27 +50,28 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::BeforeTrain() { ...@@ -50,27 +50,28 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
} }
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
void FeatureParallelTreeLearner<TREELEARNER_T>::FindBestSplitsForLeaves() { void FeatureParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) {
SplitInfo smaller_best, larger_best; TREELEARNER_T::FindBestSplitsFromHistograms(is_feature_used, use_subtract);
SplitInfo smaller_best_split, larger_best_split;
// get best split at smaller leaf // get best split at smaller leaf
smaller_best = this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()]; smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()];
// find local best split for larger leaf // find local best split for larger leaf
if (this->larger_leaf_splits_->LeafIndex() >= 0) { if (this->larger_leaf_splits_->LeafIndex() >= 0) {
larger_best = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()]; larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()];
} }
// sync global best info // sync global best info
std::memcpy(input_buffer_.data(), &smaller_best, sizeof(SplitInfo)); std::memcpy(input_buffer_.data(), &smaller_best_split, sizeof(SplitInfo));
std::memcpy(input_buffer_.data() + sizeof(SplitInfo), &larger_best, sizeof(SplitInfo)); std::memcpy(input_buffer_.data() + sizeof(SplitInfo), &larger_best_split, sizeof(SplitInfo));
Network::Allreduce(input_buffer_.data(), sizeof(SplitInfo) * 2, sizeof(SplitInfo), Network::Allreduce(input_buffer_.data(), sizeof(SplitInfo) * 2, sizeof(SplitInfo),
output_buffer_.data(), &SplitInfo::MaxReducer); output_buffer_.data(), &SplitInfo::MaxReducer);
// copy back // copy back
std::memcpy(&smaller_best, output_buffer_.data(), sizeof(SplitInfo)); std::memcpy(&smaller_best_split, output_buffer_.data(), sizeof(SplitInfo));
std::memcpy(&larger_best, output_buffer_.data() + sizeof(SplitInfo), sizeof(SplitInfo)); std::memcpy(&larger_best_split, output_buffer_.data() + sizeof(SplitInfo), sizeof(SplitInfo));
// update best split // update best split
this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best; this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split;
if (this->larger_leaf_splits_->LeafIndex() >= 0) { if (this->larger_leaf_splits_->LeafIndex() >= 0) {
this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()] = larger_best; this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()] = larger_best_split;
} }
} }
......
...@@ -1071,8 +1071,8 @@ void GPUTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_feature_u ...@@ -1071,8 +1071,8 @@ void GPUTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_feature_u
} }
} }
void GPUTreeLearner::FindBestThresholds() { void GPUTreeLearner::FindBestSplits() {
SerialTreeLearner::FindBestThresholds(); SerialTreeLearner::FindBestSplits();
#if GPU_DEBUG >= 3 #if GPU_DEBUG >= 3
for (int feature_index = 0; feature_index < num_features_; ++feature_index) { for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
......
...@@ -58,7 +58,7 @@ public: ...@@ -58,7 +58,7 @@ public:
protected: protected:
void BeforeTrain() override; void BeforeTrain() override;
bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override; bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override;
void FindBestThresholds() override; void FindBestSplits() override;
void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override; void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;
void ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override; void ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
private: private:
......
...@@ -28,7 +28,7 @@ public: ...@@ -28,7 +28,7 @@ public:
protected: protected:
void BeforeTrain() override; void BeforeTrain() override;
void FindBestSplitsForLeaves() override; void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
private: private:
/*! \brief rank of local machine */ /*! \brief rank of local machine */
int rank_; int rank_;
...@@ -54,8 +54,8 @@ public: ...@@ -54,8 +54,8 @@ public:
void ResetConfig(const TreeConfig* tree_config) override; void ResetConfig(const TreeConfig* tree_config) override;
protected: protected:
void BeforeTrain() override; void BeforeTrain() override;
void FindBestThresholds() override; void FindBestSplits() override;
void FindBestSplitsForLeaves() override; void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override; void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;
inline data_size_t GetGlobalDataCountInLeaf(int leaf_idx) const override { inline data_size_t GetGlobalDataCountInLeaf(int leaf_idx) const override {
...@@ -108,8 +108,8 @@ public: ...@@ -108,8 +108,8 @@ public:
protected: protected:
void BeforeTrain() override; void BeforeTrain() override;
bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override; bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override;
void FindBestThresholds() override; void FindBestSplits() override;
void FindBestSplitsForLeaves() override; void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override; void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;
inline data_size_t GetGlobalDataCountInLeaf(int leaf_idx) const override { inline data_size_t GetGlobalDataCountInLeaf(int leaf_idx) const override {
......
...@@ -179,9 +179,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians ...@@ -179,9 +179,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
init_split_time += std::chrono::steady_clock::now() - start_time; init_split_time += std::chrono::steady_clock::now() - start_time;
#endif #endif
// find best threshold for every feature // find best threshold for every feature
FindBestThresholds(); FindBestSplits();
// find best split from all features
FindBestSplitsForLeaves();
} }
// Get a leaf with max split gain // Get a leaf with max split gain
int best_leaf = static_cast<int>(ArrayArgs<SplitInfo>::ArgMax(best_split_per_leaf_)); int best_leaf = static_cast<int>(ArrayArgs<SplitInfo>::ArgMax(best_split_per_leaf_));
...@@ -405,10 +403,27 @@ bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int ...@@ -405,10 +403,27 @@ bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int
return true; return true;
} }
void SerialTreeLearner::FindBestSplits() {
std::vector<int8_t> is_feature_used(num_features_, 0);
#pragma omp parallel for schedule(static,1024) if (num_features_ >= 2048)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
if (!is_feature_used_[feature_index]) continue;
if (parent_leaf_histogram_array_ != nullptr
&& !parent_leaf_histogram_array_[feature_index].is_splittable()) {
smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
continue;
}
is_feature_used[feature_index] = 1;
}
bool use_subtract = parent_leaf_histogram_array_ != nullptr;
ConstructHistograms(is_feature_used, use_subtract);
FindBestSplitsFromHistograms(is_feature_used, use_subtract);
}
void SerialTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) { void SerialTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) {
#ifdef TIMETAG #ifdef TIMETAG
auto start_time = std::chrono::steady_clock::now(); auto start_time = std::chrono::steady_clock::now();
#endif #endif
// construct smaller leaf // construct smaller leaf
HistogramBinEntry* ptr_smaller_leaf_hist_data = smaller_leaf_histogram_array_[0].RawData() - 1; HistogramBinEntry* ptr_smaller_leaf_hist_data = smaller_leaf_histogram_array_[0].RawData() - 1;
train_data_->ConstructHistograms(is_feature_used, train_data_->ConstructHistograms(is_feature_used,
...@@ -428,29 +443,12 @@ void SerialTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_featur ...@@ -428,29 +443,12 @@ void SerialTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_featur
ordered_gradients_.data(), ordered_hessians_.data(), is_constant_hessian_, ordered_gradients_.data(), ordered_hessians_.data(), is_constant_hessian_,
ptr_larger_leaf_hist_data); ptr_larger_leaf_hist_data);
} }
#ifdef TIMETAG #ifdef TIMETAG
hist_time += std::chrono::steady_clock::now() - start_time; hist_time += std::chrono::steady_clock::now() - start_time;
#endif #endif
} }
void SerialTreeLearner::FindBestThresholds() { void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) {
std::vector<int8_t> is_feature_used(num_features_, 0);
#pragma omp parallel for schedule(static,1024) if (num_features_ >= 2048)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
if (!is_feature_used_[feature_index]) continue;
if (parent_leaf_histogram_array_ != nullptr
&& !parent_leaf_histogram_array_[feature_index].is_splittable()) {
smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
continue;
}
is_feature_used[feature_index] = 1;
}
bool use_subtract = true;
if (parent_leaf_histogram_array_ == nullptr) {
use_subtract = false;
}
ConstructHistograms(is_feature_used, use_subtract);
#ifdef TIMETAG #ifdef TIMETAG
auto start_time = std::chrono::steady_clock::now(); auto start_time = std::chrono::steady_clock::now();
#endif #endif
...@@ -498,7 +496,7 @@ void SerialTreeLearner::FindBestThresholds() { ...@@ -498,7 +496,7 @@ void SerialTreeLearner::FindBestThresholds() {
larger_split.feature = real_fidx; larger_split.feature = real_fidx;
if (larger_split > larger_best[tid]) { if (larger_split > larger_best[tid]) {
larger_best[tid] = larger_split; larger_best[tid] = larger_split;
} }
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
OMP_THROW_EX(); OMP_THROW_EX();
...@@ -517,10 +515,6 @@ void SerialTreeLearner::FindBestThresholds() { ...@@ -517,10 +515,6 @@ void SerialTreeLearner::FindBestThresholds() {
#endif #endif
} }
void SerialTreeLearner::FindBestSplitsForLeaves() {
}
void SerialTreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) { void SerialTreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) {
const SplitInfo& best_split_info = best_split_per_leaf_[best_Leaf]; const SplitInfo& best_split_info = best_split_per_leaf_[best_Leaf];
......
...@@ -74,20 +74,11 @@ protected: ...@@ -74,20 +74,11 @@ protected:
*/ */
virtual bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf); virtual bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf);
virtual void ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract); virtual void FindBestSplits();
/*! virtual void ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract);
* \brief Find best thresholds for all features, using multi-threading.
* The result will be stored in smaller_leaf_splits_ and larger_leaf_splits_.
* This function will be called in FindBestSplit.
*/
virtual void FindBestThresholds();
/*! virtual void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract);
* \brief Find best features for leaves from smaller_leaf_splits_ and larger_leaf_splits_.
* This function will be called after FindBestThresholds.
*/
virtual void FindBestSplitsForLeaves();
/*! /*!
* \brief Partition tree and data according best split. * \brief Partition tree and data according best split.
......
...@@ -252,7 +252,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::CopyLocalHistogram(const std::vec ...@@ -252,7 +252,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::CopyLocalHistogram(const std::vec
} }
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() { void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
// use local data to find local best splits // use local data to find local best splits
std::vector<int8_t> is_feature_used(this->num_features_, 0); std::vector<int8_t> is_feature_used(this->num_features_, 0);
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
...@@ -269,10 +269,11 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() { ...@@ -269,10 +269,11 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() {
if (this->parent_leaf_histogram_array_ == nullptr) { if (this->parent_leaf_histogram_array_ == nullptr) {
use_subtract = false; use_subtract = false;
} }
this->ConstructHistograms(is_feature_used, use_subtract); TREELEARNER_T::ConstructHistograms(is_feature_used, use_subtract);
std::vector<SplitInfo> smaller_bestsplit_per_features(this->num_features_); std::vector<SplitInfo> smaller_bestsplit_per_features(this->num_features_);
std::vector<SplitInfo> larger_bestsplit_per_features(this->num_features_); std::vector<SplitInfo> larger_bestsplit_per_features(this->num_features_);
OMP_INIT_EX(); OMP_INIT_EX();
// find splits // find splits
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
...@@ -350,13 +351,22 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() { ...@@ -350,13 +351,22 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() {
Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, block_start_.data(), block_len_.data(), Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, block_start_.data(), block_len_.data(),
output_buffer_.data(), &HistogramBinEntry::SumReducer); output_buffer_.data(), &HistogramBinEntry::SumReducer);
std::vector<SplitInfo> smaller_best(this->num_threads_); this->FindBestSplitsFromHistograms(is_feature_used, false);
std::vector<SplitInfo> larger_best(this->num_threads_); }
template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>&, bool) {
std::vector<SplitInfo> smaller_bests_per_thread(this->num_threads_);
std::vector<SplitInfo> larger_best_per_thread(this->num_threads_);
// find best split from local aggregated histograms // find best split from local aggregated histograms
#pragma omp parallel for schedule(static)
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) { for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
OMP_LOOP_EX_BEGIN(); OMP_LOOP_EX_BEGIN();
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
const int real_feature_index = this->train_data_->RealFeatureIndex(feature_index);
if (smaller_is_feature_aggregated_[feature_index]) { if (smaller_is_feature_aggregated_[feature_index]) {
SplitInfo smaller_split; SplitInfo smaller_split;
// restore from buffer // restore from buffer
...@@ -364,9 +374,9 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() { ...@@ -364,9 +374,9 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() {
output_buffer_.data() + smaller_buffer_read_start_pos_[feature_index]); output_buffer_.data() + smaller_buffer_read_start_pos_[feature_index]);
this->train_data_->FixHistogram(feature_index, this->train_data_->FixHistogram(feature_index,
smaller_leaf_splits_global_->sum_gradients(), smaller_leaf_splits_global_->sum_hessians(), smaller_leaf_splits_global_->sum_gradients(), smaller_leaf_splits_global_->sum_hessians(),
GetGlobalDataCountInLeaf(smaller_leaf_splits_global_->LeafIndex()), GetGlobalDataCountInLeaf(smaller_leaf_splits_global_->LeafIndex()),
smaller_leaf_histogram_array_global_[feature_index].RawData()); smaller_leaf_histogram_array_global_[feature_index].RawData());
// find best threshold // find best threshold
smaller_leaf_histogram_array_global_[feature_index].FindBestThreshold( smaller_leaf_histogram_array_global_[feature_index].FindBestThreshold(
...@@ -374,9 +384,9 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() { ...@@ -374,9 +384,9 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() {
smaller_leaf_splits_global_->sum_hessians(), smaller_leaf_splits_global_->sum_hessians(),
GetGlobalDataCountInLeaf(smaller_leaf_splits_global_->LeafIndex()), GetGlobalDataCountInLeaf(smaller_leaf_splits_global_->LeafIndex()),
&smaller_split); &smaller_split);
if (smaller_split.gain > smaller_best[tid].gain) { smaller_split.feature = real_feature_index;
smaller_best[tid] = smaller_split; if (smaller_split > smaller_bests_per_thread[tid]) {
smaller_best[tid].feature = this->train_data_->RealFeatureIndex(feature_index); smaller_bests_per_thread[tid] = smaller_split;
} }
} }
...@@ -386,9 +396,9 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() { ...@@ -386,9 +396,9 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() {
larger_leaf_histogram_array_global_[feature_index].FromMemory(output_buffer_.data() + larger_buffer_read_start_pos_[feature_index]); larger_leaf_histogram_array_global_[feature_index].FromMemory(output_buffer_.data() + larger_buffer_read_start_pos_[feature_index]);
this->train_data_->FixHistogram(feature_index, this->train_data_->FixHistogram(feature_index,
larger_leaf_splits_global_->sum_gradients(), larger_leaf_splits_global_->sum_hessians(), larger_leaf_splits_global_->sum_gradients(), larger_leaf_splits_global_->sum_hessians(),
GetGlobalDataCountInLeaf(larger_leaf_splits_global_->LeafIndex()), GetGlobalDataCountInLeaf(larger_leaf_splits_global_->LeafIndex()),
larger_leaf_histogram_array_global_[feature_index].RawData()); larger_leaf_histogram_array_global_[feature_index].RawData());
// find best threshold // find best threshold
larger_leaf_histogram_array_global_[feature_index].FindBestThreshold( larger_leaf_histogram_array_global_[feature_index].FindBestThreshold(
...@@ -396,48 +406,45 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() { ...@@ -396,48 +406,45 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() {
larger_leaf_splits_global_->sum_hessians(), larger_leaf_splits_global_->sum_hessians(),
GetGlobalDataCountInLeaf(larger_leaf_splits_global_->LeafIndex()), GetGlobalDataCountInLeaf(larger_leaf_splits_global_->LeafIndex()),
&larger_split); &larger_split);
if (larger_split.gain > larger_best[tid].gain) { larger_split.feature = real_feature_index;
larger_best[tid] = larger_split; if (larger_split > larger_best_per_thread[tid]) {
larger_best[tid].feature = this->train_data_->RealFeatureIndex(feature_index); larger_best_per_thread[tid] = larger_split;
} }
} }
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
OMP_THROW_EX(); OMP_THROW_EX();
auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_best);
auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_bests_per_thread);
int leaf = this->smaller_leaf_splits_->LeafIndex(); int leaf = this->smaller_leaf_splits_->LeafIndex();
this->best_split_per_leaf_[leaf] = smaller_best[smaller_best_idx]; this->best_split_per_leaf_[leaf] = smaller_bests_per_thread[smaller_best_idx];
if (this->larger_leaf_splits_ != nullptr && this->larger_leaf_splits_->LeafIndex() >= 0) { if (this->larger_leaf_splits_ != nullptr && this->larger_leaf_splits_->LeafIndex() >= 0) {
leaf = this->larger_leaf_splits_->LeafIndex(); leaf = this->larger_leaf_splits_->LeafIndex();
auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_best); auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_best_per_thread);
this->best_split_per_leaf_[leaf] = larger_best[larger_best_idx]; this->best_split_per_leaf_[leaf] = larger_best_per_thread[larger_best_idx];
} }
}
template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsForLeaves() {
// find local best // find local best
SplitInfo smaller_best, larger_best; SplitInfo smaller_best_split, larger_best_split;
smaller_best = this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()]; smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()];
// find local best split for larger leaf // find local best split for larger leaf
if (this->larger_leaf_splits_->LeafIndex() >= 0) { if (this->larger_leaf_splits_->LeafIndex() >= 0) {
larger_best = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()]; larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()];
} }
// sync global best info // sync global best info
std::memcpy(input_buffer_.data(), &smaller_best, sizeof(SplitInfo)); std::memcpy(input_buffer_.data(), &smaller_best_split, sizeof(SplitInfo));
std::memcpy(input_buffer_.data() + sizeof(SplitInfo), &larger_best, sizeof(SplitInfo)); std::memcpy(input_buffer_.data() + sizeof(SplitInfo), &larger_best_split, sizeof(SplitInfo));
Network::Allreduce(input_buffer_.data(), sizeof(SplitInfo) * 2, sizeof(SplitInfo), output_buffer_.data(), &SplitInfo::MaxReducer); Network::Allreduce(input_buffer_.data(), sizeof(SplitInfo) * 2, sizeof(SplitInfo), output_buffer_.data(), &SplitInfo::MaxReducer);
std::memcpy(&smaller_best, output_buffer_.data(), sizeof(SplitInfo)); std::memcpy(&smaller_best_split, output_buffer_.data(), sizeof(SplitInfo));
std::memcpy(&larger_best, output_buffer_.data() + sizeof(SplitInfo), sizeof(SplitInfo)); std::memcpy(&larger_best_split, output_buffer_.data() + sizeof(SplitInfo), sizeof(SplitInfo));
// copy back // copy back
this->best_split_per_leaf_[smaller_leaf_splits_global_->LeafIndex()] = smaller_best; this->best_split_per_leaf_[smaller_leaf_splits_global_->LeafIndex()] = smaller_best_split;
if (larger_best.feature >= 0 && larger_leaf_splits_global_->LeafIndex() >= 0) { if (larger_best_split.feature >= 0 && larger_leaf_splits_global_->LeafIndex() >= 0) {
this->best_split_per_leaf_[larger_leaf_splits_global_->LeafIndex()] = larger_best; this->best_split_per_leaf_[larger_leaf_splits_global_->LeafIndex()] = larger_best_split;
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment