"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "0b3d9da2eb042e90168501ccb07f9367888bee32"
Commit afe63f1a authored by ww's avatar ww Committed by Guolin Ke
Browse files

Use feature name in prediction (#988)

* add feature

* add label idx move func to .cpp file

* fix

* move func to predictor

* restore files

* fix by commits

* move the non-used features to the end of the feature vector

* fix by commits

* fix bug by commits
parent db9ec217
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <LightGBM/utils/openmp_wrapper.h> #include <LightGBM/utils/openmp_wrapper.h>
#include <map>
#include <cstring> #include <cstring>
#include <cstdio> #include <cstdio>
#include <vector> #include <vector>
...@@ -128,12 +129,48 @@ public: ...@@ -128,12 +129,48 @@ public:
Log::Fatal("Could not recognize the data format of data file %s.", data_filename); Log::Fatal("Could not recognize the data format of data file %s.", data_filename);
} }
TextReader<data_size_t> predict_data_reader(data_filename, has_header);
std::unordered_map<int, int> feature_names_map_;
bool need_adjust = false;
if(has_header) {
std::string first_line = predict_data_reader.first_line();
std::vector<std::string> header = Common::Split(first_line.c_str(), "\t,");
header.erase(header.begin() + boosting_->LabelIdx());
for(int i = 0; i < static_cast<int>(header.size()); ++i) {
for(int j = 0; j < static_cast<int>(boosting_->FeatureNames().size()); ++j) {
if(header[i] == boosting_->FeatureNames()[j]) {
feature_names_map_[i] = j;
break;
}
}
}
for(auto s:feature_names_map_) {
if(s.first != s.second) {
need_adjust = true;
break;
}
}
}
// function for parse data // function for parse data
std::function<void(const char*, std::vector<std::pair<int, double>>*)> parser_fun; std::function<void(const char*, std::vector<std::pair<int, double>>*)> parser_fun;
double tmp_label; double tmp_label;
parser_fun = [this, &parser, &tmp_label] parser_fun = [this, &parser, &tmp_label, &need_adjust, &feature_names_map_]
(const char* buffer, std::vector<std::pair<int, double>>* feature) { (const char* buffer, std::vector<std::pair<int, double>>* feature) {
parser->ParseOneLine(buffer, feature, &tmp_label); parser->ParseOneLine(buffer, feature, &tmp_label);
if(need_adjust) {
int i = 0, j = static_cast<int>(feature->size());
while(i < j) {
if(feature_names_map_.find((*feature)[i].first) != feature_names_map_.end()) {
(*feature)[i].first = feature_names_map_[(*feature)[i].first];
++i;
}
else {
//move the non-used features to the end of the feature vector
std::swap((*feature)[i], (*feature)[--j]);
}
}
feature->resize(i);
}
}; };
std::function<void(data_size_t, const std::vector<std::string>&)> process_fun = std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
...@@ -160,7 +197,6 @@ public: ...@@ -160,7 +197,6 @@ public:
fprintf(result_file, "%s\n", result_to_write[i].c_str()); fprintf(result_file, "%s\n", result_to_write[i].c_str());
} }
}; };
TextReader<data_size_t> predict_data_reader(data_filename, has_header);
predict_data_reader.ReadAllAndProcessParallel(process_fun); predict_data_reader.ReadAllAndProcessParallel(process_fun);
fclose(result_file); fclose(result_file);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment