Commit 63eddae0 authored by Guolin Ke's avatar Guolin Ke
Browse files

provide a light weight interface for reset learning rate

parent 19512d82
...@@ -51,6 +51,12 @@ public: ...@@ -51,6 +51,12 @@ public:
*/ */
virtual void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) = 0; virtual void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) = 0;
/*!
* \brief Reset shrinkage_rate data for current boosting
* \param shrinkage_rate Configs for boosting
*/
virtual void ResetShrinkageRate(double shrinkage_rate) = 0;
/*! /*!
* \brief Add a validation data * \brief Add a validation data
* \param valid_data Validation data * \param valid_data Validation data
......
...@@ -128,9 +128,11 @@ def param_dict_to_str(data): ...@@ -128,9 +128,11 @@ def param_dict_to_str(data):
return "" return ""
pairs = [] pairs = []
for key, val in data.items(): for key, val in data.items():
if isinstance(val, list): if is_str(val):
pairs.append(str(key)+'='+','.join(val)) pairs.append(str(key)+'='+str(val))
elif isinstance(val, (int, float, str, bool)): elif isinstance(val, (list, tuple)):
pairs.append(str(key)+'='+','.join(map(str,val)))
elif isinstance(val, (int, float, bool)):
pairs.append(str(key)+'='+str(val)) pairs.append(str(key)+'='+str(val))
else: else:
raise TypeError('unknow type of parameter:%s , got:%s' %(key, type(val).__name__)) raise TypeError('unknow type of parameter:%s , got:%s' %(key, type(val).__name__))
......
...@@ -144,6 +144,11 @@ void Application::LoadData() { ...@@ -144,6 +144,11 @@ void Application::LoadData() {
} }
} }
train_metric_.shrink_to_fit(); train_metric_.shrink_to_fit();
if (config_.metric_types.size() > 0) {
// only when have metrics then need to construct validation data
// Add validation data, if it exists // Add validation data, if it exists
for (size_t i = 0; i < config_.io_config.valid_data_filenames.size(); ++i) { for (size_t i = 0; i < config_.io_config.valid_data_filenames.size(); ++i) {
// add // add
...@@ -171,6 +176,7 @@ void Application::LoadData() { ...@@ -171,6 +176,7 @@ void Application::LoadData() {
} }
valid_datas_.shrink_to_fit(); valid_datas_.shrink_to_fit();
valid_metrics_.shrink_to_fit(); valid_metrics_.shrink_to_fit();
}
auto end_time = std::chrono::high_resolution_clock::now(); auto end_time = std::chrono::high_resolution_clock::now();
// output used time on each iteration // output used time on each iteration
Log::Info("Finished loading data in %f seconds", Log::Info("Finished loading data in %f seconds",
......
...@@ -68,6 +68,14 @@ public: ...@@ -68,6 +68,14 @@ public:
*/ */
void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) override; void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) override;
/*!
* \brief Reset shrinkage_rate data for current boosting
* \param shrinkage_rate Configs for boosting
*/
void ResetShrinkageRate(double shrinkage_rate) override {
shrinkage_rate_ = shrinkage_rate;
}
/*! /*!
* \brief Adding a validation dataset * \brief Adding a validation dataset
* \param valid_data Validation dataset * \param valid_data Validation dataset
......
...@@ -72,8 +72,13 @@ public: ...@@ -72,8 +72,13 @@ public:
Log::Fatal("cannot change boosting_type during training"); Log::Fatal("cannot change boosting_type during training");
} }
config_.Set(param); config_.Set(param);
if (param.size() == 1 && (param.count("learning_rate") || param.count("shrinkage_rate"))) {
// only need to set learning rate
boosting_->ResetShrinkageRate(config_.boosting_config.learning_rate);
} else {
ResetTrainingData(train_data_); ResetTrainingData(train_data_);
} }
}
void AddValidData(const Dataset* valid_data) { void AddValidData(const Dataset* valid_data) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment