Commit 3484e898 authored by Guolin Ke's avatar Guolin Ke
Browse files

thread safe for booster

parent 9006d3f2
...@@ -128,11 +128,10 @@ public: ...@@ -128,11 +128,10 @@ public:
/*! /*!
* \brief save model to file * \brief save model to file
* \param num_used_model number of model that want to save, -1 means save all * \param num_iterations Iterations that want to save, -1 means save all
* \param is_finish is training finished or not
* \param filename filename that want to save to * \param filename filename that want to save to
*/ */
virtual void SaveModelToFile(int num_used_model, bool is_finish, const char* filename) = 0; virtual void SaveModelToFile(int num_iterations, const char* filename) const = 0;
/*! /*!
* \brief Restore from a serialized string * \brief Restore from a serialized string
......
...@@ -473,7 +473,7 @@ SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std:: ...@@ -473,7 +473,7 @@ SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::
// exception handle and error msg // exception handle and error msg
static std::string& LastErrorMsg() { static std::string err_msg("Everything is fine"); return err_msg; } static std::string& LastErrorMsg() { static thread_local std::string err_msg("Everything is fine"); return err_msg; }
inline void LGBM_SetLastError(const char* msg) { inline void LGBM_SetLastError(const char* msg) {
LastErrorMsg() = msg; LastErrorMsg() = msg;
......
...@@ -89,7 +89,7 @@ private: ...@@ -89,7 +89,7 @@ private:
// a trick to use static variable in header file. // a trick to use static variable in header file.
// May be not good, but avoid to use an additional cpp file // May be not good, but avoid to use an additional cpp file
static LogLevel& GetLevel() { static LogLevel level; return level; } static LogLevel& GetLevel() { static thread_local LogLevel level = LogLevel::Info; return level; }
}; };
......
...@@ -225,11 +225,10 @@ void Application::Train() { ...@@ -225,11 +225,10 @@ void Application::Train() {
// output used time per iteration // output used time per iteration
Log::Info("%f seconds elapsed, finished iteration %d", std::chrono::duration<double, Log::Info("%f seconds elapsed, finished iteration %d", std::chrono::duration<double,
std::milli>(end_time - start_time) * 1e-3, iter + 1); std::milli>(end_time - start_time) * 1e-3, iter + 1);
boosting_->SaveModelToFile(NO_LIMIT, is_finished, config_.io_config.output_model.c_str());
} }
is_finished = true; is_finished = true;
// save model to file // save model to file
boosting_->SaveModelToFile(NO_LIMIT, is_finished, config_.io_config.output_model.c_str()); boosting_->SaveModelToFile(NO_LIMIT, config_.io_config.output_model.c_str());
Log::Info("Finished training"); Log::Info("Finished training");
} }
......
...@@ -67,18 +67,7 @@ public: ...@@ -67,18 +67,7 @@ public:
*out_len = train_score_updater_->num_data() * num_class_; *out_len = train_score_updater_->num_data() * num_class_;
return train_score_updater_->score(); return train_score_updater_->score();
} }
/*!
* \brief save model to file
* \param num_iteration -1 means save all
* \param is_finish is training finished or not
* \param filename filename that want to save to
*/
void SaveModelToFile(int num_iteration, bool is_finish, const char* filename) override {
// only save model once when is_finish = true
if (is_finish && saved_model_size_ < 0) {
GBDT::SaveModelToFile(num_iteration, is_finish, filename);
}
}
/*! /*!
* \brief Get Type name of this boosting object * \brief Get Type name of this boosting object
*/ */
......
...@@ -17,8 +17,7 @@ ...@@ -17,8 +17,7 @@
namespace LightGBM { namespace LightGBM {
GBDT::GBDT() GBDT::GBDT()
:saved_model_size_(-1), :num_iteration_for_pred_(0),
num_iteration_for_pred_(0),
num_init_iteration_(0) { num_init_iteration_(0) {
} }
...@@ -30,7 +29,6 @@ GBDT::~GBDT() { ...@@ -30,7 +29,6 @@ GBDT::~GBDT() {
void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics) { const std::vector<const Metric*>& training_metrics) {
iter_ = 0; iter_ = 0;
saved_model_size_ = -1;
num_iteration_for_pred_ = 0; num_iteration_for_pred_ = 0;
max_feature_idx_ = 0; max_feature_idx_ = 0;
num_class_ = config->num_class; num_class_ = config->num_class;
...@@ -395,56 +393,41 @@ void GBDT::Boosting() { ...@@ -395,56 +393,41 @@ void GBDT::Boosting() {
GetGradients(GetTrainingScore(&num_score), gradients_.data(), hessians_.data()); GetGradients(GetTrainingScore(&num_score), gradients_.data(), hessians_.data());
} }
void GBDT::SaveModelToFile(int num_iteration, bool is_finish, const char* filename) { void GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
// first time to this function, open file /*! \brief File to write models */
if (saved_model_size_ < 0) { std::ofstream outpu_file;
model_output_file_.open(filename); outpu_file.open(filename);
// output model type // output model type
model_output_file_ << Name() << std::endl; outpu_file << Name() << std::endl;
// output number of class // output number of class
model_output_file_ << "num_class=" << num_class_ << std::endl; outpu_file << "num_class=" << num_class_ << std::endl;
// output label index // output label index
model_output_file_ << "label_index=" << label_idx_ << std::endl; outpu_file << "label_index=" << label_idx_ << std::endl;
// output max_feature_idx // output max_feature_idx
model_output_file_ << "max_feature_idx=" << max_feature_idx_ << std::endl; outpu_file << "max_feature_idx=" << max_feature_idx_ << std::endl;
// output objective name // output objective name
if (object_function_ != nullptr) { if (object_function_ != nullptr) {
model_output_file_ << "objective=" << object_function_->GetName() << std::endl; outpu_file << "objective=" << object_function_->GetName() << std::endl;
} }
// output sigmoid parameter // output sigmoid parameter
model_output_file_ << "sigmoid=" << sigmoid_ << std::endl; outpu_file << "sigmoid=" << sigmoid_ << std::endl;
model_output_file_ << std::endl; outpu_file << std::endl;
saved_model_size_ = 0;
}
// already saved
if (!model_output_file_.is_open()) {
return;
}
int num_used_model = 0; int num_used_model = 0;
if (num_iteration == NO_LIMIT) { if (num_iteration == NO_LIMIT) {
num_used_model = static_cast<int>(models_.size()); num_used_model = static_cast<int>(models_.size());
} else { } else {
num_used_model = num_iteration * num_class_; num_used_model = num_iteration * num_class_;
} }
int rest = num_used_model - early_stopping_round_ * num_class_;
// output tree models // output tree models
for (int i = saved_model_size_; i < rest; ++i) { for (int i = 0; i < num_used_model; ++i) {
model_output_file_ << "Tree=" << i << std::endl; outpu_file << "Tree=" << i << std::endl;
model_output_file_ << models_[i]->ToString() << std::endl; outpu_file << models_[i]->ToString() << std::endl;
} }
saved_model_size_ = std::max(saved_model_size_, rest); outpu_file << std::endl << FeatureImportance() << std::endl;
outpu_file.close();
model_output_file_.flush();
// training finished, can close file
if (is_finish) {
for (int i = saved_model_size_; i < num_used_model; ++i) {
model_output_file_ << "Tree=" << i << std::endl;
model_output_file_ << models_[i]->ToString() << std::endl;
}
model_output_file_ << std::endl << FeatureImportance() << std::endl;
model_output_file_.close();
}
} }
void GBDT::LoadModelFromString(const std::string& model_str) { void GBDT::LoadModelFromString(const std::string& model_str) {
......
...@@ -138,11 +138,11 @@ public: ...@@ -138,11 +138,11 @@ public:
/*! /*!
* \brief save model to file * \brief save model to file
* \param num_iteration -1 means save all * \param num_iterations Iterations that want to save, -1 means save all
* \param is_finish is training finished or not
* \param filename filename that want to save to * \param filename filename that want to save to
*/ */
virtual void SaveModelToFile(int num_iteration, bool is_finish, const char* filename) override; virtual void SaveModelToFile(int num_iterations, const char* filename) const override ;
/*! /*!
* \brief Restore from a serialized string * \brief Restore from a serialized string
*/ */
...@@ -274,10 +274,6 @@ protected: ...@@ -274,10 +274,6 @@ protected:
double sigmoid_; double sigmoid_;
/*! \brief Index of label column */ /*! \brief Index of label column */
data_size_t label_idx_; data_size_t label_idx_;
/*! \brief Saved number of models */
int saved_model_size_;
/*! \brief File to write models */
std::ofstream model_output_file_;
/*! \brief number of used model */ /*! \brief number of used model */
int num_iteration_for_pred_; int num_iteration_for_pred_;
/*! \brief Shrinkage rate for one iteration */ /*! \brief Shrinkage rate for one iteration */
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <cstring> #include <cstring>
#include <memory> #include <memory>
#include <stdexcept> #include <stdexcept>
#include <mutex>
#include "./application/predictor.hpp" #include "./application/predictor.hpp"
...@@ -29,6 +30,7 @@ public: ...@@ -29,6 +30,7 @@ public:
Booster(const Dataset* train_data, Booster(const Dataset* train_data,
const char* parameters) { const char* parameters) {
std::unique_lock<std::mutex> lock(mutex_);
auto param = ConfigBase::Str2Map(parameters); auto param = ConfigBase::Str2Map(parameters);
config_.Set(param); config_.Set(param);
// create boosting // create boosting
...@@ -41,48 +43,31 @@ public: ...@@ -41,48 +43,31 @@ public:
// initialize the boosting // initialize the boosting
boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(), boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_)); Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
lock.unlock();
} }
void MergeFrom(const Booster* other) { void MergeFrom(const Booster* other) {
std::unique_lock<std::mutex> lock(mutex_);
boosting_->MergeFrom(other->boosting_.get()); boosting_->MergeFrom(other->boosting_.get());
lock.unlock();
} }
~Booster() { ~Booster() {
} }
void ConstructObjectAndTrainingMetrics(const Dataset* train_data) {
// create objective function
objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
config_.objective_config));
if (objective_fun_ == nullptr) {
Log::Warning("Using self-defined objective functions");
}
// create training metric
train_metric_.clear();
for (auto metric_type : config_.metric_types) {
auto metric = std::unique_ptr<Metric>(
Metric::CreateMetric(metric_type, config_.metric_config));
if (metric == nullptr) { continue; }
metric->Init(train_data->metadata(), train_data->num_data());
train_metric_.push_back(std::move(metric));
}
train_metric_.shrink_to_fit();
// initialize the objective function
if (objective_fun_ != nullptr) {
objective_fun_->Init(train_data->metadata(), train_data->num_data());
}
}
void ResetTrainingData(const Dataset* train_data) { void ResetTrainingData(const Dataset* train_data) {
std::unique_lock<std::mutex> lock(mutex_);
train_data_ = train_data; train_data_ = train_data;
ConstructObjectAndTrainingMetrics(train_data_); ConstructObjectAndTrainingMetrics(train_data_);
// initialize the boosting // initialize the boosting
boosting_->ResetTrainingData(&config_.boosting_config, train_data_, boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_)); objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
lock.unlock();
} }
void ResetConfig(const char* parameters) { void ResetConfig(const char* parameters) {
std::unique_lock<std::mutex> lock(mutex_);
auto param = ConfigBase::Str2Map(parameters); auto param = ConfigBase::Str2Map(parameters);
if (param.count("num_class")) { if (param.count("num_class")) {
Log::Fatal("cannot change num class during training"); Log::Fatal("cannot change num class during training");
...@@ -92,9 +77,11 @@ public: ...@@ -92,9 +77,11 @@ public:
} }
config_.Set(param); config_.Set(param);
ResetTrainingData(train_data_); ResetTrainingData(train_data_);
lock.unlock();
} }
void AddValidData(const Dataset* valid_data) { void AddValidData(const Dataset* valid_data) {
std::unique_lock<std::mutex> lock(mutex_);
valid_metrics_.emplace_back(); valid_metrics_.emplace_back();
for (auto metric_type : config_.metric_types) { for (auto metric_type : config_.metric_types) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config)); auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
...@@ -105,20 +92,30 @@ public: ...@@ -105,20 +92,30 @@ public:
valid_metrics_.back().shrink_to_fit(); valid_metrics_.back().shrink_to_fit();
boosting_->AddValidDataset(valid_data, boosting_->AddValidDataset(valid_data,
Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back())); Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
lock.unlock();
} }
bool TrainOneIter() { bool TrainOneIter() {
return boosting_->TrainOneIter(nullptr, nullptr, false); std::unique_lock<std::mutex> lock(mutex_);
bool ret = boosting_->TrainOneIter(nullptr, nullptr, false);
lock.unlock();
return ret;
} }
bool TrainOneIter(const float* gradients, const float* hessians) { bool TrainOneIter(const float* gradients, const float* hessians) {
return boosting_->TrainOneIter(gradients, hessians, false); std::unique_lock<std::mutex> lock(mutex_);
bool ret = boosting_->TrainOneIter(gradients, hessians, false);
lock.unlock();
return ret;
} }
void RollbackOneIter() { void RollbackOneIter() {
std::unique_lock<std::mutex> lock(mutex_);
boosting_->RollbackOneIter(); boosting_->RollbackOneIter();
lock.unlock();
} }
void PrepareForPrediction(int num_iteration, int predict_type) { void PrepareForPrediction(int num_iteration, int predict_type) {
std::unique_lock<std::mutex> lock(mutex_);
boosting_->SetNumIterationForPred(num_iteration); boosting_->SetNumIterationForPred(num_iteration);
bool is_predict_leaf = false; bool is_predict_leaf = false;
bool is_raw_score = false; bool is_raw_score = false;
...@@ -130,6 +127,7 @@ public: ...@@ -130,6 +127,7 @@ public:
is_raw_score = false; is_raw_score = false;
} }
predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf)); predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf));
lock.unlock();
} }
void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) { void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) {
...@@ -145,7 +143,9 @@ public: ...@@ -145,7 +143,9 @@ public:
} }
void SaveModelToFile(int num_iteration, const char* filename) { void SaveModelToFile(int num_iteration, const char* filename) {
boosting_->SaveModelToFile(num_iteration, true, filename); std::unique_lock<std::mutex> lock(mutex_);
boosting_->SaveModelToFile(num_iteration, filename);
lock.unlock();
} }
int GetEvalCounts() const { int GetEvalCounts() const {
...@@ -170,6 +170,30 @@ public: ...@@ -170,6 +170,30 @@ public:
const Boosting* GetBoosting() const { return boosting_.get(); } const Boosting* GetBoosting() const { return boosting_.get(); }
private: private:
void ConstructObjectAndTrainingMetrics(const Dataset* train_data) {
// create objective function
objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
config_.objective_config));
if (objective_fun_ == nullptr) {
Log::Warning("Using self-defined objective functions");
}
// create training metric
train_metric_.clear();
for (auto metric_type : config_.metric_types) {
auto metric = std::unique_ptr<Metric>(
Metric::CreateMetric(metric_type, config_.metric_config));
if (metric == nullptr) { continue; }
metric->Init(train_data->metadata(), train_data->num_data());
train_metric_.push_back(std::move(metric));
}
train_metric_.shrink_to_fit();
// initialize the objective function
if (objective_fun_ != nullptr) {
objective_fun_->Init(train_data->metadata(), train_data->num_data());
}
}
const Dataset* train_data_; const Dataset* train_data_;
std::unique_ptr<Boosting> boosting_; std::unique_ptr<Boosting> boosting_;
/*! \brief All configs */ /*! \brief All configs */
...@@ -182,7 +206,8 @@ private: ...@@ -182,7 +206,8 @@ private:
std::unique_ptr<ObjectiveFunction> objective_fun_; std::unique_ptr<ObjectiveFunction> objective_fun_;
/*! \brief Using predictor for prediction task */ /*! \brief Using predictor for prediction task */
std::unique_ptr<Predictor> predictor_; std::unique_ptr<Predictor> predictor_;
/*! \brief mutex for threading safe call */
std::mutex mutex_;
}; };
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment