"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "3ef3a4895e9e13ed38c0b3e0309374ad714e6876"
Commit df358b2d authored by cbecker's avatar cbecker Committed by Guolin Ke
Browse files

Added success return value to LoadFileToBoosting and SaveModelToFile (#234)

parent 82fcfa0e
...@@ -143,14 +143,16 @@ public: ...@@ -143,14 +143,16 @@ public:
* \param num_used_model Number of model that want to save, -1 means save all * \param num_used_model Number of model that want to save, -1 means save all
* \param is_finish Is training finished or not * \param is_finish Is training finished or not
* \param filename Filename that want to save to * \param filename Filename that want to save to
* \return true if succeeded
*/ */
virtual void SaveModelToFile(int num_iterations, const char* filename) const = 0; virtual bool SaveModelToFile(int num_iterations, const char* filename) const = 0;
/*! /*!
* \brief Restore from a serialized string * \brief Restore from a serialized string
* \param model_str The string of model * \param model_str The string of model
* \return true if succeeded
*/ */
virtual void LoadModelFromString(const std::string& model_str) = 0; virtual bool LoadModelFromString(const std::string& model_str) = 0;
/*! /*!
* \brief Get max feature index of this model * \brief Get max feature index of this model
...@@ -192,7 +194,7 @@ public: ...@@ -192,7 +194,7 @@ public:
/*! \brief Disable copy */ /*! \brief Disable copy */
Boosting(const Boosting&) = delete; Boosting(const Boosting&) = delete;
static void LoadFileToBoosting(Boosting* boosting, const char* filename); static bool LoadFileToBoosting(Boosting* boosting, const char* filename);
/*! /*!
* \brief Create boosting object * \brief Create boosting object
......
...@@ -10,7 +10,7 @@ std::string GetBoostingTypeFromModelFile(const char* filename) { ...@@ -10,7 +10,7 @@ std::string GetBoostingTypeFromModelFile(const char* filename) {
return type; return type;
} }
void Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) { bool Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
if (boosting != nullptr) { if (boosting != nullptr) {
TextReader<size_t> model_reader(filename, true); TextReader<size_t> model_reader(filename, true);
model_reader.ReadAllLines(); model_reader.ReadAllLines();
...@@ -18,8 +18,11 @@ void Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) { ...@@ -18,8 +18,11 @@ void Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
for (auto& line : model_reader.Lines()) { for (auto& line : model_reader.Lines()) {
str_buf << line << '\n'; str_buf << line << '\n';
} }
boosting->LoadModelFromString(str_buf.str()); if (!boosting->LoadModelFromString(str_buf.str()))
return false;
} }
return true;
} }
Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename) { Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename) {
......
...@@ -509,7 +509,7 @@ std::string GBDT::DumpModel(int num_iteration) const { ...@@ -509,7 +509,7 @@ std::string GBDT::DumpModel(int num_iteration) const {
return str_buf.str(); return str_buf.str();
} }
void GBDT::SaveModelToFile(int num_iteration, const char* filename) const { bool GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
/*! \brief File to write models */ /*! \brief File to write models */
std::ofstream output_file; std::ofstream output_file;
output_file.open(filename); output_file.open(filename);
...@@ -553,9 +553,11 @@ void GBDT::SaveModelToFile(int num_iteration, const char* filename) const { ...@@ -553,9 +553,11 @@ void GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
} }
output_file.close(); output_file.close();
return (bool)output_file;
} }
void GBDT::LoadModelFromString(const std::string& model_str) { bool GBDT::LoadModelFromString(const std::string& model_str) {
// use serialized string to restore this object // use serialized string to restore this object
models_.clear(); models_.clear();
std::vector<std::string> lines = Common::Split(model_str.c_str(), '\n'); std::vector<std::string> lines = Common::Split(model_str.c_str(), '\n');
...@@ -566,7 +568,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) { ...@@ -566,7 +568,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &num_class_); Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &num_class_);
} else { } else {
Log::Fatal("Model file doesn't specify the number of classes"); Log::Fatal("Model file doesn't specify the number of classes");
return; return false;
} }
// get index of label // get index of label
line = Common::FindFromLines(lines, "label_index="); line = Common::FindFromLines(lines, "label_index=");
...@@ -574,7 +576,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) { ...@@ -574,7 +576,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &label_idx_); Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &label_idx_);
} else { } else {
Log::Fatal("Model file doesn't specify the label index"); Log::Fatal("Model file doesn't specify the label index");
return; return false;
} }
// get max_feature_idx first // get max_feature_idx first
line = Common::FindFromLines(lines, "max_feature_idx="); line = Common::FindFromLines(lines, "max_feature_idx=");
...@@ -582,7 +584,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) { ...@@ -582,7 +584,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &max_feature_idx_); Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &max_feature_idx_);
} else { } else {
Log::Fatal("Model file doesn't specify max_feature_idx"); Log::Fatal("Model file doesn't specify max_feature_idx");
return; return false;
} }
// get sigmoid parameter // get sigmoid parameter
line = Common::FindFromLines(lines, "sigmoid="); line = Common::FindFromLines(lines, "sigmoid=");
...@@ -597,11 +599,11 @@ void GBDT::LoadModelFromString(const std::string& model_str) { ...@@ -597,11 +599,11 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
feature_names_ = Common::Split(line.substr(std::strlen("feature_names=")).c_str(), " "); feature_names_ = Common::Split(line.substr(std::strlen("feature_names=")).c_str(), " ");
if (feature_names_.size() != static_cast<size_t>(max_feature_idx_ + 1)) { if (feature_names_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
Log::Fatal("Wrong size of feature_names"); Log::Fatal("Wrong size of feature_names");
return; return false;
} }
} else { } else {
Log::Fatal("Model file doesn't contain feature names"); Log::Fatal("Model file doesn't contain feature names");
return; return false;
} }
// get tree models // get tree models
...@@ -624,6 +626,8 @@ void GBDT::LoadModelFromString(const std::string& model_str) { ...@@ -624,6 +626,8 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_class_; num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_class_;
num_init_iteration_ = num_iteration_for_pred_; num_init_iteration_ = num_iteration_for_pred_;
iter_ = 0; iter_ = 0;
return true;
} }
std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const { std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const {
......
...@@ -156,12 +156,12 @@ public: ...@@ -156,12 +156,12 @@ public:
* \param is_finish Is training finished or not * \param is_finish Is training finished or not
* \param filename Filename that want to save to * \param filename Filename that want to save to
*/ */
virtual void SaveModelToFile(int num_iterations, const char* filename) const override ; virtual bool SaveModelToFile(int num_iterations, const char* filename) const override ;
/*! /*!
* \brief Restore from a serialized string * \brief Restore from a serialized string
*/ */
void LoadModelFromString(const std::string& model_str) override; bool LoadModelFromString(const std::string& model_str) override;
/*! /*!
* \brief Get max feature index of this model * \brief Get max feature index of this model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment