Commit de2f6ab2 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix a bug when using label_idx

parent 1a4a7281
......@@ -264,6 +264,8 @@ public:
virtual void ParseOneLine(const char* str,
std::vector<std::pair<int, double>>* out_features, double* out_label) const = 0;
virtual int TotalColumns() const = 0;
/*!
* \brief Create a object of parser, will auto choose the format depend on file
* \param filename One Filename of data
......
......@@ -767,8 +767,8 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
if (feature_names_.empty()) {
// -1 means doesn't use this feature
dataset->used_feature_map_ = std::vector<int>(sample_values.size(), -1);
dataset->num_total_features_ = static_cast<int>(sample_values.size());
dataset->num_total_features_ = std::max(static_cast<int>(sample_values.size()), parser->TotalColumns() - 1);
dataset->used_feature_map_ = std::vector<int>(dataset->num_total_features_, -1);
} else {
dataset->used_feature_map_ = std::vector<int>(feature_names_.size(), -1);
dataset->num_total_features_ = static_cast<int>(feature_names_.size());
......
......@@ -137,8 +137,10 @@ Parser* Parser::CreateParser(const char* filename, bool has_header, int num_feat
type = DataType::LIBSVM;
} else if (tab_cnt == tab_cnt2 && tab_cnt > 0) {
type = DataType::TSV;
CHECK(tab_cnt == tab_cnt2);
} else if (comma_cnt == comma_cnt2 && comma_cnt > 0) {
type = DataType::CSV;
CHECK(comma_cnt == comma_cnt2);
}
}
if (type == DataType::INVALID) {
......@@ -151,11 +153,11 @@ Parser* Parser::CreateParser(const char* filename, bool has_header, int num_feat
}
else if (type == DataType::TSV) {
label_idx = GetLabelIdxForTSV(line1, num_features, label_idx);
ret.reset(new TSVParser(label_idx));
ret.reset(new TSVParser(label_idx, tab_cnt + 1));
}
else if (type == DataType::CSV) {
label_idx = GetLabelIdxForCSV(line1, num_features, label_idx);
ret.reset(new CSVParser(label_idx));
ret.reset(new CSVParser(label_idx, comma_cnt + 1));
}
if (label_idx < 0) {
......
......@@ -14,8 +14,8 @@ namespace LightGBM {
class CSVParser: public Parser {
public:
explicit CSVParser(int label_idx)
:label_idx_(label_idx) {
explicit CSVParser(int label_idx, int total_columns)
:label_idx_(label_idx), total_columns_(total_columns) {
}
inline void ParseOneLine(const char* str,
std::vector<std::pair<int, double>>* out_features, double* out_label) const override {
......@@ -40,14 +40,19 @@ public:
}
}
}
inline int TotalColumns() const override {
return total_columns_;
}
private:
int label_idx_ = 0;
int total_columns_ = -1;
};
class TSVParser: public Parser {
public:
explicit TSVParser(int label_idx)
:label_idx_(label_idx) {
explicit TSVParser(int label_idx, int total_columns)
:label_idx_(label_idx), total_columns_(total_columns) {
}
inline void ParseOneLine(const char* str,
std::vector<std::pair<int, double>>* out_features, double* out_label) const override {
......@@ -70,8 +75,13 @@ public:
}
}
}
inline int TotalColumns() const override {
return total_columns_;
}
private:
int label_idx_ = 0;
int total_columns_ = -1;
};
class LibSVMParser: public Parser {
......@@ -104,6 +114,10 @@ public:
str = Common::SkipSpaceAndTab(str);
}
}
inline int TotalColumns() const override {
return -1;
}
private:
int label_idx_ = 0;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment