Unverified Commit e2f11b05 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix possible problem in read number of columns from libsvm file. (#3242)

parent 2c742d57
......@@ -129,7 +129,56 @@ std::vector<std::string> ReadKLineFromFile(const char* filename, bool header, in
return ret;
}
DataType GetDataType(const std::vector<std::string>& lines, int* num_col) {
int GetNumColFromLIBSVMFile(const char* filename, bool header) {
auto reader = VirtualFileReader::Make(filename);
if (!reader->Init()) {
Log::Fatal("Data file %s doesn't exist.", filename);
}
std::vector<std::string> ret;
std::string cur_line;
const size_t buffer_size = 1024 * 1024;
auto buffer = std::vector<char>(buffer_size);
size_t read_len = reader->Read(buffer.data(), buffer_size);
if (read_len <= 0) {
Log::Fatal("Data file %s couldn't be read.", filename);
}
std::string read_str = std::string(buffer.data(), read_len);
std::stringstream tmp_file(read_str);
if (header) {
if (!tmp_file.eof()) {
GetLine(&tmp_file, &cur_line, reader.get(), &buffer, buffer_size);
}
}
int max_col_idx = 0;
int max_line_idx = 0;
const int stop_round = 1 << 7;
const int max_line = 1 << 13;
for (int i = 0; i < max_line; ++i) {
if (!tmp_file.eof()) {
GetLine(&tmp_file, &cur_line, reader.get(), &buffer, buffer_size);
cur_line = Common::Trim(cur_line);
auto colon_pos = cur_line.find_last_of(":");
auto space_pos = cur_line.find_last_of(" \f\t\v");
auto sub_str = cur_line.substr(space_pos + 1, space_pos - colon_pos - 1);
int cur_idx = 0;
Common::Atoi(sub_str.c_str(), &cur_idx);
if (cur_idx > max_col_idx) {
max_col_idx = cur_idx;
max_line_idx = i;
}
if (i - max_line_idx >= stop_round) {
break;
}
} else {
break;
}
}
CHECK_GT(max_col_idx, 0);
return max_col_idx;
}
DataType GetDataType(const char* filename, bool header,
const std::vector<std::string>& lines, int* num_col) {
DataType type = DataType::INVALID;
if (lines.empty()) {
return type;
......@@ -173,16 +222,7 @@ DataType GetDataType(const std::vector<std::string>& lines, int* num_col) {
}
}
if (type == DataType::LIBSVM) {
int max_col_idx = 0;
for (size_t i = 0; i < lines.size(); ++i) {
auto str = Common::Trim(lines[i]);
auto colon_pos = str.find_last_of(":");
auto space_pos = str.find_last_of(" \f\t\v");
auto sub_str = str.substr(space_pos + 1, space_pos - colon_pos - 1);
int cur_idx = 0;
Common::Atoi(sub_str.c_str(), &cur_idx);
max_col_idx = std::max(cur_idx, max_col_idx);
}
int max_col_idx = GetNumColFromLIBSVMFile(filename, header);
*num_col = max_col_idx + 1;
} else if (type == DataType::CSV) {
*num_col = comma_cnt + 1;
......@@ -193,10 +233,10 @@ DataType GetDataType(const std::vector<std::string>& lines, int* num_col) {
}
Parser* Parser::CreateParser(const char* filename, bool header, int num_features, int label_idx) {
const int n_read_line = 20;
const int n_read_line = 32;
auto lines = ReadKLineFromFile(filename, header, n_read_line);
int num_col = 0;
DataType type = GetDataType(lines, &num_col);
DataType type = GetDataType(filename, header, lines, &num_col);
if (type == DataType::INVALID) {
Log::Fatal("Unknown format of training data.");
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment