Commit b24190a6 authored by Guolin Ke's avatar Guolin Ke
Browse files

check the type of parameter, output sorted importance

parent b23a2c31
...@@ -254,7 +254,10 @@ inline bool ConfigBase::GetInt( ...@@ -254,7 +254,10 @@ inline bool ConfigBase::GetInt(
const std::unordered_map<std::string, std::string>& params, const std::unordered_map<std::string, std::string>& params,
const std::string& name, int* out) { const std::string& name, int* out) {
if (params.count(name) > 0) { if (params.count(name) > 0) {
Common::Atoi(params.at(name).c_str(), out); if (!Common::AtoiAndCheck(params.at(name).c_str(), out)) {
Log::Fatal("Parameter %s should be int type, passed is [%s]",
name.c_str(), params.at(name).c_str());
}
return true; return true;
} }
return false; return false;
...@@ -264,7 +267,10 @@ inline bool ConfigBase::GetDouble( ...@@ -264,7 +267,10 @@ inline bool ConfigBase::GetDouble(
const std::unordered_map<std::string, std::string>& params, const std::unordered_map<std::string, std::string>& params,
const std::string& name, double* out) { const std::string& name, double* out) {
if (params.count(name) > 0) { if (params.count(name) > 0) {
Common::Atof(params.at(name).c_str(), out); if (!Common::AtofAndCheck(params.at(name).c_str(), out)) {
Log::Fatal("Parameter %s should be float type, passed is [%s]",
name.c_str(), params.at(name).c_str());
}
return true; return true;
} }
return false; return false;
...@@ -276,10 +282,13 @@ inline bool ConfigBase::GetBool( ...@@ -276,10 +282,13 @@ inline bool ConfigBase::GetBool(
if (params.count(name) > 0) { if (params.count(name) > 0) {
std::string value = params.at(name); std::string value = params.at(name);
std::transform(value.begin(), value.end(), value.begin(), ::tolower); std::transform(value.begin(), value.end(), value.begin(), ::tolower);
if (value == std::string("false")) { if (value == std::string("false") || value == std::string("-")) {
*out = false; *out = false;
} else { } else if (value == std::string("true") || value == std::string("+")) {
*out = true; *out = true;
} else {
Log::Fatal("Parameter %s should be \"true\"/\"+\" or \"false\"/\"-\", passed is [%s]",
name.c_str(), params.at(name).c_str());
} }
return true; return true;
} }
......
...@@ -195,6 +195,22 @@ inline static const char* Atof(const char* p, double* out) { ...@@ -195,6 +195,22 @@ inline static const char* Atof(const char* p, double* out) {
return p; return p;
} }
inline bool AtoiAndCheck(const char* p, int* out) {
const char* after = Atoi(p, out);
if (*after != '\0') {
return false;
}
return true;
}
inline bool AtofAndCheck(const char* p, double* out) {
const char* after = Atof(p, out);
if (*after != '\0') {
return false;
}
return true;
}
inline static const char* SkipSpaceAndTab(const char* p) { inline static const char* SkipSpaceAndTab(const char* p) {
while (*p == ' ' || *p == '\t') { while (*p == ' ' || *p == '\t') {
++p; ++p;
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <chrono> #include <chrono>
#include <string> #include <string>
#include <vector> #include <vector>
#include <utility>
namespace LightGBM { namespace LightGBM {
...@@ -374,16 +375,30 @@ void GBDT::ModelsFromString(const std::string& model_str, int num_used_model) { ...@@ -374,16 +375,30 @@ void GBDT::ModelsFromString(const std::string& model_str, int num_used_model) {
} }
void GBDT::FeatureImportance(const int last_iter) { void GBDT::FeatureImportance(const int last_iter) {
size_t* feature_importances = new size_t[max_feature_idx_ + 1]{0}; std::vector<size_t> feature_importances(max_feature_idx_ + 1, 0);
for (int iter = 0; iter < last_iter; ++iter) { for (int iter = 0; iter < last_iter; ++iter) {
for (int split_idx = 0; split_idx < models_.at(iter)->num_leaves() - 1; ++split_idx) { for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
++feature_importances[models_.at(iter)->split_feature_real(split_idx)]; ++feature_importances[models_[iter]->split_feature_real(split_idx)];
} }
} }
std::string ret = Common::ArrayToString(feature_importances, max_feature_idx_ + 1, ' '); // store the importance first
fprintf(output_model_file, "feature importances=%s\n", ret.c_str()); std::vector<std::pair<size_t, std::string>> pairs;
for (size_t i = 0; i < feature_importances.size(); ++i) {
pairs.emplace_back(feature_importances[i], train_data_->feature_names()[i]);
}
// sort the importance
std::sort(pairs.begin(), pairs.end(),
[](const std::pair<size_t, std::string>& lhs,
const std::pair<size_t, std::string>& rhs) {
return lhs.first > rhs.first;
});
// write to model file
fprintf(output_model_file, "\nfeature importances:\n");
for (size_t i = 0; i < pairs.size(); ++i) {
fprintf(output_model_file, "%s=%s\n", pairs[i].second.c_str(),
std::to_string(pairs[i].first).c_str());
}
fflush(output_model_file); fflush(output_model_file);
delete[] feature_importances;
} }
double GBDT::PredictRaw(const double* value) const { double GBDT::PredictRaw(const double* value) const {
......
...@@ -54,9 +54,7 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, ...@@ -54,9 +54,7 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
Log::Fatal("cannot find label column: %s in data file", name.c_str()); Log::Fatal("cannot find label column: %s in data file", name.c_str());
} }
} else { } else {
size_t pos = 0; if (!Common::AtoiAndCheck(io_config.label_column.c_str(), &label_idx_)) {
label_idx_ = std::stoi(io_config.label_column, &pos);
if (pos != io_config.label_column.size()) {
Log::Fatal("label_column is not a number, \ Log::Fatal("label_column is not a number, \
if you want to use column name, \ if you want to use column name, \
please add prefix \"name:\" before column name"); please add prefix \"name:\" before column name");
...@@ -84,9 +82,8 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, ...@@ -84,9 +82,8 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
} }
} else { } else {
for (auto token : Common::Split(io_config.ignore_column.c_str(), ',')) { for (auto token : Common::Split(io_config.ignore_column.c_str(), ',')) {
size_t pos = 0; int tmp = 0;
int tmp = std::stoi(token, &pos); if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
if (pos != token.size()) {
Log::Fatal("ignore_column is not a number, \ Log::Fatal("ignore_column is not a number, \
if you want to use column name, \ if you want to use column name, \
please add prefix \"name:\" before column name"); please add prefix \"name:\" before column name");
...@@ -110,9 +107,7 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, ...@@ -110,9 +107,7 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
Log::Fatal("cannot find weight column: %s in data file", name.c_str()); Log::Fatal("cannot find weight column: %s in data file", name.c_str());
} }
} else { } else {
size_t pos = 0; if (!Common::AtoiAndCheck(io_config.weight_column.c_str(), &weight_idx_)) {
weight_idx_ = std::stoi(io_config.weight_column, &pos);
if (pos != io_config.weight_column.size()) {
Log::Fatal("weight_column is not a number, \ Log::Fatal("weight_column is not a number, \
if you want to use column name, \ if you want to use column name, \
please add prefix \"name:\" before column name"); please add prefix \"name:\" before column name");
...@@ -136,9 +131,7 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, ...@@ -136,9 +131,7 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
Log::Fatal("cannot find group/query column: %s in data file", name.c_str()); Log::Fatal("cannot find group/query column: %s in data file", name.c_str());
} }
} else { } else {
size_t pos = 0; if (!Common::AtoiAndCheck(io_config.group_column.c_str(), &group_idx_)) {
group_idx_ = std::stoi(io_config.group_column, &pos);
if (pos != io_config.group_column.size()) {
Log::Fatal("group_column is not a number, \ Log::Fatal("group_column is not a number, \
if you want to use column name, \ if you want to use column name, \
please add prefix \"name:\" before column name"); please add prefix \"name:\" before column name");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment