Commit fa643b3d authored by Guolin Ke's avatar Guolin Ke
Browse files

simplify the logic of load model from string

parent 01e10529
...@@ -81,14 +81,14 @@ inline static std::vector<std::string> Split(const char* c_str, const char* deli ...@@ -81,14 +81,14 @@ inline static std::vector<std::string> Split(const char* c_str, const char* deli
return ret; return ret;
} }
template<typename T> inline static std::string FindFromLines(const std::vector<std::string>& lines, const char* key_word) {
inline std::string Join(const std::vector<T>& data, char delimiters) { for (auto& line : lines) {
std::stringstream result_stream_buf; size_t find_pos = line.find(key_word);
result_stream_buf << data[0]; if (find_pos != std::string::npos) {
for (size_t i = 1; i < data.size(); ++i) { return line;
result_stream_buf << delimiters << data[i]; }
} }
return result_stream_buf.str(); return "";
} }
inline static const char* Atoi(const char* p, int* out) { inline static const char* Atoi(const char* p, int* out) {
...@@ -310,7 +310,8 @@ inline static std::vector<int> StringToIntArray(const std::string& str, char del ...@@ -310,7 +310,8 @@ inline static std::vector<int> StringToIntArray(const std::string& str, char del
return ret; return ret;
} }
inline static std::string Join(const std::vector<std::string>& strs, char delimiter) { template<typename T>
inline static std::string Join(const std::vector<T>& strs, char delimiter) {
if (strs.size() <= 0) { if (strs.size() <= 0) {
return std::string(""); return std::string("");
} }
...@@ -323,7 +324,8 @@ inline static std::string Join(const std::vector<std::string>& strs, char delimi ...@@ -323,7 +324,8 @@ inline static std::string Join(const std::vector<std::string>& strs, char delimi
return ss.str(); return ss.str();
} }
inline static std::string Join(const std::vector<std::string>& strs, size_t start, size_t end, char delimiter) { template<typename T>
inline static std::string Join(const std::vector<T>& strs, size_t start, size_t end, char delimiter) {
if (end - start <= 0) { if (end - start <= 0) {
return std::string(""); return std::string("");
} }
......
...@@ -403,79 +403,40 @@ void GBDT::LoadModelFromString(const std::string& model_str) { ...@@ -403,79 +403,40 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
// use serialized string to restore this object // use serialized string to restore this object
models_.clear(); models_.clear();
std::vector<std::string> lines = Common::Split(model_str.c_str(), '\n'); std::vector<std::string> lines = Common::Split(model_str.c_str(), '\n');
size_t i = 0;
// get number of classes // get number of classes
while (i < lines.size()) { auto line = Common::FindFromLines(lines, "num_class=");
size_t find_pos = lines[i].find("num_class="); if (line.size() > 0) {
if (find_pos != std::string::npos) { Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &num_class_);
std::vector<std::string> strs = Common::Split(lines[i].c_str(), '='); } else {
Common::Atoi(strs[1].c_str(), &num_class_);
++i;
break;
} else {
++i;
}
}
if (i == lines.size()) {
Log::Fatal("Model file doesn't specify the number of classes"); Log::Fatal("Model file doesn't specify the number of classes");
return; return;
} }
// get index of label // get index of label
i = 0; line = Common::FindFromLines(lines, "label_index=");
while (i < lines.size()) { if (line.size() > 0) {
size_t find_pos = lines[i].find("label_index="); Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &label_idx_);
if (find_pos != std::string::npos) { } else {
std::vector<std::string> strs = Common::Split(lines[i].c_str(), '=');
Common::Atoi(strs[1].c_str(), &label_idx_);
++i;
break;
} else {
++i;
}
}
if (i == lines.size()) {
Log::Fatal("Model file doesn't specify the label index"); Log::Fatal("Model file doesn't specify the label index");
return; return;
} }
// get max_feature_idx first // get max_feature_idx first
i = 0; line = Common::FindFromLines(lines, "max_feature_idx=");
while (i < lines.size()) { if (line.size() > 0) {
size_t find_pos = lines[i].find("max_feature_idx="); Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &max_feature_idx_);
if (find_pos != std::string::npos) { } else {
std::vector<std::string> strs = Common::Split(lines[i].c_str(), '=');
Common::Atoi(strs[1].c_str(), &max_feature_idx_);
++i;
break;
} else {
++i;
}
}
if (i == lines.size()) {
Log::Fatal("Model file doesn't specify max_feature_idx"); Log::Fatal("Model file doesn't specify max_feature_idx");
return; return;
} }
// get sigmoid parameter // get sigmoid parameter
i = 0; line = Common::FindFromLines(lines, "sigmoid=");
while (i < lines.size()) { if (line.size() > 0) {
size_t find_pos = lines[i].find("sigmoid="); Common::Atof(Common::Split(line.c_str(), '=')[1].c_str(), &sigmoid_);
if (find_pos != std::string::npos) { } else {
std::vector<std::string> strs = Common::Split(lines[i].c_str(), '=');
Common::Atof(strs[1].c_str(), &sigmoid_);
++i;
break;
} else {
++i;
}
}
// if sigmoid doesn't exists
if (i == lines.size()) {
sigmoid_ = -1.0f; sigmoid_ = -1.0f;
} }
// get tree models // get tree models
i = 0; size_t i = 0;
while (i < lines.size()) { while (i < lines.size()) {
size_t find_pos = lines[i].find("Tree="); size_t find_pos = lines[i].find("Tree=");
if (find_pos != std::string::npos) { if (find_pos != std::string::npos) {
...@@ -483,7 +444,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) { ...@@ -483,7 +444,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
int start = static_cast<int>(i); int start = static_cast<int>(i);
while (i < lines.size() && lines[i].find("Tree=") == std::string::npos) { ++i; } while (i < lines.size() && lines[i].find("Tree=") == std::string::npos) { ++i; }
int end = static_cast<int>(i); int end = static_cast<int>(i);
std::string tree_str = Common::Join(lines, start, end, '\n'); std::string tree_str = Common::Join<std::string>(lines, start, end, '\n');
models_.push_back(new Tree(tree_str)); models_.push_back(new Tree(tree_str));
} else { } else {
++i; ++i;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment