parser.cpp 4.76 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#include "parser.hpp"

7
#include <string>
Guolin Ke's avatar
Guolin Ke committed
8
#include <fstream>
Guolin Ke's avatar
Guolin Ke committed
9
#include <functional>
10
#include <iostream>
Guolin Ke's avatar
Guolin Ke committed
11
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
12
13
14

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
15
void GetStatistic(const char* str, int* comma_cnt, int* tab_cnt, int* colon_cnt) {
Guolin Ke's avatar
Guolin Ke committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
  *comma_cnt = 0;
  *tab_cnt = 0;
  *colon_cnt = 0;
  for (int i = 0; str[i] != '\0'; ++i) {
    if (str[i] == ',') {
      ++(*comma_cnt);
    } else if (str[i] == '\t') {
      ++(*tab_cnt);
    } else if (str[i] == ':') {
      ++(*colon_cnt);
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
30
int GetLabelIdxForLibsvm(const std::string& str, int num_features, int label_idx) {
Guolin Ke's avatar
Guolin Ke committed
31
32
33
  if (num_features <= 0) {
    return label_idx;
  }
Guolin Ke's avatar
Guolin Ke committed
34
35
36
  auto str2 = Common::Trim(str);
  auto pos_space = str2.find_first_of(" \f\n\r\t\v");
  auto pos_colon = str2.find_first_of(":");
37
  if (pos_space == std::string::npos || pos_space < pos_colon) {
Guolin Ke's avatar
Guolin Ke committed
38
    return label_idx;
39
40
  } else {
    return -1;
41
42
43
  }
}

Guolin Ke's avatar
Guolin Ke committed
44
int GetLabelIdxForTSV(const std::string& str, int num_features, int label_idx) {
Guolin Ke's avatar
Guolin Ke committed
45
46
47
  if (num_features <= 0) {
    return label_idx;
  }
Guolin Ke's avatar
Guolin Ke committed
48
49
  auto str2 = Common::Trim(str);
  auto tokens = Common::Split(str2.c_str(), '\t');
Guolin Ke's avatar
Guolin Ke committed
50
  if (static_cast<int>(tokens.size()) == num_features) {
Guolin Ke's avatar
Guolin Ke committed
51
    return -1;
52
  } else {
Guolin Ke's avatar
Guolin Ke committed
53
    return label_idx;
54
55
56
  }
}

Guolin Ke's avatar
Guolin Ke committed
57
int GetLabelIdxForCSV(const std::string& str, int num_features, int label_idx) {
Guolin Ke's avatar
Guolin Ke committed
58
59
60
  if (num_features <= 0) {
    return label_idx;
  }
Guolin Ke's avatar
Guolin Ke committed
61
62
  auto str2 = Common::Trim(str);
  auto tokens = Common::Split(str2.c_str(), ',');
Guolin Ke's avatar
Guolin Ke committed
63
  if (static_cast<int>(tokens.size()) == num_features) {
Guolin Ke's avatar
Guolin Ke committed
64
    return -1;
65
  } else {
Guolin Ke's avatar
Guolin Ke committed
66
    return label_idx;
67
68
69
  }
}

Guolin Ke's avatar
Guolin Ke committed
70
71
72
73
74
75
76
enum DataType {
  INVALID,
  CSV,
  TSV,
  LIBSVM
};

Guolin Ke's avatar
Guolin Ke committed
77
78
79
80
void GetLine(std::stringstream* ss, std::string* line, const VirtualFileReader* reader, std::vector<char>* buffer, size_t buffer_size) {
  std::getline(*ss, *line);
  while (ss->eof()) {
    size_t read_len = reader->Read(buffer->data(), buffer_size);
81
82
83
    if (read_len <= 0) {
      break;
    }
Guolin Ke's avatar
Guolin Ke committed
84
85
    ss->clear();
    ss->str(std::string(buffer->data(), read_len));
86
    std::string tmp;
Guolin Ke's avatar
Guolin Ke committed
87
88
    std::getline(*ss, tmp);
    *line += tmp;
89
90
91
  }
}

Guolin Ke's avatar
Guolin Ke committed
92
Parser* Parser::CreateParser(const char* filename, bool header, int num_features, int label_idx) {
93
94
  auto reader = VirtualFileReader::Make(filename);
  if (!reader->Init()) {
95
    Log::Fatal("Data file %s doesn't exist", filename);
Guolin Ke's avatar
Guolin Ke committed
96
97
  }
  std::string line1, line2;
98
99
100
101
102
103
104
105
  size_t buffer_size = 64 * 1024;
  auto buffer = std::vector<char>(buffer_size);
  size_t read_len = reader->Read(buffer.data(), buffer_size);
  if (read_len <= 0) {
    Log::Fatal("Data file %s couldn't be read", filename);
  }

  std::stringstream tmp_file(std::string(buffer.data(), read_len));
Guolin Ke's avatar
Guolin Ke committed
106
  if (header) {
Guolin Ke's avatar
Guolin Ke committed
107
    if (!tmp_file.eof()) {
Guolin Ke's avatar
Guolin Ke committed
108
      GetLine(&tmp_file, &line1, reader.get(), &buffer, buffer_size);
Guolin Ke's avatar
Guolin Ke committed
109
110
    }
  }
Guolin Ke's avatar
Guolin Ke committed
111
  if (!tmp_file.eof()) {
Guolin Ke's avatar
Guolin Ke committed
112
    GetLine(&tmp_file, &line1, reader.get(), &buffer, buffer_size);
Guolin Ke's avatar
Guolin Ke committed
113
  } else {
114
    Log::Fatal("Data file %s should have at least one line", filename);
Guolin Ke's avatar
Guolin Ke committed
115
116
  }
  if (!tmp_file.eof()) {
Guolin Ke's avatar
Guolin Ke committed
117
    GetLine(&tmp_file, &line2, reader.get(), &buffer, buffer_size);
Guolin Ke's avatar
Guolin Ke committed
118
  } else {
119
    Log::Warning("Data file %s only has one line", filename);
Guolin Ke's avatar
Guolin Ke committed
120
121
122
123
124
125
126
  }
  int comma_cnt = 0, comma_cnt2 = 0;
  int tab_cnt = 0, tab_cnt2 = 0;
  int colon_cnt = 0, colon_cnt2 = 0;
  // Get some statistic from 2 line
  GetStatistic(line1.c_str(), &comma_cnt, &tab_cnt, &colon_cnt);
  GetStatistic(line2.c_str(), &comma_cnt2, &tab_cnt2, &colon_cnt2);
127
128


Guolin Ke's avatar
Guolin Ke committed
129
130

  DataType type = DataType::INVALID;
Guolin Ke's avatar
Guolin Ke committed
131
132
133
  if (line2.size() == 0) {
    // if only have one line on file
    if (colon_cnt > 0) {
Guolin Ke's avatar
Guolin Ke committed
134
      type = DataType::LIBSVM;
Guolin Ke's avatar
Guolin Ke committed
135
    } else if (tab_cnt > 0) {
Guolin Ke's avatar
Guolin Ke committed
136
      type = DataType::TSV;
Guolin Ke's avatar
Guolin Ke committed
137
    } else if (comma_cnt > 0) {
Guolin Ke's avatar
Guolin Ke committed
138
139
      type = DataType::CSV;
    }
Guolin Ke's avatar
Guolin Ke committed
140
141
  } else {
    if (colon_cnt > 0 || colon_cnt2 > 0) {
Guolin Ke's avatar
Guolin Ke committed
142
143
144
      type = DataType::LIBSVM;
    } else if (tab_cnt == tab_cnt2 && tab_cnt > 0) {
      type = DataType::TSV;
Guolin Ke's avatar
Guolin Ke committed
145
      CHECK(tab_cnt == tab_cnt2);
Guolin Ke's avatar
Guolin Ke committed
146
    } else if (comma_cnt == comma_cnt2 && comma_cnt > 0) {
Guolin Ke's avatar
Guolin Ke committed
147
      type = DataType::CSV;
Guolin Ke's avatar
Guolin Ke committed
148
      CHECK(comma_cnt == comma_cnt2);
Guolin Ke's avatar
Guolin Ke committed
149
150
    }
  }
Guolin Ke's avatar
Guolin Ke committed
151
  if (type == DataType::INVALID) {
152
    Log::Fatal("Unknown format of training data");
Guolin Ke's avatar
Guolin Ke committed
153
  }
Guolin Ke's avatar
Guolin Ke committed
154
  std::unique_ptr<Parser> ret;
Guolin Ke's avatar
Guolin Ke committed
155
156
  if (type == DataType::LIBSVM) {
    label_idx = GetLabelIdxForLibsvm(line1, num_features, label_idx);
Guolin Ke's avatar
Guolin Ke committed
157
    ret.reset(new LibSVMParser(label_idx));
158
  } else if (type == DataType::TSV) {
Guolin Ke's avatar
Guolin Ke committed
159
    label_idx = GetLabelIdxForTSV(line1, num_features, label_idx);
Guolin Ke's avatar
Guolin Ke committed
160
    ret.reset(new TSVParser(label_idx, tab_cnt + 1));
161
  } else if (type == DataType::CSV) {
Guolin Ke's avatar
Guolin Ke committed
162
    label_idx = GetLabelIdxForCSV(line1, num_features, label_idx);
Guolin Ke's avatar
Guolin Ke committed
163
    ret.reset(new CSVParser(label_idx, comma_cnt + 1));
Guolin Ke's avatar
Guolin Ke committed
164
165
166
  }

  if (label_idx < 0) {
167
    Log::Info("Data file %s doesn't contain a label column", filename);
Guolin Ke's avatar
Guolin Ke committed
168
  }
Guolin Ke's avatar
Guolin Ke committed
169
  return ret.release();
Guolin Ke's avatar
Guolin Ke committed
170
171
172
}

}  // namespace LightGBM