parser.cpp 6.46 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5

6
#include <string>
7
#include <algorithm>
Guolin Ke's avatar
Guolin Ke committed
8
#include <fstream>
Guolin Ke's avatar
Guolin Ke committed
9
#include <functional>
10
#include <iostream>
Guolin Ke's avatar
Guolin Ke committed
11
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
12

13
14
#include "parser.hpp"

Guolin Ke's avatar
Guolin Ke committed
15
16
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
17
void GetStatistic(const char* str, int* comma_cnt, int* tab_cnt, int* colon_cnt) {
Guolin Ke's avatar
Guolin Ke committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
  *comma_cnt = 0;
  *tab_cnt = 0;
  *colon_cnt = 0;
  for (int i = 0; str[i] != '\0'; ++i) {
    if (str[i] == ',') {
      ++(*comma_cnt);
    } else if (str[i] == '\t') {
      ++(*tab_cnt);
    } else if (str[i] == ':') {
      ++(*colon_cnt);
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
32
int GetLabelIdxForLibsvm(const std::string& str, int num_features, int label_idx) {
Guolin Ke's avatar
Guolin Ke committed
33
34
35
  if (num_features <= 0) {
    return label_idx;
  }
Guolin Ke's avatar
Guolin Ke committed
36
37
38
  auto str2 = Common::Trim(str);
  auto pos_space = str2.find_first_of(" \f\n\r\t\v");
  auto pos_colon = str2.find_first_of(":");
39
  if (pos_space == std::string::npos || pos_space < pos_colon) {
Guolin Ke's avatar
Guolin Ke committed
40
    return label_idx;
41
42
  } else {
    return -1;
43
44
45
  }
}

Guolin Ke's avatar
Guolin Ke committed
46
int GetLabelIdxForTSV(const std::string& str, int num_features, int label_idx) {
Guolin Ke's avatar
Guolin Ke committed
47
48
49
  if (num_features <= 0) {
    return label_idx;
  }
Guolin Ke's avatar
Guolin Ke committed
50
51
  auto str2 = Common::Trim(str);
  auto tokens = Common::Split(str2.c_str(), '\t');
Guolin Ke's avatar
Guolin Ke committed
52
  if (static_cast<int>(tokens.size()) == num_features) {
Guolin Ke's avatar
Guolin Ke committed
53
    return -1;
54
  } else {
Guolin Ke's avatar
Guolin Ke committed
55
    return label_idx;
56
57
58
  }
}

Guolin Ke's avatar
Guolin Ke committed
59
int GetLabelIdxForCSV(const std::string& str, int num_features, int label_idx) {
Guolin Ke's avatar
Guolin Ke committed
60
61
62
  if (num_features <= 0) {
    return label_idx;
  }
Guolin Ke's avatar
Guolin Ke committed
63
64
  auto str2 = Common::Trim(str);
  auto tokens = Common::Split(str2.c_str(), ',');
Guolin Ke's avatar
Guolin Ke committed
65
  if (static_cast<int>(tokens.size()) == num_features) {
Guolin Ke's avatar
Guolin Ke committed
66
    return -1;
67
  } else {
Guolin Ke's avatar
Guolin Ke committed
68
    return label_idx;
69
70
71
  }
}

Guolin Ke's avatar
Guolin Ke committed
72
73
74
75
76
77
78
enum DataType {
  INVALID,
  CSV,
  TSV,
  LIBSVM
};

Guolin Ke's avatar
Guolin Ke committed
79
80
81
82
void GetLine(std::stringstream* ss, std::string* line, const VirtualFileReader* reader, std::vector<char>* buffer, size_t buffer_size) {
  std::getline(*ss, *line);
  while (ss->eof()) {
    size_t read_len = reader->Read(buffer->data(), buffer_size);
83
84
85
    if (read_len <= 0) {
      break;
    }
Guolin Ke's avatar
Guolin Ke committed
86
87
    ss->clear();
    ss->str(std::string(buffer->data(), read_len));
88
    std::string tmp;
Guolin Ke's avatar
Guolin Ke committed
89
90
    std::getline(*ss, tmp);
    *line += tmp;
91
92
93
  }
}

94
std::vector<std::string> ReadKLineFromFile(const char* filename, bool header, int k) {
95
96
  auto reader = VirtualFileReader::Make(filename);
  if (!reader->Init()) {
97
    Log::Fatal("Data file %s doesn't exist.", filename);
Guolin Ke's avatar
Guolin Ke committed
98
  }
99
100
101
  std::vector<std::string> ret;
  std::string cur_line;
  const size_t buffer_size = 1024 * 1024;
102
103
104
  auto buffer = std::vector<char>(buffer_size);
  size_t read_len = reader->Read(buffer.data(), buffer_size);
  if (read_len <= 0) {
105
    Log::Fatal("Data file %s couldn't be read.", filename);
106
  }
107
108
  std::string read_str = std::string(buffer.data(), read_len);
  std::stringstream tmp_file(read_str);
Guolin Ke's avatar
Guolin Ke committed
109
  if (header) {
Guolin Ke's avatar
Guolin Ke committed
110
    if (!tmp_file.eof()) {
111
      GetLine(&tmp_file, &cur_line, reader.get(), &buffer, buffer_size);
Guolin Ke's avatar
Guolin Ke committed
112
113
    }
  }
114
115
116
  for (int i = 0; i < k; ++i) {
    if (!tmp_file.eof()) {
      GetLine(&tmp_file, &cur_line, reader.get(), &buffer, buffer_size);
Guolin Ke's avatar
Guolin Ke committed
117
118
119
120
      cur_line = Common::Trim(cur_line);
      if (!cur_line.empty()) {
        ret.push_back(cur_line);
      }
121
122
123
    } else {
      break;
    }
Guolin Ke's avatar
Guolin Ke committed
124
  }
125
126
127
128
129
130
131
  if (ret.empty()) {
    Log::Fatal("Data file %s should have at least one line.", filename);
  } else if (ret.size() == 1) {
    Log::Warning("Data file %s only has one line.", filename);
  }
  return ret;
}
Guolin Ke's avatar
Guolin Ke committed
132

133
DataType GetDataType(const std::vector<std::string>& lines, int* num_col) {
Guolin Ke's avatar
Guolin Ke committed
134
  DataType type = DataType::INVALID;
135
136
137
138
139
140
141
142
  if (lines.empty()) {
    return type;
  }
  int comma_cnt = 0;
  int tab_cnt = 0;
  int colon_cnt = 0;
  GetStatistic(lines[0].c_str(), &comma_cnt, &tab_cnt, &colon_cnt);
  if (lines.size() == 1) {
Guolin Ke's avatar
Guolin Ke committed
143
    if (colon_cnt > 0) {
Guolin Ke's avatar
Guolin Ke committed
144
      type = DataType::LIBSVM;
Guolin Ke's avatar
Guolin Ke committed
145
    } else if (tab_cnt > 0) {
Guolin Ke's avatar
Guolin Ke committed
146
      type = DataType::TSV;
Guolin Ke's avatar
Guolin Ke committed
147
    } else if (comma_cnt > 0) {
Guolin Ke's avatar
Guolin Ke committed
148
149
      type = DataType::CSV;
    }
Guolin Ke's avatar
Guolin Ke committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
  } else if (lines.size() > 1) {
    int comma_cnt2 = 0;
    int tab_cnt2 = 0;
    int colon_cnt2 = 0;
    GetStatistic(lines[1].c_str(), &comma_cnt2, &tab_cnt2, &colon_cnt2);
    if (colon_cnt > 0 || colon_cnt2 > 0) {
      type = DataType::LIBSVM;
    } else if (tab_cnt == tab_cnt2 && tab_cnt > 0) {
      type = DataType::TSV;
    } else if (comma_cnt == comma_cnt2 && comma_cnt > 0) {
      type = DataType::CSV;
    }
    if (type == DataType::TSV || type == DataType::CSV) {
      // valid the type
      for (size_t i = 2; i < lines.size(); ++i) {
        GetStatistic(lines[i].c_str(), &comma_cnt2, &tab_cnt2, &colon_cnt2);
        if (type == DataType::TSV && tab_cnt2 != tab_cnt) {
          type = DataType::INVALID;
          break;
        } else if (type == DataType::CSV && comma_cnt != comma_cnt2) {
          type = DataType::INVALID;
          break;
        }
173
174
175
176
177
178
179
180
181
182
183
184
185
      }
    }
  }
  if (type == DataType::LIBSVM) {
    int max_col_idx = 0;
    for (size_t i = 0; i < lines.size(); ++i) {
      auto str = Common::Trim(lines[i]);
      auto colon_pos = str.find_last_of(":");
      auto space_pos = str.find_last_of(" \f\t\v");
      auto sub_str = str.substr(space_pos + 1, space_pos - colon_pos - 1);
      int cur_idx = 0;
      Common::Atoi(sub_str.c_str(), &cur_idx);
      max_col_idx = std::max(cur_idx, max_col_idx);
Guolin Ke's avatar
Guolin Ke committed
186
    }
187
188
189
190
191
    *num_col = max_col_idx + 1;
  } else if (type == DataType::CSV) {
    *num_col = comma_cnt + 1;
  } else if (type == DataType::TSV) {
    *num_col = tab_cnt + 1;
Guolin Ke's avatar
Guolin Ke committed
192
  }
193
194
195
196
197
198
199
200
  return type;
}

Parser* Parser::CreateParser(const char* filename, bool header, int num_features, int label_idx) {
  const int n_read_line = 20;
  auto lines = ReadKLineFromFile(filename, header, n_read_line);
  int num_col = 0;
  DataType type = GetDataType(lines, &num_col);
Guolin Ke's avatar
Guolin Ke committed
201
  if (type == DataType::INVALID) {
202
    Log::Fatal("Unknown format of training data.");
Guolin Ke's avatar
Guolin Ke committed
203
  }
Guolin Ke's avatar
Guolin Ke committed
204
  std::unique_ptr<Parser> ret;
Guolin Ke's avatar
Guolin Ke committed
205
  int output_label_index = -1;
Guolin Ke's avatar
Guolin Ke committed
206
  if (type == DataType::LIBSVM) {
Guolin Ke's avatar
Guolin Ke committed
207
208
    output_label_index = GetLabelIdxForLibsvm(lines[0], num_features, label_idx);
    ret.reset(new LibSVMParser(output_label_index, num_col));
209
  } else if (type == DataType::TSV) {
Guolin Ke's avatar
Guolin Ke committed
210
211
    output_label_index = GetLabelIdxForTSV(lines[0], num_features, label_idx);
    ret.reset(new TSVParser(output_label_index, num_col));
212
  } else if (type == DataType::CSV) {
Guolin Ke's avatar
Guolin Ke committed
213
214
    output_label_index = GetLabelIdxForCSV(lines[0], num_features, label_idx);
    ret.reset(new CSVParser(output_label_index, num_col));
Guolin Ke's avatar
Guolin Ke committed
215
216
  }

Guolin Ke's avatar
Guolin Ke committed
217
  if (output_label_index < 0 && label_idx >= 0) {
218
    Log::Info("Data file %s doesn't contain a label column.", filename);
Guolin Ke's avatar
Guolin Ke committed
219
  }
Guolin Ke's avatar
Guolin Ke committed
220
  return ret.release();
Guolin Ke's avatar
Guolin Ke committed
221
222
223
}

}  // namespace LightGBM