Commit 63eddae0 authored by Guolin Ke's avatar Guolin Ke
Browse files

provide a light weight interface for reset learning rate

parent 19512d82
......@@ -51,6 +51,12 @@ public:
*/
virtual void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) = 0;
/*!
* \brief Reset shrinkage_rate data for current boosting
* \param shrinkage_rate Configs for boosting
*/
virtual void ResetShrinkageRate(double shrinkage_rate) = 0;
/*!
* \brief Add a validation data
* \param valid_data Validation data
......
......@@ -128,9 +128,11 @@ def param_dict_to_str(data):
return ""
pairs = []
for key, val in data.items():
if isinstance(val, list):
pairs.append(str(key)+'='+','.join(val))
elif isinstance(val, (int, float, str, bool)):
if is_str(val):
pairs.append(str(key)+'='+str(val))
elif isinstance(val, (list, tuple)):
pairs.append(str(key)+'='+','.join(map(str,val)))
elif isinstance(val, (int, float, bool)):
pairs.append(str(key)+'='+str(val))
else:
raise TypeError('unknow type of parameter:%s , got:%s' %(key, type(val).__name__))
......
......@@ -144,6 +144,11 @@ void Application::LoadData() {
}
}
train_metric_.shrink_to_fit();
if (config_.metric_types.size() > 0) {
// only when have metrics then need to construct validation data
// Add validation data, if it exists
for (size_t i = 0; i < config_.io_config.valid_data_filenames.size(); ++i) {
// add
......@@ -171,6 +176,7 @@ void Application::LoadData() {
}
valid_datas_.shrink_to_fit();
valid_metrics_.shrink_to_fit();
}
auto end_time = std::chrono::high_resolution_clock::now();
// output used time on each iteration
Log::Info("Finished loading data in %f seconds",
......
......@@ -68,6 +68,14 @@ public:
*/
void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) override;
/*!
* \brief Reset shrinkage_rate data for current boosting
* \param shrinkage_rate Configs for boosting
*/
void ResetShrinkageRate(double shrinkage_rate) override {
shrinkage_rate_ = shrinkage_rate;
}
/*!
* \brief Adding a validation dataset
* \param valid_data Validation dataset
......
......@@ -72,8 +72,13 @@ public:
Log::Fatal("cannot change boosting_type during training");
}
config_.Set(param);
if (param.size() == 1 && (param.count("learning_rate") || param.count("shrinkage_rate"))) {
// only need to set learning rate
boosting_->ResetShrinkageRate(config_.boosting_config.learning_rate);
} else {
ResetTrainingData(train_data_);
}
}
void AddValidData(const Dataset* valid_data) {
std::lock_guard<std::mutex> lock(mutex_);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment