parser.cpp 3.66 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
#include "parser.hpp"

#include <iostream>
#include <fstream>
Guolin Ke's avatar
Guolin Ke committed
5
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
6
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
7
8
9

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
10
void GetStatistic(const char* str, int* comma_cnt, int* tab_cnt, int* colon_cnt) {
Guolin Ke's avatar
Guolin Ke committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
  *comma_cnt = 0;
  *tab_cnt = 0;
  *colon_cnt = 0;
  for (int i = 0; str[i] != '\0'; ++i) {
    if (str[i] == ',') {
      ++(*comma_cnt);
    } else if (str[i] == '\t') {
      ++(*tab_cnt);
    } else if (str[i] == ':') {
      ++(*colon_cnt);
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
25
26
27
28
int GetLabelIdxForLibsvm(std::string& str, int num_features, int label_idx) {
  if (num_features <= 0) {
    return label_idx;
  }
29
30
31
  str = Common::Trim(str);
  auto pos_space = str.find_first_of(" \f\n\r\t\v");
  auto pos_colon = str.find_first_of(":");
32
  if (pos_space == std::string::npos || pos_space < pos_colon) {
Guolin Ke's avatar
Guolin Ke committed
33
    return label_idx;
34
35
  } else {
    return -1;
36
37
38
  }
}

Guolin Ke's avatar
Guolin Ke committed
39
40
41
42
int GetLabelIdxForTSV(std::string& str, int num_features, int label_idx) {
  if (num_features <= 0) {
    return label_idx;
  }
43
44
  str = Common::Trim(str);
  auto tokens = Common::Split(str.c_str(), '\t');
Guolin Ke's avatar
Guolin Ke committed
45
  if (static_cast<int>(tokens.size()) == num_features) {
Guolin Ke's avatar
Guolin Ke committed
46
    return -1;
47
  } else {
Guolin Ke's avatar
Guolin Ke committed
48
    return label_idx;
49
50
51
  }
}

Guolin Ke's avatar
Guolin Ke committed
52
53
54
55
int GetLabelIdxForCSV(std::string& str, int num_features, int label_idx) {
  if (num_features <= 0) {
    return label_idx;
  }
56
57
  str = Common::Trim(str);
  auto tokens = Common::Split(str.c_str(), ',');
Guolin Ke's avatar
Guolin Ke committed
58
  if (static_cast<int>(tokens.size()) == num_features) {
Guolin Ke's avatar
Guolin Ke committed
59
    return -1;
60
  } else {
Guolin Ke's avatar
Guolin Ke committed
61
    return label_idx;
62
63
64
  }
}

Guolin Ke's avatar
Guolin Ke committed
65
66
67
68
69
70
71
enum DataType {
  INVALID,
  CSV,
  TSV,
  LIBSVM
};

Guolin Ke's avatar
Guolin Ke committed
72
Parser* Parser::CreateParser(const char* filename, bool has_header, int num_features, int label_idx) {
Guolin Ke's avatar
Guolin Ke committed
73
74
75
  std::ifstream tmp_file;
  tmp_file.open(filename);
  if (!tmp_file.is_open()) {
76
    Log::Fatal("Data file %s doesn't exist'", filename);
Guolin Ke's avatar
Guolin Ke committed
77
78
  }
  std::string line1, line2;
Guolin Ke's avatar
Guolin Ke committed
79
80
81
82
83
  if (has_header) {
    if (!tmp_file.eof()) {
      std::getline(tmp_file, line1);
    }
  }
Guolin Ke's avatar
Guolin Ke committed
84
85
86
  if (!tmp_file.eof()) {
    std::getline(tmp_file, line1);
  } else {
87
    Log::Fatal("Data file %s should have at least one line", filename);
Guolin Ke's avatar
Guolin Ke committed
88
89
90
91
  }
  if (!tmp_file.eof()) {
    std::getline(tmp_file, line2);
  } else {
92
    Log::Warning("Data file %s only has one line", filename);
Guolin Ke's avatar
Guolin Ke committed
93
94
95
96
97
98
99
100
  }
  tmp_file.close();
  int comma_cnt = 0, comma_cnt2 = 0;
  int tab_cnt = 0, tab_cnt2 = 0;
  int colon_cnt = 0, colon_cnt2 = 0;
  // Get some statistic from 2 line
  GetStatistic(line1.c_str(), &comma_cnt, &tab_cnt, &colon_cnt);
  GetStatistic(line2.c_str(), &comma_cnt2, &tab_cnt2, &colon_cnt2);
101
102


Guolin Ke's avatar
Guolin Ke committed
103
104

  DataType type = DataType::INVALID;
Guolin Ke's avatar
Guolin Ke committed
105
106
107
  if (line2.size() == 0) {
    // if only have one line on file
    if (colon_cnt > 0) {
Guolin Ke's avatar
Guolin Ke committed
108
      type = DataType::LIBSVM;
Guolin Ke's avatar
Guolin Ke committed
109
    } else if (tab_cnt > 0) {
Guolin Ke's avatar
Guolin Ke committed
110
      type = DataType::TSV;
Guolin Ke's avatar
Guolin Ke committed
111
    } else if (comma_cnt > 0) {
Guolin Ke's avatar
Guolin Ke committed
112
113
      type = DataType::CSV;
    }
Guolin Ke's avatar
Guolin Ke committed
114
115
  } else {
    if (colon_cnt > 0 || colon_cnt2 > 0) {
Guolin Ke's avatar
Guolin Ke committed
116
117
118
      type = DataType::LIBSVM;
    } else if (tab_cnt == tab_cnt2 && tab_cnt > 0) {
      type = DataType::TSV;
Guolin Ke's avatar
Guolin Ke committed
119
    } else if (comma_cnt == comma_cnt2 && comma_cnt > 0) {
Guolin Ke's avatar
Guolin Ke committed
120
      type = DataType::CSV;
Guolin Ke's avatar
Guolin Ke committed
121
122
    }
  }
Guolin Ke's avatar
Guolin Ke committed
123
  if (type == DataType::INVALID) {
124
    Log::Fatal("Unknown format of training data");
Guolin Ke's avatar
Guolin Ke committed
125
  }
Guolin Ke's avatar
Guolin Ke committed
126
  std::unique_ptr<Parser> ret;
Guolin Ke's avatar
Guolin Ke committed
127
128
  if (type == DataType::LIBSVM) {
    label_idx = GetLabelIdxForLibsvm(line1, num_features, label_idx);
Guolin Ke's avatar
Guolin Ke committed
129
    ret.reset(new LibSVMParser(label_idx));
Guolin Ke's avatar
Guolin Ke committed
130
131
132
  }
  else if (type == DataType::TSV) {
    label_idx = GetLabelIdxForTSV(line1, num_features, label_idx);
Guolin Ke's avatar
Guolin Ke committed
133
    ret.reset(new TSVParser(label_idx));
Guolin Ke's avatar
Guolin Ke committed
134
135
136
  }
  else if (type == DataType::CSV) {
    label_idx = GetLabelIdxForCSV(line1, num_features, label_idx);
Guolin Ke's avatar
Guolin Ke committed
137
    ret.reset(new CSVParser(label_idx));
Guolin Ke's avatar
Guolin Ke committed
138
139
140
  }

  if (label_idx < 0) {
141
    Log::Info("Data file %s doesn't contain a label column", filename);
Guolin Ke's avatar
Guolin Ke committed
142
  }
Guolin Ke's avatar
Guolin Ke committed
143
  return ret.release();
Guolin Ke's avatar
Guolin Ke committed
144
145
146
}

}  // namespace LightGBM