Commit 98ffbb2b authored by Guolin Ke's avatar Guolin Ke
Browse files

Let TreeLearner share the same code of ConstructHistograms.

parent b6c973af
......@@ -140,12 +140,7 @@ void DataParallelTreeLearner::BeforeTrain() {
}
void DataParallelTreeLearner::FindBestThresholds() {
train_data_->ConstructHistograms(is_feature_used_,
smaller_leaf_splits_->data_indices(), smaller_leaf_splits_->num_data_in_leaf(),
smaller_leaf_splits_->LeafIndex(),
ordered_bins_, gradients_, hessians_,
ordered_gradients_.data(), ordered_hessians_.data(),
smaller_leaf_histogram_array_[0].RawData() - 1);
ConstructHistograms(is_feature_used_, true);
// construct local histograms
#pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
......
......@@ -417,25 +417,10 @@ bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int
return true;
}
void SerialTreeLearner::FindBestThresholds() {
#ifdef TIMETAG
void SerialTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) {
#ifdef TIMETAG
auto start_time = std::chrono::steady_clock::now();
#endif
std::vector<int8_t> is_feature_used(num_features_, 0);
#pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
if (!is_feature_used_[feature_index]) continue;
if (parent_leaf_histogram_array_ != nullptr
&& !parent_leaf_histogram_array_[feature_index].is_splittable()) {
smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
continue;
}
is_feature_used[feature_index] = 1;
}
bool use_subtract = true;
if (parent_leaf_histogram_array_ == nullptr) {
use_subtract = false;
}
#endif
// construct smaller leaf
HistogramBinEntry* ptr_smaller_leaf_hist_data = smaller_leaf_histogram_array_[0].RawData() - 1;
train_data_->ConstructHistograms(is_feature_used,
......@@ -455,11 +440,31 @@ void SerialTreeLearner::FindBestThresholds() {
ordered_gradients_.data(), ordered_hessians_.data(),
ptr_larger_leaf_hist_data);
}
#ifdef TIMETAG
#ifdef TIMETAG
hist_time += std::chrono::steady_clock::now() - start_time;
#endif
#endif
}
void SerialTreeLearner::FindBestThresholds() {
std::vector<int8_t> is_feature_used(num_features_, 0);
#pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
if (!is_feature_used_[feature_index]) continue;
if (parent_leaf_histogram_array_ != nullptr
&& !parent_leaf_histogram_array_[feature_index].is_splittable()) {
smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
continue;
}
is_feature_used[feature_index] = 1;
}
bool use_subtract = true;
if (parent_leaf_histogram_array_ == nullptr) {
use_subtract = false;
}
ConstructHistograms(is_feature_used, use_subtract);
#ifdef TIMETAG
start_time = std::chrono::steady_clock::now();
auto start_time = std::chrono::steady_clock::now();
#endif
std::vector<SplitInfo> smaller_best(num_threads_);
std::vector<SplitInfo> larger_best(num_threads_);
......
......@@ -69,6 +69,7 @@ protected:
*/
virtual bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf);
void ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract);
/*!
* \brief Find best thresholds for all features, using multi-threading.
......
......@@ -260,25 +260,7 @@ void VotingParallelTreeLearner::FindBestThresholds() {
if (parent_leaf_histogram_array_ == nullptr) {
use_subtract = false;
}
// construct smaller leaf
HistogramBinEntry* ptr_smaller_leaf_hist_data = smaller_leaf_histogram_array_[0].RawData() - 1;
train_data_->ConstructHistograms(is_feature_used,
smaller_leaf_splits_->data_indices(), smaller_leaf_splits_->num_data_in_leaf(),
smaller_leaf_splits_->LeafIndex(),
ordered_bins_, gradients_, hessians_,
ordered_gradients_.data(), ordered_hessians_.data(),
ptr_smaller_leaf_hist_data);
if (larger_leaf_histogram_array_ != nullptr && !use_subtract) {
// construct larger leaf
HistogramBinEntry* ptr_larger_leaf_hist_data = larger_leaf_histogram_array_[0].RawData() - 1;
train_data_->ConstructHistograms(is_feature_used,
larger_leaf_splits_->data_indices(), larger_leaf_splits_->num_data_in_leaf(),
larger_leaf_splits_->LeafIndex(),
ordered_bins_, gradients_, hessians_,
ordered_gradients_.data(), ordered_hessians_.data(),
ptr_larger_leaf_hist_data);
}
ConstructHistograms(is_feature_used, use_subtract);
std::vector<SplitInfo> smaller_bestsplit_per_features(num_features_);
std::vector<SplitInfo> larger_bestsplit_per_features(num_features_);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment