Commit 3a4608f4 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix the split functions.

parent 5633ef53
...@@ -44,6 +44,7 @@ inline static std::string& RemoveQuotationSymbol(std::string& str) { ...@@ -44,6 +44,7 @@ inline static std::string& RemoveQuotationSymbol(std::string& str) {
str.erase(0, str.find_first_not_of("'\"")); str.erase(0, str.find_first_not_of("'\""));
return str; return str;
} }
inline static bool StartsWith(const std::string& str, const std::string prefix) { inline static bool StartsWith(const std::string& str, const std::string prefix) {
if (str.substr(0, prefix.size()) == prefix) { if (str.substr(0, prefix.size()) == prefix) {
return true; return true;
...@@ -51,32 +52,79 @@ inline static bool StartsWith(const std::string& str, const std::string prefix) ...@@ -51,32 +52,79 @@ inline static bool StartsWith(const std::string& str, const std::string prefix)
return false; return false;
} }
} }
inline static std::vector<std::string> Split(const char* c_str, char delimiter) { inline static std::vector<std::string> Split(const char* c_str, char delimiter) {
std::vector<std::string> ret; std::vector<std::string> ret;
std::string str(c_str); std::string str(c_str);
size_t i = 0; size_t i = 0;
size_t pos = str.find(delimiter); size_t pos = 0;
while (pos != std::string::npos) { while (pos < str.length()) {
ret.push_back(str.substr(i, pos - i)); if (str[pos] == delimiter) {
i = ++pos; if (i < pos) {
pos = str.find(delimiter, pos); ret.push_back(str.substr(i, pos - i));
}
++pos;
i = pos;
} else {
++pos;
}
}
if (i < pos) {
ret.push_back(str.substr(i));
}
return ret;
}
inline static std::vector<std::string> SplitLines(const char* c_str) {
std::vector<std::string> ret;
std::string str(c_str);
size_t i = 0;
size_t pos = 0;
while (pos < str.length()) {
if (str[pos] == '\n' || str[pos] == '\r') {
if (i < pos) {
ret.push_back(str.substr(i, pos - i));
}
// skip the line endings
while (str[pos] == '\n' || str[pos] == '\r') ++pos;
// new begin
i = pos;
} else {
++pos;
}
}
if (i < pos) {
ret.push_back(str.substr(i));
} }
ret.push_back(str.substr(i));
return ret; return ret;
} }
inline static std::vector<std::string> Split(const char* c_str, const char* delimiters) { inline static std::vector<std::string> Split(const char* c_str, const char* delimiters) {
// will split when met any chars in delimiters
std::vector<std::string> ret; std::vector<std::string> ret;
std::string str(c_str); std::string str(c_str);
size_t i = 0; size_t i = 0;
size_t pos = str.find_first_of(delimiters); size_t pos = 0;
while (pos != std::string::npos) { while (pos < str.length()) {
ret.push_back(str.substr(i, pos - i)); bool met_delimiters = false;
i = ++pos; for (int j = 0; delimiters[j] != '\0'; ++j) {
pos = str.find_first_of(delimiters, pos); if (str[pos] == delimiters[j]) {
met_delimiters = true;
break;
}
}
if (met_delimiters) {
if (i < pos) {
ret.push_back(str.substr(i, pos - i));
}
++pos;
i = pos;
} else {
++pos;
}
}
if (i < pos) {
ret.push_back(str.substr(i));
} }
ret.push_back(str.substr(i));
return ret; return ret;
} }
......
...@@ -875,7 +875,7 @@ bool GBDT::SaveModelToFile(int num_iteration, const char* filename) const { ...@@ -875,7 +875,7 @@ bool GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
bool GBDT::LoadModelFromString(const std::string& model_str) { bool GBDT::LoadModelFromString(const std::string& model_str) {
// use serialized string to restore this object // use serialized string to restore this object
models_.clear(); models_.clear();
std::vector<std::string> lines = Common::Split(model_str.c_str(), '\n'); std::vector<std::string> lines = Common::SplitLines(model_str.c_str());
// get number of classes // get number of classes
auto line = Common::FindFromLines(lines, "num_class="); auto line = Common::FindFromLines(lines, "num_class=");
...@@ -917,7 +917,7 @@ bool GBDT::LoadModelFromString(const std::string& model_str) { ...@@ -917,7 +917,7 @@ bool GBDT::LoadModelFromString(const std::string& model_str) {
// get feature names // get feature names
line = Common::FindFromLines(lines, "feature_names="); line = Common::FindFromLines(lines, "feature_names=");
if (line.size() > 0) { if (line.size() > 0) {
feature_names_ = Common::Split(line.substr(std::strlen("feature_names=")).c_str(), " "); feature_names_ = Common::Split(line.substr(std::strlen("feature_names=")).c_str(), ' ');
if (feature_names_.size() != static_cast<size_t>(max_feature_idx_ + 1)) { if (feature_names_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
Log::Fatal("Wrong size of feature_names"); Log::Fatal("Wrong size of feature_names");
return false; return false;
...@@ -929,7 +929,7 @@ bool GBDT::LoadModelFromString(const std::string& model_str) { ...@@ -929,7 +929,7 @@ bool GBDT::LoadModelFromString(const std::string& model_str) {
line = Common::FindFromLines(lines, "feature_infos="); line = Common::FindFromLines(lines, "feature_infos=");
if (line.size() > 0) { if (line.size() > 0) {
feature_infos_ = Common::Split(line.substr(std::strlen("feature_infos=")).c_str(), " "); feature_infos_ = Common::Split(line.substr(std::strlen("feature_infos=")).c_str(), ' ');
if (feature_infos_.size() != static_cast<size_t>(max_feature_idx_ + 1)) { if (feature_infos_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
Log::Fatal("Wrong size of feature_infos"); Log::Fatal("Wrong size of feature_infos");
return false; return false;
......
...@@ -441,7 +441,7 @@ std::string Tree::NodeToIfElse(int index, bool is_predict_leaf_index) { ...@@ -441,7 +441,7 @@ std::string Tree::NodeToIfElse(int index, bool is_predict_leaf_index) {
} }
Tree::Tree(const std::string& str) { Tree::Tree(const std::string& str) {
std::vector<std::string> lines = Common::Split(str.c_str(), '\n'); std::vector<std::string> lines = Common::SplitLines(str.c_str());
std::unordered_map<std::string, std::string> key_vals; std::unordered_map<std::string, std::string> key_vals;
for (const std::string& line : lines) { for (const std::string& line : lines) {
std::vector<std::string> tmp_strs = Common::Split(line.c_str(), '='); std::vector<std::string> tmp_strs = Common::Split(line.c_str(), '=');
......
...@@ -135,7 +135,7 @@ LGBM_SE LGBM_DatasetSetFeatureNames_R(LGBM_SE handle, ...@@ -135,7 +135,7 @@ LGBM_SE LGBM_DatasetSetFeatureNames_R(LGBM_SE handle,
LGBM_SE feature_names, LGBM_SE feature_names,
LGBM_SE call_state) { LGBM_SE call_state) {
R_API_BEGIN(); R_API_BEGIN();
auto vec_names = Common::Split(R_CHAR_PTR(feature_names), "\t"); auto vec_names = Common::Split(R_CHAR_PTR(feature_names), '\t');
std::vector<const char*> vec_sptr; std::vector<const char*> vec_sptr;
int len = static_cast<int>(vec_names.size()); int len = static_cast<int>(vec_names.size());
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
......
...@@ -28,7 +28,7 @@ public: ...@@ -28,7 +28,7 @@ public:
explicit BinaryLogloss(const std::vector<std::string>& strs) { explicit BinaryLogloss(const std::vector<std::string>& strs) {
sigmoid_ = -1; sigmoid_ = -1;
for (auto str : strs) { for (auto str : strs) {
auto tokens = Common::Split(str.c_str(), ":"); auto tokens = Common::Split(str.c_str(), ':');
if (tokens.size() == 2) { if (tokens.size() == 2) {
if (tokens[0] == std::string("sigmoid")) { if (tokens[0] == std::string("sigmoid")) {
Common::Atof(tokens[1].c_str(), &sigmoid_); Common::Atof(tokens[1].c_str(), &sigmoid_);
......
...@@ -22,7 +22,7 @@ public: ...@@ -22,7 +22,7 @@ public:
explicit MulticlassSoftmax(const std::vector<std::string>& strs) { explicit MulticlassSoftmax(const std::vector<std::string>& strs) {
num_class_ = -1; num_class_ = -1;
for (auto str : strs) { for (auto str : strs) {
auto tokens = Common::Split(str.c_str(), ":"); auto tokens = Common::Split(str.c_str(), ':');
if (tokens.size() == 2) { if (tokens.size() == 2) {
if (tokens[0] == std::string("num_class")) { if (tokens[0] == std::string("num_class")) {
Common::Atoi(tokens[1].c_str(), &num_class_); Common::Atoi(tokens[1].c_str(), &num_class_);
...@@ -151,7 +151,7 @@ public: ...@@ -151,7 +151,7 @@ public:
num_class_ = -1; num_class_ = -1;
sigmoid_ = -1; sigmoid_ = -1;
for (auto str : strs) { for (auto str : strs) {
auto tokens = Common::Split(str.c_str(), ":"); auto tokens = Common::Split(str.c_str(), ':');
if (tokens.size() == 2) { if (tokens.size() == 2) {
if (tokens[0] == std::string("num_class")) { if (tokens[0] == std::string("num_class")) {
Common::Atoi(tokens[1].c_str(), &num_class_); Common::Atoi(tokens[1].c_str(), &num_class_);
......
...@@ -31,7 +31,7 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& ...@@ -31,7 +31,7 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string&
} }
ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& str) { ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& str) {
auto strs = Common::Split(str.c_str(), " "); auto strs = Common::Split(str.c_str(), ' ');
auto type = strs[0]; auto type = strs[0];
if (type == std::string("regression")) { if (type == std::string("regression")) {
return new RegressionL2loss(strs); return new RegressionL2loss(strs);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment