Commit fa643b3d authored by Guolin Ke's avatar Guolin Ke
Browse files

simplify the logic of load model from string

parent 01e10529
......@@ -81,14 +81,14 @@ inline static std::vector<std::string> Split(const char* c_str, const char* deli
return ret;
}
template<typename T>
inline std::string Join(const std::vector<T>& data, char delimiters) {
std::stringstream result_stream_buf;
result_stream_buf << data[0];
for (size_t i = 1; i < data.size(); ++i) {
result_stream_buf << delimiters << data[i];
inline static std::string FindFromLines(const std::vector<std::string>& lines, const char* key_word) {
for (auto& line : lines) {
size_t find_pos = line.find(key_word);
if (find_pos != std::string::npos) {
return line;
}
}
return result_stream_buf.str();
return "";
}
inline static const char* Atoi(const char* p, int* out) {
......@@ -310,7 +310,8 @@ inline static std::vector<int> StringToIntArray(const std::string& str, char del
return ret;
}
inline static std::string Join(const std::vector<std::string>& strs, char delimiter) {
template<typename T>
inline static std::string Join(const std::vector<T>& strs, char delimiter) {
if (strs.size() <= 0) {
return std::string("");
}
......@@ -323,7 +324,8 @@ inline static std::string Join(const std::vector<std::string>& strs, char delimi
return ss.str();
}
inline static std::string Join(const std::vector<std::string>& strs, size_t start, size_t end, char delimiter) {
template<typename T>
inline static std::string Join(const std::vector<T>& strs, size_t start, size_t end, char delimiter) {
if (end - start <= 0) {
return std::string("");
}
......
......@@ -403,79 +403,40 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
// use serialized string to restore this object
models_.clear();
std::vector<std::string> lines = Common::Split(model_str.c_str(), '\n');
size_t i = 0;
// get number of classes
while (i < lines.size()) {
size_t find_pos = lines[i].find("num_class=");
if (find_pos != std::string::npos) {
std::vector<std::string> strs = Common::Split(lines[i].c_str(), '=');
Common::Atoi(strs[1].c_str(), &num_class_);
++i;
break;
auto line = Common::FindFromLines(lines, "num_class=");
if (line.size() > 0) {
Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &num_class_);
} else {
++i;
}
}
if (i == lines.size()) {
Log::Fatal("Model file doesn't specify the number of classes");
return;
}
// get index of label
i = 0;
while (i < lines.size()) {
size_t find_pos = lines[i].find("label_index=");
if (find_pos != std::string::npos) {
std::vector<std::string> strs = Common::Split(lines[i].c_str(), '=');
Common::Atoi(strs[1].c_str(), &label_idx_);
++i;
break;
line = Common::FindFromLines(lines, "label_index=");
if (line.size() > 0) {
Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &label_idx_);
} else {
++i;
}
}
if (i == lines.size()) {
Log::Fatal("Model file doesn't specify the label index");
return;
}
// get max_feature_idx first
i = 0;
while (i < lines.size()) {
size_t find_pos = lines[i].find("max_feature_idx=");
if (find_pos != std::string::npos) {
std::vector<std::string> strs = Common::Split(lines[i].c_str(), '=');
Common::Atoi(strs[1].c_str(), &max_feature_idx_);
++i;
break;
line = Common::FindFromLines(lines, "max_feature_idx=");
if (line.size() > 0) {
Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &max_feature_idx_);
} else {
++i;
}
}
if (i == lines.size()) {
Log::Fatal("Model file doesn't specify max_feature_idx");
return;
}
// get sigmoid parameter
i = 0;
while (i < lines.size()) {
size_t find_pos = lines[i].find("sigmoid=");
if (find_pos != std::string::npos) {
std::vector<std::string> strs = Common::Split(lines[i].c_str(), '=');
Common::Atof(strs[1].c_str(), &sigmoid_);
++i;
break;
line = Common::FindFromLines(lines, "sigmoid=");
if (line.size() > 0) {
Common::Atof(Common::Split(line.c_str(), '=')[1].c_str(), &sigmoid_);
} else {
++i;
}
}
// if sigmoid doesn't exists
if (i == lines.size()) {
sigmoid_ = -1.0f;
}
// get tree models
i = 0;
size_t i = 0;
while (i < lines.size()) {
size_t find_pos = lines[i].find("Tree=");
if (find_pos != std::string::npos) {
......@@ -483,7 +444,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
int start = static_cast<int>(i);
while (i < lines.size() && lines[i].find("Tree=") == std::string::npos) { ++i; }
int end = static_cast<int>(i);
std::string tree_str = Common::Join(lines, start, end, '\n');
std::string tree_str = Common::Join<std::string>(lines, start, end, '\n');
models_.push_back(new Tree(tree_str));
} else {
++i;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment