Commit b41e0f0a authored by Guolin Ke's avatar Guolin Ke
Browse files

more flexiable reset config/training data logic for boosting

parent 5b4ee9db
...@@ -40,19 +40,15 @@ public: ...@@ -40,19 +40,15 @@ public:
* \param other * \param other
*/ */
virtual void MergeFrom(const Boosting* other) = 0; virtual void MergeFrom(const Boosting* other) = 0;
/*!
* \brief Reset Config for current boosting
* \param config Configs for boosting
*/
virtual void ResetConfig(const BoostingConfig* config) = 0;
/*! /*!
* \brief Reset training data for current boosting * \brief Reset training data for current boosting
* \param config Configs for boosting
* \param train_data Training data * \param train_data Training data
* \param object_function Training objective function * \param object_function Training objective function
* \param training_metrics Training metric * \param training_metrics Training metric
*/ */
virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) = 0; virtual void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) = 0;
/*! /*!
* \brief Add a validation data * \brief Add a validation data
......
...@@ -72,6 +72,8 @@ public: ...@@ -72,6 +72,8 @@ public:
inline bool GetBool( inline bool GetBool(
const std::unordered_map<std::string, std::string>& params, const std::unordered_map<std::string, std::string>& params,
const std::string& name, bool* out); const std::string& name, bool* out);
static std::unordered_map<std::string, std::string> Str2Map(const char* parameters);
}; };
/*! \brief Types of boosting */ /*! \brief Types of boosting */
...@@ -231,7 +233,7 @@ public: ...@@ -231,7 +233,7 @@ public:
MetricConfig metric_config; MetricConfig metric_config;
void Set(const std::unordered_map<std::string, std::string>& params) override; void Set(const std::unordered_map<std::string, std::string>& params) override;
void LoadFromString(const char* str);
private: private:
void GetBoostingType(const std::unordered_map<std::string, std::string>& params); void GetBoostingType(const std::unordered_map<std::string, std::string>& params);
......
...@@ -29,52 +29,23 @@ GBDT::~GBDT() { ...@@ -29,52 +29,23 @@ GBDT::~GBDT() {
void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics) { const std::vector<const Metric*>& training_metrics) {
gbdt_config_ = config;
iter_ = 0; iter_ = 0;
saved_model_size_ = -1; saved_model_size_ = -1;
num_iteration_for_pred_ = 0; num_iteration_for_pred_ = 0;
max_feature_idx_ = 0; max_feature_idx_ = 0;
early_stopping_round_ = gbdt_config_->early_stopping_round;
shrinkage_rate_ = gbdt_config_->learning_rate;
num_class_ = config->num_class; num_class_ = config->num_class;
train_data_ = nullptr; train_data_ = nullptr;
ResetTrainingData(train_data, object_function, training_metrics); ResetTrainingData(config, train_data, object_function, training_metrics);
// initialize random generator
random_ = Random(gbdt_config_->bagging_seed);
}
void GBDT::ResetConfig(const BoostingConfig* config) {
gbdt_config_ = config;
early_stopping_round_ = gbdt_config_->early_stopping_round;
shrinkage_rate_ = gbdt_config_->learning_rate;
// create tree learner
tree_learner_.clear();
for (int i = 0; i < num_class_; ++i) {
auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config));
new_tree_learner->Init(train_data_);
// init tree learner
tree_learner_.push_back(std::move(new_tree_learner));
}
tree_learner_.shrink_to_fit();
// if need bagging, create buffer
if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) {
out_of_bag_data_indices_ = std::vector<data_size_t>(num_data_);
bag_data_indices_ = std::vector<data_size_t>(num_data_);
} else {
out_of_bag_data_cnt_ = 0;
out_of_bag_data_indices_.clear();
bag_data_cnt_ = num_data_;
bag_data_indices_.clear();
}
// initialize random generator
random_ = Random(gbdt_config_->bagging_seed);
} }
void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) { void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics) {
if (train_data_ != nullptr && !train_data_->CheckAlign(*train_data)) { if (train_data_ != nullptr && !train_data_->CheckAlign(*train_data)) {
Log::Fatal("cannot reset training data, since new training data has different bin mappers"); Log::Fatal("cannot reset training data, since new training data has different bin mappers");
} }
gbdt_config_ = config;
early_stopping_round_ = gbdt_config_->early_stopping_round;
shrinkage_rate_ = gbdt_config_->learning_rate;
train_data_ = train_data; train_data_ = train_data;
// create tree learner // create tree learner
tree_learner_.clear(); tree_learner_.clear();
...@@ -120,6 +91,7 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* ...@@ -120,6 +91,7 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*
bag_data_cnt_ = num_data_; bag_data_cnt_ = num_data_;
bag_data_indices_.clear(); bag_data_indices_.clear();
} }
random_ = Random(gbdt_config_->bagging_seed);
// update score // update score
for (int i = 0; i < iter_; ++i) { for (int i = 0; i < iter_; ++i) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) { for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
......
...@@ -44,19 +44,13 @@ public: ...@@ -44,19 +44,13 @@ public:
} }
} }
/*!
* \brief Reset Config for current boosting
* \param config Configs for boosting
*/
void ResetConfig(const BoostingConfig* config) override;
/*! /*!
* \brief Reset training data for current boosting * \brief Reset training data for current boosting
* \param train_data Training data * \param train_data Training data
* \param object_function Training objective function * \param object_function Training objective function
* \param training_metrics Training metric * \param training_metrics Training metric
*/ */
void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) override; void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) override;
/*! /*!
* \brief Adding a validation dataset * \brief Adding a validation dataset
......
...@@ -29,7 +29,8 @@ public: ...@@ -29,7 +29,8 @@ public:
Booster(const Dataset* train_data, Booster(const Dataset* train_data,
const char* parameters) { const char* parameters) {
config_.LoadFromString(parameters); auto param = ConfigBase::Str2Map(parameters);
config_.Set(param);
// create boosting // create boosting
if (config_.io_config.input_model.size() > 0) { if (config_.io_config.input_model.size() > 0) {
Log::Warning("continued train from model is not support for c_api, \ Log::Warning("continued train from model is not support for c_api, \
...@@ -74,9 +75,23 @@ public: ...@@ -74,9 +75,23 @@ public:
} }
void ResetTrainingData(const Dataset* train_data) { void ResetTrainingData(const Dataset* train_data) {
ConstructObjectAndTrainingMetrics(train_data); train_data_ = train_data;
ConstructObjectAndTrainingMetrics(train_data_);
// initialize the boosting // initialize the boosting
boosting_->ResetTrainingData(train_data, objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_)); boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
}
void ResetConfig(const char* parameters) {
auto param = ConfigBase::Str2Map(parameters);
if (param.count("num_class")) {
Log::Fatal("cannot change num class during training");
}
if (param.count("boosting_type")) {
Log::Fatal("cannot change boosting_type during training");
}
config_.Set(param);
ResetTrainingData(train_data_);
} }
void AddValidData(const Dataset* valid_data) { void AddValidData(const Dataset* valid_data) {
...@@ -154,10 +169,6 @@ public: ...@@ -154,10 +169,6 @@ public:
return idx; return idx;
} }
void ResetBoostingConfig(const char* parameters) {
config_.LoadFromString(parameters);
boosting_->ResetConfig(&config_.boosting_config);
}
void RollbackOneIter() { void RollbackOneIter() {
boosting_->RollbackOneIter(); boosting_->RollbackOneIter();
...@@ -166,6 +177,7 @@ public: ...@@ -166,6 +177,7 @@ public:
const Boosting* GetBoosting() const { return boosting_.get(); } const Boosting* GetBoosting() const { return boosting_.get(); }
private: private:
const Dataset* train_data_;
std::unique_ptr<Boosting> boosting_; std::unique_ptr<Boosting> boosting_;
/*! \brief All configs */ /*! \brief All configs */
OverallConfig config_; OverallConfig config_;
...@@ -193,9 +205,10 @@ DllExport int LGBM_CreateDatasetFromFile(const char* filename, ...@@ -193,9 +205,10 @@ DllExport int LGBM_CreateDatasetFromFile(const char* filename,
const DatesetHandle* reference, const DatesetHandle* reference,
DatesetHandle* out) { DatesetHandle* out) {
API_BEGIN(); API_BEGIN();
OverallConfig config; auto param = ConfigBase::Str2Map(parameters);
config.LoadFromString(parameters); IOConfig io_config;
DatasetLoader loader(config.io_config, nullptr); io_config.Set(param);
DatasetLoader loader(io_config, nullptr);
loader.SetHeader(filename); loader.SetHeader(filename);
if (reference == nullptr) { if (reference == nullptr) {
*out = loader.LoadFromFile(filename); *out = loader.LoadFromFile(filename);
...@@ -224,15 +237,16 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data, ...@@ -224,15 +237,16 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
const DatesetHandle* reference, const DatesetHandle* reference,
DatesetHandle* out) { DatesetHandle* out) {
API_BEGIN(); API_BEGIN();
OverallConfig config; auto param = ConfigBase::Str2Map(parameters);
config.LoadFromString(parameters); IOConfig io_config;
DatasetLoader loader(config.io_config, nullptr); io_config.Set(param);
DatasetLoader loader(io_config, nullptr);
std::unique_ptr<Dataset> ret; std::unique_ptr<Dataset> ret;
auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major); auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
if (reference == nullptr) { if (reference == nullptr) {
// sample data first // sample data first
Random rand(config.io_config.data_random_seed); Random rand(io_config.data_random_seed);
const int sample_cnt = static_cast<int>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt); const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
auto sample_indices = rand.Sample(nrow, sample_cnt); auto sample_indices = rand.Sample(nrow, sample_cnt);
std::vector<std::vector<double>> sample_values(ncol); std::vector<std::vector<double>> sample_values(ncol);
for (size_t i = 0; i < sample_indices.size(); ++i) { for (size_t i = 0; i < sample_indices.size(); ++i) {
...@@ -246,10 +260,10 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data, ...@@ -246,10 +260,10 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
} }
ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow)); ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
} else { } else {
ret.reset(new Dataset(nrow, config.io_config.num_class)); ret.reset(new Dataset(nrow, io_config.num_class));
ret->CopyFeatureMapperFrom( ret->CopyFeatureMapperFrom(
reinterpret_cast<const Dataset*>(*reference), reinterpret_cast<const Dataset*>(*reference),
config.io_config.is_enable_sparse); io_config.is_enable_sparse);
} }
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
...@@ -275,16 +289,17 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr, ...@@ -275,16 +289,17 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
const DatesetHandle* reference, const DatesetHandle* reference,
DatesetHandle* out) { DatesetHandle* out) {
API_BEGIN(); API_BEGIN();
OverallConfig config; auto param = ConfigBase::Str2Map(parameters);
config.LoadFromString(parameters); IOConfig io_config;
DatasetLoader loader(config.io_config, nullptr); io_config.Set(param);
DatasetLoader loader(io_config, nullptr);
std::unique_ptr<Dataset> ret; std::unique_ptr<Dataset> ret;
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
int32_t nrow = static_cast<int32_t>(nindptr - 1); int32_t nrow = static_cast<int32_t>(nindptr - 1);
if (reference == nullptr) { if (reference == nullptr) {
// sample data first // sample data first
Random rand(config.io_config.data_random_seed); Random rand(io_config.data_random_seed);
const int sample_cnt = static_cast<int>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt); const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
auto sample_indices = rand.Sample(nrow, sample_cnt); auto sample_indices = rand.Sample(nrow, sample_cnt);
std::vector<std::vector<double>> sample_values; std::vector<std::vector<double>> sample_values;
for (size_t i = 0; i < sample_indices.size(); ++i) { for (size_t i = 0; i < sample_indices.size(); ++i) {
...@@ -307,10 +322,10 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr, ...@@ -307,10 +322,10 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
CHECK(num_col >= static_cast<int>(sample_values.size())); CHECK(num_col >= static_cast<int>(sample_values.size()));
ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow)); ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
} else { } else {
ret.reset(new Dataset(nrow, config.io_config.num_class)); ret.reset(new Dataset(nrow, io_config.num_class));
ret->CopyFeatureMapperFrom( ret->CopyFeatureMapperFrom(
reinterpret_cast<const Dataset*>(*reference), reinterpret_cast<const Dataset*>(*reference),
config.io_config.is_enable_sparse); io_config.is_enable_sparse);
} }
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
...@@ -336,17 +351,18 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr, ...@@ -336,17 +351,18 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
const DatesetHandle* reference, const DatesetHandle* reference,
DatesetHandle* out) { DatesetHandle* out) {
API_BEGIN(); API_BEGIN();
OverallConfig config; auto param = ConfigBase::Str2Map(parameters);
config.LoadFromString(parameters); IOConfig io_config;
DatasetLoader loader(config.io_config, nullptr); io_config.Set(param);
DatasetLoader loader(io_config, nullptr);
std::unique_ptr<Dataset> ret; std::unique_ptr<Dataset> ret;
auto get_col_fun = ColumnFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem); auto get_col_fun = ColumnFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem);
int32_t nrow = static_cast<int32_t>(num_row); int32_t nrow = static_cast<int32_t>(num_row);
if (reference == nullptr) { if (reference == nullptr) {
Log::Warning("Construct from CSC format is not efficient"); Log::Warning("Construct from CSC format is not efficient");
// sample data first // sample data first
Random rand(config.io_config.data_random_seed); Random rand(io_config.data_random_seed);
const int sample_cnt = static_cast<int>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt); const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
auto sample_indices = rand.Sample(nrow, sample_cnt); auto sample_indices = rand.Sample(nrow, sample_cnt);
std::vector<std::vector<double>> sample_values(ncol_ptr - 1); std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
...@@ -356,10 +372,10 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr, ...@@ -356,10 +372,10 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
} }
ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow)); ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
} else { } else {
ret.reset(new Dataset(nrow, config.io_config.num_class)); ret.reset(new Dataset(nrow, io_config.num_class));
ret->CopyFeatureMapperFrom( ret->CopyFeatureMapperFrom(
reinterpret_cast<const Dataset*>(*reference), reinterpret_cast<const Dataset*>(*reference),
config.io_config.is_enable_sparse); io_config.is_enable_sparse);
} }
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
...@@ -500,7 +516,7 @@ DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle, ...@@ -500,7 +516,7 @@ DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) { DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->ResetBoostingConfig(parameters); ref_booster->ResetConfig(parameters);
API_END(); API_END();
} }
......
...@@ -10,9 +10,9 @@ ...@@ -10,9 +10,9 @@
namespace LightGBM { namespace LightGBM {
void OverallConfig::LoadFromString(const char* str) { std::unordered_map<std::string, std::string> ConfigBase::Str2Map(const char* parameters) {
std::unordered_map<std::string, std::string> params; std::unordered_map<std::string, std::string> params;
auto args = Common::Split(str, " \t\n\r"); auto args = Common::Split(parameters, " \t\n\r");
for (auto arg : args) { for (auto arg : args) {
std::vector<std::string> tmp_strs = Common::Split(arg.c_str(), '='); std::vector<std::string> tmp_strs = Common::Split(arg.c_str(), '=');
if (tmp_strs.size() == 2) { if (tmp_strs.size() == 2) {
...@@ -27,7 +27,7 @@ void OverallConfig::LoadFromString(const char* str) { ...@@ -27,7 +27,7 @@ void OverallConfig::LoadFromString(const char* str) {
} }
} }
ParameterAlias::KeyAliasTransform(&params); ParameterAlias::KeyAliasTransform(&params);
Set(params); return params;
} }
void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) { void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment