Commit c2e94f17 authored by Guolin Ke's avatar Guolin Ke
Browse files

refine reset_parameters logic

parent 714c6732
...@@ -51,12 +51,6 @@ public: ...@@ -51,12 +51,6 @@ public:
*/ */
virtual void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) = 0; virtual void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) = 0;
/*!
* \brief Reset shrinkage_rate data for current boosting
* \param shrinkage_rate Configs for boosting
*/
virtual void ResetShrinkageRate(double shrinkage_rate) = 0;
/*! /*!
* \brief Add a validation data * \brief Add a validation data
* \param valid_data Validation data * \param valid_data Validation data
......
...@@ -22,11 +22,17 @@ public: ...@@ -22,11 +22,17 @@ public:
virtual ~TreeLearner() {} virtual ~TreeLearner() {}
/*! /*!
* \brief Initialize tree learner with training dataset and configs * \brief Initialize tree learner with training dataset
* \param train_data The used training data * \param train_data The used training data
*/ */
virtual void Init(const Dataset* train_data) = 0; virtual void Init(const Dataset* train_data) = 0;
/*!
* \brief Reset tree configs
* \param tree_config config of tree
*/
virtual void ResetConfig(const TreeConfig* tree_config) = 0;
/*! /*!
* \brief training tree model on dataset * \brief training tree model on dataset
* \param gradients The first order gradients * \param gradients The first order gradients
...@@ -58,9 +64,10 @@ public: ...@@ -58,9 +64,10 @@ public:
/*! /*!
* \brief Create object of tree learner * \brief Create object of tree learner
* \param type Type of tree learner * \param type Type of tree learner
* \param tree_config config of tree
*/ */
static TreeLearner* CreateTreeLearner(TreeLearnerType type, static TreeLearner* CreateTreeLearner(TreeLearnerType type,
const TreeConfig& tree_config); const TreeConfig* tree_config);
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -35,7 +35,6 @@ public: ...@@ -35,7 +35,6 @@ public:
void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics) override { const std::vector<const Metric*>& training_metrics) override {
GBDT::Init(config, train_data, object_function, training_metrics); GBDT::Init(config, train_data, object_function, training_metrics);
drop_rate_ = gbdt_config_->drop_rate;
shrinkage_rate_ = 1.0; shrinkage_rate_ = 1.0;
random_for_drop_ = Random(gbdt_config_->drop_seed); random_for_drop_ = Random(gbdt_config_->drop_seed);
} }
...@@ -53,6 +52,14 @@ public: ...@@ -53,6 +52,14 @@ public:
return false; return false;
} }
} }
void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics) {
GBDT::ResetTrainingData(config, train_data, object_function, training_metrics);
shrinkage_rate_ = 1.0;
random_for_drop_ = Random(gbdt_config_->drop_seed);
}
/*! /*!
* \brief Get current training score * \brief Get current training score
* \param out_len length of returned score * \param out_len length of returned score
...@@ -81,9 +88,9 @@ private: ...@@ -81,9 +88,9 @@ private:
drop_index_.clear(); drop_index_.clear();
// select dropping tree indexes based on drop_rate // select dropping tree indexes based on drop_rate
// if drop rate is too small, skip this step, drop one tree randomly // if drop rate is too small, skip this step, drop one tree randomly
if (drop_rate_ > kEpsilon) { if (gbdt_config_->drop_rate > kEpsilon) {
for (int i = 0; i < iter_; ++i) { for (int i = 0; i < iter_; ++i) {
if (random_for_drop_.NextDouble() < drop_rate_) { if (random_for_drop_.NextDouble() < gbdt_config_->drop_rate) {
drop_index_.push_back(i); drop_index_.push_back(i);
} }
} }
...@@ -123,8 +130,6 @@ private: ...@@ -123,8 +130,6 @@ private:
} }
/*! \brief The indexes of dropping trees */ /*! \brief The indexes of dropping trees */
std::vector<int> drop_index_; std::vector<int> drop_index_;
/*! \brief Dropping rate */
double drop_rate_;
/*! \brief Random generator, used to select dropping trees */ /*! \brief Random generator, used to select dropping trees */
Random random_for_drop_; Random random_for_drop_;
/*! \brief Flag that the score is update on current iter or not*/ /*! \brief Flag that the score is update on current iter or not*/
......
...@@ -33,41 +33,57 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O ...@@ -33,41 +33,57 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
max_feature_idx_ = 0; max_feature_idx_ = 0;
num_class_ = config->num_class; num_class_ = config->num_class;
train_data_ = nullptr; train_data_ = nullptr;
gbdt_config_ = nullptr;
ResetTrainingData(config, train_data, object_function, training_metrics); ResetTrainingData(config, train_data, object_function, training_metrics);
} }
void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics) { const std::vector<const Metric*>& training_metrics) {
if (train_data == nullptr) { return; }
auto new_config = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
if (train_data_ != nullptr && !train_data_->CheckAlign(*train_data)) { if (train_data_ != nullptr && !train_data_->CheckAlign(*train_data)) {
Log::Fatal("cannot reset training data, since new training data has different bin mappers"); Log::Fatal("cannot reset training data, since new training data has different bin mappers");
} }
gbdt_config_ = config; early_stopping_round_ = new_config->early_stopping_round;
early_stopping_round_ = gbdt_config_->early_stopping_round; shrinkage_rate_ = new_config->learning_rate;
shrinkage_rate_ = gbdt_config_->learning_rate; random_ = Random(new_config->bagging_seed);
random_ = Random(gbdt_config_->bagging_seed);
// create tree learner // create tree learner, only create once
tree_learner_.clear(); if (gbdt_config_ == nullptr) {
tree_learner_.clear();
for (int i = 0; i < num_class_; ++i) {
auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(new_config->tree_learner_type, &new_config->tree_config));
tree_learner_.push_back(std::move(new_tree_learner));
}
tree_learner_.shrink_to_fit();
}
// init tree learner
if (train_data_ != train_data) {
for (int i = 0; i < num_class_; ++i) {
tree_learner_[i]->Init(train_data);
}
}
// reset config for tree learner
for (int i = 0; i < num_class_; ++i) { for (int i = 0; i < num_class_; ++i) {
auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config)); tree_learner_[i]->ResetConfig(&new_config->tree_config);
new_tree_learner->Init(train_data);
// init tree learner
tree_learner_.push_back(std::move(new_tree_learner));
} }
tree_learner_.shrink_to_fit();
object_function_ = object_function; object_function_ = object_function;
// push training metrics
training_metrics_.clear();
for (const auto& metric : training_metrics) {
training_metrics_.push_back(metric);
}
training_metrics_.shrink_to_fit();
sigmoid_ = -1.0f; sigmoid_ = -1.0f;
if (object_function_ != nullptr if (object_function_ != nullptr
&& std::string(object_function_->GetName()) == std::string("binary")) { && std::string(object_function_->GetName()) == std::string("binary")) {
// only binary classification need sigmoid transform // only binary classification need sigmoid transform
sigmoid_ = gbdt_config_->sigmoid; sigmoid_ = new_config->sigmoid;
} }
if (train_data_ != train_data) { if (train_data_ != train_data) {
// push training metrics
training_metrics_.clear();
for (const auto& metric : training_metrics) {
training_metrics_.push_back(metric);
}
training_metrics_.shrink_to_fit();
// not same training data, need reset score and others // not same training data, need reset score and others
// create score tracker // create score tracker
train_score_updater_.reset(new ScoreUpdater(train_data, num_class_)); train_score_updater_.reset(new ScoreUpdater(train_data, num_class_));
...@@ -88,8 +104,13 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -88,8 +104,13 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
max_feature_idx_ = train_data->num_total_features() - 1; max_feature_idx_ = train_data->num_total_features() - 1;
// get label index // get label index
label_idx_ = train_data->label_idx(); label_idx_ = train_data->label_idx();
}
if (train_data_ != train_data
|| gbdt_config_ == nullptr
|| (gbdt_config_->bagging_fraction != new_config->bagging_fraction)) {
// if need bagging, create buffer // if need bagging, create buffer
if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) { if (new_config->bagging_fraction < 1.0 && new_config->bagging_freq > 0) {
out_of_bag_data_indices_ = std::vector<data_size_t>(num_data_); out_of_bag_data_indices_ = std::vector<data_size_t>(num_data_);
bag_data_indices_ = std::vector<data_size_t>(num_data_); bag_data_indices_ = std::vector<data_size_t>(num_data_);
} else { } else {
...@@ -100,6 +121,7 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -100,6 +121,7 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
} }
} }
train_data_ = train_data; train_data_ = train_data;
gbdt_config_.reset(new_config.release());
} }
void GBDT::AddValidDataset(const Dataset* valid_data, void GBDT::AddValidDataset(const Dataset* valid_data,
......
...@@ -68,14 +68,6 @@ public: ...@@ -68,14 +68,6 @@ public:
*/ */
void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) override; void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) override;
/*!
* \brief Reset shrinkage_rate data for current boosting
* \param shrinkage_rate Configs for boosting
*/
void ResetShrinkageRate(double shrinkage_rate) override {
shrinkage_rate_ = shrinkage_rate;
}
/*! /*!
* \brief Adding a validation dataset * \brief Adding a validation dataset
* \param valid_data Validation dataset * \param valid_data Validation dataset
...@@ -245,7 +237,7 @@ protected: ...@@ -245,7 +237,7 @@ protected:
/*! \brief Pointer to training data */ /*! \brief Pointer to training data */
const Dataset* train_data_; const Dataset* train_data_;
/*! \brief Config of gbdt */ /*! \brief Config of gbdt */
const BoostingConfig* gbdt_config_; std::unique_ptr<BoostingConfig> gbdt_config_;
/*! \brief Tree learner, will use this class to learn trees */ /*! \brief Tree learner, will use this class to learn trees */
std::vector<std::unique_ptr<TreeLearner>> tree_learner_; std::vector<std::unique_ptr<TreeLearner>> tree_learner_;
/*! \brief Objective function */ /*! \brief Objective function */
......
...@@ -40,12 +40,14 @@ public: ...@@ -40,12 +40,14 @@ public:
Log::Warning("continued train from model is not support for c_api, \ Log::Warning("continued train from model is not support for c_api, \
please use continued train with input score"); please use continued train with input score");
} }
boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr)); boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
train_data_ = train_data;
ConstructObjectAndTrainingMetrics(train_data);
// initialize the boosting // initialize the boosting
boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(), boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_)); Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
ResetTrainingData(train_data);
} }
void MergeFrom(const Booster* other) { void MergeFrom(const Booster* other) {
...@@ -60,13 +62,34 @@ public: ...@@ -60,13 +62,34 @@ public:
void ResetTrainingData(const Dataset* train_data) { void ResetTrainingData(const Dataset* train_data) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
train_data_ = train_data; train_data_ = train_data;
ConstructObjectAndTrainingMetrics(train_data_); // create objective function
// initialize the boosting objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
config_.objective_config));
if (objective_fun_ == nullptr) {
Log::Warning("Using self-defined objective function");
}
// initialize the objective function
if (objective_fun_ != nullptr) {
objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
}
// create training metric
train_metric_.clear();
for (auto metric_type : config_.metric_types) {
auto metric = std::unique_ptr<Metric>(
Metric::CreateMetric(metric_type, config_.metric_config));
if (metric == nullptr) { continue; }
metric->Init(train_data_->metadata(), train_data_->num_data());
train_metric_.push_back(std::move(metric));
}
train_metric_.shrink_to_fit();
// reset the boosting
boosting_->ResetTrainingData(&config_.boosting_config, train_data_, boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_)); objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
} }
void ResetConfig(const char* parameters) { void ResetConfig(const char* parameters) {
std::lock_guard<std::mutex> lock(mutex_);
auto param = ConfigBase::Str2Map(parameters); auto param = ConfigBase::Str2Map(parameters);
if (param.count("num_class")) { if (param.count("num_class")) {
Log::Fatal("cannot change num class during training"); Log::Fatal("cannot change num class during training");
...@@ -77,21 +100,28 @@ public: ...@@ -77,21 +100,28 @@ public:
if (param.count("metric")) { if (param.count("metric")) {
Log::Fatal("cannot change metric during training"); Log::Fatal("cannot change metric during training");
} }
{
std::lock_guard<std::mutex> lock(mutex_); config_.Set(param);
config_.Set(param);
}
if (config_.num_threads > 0) { if (config_.num_threads > 0) {
std::lock_guard<std::mutex> lock(mutex_);
omp_set_num_threads(config_.num_threads); omp_set_num_threads(config_.num_threads);
} }
if (param.size() == 1 && (param.count("learning_rate") || param.count("shrinkage_rate"))) {
// only need to set learning rate if (param.count("objective")) {
std::lock_guard<std::mutex> lock(mutex_); // create objective function
boosting_->ResetShrinkageRate(config_.boosting_config.learning_rate); objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
} else { config_.objective_config));
ResetTrainingData(train_data_); if (objective_fun_ == nullptr) {
Log::Warning("Using self-defined objective function");
}
// initialize the objective function
if (objective_fun_ != nullptr) {
objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
}
} }
boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
} }
void AddValidData(const Dataset* valid_data) { void AddValidData(const Dataset* valid_data) {
...@@ -107,6 +137,7 @@ public: ...@@ -107,6 +137,7 @@ public:
boosting_->AddValidDataset(valid_data, boosting_->AddValidDataset(valid_data,
Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back())); Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
} }
bool TrainOneIter() { bool TrainOneIter() {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
return boosting_->TrainOneIter(nullptr, nullptr, false); return boosting_->TrainOneIter(nullptr, nullptr, false);
...@@ -142,10 +173,12 @@ public: ...@@ -142,10 +173,12 @@ public:
} }
std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) { std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) {
std::lock_guard<std::mutex> lock(mutex_);
return predictor_->GetPredictFunction()(features); return predictor_->GetPredictFunction()(features);
} }
void PredictForFile(const char* data_filename, const char* result_filename, bool data_has_header) { void PredictForFile(const char* data_filename, const char* result_filename, bool data_has_header) {
std::lock_guard<std::mutex> lock(mutex_);
predictor_->Predict(data_filename, result_filename, data_has_header); predictor_->Predict(data_filename, result_filename, data_has_header);
} }
...@@ -180,29 +213,6 @@ public: ...@@ -180,29 +213,6 @@ public:
private: private:
void ConstructObjectAndTrainingMetrics(const Dataset* train_data) {
// create objective function
objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
config_.objective_config));
if (objective_fun_ == nullptr) {
Log::Warning("Using self-defined objective functions");
}
// create training metric
train_metric_.clear();
for (auto metric_type : config_.metric_types) {
auto metric = std::unique_ptr<Metric>(
Metric::CreateMetric(metric_type, config_.metric_config));
if (metric == nullptr) { continue; }
metric->Init(train_data->metadata(), train_data->num_data());
train_metric_.push_back(std::move(metric));
}
train_metric_.shrink_to_fit();
// initialize the objective function
if (objective_fun_ != nullptr) {
objective_fun_->Init(train_data->metadata(), train_data->num_data());
}
}
const Dataset* train_data_; const Dataset* train_data_;
std::unique_ptr<Boosting> boosting_; std::unique_ptr<Boosting> boosting_;
/*! \brief All configs */ /*! \brief All configs */
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
namespace LightGBM { namespace LightGBM {
DataParallelTreeLearner::DataParallelTreeLearner(const TreeConfig& tree_config) DataParallelTreeLearner::DataParallelTreeLearner(const TreeConfig* tree_config)
:SerialTreeLearner(tree_config) { :SerialTreeLearner(tree_config) {
} }
...@@ -37,10 +37,13 @@ void DataParallelTreeLearner::Init(const Dataset* train_data) { ...@@ -37,10 +37,13 @@ void DataParallelTreeLearner::Init(const Dataset* train_data) {
buffer_write_start_pos_.resize(num_features_); buffer_write_start_pos_.resize(num_features_);
buffer_read_start_pos_.resize(num_features_); buffer_read_start_pos_.resize(num_features_);
global_data_count_in_leaf_.resize(tree_config_.num_leaves); global_data_count_in_leaf_.resize(tree_config_->num_leaves);
} }
void DataParallelTreeLearner::ResetConfig(const TreeConfig* tree_config) {
SerialTreeLearner::ResetConfig(tree_config);
global_data_count_in_leaf_.resize(tree_config_->num_leaves);
}
void DataParallelTreeLearner::BeforeTrain() { void DataParallelTreeLearner::BeforeTrain() {
SerialTreeLearner::BeforeTrain(); SerialTreeLearner::BeforeTrain();
......
...@@ -276,6 +276,10 @@ public: ...@@ -276,6 +276,10 @@ public:
*/ */
void set_is_splittable(bool val) { is_splittable_ = val; } void set_is_splittable(bool val) { is_splittable_ = val; }
void ResetConfig(const TreeConfig* tree_config) {
tree_config_ = tree_config;
}
private: private:
/*! /*!
* \brief Calculate the split gain based on regularized sum_gradients and sum_hessians * \brief Calculate the split gain based on regularized sum_gradients and sum_hessians
...@@ -336,6 +340,8 @@ public: ...@@ -336,6 +340,8 @@ public:
* \brief Constructor * \brief Constructor
*/ */
HistogramPool() { HistogramPool() {
cache_size_ = 0;
total_size_ = 0;
} }
/*! /*!
...@@ -348,7 +354,7 @@ public: ...@@ -348,7 +354,7 @@ public:
* \param cache_size Max cache size * \param cache_size Max cache size
* \param total_size Total size will be used * \param total_size Total size will be used
*/ */
void ResetSize(int cache_size, int total_size) { void Reset(int cache_size, int total_size) {
cache_size_ = cache_size; cache_size_ = cache_size;
// at least need 2 bucket to store smaller leaf and larger leaf // at least need 2 bucket to store smaller leaf and larger leaf
CHECK(cache_size_ >= 2); CHECK(cache_size_ >= 2);
...@@ -382,6 +388,7 @@ public: ...@@ -382,6 +388,7 @@ public:
* \param obj_create_fun that used to generate object * \param obj_create_fun that used to generate object
*/ */
void Fill(std::function<FeatureHistogram*()> obj_create_fun) { void Fill(std::function<FeatureHistogram*()> obj_create_fun) {
fill_func_ = obj_create_fun;
pool_.clear(); pool_.clear();
pool_.resize(cache_size_); pool_.resize(cache_size_);
for (int i = 0; i < cache_size_; ++i) { for (int i = 0; i < cache_size_; ++i) {
...@@ -389,6 +396,23 @@ public: ...@@ -389,6 +396,23 @@ public:
} }
} }
void DynamicChangeSize(int cache_size, int total_size) {
int old_cache_size = cache_size_;
Reset(cache_size, total_size);
pool_.resize(cache_size_);
for (int i = old_cache_size; i < cache_size_; ++i) {
pool_[i].reset(fill_func_());
}
}
void ResetConfig(const TreeConfig* tree_config, int array_size) {
for (int i = 0; i < cache_size_; ++i) {
auto data_ptr = pool_[i].get();
for (int j = 0; j < array_size; ++j) {
data_ptr[j].ResetConfig(tree_config);
}
}
}
/*! /*!
* \brief Get data for the specific index * \brief Get data for the specific index
* \param idx which index want to get * \param idx which index want to get
...@@ -446,6 +470,7 @@ public: ...@@ -446,6 +470,7 @@ public:
private: private:
std::vector<std::unique_ptr<FeatureHistogram[]>> pool_; std::vector<std::unique_ptr<FeatureHistogram[]>> pool_;
std::function<FeatureHistogram*()> fill_func_;
int cache_size_; int cache_size_;
int total_size_; int total_size_;
bool is_enough_ = false; bool is_enough_ = false;
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace LightGBM { namespace LightGBM {
FeatureParallelTreeLearner::FeatureParallelTreeLearner(const TreeConfig& tree_config) FeatureParallelTreeLearner::FeatureParallelTreeLearner(const TreeConfig* tree_config)
:SerialTreeLearner(tree_config) { :SerialTreeLearner(tree_config) {
} }
......
...@@ -20,7 +20,7 @@ namespace LightGBM { ...@@ -20,7 +20,7 @@ namespace LightGBM {
*/ */
class FeatureParallelTreeLearner: public SerialTreeLearner { class FeatureParallelTreeLearner: public SerialTreeLearner {
public: public:
explicit FeatureParallelTreeLearner(const TreeConfig& tree_config); explicit FeatureParallelTreeLearner(const TreeConfig* tree_config);
~FeatureParallelTreeLearner(); ~FeatureParallelTreeLearner();
virtual void Init(const Dataset* train_data); virtual void Init(const Dataset* train_data);
...@@ -45,9 +45,10 @@ private: ...@@ -45,9 +45,10 @@ private:
*/ */
class DataParallelTreeLearner: public SerialTreeLearner { class DataParallelTreeLearner: public SerialTreeLearner {
public: public:
explicit DataParallelTreeLearner(const TreeConfig& tree_config); explicit DataParallelTreeLearner(const TreeConfig* tree_config);
~DataParallelTreeLearner(); ~DataParallelTreeLearner();
void Init(const Dataset* train_data) override; void Init(const Dataset* train_data) override;
void ResetConfig(const TreeConfig* tree_config) override;
protected: protected:
void BeforeTrain() override; void BeforeTrain() override;
void FindBestThresholds() override; void FindBestThresholds() override;
...@@ -96,10 +97,10 @@ private: ...@@ -96,10 +97,10 @@ private:
*/ */
class VotingParallelTreeLearner: public SerialTreeLearner { class VotingParallelTreeLearner: public SerialTreeLearner {
public: public:
explicit VotingParallelTreeLearner(const TreeConfig& tree_config); explicit VotingParallelTreeLearner(const TreeConfig* tree_config);
~VotingParallelTreeLearner() { } ~VotingParallelTreeLearner() { }
void Init(const Dataset* train_data) override; void Init(const Dataset* train_data) override;
void ResetConfig(const TreeConfig* tree_config) override;
protected: protected:
void BeforeTrain() override; void BeforeTrain() override;
bool BeforeFindBestSplit(int left_leaf, int right_leaf) override; bool BeforeFindBestSplit(int left_leaf, int right_leaf) override;
......
...@@ -7,9 +7,9 @@ ...@@ -7,9 +7,9 @@
namespace LightGBM { namespace LightGBM {
SerialTreeLearner::SerialTreeLearner(const TreeConfig& tree_config) SerialTreeLearner::SerialTreeLearner(const TreeConfig* tree_config)
:tree_config_(tree_config){ :tree_config_(tree_config){
random_ = Random(tree_config.feature_fraction_seed); random_ = Random(tree_config_->feature_fraction_seed);
} }
SerialTreeLearner::~SerialTreeLearner() { SerialTreeLearner::~SerialTreeLearner() {
...@@ -22,32 +22,32 @@ void SerialTreeLearner::Init(const Dataset* train_data) { ...@@ -22,32 +22,32 @@ void SerialTreeLearner::Init(const Dataset* train_data) {
num_features_ = train_data_->num_features(); num_features_ = train_data_->num_features();
int max_cache_size = 0; int max_cache_size = 0;
// Get the max size of pool // Get the max size of pool
if (tree_config_.histogram_pool_size < 0) { if (tree_config_->histogram_pool_size <= 0) {
max_cache_size = tree_config_.num_leaves; max_cache_size = tree_config_->num_leaves;
} else { } else {
size_t total_histogram_size = 0; size_t total_histogram_size = 0;
for (int i = 0; i < train_data_->num_features(); ++i) { for (int i = 0; i < train_data_->num_features(); ++i) {
total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureAt(i)->num_bin(); total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureAt(i)->num_bin();
} }
max_cache_size = static_cast<int>(tree_config_.histogram_pool_size * 1024 * 1024 / total_histogram_size); max_cache_size = static_cast<int>(tree_config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
} }
// at least need 2 leaves // at least need 2 leaves
max_cache_size = std::max(2, max_cache_size); max_cache_size = std::max(2, max_cache_size);
max_cache_size = std::min(max_cache_size, tree_config_.num_leaves); max_cache_size = std::min(max_cache_size, tree_config_->num_leaves);
histogram_pool_.ResetSize(max_cache_size, tree_config_.num_leaves); histogram_pool_.Reset(max_cache_size, tree_config_->num_leaves);
auto histogram_create_function = [this]() { auto histogram_create_function = [this]() {
auto tmp_histogram_array = std::unique_ptr<FeatureHistogram[]>(new FeatureHistogram[train_data_->num_features()]); auto tmp_histogram_array = std::unique_ptr<FeatureHistogram[]>(new FeatureHistogram[train_data_->num_features()]);
for (int j = 0; j < train_data_->num_features(); ++j) { for (int j = 0; j < train_data_->num_features(); ++j) {
tmp_histogram_array[j].Init(train_data_->FeatureAt(j), tmp_histogram_array[j].Init(train_data_->FeatureAt(j),
j, &tree_config_); j, tree_config_);
} }
return tmp_histogram_array.release(); return tmp_histogram_array.release();
}; };
histogram_pool_.Fill(histogram_create_function); histogram_pool_.Fill(histogram_create_function);
// push split information for all leaves // push split information for all leaves
best_split_per_leaf_.resize(tree_config_.num_leaves); best_split_per_leaf_.resize(tree_config_->num_leaves);
// initialize ordered_bins_ with nullptr // initialize ordered_bins_ with nullptr
ordered_bins_.resize(num_features_); ordered_bins_.resize(num_features_);
...@@ -69,7 +69,7 @@ void SerialTreeLearner::Init(const Dataset* train_data) { ...@@ -69,7 +69,7 @@ void SerialTreeLearner::Init(const Dataset* train_data) {
larger_leaf_splits_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data())); larger_leaf_splits_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data()));
// initialize data partition // initialize data partition
data_partition_.reset(new DataPartition(num_data_, tree_config_.num_leaves)); data_partition_.reset(new DataPartition(num_data_, tree_config_->num_leaves));
is_feature_used_.resize(num_features_); is_feature_used_.resize(num_features_);
...@@ -84,19 +84,49 @@ void SerialTreeLearner::Init(const Dataset* train_data) { ...@@ -84,19 +84,49 @@ void SerialTreeLearner::Init(const Dataset* train_data) {
} }
void SerialTreeLearner::ResetConfig(const TreeConfig* tree_config) {
if (tree_config_->num_leaves != tree_config->num_leaves) {
tree_config_ = tree_config;
int max_cache_size = 0;
// Get the max size of pool
if (tree_config->histogram_pool_size <= 0) {
max_cache_size = tree_config_->num_leaves;
} else {
size_t total_histogram_size = 0;
for (int i = 0; i < train_data_->num_features(); ++i) {
total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureAt(i)->num_bin();
}
max_cache_size = static_cast<int>(tree_config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
}
// at least need 2 leaves
max_cache_size = std::max(2, max_cache_size);
max_cache_size = std::min(max_cache_size, tree_config_->num_leaves);
histogram_pool_.DynamicChangeSize(max_cache_size, tree_config_->num_leaves);
// push split information for all leaves
best_split_per_leaf_.resize(tree_config_->num_leaves);
data_partition_.reset(new DataPartition(num_data_, tree_config_->num_leaves));
} else {
tree_config_ = tree_config;
}
histogram_pool_.ResetConfig(tree_config_, train_data_->num_features());
random_ = Random(tree_config_->feature_fraction_seed);
}
Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians) { Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians) {
gradients_ = gradients; gradients_ = gradients;
hessians_ = hessians; hessians_ = hessians;
// some initial works before training // some initial works before training
BeforeTrain(); BeforeTrain();
auto tree = std::unique_ptr<Tree>(new Tree(tree_config_.num_leaves)); auto tree = std::unique_ptr<Tree>(new Tree(tree_config_->num_leaves));
// save pointer to last trained tree // save pointer to last trained tree
last_trained_tree_ = tree.get(); last_trained_tree_ = tree.get();
// root leaf // root leaf
int left_leaf = 0; int left_leaf = 0;
// only root leaf can be splitted on first time // only root leaf can be splitted on first time
int right_leaf = -1; int right_leaf = -1;
for (int split = 0; split < tree_config_.num_leaves - 1; split++) { for (int split = 0; split < tree_config_->num_leaves - 1; split++) {
// some initial works before finding best split // some initial works before finding best split
if (BeforeFindBestSplit(left_leaf, right_leaf)) { if (BeforeFindBestSplit(left_leaf, right_leaf)) {
// find best threshold for every feature // find best threshold for every feature
...@@ -121,6 +151,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians ...@@ -121,6 +151,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
} }
void SerialTreeLearner::BeforeTrain() { void SerialTreeLearner::BeforeTrain() {
// reset histogram pool // reset histogram pool
histogram_pool_.ResetMap(); histogram_pool_.ResetMap();
// initialize used features // initialize used features
...@@ -128,7 +159,7 @@ void SerialTreeLearner::BeforeTrain() { ...@@ -128,7 +159,7 @@ void SerialTreeLearner::BeforeTrain() {
is_feature_used_[i] = false; is_feature_used_[i] = false;
} }
// Get used feature at current tree // Get used feature at current tree
int used_feature_cnt = static_cast<int>(num_features_*tree_config_.feature_fraction); int used_feature_cnt = static_cast<int>(num_features_*tree_config_->feature_fraction);
auto used_feature_indices = random_.Sample(num_features_, used_feature_cnt); auto used_feature_indices = random_.Sample(num_features_, used_feature_cnt);
for (auto idx : used_feature_indices) { for (auto idx : used_feature_indices) {
is_feature_used_[idx] = true; is_feature_used_[idx] = true;
...@@ -138,7 +169,7 @@ void SerialTreeLearner::BeforeTrain() { ...@@ -138,7 +169,7 @@ void SerialTreeLearner::BeforeTrain() {
data_partition_->Init(); data_partition_->Init();
// reset the splits for leaves // reset the splits for leaves
for (int i = 0; i < tree_config_.num_leaves; ++i) { for (int i = 0; i < tree_config_->num_leaves; ++i) {
best_split_per_leaf_[i].Reset(); best_split_per_leaf_[i].Reset();
} }
...@@ -177,7 +208,7 @@ void SerialTreeLearner::BeforeTrain() { ...@@ -177,7 +208,7 @@ void SerialTreeLearner::BeforeTrain() {
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
for (int i = 0; i < num_features_; ++i) { for (int i = 0; i < num_features_; ++i) {
if (ordered_bins_[i] != nullptr) { if (ordered_bins_[i] != nullptr) {
ordered_bins_[i]->Init(nullptr, tree_config_.num_leaves); ordered_bins_[i]->Init(nullptr, tree_config_->num_leaves);
} }
} }
} else { } else {
...@@ -196,7 +227,7 @@ void SerialTreeLearner::BeforeTrain() { ...@@ -196,7 +227,7 @@ void SerialTreeLearner::BeforeTrain() {
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
for (int i = 0; i < num_features_; ++i) { for (int i = 0; i < num_features_; ++i) {
if (ordered_bins_[i] != nullptr) { if (ordered_bins_[i] != nullptr) {
ordered_bins_[i]->Init(is_data_in_leaf_.data(), tree_config_.num_leaves); ordered_bins_[i]->Init(is_data_in_leaf_.data(), tree_config_->num_leaves);
} }
} }
} }
...@@ -205,9 +236,9 @@ void SerialTreeLearner::BeforeTrain() { ...@@ -205,9 +236,9 @@ void SerialTreeLearner::BeforeTrain() {
bool SerialTreeLearner::BeforeFindBestSplit(int left_leaf, int right_leaf) { bool SerialTreeLearner::BeforeFindBestSplit(int left_leaf, int right_leaf) {
// check depth of current leaf // check depth of current leaf
if (tree_config_.max_depth > 0) { if (tree_config_->max_depth > 0) {
// only need to check left leaf, since right leaf is in same level of left leaf // only need to check left leaf, since right leaf is in same level of left leaf
if (last_trained_tree_->leaf_depth(left_leaf) >= tree_config_.max_depth) { if (last_trained_tree_->leaf_depth(left_leaf) >= tree_config_->max_depth) {
best_split_per_leaf_[left_leaf].gain = kMinScore; best_split_per_leaf_[left_leaf].gain = kMinScore;
if (right_leaf >= 0) { if (right_leaf >= 0) {
best_split_per_leaf_[right_leaf].gain = kMinScore; best_split_per_leaf_[right_leaf].gain = kMinScore;
...@@ -218,8 +249,8 @@ bool SerialTreeLearner::BeforeFindBestSplit(int left_leaf, int right_leaf) { ...@@ -218,8 +249,8 @@ bool SerialTreeLearner::BeforeFindBestSplit(int left_leaf, int right_leaf) {
data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf); data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf);
data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf); data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf);
// no enough data to continue // no enough data to continue
if (num_data_in_right_child < static_cast<data_size_t>(tree_config_.min_data_in_leaf * 2) if (num_data_in_right_child < static_cast<data_size_t>(tree_config_->min_data_in_leaf * 2)
&& num_data_in_left_child < static_cast<data_size_t>(tree_config_.min_data_in_leaf * 2)) { && num_data_in_left_child < static_cast<data_size_t>(tree_config_->min_data_in_leaf * 2)) {
best_split_per_leaf_[left_leaf].gain = kMinScore; best_split_per_leaf_[left_leaf].gain = kMinScore;
if (right_leaf >= 0) { if (right_leaf >= 0) {
best_split_per_leaf_[right_leaf].gain = kMinScore; best_split_per_leaf_[right_leaf].gain = kMinScore;
......
...@@ -26,12 +26,14 @@ namespace LightGBM { ...@@ -26,12 +26,14 @@ namespace LightGBM {
*/ */
class SerialTreeLearner: public TreeLearner { class SerialTreeLearner: public TreeLearner {
public: public:
explicit SerialTreeLearner(const TreeConfig& tree_config); explicit SerialTreeLearner(const TreeConfig* tree_config);
~SerialTreeLearner(); ~SerialTreeLearner();
void Init(const Dataset* train_data) override; void Init(const Dataset* train_data) override;
void ResetConfig(const TreeConfig* tree_config) override;
Tree* Train(const score_t* gradients, const score_t *hessians) override; Tree* Train(const score_t* gradients, const score_t *hessians) override;
void SetBaggingData(const data_size_t* used_indices, data_size_t num_data) override { void SetBaggingData(const data_size_t* used_indices, data_size_t num_data) override {
...@@ -153,7 +155,7 @@ protected: ...@@ -153,7 +155,7 @@ protected:
/*! \brief used to cache historical histogram to speed up*/ /*! \brief used to cache historical histogram to speed up*/
HistogramPool histogram_pool_; HistogramPool histogram_pool_;
/*! \brief config of tree learner*/ /*! \brief config of tree learner*/
const TreeConfig& tree_config_; const TreeConfig* tree_config_;
}; };
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
namespace LightGBM { namespace LightGBM {
TreeLearner* TreeLearner::CreateTreeLearner(TreeLearnerType type, const TreeConfig& tree_config) { TreeLearner* TreeLearner::CreateTreeLearner(TreeLearnerType type, const TreeConfig* tree_config) {
if (type == TreeLearnerType::kSerialTreeLearner) { if (type == TreeLearnerType::kSerialTreeLearner) {
return new SerialTreeLearner(tree_config); return new SerialTreeLearner(tree_config);
} else if (type == TreeLearnerType::kFeatureParallelTreelearner) { } else if (type == TreeLearnerType::kFeatureParallelTreelearner) {
......
...@@ -9,9 +9,9 @@ ...@@ -9,9 +9,9 @@
namespace LightGBM { namespace LightGBM {
VotingParallelTreeLearner::VotingParallelTreeLearner(const TreeConfig& tree_config) VotingParallelTreeLearner::VotingParallelTreeLearner(const TreeConfig* tree_config)
:SerialTreeLearner(tree_config) { :SerialTreeLearner(tree_config) {
top_k_ = tree_config.top_k; top_k_ = tree_config_->top_k;
} }
void VotingParallelTreeLearner::Init(const Dataset* train_data) { void VotingParallelTreeLearner::Init(const Dataset* train_data) {
...@@ -44,34 +44,41 @@ void VotingParallelTreeLearner::Init(const Dataset* train_data) { ...@@ -44,34 +44,41 @@ void VotingParallelTreeLearner::Init(const Dataset* train_data) {
smaller_buffer_read_start_pos_.resize(num_features_); smaller_buffer_read_start_pos_.resize(num_features_);
larger_buffer_read_start_pos_.resize(num_features_); larger_buffer_read_start_pos_.resize(num_features_);
global_data_count_in_leaf_.resize(tree_config_.num_leaves); global_data_count_in_leaf_.resize(tree_config_->num_leaves);
smaller_leaf_splits_global_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data())); smaller_leaf_splits_global_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data()));
larger_leaf_splits_global_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data())); larger_leaf_splits_global_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data()));
local_tree_config_ = tree_config_; local_tree_config_ = *tree_config_;
local_tree_config_.min_data_in_leaf /= num_machines_; local_tree_config_.min_data_in_leaf /= num_machines_;
local_tree_config_.min_sum_hessian_in_leaf /= num_machines_; local_tree_config_.min_sum_hessian_in_leaf /= num_machines_;
auto histogram_create_function = [this]() { histogram_pool_.ResetConfig(&local_tree_config_, train_data_->num_features());
auto tmp_histogram_array = std::unique_ptr<FeatureHistogram[]>(new FeatureHistogram[train_data_->num_features()]);
for (int j = 0; j < train_data_->num_features(); ++j) {
tmp_histogram_array[j].Init(train_data_->FeatureAt(j),
j, &local_tree_config_);
}
return tmp_histogram_array.release();
};
histogram_pool_.Fill(histogram_create_function);
// initialize histograms for global // initialize histograms for global
smaller_leaf_histogram_array_global_.reset(new FeatureHistogram[num_features_]); smaller_leaf_histogram_array_global_.reset(new FeatureHistogram[num_features_]);
larger_leaf_histogram_array_global_.reset(new FeatureHistogram[num_features_]); larger_leaf_histogram_array_global_.reset(new FeatureHistogram[num_features_]);
for (int j = 0; j < num_features_; ++j) { for (int j = 0; j < num_features_; ++j) {
smaller_leaf_histogram_array_global_[j].Init(train_data_->FeatureAt(j), j, &tree_config_); smaller_leaf_histogram_array_global_[j].Init(train_data_->FeatureAt(j), j, tree_config_);
larger_leaf_histogram_array_global_[j].Init(train_data_->FeatureAt(j), j, &tree_config_); larger_leaf_histogram_array_global_[j].Init(train_data_->FeatureAt(j), j, tree_config_);
} }
} }
void VotingParallelTreeLearner::ResetConfig(const TreeConfig* tree_config) {
SerialTreeLearner::ResetConfig(tree_config);
local_tree_config_ = *tree_config_;
local_tree_config_.min_data_in_leaf /= num_machines_;
local_tree_config_.min_sum_hessian_in_leaf /= num_machines_;
histogram_pool_.ResetConfig(&local_tree_config_, train_data_->num_features());
global_data_count_in_leaf_.resize(tree_config_->num_leaves);
for (int j = 0; j < num_features_; ++j) {
smaller_leaf_histogram_array_global_[j].ResetConfig(tree_config_);
larger_leaf_histogram_array_global_[j].ResetConfig(tree_config_);
}
}
void VotingParallelTreeLearner::BeforeTrain() { void VotingParallelTreeLearner::BeforeTrain() {
SerialTreeLearner::BeforeTrain(); SerialTreeLearner::BeforeTrain();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment