Commit c0f8e38d authored by cbecker's avatar cbecker Committed by Guolin Ke
Browse files

Added `SaveModelToString` and modified `SaveModelToFile` to use it. (#235)

parent df358b2d
...@@ -147,6 +147,13 @@ public: ...@@ -147,6 +147,13 @@ public:
*/ */
virtual bool SaveModelToFile(int num_iterations, const char* filename) const = 0; virtual bool SaveModelToFile(int num_iterations, const char* filename) const = 0;
/*!
* \brief Save model to string
* \param num_used_model Number of model that want to save, -1 means save all
* \return Non-empty string if succeeded
*/
virtual std::string SaveModelToString(int num_iterations) const = 0;
/*! /*!
* \brief Restore from a serialized string * \brief Restore from a serialized string
* \param model_str The string of model * \param model_str The string of model
......
...@@ -509,48 +509,57 @@ std::string GBDT::DumpModel(int num_iteration) const { ...@@ -509,48 +509,57 @@ std::string GBDT::DumpModel(int num_iteration) const {
return str_buf.str(); return str_buf.str();
} }
std::string GBDT::SaveModelToString(int num_iterations) const {
std::stringstream ss;
// output model type
ss << SubModelName() << std::endl;
// output number of class
ss << "num_class=" << num_class_ << std::endl;
// output label index
ss << "label_index=" << label_idx_ << std::endl;
// output max_feature_idx
ss << "max_feature_idx=" << max_feature_idx_ << std::endl;
// output objective name
if (object_function_ != nullptr) {
ss << "objective=" << object_function_->GetName() << std::endl;
}
// output sigmoid parameter
ss << "sigmoid=" << sigmoid_ << std::endl;
ss << "feature_names=" << Common::Join(feature_names_, " ") << std::endl;
ss << std::endl;
int num_used_model = static_cast<int>(models_.size());
if (num_iterations > 0) {
num_used_model = std::min(num_iterations * num_class_, num_used_model);
}
// output tree models
for (int i = 0; i < num_used_model; ++i) {
ss << "Tree=" << i << std::endl;
ss << models_[i]->ToString() << std::endl;
}
std::vector<std::pair<size_t, std::string>> pairs = FeatureImportance();
ss << std::endl << "feature importances:" << std::endl;
for (size_t i = 0; i < pairs.size(); ++i) {
ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << std::endl;
}
ss << std::endl << "feature information:" << std::endl;
for (size_t i = 0; i < max_feature_idx_ + 1; ++i) {
ss << feature_names_[i] << "=" << feature_infos_[i] << std::endl;
}
return ss.str();
}
bool GBDT::SaveModelToFile(int num_iteration, const char* filename) const { bool GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
/*! \brief File to write models */ /*! \brief File to write models */
std::ofstream output_file; std::ofstream output_file;
output_file.open(filename); output_file.open(filename);
// output model type
output_file << SubModelName() << std::endl;
// output number of class
output_file << "num_class=" << num_class_ << std::endl;
// output label index
output_file << "label_index=" << label_idx_ << std::endl;
// output max_feature_idx
output_file << "max_feature_idx=" << max_feature_idx_ << std::endl;
// output objective name
if (object_function_ != nullptr) {
output_file << "objective=" << object_function_->GetName() << std::endl;
}
// output sigmoid parameter
output_file << "sigmoid=" << sigmoid_ << std::endl;
output_file << "feature_names=" << Common::Join(feature_names_, " ") << std::endl;
output_file << std::endl;
int num_used_model = static_cast<int>(models_.size());
if (num_iteration > 0) {
num_used_model = std::min(num_iteration * num_class_, num_used_model);
}
// output tree models
for (int i = 0; i < num_used_model; ++i) {
output_file << "Tree=" << i << std::endl;
output_file << models_[i]->ToString() << std::endl;
}
std::vector<std::pair<size_t, std::string>> pairs = FeatureImportance(); output_file << SaveModelToString(num_iteration);
output_file << std::endl << "feature importances:" << std::endl;
for (size_t i = 0; i < pairs.size(); ++i) {
output_file << pairs[i].second << "=" << std::to_string(pairs[i].first) << std::endl;
}
output_file << std::endl << "feature information:" << std::endl;
for (size_t i = 0; i < max_feature_idx_ + 1; ++i) {
output_file << feature_names_[i] << "=" << feature_infos_[i] << std::endl;
}
output_file.close(); output_file.close();
......
...@@ -158,6 +158,13 @@ public: ...@@ -158,6 +158,13 @@ public:
*/ */
virtual bool SaveModelToFile(int num_iterations, const char* filename) const override ; virtual bool SaveModelToFile(int num_iterations, const char* filename) const override ;
/*!
* \brief Save model to string
* \param num_used_model Number of model that want to save, -1 means save all
* \return Non-empty string if succeeded
*/
virtual std::string SaveModelToString(int num_iterations) const override ;
/*! /*!
* \brief Restore from a serialized string * \brief Restore from a serialized string
*/ */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment