Unverified Commit ae320e59 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix predict with header (#2643)

* fix predict with header

* avoid duplicated feature names
parent d1002776
...@@ -556,6 +556,7 @@ class Dataset { ...@@ -556,6 +556,7 @@ class Dataset {
Log::Fatal("Size of feature_names error, should equal with total number of features"); Log::Fatal("Size of feature_names error, should equal with total number of features");
} }
feature_names_ = std::vector<std::string>(feature_names); feature_names_ = std::vector<std::string>(feature_names);
std::unordered_set<std::string> feature_name_set;
// replace ' ' in feature_names with '_' // replace ' ' in feature_names with '_'
bool spaceInFeatureName = false; bool spaceInFeatureName = false;
for (auto& feature_name : feature_names_) { for (auto& feature_name : feature_names_) {
...@@ -571,6 +572,10 @@ class Dataset { ...@@ -571,6 +572,10 @@ class Dataset {
spaceInFeatureName = true; spaceInFeatureName = true;
std::replace(feature_name.begin(), feature_name.end(), ' ', '_'); std::replace(feature_name.begin(), feature_name.end(), ' ', '_');
} }
if (feature_name_set.count(feature_name) > 0) {
Log::Fatal("Feature (%s) appears more than one time.", feature_name.c_str());
}
feature_name_set.insert(feature_name);
} }
if (spaceInFeatureName) { if (spaceInFeatureName) {
Log::Warning("Find whitespaces in feature_names, replace with underlines"); Log::Warning("Find whitespaces in feature_names, replace with underlines");
......
...@@ -135,31 +135,38 @@ class Predictor { ...@@ -135,31 +135,38 @@ class Predictor {
if (!writer->Init()) { if (!writer->Init()) {
Log::Fatal("Prediction results file %s cannot be found", result_filename); Log::Fatal("Prediction results file %s cannot be found", result_filename);
} }
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(data_filename, header, boosting_->MaxFeatureIdx() + 1, boosting_->LabelIdx())); auto label_idx = header ? -1 : boosting_->LabelIdx();
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(data_filename, header, boosting_->MaxFeatureIdx() + 1, label_idx));
if (parser == nullptr) { if (parser == nullptr) {
Log::Fatal("Could not recognize the data format of data file %s", data_filename); Log::Fatal("Could not recognize the data format of data file %s", data_filename);
} }
if (parser->NumFeatures() != boosting_->MaxFeatureIdx() + 1) { if (!header && parser->NumFeatures() != boosting_->MaxFeatureIdx() + 1) {
Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).", parser->NumFeatures(), boosting_->MaxFeatureIdx() + 1); Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).", parser->NumFeatures(), boosting_->MaxFeatureIdx() + 1);
} }
TextReader<data_size_t> predict_data_reader(data_filename, header); TextReader<data_size_t> predict_data_reader(data_filename, header);
std::unordered_map<int, int> feature_names_map_; std::vector<int> feature_remapper(parser->NumFeatures(), -1);
bool need_adjust = false; bool need_adjust = false;
if (header) { if (header) {
std::string first_line = predict_data_reader.first_line(); std::string first_line = predict_data_reader.first_line();
std::vector<std::string> header_words = Common::Split(first_line.c_str(), "\t,"); std::vector<std::string> header_words = Common::Split(first_line.c_str(), "\t,");
header_words.erase(header_words.begin() + boosting_->LabelIdx()); std::unordered_map<std::string, int> header_mapper;
for (int i = 0; i < static_cast<int>(header_words.size()); ++i) { for (int i = 0; i < static_cast<int>(header_words.size()); ++i) {
for (int j = 0; j < static_cast<int>(boosting_->FeatureNames().size()); ++j) { if (header_mapper.count(header_words[i]) > 0) {
if (header_words[i] == boosting_->FeatureNames()[j]) { Log::Fatal("Feature (%s) appears more than one time.", header_words[i].c_str());
feature_names_map_[i] = j; }
break; header_mapper[header_words[i]] = i;
} }
const auto& fnames = boosting_->FeatureNames();
for (int i = 0; i < static_cast<int>(fnames.size()); ++i) {
if (header_mapper.count(fnames[i]) <= 0) {
Log::Warning("Feature (%s) is missed in data file. If it is weight/query/group/ignore_column, you can ignore this warning.", fnames[i].c_str());
} else {
feature_remapper[header_mapper.at(fnames[i])] = i;
} }
} }
for (auto s : feature_names_map_) { for (int i = 0; i < static_cast<int>(feature_remapper.size()); ++i) {
if (s.first != s.second) { if (feature_remapper[i] >= 0 && i != feature_remapper[i]) {
need_adjust = true; need_adjust = true;
break; break;
} }
...@@ -174,8 +181,8 @@ class Predictor { ...@@ -174,8 +181,8 @@ class Predictor {
if (need_adjust) { if (need_adjust) {
int i = 0, j = static_cast<int>(feature->size()); int i = 0, j = static_cast<int>(feature->size());
while (i < j) { while (i < j) {
if (feature_names_map_.find((*feature)[i].first) != feature_names_map_.end()) { if (feature_remapper[(*feature)[i].first] >= 0) {
(*feature)[i].first = feature_names_map_[(*feature)[i].first]; (*feature)[i].first = feature_remapper[(*feature)[i].first];
++i; ++i;
} else { } else {
// move the non-used features to the end of the feature vector // move the non-used features to the end of the feature vector
......
...@@ -201,18 +201,19 @@ Parser* Parser::CreateParser(const char* filename, bool header, int num_features ...@@ -201,18 +201,19 @@ Parser* Parser::CreateParser(const char* filename, bool header, int num_features
Log::Fatal("Unknown format of training data."); Log::Fatal("Unknown format of training data.");
} }
std::unique_ptr<Parser> ret; std::unique_ptr<Parser> ret;
int output_label_index = -1;
if (type == DataType::LIBSVM) { if (type == DataType::LIBSVM) {
label_idx = GetLabelIdxForLibsvm(lines[0], num_features, label_idx); output_label_index = GetLabelIdxForLibsvm(lines[0], num_features, label_idx);
ret.reset(new LibSVMParser(label_idx, num_col)); ret.reset(new LibSVMParser(output_label_index, num_col));
} else if (type == DataType::TSV) { } else if (type == DataType::TSV) {
label_idx = GetLabelIdxForTSV(lines[0], num_features, label_idx); output_label_index = GetLabelIdxForTSV(lines[0], num_features, label_idx);
ret.reset(new TSVParser(label_idx, num_col)); ret.reset(new TSVParser(output_label_index, num_col));
} else if (type == DataType::CSV) { } else if (type == DataType::CSV) {
label_idx = GetLabelIdxForCSV(lines[0], num_features, label_idx); output_label_index = GetLabelIdxForCSV(lines[0], num_features, label_idx);
ret.reset(new CSVParser(label_idx, num_col)); ret.reset(new CSVParser(output_label_index, num_col));
} }
if (label_idx < 0) { if (output_label_index < 0 && label_idx >= 0) {
Log::Info("Data file %s doesn't contain a label column.", filename); Log::Info("Data file %s doesn't contain a label column.", filename);
} }
return ret.release(); return ret.release();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment