"tests/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "3051b7711874aaa39d26ab80b030bea9a71dd7d2"
Commit c0f8e38d authored by cbecker's avatar cbecker Committed by Guolin Ke
Browse files

Added `SaveModelToString` and modified `SaveModelToFile` to use it. (#235)

parent df358b2d
...@@ -147,6 +147,13 @@ public: ...@@ -147,6 +147,13 @@ public:
*/ */
virtual bool SaveModelToFile(int num_iterations, const char* filename) const = 0; virtual bool SaveModelToFile(int num_iterations, const char* filename) const = 0;
/*!
* \brief Save model to string
* \param num_used_model Number of model that want to save, -1 means save all
* \return Non-empty string if succeeded
*/
virtual std::string SaveModelToString(int num_iterations) const = 0;
/*! /*!
* \brief Restore from a serialized string * \brief Restore from a serialized string
* \param model_str The string of model * \param model_str The string of model
......
...@@ -509,49 +509,58 @@ std::string GBDT::DumpModel(int num_iteration) const { ...@@ -509,49 +509,58 @@ std::string GBDT::DumpModel(int num_iteration) const {
return str_buf.str(); return str_buf.str();
} }
bool GBDT::SaveModelToFile(int num_iteration, const char* filename) const { std::string GBDT::SaveModelToString(int num_iterations) const {
/*! \brief File to write models */ std::stringstream ss;
std::ofstream output_file;
output_file.open(filename);
// output model type // output model type
output_file << SubModelName() << std::endl; ss << SubModelName() << std::endl;
// output number of class // output number of class
output_file << "num_class=" << num_class_ << std::endl; ss << "num_class=" << num_class_ << std::endl;
// output label index // output label index
output_file << "label_index=" << label_idx_ << std::endl; ss << "label_index=" << label_idx_ << std::endl;
// output max_feature_idx // output max_feature_idx
output_file << "max_feature_idx=" << max_feature_idx_ << std::endl; ss << "max_feature_idx=" << max_feature_idx_ << std::endl;
// output objective name // output objective name
if (object_function_ != nullptr) { if (object_function_ != nullptr) {
output_file << "objective=" << object_function_->GetName() << std::endl; ss << "objective=" << object_function_->GetName() << std::endl;
} }
// output sigmoid parameter // output sigmoid parameter
output_file << "sigmoid=" << sigmoid_ << std::endl; ss << "sigmoid=" << sigmoid_ << std::endl;
output_file << "feature_names=" << Common::Join(feature_names_, " ") << std::endl; ss << "feature_names=" << Common::Join(feature_names_, " ") << std::endl;
output_file << std::endl; ss << std::endl;
int num_used_model = static_cast<int>(models_.size()); int num_used_model = static_cast<int>(models_.size());
if (num_iteration > 0) { if (num_iterations > 0) {
num_used_model = std::min(num_iteration * num_class_, num_used_model); num_used_model = std::min(num_iterations * num_class_, num_used_model);
} }
// output tree models // output tree models
for (int i = 0; i < num_used_model; ++i) { for (int i = 0; i < num_used_model; ++i) {
output_file << "Tree=" << i << std::endl; ss << "Tree=" << i << std::endl;
output_file << models_[i]->ToString() << std::endl; ss << models_[i]->ToString() << std::endl;
} }
std::vector<std::pair<size_t, std::string>> pairs = FeatureImportance(); std::vector<std::pair<size_t, std::string>> pairs = FeatureImportance();
output_file << std::endl << "feature importances:" << std::endl; ss << std::endl << "feature importances:" << std::endl;
for (size_t i = 0; i < pairs.size(); ++i) { for (size_t i = 0; i < pairs.size(); ++i) {
output_file << pairs[i].second << "=" << std::to_string(pairs[i].first) << std::endl; ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << std::endl;
} }
output_file << std::endl << "feature information:" << std::endl; ss << std::endl << "feature information:" << std::endl;
for (size_t i = 0; i < max_feature_idx_ + 1; ++i) { for (size_t i = 0; i < max_feature_idx_ + 1; ++i) {
output_file << feature_names_[i] << "=" << feature_infos_[i] << std::endl; ss << feature_names_[i] << "=" << feature_infos_[i] << std::endl;
} }
return ss.str();
}
bool GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
/*! \brief File to write models */
std::ofstream output_file;
output_file.open(filename);
output_file << SaveModelToString(num_iteration);
output_file.close(); output_file.close();
return (bool)output_file; return (bool)output_file;
......
...@@ -158,6 +158,13 @@ public: ...@@ -158,6 +158,13 @@ public:
*/ */
virtual bool SaveModelToFile(int num_iterations, const char* filename) const override ; virtual bool SaveModelToFile(int num_iterations, const char* filename) const override ;
/*!
* \brief Save model to string
* \param num_used_model Number of model that want to save, -1 means save all
* \return Non-empty string if succeeded
*/
virtual std::string SaveModelToString(int num_iterations) const override ;
/*! /*!
* \brief Restore from a serialized string * \brief Restore from a serialized string
*/ */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment