"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "5694f166247052ed79a12145d67d6a7b436a604b"
Commit 98ffbb2b authored by Guolin Ke's avatar Guolin Ke
Browse files

Let TreeLearner share the same code of ConstructHistograms.

parent b6c973af
...@@ -140,12 +140,7 @@ void DataParallelTreeLearner::BeforeTrain() { ...@@ -140,12 +140,7 @@ void DataParallelTreeLearner::BeforeTrain() {
} }
void DataParallelTreeLearner::FindBestThresholds() { void DataParallelTreeLearner::FindBestThresholds() {
train_data_->ConstructHistograms(is_feature_used_, ConstructHistograms(is_feature_used_, true);
smaller_leaf_splits_->data_indices(), smaller_leaf_splits_->num_data_in_leaf(),
smaller_leaf_splits_->LeafIndex(),
ordered_bins_, gradients_, hessians_,
ordered_gradients_.data(), ordered_hessians_.data(),
smaller_leaf_histogram_array_[0].RawData() - 1);
// construct local histograms // construct local histograms
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) { for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
......
...@@ -417,25 +417,10 @@ bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int ...@@ -417,25 +417,10 @@ bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int
return true; return true;
} }
void SerialTreeLearner::FindBestThresholds() { void SerialTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) {
#ifdef TIMETAG #ifdef TIMETAG
auto start_time = std::chrono::steady_clock::now(); auto start_time = std::chrono::steady_clock::now();
#endif #endif
std::vector<int8_t> is_feature_used(num_features_, 0);
#pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
if (!is_feature_used_[feature_index]) continue;
if (parent_leaf_histogram_array_ != nullptr
&& !parent_leaf_histogram_array_[feature_index].is_splittable()) {
smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
continue;
}
is_feature_used[feature_index] = 1;
}
bool use_subtract = true;
if (parent_leaf_histogram_array_ == nullptr) {
use_subtract = false;
}
// construct smaller leaf // construct smaller leaf
HistogramBinEntry* ptr_smaller_leaf_hist_data = smaller_leaf_histogram_array_[0].RawData() - 1; HistogramBinEntry* ptr_smaller_leaf_hist_data = smaller_leaf_histogram_array_[0].RawData() - 1;
train_data_->ConstructHistograms(is_feature_used, train_data_->ConstructHistograms(is_feature_used,
...@@ -455,11 +440,31 @@ void SerialTreeLearner::FindBestThresholds() { ...@@ -455,11 +440,31 @@ void SerialTreeLearner::FindBestThresholds() {
ordered_gradients_.data(), ordered_hessians_.data(), ordered_gradients_.data(), ordered_hessians_.data(),
ptr_larger_leaf_hist_data); ptr_larger_leaf_hist_data);
} }
#ifdef TIMETAG #ifdef TIMETAG
hist_time += std::chrono::steady_clock::now() - start_time; hist_time += std::chrono::steady_clock::now() - start_time;
#endif #endif
}
void SerialTreeLearner::FindBestThresholds() {
std::vector<int8_t> is_feature_used(num_features_, 0);
#pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
if (!is_feature_used_[feature_index]) continue;
if (parent_leaf_histogram_array_ != nullptr
&& !parent_leaf_histogram_array_[feature_index].is_splittable()) {
smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
continue;
}
is_feature_used[feature_index] = 1;
}
bool use_subtract = true;
if (parent_leaf_histogram_array_ == nullptr) {
use_subtract = false;
}
ConstructHistograms(is_feature_used, use_subtract);
#ifdef TIMETAG #ifdef TIMETAG
start_time = std::chrono::steady_clock::now(); auto start_time = std::chrono::steady_clock::now();
#endif #endif
std::vector<SplitInfo> smaller_best(num_threads_); std::vector<SplitInfo> smaller_best(num_threads_);
std::vector<SplitInfo> larger_best(num_threads_); std::vector<SplitInfo> larger_best(num_threads_);
......
...@@ -69,6 +69,7 @@ protected: ...@@ -69,6 +69,7 @@ protected:
*/ */
virtual bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf); virtual bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf);
void ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract);
/*! /*!
* \brief Find best thresholds for all features, using multi-threading. * \brief Find best thresholds for all features, using multi-threading.
......
...@@ -260,25 +260,7 @@ void VotingParallelTreeLearner::FindBestThresholds() { ...@@ -260,25 +260,7 @@ void VotingParallelTreeLearner::FindBestThresholds() {
if (parent_leaf_histogram_array_ == nullptr) { if (parent_leaf_histogram_array_ == nullptr) {
use_subtract = false; use_subtract = false;
} }
// construct smaller leaf ConstructHistograms(is_feature_used, use_subtract);
HistogramBinEntry* ptr_smaller_leaf_hist_data = smaller_leaf_histogram_array_[0].RawData() - 1;
train_data_->ConstructHistograms(is_feature_used,
smaller_leaf_splits_->data_indices(), smaller_leaf_splits_->num_data_in_leaf(),
smaller_leaf_splits_->LeafIndex(),
ordered_bins_, gradients_, hessians_,
ordered_gradients_.data(), ordered_hessians_.data(),
ptr_smaller_leaf_hist_data);
if (larger_leaf_histogram_array_ != nullptr && !use_subtract) {
// construct larger leaf
HistogramBinEntry* ptr_larger_leaf_hist_data = larger_leaf_histogram_array_[0].RawData() - 1;
train_data_->ConstructHistograms(is_feature_used,
larger_leaf_splits_->data_indices(), larger_leaf_splits_->num_data_in_leaf(),
larger_leaf_splits_->LeafIndex(),
ordered_bins_, gradients_, hessians_,
ordered_gradients_.data(), ordered_hessians_.data(),
ptr_larger_leaf_hist_data);
}
std::vector<SplitInfo> smaller_bestsplit_per_features(num_features_); std::vector<SplitInfo> smaller_bestsplit_per_features(num_features_);
std::vector<SplitInfo> larger_bestsplit_per_features(num_features_); std::vector<SplitInfo> larger_bestsplit_per_features(num_features_);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment