Commit a4c022a7 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix #287

parent ae7bbb6f
...@@ -136,7 +136,8 @@ public: ...@@ -136,7 +136,8 @@ public:
*/ */
inline std::string bin_info() const { inline std::string bin_info() const {
std::stringstream str_buf; std::stringstream str_buf;
str_buf << '[' << min_val_ << ',' << max_val_ << ']'; str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
str_buf << '[' << min_val_ << ':' << max_val_ << ']';
return str_buf.str(); return str_buf.str();
} }
......
...@@ -471,6 +471,20 @@ public: ...@@ -471,6 +471,20 @@ public:
feature_names_ = std::vector<std::string>(feature_names); feature_names_ = std::vector<std::string>(feature_names);
} }
inline std::vector<std::string> feature_infos() const {
std::vector<std::string> bufs;
for (int i = 0; i < num_total_features_; i++) {
int fidx = used_feature_map_[i];
if (fidx == -1) {
bufs.push_back("none");
} else {
const auto bin_mapper = FeatureBinMapper(fidx);
bufs.push_back(bin_mapper->bin_info());
}
}
return bufs;
}
/*! \brief Get Number of data */ /*! \brief Get Number of data */
inline data_size_t num_data() const { return num_data_; } inline data_size_t num_data() const { return num_data_; }
......
...@@ -104,6 +104,8 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -104,6 +104,8 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
label_idx_ = train_data->label_idx(); label_idx_ = train_data->label_idx();
// get feature names // get feature names
feature_names_ = train_data->feature_names(); feature_names_ = train_data->feature_names();
feature_infos_ = train_data->feature_infos();
} }
if ((train_data_ != train_data && train_data != nullptr) if ((train_data_ != train_data && train_data != nullptr)
...@@ -558,6 +560,8 @@ std::string GBDT::SaveModelToString(int num_iterations) const { ...@@ -558,6 +560,8 @@ std::string GBDT::SaveModelToString(int num_iterations) const {
ss << "feature_names=" << Common::Join(feature_names_, " ") << std::endl; ss << "feature_names=" << Common::Join(feature_names_, " ") << std::endl;
ss << "feature_infos=" << Common::Join(feature_infos_, " ") << std::endl;
ss << std::endl; ss << std::endl;
int num_used_model = static_cast<int>(models_.size()); int num_used_model = static_cast<int>(models_.size());
if (num_iterations > 0) { if (num_iterations > 0) {
...@@ -640,6 +644,18 @@ bool GBDT::LoadModelFromString(const std::string& model_str) { ...@@ -640,6 +644,18 @@ bool GBDT::LoadModelFromString(const std::string& model_str) {
return false; return false;
} }
line = Common::FindFromLines(lines, "feature_infos=");
if (line.size() > 0) {
feature_infos_ = Common::Split(line.substr(std::strlen("feature_infos=")).c_str(), " ");
if (feature_infos_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
Log::Fatal("Wrong size of feature_infos");
return false;
}
} else {
Log::Fatal("Model file doesn't contain feature infos");
return false;
}
// get tree models // get tree models
size_t i = 0; size_t i = 0;
while (i < lines.size()) { while (i < lines.size()) {
......
...@@ -329,6 +329,7 @@ protected: ...@@ -329,6 +329,7 @@ protected:
int num_init_iteration_; int num_init_iteration_;
/*! \brief Feature names */ /*! \brief Feature names */
std::vector<std::string> feature_names_; std::vector<std::string> feature_names_;
std::vector<std::string> feature_infos_;
/*! \brief number of threads */ /*! \brief number of threads */
int num_threads_; int num_threads_;
/*! \brief Buffer for multi-threading bagging */ /*! \brief Buffer for multi-threading bagging */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment