Commit 82e273ba authored by Guolin Ke's avatar Guolin Ke
Browse files

clean code for tree learner.

parent ca6018fe
......@@ -145,8 +145,8 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
}
template <typename TREELEARNER_T>
void DataParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() {
this->ConstructHistograms(this->is_feature_used_, true);
void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
TREELEARNER_T::ConstructHistograms(this->is_feature_used_, true);
// construct local histograms
#pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
......@@ -159,15 +159,21 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() {
// Reduce scatter for histogram
Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, block_start_.data(),
block_len_.data(), output_buffer_.data(), &HistogramBinEntry::SumReducer);
this->FindBestSplitsFromHistograms(this->is_feature_used_, true);
}
template <typename TREELEARNER_T>
void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>&, bool) {
std::vector<SplitInfo> smaller_bests_per_thread(this->num_threads_, SplitInfo());
std::vector<SplitInfo> larger_bests_per_thread(this->num_threads_, SplitInfo());
std::vector<SplitInfo> smaller_best(this->num_threads_, SplitInfo());
std::vector<SplitInfo> larger_best(this->num_threads_, SplitInfo());
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
OMP_LOOP_EX_BEGIN();
if (!is_feature_aggregated_[feature_index]) continue;
const int tid = omp_get_thread_num();
const int real_feature_index = this->train_data_->RealFeatureIndex(feature_index);
// restore global histograms from buffer
this->smaller_leaf_histogram_array_[feature_index].FromMemory(
output_buffer_.data() + buffer_read_start_pos_[feature_index]);
......@@ -183,9 +189,9 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() {
this->smaller_leaf_splits_->sum_hessians(),
GetGlobalDataCountInLeaf(this->smaller_leaf_splits_->LeafIndex()),
&smaller_split);
if (smaller_split.gain > smaller_best[tid].gain) {
smaller_best[tid] = smaller_split;
smaller_best[tid].feature = this->train_data_->RealFeatureIndex(feature_index);
smaller_split.feature = real_feature_index;
if (smaller_split > smaller_bests_per_thread[tid]) {
smaller_bests_per_thread[tid] = smaller_split;
}
// only root leaf
......@@ -201,49 +207,45 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() {
this->larger_leaf_splits_->sum_hessians(),
GetGlobalDataCountInLeaf(this->larger_leaf_splits_->LeafIndex()),
&larger_split);
if (larger_split.gain > larger_best[tid].gain) {
larger_best[tid] = larger_split;
larger_best[tid].feature = this->train_data_->RealFeatureIndex(feature_index);
larger_split.feature = real_feature_index;
if (larger_split > larger_bests_per_thread[tid]) {
larger_bests_per_thread[tid] = larger_split;
}
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_best);
int leaf = this->smaller_leaf_splits_->LeafIndex();
this->best_split_per_leaf_[leaf] = smaller_best[smaller_best_idx];
if (this->larger_leaf_splits_ == nullptr || this->larger_leaf_splits_->LeafIndex() < 0) { return; }
auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_bests_per_thread);
int leaf = this->smaller_leaf_splits_->LeafIndex();
this->best_split_per_leaf_[leaf] = smaller_bests_per_thread[smaller_best_idx];
if (this->larger_leaf_splits_ != nullptr && this->larger_leaf_splits_->LeafIndex() >= 0) {
leaf = this->larger_leaf_splits_->LeafIndex();
auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_best);
this->best_split_per_leaf_[leaf] = larger_best[larger_best_idx];
}
auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_bests_per_thread);
this->best_split_per_leaf_[leaf] = larger_bests_per_thread[larger_best_idx];
}
template <typename TREELEARNER_T>
void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsForLeaves() {
SplitInfo smaller_best, larger_best;
smaller_best = this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()];
SplitInfo smaller_best_split, larger_best_split;
smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()];
// find local best split for larger leaf
if (this->larger_leaf_splits_->LeafIndex() >= 0) {
larger_best = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()];
larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()];
}
// sync global best info
std::memcpy(input_buffer_.data(), &smaller_best, sizeof(SplitInfo));
std::memcpy(input_buffer_.data() + sizeof(SplitInfo), &larger_best, sizeof(SplitInfo));
std::memcpy(input_buffer_.data(), &smaller_best_split, sizeof(SplitInfo));
std::memcpy(input_buffer_.data() + sizeof(SplitInfo), &larger_best_split, sizeof(SplitInfo));
Network::Allreduce(input_buffer_.data(), sizeof(SplitInfo) * 2, sizeof(SplitInfo),
output_buffer_.data(), &SplitInfo::MaxReducer);
std::memcpy(&smaller_best, output_buffer_.data(), sizeof(SplitInfo));
std::memcpy(&larger_best, output_buffer_.data() + sizeof(SplitInfo), sizeof(SplitInfo));
std::memcpy(&smaller_best_split, output_buffer_.data(), sizeof(SplitInfo));
std::memcpy(&larger_best_split, output_buffer_.data() + sizeof(SplitInfo), sizeof(SplitInfo));
// set best split
this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best;
this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split;
if (this->larger_leaf_splits_->LeafIndex() >= 0) {
this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()] = larger_best;
this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()] = larger_best_split;
}
}
......
......@@ -50,27 +50,28 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
}
template <typename TREELEARNER_T>
void FeatureParallelTreeLearner<TREELEARNER_T>::FindBestSplitsForLeaves() {
SplitInfo smaller_best, larger_best;
void FeatureParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) {
TREELEARNER_T::FindBestSplitsFromHistograms(is_feature_used, use_subtract);
SplitInfo smaller_best_split, larger_best_split;
// get best split at smaller leaf
smaller_best = this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()];
smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()];
// find local best split for larger leaf
if (this->larger_leaf_splits_->LeafIndex() >= 0) {
larger_best = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()];
larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()];
}
// sync global best info
std::memcpy(input_buffer_.data(), &smaller_best, sizeof(SplitInfo));
std::memcpy(input_buffer_.data() + sizeof(SplitInfo), &larger_best, sizeof(SplitInfo));
std::memcpy(input_buffer_.data(), &smaller_best_split, sizeof(SplitInfo));
std::memcpy(input_buffer_.data() + sizeof(SplitInfo), &larger_best_split, sizeof(SplitInfo));
Network::Allreduce(input_buffer_.data(), sizeof(SplitInfo) * 2, sizeof(SplitInfo),
output_buffer_.data(), &SplitInfo::MaxReducer);
// copy back
std::memcpy(&smaller_best, output_buffer_.data(), sizeof(SplitInfo));
std::memcpy(&larger_best, output_buffer_.data() + sizeof(SplitInfo), sizeof(SplitInfo));
std::memcpy(&smaller_best_split, output_buffer_.data(), sizeof(SplitInfo));
std::memcpy(&larger_best_split, output_buffer_.data() + sizeof(SplitInfo), sizeof(SplitInfo));
// update best split
this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best;
this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split;
if (this->larger_leaf_splits_->LeafIndex() >= 0) {
this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()] = larger_best;
this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()] = larger_best_split;
}
}
......
......@@ -1071,8 +1071,8 @@ void GPUTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_feature_u
}
}
void GPUTreeLearner::FindBestThresholds() {
SerialTreeLearner::FindBestThresholds();
void GPUTreeLearner::FindBestSplits() {
SerialTreeLearner::FindBestSplits();
#if GPU_DEBUG >= 3
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
......
......@@ -58,7 +58,7 @@ public:
protected:
void BeforeTrain() override;
bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override;
void FindBestThresholds() override;
void FindBestSplits() override;
void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;
void ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
private:
......
......@@ -28,7 +28,7 @@ public:
protected:
void BeforeTrain() override;
void FindBestSplitsForLeaves() override;
void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
private:
/*! \brief rank of local machine */
int rank_;
......@@ -54,8 +54,8 @@ public:
void ResetConfig(const TreeConfig* tree_config) override;
protected:
void BeforeTrain() override;
void FindBestThresholds() override;
void FindBestSplitsForLeaves() override;
void FindBestSplits() override;
void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;
inline data_size_t GetGlobalDataCountInLeaf(int leaf_idx) const override {
......@@ -108,8 +108,8 @@ public:
protected:
void BeforeTrain() override;
bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override;
void FindBestThresholds() override;
void FindBestSplitsForLeaves() override;
void FindBestSplits() override;
void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;
inline data_size_t GetGlobalDataCountInLeaf(int leaf_idx) const override {
......
......@@ -179,9 +179,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
init_split_time += std::chrono::steady_clock::now() - start_time;
#endif
// find best threshold for every feature
FindBestThresholds();
// find best split from all features
FindBestSplitsForLeaves();
FindBestSplits();
}
// Get a leaf with max split gain
int best_leaf = static_cast<int>(ArrayArgs<SplitInfo>::ArgMax(best_split_per_leaf_));
......@@ -405,10 +403,27 @@ bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int
return true;
}
void SerialTreeLearner::FindBestSplits() {
std::vector<int8_t> is_feature_used(num_features_, 0);
#pragma omp parallel for schedule(static,1024) if (num_features_ >= 2048)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
if (!is_feature_used_[feature_index]) continue;
if (parent_leaf_histogram_array_ != nullptr
&& !parent_leaf_histogram_array_[feature_index].is_splittable()) {
smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
continue;
}
is_feature_used[feature_index] = 1;
}
bool use_subtract = parent_leaf_histogram_array_ != nullptr;
ConstructHistograms(is_feature_used, use_subtract);
FindBestSplitsFromHistograms(is_feature_used, use_subtract);
}
void SerialTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) {
#ifdef TIMETAG
#ifdef TIMETAG
auto start_time = std::chrono::steady_clock::now();
#endif
#endif
// construct smaller leaf
HistogramBinEntry* ptr_smaller_leaf_hist_data = smaller_leaf_histogram_array_[0].RawData() - 1;
train_data_->ConstructHistograms(is_feature_used,
......@@ -428,29 +443,12 @@ void SerialTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_featur
ordered_gradients_.data(), ordered_hessians_.data(), is_constant_hessian_,
ptr_larger_leaf_hist_data);
}
#ifdef TIMETAG
#ifdef TIMETAG
hist_time += std::chrono::steady_clock::now() - start_time;
#endif
#endif
}
void SerialTreeLearner::FindBestThresholds() {
std::vector<int8_t> is_feature_used(num_features_, 0);
#pragma omp parallel for schedule(static,1024) if (num_features_ >= 2048)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
if (!is_feature_used_[feature_index]) continue;
if (parent_leaf_histogram_array_ != nullptr
&& !parent_leaf_histogram_array_[feature_index].is_splittable()) {
smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
continue;
}
is_feature_used[feature_index] = 1;
}
bool use_subtract = true;
if (parent_leaf_histogram_array_ == nullptr) {
use_subtract = false;
}
ConstructHistograms(is_feature_used, use_subtract);
void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) {
#ifdef TIMETAG
auto start_time = std::chrono::steady_clock::now();
#endif
......@@ -517,10 +515,6 @@ void SerialTreeLearner::FindBestThresholds() {
#endif
}
void SerialTreeLearner::FindBestSplitsForLeaves() {
}
void SerialTreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) {
const SplitInfo& best_split_info = best_split_per_leaf_[best_Leaf];
......
......@@ -74,20 +74,11 @@ protected:
*/
virtual bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf);
virtual void ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract);
virtual void FindBestSplits();
/*!
* \brief Find best thresholds for all features, using multi-threading.
* The result will be stored in smaller_leaf_splits_ and larger_leaf_splits_.
* This function will be called in FindBestSplit.
*/
virtual void FindBestThresholds();
virtual void ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract);
/*!
* \brief Find best features for leaves from smaller_leaf_splits_ and larger_leaf_splits_.
* This function will be called after FindBestThresholds.
*/
virtual void FindBestSplitsForLeaves();
virtual void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract);
/*!
* \brief Partition tree and data according best split.
......
......@@ -252,7 +252,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::CopyLocalHistogram(const std::vec
}
template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() {
void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
// use local data to find local best splits
std::vector<int8_t> is_feature_used(this->num_features_, 0);
#pragma omp parallel for schedule(static)
......@@ -269,10 +269,11 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() {
if (this->parent_leaf_histogram_array_ == nullptr) {
use_subtract = false;
}
this->ConstructHistograms(is_feature_used, use_subtract);
TREELEARNER_T::ConstructHistograms(is_feature_used, use_subtract);
std::vector<SplitInfo> smaller_bestsplit_per_features(this->num_features_);
std::vector<SplitInfo> larger_bestsplit_per_features(this->num_features_);
OMP_INIT_EX();
// find splits
#pragma omp parallel for schedule(static)
......@@ -350,13 +351,22 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() {
Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, block_start_.data(), block_len_.data(),
output_buffer_.data(), &HistogramBinEntry::SumReducer);
std::vector<SplitInfo> smaller_best(this->num_threads_);
std::vector<SplitInfo> larger_best(this->num_threads_);
this->FindBestSplitsFromHistograms(is_feature_used, false);
}
template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>&, bool) {
std::vector<SplitInfo> smaller_bests_per_thread(this->num_threads_);
std::vector<SplitInfo> larger_best_per_thread(this->num_threads_);
// find best split from local aggregated histograms
#pragma omp parallel for schedule(static)
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
OMP_LOOP_EX_BEGIN();
const int tid = omp_get_thread_num();
const int real_feature_index = this->train_data_->RealFeatureIndex(feature_index);
if (smaller_is_feature_aggregated_[feature_index]) {
SplitInfo smaller_split;
// restore from buffer
......@@ -374,9 +384,9 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() {
smaller_leaf_splits_global_->sum_hessians(),
GetGlobalDataCountInLeaf(smaller_leaf_splits_global_->LeafIndex()),
&smaller_split);
if (smaller_split.gain > smaller_best[tid].gain) {
smaller_best[tid] = smaller_split;
smaller_best[tid].feature = this->train_data_->RealFeatureIndex(feature_index);
smaller_split.feature = real_feature_index;
if (smaller_split > smaller_bests_per_thread[tid]) {
smaller_bests_per_thread[tid] = smaller_split;
}
}
......@@ -396,48 +406,45 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() {
larger_leaf_splits_global_->sum_hessians(),
GetGlobalDataCountInLeaf(larger_leaf_splits_global_->LeafIndex()),
&larger_split);
if (larger_split.gain > larger_best[tid].gain) {
larger_best[tid] = larger_split;
larger_best[tid].feature = this->train_data_->RealFeatureIndex(feature_index);
larger_split.feature = real_feature_index;
if (larger_split > larger_best_per_thread[tid]) {
larger_best_per_thread[tid] = larger_split;
}
}
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_best);
auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_bests_per_thread);
int leaf = this->smaller_leaf_splits_->LeafIndex();
this->best_split_per_leaf_[leaf] = smaller_best[smaller_best_idx];
this->best_split_per_leaf_[leaf] = smaller_bests_per_thread[smaller_best_idx];
if (this->larger_leaf_splits_ != nullptr && this->larger_leaf_splits_->LeafIndex() >= 0) {
leaf = this->larger_leaf_splits_->LeafIndex();
auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_best);
this->best_split_per_leaf_[leaf] = larger_best[larger_best_idx];
auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_best_per_thread);
this->best_split_per_leaf_[leaf] = larger_best_per_thread[larger_best_idx];
}
}
template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsForLeaves() {
// find local best
SplitInfo smaller_best, larger_best;
smaller_best = this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()];
SplitInfo smaller_best_split, larger_best_split;
smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()];
// find local best split for larger leaf
if (this->larger_leaf_splits_->LeafIndex() >= 0) {
larger_best = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()];
larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()];
}
// sync global best info
std::memcpy(input_buffer_.data(), &smaller_best, sizeof(SplitInfo));
std::memcpy(input_buffer_.data() + sizeof(SplitInfo), &larger_best, sizeof(SplitInfo));
std::memcpy(input_buffer_.data(), &smaller_best_split, sizeof(SplitInfo));
std::memcpy(input_buffer_.data() + sizeof(SplitInfo), &larger_best_split, sizeof(SplitInfo));
Network::Allreduce(input_buffer_.data(), sizeof(SplitInfo) * 2, sizeof(SplitInfo), output_buffer_.data(), &SplitInfo::MaxReducer);
std::memcpy(&smaller_best, output_buffer_.data(), sizeof(SplitInfo));
std::memcpy(&larger_best, output_buffer_.data() + sizeof(SplitInfo), sizeof(SplitInfo));
std::memcpy(&smaller_best_split, output_buffer_.data(), sizeof(SplitInfo));
std::memcpy(&larger_best_split, output_buffer_.data() + sizeof(SplitInfo), sizeof(SplitInfo));
// copy back
this->best_split_per_leaf_[smaller_leaf_splits_global_->LeafIndex()] = smaller_best;
if (larger_best.feature >= 0 && larger_leaf_splits_global_->LeafIndex() >= 0) {
this->best_split_per_leaf_[larger_leaf_splits_global_->LeafIndex()] = larger_best;
this->best_split_per_leaf_[smaller_leaf_splits_global_->LeafIndex()] = smaller_best_split;
if (larger_best_split.feature >= 0 && larger_leaf_splits_global_->LeafIndex() >= 0) {
this->best_split_per_leaf_[larger_leaf_splits_global_->LeafIndex()] = larger_best_split;
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment