Commit 962b7eb0 authored by Guolin Ke's avatar Guolin Ke
Browse files

change to std::lock_guard

parent 3484e898
...@@ -30,7 +30,7 @@ public: ...@@ -30,7 +30,7 @@ public:
Booster(const Dataset* train_data, Booster(const Dataset* train_data,
const char* parameters) { const char* parameters) {
std::unique_lock<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto param = ConfigBase::Str2Map(parameters); auto param = ConfigBase::Str2Map(parameters);
config_.Set(param); config_.Set(param);
// create boosting // create boosting
...@@ -43,13 +43,11 @@ public: ...@@ -43,13 +43,11 @@ public:
// initialize the boosting // initialize the boosting
boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(), boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_)); Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
lock.unlock();
} }
void MergeFrom(const Booster* other) { void MergeFrom(const Booster* other) {
std::unique_lock<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
boosting_->MergeFrom(other->boosting_.get()); boosting_->MergeFrom(other->boosting_.get());
lock.unlock();
} }
~Booster() { ~Booster() {
...@@ -57,17 +55,16 @@ public: ...@@ -57,17 +55,16 @@ public:
} }
void ResetTrainingData(const Dataset* train_data) { void ResetTrainingData(const Dataset* train_data) {
std::unique_lock<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
train_data_ = train_data; train_data_ = train_data;
ConstructObjectAndTrainingMetrics(train_data_); ConstructObjectAndTrainingMetrics(train_data_);
// initialize the boosting // initialize the boosting
boosting_->ResetTrainingData(&config_.boosting_config, train_data_, boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_)); objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
lock.unlock();
} }
void ResetConfig(const char* parameters) { void ResetConfig(const char* parameters) {
std::unique_lock<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto param = ConfigBase::Str2Map(parameters); auto param = ConfigBase::Str2Map(parameters);
if (param.count("num_class")) { if (param.count("num_class")) {
Log::Fatal("cannot change num class during training"); Log::Fatal("cannot change num class during training");
...@@ -77,11 +74,10 @@ public: ...@@ -77,11 +74,10 @@ public:
} }
config_.Set(param); config_.Set(param);
ResetTrainingData(train_data_); ResetTrainingData(train_data_);
lock.unlock();
} }
void AddValidData(const Dataset* valid_data) { void AddValidData(const Dataset* valid_data) {
std::unique_lock<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
valid_metrics_.emplace_back(); valid_metrics_.emplace_back();
for (auto metric_type : config_.metric_types) { for (auto metric_type : config_.metric_types) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config)); auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
...@@ -92,30 +88,24 @@ public: ...@@ -92,30 +88,24 @@ public:
valid_metrics_.back().shrink_to_fit(); valid_metrics_.back().shrink_to_fit();
boosting_->AddValidDataset(valid_data, boosting_->AddValidDataset(valid_data,
Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back())); Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
lock.unlock();
} }
bool TrainOneIter() { bool TrainOneIter() {
std::unique_lock<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
bool ret = boosting_->TrainOneIter(nullptr, nullptr, false); return boosting_->TrainOneIter(nullptr, nullptr, false);
lock.unlock();
return ret;
} }
bool TrainOneIter(const float* gradients, const float* hessians) { bool TrainOneIter(const float* gradients, const float* hessians) {
std::unique_lock<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
bool ret = boosting_->TrainOneIter(gradients, hessians, false); return boosting_->TrainOneIter(gradients, hessians, false);
lock.unlock();
return ret;
} }
void RollbackOneIter() { void RollbackOneIter() {
std::unique_lock<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
boosting_->RollbackOneIter(); boosting_->RollbackOneIter();
lock.unlock();
} }
void PrepareForPrediction(int num_iteration, int predict_type) { void PrepareForPrediction(int num_iteration, int predict_type) {
std::unique_lock<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
boosting_->SetNumIterationForPred(num_iteration); boosting_->SetNumIterationForPred(num_iteration);
bool is_predict_leaf = false; bool is_predict_leaf = false;
bool is_raw_score = false; bool is_raw_score = false;
...@@ -127,7 +117,6 @@ public: ...@@ -127,7 +117,6 @@ public:
is_raw_score = false; is_raw_score = false;
} }
predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf)); predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf));
lock.unlock();
} }
void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) { void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) {
...@@ -143,9 +132,7 @@ public: ...@@ -143,9 +132,7 @@ public:
} }
void SaveModelToFile(int num_iteration, const char* filename) { void SaveModelToFile(int num_iteration, const char* filename) {
std::unique_lock<std::mutex> lock(mutex_);
boosting_->SaveModelToFile(num_iteration, filename); boosting_->SaveModelToFile(num_iteration, filename);
lock.unlock();
} }
int GetEvalCounts() const { int GetEvalCounts() const {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment