Commit de2f6ab2 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix a bug when using label_idx

parent 1a4a7281
...@@ -264,6 +264,8 @@ public: ...@@ -264,6 +264,8 @@ public:
virtual void ParseOneLine(const char* str, virtual void ParseOneLine(const char* str,
std::vector<std::pair<int, double>>* out_features, double* out_label) const = 0; std::vector<std::pair<int, double>>* out_features, double* out_label) const = 0;
virtual int TotalColumns() const = 0;
/*! /*!
* \brief Create a object of parser, will auto choose the format depend on file * \brief Create a object of parser, will auto choose the format depend on file
* \param filename One Filename of data * \param filename One Filename of data
......
...@@ -767,8 +767,8 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -767,8 +767,8 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
if (feature_names_.empty()) { if (feature_names_.empty()) {
// -1 means doesn't use this feature // -1 means doesn't use this feature
dataset->used_feature_map_ = std::vector<int>(sample_values.size(), -1); dataset->num_total_features_ = std::max(static_cast<int>(sample_values.size()), parser->TotalColumns() - 1);
dataset->num_total_features_ = static_cast<int>(sample_values.size()); dataset->used_feature_map_ = std::vector<int>(dataset->num_total_features_, -1);
} else { } else {
dataset->used_feature_map_ = std::vector<int>(feature_names_.size(), -1); dataset->used_feature_map_ = std::vector<int>(feature_names_.size(), -1);
dataset->num_total_features_ = static_cast<int>(feature_names_.size()); dataset->num_total_features_ = static_cast<int>(feature_names_.size());
......
...@@ -137,8 +137,10 @@ Parser* Parser::CreateParser(const char* filename, bool has_header, int num_feat ...@@ -137,8 +137,10 @@ Parser* Parser::CreateParser(const char* filename, bool has_header, int num_feat
type = DataType::LIBSVM; type = DataType::LIBSVM;
} else if (tab_cnt == tab_cnt2 && tab_cnt > 0) { } else if (tab_cnt == tab_cnt2 && tab_cnt > 0) {
type = DataType::TSV; type = DataType::TSV;
CHECK(tab_cnt == tab_cnt2);
} else if (comma_cnt == comma_cnt2 && comma_cnt > 0) { } else if (comma_cnt == comma_cnt2 && comma_cnt > 0) {
type = DataType::CSV; type = DataType::CSV;
CHECK(comma_cnt == comma_cnt2);
} }
} }
if (type == DataType::INVALID) { if (type == DataType::INVALID) {
...@@ -151,11 +153,11 @@ Parser* Parser::CreateParser(const char* filename, bool has_header, int num_feat ...@@ -151,11 +153,11 @@ Parser* Parser::CreateParser(const char* filename, bool has_header, int num_feat
} }
else if (type == DataType::TSV) { else if (type == DataType::TSV) {
label_idx = GetLabelIdxForTSV(line1, num_features, label_idx); label_idx = GetLabelIdxForTSV(line1, num_features, label_idx);
ret.reset(new TSVParser(label_idx)); ret.reset(new TSVParser(label_idx, tab_cnt + 1));
} }
else if (type == DataType::CSV) { else if (type == DataType::CSV) {
label_idx = GetLabelIdxForCSV(line1, num_features, label_idx); label_idx = GetLabelIdxForCSV(line1, num_features, label_idx);
ret.reset(new CSVParser(label_idx)); ret.reset(new CSVParser(label_idx, comma_cnt + 1));
} }
if (label_idx < 0) { if (label_idx < 0) {
......
...@@ -14,8 +14,8 @@ namespace LightGBM { ...@@ -14,8 +14,8 @@ namespace LightGBM {
class CSVParser: public Parser { class CSVParser: public Parser {
public: public:
explicit CSVParser(int label_idx) explicit CSVParser(int label_idx, int total_columns)
:label_idx_(label_idx) { :label_idx_(label_idx), total_columns_(total_columns) {
} }
inline void ParseOneLine(const char* str, inline void ParseOneLine(const char* str,
std::vector<std::pair<int, double>>* out_features, double* out_label) const override { std::vector<std::pair<int, double>>* out_features, double* out_label) const override {
...@@ -40,14 +40,19 @@ public: ...@@ -40,14 +40,19 @@ public:
} }
} }
} }
inline int TotalColumns() const override {
return total_columns_;
}
private: private:
int label_idx_ = 0; int label_idx_ = 0;
int total_columns_ = -1;
}; };
class TSVParser: public Parser { class TSVParser: public Parser {
public: public:
explicit TSVParser(int label_idx) explicit TSVParser(int label_idx, int total_columns)
:label_idx_(label_idx) { :label_idx_(label_idx), total_columns_(total_columns) {
} }
inline void ParseOneLine(const char* str, inline void ParseOneLine(const char* str,
std::vector<std::pair<int, double>>* out_features, double* out_label) const override { std::vector<std::pair<int, double>>* out_features, double* out_label) const override {
...@@ -70,8 +75,13 @@ public: ...@@ -70,8 +75,13 @@ public:
} }
} }
} }
inline int TotalColumns() const override {
return total_columns_;
}
private: private:
int label_idx_ = 0; int label_idx_ = 0;
int total_columns_ = -1;
}; };
class LibSVMParser: public Parser { class LibSVMParser: public Parser {
...@@ -104,6 +114,10 @@ public: ...@@ -104,6 +114,10 @@ public:
str = Common::SkipSpaceAndTab(str); str = Common::SkipSpaceAndTab(str);
} }
} }
inline int TotalColumns() const override {
return -1;
}
private: private:
int label_idx_ = 0; int label_idx_ = 0;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment