"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "cebdc2a8c436dfc92c6169b2f54ddaecee827cb4"
Commit 9fd5bc25 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix bug for get feature name from freed dataset

parent 7dec4dec
...@@ -107,6 +107,8 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -107,6 +107,8 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
max_feature_idx_ = train_data->num_total_features() - 1; max_feature_idx_ = train_data->num_total_features() - 1;
// get label index // get label index
label_idx_ = train_data->label_idx(); label_idx_ = train_data->label_idx();
// get feature names
feature_names_ = train_data->feature_names();
} }
if ((train_data_ != train_data && train_data != nullptr) if ((train_data_ != train_data && train_data != nullptr)
...@@ -472,14 +474,8 @@ std::string GBDT::DumpModel(int num_iteration) const { ...@@ -472,14 +474,8 @@ std::string GBDT::DumpModel(int num_iteration) const {
str_buf << "\"max_feature_idx\":" << max_feature_idx_ << "," << std::endl; str_buf << "\"max_feature_idx\":" << max_feature_idx_ << "," << std::endl;
str_buf << "\"sigmoid\":" << sigmoid_ << "," << std::endl; str_buf << "\"sigmoid\":" << sigmoid_ << "," << std::endl;
// output feature names
auto feature_names = std::ref(feature_names_);
if (train_data_ != nullptr) {
feature_names = std::ref(train_data_->feature_names());
}
str_buf << "\"feature_names\":[\"" str_buf << "\"feature_names\":[\""
<< Common::Join(feature_names.get(), "\",\"") << "\"]," << Common::Join(feature_names_, "\",\"") << "\"],"
<< std::endl; << std::endl;
str_buf << "\"tree_info\":["; str_buf << "\"tree_info\":[";
...@@ -521,12 +517,8 @@ void GBDT::SaveModelToFile(int num_iteration, const char* filename) const { ...@@ -521,12 +517,8 @@ void GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
} }
// output sigmoid parameter // output sigmoid parameter
output_file << "sigmoid=" << sigmoid_ << std::endl; output_file << "sigmoid=" << sigmoid_ << std::endl;
// output feature names
auto feature_names = std::ref(feature_names_); output_file << "feature_names=" << Common::Join(feature_names_, " ") << std::endl;
if (train_data_ != nullptr) {
feature_names = std::ref(train_data_->feature_names());
}
output_file << "feature_names=" << Common::Join(feature_names.get(), " ") << std::endl;
output_file << std::endl; output_file << std::endl;
int num_used_model = static_cast<int>(models_.size()); int num_used_model = static_cast<int>(models_.size());
...@@ -619,10 +611,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) { ...@@ -619,10 +611,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
} }
std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const { std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const {
auto feature_names = std::ref(feature_names_);
if (train_data_ != nullptr) {
feature_names = std::ref(train_data_->feature_names());
}
std::vector<size_t> feature_importances(max_feature_idx_ + 1, 0); std::vector<size_t> feature_importances(max_feature_idx_ + 1, 0);
for (size_t iter = 0; iter < models_.size(); ++iter) { for (size_t iter = 0; iter < models_.size(); ++iter) {
for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) { for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
...@@ -633,7 +622,7 @@ std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const { ...@@ -633,7 +622,7 @@ std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const {
std::vector<std::pair<size_t, std::string>> pairs; std::vector<std::pair<size_t, std::string>> pairs;
for (size_t i = 0; i < feature_importances.size(); ++i) { for (size_t i = 0; i < feature_importances.size(); ++i) {
if (feature_importances[i] > 0) { if (feature_importances[i] > 0) {
pairs.emplace_back(feature_importances[i], feature_names.get().at(i)); pairs.emplace_back(feature_importances[i], feature_names_[i]);
} }
} }
// sort the importance // sort the importance
......
...@@ -22,7 +22,7 @@ class FeatureParallelTreeLearner: public SerialTreeLearner { ...@@ -22,7 +22,7 @@ class FeatureParallelTreeLearner: public SerialTreeLearner {
public: public:
explicit FeatureParallelTreeLearner(const TreeConfig* tree_config); explicit FeatureParallelTreeLearner(const TreeConfig* tree_config);
~FeatureParallelTreeLearner(); ~FeatureParallelTreeLearner();
virtual void Init(const Dataset* train_data); void Init(const Dataset* train_data) override;
protected: protected:
void BeforeTrain() override; void BeforeTrain() override;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment