Commit 63eddae0 authored by Guolin Ke's avatar Guolin Ke
Browse files

provide a light weight interface for reset learning rate

parent 19512d82
......@@ -51,6 +51,12 @@ public:
*/
virtual void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) = 0;
/*!
* \brief Reset shrinkage_rate data for current boosting
* \param shrinkage_rate Configs for boosting
*/
virtual void ResetShrinkageRate(double shrinkage_rate) = 0;
/*!
* \brief Add a validation data
* \param valid_data Validation data
......
......@@ -128,9 +128,11 @@ def param_dict_to_str(data):
return ""
pairs = []
for key, val in data.items():
if isinstance(val, list):
pairs.append(str(key)+'='+','.join(val))
elif isinstance(val, (int, float, str, bool)):
if is_str(val):
pairs.append(str(key)+'='+str(val))
elif isinstance(val, (list, tuple)):
pairs.append(str(key)+'='+','.join(map(str,val)))
elif isinstance(val, (int, float, bool)):
pairs.append(str(key)+'='+str(val))
else:
raise TypeError('unknow type of parameter:%s , got:%s' %(key, type(val).__name__))
......
......@@ -144,33 +144,39 @@ void Application::LoadData() {
}
}
train_metric_.shrink_to_fit();
// Add validation data, if it exists
for (size_t i = 0; i < config_.io_config.valid_data_filenames.size(); ++i) {
// add
auto new_dataset = std::unique_ptr<Dataset>(
dataset_loader.LoadFromFileAlignWithOtherDataset(
config_.io_config.valid_data_filenames[i].c_str(),
train_data_.get())
);
valid_datas_.push_back(std::move(new_dataset));
// need save binary file
if (config_.io_config.is_save_binary_file) {
valid_datas_.back()->SaveBinaryFile(nullptr);
}
// add metric for validation data
valid_metrics_.emplace_back();
for (auto metric_type : config_.metric_types) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
if (metric == nullptr) { continue; }
metric->Init(valid_datas_.back()->metadata(),
valid_datas_.back()->num_data());
valid_metrics_.back().push_back(std::move(metric));
if (config_.metric_types.size() > 0) {
// only when have metrics then need to construct validation data
// Add validation data, if it exists
for (size_t i = 0; i < config_.io_config.valid_data_filenames.size(); ++i) {
// add
auto new_dataset = std::unique_ptr<Dataset>(
dataset_loader.LoadFromFileAlignWithOtherDataset(
config_.io_config.valid_data_filenames[i].c_str(),
train_data_.get())
);
valid_datas_.push_back(std::move(new_dataset));
// need save binary file
if (config_.io_config.is_save_binary_file) {
valid_datas_.back()->SaveBinaryFile(nullptr);
}
// add metric for validation data
valid_metrics_.emplace_back();
for (auto metric_type : config_.metric_types) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
if (metric == nullptr) { continue; }
metric->Init(valid_datas_.back()->metadata(),
valid_datas_.back()->num_data());
valid_metrics_.back().push_back(std::move(metric));
}
valid_metrics_.back().shrink_to_fit();
}
valid_metrics_.back().shrink_to_fit();
valid_datas_.shrink_to_fit();
valid_metrics_.shrink_to_fit();
}
valid_datas_.shrink_to_fit();
valid_metrics_.shrink_to_fit();
auto end_time = std::chrono::high_resolution_clock::now();
// output used time on each iteration
Log::Info("Finished loading data in %f seconds",
......
......@@ -68,6 +68,14 @@ public:
*/
void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) override;
/*!
* \brief Reset shrinkage_rate data for current boosting
* \param shrinkage_rate Configs for boosting
*/
void ResetShrinkageRate(double shrinkage_rate) override {
shrinkage_rate_ = shrinkage_rate;
}
/*!
* \brief Adding a validation dataset
* \param valid_data Validation dataset
......
......@@ -72,7 +72,12 @@ public:
Log::Fatal("cannot change boosting_type during training");
}
config_.Set(param);
ResetTrainingData(train_data_);
if (param.size() == 1 && (param.count("learning_rate") || param.count("shrinkage_rate"))) {
// only need to set learning rate
boosting_->ResetShrinkageRate(config_.boosting_config.learning_rate);
} else {
ResetTrainingData(train_data_);
}
}
void AddValidData(const Dataset* valid_data) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment