Unverified Commit dc699574 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

Refine config object (#1381)

* [WIP] refine config

* [wip] ready for the auto code generate

* auto generate config codes

* use with to open file

* fix bug

* fix pylint

* fix bug

* fix pylint

* fix bugs.

* tmp for failed test.

* fix tests.

* added nthreads alias

* added new aliases from new config.h

* fixed duplicated alias

* refactored parameter_generator.py

* added new aliases from config.h and removed remaining old names

* fix bugs & some miss alias

* added aliases

* add more descriptions.

* add comment.
parent 497e60ed
......@@ -15,7 +15,7 @@ namespace LightGBM {
*/
class MulticlassSoftmax: public ObjectiveFunction {
public:
explicit MulticlassSoftmax(const ObjectiveConfig& config) {
explicit MulticlassSoftmax(const Config& config) {
num_class_ = config.num_class;
}
......@@ -138,7 +138,7 @@ private:
*/
class MulticlassOVA: public ObjectiveFunction {
public:
explicit MulticlassOVA(const ObjectiveConfig& config) {
explicit MulticlassOVA(const Config& config) {
num_class_ = config.num_class;
for (int i = 0; i < num_class_; ++i) {
binary_loss_.emplace_back(
......
......@@ -7,7 +7,7 @@
namespace LightGBM {
ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& type, const ObjectiveConfig& config) {
ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& type, const Config& config) {
if (type == std::string("regression") || type == std::string("regression_l2")
|| type == std::string("mean_squared_error") || type == std::string("mse")
|| type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
......
......@@ -18,15 +18,12 @@ namespace LightGBM {
*/
class LambdarankNDCG: public ObjectiveFunction {
public:
explicit LambdarankNDCG(const ObjectiveConfig& config) {
explicit LambdarankNDCG(const Config& config) {
sigmoid_ = static_cast<double>(config.sigmoid);
label_gain_ = config.label_gain;
// initialize DCG calculator
DCGCalculator::Init(config.label_gain);
// copy lable gain to local
for (auto gain : config.label_gain) {
label_gain_.push_back(static_cast<double>(gain));
}
label_gain_.shrink_to_fit();
DCGCalculator::DefaultLabelGain(&label_gain_);
DCGCalculator::Init(label_gain_);
// will optimize NDCG@optimize_pos_at_
optimize_pos_at_ = config.max_position;
sigmoid_table_.clear();
......
......@@ -63,7 +63,7 @@ namespace LightGBM {
*/
class RegressionL2loss: public ObjectiveFunction {
public:
explicit RegressionL2loss(const ObjectiveConfig& config) {
explicit RegressionL2loss(const Config& config) {
sqrt_ = config.reg_sqrt;
}
......@@ -174,7 +174,7 @@ protected:
*/
class RegressionL1loss: public RegressionL2loss {
public:
explicit RegressionL1loss(const ObjectiveConfig& config): RegressionL2loss(config) {
explicit RegressionL1loss(const Config& config): RegressionL2loss(config) {
}
explicit RegressionL1loss(const std::vector<std::string>& strs): RegressionL2loss(strs) {
......@@ -260,7 +260,7 @@ public:
*/
class RegressionHuberLoss: public RegressionL2loss {
public:
explicit RegressionHuberLoss(const ObjectiveConfig& config): RegressionL2loss(config) {
explicit RegressionHuberLoss(const Config& config): RegressionL2loss(config) {
alpha_ = static_cast<double>(config.alpha);
}
......@@ -315,7 +315,7 @@ private:
// http://research.microsoft.com/en-us/um/people/zhang/INRIA/Publis/Tutorial-Estim/node24.html
class RegressionFairLoss: public RegressionL2loss {
public:
explicit RegressionFairLoss(const ObjectiveConfig& config): RegressionL2loss(config) {
explicit RegressionFairLoss(const Config& config): RegressionL2loss(config) {
c_ = static_cast<double>(config.fair_c);
}
......@@ -363,7 +363,7 @@ private:
*/
class RegressionPoissonLoss: public RegressionL2loss {
public:
explicit RegressionPoissonLoss(const ObjectiveConfig& config): RegressionL2loss(config) {
explicit RegressionPoissonLoss(const Config& config): RegressionL2loss(config) {
max_delta_step_ = static_cast<double>(config.poisson_max_delta_step);
if (sqrt_) {
Log::Warning("Cannot use sqrt transform in %s Regression, will auto disable it", GetName());
......@@ -444,7 +444,7 @@ private:
class RegressionQuantileloss : public RegressionL2loss {
public:
explicit RegressionQuantileloss(const ObjectiveConfig& config): RegressionL2loss(config) {
explicit RegressionQuantileloss(const Config& config): RegressionL2loss(config) {
alpha_ = static_cast<score_t>(config.alpha);
CHECK(alpha_ > 0 && alpha_ < 1);
}
......@@ -543,7 +543,7 @@ private:
*/
class RegressionMAPELOSS : public RegressionL1loss {
public:
explicit RegressionMAPELOSS(const ObjectiveConfig& config) : RegressionL1loss(config) {
explicit RegressionMAPELOSS(const Config& config) : RegressionL1loss(config) {
}
explicit RegressionMAPELOSS(const std::vector<std::string>& strs) : RegressionL1loss(strs) {
......@@ -644,7 +644,7 @@ private:
*/
class RegressionGammaLoss : public RegressionPoissonLoss {
public:
explicit RegressionGammaLoss(const ObjectiveConfig& config) : RegressionPoissonLoss(config) {
explicit RegressionGammaLoss(const Config& config) : RegressionPoissonLoss(config) {
}
explicit RegressionGammaLoss(const std::vector<std::string>& strs) : RegressionPoissonLoss(strs) {
......@@ -681,7 +681,7 @@ public:
*/
class RegressionTweedieLoss: public RegressionPoissonLoss {
public:
explicit RegressionTweedieLoss(const ObjectiveConfig& config) : RegressionPoissonLoss(config) {
explicit RegressionTweedieLoss(const Config& config) : RegressionPoissonLoss(config) {
rho_ = config.tweedie_variance_power;
}
......
......@@ -38,7 +38,7 @@ namespace LightGBM {
*/
class CrossEntropy: public ObjectiveFunction {
public:
explicit CrossEntropy(const ObjectiveConfig&) {
explicit CrossEntropy(const Config&) {
}
explicit CrossEntropy(const std::vector<std::string>&) {
......@@ -141,7 +141,7 @@ private:
*/
class CrossEntropyLambda: public ObjectiveFunction {
public:
explicit CrossEntropyLambda(const ObjectiveConfig&) {
explicit CrossEntropyLambda(const Config&) {
min_weight_ = max_weight_ = 0.0f;
}
......
......@@ -8,8 +8,8 @@
namespace LightGBM {
template <typename TREELEARNER_T>
DataParallelTreeLearner<TREELEARNER_T>::DataParallelTreeLearner(const TreeConfig* tree_config)
:TREELEARNER_T(tree_config) {
DataParallelTreeLearner<TREELEARNER_T>::DataParallelTreeLearner(const Config* config)
:TREELEARNER_T(config) {
}
template <typename TREELEARNER_T>
......@@ -37,13 +37,13 @@ void DataParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, boo
buffer_write_start_pos_.resize(this->num_features_);
buffer_read_start_pos_.resize(this->num_features_);
global_data_count_in_leaf_.resize(this->tree_config_->num_leaves);
global_data_count_in_leaf_.resize(this->config_->num_leaves);
}
template <typename TREELEARNER_T>
void DataParallelTreeLearner<TREELEARNER_T>::ResetConfig(const TreeConfig* tree_config) {
TREELEARNER_T::ResetConfig(tree_config);
global_data_count_in_leaf_.resize(this->tree_config_->num_leaves);
void DataParallelTreeLearner<TREELEARNER_T>::ResetConfig(const Config* config) {
TREELEARNER_T::ResetConfig(config);
global_data_count_in_leaf_.resize(this->config_->num_leaves);
}
template <typename TREELEARNER_T>
......@@ -236,7 +236,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const
}
// sync global best info
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->tree_config_->max_cat_threshold);
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->config_->max_cat_threshold);
// set best split
this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split;
......
......@@ -20,7 +20,7 @@ public:
uint32_t default_bin;
int8_t monotone_type;
/*! \brief pointer of tree config */
const TreeConfig* tree_config;
const Config* config;
BinType bin_type;
};
/*!
......@@ -84,8 +84,8 @@ public:
is_splittable_ = false;
double gain_shift = GetLeafSplitGain(sum_gradient, sum_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->tree_config->max_delta_step);
double min_gain_shift = gain_shift + meta_->tree_config->min_gain_to_split;
meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step);
double min_gain_shift = gain_shift + meta_->config->min_gain_to_split;
if (meta_->num_bin > 2 && meta_->missing_type != MissingType::None) {
if (meta_->missing_type == MissingType::Zero) {
FindBestThresholdSequence(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, true, false);
......@@ -115,35 +115,35 @@ public:
data_size_t best_left_count = 0;
double best_sum_left_gradient = 0;
double best_sum_left_hessian = 0;
double gain_shift = GetLeafSplitGain(sum_gradient, sum_hessian, meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->tree_config->max_delta_step);
double gain_shift = GetLeafSplitGain(sum_gradient, sum_hessian, meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step);
double min_gain_shift = gain_shift + meta_->tree_config->min_gain_to_split;
double min_gain_shift = gain_shift + meta_->config->min_gain_to_split;
bool is_full_categorical = meta_->missing_type == MissingType::None;
int used_bin = meta_->num_bin - 1 + is_full_categorical;
std::vector<int> sorted_idx;
double l2 = meta_->tree_config->lambda_l2;
bool use_onehot = meta_->num_bin <= meta_->tree_config->max_cat_to_onehot;
double l2 = meta_->config->lambda_l2;
bool use_onehot = meta_->num_bin <= meta_->config->max_cat_to_onehot;
int best_threshold = -1;
int best_dir = 1;
if (use_onehot) {
for (int t = 0; t < used_bin; ++t) {
// if data not enough, or sum hessian too small
if (data_[t].cnt < meta_->tree_config->min_data_in_leaf
|| data_[t].sum_hessians < meta_->tree_config->min_sum_hessian_in_leaf) continue;
if (data_[t].cnt < meta_->config->min_data_in_leaf
|| data_[t].sum_hessians < meta_->config->min_sum_hessian_in_leaf) continue;
data_size_t other_count = num_data - data_[t].cnt;
// if data not enough
if (other_count < meta_->tree_config->min_data_in_leaf) continue;
if (other_count < meta_->config->min_data_in_leaf) continue;
double sum_other_hessian = sum_hessian - data_[t].sum_hessians - kEpsilon;
// if sum hessian too small
if (sum_other_hessian < meta_->tree_config->min_sum_hessian_in_leaf) continue;
if (sum_other_hessian < meta_->config->min_sum_hessian_in_leaf) continue;
double sum_other_gradient = sum_gradient - data_[t].sum_gradients;
// current split gain
double current_gain = GetSplitGains(sum_other_gradient, sum_other_hessian, data_[t].sum_gradients, data_[t].sum_hessians + kEpsilon,
meta_->tree_config->lambda_l1, l2, meta_->tree_config->max_delta_step,
meta_->config->lambda_l1, l2, meta_->config->max_delta_step,
min_constraint, max_constraint, 0);
// gain with split is worse than without split
if (current_gain <= min_gain_shift) continue;
......@@ -161,16 +161,16 @@ public:
}
} else {
for (int i = 0; i < used_bin; ++i) {
if (data_[i].cnt >= meta_->tree_config->cat_smooth) {
if (data_[i].cnt >= meta_->config->cat_smooth) {
sorted_idx.push_back(i);
}
}
used_bin = static_cast<int>(sorted_idx.size());
l2 += meta_->tree_config->cat_l2;
l2 += meta_->config->cat_l2;
auto ctr_fun = [this](double sum_grad, double sum_hess) {
return (sum_grad) / (sum_hess + meta_->tree_config->cat_smooth);
return (sum_grad) / (sum_hess + meta_->config->cat_smooth);
};
std::sort(sorted_idx.begin(), sorted_idx.end(),
[this, &ctr_fun](int i, int j) {
......@@ -181,13 +181,13 @@ public:
std::vector<int> start_position(1, 0);
find_direction.push_back(-1);
start_position.push_back(used_bin - 1);
const int max_num_cat = std::min(meta_->tree_config->max_cat_threshold, (used_bin + 1) / 2);
const int max_num_cat = std::min(meta_->config->max_cat_threshold, (used_bin + 1) / 2);
is_splittable_ = false;
for (size_t out_i = 0; out_i < find_direction.size(); ++out_i) {
auto dir = find_direction[out_i];
auto start_pos = start_position[out_i];
data_size_t min_data_per_group = meta_->tree_config->min_data_per_group;
data_size_t min_data_per_group = meta_->config->min_data_per_group;
data_size_t cnt_cur_group = 0;
double sum_left_gradient = 0.0f;
double sum_left_hessian = kEpsilon;
......@@ -201,13 +201,13 @@ public:
left_count += data_[t].cnt;
cnt_cur_group += data_[t].cnt;
if (left_count < meta_->tree_config->min_data_in_leaf
|| sum_left_hessian < meta_->tree_config->min_sum_hessian_in_leaf) continue;
if (left_count < meta_->config->min_data_in_leaf
|| sum_left_hessian < meta_->config->min_sum_hessian_in_leaf) continue;
data_size_t right_count = num_data - left_count;
if (right_count < meta_->tree_config->min_data_in_leaf || right_count < min_data_per_group) break;
if (right_count < meta_->config->min_data_in_leaf || right_count < min_data_per_group) break;
double sum_right_hessian = sum_hessian - sum_left_hessian;
if (sum_right_hessian < meta_->tree_config->min_sum_hessian_in_leaf) break;
if (sum_right_hessian < meta_->config->min_sum_hessian_in_leaf) break;
if (cnt_cur_group < min_data_per_group) continue;
......@@ -215,7 +215,7 @@ public:
double sum_right_gradient = sum_gradient - sum_left_gradient;
double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian,
meta_->tree_config->lambda_l1, l2, meta_->tree_config->max_delta_step,
meta_->config->lambda_l1, l2, meta_->config->max_delta_step,
min_constraint, max_constraint, 0);
if (current_gain <= min_gain_shift) continue;
is_splittable_ = true;
......@@ -233,14 +233,14 @@ public:
if (is_splittable_) {
output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian,
meta_->tree_config->lambda_l1, l2, meta_->tree_config->max_delta_step,
meta_->config->lambda_l1, l2, meta_->config->max_delta_step,
min_constraint, max_constraint);
output->left_count = best_left_count;
output->left_sum_gradient = best_sum_left_gradient;
output->left_sum_hessian = best_sum_left_hessian - kEpsilon;
output->right_output = CalculateSplittedLeafOutput(sum_gradient - best_sum_left_gradient,
sum_hessian - best_sum_left_hessian,
meta_->tree_config->lambda_l1, l2, meta_->tree_config->max_delta_step,
meta_->config->lambda_l1, l2, meta_->config->max_delta_step,
min_constraint, max_constraint);
output->right_count = num_data - best_left_count;
output->right_sum_gradient = sum_gradient - best_sum_left_gradient;
......@@ -285,9 +285,9 @@ public:
uint32_t threshold, data_size_t num_data,
SplitInfo *output) {
double gain_shift = GetLeafSplitGain(sum_gradient, sum_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2,
meta_->tree_config->max_delta_step);
double min_gain_shift = gain_shift + meta_->tree_config->min_gain_to_split;
meta_->config->lambda_l1, meta_->config->lambda_l2,
meta_->config->max_delta_step);
double min_gain_shift = gain_shift + meta_->config->min_gain_to_split;
// do stuff here
const int8_t bias = meta_->bias;
......@@ -325,11 +325,11 @@ public:
double sum_left_hessian = sum_hessian - sum_right_hessian;
data_size_t left_count = num_data - right_count;
double current_gain = GetLeafSplitGain(sum_left_gradient, sum_left_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2,
meta_->tree_config->max_delta_step)
meta_->config->lambda_l1, meta_->config->lambda_l2,
meta_->config->max_delta_step)
+ GetLeafSplitGain(sum_right_gradient, sum_right_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2,
meta_->tree_config->max_delta_step);
meta_->config->lambda_l1, meta_->config->lambda_l2,
meta_->config->max_delta_step);
// gain with split is worse than without split
if (std::isnan(current_gain) || current_gain <= min_gain_shift) {
......@@ -341,15 +341,15 @@ public:
// update split information
output->threshold = threshold;
output->left_output = CalculateSplittedLeafOutput(sum_left_gradient, sum_left_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2,
meta_->tree_config->max_delta_step);
meta_->config->lambda_l1, meta_->config->lambda_l2,
meta_->config->max_delta_step);
output->left_count = left_count;
output->left_sum_gradient = sum_left_gradient;
output->left_sum_hessian = sum_left_hessian - kEpsilon;
output->right_output = CalculateSplittedLeafOutput(sum_gradient - sum_left_gradient,
sum_hessian - sum_left_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2,
meta_->tree_config->max_delta_step);
meta_->config->lambda_l1, meta_->config->lambda_l2,
meta_->config->max_delta_step);
output->right_count = num_data - left_count;
output->right_sum_gradient = sum_gradient - sum_left_gradient;
output->right_sum_hessian = sum_hessian - sum_left_hessian - kEpsilon;
......@@ -364,9 +364,9 @@ public:
output->default_left = false;
double gain_shift = GetLeafSplitGain(
sum_gradient, sum_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2,
meta_->tree_config->max_delta_step);
double min_gain_shift = gain_shift + meta_->tree_config->min_gain_to_split;
meta_->config->lambda_l1, meta_->config->lambda_l2,
meta_->config->max_delta_step);
double min_gain_shift = gain_shift + meta_->config->min_gain_to_split;
bool is_full_categorical = meta_->missing_type == MissingType::None;
int used_bin = meta_->num_bin - 1 + is_full_categorical;
if (threshold >= static_cast<uint32_t>(used_bin)) {
......@@ -375,7 +375,7 @@ public:
return;
}
double l2 = meta_->tree_config->lambda_l2;
double l2 = meta_->config->lambda_l2;
data_size_t left_count = data_[threshold].cnt;
data_size_t right_count = num_data - left_count;
double sum_left_hessian = data_[threshold].sum_hessians + kEpsilon;
......@@ -384,11 +384,11 @@ public:
double sum_right_gradient = sum_gradient - sum_left_gradient;
// current split gain
double current_gain = GetLeafSplitGain(sum_right_gradient, sum_right_hessian,
meta_->tree_config->lambda_l1, l2,
meta_->tree_config->max_delta_step)
meta_->config->lambda_l1, l2,
meta_->config->max_delta_step)
+ GetLeafSplitGain(sum_left_gradient, sum_right_hessian,
meta_->tree_config->lambda_l1, l2,
meta_->tree_config->max_delta_step);
meta_->config->lambda_l1, l2,
meta_->config->max_delta_step);
if (std::isnan(current_gain) || current_gain <= min_gain_shift) {
output->gain = kMinScore;
Log::Warning("Gain with forced split worse than without split");
......@@ -396,14 +396,14 @@ public:
}
output->left_output = CalculateSplittedLeafOutput(sum_left_gradient, sum_left_hessian,
meta_->tree_config->lambda_l1, l2,
meta_->tree_config->max_delta_step);
meta_->config->lambda_l1, l2,
meta_->config->max_delta_step);
output->left_count = left_count;
output->left_sum_gradient = sum_left_gradient;
output->left_sum_hessian = sum_left_hessian - kEpsilon;
output->right_output = CalculateSplittedLeafOutput(sum_right_gradient, sum_right_hessian,
meta_->tree_config->lambda_l1, l2,
meta_->tree_config->max_delta_step);
meta_->config->lambda_l1, l2,
meta_->config->max_delta_step);
output->right_count = right_count;
output->right_sum_gradient = sum_gradient - sum_left_gradient;
output->right_sum_hessian = sum_right_hessian - kEpsilon;
......@@ -530,20 +530,20 @@ private:
sum_right_hessian += data_[t].sum_hessians;
right_count += data_[t].cnt;
// if data not enough, or sum hessian too small
if (right_count < meta_->tree_config->min_data_in_leaf
|| sum_right_hessian < meta_->tree_config->min_sum_hessian_in_leaf) continue;
if (right_count < meta_->config->min_data_in_leaf
|| sum_right_hessian < meta_->config->min_sum_hessian_in_leaf) continue;
data_size_t left_count = num_data - right_count;
// if data not enough
if (left_count < meta_->tree_config->min_data_in_leaf) break;
if (left_count < meta_->config->min_data_in_leaf) break;
double sum_left_hessian = sum_hessian - sum_right_hessian;
// if sum hessian too small
if (sum_left_hessian < meta_->tree_config->min_sum_hessian_in_leaf) break;
if (sum_left_hessian < meta_->config->min_sum_hessian_in_leaf) break;
double sum_left_gradient = sum_gradient - sum_right_gradient;
// current split gain
double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->tree_config->max_delta_step,
meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step,
min_constraint, max_constraint, meta_->monotone_type);
// gain with split is worse than without split
if (current_gain <= min_gain_shift) continue;
......@@ -590,20 +590,20 @@ private:
left_count += data_[t].cnt;
}
// if data not enough, or sum hessian too small
if (left_count < meta_->tree_config->min_data_in_leaf
|| sum_left_hessian < meta_->tree_config->min_sum_hessian_in_leaf) continue;
if (left_count < meta_->config->min_data_in_leaf
|| sum_left_hessian < meta_->config->min_sum_hessian_in_leaf) continue;
data_size_t right_count = num_data - left_count;
// if data not enough
if (right_count < meta_->tree_config->min_data_in_leaf) break;
if (right_count < meta_->config->min_data_in_leaf) break;
double sum_right_hessian = sum_hessian - sum_left_hessian;
// if sum hessian too small
if (sum_right_hessian < meta_->tree_config->min_sum_hessian_in_leaf) break;
if (sum_right_hessian < meta_->config->min_sum_hessian_in_leaf) break;
double sum_right_gradient = sum_gradient - sum_left_gradient;
// current split gain
double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->tree_config->max_delta_step,
meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step,
min_constraint, max_constraint, meta_->monotone_type);
// gain with split is worse than without split
if (current_gain <= min_gain_shift) continue;
......@@ -625,14 +625,14 @@ private:
// update split information
output->threshold = best_threshold;
output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->tree_config->max_delta_step,
meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step,
min_constraint, max_constraint);
output->left_count = best_left_count;
output->left_sum_gradient = best_sum_left_gradient;
output->left_sum_hessian = best_sum_left_hessian - kEpsilon;
output->right_output = CalculateSplittedLeafOutput(sum_gradient - best_sum_left_gradient,
sum_hessian - best_sum_left_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, meta_->tree_config->max_delta_step,
meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step,
min_constraint, max_constraint);
output->right_count = num_data - best_left_count;
output->right_sum_gradient = sum_gradient - best_sum_left_gradient;
......@@ -697,7 +697,7 @@ public:
}
}
void DynamicChangeSize(const Dataset* train_data, const TreeConfig* tree_config, int cache_size, int total_size) {
void DynamicChangeSize(const Dataset* train_data, const Config* config, int cache_size, int total_size) {
if (feature_metas_.empty()) {
int num_feature = train_data->num_features();
feature_metas_.resize(num_feature);
......@@ -712,7 +712,7 @@ public:
} else {
feature_metas_[i].bias = 0;
}
feature_metas_[i].tree_config = tree_config;
feature_metas_[i].config = config;
feature_metas_[i].bin_type = train_data->FeatureBinMapper(i)->bin_type();
}
}
......@@ -748,11 +748,11 @@ public:
OMP_THROW_EX();
}
void ResetConfig(const TreeConfig* tree_config) {
void ResetConfig(const Config* config) {
int size = static_cast<int>(feature_metas_.size());
#pragma omp parallel for schedule(static, 512) if(size >= 1024)
for (int i = 0; i < size; ++i) {
feature_metas_[i].tree_config = tree_config;
feature_metas_[i].config = config;
}
}
/*!
......
......@@ -8,8 +8,8 @@ namespace LightGBM {
template <typename TREELEARNER_T>
FeatureParallelTreeLearner<TREELEARNER_T>::FeatureParallelTreeLearner(const TreeConfig* tree_config)
:TREELEARNER_T(tree_config) {
FeatureParallelTreeLearner<TREELEARNER_T>::FeatureParallelTreeLearner(const Config* config)
:TREELEARNER_T(config) {
}
template <typename TREELEARNER_T>
......@@ -22,8 +22,8 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data,
TREELEARNER_T::Init(train_data, is_constant_hessian);
rank_ = Network::rank();
num_machines_ = Network::num_machines();
input_buffer_.resize((sizeof(SplitInfo) + sizeof(uint32_t) * this->tree_config_->max_cat_threshold) * 2);
output_buffer_.resize((sizeof(SplitInfo) + sizeof(uint32_t) * this->tree_config_->max_cat_threshold) * 2);
input_buffer_.resize((sizeof(SplitInfo) + sizeof(uint32_t) * this->config_->max_cat_threshold) * 2);
output_buffer_.resize((sizeof(SplitInfo) + sizeof(uint32_t) * this->config_->max_cat_threshold) * 2);
}
......@@ -60,7 +60,7 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(con
larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()];
}
// sync global best info
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->tree_config_->max_cat_threshold);
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->config_->max_cat_threshold);
// update best split
this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split;
if (this->larger_leaf_splits_->LeafIndex() >= 0) {
......
......@@ -15,8 +15,8 @@
namespace LightGBM {
GPUTreeLearner::GPUTreeLearner(const TreeConfig* tree_config)
:SerialTreeLearner(tree_config) {
GPUTreeLearner::GPUTreeLearner(const Config* config)
:SerialTreeLearner(config) {
use_bagging_ = false;
Log::Info("This is the GPU trainer!!");
}
......@@ -39,7 +39,7 @@ void GPUTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian) {
// some additional variables needed for GPU trainer
num_feature_groups_ = train_data_->num_feature_groups();
// Initialize GPU buffers and kernels
InitGPU(tree_config_->gpu_platform_id, tree_config_->gpu_device_id);
InitGPU(config_->gpu_platform_id, config_->gpu_device_id);
}
// some functions used for debugging the GPU histogram construction
......@@ -304,7 +304,7 @@ void GPUTreeLearner::AllocateGPUMemory() {
device_data_indices_ = std::unique_ptr<boost::compute::vector<data_size_t>>(new boost::compute::vector<data_size_t>(allocated_num_data_, ctx_));
boost::compute::fill(device_data_indices_->begin(), device_data_indices_->end(), 0, queue_);
// histogram bin entry size depends on the precision (single/double)
hist_bin_entry_sz_ = tree_config_->gpu_use_dp ? sizeof(HistogramBinEntry) : sizeof(GPUHistogramBinEntry);
hist_bin_entry_sz_ = config_->gpu_use_dp ? sizeof(HistogramBinEntry) : sizeof(GPUHistogramBinEntry);
Log::Info("Size of histogram bin entry: %d", hist_bin_entry_sz_);
// create output buffer, each feature has a histogram with device_bin_size_ bins,
// each work group generates a sub-histogram of dword_features_ features.
......@@ -598,7 +598,7 @@ void GPUTreeLearner::BuildGPUKernels() {
std::ostringstream opts;
// compile the GPU kernel depending if double precision is used, constant hessian is used, etc
opts << " -D POWER_FEATURE_WORKGROUPS=" << i
<< " -D USE_CONSTANT_BUF=" << use_constants << " -D USE_DP_FLOAT=" << int(tree_config_->gpu_use_dp)
<< " -D USE_CONSTANT_BUF=" << use_constants << " -D USE_DP_FLOAT=" << int(config_->gpu_use_dp)
<< " -D CONST_HESSIAN=" << int(is_constant_hessian_)
<< " -cl-mad-enable -cl-no-signed-zeros -cl-fast-relaxed-math";
#if GPU_DEBUG >= 1
......@@ -1006,7 +1006,7 @@ void GPUTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_feature_u
ptr_smaller_leaf_hist_data);
// wait for GPU to finish, only if GPU is actually used
if (is_gpu_used) {
if (tree_config_->gpu_use_dp) {
if (config_->gpu_use_dp) {
// use double precision
WaitAndGetHistograms<HistogramBinEntry>(ptr_smaller_leaf_hist_data);
}
......@@ -1060,7 +1060,7 @@ void GPUTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_feature_u
ptr_larger_leaf_hist_data);
// wait for GPU to finish, only if GPU is actually used
if (is_gpu_used) {
if (tree_config_->gpu_use_dp) {
if (config_->gpu_use_dp) {
// use double precision
WaitAndGetHistograms<HistogramBinEntry>(ptr_larger_leaf_hist_data);
}
......
......@@ -37,7 +37,7 @@ namespace LightGBM {
*/
class GPUTreeLearner: public SerialTreeLearner {
public:
explicit GPUTreeLearner(const TreeConfig* tree_config);
explicit GPUTreeLearner(const Config* tree_config);
~GPUTreeLearner();
void Init(const Dataset* train_data, bool is_constant_hessian) override;
void ResetTrainingData(const Dataset* train_data) override;
......@@ -270,7 +270,7 @@ namespace LightGBM {
class GPUTreeLearner: public SerialTreeLearner {
public:
#pragma warning(disable : 4702)
explicit GPUTreeLearner(const TreeConfig* tree_config) : SerialTreeLearner(tree_config) {
explicit GPUTreeLearner(const Config* tree_config) : SerialTreeLearner(tree_config) {
Log::Fatal("GPU Tree Learner was not enabled in this build.\n"
"Please recompile with CMake option -DUSE_GPU=1");
}
......
......@@ -22,7 +22,7 @@ namespace LightGBM {
template <typename TREELEARNER_T>
class FeatureParallelTreeLearner: public TREELEARNER_T {
public:
explicit FeatureParallelTreeLearner(const TreeConfig* tree_config);
explicit FeatureParallelTreeLearner(const Config* config);
~FeatureParallelTreeLearner();
void Init(const Dataset* train_data, bool is_constant_hessian) override;
......@@ -48,10 +48,10 @@ private:
template <typename TREELEARNER_T>
class DataParallelTreeLearner: public TREELEARNER_T {
public:
explicit DataParallelTreeLearner(const TreeConfig* tree_config);
explicit DataParallelTreeLearner(const Config* config);
~DataParallelTreeLearner();
void Init(const Dataset* train_data, bool is_constant_hessian) override;
void ResetConfig(const TreeConfig* tree_config) override;
void ResetConfig(const Config* config) override;
protected:
void BeforeTrain() override;
void FindBestSplits() override;
......@@ -101,10 +101,10 @@ private:
template <typename TREELEARNER_T>
class VotingParallelTreeLearner: public TREELEARNER_T {
public:
explicit VotingParallelTreeLearner(const TreeConfig* tree_config);
explicit VotingParallelTreeLearner(const Config* config);
~VotingParallelTreeLearner() { }
void Init(const Dataset* train_data, bool is_constant_hessian) override;
void ResetConfig(const TreeConfig* tree_config) override;
void ResetConfig(const Config* config) override;
protected:
void BeforeTrain() override;
bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override;
......@@ -137,7 +137,7 @@ protected:
private:
/*! \brief Tree config used in local mode */
TreeConfig local_tree_config_;
Config local_config_;
/*! \brief Voting size */
int top_k_;
/*! \brief Rank of local machine*/
......
......@@ -19,9 +19,9 @@ std::chrono::duration<double, std::milli> split_time;
std::chrono::duration<double, std::milli> ordered_bin_time;
#endif // TIMETAG
SerialTreeLearner::SerialTreeLearner(const TreeConfig* tree_config)
:tree_config_(tree_config) {
random_ = Random(tree_config_->feature_fraction_seed);
SerialTreeLearner::SerialTreeLearner(const Config* config)
:config_(config) {
random_ = Random(config_->feature_fraction_seed);
#pragma omp parallel
#pragma omp master
{
......@@ -47,22 +47,22 @@ void SerialTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian
is_constant_hessian_ = is_constant_hessian;
int max_cache_size = 0;
// Get the max size of pool
if (tree_config_->histogram_pool_size <= 0) {
max_cache_size = tree_config_->num_leaves;
if (config_->histogram_pool_size <= 0) {
max_cache_size = config_->num_leaves;
} else {
size_t total_histogram_size = 0;
for (int i = 0; i < train_data_->num_features(); ++i) {
total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureNumBin(i);
}
max_cache_size = static_cast<int>(tree_config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
max_cache_size = static_cast<int>(config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
}
// at least need 2 leaves
max_cache_size = std::max(2, max_cache_size);
max_cache_size = std::min(max_cache_size, tree_config_->num_leaves);
max_cache_size = std::min(max_cache_size, config_->num_leaves);
histogram_pool_.DynamicChangeSize(train_data_, tree_config_, max_cache_size, tree_config_->num_leaves);
histogram_pool_.DynamicChangeSize(train_data_, config_, max_cache_size, config_->num_leaves);
// push split information for all leaves
best_split_per_leaf_.resize(tree_config_->num_leaves);
best_split_per_leaf_.resize(config_->num_leaves);
// get ordered bin
train_data_->CreateOrderedBins(&ordered_bins_);
......@@ -79,7 +79,7 @@ void SerialTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian
larger_leaf_splits_.reset(new LeafSplits(train_data_->num_data()));
// initialize data partition
data_partition_.reset(new DataPartition(num_data_, tree_config_->num_leaves));
data_partition_.reset(new DataPartition(num_data_, config_->num_leaves));
is_feature_used_.resize(num_features_);
valid_feature_indices_ = train_data_->ValidFeatureIndices();
// initialize ordered gradients and hessians
......@@ -124,33 +124,33 @@ void SerialTreeLearner::ResetTrainingData(const Dataset* train_data) {
}
}
void SerialTreeLearner::ResetConfig(const TreeConfig* tree_config) {
if (tree_config_->num_leaves != tree_config->num_leaves) {
tree_config_ = tree_config;
void SerialTreeLearner::ResetConfig(const Config* config) {
if (config_->num_leaves != config->num_leaves) {
config_ = config;
int max_cache_size = 0;
// Get the max size of pool
if (tree_config->histogram_pool_size <= 0) {
max_cache_size = tree_config_->num_leaves;
if (config->histogram_pool_size <= 0) {
max_cache_size = config_->num_leaves;
} else {
size_t total_histogram_size = 0;
for (int i = 0; i < train_data_->num_features(); ++i) {
total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureNumBin(i);
}
max_cache_size = static_cast<int>(tree_config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
max_cache_size = static_cast<int>(config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
}
// at least need 2 leaves
max_cache_size = std::max(2, max_cache_size);
max_cache_size = std::min(max_cache_size, tree_config_->num_leaves);
histogram_pool_.DynamicChangeSize(train_data_, tree_config_, max_cache_size, tree_config_->num_leaves);
max_cache_size = std::min(max_cache_size, config_->num_leaves);
histogram_pool_.DynamicChangeSize(train_data_, config_, max_cache_size, config_->num_leaves);
// push split information for all leaves
best_split_per_leaf_.resize(tree_config_->num_leaves);
data_partition_->ResetLeaves(tree_config_->num_leaves);
best_split_per_leaf_.resize(config_->num_leaves);
data_partition_->ResetLeaves(config_->num_leaves);
} else {
tree_config_ = tree_config;
config_ = config;
}
histogram_pool_.ResetConfig(tree_config_);
histogram_pool_.ResetConfig(config_);
}
Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_constant_hessian, Json& forced_split_json) {
......@@ -167,7 +167,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
init_train_time += std::chrono::steady_clock::now() - start_time;
#endif
auto tree = std::unique_ptr<Tree>(new Tree(tree_config_->num_leaves));
auto tree = std::unique_ptr<Tree>(new Tree(config_->num_leaves));
// root leaf
int left_leaf = 0;
int cur_depth = 1;
......@@ -181,7 +181,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
&right_leaf, &cur_depth, &aborted_last_force_split);
}
for (int split = init_splits; split < tree_config_->num_leaves - 1; ++split) {
for (int split = init_splits; split < config_->num_leaves - 1; ++split) {
#ifdef TIMETAG
start_time = std::chrono::steady_clock::now();
#endif
......@@ -236,7 +236,7 @@ Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t*
sum_hess += hessians[idx];
}
double output = FeatureHistogram::CalculateSplittedLeafOutput(sum_grad, sum_hess,
tree_config_->lambda_l1, tree_config_->lambda_l2, tree_config_->max_delta_step);
config_->lambda_l1, config_->lambda_l2, config_->max_delta_step);
tree->SetLeafOutput(i, output* tree->shrinkage());
OMP_LOOP_EX_END();
}
......@@ -254,8 +254,8 @@ void SerialTreeLearner::BeforeTrain() {
// reset histogram pool
histogram_pool_.ResetMap();
if (tree_config_->feature_fraction < 1) {
int used_feature_cnt = static_cast<int>(valid_feature_indices_.size()*tree_config_->feature_fraction);
if (config_->feature_fraction < 1) {
int used_feature_cnt = static_cast<int>(valid_feature_indices_.size()*config_->feature_fraction);
// at least use one feature
used_feature_cnt = std::max(used_feature_cnt, 1);
// initialize used features
......@@ -281,7 +281,7 @@ void SerialTreeLearner::BeforeTrain() {
data_partition_->Init();
// reset the splits for leaves
for (int i = 0; i < tree_config_->num_leaves; ++i) {
for (int i = 0; i < config_->num_leaves; ++i) {
best_split_per_leaf_[i].Reset();
}
......@@ -308,7 +308,7 @@ void SerialTreeLearner::BeforeTrain() {
#pragma omp parallel for schedule(static)
for (int i = 0; i < static_cast<int>(ordered_bin_indices_.size()); ++i) {
OMP_LOOP_EX_BEGIN();
ordered_bins_[ordered_bin_indices_[i]]->Init(nullptr, tree_config_->num_leaves);
ordered_bins_[ordered_bin_indices_[i]]->Init(nullptr, config_->num_leaves);
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
......@@ -329,7 +329,7 @@ void SerialTreeLearner::BeforeTrain() {
#pragma omp parallel for schedule(static)
for (int i = 0; i < static_cast<int>(ordered_bin_indices_.size()); ++i) {
OMP_LOOP_EX_BEGIN();
ordered_bins_[ordered_bin_indices_[i]]->Init(is_data_in_leaf_.data(), tree_config_->num_leaves);
ordered_bins_[ordered_bin_indices_[i]]->Init(is_data_in_leaf_.data(), config_->num_leaves);
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
......@@ -346,9 +346,9 @@ void SerialTreeLearner::BeforeTrain() {
bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) {
// check depth of current leaf
if (tree_config_->max_depth > 0) {
if (config_->max_depth > 0) {
// only need to check left leaf, since right leaf is in same level of left leaf
if (tree->leaf_depth(left_leaf) >= tree_config_->max_depth) {
if (tree->leaf_depth(left_leaf) >= config_->max_depth) {
best_split_per_leaf_[left_leaf].gain = kMinScore;
if (right_leaf >= 0) {
best_split_per_leaf_[right_leaf].gain = kMinScore;
......@@ -359,8 +359,8 @@ bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int
data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf);
data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf);
// no enough data to continue
if (num_data_in_right_child < static_cast<data_size_t>(tree_config_->min_data_in_leaf * 2)
&& num_data_in_left_child < static_cast<data_size_t>(tree_config_->min_data_in_leaf * 2)) {
if (num_data_in_right_child < static_cast<data_size_t>(config_->min_data_in_leaf * 2)
&& num_data_in_left_child < static_cast<data_size_t>(config_->min_data_in_leaf * 2)) {
best_split_per_leaf_[left_leaf].gain = kMinScore;
if (right_leaf >= 0) {
best_split_per_leaf_[right_leaf].gain = kMinScore;
......
......@@ -33,7 +33,7 @@ namespace LightGBM {
*/
class SerialTreeLearner: public TreeLearner {
public:
explicit SerialTreeLearner(const TreeConfig* tree_config);
explicit SerialTreeLearner(const Config* config);
~SerialTreeLearner();
......@@ -41,7 +41,7 @@ public:
void ResetTrainingData(const Dataset* train_data) override;
void ResetConfig(const TreeConfig* tree_config) override;
void ResetConfig(const Config* config) override;
Tree* Train(const score_t* gradients, const score_t *hessians, bool is_constant_hessian,
Json& forced_split_json) override;
......@@ -163,7 +163,7 @@ protected:
/*! \brief used to cache historical histogram to speed up*/
HistogramPool histogram_pool_;
/*! \brief config of tree learner*/
const TreeConfig* tree_config_;
const Config* config_;
int num_threads_;
std::vector<int> ordered_bin_indices_;
bool is_constant_hessian_;
......
......@@ -6,27 +6,27 @@
namespace LightGBM {
TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, const std::string& device_type, const TreeConfig* tree_config) {
TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, const std::string& device_type, const Config* config) {
if (device_type == std::string("cpu")) {
if (learner_type == std::string("serial")) {
return new SerialTreeLearner(tree_config);
return new SerialTreeLearner(config);
} else if (learner_type == std::string("feature")) {
return new FeatureParallelTreeLearner<SerialTreeLearner>(tree_config);
return new FeatureParallelTreeLearner<SerialTreeLearner>(config);
} else if (learner_type == std::string("data")) {
return new DataParallelTreeLearner<SerialTreeLearner>(tree_config);
return new DataParallelTreeLearner<SerialTreeLearner>(config);
} else if (learner_type == std::string("voting")) {
return new VotingParallelTreeLearner<SerialTreeLearner>(tree_config);
return new VotingParallelTreeLearner<SerialTreeLearner>(config);
}
}
else if (device_type == std::string("gpu")) {
if (learner_type == std::string("serial")) {
return new GPUTreeLearner(tree_config);
return new GPUTreeLearner(config);
} else if (learner_type == std::string("feature")) {
return new FeatureParallelTreeLearner<GPUTreeLearner>(tree_config);
return new FeatureParallelTreeLearner<GPUTreeLearner>(config);
} else if (learner_type == std::string("data")) {
return new DataParallelTreeLearner<GPUTreeLearner>(tree_config);
return new DataParallelTreeLearner<GPUTreeLearner>(config);
} else if (learner_type == std::string("voting")) {
return new VotingParallelTreeLearner<GPUTreeLearner>(tree_config);
return new VotingParallelTreeLearner<GPUTreeLearner>(config);
}
}
return nullptr;
......
......@@ -10,9 +10,9 @@
namespace LightGBM {
template <typename TREELEARNER_T>
VotingParallelTreeLearner<TREELEARNER_T>::VotingParallelTreeLearner(const TreeConfig* tree_config)
:TREELEARNER_T(tree_config) {
top_k_ = this->tree_config_->top_k;
VotingParallelTreeLearner<TREELEARNER_T>::VotingParallelTreeLearner(const Config* config)
:TREELEARNER_T(config) {
top_k_ = this->config_->top_k;
}
template <typename TREELEARNER_T>
......@@ -46,16 +46,16 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b
smaller_buffer_read_start_pos_.resize(this->num_features_);
larger_buffer_read_start_pos_.resize(this->num_features_);
global_data_count_in_leaf_.resize(this->tree_config_->num_leaves);
global_data_count_in_leaf_.resize(this->config_->num_leaves);
smaller_leaf_splits_global_.reset(new LeafSplits(this->train_data_->num_data()));
larger_leaf_splits_global_.reset(new LeafSplits(this->train_data_->num_data()));
local_tree_config_ = *this->tree_config_;
local_tree_config_.min_data_in_leaf /= num_machines_;
local_tree_config_.min_sum_hessian_in_leaf /= num_machines_;
local_config_ = *this->config_;
local_config_.min_data_in_leaf /= num_machines_;
local_config_.min_sum_hessian_in_leaf /= num_machines_;
this->histogram_pool_.ResetConfig(&local_tree_config_);
this->histogram_pool_.ResetConfig(&local_config_);
// initialize histograms for global
smaller_leaf_histogram_array_global_.reset(new FeatureHistogram[this->num_features_]);
......@@ -75,7 +75,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b
} else {
feature_metas_[i].bias = 0;
}
feature_metas_[i].tree_config = this->tree_config_;
feature_metas_[i].config = this->config_;
feature_metas_[i].bin_type = train_data->FeatureBinMapper(i)->bin_type();
}
uint64_t offset = 0;
......@@ -92,18 +92,18 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b
}
template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::ResetConfig(const TreeConfig* tree_config) {
TREELEARNER_T::ResetConfig(tree_config);
void VotingParallelTreeLearner<TREELEARNER_T>::ResetConfig(const Config* config) {
TREELEARNER_T::ResetConfig(config);
local_tree_config_ = *this->tree_config_;
local_tree_config_.min_data_in_leaf /= num_machines_;
local_tree_config_.min_sum_hessian_in_leaf /= num_machines_;
local_config_ = *this->config_;
local_config_.min_data_in_leaf /= num_machines_;
local_config_.min_sum_hessian_in_leaf /= num_machines_;
this->histogram_pool_.ResetConfig(&local_tree_config_);
global_data_count_in_leaf_.resize(this->tree_config_->num_leaves);
this->histogram_pool_.ResetConfig(&local_config_);
global_data_count_in_leaf_.resize(this->config_->num_leaves);
for (size_t i = 0; i < feature_metas_.size(); ++i) {
feature_metas_[i].tree_config = this->tree_config_;
feature_metas_[i].config = this->config_;
}
}
......@@ -451,7 +451,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons
larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()];
}
// sync global best info
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->tree_config_->max_cat_threshold);
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->config_->max_cat_threshold);
// copy back
this->best_split_per_leaf_[smaller_leaf_splits_global_->LeafIndex()] = smaller_best_split;
......
......@@ -257,6 +257,7 @@
<ClCompile Include="..\src\c_api.cpp" />
<ClCompile Include="..\src\io\bin.cpp" />
<ClCompile Include="..\src\io\config.cpp" />
<ClCompile Include="..\src\io\config_auto.cpp" />
<ClCompile Include="..\src\io\dataset.cpp" />
<ClCompile Include="..\src\io\dataset_loader.cpp" />
<ClCompile Include="..\src\io\file_io.cpp" />
......
......@@ -299,5 +299,8 @@
<ClCompile Include="..\src\io\json11.cpp">
<Filter>src\io</Filter>
</ClCompile>
<ClCompile Include="..\src\io\config_auto.cpp">
<Filter>src\io</Filter>
</ClCompile>
</ItemGroup>
</Project>
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment