Unverified Commit 77d92b7c authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

speed up `FindBestThresholdFromHistogram` (#2867)

* speed up for const hessian

* rename template

* some refactorings

* refine

* refine

* simplify codes

* fix random in feature histogram

* code refine

* refine

* try fix

* make gcc happy

* remove timer

* rollback some changes

* more templates

* fix a bug

* reduce the cost of timer

* fix gpu

* fix bug

* fix gpu
parent 7776cfea
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <LightGBM/json11.hpp> #include <LightGBM/utils/json11.h>
namespace LightGBM { namespace LightGBM {
...@@ -48,6 +48,8 @@ class TreeLearner { ...@@ -48,6 +48,8 @@ class TreeLearner {
*/ */
virtual void ResetConfig(const Config* config) = 0; virtual void ResetConfig(const Config* config) = 0;
virtual void SetForcedSplit(const Json* forced_split_json) = 0;
/*! /*!
* \brief training tree model on dataset * \brief training tree model on dataset
* \param gradients The first order gradients * \param gradients The first order gradients
...@@ -55,8 +57,7 @@ class TreeLearner { ...@@ -55,8 +57,7 @@ class TreeLearner {
* \param is_constant_hessian True if all hessians share the same value * \param is_constant_hessian True if all hessians share the same value
* \return A trained tree * \return A trained tree
*/ */
virtual Tree* Train(const score_t* gradients, const score_t* hessians, virtual Tree* Train(const score_t* gradients, const score_t* hessians) = 0;
const Json& forced_split_json) = 0;
/*! /*!
* \brief use an existing tree to fit the new gradients and hessians. * \brief use an existing tree to fit the new gradients and hessians.
......
...@@ -1089,11 +1089,10 @@ class Timer { ...@@ -1089,11 +1089,10 @@ class Timer {
// Note: this class is not thread-safe, don't use it inside omp blocks // Note: this class is not thread-safe, don't use it inside omp blocks
class FunctionTimer { class FunctionTimer {
public: public:
#ifdef TIMETAG
FunctionTimer(const std::string& name, Timer& timer) : timer_(timer) { FunctionTimer(const std::string& name, Timer& timer) : timer_(timer) {
timer.Start(name); timer.Start(name);
#ifdef TIMETAG
name_ = name; name_ = name;
#endif // TIMETAG
} }
~FunctionTimer() { timer_.Stop(name_); } ~FunctionTimer() { timer_.Stop(name_); }
...@@ -1101,6 +1100,9 @@ class FunctionTimer { ...@@ -1101,6 +1100,9 @@ class FunctionTimer {
private: private:
std::string name_; std::string name_;
Timer& timer_; Timer& timer_;
#else
FunctionTimer(const std::string&, Timer&) {}
#endif // TIMETAG
}; };
} // namespace Common } // namespace Common
......
...@@ -58,10 +58,9 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective ...@@ -58,10 +58,9 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
es_first_metric_only_ = config_->first_metric_only; es_first_metric_only_ = config_->first_metric_only;
shrinkage_rate_ = config_->learning_rate; shrinkage_rate_ = config_->learning_rate;
std::string forced_splits_path = config->forcedsplits_filename;
// load forced_splits file // load forced_splits file
if (forced_splits_path != "") { if (!config->forcedsplits_filename.empty()) {
std::ifstream forced_splits_file(forced_splits_path.c_str()); std::ifstream forced_splits_file(config->forcedsplits_filename.c_str());
std::stringstream buffer; std::stringstream buffer;
buffer << forced_splits_file.rdbuf(); buffer << forced_splits_file.rdbuf();
std::string err; std::string err;
...@@ -81,6 +80,7 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective ...@@ -81,6 +80,7 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
// init tree learner // init tree learner
tree_learner_->Init(train_data_, is_constant_hessian_); tree_learner_->Init(train_data_, is_constant_hessian_);
tree_learner_->SetForcedSplit(&forced_splits_json_);
// push training metrics // push training metrics
training_metrics_.clear(); training_metrics_.clear();
...@@ -366,7 +366,7 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) { ...@@ -366,7 +366,7 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
grad = gradients_.data() + offset; grad = gradients_.data() + offset;
hess = hessians_.data() + offset; hess = hessians_.data() + offset;
} }
new_tree.reset(tree_learner_->Train(grad, hess, forced_splits_json_)); new_tree.reset(tree_learner_->Train(grad, hess));
} }
if (new_tree->num_leaves() > 1) { if (new_tree->num_leaves() > 1) {
...@@ -717,6 +717,21 @@ void GBDT::ResetConfig(const Config* config) { ...@@ -717,6 +717,21 @@ void GBDT::ResetConfig(const Config* config) {
if (train_data_ != nullptr) { if (train_data_ != nullptr) {
ResetBaggingConfig(new_config.get(), false); ResetBaggingConfig(new_config.get(), false);
} }
if (config_->forcedsplits_filename != new_config->forcedbins_filename) {
// load forced_splits file
if (!new_config->forcedsplits_filename.empty()) {
std::ifstream forced_splits_file(
new_config->forcedsplits_filename.c_str());
std::stringstream buffer;
buffer << forced_splits_file.rdbuf();
std::string err;
forced_splits_json_ = Json::parse(buffer.str(), err);
tree_learner_->SetForcedSplit(&forced_splits_json_);
} else {
forced_splits_json_ = Json();
tree_learner_->SetForcedSplit(nullptr);
}
}
config_.reset(new_config.release()); config_.reset(new_config.release());
} }
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <LightGBM/json11.hpp> #include <LightGBM/utils/json11.h>
#include "score_updater.hpp" #include "score_updater.hpp"
namespace LightGBM { namespace LightGBM {
......
...@@ -125,7 +125,7 @@ class RF : public GBDT { ...@@ -125,7 +125,7 @@ class RF : public GBDT {
hess = tmp_hess_.data(); hess = tmp_hess_.data();
} }
new_tree.reset(tree_learner_->Train(grad, hess, forced_splits_json_)); new_tree.reset(tree_learner_->Train(grad, hess));
} }
if (new_tree->num_leaves() > 1) { if (new_tree->num_leaves() > 1) {
......
...@@ -295,7 +295,6 @@ class Booster { ...@@ -295,7 +295,6 @@ class Booster {
if (param.count("metric")) { if (param.count("metric")) {
Log::Fatal("Cannot change metric during training"); Log::Fatal("Cannot change metric during training");
} }
CheckDatasetResetConfig(config_, param); CheckDatasetResetConfig(config_, param);
config_.Set(param); config_.Set(param);
......
...@@ -293,6 +293,12 @@ void Config::CheckParamConflict() { ...@@ -293,6 +293,12 @@ void Config::CheckParamConflict() {
histogram_pool_size = -1; histogram_pool_size = -1;
} }
} }
if (is_data_based_parallel) {
if (!forcedsplits_filename.empty()) {
Log::Fatal("Don't support forcedsplits in %s tree learner",
tree_learner.c_str());
}
}
// Check max_depth and num_leaves // Check max_depth and num_leaves
if (max_depth > 0) { if (max_depth > 0) {
double full_num_leaves = std::pow(2, max_depth); double full_num_leaves = std::pow(2, max_depth);
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include <fstream> #include <fstream>
#include <LightGBM/json11.hpp> #include <LightGBM/utils/json11.h>
namespace LightGBM { namespace LightGBM {
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <LightGBM/json11.hpp> #include <LightGBM/utils/json11.h>
#include <limits> #include <limits>
#include <cassert> #include <cassert>
......
/*!
* Copyright (c) 2020 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/
#ifndef LIGHTGBM_TREELEARNER_COL_SAMPLER_HPP_
#define LIGHTGBM_TREELEARNER_COL_SAMPLER_HPP_
#include <LightGBM/dataset.h>
#include <LightGBM/meta.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/random.h>
namespace LightGBM {
class ColSampler {
public:
ColSampler(const Config* config)
: fraction_bytree_(config->feature_fraction),
fraction_bynode_(config->feature_fraction_bynode),
seed_(config->feature_fraction_seed),
random_(config->feature_fraction_seed) {
}
static int GetCnt(size_t total_cnt, double fraction) {
const int min = std::min(2, static_cast<int>(total_cnt));
int used_feature_cnt = static_cast<int>(Common::RoundInt(total_cnt * fraction));
return std::max(used_feature_cnt, min);
}
void SetTrainingData(const Dataset* train_data) {
train_data_ = train_data;
is_feature_used_.resize(train_data_->num_features(), 1);
valid_feature_indices_ = train_data->ValidFeatureIndices();
if (fraction_bytree_ >= 1.0f) {
need_reset_bytree_ = false;
used_cnt_bytree_ = static_cast<int>(valid_feature_indices_.size());
} else {
need_reset_bytree_ = true;
used_cnt_bytree_ =
GetCnt(valid_feature_indices_.size(), fraction_bytree_);
}
ResetByTree();
}
void SetConfig(const Config* config) {
fraction_bytree_ = config->feature_fraction;
fraction_bynode_ = config->feature_fraction_bynode;
is_feature_used_.resize(train_data_->num_features(), 1);
// seed is changed
if (seed_ != config->feature_fraction_seed) {
seed_ = config->feature_fraction_seed;
random_ = Random(seed_);
}
if (fraction_bytree_ >= 1.0f) {
need_reset_bytree_ = false;
used_cnt_bytree_ = static_cast<int>(valid_feature_indices_.size());
} else {
need_reset_bytree_ = true;
used_cnt_bytree_ =
GetCnt(valid_feature_indices_.size(), fraction_bytree_);
}
ResetByTree();
}
void ResetByTree() {
if (need_reset_bytree_) {
std::memset(is_feature_used_.data(), 0,
sizeof(int8_t) * is_feature_used_.size());
used_feature_indices_ = random_.Sample(
static_cast<int>(valid_feature_indices_.size()), used_cnt_bytree_);
int omp_loop_size = static_cast<int>(used_feature_indices_.size());
#pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024)
for (int i = 0; i < omp_loop_size; ++i) {
int used_feature = valid_feature_indices_[used_feature_indices_[i]];
int inner_feature_index = train_data_->InnerFeatureIndex(used_feature);
is_feature_used_[inner_feature_index] = 1;
}
}
}
std::vector<int8_t> GetByNode() {
if (fraction_bynode_ >= 1.0f) {
return std::vector<int8_t>(train_data_->num_features(), 1);
}
std::vector<int8_t> ret(train_data_->num_features(), 0);
if (need_reset_bytree_) {
auto used_feature_cnt = GetCnt(used_feature_indices_.size(), fraction_bynode_);
auto sampled_indices = random_.Sample(
static_cast<int>(used_feature_indices_.size()), used_feature_cnt);
int omp_loop_size = static_cast<int>(sampled_indices.size());
#pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024)
for (int i = 0; i < omp_loop_size; ++i) {
int used_feature =
valid_feature_indices_[used_feature_indices_[sampled_indices[i]]];
int inner_feature_index = train_data_->InnerFeatureIndex(used_feature);
ret[inner_feature_index] = 1;
}
} else {
auto used_feature_cnt =
GetCnt(valid_feature_indices_.size(), fraction_bynode_);
auto sampled_indices = random_.Sample(
static_cast<int>(valid_feature_indices_.size()), used_feature_cnt);
int omp_loop_size = static_cast<int>(sampled_indices.size());
#pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024)
for (int i = 0; i < omp_loop_size; ++i) {
int used_feature = valid_feature_indices_[sampled_indices[i]];
int inner_feature_index = train_data_->InnerFeatureIndex(used_feature);
ret[inner_feature_index] = 1;
}
}
return ret;
}
const std::vector<int8_t>& is_feature_used_bytree() const {
return is_feature_used_;
}
void SetIsFeatureUsedByTree(int fid, bool val) {
is_feature_used_[fid] = val;
}
private:
const Dataset* train_data_;
double fraction_bytree_;
double fraction_bynode_;
bool need_reset_bytree_;
int used_cnt_bytree_;
int seed_;
Random random_;
std::vector<int8_t> is_feature_used_;
std::vector<int> used_feature_indices_;
std::vector<int> valid_feature_indices_;
};
} // namespace LightGBM
#endif // LIGHTGBM_TREELEARNER_COL_SAMPLER_HPP_
...@@ -57,7 +57,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() { ...@@ -57,7 +57,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
for (int i = 0; i < this->train_data_->num_total_features(); ++i) { for (int i = 0; i < this->train_data_->num_total_features(); ++i) {
int inner_feature_index = this->train_data_->InnerFeatureIndex(i); int inner_feature_index = this->train_data_->InnerFeatureIndex(i);
if (inner_feature_index == -1) { continue; } if (inner_feature_index == -1) { continue; }
if (this->is_feature_used_[inner_feature_index]) { if (this->col_sampler_.is_feature_used_bytree()[inner_feature_index]) {
int cur_min_machine = static_cast<int>(ArrayArgs<int>::ArgMin(num_bins_distributed)); int cur_min_machine = static_cast<int>(ArrayArgs<int>::ArgMin(num_bins_distributed));
feature_distribution[cur_min_machine].push_back(inner_feature_index); feature_distribution[cur_min_machine].push_back(inner_feature_index);
auto num_bin = this->train_data_->FeatureNumBin(inner_feature_index); auto num_bin = this->train_data_->FeatureNumBin(inner_feature_index);
...@@ -147,11 +147,13 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() { ...@@ -147,11 +147,13 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplits() { void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
TREELEARNER_T::ConstructHistograms(this->is_feature_used_, true); TREELEARNER_T::ConstructHistograms(
this->col_sampler_.is_feature_used_bytree(), true);
// construct local histograms // construct local histograms
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) { for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
if ((!this->is_feature_used_.empty() && this->is_feature_used_[feature_index] == false)) continue; if (this->col_sampler_.is_feature_used_bytree()[feature_index] == false)
continue;
// copy to buffer // copy to buffer
std::memcpy(input_buffer_.data() + buffer_write_start_pos_[feature_index], std::memcpy(input_buffer_.data() + buffer_write_start_pos_[feature_index],
this->smaller_leaf_histogram_array_[feature_index].RawData(), this->smaller_leaf_histogram_array_[feature_index].RawData(),
...@@ -160,19 +162,18 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplits() { ...@@ -160,19 +162,18 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
// Reduce scatter for histogram // Reduce scatter for histogram
Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, sizeof(hist_t), block_start_.data(), Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, sizeof(hist_t), block_start_.data(),
block_len_.data(), output_buffer_.data(), static_cast<comm_size_t>(output_buffer_.size()), &HistogramSumReducer); block_len_.data(), output_buffer_.data(), static_cast<comm_size_t>(output_buffer_.size()), &HistogramSumReducer);
this->FindBestSplitsFromHistograms(this->is_feature_used_, true); this->FindBestSplitsFromHistograms(
this->col_sampler_.is_feature_used_bytree(), true);
} }
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>&, bool) { void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>&, bool) {
std::vector<SplitInfo> smaller_bests_per_thread(this->share_state_->num_threads); std::vector<SplitInfo> smaller_bests_per_thread(this->share_state_->num_threads);
std::vector<SplitInfo> larger_bests_per_thread(this->share_state_->num_threads); std::vector<SplitInfo> larger_bests_per_thread(this->share_state_->num_threads);
std::vector<int8_t> smaller_node_used_features(this->num_features_, 1); std::vector<int8_t> smaller_node_used_features =
std::vector<int8_t> larger_node_used_features(this->num_features_, 1); this->col_sampler_.GetByNode();
if (this->config_->feature_fraction_bynode < 1.0f) { std::vector<int8_t> larger_node_used_features =
smaller_node_used_features = this->GetUsedFeatures(false); this->col_sampler_.GetByNode();
larger_node_used_features = this->GetUsedFeatures(false);
}
OMP_INIT_EX(); OMP_INIT_EX();
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) { for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
...@@ -241,7 +242,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const ...@@ -241,7 +242,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
void DataParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) { void DataParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) {
this->SplitInner(tree, best_Leaf, left_leaf, right_leaf, false); TREELEARNER_T::SplitInner(tree, best_Leaf, left_leaf, right_leaf, false);
const SplitInfo& best_split_info = this->best_split_per_leaf_[best_Leaf]; const SplitInfo& best_split_info = this->best_split_per_leaf_[best_Leaf];
// need update global number of data in leaf // need update global number of data in leaf
global_data_count_in_leaf_[*left_leaf] = best_split_info.left_count; global_data_count_in_leaf_[*left_leaf] = best_split_info.left_count;
......
This diff is collapsed.
...@@ -38,16 +38,16 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::BeforeTrain() { ...@@ -38,16 +38,16 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
for (int i = 0; i < this->train_data_->num_total_features(); ++i) { for (int i = 0; i < this->train_data_->num_total_features(); ++i) {
int inner_feature_index = this->train_data_->InnerFeatureIndex(i); int inner_feature_index = this->train_data_->InnerFeatureIndex(i);
if (inner_feature_index == -1) { continue; } if (inner_feature_index == -1) { continue; }
if (this->is_feature_used_[inner_feature_index]) { if (this->col_sampler_.is_feature_used_bytree()[inner_feature_index]) {
int cur_min_machine = static_cast<int>(ArrayArgs<int>::ArgMin(num_bins_distributed)); int cur_min_machine = static_cast<int>(ArrayArgs<int>::ArgMin(num_bins_distributed));
feature_distribution[cur_min_machine].push_back(inner_feature_index); feature_distribution[cur_min_machine].push_back(inner_feature_index);
num_bins_distributed[cur_min_machine] += this->train_data_->FeatureNumBin(inner_feature_index); num_bins_distributed[cur_min_machine] += this->train_data_->FeatureNumBin(inner_feature_index);
this->is_feature_used_[inner_feature_index] = false; this->col_sampler_.SetIsFeatureUsedByTree(inner_feature_index, false);
} }
} }
// get local used features // get local used features
for (auto fid : feature_distribution[rank_]) { for (auto fid : feature_distribution[rank_]) {
this->is_feature_used_[fid] = true; this->col_sampler_.SetIsFeatureUsedByTree(fid, true);
} }
} }
......
...@@ -735,9 +735,8 @@ void GPUTreeLearner::InitGPU(int platform_id, int device_id) { ...@@ -735,9 +735,8 @@ void GPUTreeLearner::InitGPU(int platform_id, int device_id) {
SetupKernelArguments(); SetupKernelArguments();
} }
Tree* GPUTreeLearner::Train(const score_t* gradients, const score_t *hessians, Tree* GPUTreeLearner::Train(const score_t* gradients, const score_t *hessians) {
const Json& forced_split_json) { return SerialTreeLearner::Train(gradients, hessians);
return SerialTreeLearner::Train(gradients, hessians, forced_split_json);
} }
void GPUTreeLearner::ResetTrainingDataInner(const Dataset* train_data, bool is_constant_hessian, bool reset_multi_val_bin) { void GPUTreeLearner::ResetTrainingDataInner(const Dataset* train_data, bool is_constant_hessian, bool reset_multi_val_bin) {
...@@ -957,7 +956,7 @@ void GPUTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_feature_u ...@@ -957,7 +956,7 @@ void GPUTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_feature_u
std::vector<int8_t> is_dense_feature_used(num_features_, 0); std::vector<int8_t> is_dense_feature_used(num_features_, 0);
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) { for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
if (!is_feature_used_[feature_index]) continue; if (!col_sampler_.is_feature_used_bytree()[feature_index]) continue;
if (!is_feature_used[feature_index]) continue; if (!is_feature_used[feature_index]) continue;
if (train_data_->IsMultiGroup(train_data_->Feature2Group(feature_index))) { if (train_data_->IsMultiGroup(train_data_->Feature2Group(feature_index))) {
is_sparse_feature_used[feature_index] = 1; is_sparse_feature_used[feature_index] = 1;
...@@ -1062,7 +1061,7 @@ void GPUTreeLearner::FindBestSplits() { ...@@ -1062,7 +1061,7 @@ void GPUTreeLearner::FindBestSplits() {
#if GPU_DEBUG >= 3 #if GPU_DEBUG >= 3
for (int feature_index = 0; feature_index < num_features_; ++feature_index) { for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
if (!is_feature_used_[feature_index]) continue; if (!col_sampler_.is_feature_used_bytree()[feature_index]) continue;
if (parent_leaf_histogram_array_ != nullptr if (parent_leaf_histogram_array_ != nullptr
&& !parent_leaf_histogram_array_[feature_index].is_splittable()) { && !parent_leaf_histogram_array_[feature_index].is_splittable()) {
smaller_leaf_histogram_array_[feature_index].set_is_splittable(false); smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
......
...@@ -47,9 +47,8 @@ class GPUTreeLearner: public SerialTreeLearner { ...@@ -47,9 +47,8 @@ class GPUTreeLearner: public SerialTreeLearner {
~GPUTreeLearner(); ~GPUTreeLearner();
void Init(const Dataset* train_data, bool is_constant_hessian) override; void Init(const Dataset* train_data, bool is_constant_hessian) override;
void ResetTrainingDataInner(const Dataset* train_data, bool is_constant_hessian, bool reset_multi_val_bin) override; void ResetTrainingDataInner(const Dataset* train_data, bool is_constant_hessian, bool reset_multi_val_bin) override;
void ResetIsConstantHessian(bool is_constant_hessian); void ResetIsConstantHessian(bool is_constant_hessian) override;
Tree* Train(const score_t* gradients, const score_t *hessians, Tree* Train(const score_t* gradients, const score_t *hessians) override;
const Json& forced_split_json) override;
void SetBaggingData(const Dataset* subset, const data_size_t* used_indices, data_size_t num_data) override { void SetBaggingData(const Dataset* subset, const data_size_t* used_indices, data_size_t num_data) override {
SerialTreeLearner::SetBaggingData(subset, used_indices, num_data); SerialTreeLearner::SetBaggingData(subset, used_indices, num_data);
......
...@@ -19,8 +19,7 @@ ...@@ -19,8 +19,7 @@
namespace LightGBM { namespace LightGBM {
SerialTreeLearner::SerialTreeLearner(const Config* config) SerialTreeLearner::SerialTreeLearner(const Config* config)
:config_(config) { : config_(config), col_sampler_(config) {
random_ = Random(config_->feature_fraction_seed);
} }
SerialTreeLearner::~SerialTreeLearner() { SerialTreeLearner::~SerialTreeLearner() {
...@@ -55,8 +54,7 @@ void SerialTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian ...@@ -55,8 +54,7 @@ void SerialTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian
// initialize data partition // initialize data partition
data_partition_.reset(new DataPartition(num_data_, config_->num_leaves)); data_partition_.reset(new DataPartition(num_data_, config_->num_leaves));
is_feature_used_.resize(num_features_); col_sampler_.SetTrainingData(train_data_);
valid_feature_indices_ = train_data_->ValidFeatureIndices();
// initialize ordered gradients and hessians // initialize ordered gradients and hessians
ordered_gradients_.resize(num_data_); ordered_gradients_.resize(num_data_);
ordered_hessians_.resize(num_data_); ordered_hessians_.resize(num_data_);
...@@ -74,15 +72,15 @@ void SerialTreeLearner::GetShareStates(const Dataset* dataset, ...@@ -74,15 +72,15 @@ void SerialTreeLearner::GetShareStates(const Dataset* dataset,
bool is_constant_hessian, bool is_constant_hessian,
bool is_first_time) { bool is_first_time) {
if (is_first_time) { if (is_first_time) {
auto used_feature = GetUsedFeatures(true);
share_state_.reset(dataset->GetShareStates( share_state_.reset(dataset->GetShareStates(
ordered_gradients_.data(), ordered_hessians_.data(), used_feature, ordered_gradients_.data(), ordered_hessians_.data(),
is_constant_hessian, config_->force_col_wise, config_->force_row_wise)); col_sampler_.is_feature_used_bytree(), is_constant_hessian,
config_->force_col_wise, config_->force_row_wise));
} else { } else {
CHECK_NOTNULL(share_state_); CHECK_NOTNULL(share_state_);
// cannot change is_hist_col_wise during training // cannot change is_hist_col_wise during training
share_state_.reset(dataset->GetShareStates( share_state_.reset(dataset->GetShareStates(
ordered_gradients_.data(), ordered_hessians_.data(), is_feature_used_, ordered_gradients_.data(), ordered_hessians_.data(), col_sampler_.is_feature_used_bytree(),
is_constant_hessian, share_state_->is_colwise, is_constant_hessian, share_state_->is_colwise,
!share_state_->is_colwise)); !share_state_->is_colwise));
} }
...@@ -102,15 +100,14 @@ void SerialTreeLearner::ResetTrainingDataInner(const Dataset* train_data, ...@@ -102,15 +100,14 @@ void SerialTreeLearner::ResetTrainingDataInner(const Dataset* train_data,
// initialize data partition // initialize data partition
data_partition_->ResetNumData(num_data_); data_partition_->ResetNumData(num_data_);
if (reset_multi_val_bin) { if (reset_multi_val_bin) {
col_sampler_.SetTrainingData(train_data_);
GetShareStates(train_data_, is_constant_hessian, false); GetShareStates(train_data_, is_constant_hessian, false);
} }
// initialize ordered gradients and hessians // initialize ordered gradients and hessians
ordered_gradients_.resize(num_data_); ordered_gradients_.resize(num_data_);
ordered_hessians_.resize(num_data_); ordered_hessians_.resize(num_data_);
if (cegb_ != nullptr) { if (cegb_ != nullptr) {
cegb_->Init(); cegb_->Init();
} }
...@@ -141,6 +138,7 @@ void SerialTreeLearner::ResetConfig(const Config* config) { ...@@ -141,6 +138,7 @@ void SerialTreeLearner::ResetConfig(const Config* config) {
} else { } else {
config_ = config; config_ = config;
} }
col_sampler_.SetConfig(config_);
histogram_pool_.ResetConfig(train_data_, config_); histogram_pool_.ResetConfig(train_data_, config_);
if (CostEfficientGradientBoosting::IsEnable(config_)) { if (CostEfficientGradientBoosting::IsEnable(config_)) {
cegb_.reset(new CostEfficientGradientBoosting(this)); cegb_.reset(new CostEfficientGradientBoosting(this));
...@@ -148,7 +146,7 @@ void SerialTreeLearner::ResetConfig(const Config* config) { ...@@ -148,7 +146,7 @@ void SerialTreeLearner::ResetConfig(const Config* config) {
} }
} }
Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians, const Json& forced_split_json) { Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians) {
Common::FunctionTimer fun_timer("SerialTreeLearner::Train", global_timer); Common::FunctionTimer fun_timer("SerialTreeLearner::Train", global_timer);
gradients_ = gradients; gradients_ = gradients;
hessians_ = hessians; hessians_ = hessians;
...@@ -165,28 +163,21 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians ...@@ -165,28 +163,21 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
BeforeTrain(); BeforeTrain();
auto tree = std::unique_ptr<Tree>(new Tree(config_->num_leaves)); auto tree = std::unique_ptr<Tree>(new Tree(config_->num_leaves));
auto tree_prt = tree.get();
// root leaf // root leaf
int left_leaf = 0; int left_leaf = 0;
int cur_depth = 1; int cur_depth = 1;
// only root leaf can be splitted on first time // only root leaf can be splitted on first time
int right_leaf = -1; int right_leaf = -1;
int init_splits = 0; int init_splits = ForceSplits(tree_prt, &left_leaf, &right_leaf, &cur_depth);
bool aborted_last_force_split = false;
if (!forced_split_json.is_null()) {
init_splits = ForceSplits(tree.get(), forced_split_json, &left_leaf,
&right_leaf, &cur_depth, &aborted_last_force_split);
}
for (int split = init_splits; split < config_->num_leaves - 1; ++split) { for (int split = init_splits; split < config_->num_leaves - 1; ++split) {
// some initial works before finding best split // some initial works before finding best split
if (!aborted_last_force_split && BeforeFindBestSplit(tree.get(), left_leaf, right_leaf)) { if (BeforeFindBestSplit(tree_prt, left_leaf, right_leaf)) {
// find best threshold for every feature // find best threshold for every feature
FindBestSplits(); FindBestSplits();
} else if (aborted_last_force_split) {
aborted_last_force_split = false;
} }
// Get a leaf with max split gain // Get a leaf with max split gain
int best_leaf = static_cast<int>(ArrayArgs<SplitInfo>::ArgMax(best_split_per_leaf_)); int best_leaf = static_cast<int>(ArrayArgs<SplitInfo>::ArgMax(best_split_per_leaf_));
// Get split information for best leaf // Get split information for best leaf
...@@ -197,7 +188,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians ...@@ -197,7 +188,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
break; break;
} }
// split tree with best leaf // split tree with best leaf
Split(tree.get(), best_leaf, &left_leaf, &right_leaf); Split(tree_prt, best_leaf, &left_leaf, &right_leaf);
cur_depth = std::max(cur_depth, tree->leaf_depth(left_leaf)); cur_depth = std::max(cur_depth, tree->leaf_depth(left_leaf));
} }
Log::Debug("Trained a tree with leaves = %d and max_depth = %d", tree->num_leaves(), cur_depth); Log::Debug("Trained a tree with leaves = %d and max_depth = %d", tree->num_leaves(), cur_depth);
...@@ -220,8 +211,9 @@ Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* ...@@ -220,8 +211,9 @@ Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t*
sum_grad += gradients[idx]; sum_grad += gradients[idx];
sum_hess += hessians[idx]; sum_hess += hessians[idx];
} }
double output = FeatureHistogram::CalculateSplittedLeafOutput(sum_grad, sum_hess, double output = FeatureHistogram::CalculateSplittedLeafOutput<true, true>(
config_->lambda_l1, config_->lambda_l2, config_->max_delta_step); sum_grad, sum_hess, config_->lambda_l1, config_->lambda_l2,
config_->max_delta_step);
auto old_leaf_output = tree->LeafOutput(i); auto old_leaf_output = tree->LeafOutput(i);
auto new_leaf_output = output * tree->shrinkage(); auto new_leaf_output = output * tree->shrinkage();
tree->SetLeafOutput(i, config_->refit_decay_rate * old_leaf_output + (1.0 - config_->refit_decay_rate) * new_leaf_output); tree->SetLeafOutput(i, config_->refit_decay_rate * old_leaf_output + (1.0 - config_->refit_decay_rate) * new_leaf_output);
...@@ -236,70 +228,13 @@ Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const std::vect ...@@ -236,70 +228,13 @@ Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const std::vect
return FitByExistingTree(old_tree, gradients, hessians); return FitByExistingTree(old_tree, gradients, hessians);
} }
std::vector<int8_t> SerialTreeLearner::GetUsedFeatures(bool is_tree_level) {
std::vector<int8_t> ret(num_features_, 1);
if (config_->feature_fraction >= 1.0f && is_tree_level) {
return ret;
}
if (config_->feature_fraction_bynode >= 1.0f && !is_tree_level) {
return ret;
}
std::memset(ret.data(), 0, sizeof(int8_t) * num_features_);
const int min_used_features = std::min(2, static_cast<int>(valid_feature_indices_.size()));
if (is_tree_level) {
int used_feature_cnt = static_cast<int>(std::round(valid_feature_indices_.size() * config_->feature_fraction));
used_feature_cnt = std::max(used_feature_cnt, min_used_features);
used_feature_indices_ = random_.Sample(static_cast<int>(valid_feature_indices_.size()), used_feature_cnt);
int omp_loop_size = static_cast<int>(used_feature_indices_.size());
#pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024)
for (int i = 0; i < omp_loop_size; ++i) {
int used_feature = valid_feature_indices_[used_feature_indices_[i]];
int inner_feature_index = train_data_->InnerFeatureIndex(used_feature);
CHECK_GE(inner_feature_index, 0);
ret[inner_feature_index] = 1;
}
} else if (used_feature_indices_.size() <= 0) {
int used_feature_cnt = static_cast<int>(std::round(valid_feature_indices_.size() * config_->feature_fraction_bynode));
used_feature_cnt = std::max(used_feature_cnt, min_used_features);
auto sampled_indices = random_.Sample(static_cast<int>(valid_feature_indices_.size()), used_feature_cnt);
int omp_loop_size = static_cast<int>(sampled_indices.size());
#pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024)
for (int i = 0; i < omp_loop_size; ++i) {
int used_feature = valid_feature_indices_[sampled_indices[i]];
int inner_feature_index = train_data_->InnerFeatureIndex(used_feature);
CHECK_GE(inner_feature_index, 0);
ret[inner_feature_index] = 1;
}
} else {
int used_feature_cnt = static_cast<int>(std::round(used_feature_indices_.size() * config_->feature_fraction_bynode));
used_feature_cnt = std::max(used_feature_cnt, min_used_features);
auto sampled_indices = random_.Sample(static_cast<int>(used_feature_indices_.size()), used_feature_cnt);
int omp_loop_size = static_cast<int>(sampled_indices.size());
#pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024)
for (int i = 0; i < omp_loop_size; ++i) {
int used_feature = valid_feature_indices_[used_feature_indices_[sampled_indices[i]]];
int inner_feature_index = train_data_->InnerFeatureIndex(used_feature);
CHECK_GE(inner_feature_index, 0);
ret[inner_feature_index] = 1;
}
}
return ret;
}
void SerialTreeLearner::BeforeTrain() { void SerialTreeLearner::BeforeTrain() {
Common::FunctionTimer fun_timer("SerialTreeLearner::BeforeTrain", global_timer); Common::FunctionTimer fun_timer("SerialTreeLearner::BeforeTrain", global_timer);
// reset histogram pool // reset histogram pool
histogram_pool_.ResetMap(); histogram_pool_.ResetMap();
if (config_->feature_fraction < 1.0f) { col_sampler_.ResetByTree();
is_feature_used_ = GetUsedFeatures(true); train_data_->InitTrain(col_sampler_.is_feature_used_bytree(), share_state_.get());
} else {
#pragma omp parallel for schedule(static, 512) if (num_features_ >= 1024)
for (int i = 0; i < num_features_; ++i) {
is_feature_used_[i] = 1;
}
}
train_data_->InitTrain(is_feature_used_, share_state_.get());
// initialize data partition // initialize data partition
data_partition_->Init(); data_partition_->Init();
...@@ -367,9 +302,9 @@ bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int ...@@ -367,9 +302,9 @@ bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int
void SerialTreeLearner::FindBestSplits() { void SerialTreeLearner::FindBestSplits() {
std::vector<int8_t> is_feature_used(num_features_, 0); std::vector<int8_t> is_feature_used(num_features_, 0);
#pragma omp parallel for schedule(static, 1024) if (num_features_ >= 2048) #pragma omp parallel for schedule(static, 256) if (num_features_ >= 512)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) { for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
if (!is_feature_used_[feature_index]) continue; if (!col_sampler_.is_feature_used_bytree()[feature_index]) continue;
if (parent_leaf_histogram_array_ != nullptr if (parent_leaf_histogram_array_ != nullptr
&& !parent_leaf_histogram_array_[feature_index].is_splittable()) { && !parent_leaf_histogram_array_[feature_index].is_splittable()) {
smaller_leaf_histogram_array_[feature_index].set_is_splittable(false); smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
...@@ -413,12 +348,8 @@ void SerialTreeLearner::FindBestSplitsFromHistograms( ...@@ -413,12 +348,8 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(
"SerialTreeLearner::FindBestSplitsFromHistograms", global_timer); "SerialTreeLearner::FindBestSplitsFromHistograms", global_timer);
std::vector<SplitInfo> smaller_best(share_state_->num_threads); std::vector<SplitInfo> smaller_best(share_state_->num_threads);
std::vector<SplitInfo> larger_best(share_state_->num_threads); std::vector<SplitInfo> larger_best(share_state_->num_threads);
std::vector<int8_t> smaller_node_used_features(num_features_, 1); std::vector<int8_t> smaller_node_used_features = col_sampler_.GetByNode();
std::vector<int8_t> larger_node_used_features(num_features_, 1); std::vector<int8_t> larger_node_used_features = col_sampler_.GetByNode();
if (config_->feature_fraction_bynode < 1.0f) {
smaller_node_used_features = GetUsedFeatures(false);
larger_node_used_features = GetUsedFeatures(false);
}
OMP_INIT_EX(); OMP_INIT_EX();
// find splits // find splits
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
...@@ -477,18 +408,21 @@ void SerialTreeLearner::FindBestSplitsFromHistograms( ...@@ -477,18 +408,21 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(
} }
} }
int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json, int* left_leaf, int32_t SerialTreeLearner::ForceSplits(Tree* tree, int* left_leaf,
int* right_leaf, int *cur_depth, int* right_leaf, int *cur_depth) {
bool *aborted_last_force_split) { bool abort_last_forced_split = false;
if (forced_split_json_ == nullptr) {
return 0;
}
int32_t result_count = 0; int32_t result_count = 0;
// start at root leaf // start at root leaf
*left_leaf = 0; *left_leaf = 0;
std::queue<std::pair<Json, int>> q; std::queue<std::pair<Json, int>> q;
Json left = forced_split_json; Json left = *forced_split_json_;
Json right; Json right;
bool left_smaller = true; bool left_smaller = true;
std::unordered_map<int, SplitInfo> forceSplitMap; std::unordered_map<int, SplitInfo> forceSplitMap;
q.push(std::make_pair(forced_split_json, *left_leaf)); q.push(std::make_pair(left, *left_leaf));
while (!q.empty()) { while (!q.empty()) {
// before processing next node from queue, store info for current left/right leaf // before processing next node from queue, store info for current left/right leaf
// store "best split" for left and right, even if they might be overwritten by forced split // store "best split" for left and right, even if they might be overwritten by forced split
...@@ -546,88 +480,13 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json ...@@ -546,88 +480,13 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json
int current_leaf = pair.second; int current_leaf = pair.second;
// split info should exist because searching in bfs fashion - should have added from parent // split info should exist because searching in bfs fashion - should have added from parent
if (forceSplitMap.find(current_leaf) == forceSplitMap.end()) { if (forceSplitMap.find(current_leaf) == forceSplitMap.end()) {
*aborted_last_force_split = true; abort_last_forced_split = true;
break; break;
} }
SplitInfo current_split_info = forceSplitMap[current_leaf]; best_split_per_leaf_[current_leaf] = forceSplitMap[current_leaf];
const int inner_feature_index = train_data_->InnerFeatureIndex( Split(tree, current_leaf, left_leaf, right_leaf);
current_split_info.feature); left_smaller = best_split_per_leaf_[current_leaf].left_count <
auto threshold_double = train_data_->RealThreshold( best_split_per_leaf_[current_leaf].right_count;
inner_feature_index, current_split_info.threshold);
// split tree, will return right leaf
*left_leaf = current_leaf;
auto next_leaf_id = tree->NextLeafId();
if (train_data_->FeatureBinMapper(inner_feature_index)->bin_type() == BinType::NumericalBin) {
data_partition_->Split(current_leaf, train_data_, inner_feature_index,
&current_split_info.threshold, 1,
current_split_info.default_left, next_leaf_id);
current_split_info.left_count = data_partition_->leaf_count(*left_leaf);
current_split_info.right_count = data_partition_->leaf_count(next_leaf_id);
*right_leaf = tree->Split(current_leaf,
inner_feature_index,
current_split_info.feature,
current_split_info.threshold,
threshold_double,
static_cast<double>(current_split_info.left_output),
static_cast<double>(current_split_info.right_output),
static_cast<data_size_t>(current_split_info.left_count),
static_cast<data_size_t>(current_split_info.right_count),
static_cast<double>(current_split_info.left_sum_hessian),
static_cast<double>(current_split_info.right_sum_hessian),
static_cast<float>(current_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type(),
current_split_info.default_left);
} else {
std::vector<uint32_t> cat_bitset_inner = Common::ConstructBitset(
current_split_info.cat_threshold.data(), current_split_info.num_cat_threshold);
std::vector<int> threshold_int(current_split_info.num_cat_threshold);
for (int i = 0; i < current_split_info.num_cat_threshold; ++i) {
threshold_int[i] = static_cast<int>(train_data_->RealThreshold(
inner_feature_index, current_split_info.cat_threshold[i]));
}
std::vector<uint32_t> cat_bitset = Common::ConstructBitset(
threshold_int.data(), current_split_info.num_cat_threshold);
data_partition_->Split(current_leaf, train_data_, inner_feature_index,
cat_bitset_inner.data(), static_cast<int>(cat_bitset_inner.size()),
current_split_info.default_left, next_leaf_id);
current_split_info.left_count = data_partition_->leaf_count(*left_leaf);
current_split_info.right_count = data_partition_->leaf_count(next_leaf_id);
*right_leaf = tree->SplitCategorical(current_leaf,
inner_feature_index,
current_split_info.feature,
cat_bitset_inner.data(),
static_cast<int>(cat_bitset_inner.size()),
cat_bitset.data(),
static_cast<int>(cat_bitset.size()),
static_cast<double>(current_split_info.left_output),
static_cast<double>(current_split_info.right_output),
static_cast<data_size_t>(current_split_info.left_count),
static_cast<data_size_t>(current_split_info.right_count),
static_cast<double>(current_split_info.left_sum_hessian),
static_cast<double>(current_split_info.right_sum_hessian),
static_cast<float>(current_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type());
}
#ifdef DEBUG
CHECK(*right_leaf == next_leaf_id);
#endif
if (current_split_info.left_count < current_split_info.right_count) {
left_smaller = true;
smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
current_split_info.left_sum_gradient,
current_split_info.left_sum_hessian);
larger_leaf_splits_->Init(*right_leaf, data_partition_.get(),
current_split_info.right_sum_gradient,
current_split_info.right_sum_hessian);
} else {
left_smaller = false;
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(),
current_split_info.right_sum_gradient, current_split_info.right_sum_hessian);
larger_leaf_splits_->Init(*left_leaf, data_partition_.get(),
current_split_info.left_sum_gradient, current_split_info.left_sum_hessian);
}
left = Json(); left = Json();
right = Json(); right = Json();
if ((pair.first).object_items().count("left") > 0) { if ((pair.first).object_items().count("left") > 0) {
...@@ -645,6 +504,19 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json ...@@ -645,6 +504,19 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json
result_count++; result_count++;
*(cur_depth) = std::max(*(cur_depth), tree->leaf_depth(*left_leaf)); *(cur_depth) = std::max(*(cur_depth), tree->leaf_depth(*left_leaf));
} }
if (abort_last_forced_split) {
int best_leaf =
static_cast<int>(ArrayArgs<SplitInfo>::ArgMax(best_split_per_leaf_));
const SplitInfo& best_leaf_SplitInfo = best_split_per_leaf_[best_leaf];
if (best_leaf_SplitInfo.gain <= 0.0) {
Log::Warning("No further splits with positive gain, best gain: %f",
best_leaf_SplitInfo.gain);
return config_->num_leaves;
}
Split(tree, best_leaf, left_leaf, right_leaf);
*(cur_depth) = std::max(*(cur_depth), tree->leaf_depth(*left_leaf));
result_count++;
}
return result_count; return result_count;
} }
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <LightGBM/tree.h> #include <LightGBM/tree.h>
#include <LightGBM/tree_learner.h> #include <LightGBM/tree_learner.h>
#include <LightGBM/utils/array_args.h> #include <LightGBM/utils/array_args.h>
#include <LightGBM/utils/json11.h>
#include <LightGBM/utils/random.h> #include <LightGBM/utils/random.h>
#include <string> #include <string>
...@@ -18,6 +19,7 @@ ...@@ -18,6 +19,7 @@
#include <random> #include <random>
#include <vector> #include <vector>
#include "col_sampler.hpp"
#include "data_partition.hpp" #include "data_partition.hpp"
#include "feature_histogram.hpp" #include "feature_histogram.hpp"
#include "leaf_splits.hpp" #include "leaf_splits.hpp"
...@@ -63,8 +65,15 @@ class SerialTreeLearner: public TreeLearner { ...@@ -63,8 +65,15 @@ class SerialTreeLearner: public TreeLearner {
void ResetConfig(const Config* config) override; void ResetConfig(const Config* config) override;
Tree* Train(const score_t* gradients, const score_t *hessians, inline void SetForcedSplit(const Json* forced_split_json) override {
const Json& forced_split_json) override; if (forced_split_json != nullptr && !forced_split_json->is_null()) {
forced_split_json_ = forced_split_json;
} else {
forced_split_json_ = nullptr;
}
}
Tree* Train(const score_t* gradients, const score_t *hessians) override;
Tree* FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t* hessians) const override; Tree* FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t* hessians) const override;
...@@ -113,7 +122,6 @@ class SerialTreeLearner: public TreeLearner { ...@@ -113,7 +122,6 @@ class SerialTreeLearner: public TreeLearner {
void GetShareStates(const Dataset* dataset, bool is_constant_hessian, bool is_first_time); void GetShareStates(const Dataset* dataset, bool is_constant_hessian, bool is_first_time);
virtual std::vector<int8_t> GetUsedFeatures(bool is_tree_level);
/*! /*!
* \brief Some initial works before training * \brief Some initial works before training
*/ */
...@@ -142,12 +150,12 @@ class SerialTreeLearner: public TreeLearner { ...@@ -142,12 +150,12 @@ class SerialTreeLearner: public TreeLearner {
SplitInner(tree, best_leaf, left_leaf, right_leaf, true); SplitInner(tree, best_leaf, left_leaf, right_leaf, true);
} }
void SplitInner(Tree* tree, int best_leaf, int* left_leaf, int* right_leaf, bool update_cnt); void SplitInner(Tree* tree, int best_leaf, int* left_leaf, int* right_leaf,
bool update_cnt);
/* Force splits with forced_split_json dict and then return num splits forced.*/ /* Force splits with forced_split_json dict and then return num splits forced.*/
virtual int32_t ForceSplits(Tree* tree, const Json& forced_split_json, int* left_leaf, int32_t ForceSplits(Tree* tree, int* left_leaf, int* right_leaf,
int* right_leaf, int* cur_depth, int* cur_depth);
bool *aborted_last_force_split);
/*! /*!
* \brief Get the number of data in a leaf * \brief Get the number of data in a leaf
...@@ -168,12 +176,6 @@ class SerialTreeLearner: public TreeLearner { ...@@ -168,12 +176,6 @@ class SerialTreeLearner: public TreeLearner {
const score_t* hessians_; const score_t* hessians_;
/*! \brief training data partition on leaves */ /*! \brief training data partition on leaves */
std::unique_ptr<DataPartition> data_partition_; std::unique_ptr<DataPartition> data_partition_;
/*! \brief used for generate used features */
Random random_;
/*! \brief used for sub feature training, is_feature_used_[i] = false means don't used feature i */
std::vector<int8_t> is_feature_used_;
/*! \brief used feature indices in current tree */
std::vector<int> used_feature_indices_;
/*! \brief pointer to histograms array of parent of current leaves */ /*! \brief pointer to histograms array of parent of current leaves */
FeatureHistogram* parent_leaf_histogram_array_; FeatureHistogram* parent_leaf_histogram_array_;
/*! \brief pointer to histograms array of smaller leaf */ /*! \brief pointer to histograms array of smaller leaf */
...@@ -192,7 +194,6 @@ class SerialTreeLearner: public TreeLearner { ...@@ -192,7 +194,6 @@ class SerialTreeLearner: public TreeLearner {
std::unique_ptr<LeafSplits> smaller_leaf_splits_; std::unique_ptr<LeafSplits> smaller_leaf_splits_;
/*! \brief stores best thresholds for all feature for larger leaf */ /*! \brief stores best thresholds for all feature for larger leaf */
std::unique_ptr<LeafSplits> larger_leaf_splits_; std::unique_ptr<LeafSplits> larger_leaf_splits_;
std::vector<int> valid_feature_indices_;
#ifdef USE_GPU #ifdef USE_GPU
/*! \brief gradients of current iteration, ordered for cache optimized, aligned to 4K page */ /*! \brief gradients of current iteration, ordered for cache optimized, aligned to 4K page */
...@@ -209,6 +210,8 @@ class SerialTreeLearner: public TreeLearner { ...@@ -209,6 +210,8 @@ class SerialTreeLearner: public TreeLearner {
HistogramPool histogram_pool_; HistogramPool histogram_pool_;
/*! \brief config of tree learner*/ /*! \brief config of tree learner*/
const Config* config_; const Config* config_;
ColSampler col_sampler_;
const Json* forced_split_json_;
std::unique_ptr<TrainingShareStates> share_state_; std::unique_ptr<TrainingShareStates> share_state_;
std::unique_ptr<CostEfficientGradientBoosting> cegb_; std::unique_ptr<CostEfficientGradientBoosting> cegb_;
}; };
......
...@@ -66,7 +66,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b ...@@ -66,7 +66,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b
auto num_total_bin = train_data->NumTotalBin(); auto num_total_bin = train_data->NumTotalBin();
smaller_leaf_histogram_data_.resize(num_total_bin); smaller_leaf_histogram_data_.resize(num_total_bin);
larger_leaf_histogram_data_.resize(num_total_bin); larger_leaf_histogram_data_.resize(num_total_bin);
HistogramPool::SetFeatureInfo(train_data, this->config_, &feature_metas_); HistogramPool::SetFeatureInfo<true, true>(train_data, this->config_, &feature_metas_);
uint64_t offset = 0; uint64_t offset = 0;
for (int j = 0; j < train_data->num_features(); ++j) { for (int j = 0; j < train_data->num_features(); ++j) {
offset += static_cast<uint64_t>(train_data->SubFeatureBinOffset(j)); offset += static_cast<uint64_t>(train_data->SubFeatureBinOffset(j));
...@@ -91,7 +91,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::ResetConfig(const Config* config) ...@@ -91,7 +91,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::ResetConfig(const Config* config)
this->histogram_pool_.ResetConfig(this->train_data_, &local_config_); this->histogram_pool_.ResetConfig(this->train_data_, &local_config_);
global_data_count_in_leaf_.resize(this->config_->num_leaves); global_data_count_in_leaf_.resize(this->config_->num_leaves);
HistogramPool::SetFeatureInfoConfig(this->train_data_, config, &feature_metas_); HistogramPool::SetFeatureInfo<false, true>(this->train_data_, config, &feature_metas_);
} }
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
...@@ -247,7 +247,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() { ...@@ -247,7 +247,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
std::vector<int8_t> is_feature_used(this->num_features_, 0); std::vector<int8_t> is_feature_used(this->num_features_, 0);
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) { for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
if (!this->is_feature_used_[feature_index]) continue; if (!this->col_sampler_.is_feature_used_bytree()[feature_index]) continue;
if (this->parent_leaf_histogram_array_ != nullptr if (this->parent_leaf_histogram_array_ != nullptr
&& !this->parent_leaf_histogram_array_[feature_index].is_splittable()) { && !this->parent_leaf_histogram_array_[feature_index].is_splittable()) {
this->smaller_leaf_histogram_array_[feature_index].set_is_splittable(false); this->smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
...@@ -351,12 +351,10 @@ template <typename TREELEARNER_T> ...@@ -351,12 +351,10 @@ template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>&, bool) { void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>&, bool) {
std::vector<SplitInfo> smaller_bests_per_thread(this->share_state_->num_threads); std::vector<SplitInfo> smaller_bests_per_thread(this->share_state_->num_threads);
std::vector<SplitInfo> larger_bests_per_thread(this->share_state_->num_threads); std::vector<SplitInfo> larger_bests_per_thread(this->share_state_->num_threads);
std::vector<int8_t> smaller_node_used_features(this->num_features_, 1); std::vector<int8_t> smaller_node_used_features =
std::vector<int8_t> larger_node_used_features(this->num_features_, 1); this->col_sampler_.GetByNode();
if (this->config_->feature_fraction_bynode < 1.0f) { std::vector<int8_t> larger_node_used_features =
smaller_node_used_features = this->GetUsedFeatures(false); this->col_sampler_.GetByNode();
larger_node_used_features = this->GetUsedFeatures(false);
}
// find best split from local aggregated histograms // find best split from local aggregated histograms
OMP_INIT_EX(); OMP_INIT_EX();
...@@ -429,7 +427,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons ...@@ -429,7 +427,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) { void VotingParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) {
this->SplitInner(tree, best_Leaf, left_leaf, right_leaf, false); TREELEARNER_T::SplitInner(tree, best_Leaf, left_leaf, right_leaf, false);
const SplitInfo& best_split_info = this->best_split_per_leaf_[best_Leaf]; const SplitInfo& best_split_info = this->best_split_per_leaf_[best_Leaf];
// set the global number of data for leaves // set the global number of data for leaves
global_data_count_in_leaf_[*left_leaf] = best_split_info.left_count; global_data_count_in_leaf_[*left_leaf] = best_split_info.left_count;
......
...@@ -406,7 +406,8 @@ class TestEngine(unittest.TestCase): ...@@ -406,7 +406,8 @@ class TestEngine(unittest.TestCase):
'num_class': 10, 'num_class': 10,
'num_leaves': 50, 'num_leaves': 50,
'min_data': 1, 'min_data': 1,
'verbose': -1 'verbose': -1,
'gpu_use_dp': True
} }
lgb_train = lgb.Dataset(X_train, y_train, params=params) lgb_train = lgb.Dataset(X_train, y_train, params=params)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train, params=params) lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train, params=params)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment