Commit df358b2d authored by cbecker's avatar cbecker Committed by Guolin Ke
Browse files

Added success return value to LoadFileToBoosting and SaveModelToFile (#234)

parent 82fcfa0e
......@@ -143,14 +143,16 @@ public:
* \param num_used_model Number of model that want to save, -1 means save all
* \param is_finish Is training finished or not
* \param filename Filename that want to save to
* \return true if succeeded
*/
virtual void SaveModelToFile(int num_iterations, const char* filename) const = 0;
virtual bool SaveModelToFile(int num_iterations, const char* filename) const = 0;
/*!
* \brief Restore from a serialized string
* \param model_str The string of model
* \return true if succeeded
*/
virtual void LoadModelFromString(const std::string& model_str) = 0;
virtual bool LoadModelFromString(const std::string& model_str) = 0;
/*!
* \brief Get max feature index of this model
......@@ -192,7 +194,7 @@ public:
/*! \brief Disable copy */
Boosting(const Boosting&) = delete;
static void LoadFileToBoosting(Boosting* boosting, const char* filename);
static bool LoadFileToBoosting(Boosting* boosting, const char* filename);
/*!
* \brief Create boosting object
......
......@@ -10,7 +10,7 @@ std::string GetBoostingTypeFromModelFile(const char* filename) {
return type;
}
void Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
bool Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
if (boosting != nullptr) {
TextReader<size_t> model_reader(filename, true);
model_reader.ReadAllLines();
......@@ -18,8 +18,11 @@ void Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
for (auto& line : model_reader.Lines()) {
str_buf << line << '\n';
}
boosting->LoadModelFromString(str_buf.str());
if (!boosting->LoadModelFromString(str_buf.str()))
return false;
}
return true;
}
Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename) {
......
......@@ -509,7 +509,7 @@ std::string GBDT::DumpModel(int num_iteration) const {
return str_buf.str();
}
void GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
bool GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
/*! \brief File to write models */
std::ofstream output_file;
output_file.open(filename);
......@@ -553,9 +553,11 @@ void GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
}
output_file.close();
return (bool)output_file;
}
void GBDT::LoadModelFromString(const std::string& model_str) {
bool GBDT::LoadModelFromString(const std::string& model_str) {
// use serialized string to restore this object
models_.clear();
std::vector<std::string> lines = Common::Split(model_str.c_str(), '\n');
......@@ -566,7 +568,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &num_class_);
} else {
Log::Fatal("Model file doesn't specify the number of classes");
return;
return false;
}
// get index of label
line = Common::FindFromLines(lines, "label_index=");
......@@ -574,7 +576,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &label_idx_);
} else {
Log::Fatal("Model file doesn't specify the label index");
return;
return false;
}
// get max_feature_idx first
line = Common::FindFromLines(lines, "max_feature_idx=");
......@@ -582,7 +584,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &max_feature_idx_);
} else {
Log::Fatal("Model file doesn't specify max_feature_idx");
return;
return false;
}
// get sigmoid parameter
line = Common::FindFromLines(lines, "sigmoid=");
......@@ -597,11 +599,11 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
feature_names_ = Common::Split(line.substr(std::strlen("feature_names=")).c_str(), " ");
if (feature_names_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
Log::Fatal("Wrong size of feature_names");
return;
return false;
}
} else {
Log::Fatal("Model file doesn't contain feature names");
return;
return false;
}
// get tree models
......@@ -624,6 +626,8 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_class_;
num_init_iteration_ = num_iteration_for_pred_;
iter_ = 0;
return true;
}
std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const {
......
......@@ -156,12 +156,12 @@ public:
* \param is_finish Is training finished or not
* \param filename Filename that want to save to
*/
virtual void SaveModelToFile(int num_iterations, const char* filename) const override ;
virtual bool SaveModelToFile(int num_iterations, const char* filename) const override ;
/*!
* \brief Restore from a serialized string
*/
void LoadModelFromString(const std::string& model_str) override;
bool LoadModelFromString(const std::string& model_str) override;
/*!
* \brief Get max feature index of this model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment