Commit 714c6732 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix bugs in reset parameters

parent 7a81f7bd
...@@ -41,6 +41,7 @@ public: ...@@ -41,6 +41,7 @@ public:
please use continued train with input score"); please use continued train with input score");
} }
boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr)); boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
train_data_ = train_data;
ConstructObjectAndTrainingMetrics(train_data); ConstructObjectAndTrainingMetrics(train_data);
// initialize the boosting // initialize the boosting
boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(), boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(),
...@@ -66,7 +67,6 @@ public: ...@@ -66,7 +67,6 @@ public:
} }
void ResetConfig(const char* parameters) { void ResetConfig(const char* parameters) {
std::lock_guard<std::mutex> lock(mutex_);
auto param = ConfigBase::Str2Map(parameters); auto param = ConfigBase::Str2Map(parameters);
if (param.count("num_class")) { if (param.count("num_class")) {
Log::Fatal("cannot change num class during training"); Log::Fatal("cannot change num class during training");
...@@ -74,12 +74,20 @@ public: ...@@ -74,12 +74,20 @@ public:
if (param.count("boosting_type")) { if (param.count("boosting_type")) {
Log::Fatal("cannot change boosting_type during training"); Log::Fatal("cannot change boosting_type during training");
} }
config_.Set(param); if (param.count("metric")) {
Log::Fatal("cannot change metric during training");
}
{
std::lock_guard<std::mutex> lock(mutex_);
config_.Set(param);
}
if (config_.num_threads > 0) { if (config_.num_threads > 0) {
std::lock_guard<std::mutex> lock(mutex_);
omp_set_num_threads(config_.num_threads); omp_set_num_threads(config_.num_threads);
} }
if (param.size() == 1 && (param.count("learning_rate") || param.count("shrinkage_rate"))) { if (param.size() == 1 && (param.count("learning_rate") || param.count("shrinkage_rate"))) {
// only need to set learning rate // only need to set learning rate
std::lock_guard<std::mutex> lock(mutex_);
boosting_->ResetShrinkageRate(config_.boosting_config.learning_rate); boosting_->ResetShrinkageRate(config_.boosting_config.learning_rate);
} else { } else {
ResetTrainingData(train_data_); ResetTrainingData(train_data_);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment