Unverified Commit dc699574 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

Refine config object (#1381)

* [WIP] refine config

* [wip] ready for the auto code generate

* auto generate config codes

* use with to open file

* fix bug

* fix pylint

* fix bug

* fix pylint

* fix bugs.

* tmp for failed test.

* fix tests.

* added nthreads alias

* added new aliases from new config.h

* fixed duplicated alias

* refactored parameter_generator.py

* added new aliases from config.h and removed remaining old names

* fix bugs & some miss alias

* added aliases

* add more descriptions.

* add comment.
parent 497e60ed
...@@ -15,7 +15,7 @@ namespace LightGBM { ...@@ -15,7 +15,7 @@ namespace LightGBM {
*/ */
class MulticlassSoftmax: public ObjectiveFunction { class MulticlassSoftmax: public ObjectiveFunction {
public: public:
explicit MulticlassSoftmax(const ObjectiveConfig& config) { explicit MulticlassSoftmax(const Config& config) {
num_class_ = config.num_class; num_class_ = config.num_class;
} }
...@@ -138,7 +138,7 @@ private: ...@@ -138,7 +138,7 @@ private:
*/ */
class MulticlassOVA: public ObjectiveFunction { class MulticlassOVA: public ObjectiveFunction {
public: public:
explicit MulticlassOVA(const ObjectiveConfig& config) { explicit MulticlassOVA(const Config& config) {
num_class_ = config.num_class; num_class_ = config.num_class;
for (int i = 0; i < num_class_; ++i) { for (int i = 0; i < num_class_; ++i) {
binary_loss_.emplace_back( binary_loss_.emplace_back(
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
namespace LightGBM { namespace LightGBM {
ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& type, const ObjectiveConfig& config) { ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& type, const Config& config) {
if (type == std::string("regression") || type == std::string("regression_l2") if (type == std::string("regression") || type == std::string("regression_l2")
|| type == std::string("mean_squared_error") || type == std::string("mse") || type == std::string("mean_squared_error") || type == std::string("mse")
|| type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) { || type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
......
...@@ -18,15 +18,12 @@ namespace LightGBM { ...@@ -18,15 +18,12 @@ namespace LightGBM {
*/ */
class LambdarankNDCG: public ObjectiveFunction { class LambdarankNDCG: public ObjectiveFunction {
public: public:
explicit LambdarankNDCG(const ObjectiveConfig& config) { explicit LambdarankNDCG(const Config& config) {
sigmoid_ = static_cast<double>(config.sigmoid); sigmoid_ = static_cast<double>(config.sigmoid);
label_gain_ = config.label_gain;
// initialize DCG calculator // initialize DCG calculator
DCGCalculator::Init(config.label_gain); DCGCalculator::DefaultLabelGain(&label_gain_);
// copy lable gain to local DCGCalculator::Init(label_gain_);
for (auto gain : config.label_gain) {
label_gain_.push_back(static_cast<double>(gain));
}
label_gain_.shrink_to_fit();
// will optimize NDCG@optimize_pos_at_ // will optimize NDCG@optimize_pos_at_
optimize_pos_at_ = config.max_position; optimize_pos_at_ = config.max_position;
sigmoid_table_.clear(); sigmoid_table_.clear();
......
...@@ -63,7 +63,7 @@ namespace LightGBM { ...@@ -63,7 +63,7 @@ namespace LightGBM {
*/ */
class RegressionL2loss: public ObjectiveFunction { class RegressionL2loss: public ObjectiveFunction {
public: public:
explicit RegressionL2loss(const ObjectiveConfig& config) { explicit RegressionL2loss(const Config& config) {
sqrt_ = config.reg_sqrt; sqrt_ = config.reg_sqrt;
} }
...@@ -174,7 +174,7 @@ protected: ...@@ -174,7 +174,7 @@ protected:
*/ */
class RegressionL1loss: public RegressionL2loss { class RegressionL1loss: public RegressionL2loss {
public: public:
explicit RegressionL1loss(const ObjectiveConfig& config): RegressionL2loss(config) { explicit RegressionL1loss(const Config& config): RegressionL2loss(config) {
} }
explicit RegressionL1loss(const std::vector<std::string>& strs): RegressionL2loss(strs) { explicit RegressionL1loss(const std::vector<std::string>& strs): RegressionL2loss(strs) {
...@@ -260,7 +260,7 @@ public: ...@@ -260,7 +260,7 @@ public:
*/ */
class RegressionHuberLoss: public RegressionL2loss { class RegressionHuberLoss: public RegressionL2loss {
public: public:
explicit RegressionHuberLoss(const ObjectiveConfig& config): RegressionL2loss(config) { explicit RegressionHuberLoss(const Config& config): RegressionL2loss(config) {
alpha_ = static_cast<double>(config.alpha); alpha_ = static_cast<double>(config.alpha);
} }
...@@ -315,7 +315,7 @@ private: ...@@ -315,7 +315,7 @@ private:
// http://research.microsoft.com/en-us/um/people/zhang/INRIA/Publis/Tutorial-Estim/node24.html // http://research.microsoft.com/en-us/um/people/zhang/INRIA/Publis/Tutorial-Estim/node24.html
class RegressionFairLoss: public RegressionL2loss { class RegressionFairLoss: public RegressionL2loss {
public: public:
explicit RegressionFairLoss(const ObjectiveConfig& config): RegressionL2loss(config) { explicit RegressionFairLoss(const Config& config): RegressionL2loss(config) {
c_ = static_cast<double>(config.fair_c); c_ = static_cast<double>(config.fair_c);
} }
...@@ -363,7 +363,7 @@ private: ...@@ -363,7 +363,7 @@ private:
*/ */
class RegressionPoissonLoss: public RegressionL2loss { class RegressionPoissonLoss: public RegressionL2loss {
public: public:
explicit RegressionPoissonLoss(const ObjectiveConfig& config): RegressionL2loss(config) { explicit RegressionPoissonLoss(const Config& config): RegressionL2loss(config) {
max_delta_step_ = static_cast<double>(config.poisson_max_delta_step); max_delta_step_ = static_cast<double>(config.poisson_max_delta_step);
if (sqrt_) { if (sqrt_) {
Log::Warning("Cannot use sqrt transform in %s Regression, will auto disable it", GetName()); Log::Warning("Cannot use sqrt transform in %s Regression, will auto disable it", GetName());
...@@ -444,7 +444,7 @@ private: ...@@ -444,7 +444,7 @@ private:
class RegressionQuantileloss : public RegressionL2loss { class RegressionQuantileloss : public RegressionL2loss {
public: public:
explicit RegressionQuantileloss(const ObjectiveConfig& config): RegressionL2loss(config) { explicit RegressionQuantileloss(const Config& config): RegressionL2loss(config) {
alpha_ = static_cast<score_t>(config.alpha); alpha_ = static_cast<score_t>(config.alpha);
CHECK(alpha_ > 0 && alpha_ < 1); CHECK(alpha_ > 0 && alpha_ < 1);
} }
...@@ -543,7 +543,7 @@ private: ...@@ -543,7 +543,7 @@ private:
*/ */
class RegressionMAPELOSS : public RegressionL1loss { class RegressionMAPELOSS : public RegressionL1loss {
public: public:
explicit RegressionMAPELOSS(const ObjectiveConfig& config) : RegressionL1loss(config) { explicit RegressionMAPELOSS(const Config& config) : RegressionL1loss(config) {
} }
explicit RegressionMAPELOSS(const std::vector<std::string>& strs) : RegressionL1loss(strs) { explicit RegressionMAPELOSS(const std::vector<std::string>& strs) : RegressionL1loss(strs) {
...@@ -644,7 +644,7 @@ private: ...@@ -644,7 +644,7 @@ private:
*/ */
class RegressionGammaLoss : public RegressionPoissonLoss { class RegressionGammaLoss : public RegressionPoissonLoss {
public: public:
explicit RegressionGammaLoss(const ObjectiveConfig& config) : RegressionPoissonLoss(config) { explicit RegressionGammaLoss(const Config& config) : RegressionPoissonLoss(config) {
} }
explicit RegressionGammaLoss(const std::vector<std::string>& strs) : RegressionPoissonLoss(strs) { explicit RegressionGammaLoss(const std::vector<std::string>& strs) : RegressionPoissonLoss(strs) {
...@@ -681,7 +681,7 @@ public: ...@@ -681,7 +681,7 @@ public:
*/ */
class RegressionTweedieLoss: public RegressionPoissonLoss { class RegressionTweedieLoss: public RegressionPoissonLoss {
public: public:
explicit RegressionTweedieLoss(const ObjectiveConfig& config) : RegressionPoissonLoss(config) { explicit RegressionTweedieLoss(const Config& config) : RegressionPoissonLoss(config) {
rho_ = config.tweedie_variance_power; rho_ = config.tweedie_variance_power;
} }
......
...@@ -38,7 +38,7 @@ namespace LightGBM { ...@@ -38,7 +38,7 @@ namespace LightGBM {
*/ */
class CrossEntropy: public ObjectiveFunction { class CrossEntropy: public ObjectiveFunction {
public: public:
explicit CrossEntropy(const ObjectiveConfig&) { explicit CrossEntropy(const Config&) {
} }
explicit CrossEntropy(const std::vector<std::string>&) { explicit CrossEntropy(const std::vector<std::string>&) {
...@@ -141,7 +141,7 @@ private: ...@@ -141,7 +141,7 @@ private:
*/ */
class CrossEntropyLambda: public ObjectiveFunction { class CrossEntropyLambda: public ObjectiveFunction {
public: public:
explicit CrossEntropyLambda(const ObjectiveConfig&) { explicit CrossEntropyLambda(const Config&) {
min_weight_ = max_weight_ = 0.0f; min_weight_ = max_weight_ = 0.0f;
} }
......
...@@ -8,8 +8,8 @@ ...@@ -8,8 +8,8 @@
namespace LightGBM { namespace LightGBM {
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
DataParallelTreeLearner<TREELEARNER_T>::DataParallelTreeLearner(const TreeConfig* tree_config) DataParallelTreeLearner<TREELEARNER_T>::DataParallelTreeLearner(const Config* config)
:TREELEARNER_T(tree_config) { :TREELEARNER_T(config) {
} }
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
...@@ -37,13 +37,13 @@ void DataParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, boo ...@@ -37,13 +37,13 @@ void DataParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, boo
buffer_write_start_pos_.resize(this->num_features_); buffer_write_start_pos_.resize(this->num_features_);
buffer_read_start_pos_.resize(this->num_features_); buffer_read_start_pos_.resize(this->num_features_);
global_data_count_in_leaf_.resize(this->tree_config_->num_leaves); global_data_count_in_leaf_.resize(this->config_->num_leaves);
} }
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
void DataParallelTreeLearner<TREELEARNER_T>::ResetConfig(const TreeConfig* tree_config) { void DataParallelTreeLearner<TREELEARNER_T>::ResetConfig(const Config* config) {
TREELEARNER_T::ResetConfig(tree_config); TREELEARNER_T::ResetConfig(config);
global_data_count_in_leaf_.resize(this->tree_config_->num_leaves); global_data_count_in_leaf_.resize(this->config_->num_leaves);
} }
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
...@@ -236,7 +236,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const ...@@ -236,7 +236,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const
} }
// sync global best info // sync global best info
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->tree_config_->max_cat_threshold); SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->config_->max_cat_threshold);
// set best split // set best split
this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split; this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split;
......
This diff is collapsed.
...@@ -8,8 +8,8 @@ namespace LightGBM { ...@@ -8,8 +8,8 @@ namespace LightGBM {
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
FeatureParallelTreeLearner<TREELEARNER_T>::FeatureParallelTreeLearner(const TreeConfig* tree_config) FeatureParallelTreeLearner<TREELEARNER_T>::FeatureParallelTreeLearner(const Config* config)
:TREELEARNER_T(tree_config) { :TREELEARNER_T(config) {
} }
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
...@@ -22,8 +22,8 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, ...@@ -22,8 +22,8 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data,
TREELEARNER_T::Init(train_data, is_constant_hessian); TREELEARNER_T::Init(train_data, is_constant_hessian);
rank_ = Network::rank(); rank_ = Network::rank();
num_machines_ = Network::num_machines(); num_machines_ = Network::num_machines();
input_buffer_.resize((sizeof(SplitInfo) + sizeof(uint32_t) * this->tree_config_->max_cat_threshold) * 2); input_buffer_.resize((sizeof(SplitInfo) + sizeof(uint32_t) * this->config_->max_cat_threshold) * 2);
output_buffer_.resize((sizeof(SplitInfo) + sizeof(uint32_t) * this->tree_config_->max_cat_threshold) * 2); output_buffer_.resize((sizeof(SplitInfo) + sizeof(uint32_t) * this->config_->max_cat_threshold) * 2);
} }
...@@ -60,7 +60,7 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(con ...@@ -60,7 +60,7 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(con
larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()]; larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()];
} }
// sync global best info // sync global best info
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->tree_config_->max_cat_threshold); SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->config_->max_cat_threshold);
// update best split // update best split
this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split; this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split;
if (this->larger_leaf_splits_->LeafIndex() >= 0) { if (this->larger_leaf_splits_->LeafIndex() >= 0) {
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
namespace LightGBM { namespace LightGBM {
GPUTreeLearner::GPUTreeLearner(const TreeConfig* tree_config) GPUTreeLearner::GPUTreeLearner(const Config* config)
:SerialTreeLearner(tree_config) { :SerialTreeLearner(config) {
use_bagging_ = false; use_bagging_ = false;
Log::Info("This is the GPU trainer!!"); Log::Info("This is the GPU trainer!!");
} }
...@@ -39,7 +39,7 @@ void GPUTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian) { ...@@ -39,7 +39,7 @@ void GPUTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian) {
// some additional variables needed for GPU trainer // some additional variables needed for GPU trainer
num_feature_groups_ = train_data_->num_feature_groups(); num_feature_groups_ = train_data_->num_feature_groups();
// Initialize GPU buffers and kernels // Initialize GPU buffers and kernels
InitGPU(tree_config_->gpu_platform_id, tree_config_->gpu_device_id); InitGPU(config_->gpu_platform_id, config_->gpu_device_id);
} }
// some functions used for debugging the GPU histogram construction // some functions used for debugging the GPU histogram construction
...@@ -304,7 +304,7 @@ void GPUTreeLearner::AllocateGPUMemory() { ...@@ -304,7 +304,7 @@ void GPUTreeLearner::AllocateGPUMemory() {
device_data_indices_ = std::unique_ptr<boost::compute::vector<data_size_t>>(new boost::compute::vector<data_size_t>(allocated_num_data_, ctx_)); device_data_indices_ = std::unique_ptr<boost::compute::vector<data_size_t>>(new boost::compute::vector<data_size_t>(allocated_num_data_, ctx_));
boost::compute::fill(device_data_indices_->begin(), device_data_indices_->end(), 0, queue_); boost::compute::fill(device_data_indices_->begin(), device_data_indices_->end(), 0, queue_);
// histogram bin entry size depends on the precision (single/double) // histogram bin entry size depends on the precision (single/double)
hist_bin_entry_sz_ = tree_config_->gpu_use_dp ? sizeof(HistogramBinEntry) : sizeof(GPUHistogramBinEntry); hist_bin_entry_sz_ = config_->gpu_use_dp ? sizeof(HistogramBinEntry) : sizeof(GPUHistogramBinEntry);
Log::Info("Size of histogram bin entry: %d", hist_bin_entry_sz_); Log::Info("Size of histogram bin entry: %d", hist_bin_entry_sz_);
// create output buffer, each feature has a histogram with device_bin_size_ bins, // create output buffer, each feature has a histogram with device_bin_size_ bins,
// each work group generates a sub-histogram of dword_features_ features. // each work group generates a sub-histogram of dword_features_ features.
...@@ -598,7 +598,7 @@ void GPUTreeLearner::BuildGPUKernels() { ...@@ -598,7 +598,7 @@ void GPUTreeLearner::BuildGPUKernels() {
std::ostringstream opts; std::ostringstream opts;
// compile the GPU kernel depending if double precision is used, constant hessian is used, etc // compile the GPU kernel depending if double precision is used, constant hessian is used, etc
opts << " -D POWER_FEATURE_WORKGROUPS=" << i opts << " -D POWER_FEATURE_WORKGROUPS=" << i
<< " -D USE_CONSTANT_BUF=" << use_constants << " -D USE_DP_FLOAT=" << int(tree_config_->gpu_use_dp) << " -D USE_CONSTANT_BUF=" << use_constants << " -D USE_DP_FLOAT=" << int(config_->gpu_use_dp)
<< " -D CONST_HESSIAN=" << int(is_constant_hessian_) << " -D CONST_HESSIAN=" << int(is_constant_hessian_)
<< " -cl-mad-enable -cl-no-signed-zeros -cl-fast-relaxed-math"; << " -cl-mad-enable -cl-no-signed-zeros -cl-fast-relaxed-math";
#if GPU_DEBUG >= 1 #if GPU_DEBUG >= 1
...@@ -1006,7 +1006,7 @@ void GPUTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_feature_u ...@@ -1006,7 +1006,7 @@ void GPUTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_feature_u
ptr_smaller_leaf_hist_data); ptr_smaller_leaf_hist_data);
// wait for GPU to finish, only if GPU is actually used // wait for GPU to finish, only if GPU is actually used
if (is_gpu_used) { if (is_gpu_used) {
if (tree_config_->gpu_use_dp) { if (config_->gpu_use_dp) {
// use double precision // use double precision
WaitAndGetHistograms<HistogramBinEntry>(ptr_smaller_leaf_hist_data); WaitAndGetHistograms<HistogramBinEntry>(ptr_smaller_leaf_hist_data);
} }
...@@ -1060,7 +1060,7 @@ void GPUTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_feature_u ...@@ -1060,7 +1060,7 @@ void GPUTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_feature_u
ptr_larger_leaf_hist_data); ptr_larger_leaf_hist_data);
// wait for GPU to finish, only if GPU is actually used // wait for GPU to finish, only if GPU is actually used
if (is_gpu_used) { if (is_gpu_used) {
if (tree_config_->gpu_use_dp) { if (config_->gpu_use_dp) {
// use double precision // use double precision
WaitAndGetHistograms<HistogramBinEntry>(ptr_larger_leaf_hist_data); WaitAndGetHistograms<HistogramBinEntry>(ptr_larger_leaf_hist_data);
} }
......
...@@ -37,7 +37,7 @@ namespace LightGBM { ...@@ -37,7 +37,7 @@ namespace LightGBM {
*/ */
class GPUTreeLearner: public SerialTreeLearner { class GPUTreeLearner: public SerialTreeLearner {
public: public:
explicit GPUTreeLearner(const TreeConfig* tree_config); explicit GPUTreeLearner(const Config* tree_config);
~GPUTreeLearner(); ~GPUTreeLearner();
void Init(const Dataset* train_data, bool is_constant_hessian) override; void Init(const Dataset* train_data, bool is_constant_hessian) override;
void ResetTrainingData(const Dataset* train_data) override; void ResetTrainingData(const Dataset* train_data) override;
...@@ -270,7 +270,7 @@ namespace LightGBM { ...@@ -270,7 +270,7 @@ namespace LightGBM {
class GPUTreeLearner: public SerialTreeLearner { class GPUTreeLearner: public SerialTreeLearner {
public: public:
#pragma warning(disable : 4702) #pragma warning(disable : 4702)
explicit GPUTreeLearner(const TreeConfig* tree_config) : SerialTreeLearner(tree_config) { explicit GPUTreeLearner(const Config* tree_config) : SerialTreeLearner(tree_config) {
Log::Fatal("GPU Tree Learner was not enabled in this build.\n" Log::Fatal("GPU Tree Learner was not enabled in this build.\n"
"Please recompile with CMake option -DUSE_GPU=1"); "Please recompile with CMake option -DUSE_GPU=1");
} }
......
...@@ -22,7 +22,7 @@ namespace LightGBM { ...@@ -22,7 +22,7 @@ namespace LightGBM {
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
class FeatureParallelTreeLearner: public TREELEARNER_T { class FeatureParallelTreeLearner: public TREELEARNER_T {
public: public:
explicit FeatureParallelTreeLearner(const TreeConfig* tree_config); explicit FeatureParallelTreeLearner(const Config* config);
~FeatureParallelTreeLearner(); ~FeatureParallelTreeLearner();
void Init(const Dataset* train_data, bool is_constant_hessian) override; void Init(const Dataset* train_data, bool is_constant_hessian) override;
...@@ -48,10 +48,10 @@ private: ...@@ -48,10 +48,10 @@ private:
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
class DataParallelTreeLearner: public TREELEARNER_T { class DataParallelTreeLearner: public TREELEARNER_T {
public: public:
explicit DataParallelTreeLearner(const TreeConfig* tree_config); explicit DataParallelTreeLearner(const Config* config);
~DataParallelTreeLearner(); ~DataParallelTreeLearner();
void Init(const Dataset* train_data, bool is_constant_hessian) override; void Init(const Dataset* train_data, bool is_constant_hessian) override;
void ResetConfig(const TreeConfig* tree_config) override; void ResetConfig(const Config* config) override;
protected: protected:
void BeforeTrain() override; void BeforeTrain() override;
void FindBestSplits() override; void FindBestSplits() override;
...@@ -101,10 +101,10 @@ private: ...@@ -101,10 +101,10 @@ private:
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
class VotingParallelTreeLearner: public TREELEARNER_T { class VotingParallelTreeLearner: public TREELEARNER_T {
public: public:
explicit VotingParallelTreeLearner(const TreeConfig* tree_config); explicit VotingParallelTreeLearner(const Config* config);
~VotingParallelTreeLearner() { } ~VotingParallelTreeLearner() { }
void Init(const Dataset* train_data, bool is_constant_hessian) override; void Init(const Dataset* train_data, bool is_constant_hessian) override;
void ResetConfig(const TreeConfig* tree_config) override; void ResetConfig(const Config* config) override;
protected: protected:
void BeforeTrain() override; void BeforeTrain() override;
bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override; bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override;
...@@ -137,7 +137,7 @@ protected: ...@@ -137,7 +137,7 @@ protected:
private: private:
/*! \brief Tree config used in local mode */ /*! \brief Tree config used in local mode */
TreeConfig local_tree_config_; Config local_config_;
/*! \brief Voting size */ /*! \brief Voting size */
int top_k_; int top_k_;
/*! \brief Rank of local machine*/ /*! \brief Rank of local machine*/
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -257,6 +257,7 @@ ...@@ -257,6 +257,7 @@
<ClCompile Include="..\src\c_api.cpp" /> <ClCompile Include="..\src\c_api.cpp" />
<ClCompile Include="..\src\io\bin.cpp" /> <ClCompile Include="..\src\io\bin.cpp" />
<ClCompile Include="..\src\io\config.cpp" /> <ClCompile Include="..\src\io\config.cpp" />
<ClCompile Include="..\src\io\config_auto.cpp" />
<ClCompile Include="..\src\io\dataset.cpp" /> <ClCompile Include="..\src\io\dataset.cpp" />
<ClCompile Include="..\src\io\dataset_loader.cpp" /> <ClCompile Include="..\src\io\dataset_loader.cpp" />
<ClCompile Include="..\src\io\file_io.cpp" /> <ClCompile Include="..\src\io\file_io.cpp" />
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment