Unverified Commit dc699574 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

Refine config object (#1381)

* [WIP] refine config

* [wip] ready for the auto code generate

* auto generate config codes

* use with to open file

* fix bug

* fix pylint

* fix bug

* fix pylint

* fix bugs.

* tmp for failed test.

* fix tests.

* added nthreads alias

* added new aliases from new config.h

* fixed duplicated alias

* refactored parameter_generator.py

* added new aliases from config.h and removed remaining old names

* fix bugs & some miss alias

* added aliases

* add more descriptions.

* add comment.
parent 497e60ed
...@@ -15,7 +15,7 @@ namespace LightGBM { ...@@ -15,7 +15,7 @@ namespace LightGBM {
*/ */
class MulticlassSoftmax: public ObjectiveFunction { class MulticlassSoftmax: public ObjectiveFunction {
public: public:
explicit MulticlassSoftmax(const ObjectiveConfig& config) { explicit MulticlassSoftmax(const Config& config) {
num_class_ = config.num_class; num_class_ = config.num_class;
} }
...@@ -138,7 +138,7 @@ private: ...@@ -138,7 +138,7 @@ private:
*/ */
class MulticlassOVA: public ObjectiveFunction { class MulticlassOVA: public ObjectiveFunction {
public: public:
explicit MulticlassOVA(const ObjectiveConfig& config) { explicit MulticlassOVA(const Config& config) {
num_class_ = config.num_class; num_class_ = config.num_class;
for (int i = 0; i < num_class_; ++i) { for (int i = 0; i < num_class_; ++i) {
binary_loss_.emplace_back( binary_loss_.emplace_back(
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
namespace LightGBM { namespace LightGBM {
ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& type, const ObjectiveConfig& config) { ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& type, const Config& config) {
if (type == std::string("regression") || type == std::string("regression_l2") if (type == std::string("regression") || type == std::string("regression_l2")
|| type == std::string("mean_squared_error") || type == std::string("mse") || type == std::string("mean_squared_error") || type == std::string("mse")
|| type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) { || type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
......
...@@ -18,15 +18,12 @@ namespace LightGBM { ...@@ -18,15 +18,12 @@ namespace LightGBM {
*/ */
class LambdarankNDCG: public ObjectiveFunction { class LambdarankNDCG: public ObjectiveFunction {
public: public:
explicit LambdarankNDCG(const ObjectiveConfig& config) { explicit LambdarankNDCG(const Config& config) {
sigmoid_ = static_cast<double>(config.sigmoid); sigmoid_ = static_cast<double>(config.sigmoid);
label_gain_ = config.label_gain;
// initialize DCG calculator // initialize DCG calculator
DCGCalculator::Init(config.label_gain); DCGCalculator::DefaultLabelGain(&label_gain_);
// copy lable gain to local DCGCalculator::Init(label_gain_);
for (auto gain : config.label_gain) {
label_gain_.push_back(static_cast<double>(gain));
}
label_gain_.shrink_to_fit();
// will optimize NDCG@optimize_pos_at_ // will optimize NDCG@optimize_pos_at_
optimize_pos_at_ = config.max_position; optimize_pos_at_ = config.max_position;
sigmoid_table_.clear(); sigmoid_table_.clear();
......
...@@ -63,7 +63,7 @@ namespace LightGBM { ...@@ -63,7 +63,7 @@ namespace LightGBM {
*/ */
class RegressionL2loss: public ObjectiveFunction { class RegressionL2loss: public ObjectiveFunction {
public: public:
explicit RegressionL2loss(const ObjectiveConfig& config) { explicit RegressionL2loss(const Config& config) {
sqrt_ = config.reg_sqrt; sqrt_ = config.reg_sqrt;
} }
...@@ -174,7 +174,7 @@ protected: ...@@ -174,7 +174,7 @@ protected:
*/ */
class RegressionL1loss: public RegressionL2loss { class RegressionL1loss: public RegressionL2loss {
public: public:
explicit RegressionL1loss(const ObjectiveConfig& config): RegressionL2loss(config) { explicit RegressionL1loss(const Config& config): RegressionL2loss(config) {
} }
explicit RegressionL1loss(const std::vector<std::string>& strs): RegressionL2loss(strs) { explicit RegressionL1loss(const std::vector<std::string>& strs): RegressionL2loss(strs) {
...@@ -260,7 +260,7 @@ public: ...@@ -260,7 +260,7 @@ public:
*/ */
class RegressionHuberLoss: public RegressionL2loss { class RegressionHuberLoss: public RegressionL2loss {
public: public:
explicit RegressionHuberLoss(const ObjectiveConfig& config): RegressionL2loss(config) { explicit RegressionHuberLoss(const Config& config): RegressionL2loss(config) {
alpha_ = static_cast<double>(config.alpha); alpha_ = static_cast<double>(config.alpha);
} }
...@@ -315,7 +315,7 @@ private: ...@@ -315,7 +315,7 @@ private:
// http://research.microsoft.com/en-us/um/people/zhang/INRIA/Publis/Tutorial-Estim/node24.html // http://research.microsoft.com/en-us/um/people/zhang/INRIA/Publis/Tutorial-Estim/node24.html
class RegressionFairLoss: public RegressionL2loss { class RegressionFairLoss: public RegressionL2loss {
public: public:
explicit RegressionFairLoss(const ObjectiveConfig& config): RegressionL2loss(config) { explicit RegressionFairLoss(const Config& config): RegressionL2loss(config) {
c_ = static_cast<double>(config.fair_c); c_ = static_cast<double>(config.fair_c);
} }
...@@ -363,7 +363,7 @@ private: ...@@ -363,7 +363,7 @@ private:
*/ */
class RegressionPoissonLoss: public RegressionL2loss { class RegressionPoissonLoss: public RegressionL2loss {
public: public:
explicit RegressionPoissonLoss(const ObjectiveConfig& config): RegressionL2loss(config) { explicit RegressionPoissonLoss(const Config& config): RegressionL2loss(config) {
max_delta_step_ = static_cast<double>(config.poisson_max_delta_step); max_delta_step_ = static_cast<double>(config.poisson_max_delta_step);
if (sqrt_) { if (sqrt_) {
Log::Warning("Cannot use sqrt transform in %s Regression, will auto disable it", GetName()); Log::Warning("Cannot use sqrt transform in %s Regression, will auto disable it", GetName());
...@@ -444,7 +444,7 @@ private: ...@@ -444,7 +444,7 @@ private:
class RegressionQuantileloss : public RegressionL2loss { class RegressionQuantileloss : public RegressionL2loss {
public: public:
explicit RegressionQuantileloss(const ObjectiveConfig& config): RegressionL2loss(config) { explicit RegressionQuantileloss(const Config& config): RegressionL2loss(config) {
alpha_ = static_cast<score_t>(config.alpha); alpha_ = static_cast<score_t>(config.alpha);
CHECK(alpha_ > 0 && alpha_ < 1); CHECK(alpha_ > 0 && alpha_ < 1);
} }
...@@ -543,7 +543,7 @@ private: ...@@ -543,7 +543,7 @@ private:
*/ */
class RegressionMAPELOSS : public RegressionL1loss { class RegressionMAPELOSS : public RegressionL1loss {
public: public:
explicit RegressionMAPELOSS(const ObjectiveConfig& config) : RegressionL1loss(config) { explicit RegressionMAPELOSS(const Config& config) : RegressionL1loss(config) {
} }
explicit RegressionMAPELOSS(const std::vector<std::string>& strs) : RegressionL1loss(strs) { explicit RegressionMAPELOSS(const std::vector<std::string>& strs) : RegressionL1loss(strs) {
...@@ -644,7 +644,7 @@ private: ...@@ -644,7 +644,7 @@ private:
*/ */
class RegressionGammaLoss : public RegressionPoissonLoss { class RegressionGammaLoss : public RegressionPoissonLoss {
public: public:
explicit RegressionGammaLoss(const ObjectiveConfig& config) : RegressionPoissonLoss(config) { explicit RegressionGammaLoss(const Config& config) : RegressionPoissonLoss(config) {
} }
explicit RegressionGammaLoss(const std::vector<std::string>& strs) : RegressionPoissonLoss(strs) { explicit RegressionGammaLoss(const std::vector<std::string>& strs) : RegressionPoissonLoss(strs) {
...@@ -681,7 +681,7 @@ public: ...@@ -681,7 +681,7 @@ public:
*/ */
class RegressionTweedieLoss: public RegressionPoissonLoss { class RegressionTweedieLoss: public RegressionPoissonLoss {
public: public:
explicit RegressionTweedieLoss(const ObjectiveConfig& config) : RegressionPoissonLoss(config) { explicit RegressionTweedieLoss(const Config& config) : RegressionPoissonLoss(config) {
rho_ = config.tweedie_variance_power; rho_ = config.tweedie_variance_power;
} }
......
...@@ -38,7 +38,7 @@ namespace LightGBM { ...@@ -38,7 +38,7 @@ namespace LightGBM {
*/ */
class CrossEntropy: public ObjectiveFunction { class CrossEntropy: public ObjectiveFunction {
public: public:
explicit CrossEntropy(const ObjectiveConfig&) { explicit CrossEntropy(const Config&) {
} }
explicit CrossEntropy(const std::vector<std::string>&) { explicit CrossEntropy(const std::vector<std::string>&) {
...@@ -141,7 +141,7 @@ private: ...@@ -141,7 +141,7 @@ private:
*/ */
class CrossEntropyLambda: public ObjectiveFunction { class CrossEntropyLambda: public ObjectiveFunction {
public: public:
explicit CrossEntropyLambda(const ObjectiveConfig&) { explicit CrossEntropyLambda(const Config&) {
min_weight_ = max_weight_ = 0.0f; min_weight_ = max_weight_ = 0.0f;
} }
......
...@@ -8,8 +8,8 @@ ...@@ -8,8 +8,8 @@
namespace LightGBM { namespace LightGBM {
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
DataParallelTreeLearner<TREELEARNER_T>::DataParallelTreeLearner(const TreeConfig* tree_config) DataParallelTreeLearner<TREELEARNER_T>::DataParallelTreeLearner(const Config* config)
:TREELEARNER_T(tree_config) { :TREELEARNER_T(config) {
} }
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
...@@ -37,13 +37,13 @@ void DataParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, boo ...@@ -37,13 +37,13 @@ void DataParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, boo
buffer_write_start_pos_.resize(this->num_features_); buffer_write_start_pos_.resize(this->num_features_);
buffer_read_start_pos_.resize(this->num_features_); buffer_read_start_pos_.resize(this->num_features_);
global_data_count_in_leaf_.resize(this->tree_config_->num_leaves); global_data_count_in_leaf_.resize(this->config_->num_leaves);
} }
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
void DataParallelTreeLearner<TREELEARNER_T>::ResetConfig(const TreeConfig* tree_config) { void DataParallelTreeLearner<TREELEARNER_T>::ResetConfig(const Config* config) {
TREELEARNER_T::ResetConfig(tree_config); TREELEARNER_T::ResetConfig(config);
global_data_count_in_leaf_.resize(this->tree_config_->num_leaves); global_data_count_in_leaf_.resize(this->config_->num_leaves);
} }
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
...@@ -236,7 +236,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const ...@@ -236,7 +236,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const
} }
// sync global best info // sync global best info
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->tree_config_->max_cat_threshold); SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->config_->max_cat_threshold);
// set best split // set best split
this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split; this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split;
......
...@@ -20,7 +20,7 @@ public: ...@@ -20,7 +20,7 @@ public:
uint32_t default_bin; uint32_t default_bin;
int8_t monotone_type; int8_t monotone_type;
/*! \brief pointer of tree config */ /*! \brief pointer of tree config */
const TreeConfig* tree_config; const Config* config;
BinType bin_type; BinType bin_type;
}; };
/*! /*!
...@@ -84,8 +84,8 @@ public: ...@@ -84,8 +84,8 @@ public:
is_splittable_ = false; is_splittable_ = false;
double gain_shift = GetLeafSplitGain(sum_gradient, sum_hessian, double gain_shift = GetLeafSplitGain(sum_gradient, sum_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->tree_config->max_delta_step); meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step);
double min_gain_shift = gain_shift + meta_->tree_config->min_gain_to_split; double min_gain_shift = gain_shift + meta_->config->min_gain_to_split;
if (meta_->num_bin > 2 && meta_->missing_type != MissingType::None) { if (meta_->num_bin > 2 && meta_->missing_type != MissingType::None) {
if (meta_->missing_type == MissingType::Zero) { if (meta_->missing_type == MissingType::Zero) {
FindBestThresholdSequence(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, true, false); FindBestThresholdSequence(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, true, false);
...@@ -115,35 +115,35 @@ public: ...@@ -115,35 +115,35 @@ public:
data_size_t best_left_count = 0; data_size_t best_left_count = 0;
double best_sum_left_gradient = 0; double best_sum_left_gradient = 0;
double best_sum_left_hessian = 0; double best_sum_left_hessian = 0;
double gain_shift = GetLeafSplitGain(sum_gradient, sum_hessian, meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->tree_config->max_delta_step); double gain_shift = GetLeafSplitGain(sum_gradient, sum_hessian, meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step);
double min_gain_shift = gain_shift + meta_->tree_config->min_gain_to_split; double min_gain_shift = gain_shift + meta_->config->min_gain_to_split;
bool is_full_categorical = meta_->missing_type == MissingType::None; bool is_full_categorical = meta_->missing_type == MissingType::None;
int used_bin = meta_->num_bin - 1 + is_full_categorical; int used_bin = meta_->num_bin - 1 + is_full_categorical;
std::vector<int> sorted_idx; std::vector<int> sorted_idx;
double l2 = meta_->tree_config->lambda_l2; double l2 = meta_->config->lambda_l2;
bool use_onehot = meta_->num_bin <= meta_->tree_config->max_cat_to_onehot; bool use_onehot = meta_->num_bin <= meta_->config->max_cat_to_onehot;
int best_threshold = -1; int best_threshold = -1;
int best_dir = 1; int best_dir = 1;
if (use_onehot) { if (use_onehot) {
for (int t = 0; t < used_bin; ++t) { for (int t = 0; t < used_bin; ++t) {
// if data not enough, or sum hessian too small // if data not enough, or sum hessian too small
if (data_[t].cnt < meta_->tree_config->min_data_in_leaf if (data_[t].cnt < meta_->config->min_data_in_leaf
|| data_[t].sum_hessians < meta_->tree_config->min_sum_hessian_in_leaf) continue; || data_[t].sum_hessians < meta_->config->min_sum_hessian_in_leaf) continue;
data_size_t other_count = num_data - data_[t].cnt; data_size_t other_count = num_data - data_[t].cnt;
// if data not enough // if data not enough
if (other_count < meta_->tree_config->min_data_in_leaf) continue; if (other_count < meta_->config->min_data_in_leaf) continue;
double sum_other_hessian = sum_hessian - data_[t].sum_hessians - kEpsilon; double sum_other_hessian = sum_hessian - data_[t].sum_hessians - kEpsilon;
// if sum hessian too small // if sum hessian too small
if (sum_other_hessian < meta_->tree_config->min_sum_hessian_in_leaf) continue; if (sum_other_hessian < meta_->config->min_sum_hessian_in_leaf) continue;
double sum_other_gradient = sum_gradient - data_[t].sum_gradients; double sum_other_gradient = sum_gradient - data_[t].sum_gradients;
// current split gain // current split gain
double current_gain = GetSplitGains(sum_other_gradient, sum_other_hessian, data_[t].sum_gradients, data_[t].sum_hessians + kEpsilon, double current_gain = GetSplitGains(sum_other_gradient, sum_other_hessian, data_[t].sum_gradients, data_[t].sum_hessians + kEpsilon,
meta_->tree_config->lambda_l1, l2, meta_->tree_config->max_delta_step, meta_->config->lambda_l1, l2, meta_->config->max_delta_step,
min_constraint, max_constraint, 0); min_constraint, max_constraint, 0);
// gain with split is worse than without split // gain with split is worse than without split
if (current_gain <= min_gain_shift) continue; if (current_gain <= min_gain_shift) continue;
...@@ -161,16 +161,16 @@ public: ...@@ -161,16 +161,16 @@ public:
} }
} else { } else {
for (int i = 0; i < used_bin; ++i) { for (int i = 0; i < used_bin; ++i) {
if (data_[i].cnt >= meta_->tree_config->cat_smooth) { if (data_[i].cnt >= meta_->config->cat_smooth) {
sorted_idx.push_back(i); sorted_idx.push_back(i);
} }
} }
used_bin = static_cast<int>(sorted_idx.size()); used_bin = static_cast<int>(sorted_idx.size());
l2 += meta_->tree_config->cat_l2; l2 += meta_->config->cat_l2;
auto ctr_fun = [this](double sum_grad, double sum_hess) { auto ctr_fun = [this](double sum_grad, double sum_hess) {
return (sum_grad) / (sum_hess + meta_->tree_config->cat_smooth); return (sum_grad) / (sum_hess + meta_->config->cat_smooth);
}; };
std::sort(sorted_idx.begin(), sorted_idx.end(), std::sort(sorted_idx.begin(), sorted_idx.end(),
[this, &ctr_fun](int i, int j) { [this, &ctr_fun](int i, int j) {
...@@ -181,13 +181,13 @@ public: ...@@ -181,13 +181,13 @@ public:
std::vector<int> start_position(1, 0); std::vector<int> start_position(1, 0);
find_direction.push_back(-1); find_direction.push_back(-1);
start_position.push_back(used_bin - 1); start_position.push_back(used_bin - 1);
const int max_num_cat = std::min(meta_->tree_config->max_cat_threshold, (used_bin + 1) / 2); const int max_num_cat = std::min(meta_->config->max_cat_threshold, (used_bin + 1) / 2);
is_splittable_ = false; is_splittable_ = false;
for (size_t out_i = 0; out_i < find_direction.size(); ++out_i) { for (size_t out_i = 0; out_i < find_direction.size(); ++out_i) {
auto dir = find_direction[out_i]; auto dir = find_direction[out_i];
auto start_pos = start_position[out_i]; auto start_pos = start_position[out_i];
data_size_t min_data_per_group = meta_->tree_config->min_data_per_group; data_size_t min_data_per_group = meta_->config->min_data_per_group;
data_size_t cnt_cur_group = 0; data_size_t cnt_cur_group = 0;
double sum_left_gradient = 0.0f; double sum_left_gradient = 0.0f;
double sum_left_hessian = kEpsilon; double sum_left_hessian = kEpsilon;
...@@ -201,13 +201,13 @@ public: ...@@ -201,13 +201,13 @@ public:
left_count += data_[t].cnt; left_count += data_[t].cnt;
cnt_cur_group += data_[t].cnt; cnt_cur_group += data_[t].cnt;
if (left_count < meta_->tree_config->min_data_in_leaf if (left_count < meta_->config->min_data_in_leaf
|| sum_left_hessian < meta_->tree_config->min_sum_hessian_in_leaf) continue; || sum_left_hessian < meta_->config->min_sum_hessian_in_leaf) continue;
data_size_t right_count = num_data - left_count; data_size_t right_count = num_data - left_count;
if (right_count < meta_->tree_config->min_data_in_leaf || right_count < min_data_per_group) break; if (right_count < meta_->config->min_data_in_leaf || right_count < min_data_per_group) break;
double sum_right_hessian = sum_hessian - sum_left_hessian; double sum_right_hessian = sum_hessian - sum_left_hessian;
if (sum_right_hessian < meta_->tree_config->min_sum_hessian_in_leaf) break; if (sum_right_hessian < meta_->config->min_sum_hessian_in_leaf) break;
if (cnt_cur_group < min_data_per_group) continue; if (cnt_cur_group < min_data_per_group) continue;
...@@ -215,7 +215,7 @@ public: ...@@ -215,7 +215,7 @@ public:
double sum_right_gradient = sum_gradient - sum_left_gradient; double sum_right_gradient = sum_gradient - sum_left_gradient;
double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian, double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian,
meta_->tree_config->lambda_l1, l2, meta_->tree_config->max_delta_step, meta_->config->lambda_l1, l2, meta_->config->max_delta_step,
min_constraint, max_constraint, 0); min_constraint, max_constraint, 0);
if (current_gain <= min_gain_shift) continue; if (current_gain <= min_gain_shift) continue;
is_splittable_ = true; is_splittable_ = true;
...@@ -233,14 +233,14 @@ public: ...@@ -233,14 +233,14 @@ public:
if (is_splittable_) { if (is_splittable_) {
output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian, output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian,
meta_->tree_config->lambda_l1, l2, meta_->tree_config->max_delta_step, meta_->config->lambda_l1, l2, meta_->config->max_delta_step,
min_constraint, max_constraint); min_constraint, max_constraint);
output->left_count = best_left_count; output->left_count = best_left_count;
output->left_sum_gradient = best_sum_left_gradient; output->left_sum_gradient = best_sum_left_gradient;
output->left_sum_hessian = best_sum_left_hessian - kEpsilon; output->left_sum_hessian = best_sum_left_hessian - kEpsilon;
output->right_output = CalculateSplittedLeafOutput(sum_gradient - best_sum_left_gradient, output->right_output = CalculateSplittedLeafOutput(sum_gradient - best_sum_left_gradient,
sum_hessian - best_sum_left_hessian, sum_hessian - best_sum_left_hessian,
meta_->tree_config->lambda_l1, l2, meta_->tree_config->max_delta_step, meta_->config->lambda_l1, l2, meta_->config->max_delta_step,
min_constraint, max_constraint); min_constraint, max_constraint);
output->right_count = num_data - best_left_count; output->right_count = num_data - best_left_count;
output->right_sum_gradient = sum_gradient - best_sum_left_gradient; output->right_sum_gradient = sum_gradient - best_sum_left_gradient;
...@@ -285,9 +285,9 @@ public: ...@@ -285,9 +285,9 @@ public:
uint32_t threshold, data_size_t num_data, uint32_t threshold, data_size_t num_data,
SplitInfo *output) { SplitInfo *output) {
double gain_shift = GetLeafSplitGain(sum_gradient, sum_hessian, double gain_shift = GetLeafSplitGain(sum_gradient, sum_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->config->lambda_l1, meta_->config->lambda_l2,
meta_->tree_config->max_delta_step); meta_->config->max_delta_step);
double min_gain_shift = gain_shift + meta_->tree_config->min_gain_to_split; double min_gain_shift = gain_shift + meta_->config->min_gain_to_split;
// do stuff here // do stuff here
const int8_t bias = meta_->bias; const int8_t bias = meta_->bias;
...@@ -325,11 +325,11 @@ public: ...@@ -325,11 +325,11 @@ public:
double sum_left_hessian = sum_hessian - sum_right_hessian; double sum_left_hessian = sum_hessian - sum_right_hessian;
data_size_t left_count = num_data - right_count; data_size_t left_count = num_data - right_count;
double current_gain = GetLeafSplitGain(sum_left_gradient, sum_left_hessian, double current_gain = GetLeafSplitGain(sum_left_gradient, sum_left_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->config->lambda_l1, meta_->config->lambda_l2,
meta_->tree_config->max_delta_step) meta_->config->max_delta_step)
+ GetLeafSplitGain(sum_right_gradient, sum_right_hessian, + GetLeafSplitGain(sum_right_gradient, sum_right_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->config->lambda_l1, meta_->config->lambda_l2,
meta_->tree_config->max_delta_step); meta_->config->max_delta_step);
// gain with split is worse than without split // gain with split is worse than without split
if (std::isnan(current_gain) || current_gain <= min_gain_shift) { if (std::isnan(current_gain) || current_gain <= min_gain_shift) {
...@@ -341,15 +341,15 @@ public: ...@@ -341,15 +341,15 @@ public:
// update split information // update split information
output->threshold = threshold; output->threshold = threshold;
output->left_output = CalculateSplittedLeafOutput(sum_left_gradient, sum_left_hessian, output->left_output = CalculateSplittedLeafOutput(sum_left_gradient, sum_left_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->config->lambda_l1, meta_->config->lambda_l2,
meta_->tree_config->max_delta_step); meta_->config->max_delta_step);
output->left_count = left_count; output->left_count = left_count;
output->left_sum_gradient = sum_left_gradient; output->left_sum_gradient = sum_left_gradient;
output->left_sum_hessian = sum_left_hessian - kEpsilon; output->left_sum_hessian = sum_left_hessian - kEpsilon;
output->right_output = CalculateSplittedLeafOutput(sum_gradient - sum_left_gradient, output->right_output = CalculateSplittedLeafOutput(sum_gradient - sum_left_gradient,
sum_hessian - sum_left_hessian, sum_hessian - sum_left_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->config->lambda_l1, meta_->config->lambda_l2,
meta_->tree_config->max_delta_step); meta_->config->max_delta_step);
output->right_count = num_data - left_count; output->right_count = num_data - left_count;
output->right_sum_gradient = sum_gradient - sum_left_gradient; output->right_sum_gradient = sum_gradient - sum_left_gradient;
output->right_sum_hessian = sum_hessian - sum_left_hessian - kEpsilon; output->right_sum_hessian = sum_hessian - sum_left_hessian - kEpsilon;
...@@ -364,9 +364,9 @@ public: ...@@ -364,9 +364,9 @@ public:
output->default_left = false; output->default_left = false;
double gain_shift = GetLeafSplitGain( double gain_shift = GetLeafSplitGain(
sum_gradient, sum_hessian, sum_gradient, sum_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->config->lambda_l1, meta_->config->lambda_l2,
meta_->tree_config->max_delta_step); meta_->config->max_delta_step);
double min_gain_shift = gain_shift + meta_->tree_config->min_gain_to_split; double min_gain_shift = gain_shift + meta_->config->min_gain_to_split;
bool is_full_categorical = meta_->missing_type == MissingType::None; bool is_full_categorical = meta_->missing_type == MissingType::None;
int used_bin = meta_->num_bin - 1 + is_full_categorical; int used_bin = meta_->num_bin - 1 + is_full_categorical;
if (threshold >= static_cast<uint32_t>(used_bin)) { if (threshold >= static_cast<uint32_t>(used_bin)) {
...@@ -375,7 +375,7 @@ public: ...@@ -375,7 +375,7 @@ public:
return; return;
} }
double l2 = meta_->tree_config->lambda_l2; double l2 = meta_->config->lambda_l2;
data_size_t left_count = data_[threshold].cnt; data_size_t left_count = data_[threshold].cnt;
data_size_t right_count = num_data - left_count; data_size_t right_count = num_data - left_count;
double sum_left_hessian = data_[threshold].sum_hessians + kEpsilon; double sum_left_hessian = data_[threshold].sum_hessians + kEpsilon;
...@@ -384,11 +384,11 @@ public: ...@@ -384,11 +384,11 @@ public:
double sum_right_gradient = sum_gradient - sum_left_gradient; double sum_right_gradient = sum_gradient - sum_left_gradient;
// current split gain // current split gain
double current_gain = GetLeafSplitGain(sum_right_gradient, sum_right_hessian, double current_gain = GetLeafSplitGain(sum_right_gradient, sum_right_hessian,
meta_->tree_config->lambda_l1, l2, meta_->config->lambda_l1, l2,
meta_->tree_config->max_delta_step) meta_->config->max_delta_step)
+ GetLeafSplitGain(sum_left_gradient, sum_right_hessian, + GetLeafSplitGain(sum_left_gradient, sum_right_hessian,
meta_->tree_config->lambda_l1, l2, meta_->config->lambda_l1, l2,
meta_->tree_config->max_delta_step); meta_->config->max_delta_step);
if (std::isnan(current_gain) || current_gain <= min_gain_shift) { if (std::isnan(current_gain) || current_gain <= min_gain_shift) {
output->gain = kMinScore; output->gain = kMinScore;
Log::Warning("Gain with forced split worse than without split"); Log::Warning("Gain with forced split worse than without split");
...@@ -396,14 +396,14 @@ public: ...@@ -396,14 +396,14 @@ public:
} }
output->left_output = CalculateSplittedLeafOutput(sum_left_gradient, sum_left_hessian, output->left_output = CalculateSplittedLeafOutput(sum_left_gradient, sum_left_hessian,
meta_->tree_config->lambda_l1, l2, meta_->config->lambda_l1, l2,
meta_->tree_config->max_delta_step); meta_->config->max_delta_step);
output->left_count = left_count; output->left_count = left_count;
output->left_sum_gradient = sum_left_gradient; output->left_sum_gradient = sum_left_gradient;
output->left_sum_hessian = sum_left_hessian - kEpsilon; output->left_sum_hessian = sum_left_hessian - kEpsilon;
output->right_output = CalculateSplittedLeafOutput(sum_right_gradient, sum_right_hessian, output->right_output = CalculateSplittedLeafOutput(sum_right_gradient, sum_right_hessian,
meta_->tree_config->lambda_l1, l2, meta_->config->lambda_l1, l2,
meta_->tree_config->max_delta_step); meta_->config->max_delta_step);
output->right_count = right_count; output->right_count = right_count;
output->right_sum_gradient = sum_gradient - sum_left_gradient; output->right_sum_gradient = sum_gradient - sum_left_gradient;
output->right_sum_hessian = sum_right_hessian - kEpsilon; output->right_sum_hessian = sum_right_hessian - kEpsilon;
...@@ -530,20 +530,20 @@ private: ...@@ -530,20 +530,20 @@ private:
sum_right_hessian += data_[t].sum_hessians; sum_right_hessian += data_[t].sum_hessians;
right_count += data_[t].cnt; right_count += data_[t].cnt;
// if data not enough, or sum hessian too small // if data not enough, or sum hessian too small
if (right_count < meta_->tree_config->min_data_in_leaf if (right_count < meta_->config->min_data_in_leaf
|| sum_right_hessian < meta_->tree_config->min_sum_hessian_in_leaf) continue; || sum_right_hessian < meta_->config->min_sum_hessian_in_leaf) continue;
data_size_t left_count = num_data - right_count; data_size_t left_count = num_data - right_count;
// if data not enough // if data not enough
if (left_count < meta_->tree_config->min_data_in_leaf) break; if (left_count < meta_->config->min_data_in_leaf) break;
double sum_left_hessian = sum_hessian - sum_right_hessian; double sum_left_hessian = sum_hessian - sum_right_hessian;
// if sum hessian too small // if sum hessian too small
if (sum_left_hessian < meta_->tree_config->min_sum_hessian_in_leaf) break; if (sum_left_hessian < meta_->config->min_sum_hessian_in_leaf) break;
double sum_left_gradient = sum_gradient - sum_right_gradient; double sum_left_gradient = sum_gradient - sum_right_gradient;
// current split gain // current split gain
double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian, double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->tree_config->max_delta_step, meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step,
min_constraint, max_constraint, meta_->monotone_type); min_constraint, max_constraint, meta_->monotone_type);
// gain with split is worse than without split // gain with split is worse than without split
if (current_gain <= min_gain_shift) continue; if (current_gain <= min_gain_shift) continue;
...@@ -590,20 +590,20 @@ private: ...@@ -590,20 +590,20 @@ private:
left_count += data_[t].cnt; left_count += data_[t].cnt;
} }
// if data not enough, or sum hessian too small // if data not enough, or sum hessian too small
if (left_count < meta_->tree_config->min_data_in_leaf if (left_count < meta_->config->min_data_in_leaf
|| sum_left_hessian < meta_->tree_config->min_sum_hessian_in_leaf) continue; || sum_left_hessian < meta_->config->min_sum_hessian_in_leaf) continue;
data_size_t right_count = num_data - left_count; data_size_t right_count = num_data - left_count;
// if data not enough // if data not enough
if (right_count < meta_->tree_config->min_data_in_leaf) break; if (right_count < meta_->config->min_data_in_leaf) break;
double sum_right_hessian = sum_hessian - sum_left_hessian; double sum_right_hessian = sum_hessian - sum_left_hessian;
// if sum hessian too small // if sum hessian too small
if (sum_right_hessian < meta_->tree_config->min_sum_hessian_in_leaf) break; if (sum_right_hessian < meta_->config->min_sum_hessian_in_leaf) break;
double sum_right_gradient = sum_gradient - sum_left_gradient; double sum_right_gradient = sum_gradient - sum_left_gradient;
// current split gain // current split gain
double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian, double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->tree_config->max_delta_step, meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step,
min_constraint, max_constraint, meta_->monotone_type); min_constraint, max_constraint, meta_->monotone_type);
// gain with split is worse than without split // gain with split is worse than without split
if (current_gain <= min_gain_shift) continue; if (current_gain <= min_gain_shift) continue;
...@@ -625,14 +625,14 @@ private: ...@@ -625,14 +625,14 @@ private:
// update split information // update split information
output->threshold = best_threshold; output->threshold = best_threshold;
output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian, output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->tree_config->max_delta_step, meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step,
min_constraint, max_constraint); min_constraint, max_constraint);
output->left_count = best_left_count; output->left_count = best_left_count;
output->left_sum_gradient = best_sum_left_gradient; output->left_sum_gradient = best_sum_left_gradient;
output->left_sum_hessian = best_sum_left_hessian - kEpsilon; output->left_sum_hessian = best_sum_left_hessian - kEpsilon;
output->right_output = CalculateSplittedLeafOutput(sum_gradient - best_sum_left_gradient, output->right_output = CalculateSplittedLeafOutput(sum_gradient - best_sum_left_gradient,
sum_hessian - best_sum_left_hessian, sum_hessian - best_sum_left_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->tree_config->max_delta_step, meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step,
min_constraint, max_constraint); min_constraint, max_constraint);
output->right_count = num_data - best_left_count; output->right_count = num_data - best_left_count;
output->right_sum_gradient = sum_gradient - best_sum_left_gradient; output->right_sum_gradient = sum_gradient - best_sum_left_gradient;
...@@ -697,7 +697,7 @@ public: ...@@ -697,7 +697,7 @@ public:
} }
} }
void DynamicChangeSize(const Dataset* train_data, const TreeConfig* tree_config, int cache_size, int total_size) { void DynamicChangeSize(const Dataset* train_data, const Config* config, int cache_size, int total_size) {
if (feature_metas_.empty()) { if (feature_metas_.empty()) {
int num_feature = train_data->num_features(); int num_feature = train_data->num_features();
feature_metas_.resize(num_feature); feature_metas_.resize(num_feature);
...@@ -712,7 +712,7 @@ public: ...@@ -712,7 +712,7 @@ public:
} else { } else {
feature_metas_[i].bias = 0; feature_metas_[i].bias = 0;
} }
feature_metas_[i].tree_config = tree_config; feature_metas_[i].config = config;
feature_metas_[i].bin_type = train_data->FeatureBinMapper(i)->bin_type(); feature_metas_[i].bin_type = train_data->FeatureBinMapper(i)->bin_type();
} }
} }
...@@ -748,11 +748,11 @@ public: ...@@ -748,11 +748,11 @@ public:
OMP_THROW_EX(); OMP_THROW_EX();
} }
void ResetConfig(const TreeConfig* tree_config) { void ResetConfig(const Config* config) {
int size = static_cast<int>(feature_metas_.size()); int size = static_cast<int>(feature_metas_.size());
#pragma omp parallel for schedule(static, 512) if(size >= 1024) #pragma omp parallel for schedule(static, 512) if(size >= 1024)
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
feature_metas_[i].tree_config = tree_config; feature_metas_[i].config = config;
} }
} }
/*! /*!
......
...@@ -8,8 +8,8 @@ namespace LightGBM { ...@@ -8,8 +8,8 @@ namespace LightGBM {
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
FeatureParallelTreeLearner<TREELEARNER_T>::FeatureParallelTreeLearner(const TreeConfig* tree_config) FeatureParallelTreeLearner<TREELEARNER_T>::FeatureParallelTreeLearner(const Config* config)
:TREELEARNER_T(tree_config) { :TREELEARNER_T(config) {
} }
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
...@@ -22,8 +22,8 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, ...@@ -22,8 +22,8 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data,
TREELEARNER_T::Init(train_data, is_constant_hessian); TREELEARNER_T::Init(train_data, is_constant_hessian);
rank_ = Network::rank(); rank_ = Network::rank();
num_machines_ = Network::num_machines(); num_machines_ = Network::num_machines();
input_buffer_.resize((sizeof(SplitInfo) + sizeof(uint32_t) * this->tree_config_->max_cat_threshold) * 2); input_buffer_.resize((sizeof(SplitInfo) + sizeof(uint32_t) * this->config_->max_cat_threshold) * 2);
output_buffer_.resize((sizeof(SplitInfo) + sizeof(uint32_t) * this->tree_config_->max_cat_threshold) * 2); output_buffer_.resize((sizeof(SplitInfo) + sizeof(uint32_t) * this->config_->max_cat_threshold) * 2);
} }
...@@ -60,7 +60,7 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(con ...@@ -60,7 +60,7 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(con
larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()]; larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()];
} }
// sync global best info // sync global best info
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->tree_config_->max_cat_threshold); SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->config_->max_cat_threshold);
// update best split // update best split
this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split; this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split;
if (this->larger_leaf_splits_->LeafIndex() >= 0) { if (this->larger_leaf_splits_->LeafIndex() >= 0) {
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
namespace LightGBM { namespace LightGBM {
GPUTreeLearner::GPUTreeLearner(const TreeConfig* tree_config) GPUTreeLearner::GPUTreeLearner(const Config* config)
:SerialTreeLearner(tree_config) { :SerialTreeLearner(config) {
use_bagging_ = false; use_bagging_ = false;
Log::Info("This is the GPU trainer!!"); Log::Info("This is the GPU trainer!!");
} }
...@@ -39,7 +39,7 @@ void GPUTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian) { ...@@ -39,7 +39,7 @@ void GPUTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian) {
// some additional variables needed for GPU trainer // some additional variables needed for GPU trainer
num_feature_groups_ = train_data_->num_feature_groups(); num_feature_groups_ = train_data_->num_feature_groups();
// Initialize GPU buffers and kernels // Initialize GPU buffers and kernels
InitGPU(tree_config_->gpu_platform_id, tree_config_->gpu_device_id); InitGPU(config_->gpu_platform_id, config_->gpu_device_id);
} }
// some functions used for debugging the GPU histogram construction // some functions used for debugging the GPU histogram construction
...@@ -304,7 +304,7 @@ void GPUTreeLearner::AllocateGPUMemory() { ...@@ -304,7 +304,7 @@ void GPUTreeLearner::AllocateGPUMemory() {
device_data_indices_ = std::unique_ptr<boost::compute::vector<data_size_t>>(new boost::compute::vector<data_size_t>(allocated_num_data_, ctx_)); device_data_indices_ = std::unique_ptr<boost::compute::vector<data_size_t>>(new boost::compute::vector<data_size_t>(allocated_num_data_, ctx_));
boost::compute::fill(device_data_indices_->begin(), device_data_indices_->end(), 0, queue_); boost::compute::fill(device_data_indices_->begin(), device_data_indices_->end(), 0, queue_);
// histogram bin entry size depends on the precision (single/double) // histogram bin entry size depends on the precision (single/double)
hist_bin_entry_sz_ = tree_config_->gpu_use_dp ? sizeof(HistogramBinEntry) : sizeof(GPUHistogramBinEntry); hist_bin_entry_sz_ = config_->gpu_use_dp ? sizeof(HistogramBinEntry) : sizeof(GPUHistogramBinEntry);
Log::Info("Size of histogram bin entry: %d", hist_bin_entry_sz_); Log::Info("Size of histogram bin entry: %d", hist_bin_entry_sz_);
// create output buffer, each feature has a histogram with device_bin_size_ bins, // create output buffer, each feature has a histogram with device_bin_size_ bins,
// each work group generates a sub-histogram of dword_features_ features. // each work group generates a sub-histogram of dword_features_ features.
...@@ -598,7 +598,7 @@ void GPUTreeLearner::BuildGPUKernels() { ...@@ -598,7 +598,7 @@ void GPUTreeLearner::BuildGPUKernels() {
std::ostringstream opts; std::ostringstream opts;
// compile the GPU kernel depending if double precision is used, constant hessian is used, etc // compile the GPU kernel depending if double precision is used, constant hessian is used, etc
opts << " -D POWER_FEATURE_WORKGROUPS=" << i opts << " -D POWER_FEATURE_WORKGROUPS=" << i
<< " -D USE_CONSTANT_BUF=" << use_constants << " -D USE_DP_FLOAT=" << int(tree_config_->gpu_use_dp) << " -D USE_CONSTANT_BUF=" << use_constants << " -D USE_DP_FLOAT=" << int(config_->gpu_use_dp)
<< " -D CONST_HESSIAN=" << int(is_constant_hessian_) << " -D CONST_HESSIAN=" << int(is_constant_hessian_)
<< " -cl-mad-enable -cl-no-signed-zeros -cl-fast-relaxed-math"; << " -cl-mad-enable -cl-no-signed-zeros -cl-fast-relaxed-math";
#if GPU_DEBUG >= 1 #if GPU_DEBUG >= 1
...@@ -1006,7 +1006,7 @@ void GPUTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_feature_u ...@@ -1006,7 +1006,7 @@ void GPUTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_feature_u
ptr_smaller_leaf_hist_data); ptr_smaller_leaf_hist_data);
// wait for GPU to finish, only if GPU is actually used // wait for GPU to finish, only if GPU is actually used
if (is_gpu_used) { if (is_gpu_used) {
if (tree_config_->gpu_use_dp) { if (config_->gpu_use_dp) {
// use double precision // use double precision
WaitAndGetHistograms<HistogramBinEntry>(ptr_smaller_leaf_hist_data); WaitAndGetHistograms<HistogramBinEntry>(ptr_smaller_leaf_hist_data);
} }
...@@ -1060,7 +1060,7 @@ void GPUTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_feature_u ...@@ -1060,7 +1060,7 @@ void GPUTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_feature_u
ptr_larger_leaf_hist_data); ptr_larger_leaf_hist_data);
// wait for GPU to finish, only if GPU is actually used // wait for GPU to finish, only if GPU is actually used
if (is_gpu_used) { if (is_gpu_used) {
if (tree_config_->gpu_use_dp) { if (config_->gpu_use_dp) {
// use double precision // use double precision
WaitAndGetHistograms<HistogramBinEntry>(ptr_larger_leaf_hist_data); WaitAndGetHistograms<HistogramBinEntry>(ptr_larger_leaf_hist_data);
} }
......
...@@ -37,7 +37,7 @@ namespace LightGBM { ...@@ -37,7 +37,7 @@ namespace LightGBM {
*/ */
class GPUTreeLearner: public SerialTreeLearner { class GPUTreeLearner: public SerialTreeLearner {
public: public:
explicit GPUTreeLearner(const TreeConfig* tree_config); explicit GPUTreeLearner(const Config* tree_config);
~GPUTreeLearner(); ~GPUTreeLearner();
void Init(const Dataset* train_data, bool is_constant_hessian) override; void Init(const Dataset* train_data, bool is_constant_hessian) override;
void ResetTrainingData(const Dataset* train_data) override; void ResetTrainingData(const Dataset* train_data) override;
...@@ -270,7 +270,7 @@ namespace LightGBM { ...@@ -270,7 +270,7 @@ namespace LightGBM {
class GPUTreeLearner: public SerialTreeLearner { class GPUTreeLearner: public SerialTreeLearner {
public: public:
#pragma warning(disable : 4702) #pragma warning(disable : 4702)
explicit GPUTreeLearner(const TreeConfig* tree_config) : SerialTreeLearner(tree_config) { explicit GPUTreeLearner(const Config* tree_config) : SerialTreeLearner(tree_config) {
Log::Fatal("GPU Tree Learner was not enabled in this build.\n" Log::Fatal("GPU Tree Learner was not enabled in this build.\n"
"Please recompile with CMake option -DUSE_GPU=1"); "Please recompile with CMake option -DUSE_GPU=1");
} }
......
...@@ -22,7 +22,7 @@ namespace LightGBM { ...@@ -22,7 +22,7 @@ namespace LightGBM {
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
class FeatureParallelTreeLearner: public TREELEARNER_T { class FeatureParallelTreeLearner: public TREELEARNER_T {
public: public:
explicit FeatureParallelTreeLearner(const TreeConfig* tree_config); explicit FeatureParallelTreeLearner(const Config* config);
~FeatureParallelTreeLearner(); ~FeatureParallelTreeLearner();
void Init(const Dataset* train_data, bool is_constant_hessian) override; void Init(const Dataset* train_data, bool is_constant_hessian) override;
...@@ -48,10 +48,10 @@ private: ...@@ -48,10 +48,10 @@ private:
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
class DataParallelTreeLearner: public TREELEARNER_T { class DataParallelTreeLearner: public TREELEARNER_T {
public: public:
explicit DataParallelTreeLearner(const TreeConfig* tree_config); explicit DataParallelTreeLearner(const Config* config);
~DataParallelTreeLearner(); ~DataParallelTreeLearner();
void Init(const Dataset* train_data, bool is_constant_hessian) override; void Init(const Dataset* train_data, bool is_constant_hessian) override;
void ResetConfig(const TreeConfig* tree_config) override; void ResetConfig(const Config* config) override;
protected: protected:
void BeforeTrain() override; void BeforeTrain() override;
void FindBestSplits() override; void FindBestSplits() override;
...@@ -101,10 +101,10 @@ private: ...@@ -101,10 +101,10 @@ private:
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
class VotingParallelTreeLearner: public TREELEARNER_T { class VotingParallelTreeLearner: public TREELEARNER_T {
public: public:
explicit VotingParallelTreeLearner(const TreeConfig* tree_config); explicit VotingParallelTreeLearner(const Config* config);
~VotingParallelTreeLearner() { } ~VotingParallelTreeLearner() { }
void Init(const Dataset* train_data, bool is_constant_hessian) override; void Init(const Dataset* train_data, bool is_constant_hessian) override;
void ResetConfig(const TreeConfig* tree_config) override; void ResetConfig(const Config* config) override;
protected: protected:
void BeforeTrain() override; void BeforeTrain() override;
bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override; bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override;
...@@ -137,7 +137,7 @@ protected: ...@@ -137,7 +137,7 @@ protected:
private: private:
/*! \brief Tree config used in local mode */ /*! \brief Tree config used in local mode */
TreeConfig local_tree_config_; Config local_config_;
/*! \brief Voting size */ /*! \brief Voting size */
int top_k_; int top_k_;
/*! \brief Rank of local machine*/ /*! \brief Rank of local machine*/
......
...@@ -19,9 +19,9 @@ std::chrono::duration<double, std::milli> split_time; ...@@ -19,9 +19,9 @@ std::chrono::duration<double, std::milli> split_time;
std::chrono::duration<double, std::milli> ordered_bin_time; std::chrono::duration<double, std::milli> ordered_bin_time;
#endif // TIMETAG #endif // TIMETAG
SerialTreeLearner::SerialTreeLearner(const TreeConfig* tree_config) SerialTreeLearner::SerialTreeLearner(const Config* config)
:tree_config_(tree_config) { :config_(config) {
random_ = Random(tree_config_->feature_fraction_seed); random_ = Random(config_->feature_fraction_seed);
#pragma omp parallel #pragma omp parallel
#pragma omp master #pragma omp master
{ {
...@@ -47,22 +47,22 @@ void SerialTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian ...@@ -47,22 +47,22 @@ void SerialTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian
is_constant_hessian_ = is_constant_hessian; is_constant_hessian_ = is_constant_hessian;
int max_cache_size = 0; int max_cache_size = 0;
// Get the max size of pool // Get the max size of pool
if (tree_config_->histogram_pool_size <= 0) { if (config_->histogram_pool_size <= 0) {
max_cache_size = tree_config_->num_leaves; max_cache_size = config_->num_leaves;
} else { } else {
size_t total_histogram_size = 0; size_t total_histogram_size = 0;
for (int i = 0; i < train_data_->num_features(); ++i) { for (int i = 0; i < train_data_->num_features(); ++i) {
total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureNumBin(i); total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureNumBin(i);
} }
max_cache_size = static_cast<int>(tree_config_->histogram_pool_size * 1024 * 1024 / total_histogram_size); max_cache_size = static_cast<int>(config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
} }
// at least need 2 leaves // at least need 2 leaves
max_cache_size = std::max(2, max_cache_size); max_cache_size = std::max(2, max_cache_size);
max_cache_size = std::min(max_cache_size, tree_config_->num_leaves); max_cache_size = std::min(max_cache_size, config_->num_leaves);
histogram_pool_.DynamicChangeSize(train_data_, tree_config_, max_cache_size, tree_config_->num_leaves); histogram_pool_.DynamicChangeSize(train_data_, config_, max_cache_size, config_->num_leaves);
// push split information for all leaves // push split information for all leaves
best_split_per_leaf_.resize(tree_config_->num_leaves); best_split_per_leaf_.resize(config_->num_leaves);
// get ordered bin // get ordered bin
train_data_->CreateOrderedBins(&ordered_bins_); train_data_->CreateOrderedBins(&ordered_bins_);
...@@ -79,7 +79,7 @@ void SerialTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian ...@@ -79,7 +79,7 @@ void SerialTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian
larger_leaf_splits_.reset(new LeafSplits(train_data_->num_data())); larger_leaf_splits_.reset(new LeafSplits(train_data_->num_data()));
// initialize data partition // initialize data partition
data_partition_.reset(new DataPartition(num_data_, tree_config_->num_leaves)); data_partition_.reset(new DataPartition(num_data_, config_->num_leaves));
is_feature_used_.resize(num_features_); is_feature_used_.resize(num_features_);
valid_feature_indices_ = train_data_->ValidFeatureIndices(); valid_feature_indices_ = train_data_->ValidFeatureIndices();
// initialize ordered gradients and hessians // initialize ordered gradients and hessians
...@@ -124,33 +124,33 @@ void SerialTreeLearner::ResetTrainingData(const Dataset* train_data) { ...@@ -124,33 +124,33 @@ void SerialTreeLearner::ResetTrainingData(const Dataset* train_data) {
} }
} }
void SerialTreeLearner::ResetConfig(const TreeConfig* tree_config) { void SerialTreeLearner::ResetConfig(const Config* config) {
if (tree_config_->num_leaves != tree_config->num_leaves) { if (config_->num_leaves != config->num_leaves) {
tree_config_ = tree_config; config_ = config;
int max_cache_size = 0; int max_cache_size = 0;
// Get the max size of pool // Get the max size of pool
if (tree_config->histogram_pool_size <= 0) { if (config->histogram_pool_size <= 0) {
max_cache_size = tree_config_->num_leaves; max_cache_size = config_->num_leaves;
} else { } else {
size_t total_histogram_size = 0; size_t total_histogram_size = 0;
for (int i = 0; i < train_data_->num_features(); ++i) { for (int i = 0; i < train_data_->num_features(); ++i) {
total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureNumBin(i); total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureNumBin(i);
} }
max_cache_size = static_cast<int>(tree_config_->histogram_pool_size * 1024 * 1024 / total_histogram_size); max_cache_size = static_cast<int>(config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
} }
// at least need 2 leaves // at least need 2 leaves
max_cache_size = std::max(2, max_cache_size); max_cache_size = std::max(2, max_cache_size);
max_cache_size = std::min(max_cache_size, tree_config_->num_leaves); max_cache_size = std::min(max_cache_size, config_->num_leaves);
histogram_pool_.DynamicChangeSize(train_data_, tree_config_, max_cache_size, tree_config_->num_leaves); histogram_pool_.DynamicChangeSize(train_data_, config_, max_cache_size, config_->num_leaves);
// push split information for all leaves // push split information for all leaves
best_split_per_leaf_.resize(tree_config_->num_leaves); best_split_per_leaf_.resize(config_->num_leaves);
data_partition_->ResetLeaves(tree_config_->num_leaves); data_partition_->ResetLeaves(config_->num_leaves);
} else { } else {
tree_config_ = tree_config; config_ = config;
} }
histogram_pool_.ResetConfig(tree_config_); histogram_pool_.ResetConfig(config_);
} }
Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_constant_hessian, Json& forced_split_json) { Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_constant_hessian, Json& forced_split_json) {
...@@ -167,7 +167,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians ...@@ -167,7 +167,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
init_train_time += std::chrono::steady_clock::now() - start_time; init_train_time += std::chrono::steady_clock::now() - start_time;
#endif #endif
auto tree = std::unique_ptr<Tree>(new Tree(tree_config_->num_leaves)); auto tree = std::unique_ptr<Tree>(new Tree(config_->num_leaves));
// root leaf // root leaf
int left_leaf = 0; int left_leaf = 0;
int cur_depth = 1; int cur_depth = 1;
...@@ -181,7 +181,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians ...@@ -181,7 +181,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
&right_leaf, &cur_depth, &aborted_last_force_split); &right_leaf, &cur_depth, &aborted_last_force_split);
} }
for (int split = init_splits; split < tree_config_->num_leaves - 1; ++split) { for (int split = init_splits; split < config_->num_leaves - 1; ++split) {
#ifdef TIMETAG #ifdef TIMETAG
start_time = std::chrono::steady_clock::now(); start_time = std::chrono::steady_clock::now();
#endif #endif
...@@ -236,7 +236,7 @@ Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* ...@@ -236,7 +236,7 @@ Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t*
sum_hess += hessians[idx]; sum_hess += hessians[idx];
} }
double output = FeatureHistogram::CalculateSplittedLeafOutput(sum_grad, sum_hess, double output = FeatureHistogram::CalculateSplittedLeafOutput(sum_grad, sum_hess,
tree_config_->lambda_l1, tree_config_->lambda_l2, tree_config_->max_delta_step); config_->lambda_l1, config_->lambda_l2, config_->max_delta_step);
tree->SetLeafOutput(i, output* tree->shrinkage()); tree->SetLeafOutput(i, output* tree->shrinkage());
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
...@@ -254,8 +254,8 @@ void SerialTreeLearner::BeforeTrain() { ...@@ -254,8 +254,8 @@ void SerialTreeLearner::BeforeTrain() {
// reset histogram pool // reset histogram pool
histogram_pool_.ResetMap(); histogram_pool_.ResetMap();
if (tree_config_->feature_fraction < 1) { if (config_->feature_fraction < 1) {
int used_feature_cnt = static_cast<int>(valid_feature_indices_.size()*tree_config_->feature_fraction); int used_feature_cnt = static_cast<int>(valid_feature_indices_.size()*config_->feature_fraction);
// at least use one feature // at least use one feature
used_feature_cnt = std::max(used_feature_cnt, 1); used_feature_cnt = std::max(used_feature_cnt, 1);
// initialize used features // initialize used features
...@@ -281,7 +281,7 @@ void SerialTreeLearner::BeforeTrain() { ...@@ -281,7 +281,7 @@ void SerialTreeLearner::BeforeTrain() {
data_partition_->Init(); data_partition_->Init();
// reset the splits for leaves // reset the splits for leaves
for (int i = 0; i < tree_config_->num_leaves; ++i) { for (int i = 0; i < config_->num_leaves; ++i) {
best_split_per_leaf_[i].Reset(); best_split_per_leaf_[i].Reset();
} }
...@@ -308,7 +308,7 @@ void SerialTreeLearner::BeforeTrain() { ...@@ -308,7 +308,7 @@ void SerialTreeLearner::BeforeTrain() {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int i = 0; i < static_cast<int>(ordered_bin_indices_.size()); ++i) { for (int i = 0; i < static_cast<int>(ordered_bin_indices_.size()); ++i) {
OMP_LOOP_EX_BEGIN(); OMP_LOOP_EX_BEGIN();
ordered_bins_[ordered_bin_indices_[i]]->Init(nullptr, tree_config_->num_leaves); ordered_bins_[ordered_bin_indices_[i]]->Init(nullptr, config_->num_leaves);
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
OMP_THROW_EX(); OMP_THROW_EX();
...@@ -329,7 +329,7 @@ void SerialTreeLearner::BeforeTrain() { ...@@ -329,7 +329,7 @@ void SerialTreeLearner::BeforeTrain() {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int i = 0; i < static_cast<int>(ordered_bin_indices_.size()); ++i) { for (int i = 0; i < static_cast<int>(ordered_bin_indices_.size()); ++i) {
OMP_LOOP_EX_BEGIN(); OMP_LOOP_EX_BEGIN();
ordered_bins_[ordered_bin_indices_[i]]->Init(is_data_in_leaf_.data(), tree_config_->num_leaves); ordered_bins_[ordered_bin_indices_[i]]->Init(is_data_in_leaf_.data(), config_->num_leaves);
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
OMP_THROW_EX(); OMP_THROW_EX();
...@@ -346,9 +346,9 @@ void SerialTreeLearner::BeforeTrain() { ...@@ -346,9 +346,9 @@ void SerialTreeLearner::BeforeTrain() {
bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) { bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) {
// check depth of current leaf // check depth of current leaf
if (tree_config_->max_depth > 0) { if (config_->max_depth > 0) {
// only need to check left leaf, since right leaf is in same level of left leaf // only need to check left leaf, since right leaf is in same level of left leaf
if (tree->leaf_depth(left_leaf) >= tree_config_->max_depth) { if (tree->leaf_depth(left_leaf) >= config_->max_depth) {
best_split_per_leaf_[left_leaf].gain = kMinScore; best_split_per_leaf_[left_leaf].gain = kMinScore;
if (right_leaf >= 0) { if (right_leaf >= 0) {
best_split_per_leaf_[right_leaf].gain = kMinScore; best_split_per_leaf_[right_leaf].gain = kMinScore;
...@@ -359,8 +359,8 @@ bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int ...@@ -359,8 +359,8 @@ bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int
data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf); data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf);
data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf); data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf);
// no enough data to continue // no enough data to continue
if (num_data_in_right_child < static_cast<data_size_t>(tree_config_->min_data_in_leaf * 2) if (num_data_in_right_child < static_cast<data_size_t>(config_->min_data_in_leaf * 2)
&& num_data_in_left_child < static_cast<data_size_t>(tree_config_->min_data_in_leaf * 2)) { && num_data_in_left_child < static_cast<data_size_t>(config_->min_data_in_leaf * 2)) {
best_split_per_leaf_[left_leaf].gain = kMinScore; best_split_per_leaf_[left_leaf].gain = kMinScore;
if (right_leaf >= 0) { if (right_leaf >= 0) {
best_split_per_leaf_[right_leaf].gain = kMinScore; best_split_per_leaf_[right_leaf].gain = kMinScore;
......
...@@ -33,7 +33,7 @@ namespace LightGBM { ...@@ -33,7 +33,7 @@ namespace LightGBM {
*/ */
class SerialTreeLearner: public TreeLearner { class SerialTreeLearner: public TreeLearner {
public: public:
explicit SerialTreeLearner(const TreeConfig* tree_config); explicit SerialTreeLearner(const Config* config);
~SerialTreeLearner(); ~SerialTreeLearner();
...@@ -41,7 +41,7 @@ public: ...@@ -41,7 +41,7 @@ public:
void ResetTrainingData(const Dataset* train_data) override; void ResetTrainingData(const Dataset* train_data) override;
void ResetConfig(const TreeConfig* tree_config) override; void ResetConfig(const Config* config) override;
Tree* Train(const score_t* gradients, const score_t *hessians, bool is_constant_hessian, Tree* Train(const score_t* gradients, const score_t *hessians, bool is_constant_hessian,
Json& forced_split_json) override; Json& forced_split_json) override;
...@@ -163,7 +163,7 @@ protected: ...@@ -163,7 +163,7 @@ protected:
/*! \brief used to cache historical histogram to speed up*/ /*! \brief used to cache historical histogram to speed up*/
HistogramPool histogram_pool_; HistogramPool histogram_pool_;
/*! \brief config of tree learner*/ /*! \brief config of tree learner*/
const TreeConfig* tree_config_; const Config* config_;
int num_threads_; int num_threads_;
std::vector<int> ordered_bin_indices_; std::vector<int> ordered_bin_indices_;
bool is_constant_hessian_; bool is_constant_hessian_;
......
...@@ -6,27 +6,27 @@ ...@@ -6,27 +6,27 @@
namespace LightGBM { namespace LightGBM {
TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, const std::string& device_type, const TreeConfig* tree_config) { TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, const std::string& device_type, const Config* config) {
if (device_type == std::string("cpu")) { if (device_type == std::string("cpu")) {
if (learner_type == std::string("serial")) { if (learner_type == std::string("serial")) {
return new SerialTreeLearner(tree_config); return new SerialTreeLearner(config);
} else if (learner_type == std::string("feature")) { } else if (learner_type == std::string("feature")) {
return new FeatureParallelTreeLearner<SerialTreeLearner>(tree_config); return new FeatureParallelTreeLearner<SerialTreeLearner>(config);
} else if (learner_type == std::string("data")) { } else if (learner_type == std::string("data")) {
return new DataParallelTreeLearner<SerialTreeLearner>(tree_config); return new DataParallelTreeLearner<SerialTreeLearner>(config);
} else if (learner_type == std::string("voting")) { } else if (learner_type == std::string("voting")) {
return new VotingParallelTreeLearner<SerialTreeLearner>(tree_config); return new VotingParallelTreeLearner<SerialTreeLearner>(config);
} }
} }
else if (device_type == std::string("gpu")) { else if (device_type == std::string("gpu")) {
if (learner_type == std::string("serial")) { if (learner_type == std::string("serial")) {
return new GPUTreeLearner(tree_config); return new GPUTreeLearner(config);
} else if (learner_type == std::string("feature")) { } else if (learner_type == std::string("feature")) {
return new FeatureParallelTreeLearner<GPUTreeLearner>(tree_config); return new FeatureParallelTreeLearner<GPUTreeLearner>(config);
} else if (learner_type == std::string("data")) { } else if (learner_type == std::string("data")) {
return new DataParallelTreeLearner<GPUTreeLearner>(tree_config); return new DataParallelTreeLearner<GPUTreeLearner>(config);
} else if (learner_type == std::string("voting")) { } else if (learner_type == std::string("voting")) {
return new VotingParallelTreeLearner<GPUTreeLearner>(tree_config); return new VotingParallelTreeLearner<GPUTreeLearner>(config);
} }
} }
return nullptr; return nullptr;
......
...@@ -10,9 +10,9 @@ ...@@ -10,9 +10,9 @@
namespace LightGBM { namespace LightGBM {
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
VotingParallelTreeLearner<TREELEARNER_T>::VotingParallelTreeLearner(const TreeConfig* tree_config) VotingParallelTreeLearner<TREELEARNER_T>::VotingParallelTreeLearner(const Config* config)
:TREELEARNER_T(tree_config) { :TREELEARNER_T(config) {
top_k_ = this->tree_config_->top_k; top_k_ = this->config_->top_k;
} }
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
...@@ -46,16 +46,16 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b ...@@ -46,16 +46,16 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b
smaller_buffer_read_start_pos_.resize(this->num_features_); smaller_buffer_read_start_pos_.resize(this->num_features_);
larger_buffer_read_start_pos_.resize(this->num_features_); larger_buffer_read_start_pos_.resize(this->num_features_);
global_data_count_in_leaf_.resize(this->tree_config_->num_leaves); global_data_count_in_leaf_.resize(this->config_->num_leaves);
smaller_leaf_splits_global_.reset(new LeafSplits(this->train_data_->num_data())); smaller_leaf_splits_global_.reset(new LeafSplits(this->train_data_->num_data()));
larger_leaf_splits_global_.reset(new LeafSplits(this->train_data_->num_data())); larger_leaf_splits_global_.reset(new LeafSplits(this->train_data_->num_data()));
local_tree_config_ = *this->tree_config_; local_config_ = *this->config_;
local_tree_config_.min_data_in_leaf /= num_machines_; local_config_.min_data_in_leaf /= num_machines_;
local_tree_config_.min_sum_hessian_in_leaf /= num_machines_; local_config_.min_sum_hessian_in_leaf /= num_machines_;
this->histogram_pool_.ResetConfig(&local_tree_config_); this->histogram_pool_.ResetConfig(&local_config_);
// initialize histograms for global // initialize histograms for global
smaller_leaf_histogram_array_global_.reset(new FeatureHistogram[this->num_features_]); smaller_leaf_histogram_array_global_.reset(new FeatureHistogram[this->num_features_]);
...@@ -75,7 +75,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b ...@@ -75,7 +75,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b
} else { } else {
feature_metas_[i].bias = 0; feature_metas_[i].bias = 0;
} }
feature_metas_[i].tree_config = this->tree_config_; feature_metas_[i].config = this->config_;
feature_metas_[i].bin_type = train_data->FeatureBinMapper(i)->bin_type(); feature_metas_[i].bin_type = train_data->FeatureBinMapper(i)->bin_type();
} }
uint64_t offset = 0; uint64_t offset = 0;
...@@ -92,18 +92,18 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b ...@@ -92,18 +92,18 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b
} }
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::ResetConfig(const TreeConfig* tree_config) { void VotingParallelTreeLearner<TREELEARNER_T>::ResetConfig(const Config* config) {
TREELEARNER_T::ResetConfig(tree_config); TREELEARNER_T::ResetConfig(config);
local_tree_config_ = *this->tree_config_; local_config_ = *this->config_;
local_tree_config_.min_data_in_leaf /= num_machines_; local_config_.min_data_in_leaf /= num_machines_;
local_tree_config_.min_sum_hessian_in_leaf /= num_machines_; local_config_.min_sum_hessian_in_leaf /= num_machines_;
this->histogram_pool_.ResetConfig(&local_tree_config_); this->histogram_pool_.ResetConfig(&local_config_);
global_data_count_in_leaf_.resize(this->tree_config_->num_leaves); global_data_count_in_leaf_.resize(this->config_->num_leaves);
for (size_t i = 0; i < feature_metas_.size(); ++i) { for (size_t i = 0; i < feature_metas_.size(); ++i) {
feature_metas_[i].tree_config = this->tree_config_; feature_metas_[i].config = this->config_;
} }
} }
...@@ -451,7 +451,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons ...@@ -451,7 +451,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons
larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()]; larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()];
} }
// sync global best info // sync global best info
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->tree_config_->max_cat_threshold); SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->config_->max_cat_threshold);
// copy back // copy back
this->best_split_per_leaf_[smaller_leaf_splits_global_->LeafIndex()] = smaller_best_split; this->best_split_per_leaf_[smaller_leaf_splits_global_->LeafIndex()] = smaller_best_split;
......
...@@ -257,6 +257,7 @@ ...@@ -257,6 +257,7 @@
<ClCompile Include="..\src\c_api.cpp" /> <ClCompile Include="..\src\c_api.cpp" />
<ClCompile Include="..\src\io\bin.cpp" /> <ClCompile Include="..\src\io\bin.cpp" />
<ClCompile Include="..\src\io\config.cpp" /> <ClCompile Include="..\src\io\config.cpp" />
<ClCompile Include="..\src\io\config_auto.cpp" />
<ClCompile Include="..\src\io\dataset.cpp" /> <ClCompile Include="..\src\io\dataset.cpp" />
<ClCompile Include="..\src\io\dataset_loader.cpp" /> <ClCompile Include="..\src\io\dataset_loader.cpp" />
<ClCompile Include="..\src\io\file_io.cpp" /> <ClCompile Include="..\src\io\file_io.cpp" />
......
...@@ -299,5 +299,8 @@ ...@@ -299,5 +299,8 @@
<ClCompile Include="..\src\io\json11.cpp"> <ClCompile Include="..\src\io\json11.cpp">
<Filter>src\io</Filter> <Filter>src\io</Filter>
</ClCompile> </ClCompile>
<ClCompile Include="..\src\io\config_auto.cpp">
<Filter>src\io</Filter>
</ClCompile>
</ItemGroup> </ItemGroup>
</Project> </Project>
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment