"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "b793cd821c1d887a74b2d680d2d299f0275fb7a6"
Commit 9fd5bc25 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix bug for get feature name from freed dataset

parent 7dec4dec
......@@ -107,6 +107,8 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
max_feature_idx_ = train_data->num_total_features() - 1;
// get label index
label_idx_ = train_data->label_idx();
// get feature names
feature_names_ = train_data->feature_names();
}
if ((train_data_ != train_data && train_data != nullptr)
......@@ -472,14 +474,8 @@ std::string GBDT::DumpModel(int num_iteration) const {
str_buf << "\"max_feature_idx\":" << max_feature_idx_ << "," << std::endl;
str_buf << "\"sigmoid\":" << sigmoid_ << "," << std::endl;
// output feature names
auto feature_names = std::ref(feature_names_);
if (train_data_ != nullptr) {
feature_names = std::ref(train_data_->feature_names());
}
str_buf << "\"feature_names\":[\""
<< Common::Join(feature_names.get(), "\",\"") << "\"],"
<< Common::Join(feature_names_, "\",\"") << "\"],"
<< std::endl;
str_buf << "\"tree_info\":[";
......@@ -521,12 +517,8 @@ void GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
}
// output sigmoid parameter
output_file << "sigmoid=" << sigmoid_ << std::endl;
// output feature names
auto feature_names = std::ref(feature_names_);
if (train_data_ != nullptr) {
feature_names = std::ref(train_data_->feature_names());
}
output_file << "feature_names=" << Common::Join(feature_names.get(), " ") << std::endl;
output_file << "feature_names=" << Common::Join(feature_names_, " ") << std::endl;
output_file << std::endl;
int num_used_model = static_cast<int>(models_.size());
......@@ -619,10 +611,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
}
std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const {
auto feature_names = std::ref(feature_names_);
if (train_data_ != nullptr) {
feature_names = std::ref(train_data_->feature_names());
}
std::vector<size_t> feature_importances(max_feature_idx_ + 1, 0);
for (size_t iter = 0; iter < models_.size(); ++iter) {
for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
......@@ -633,7 +622,7 @@ std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const {
std::vector<std::pair<size_t, std::string>> pairs;
for (size_t i = 0; i < feature_importances.size(); ++i) {
if (feature_importances[i] > 0) {
pairs.emplace_back(feature_importances[i], feature_names.get().at(i));
pairs.emplace_back(feature_importances[i], feature_names_[i]);
}
}
// sort the importance
......
......@@ -22,7 +22,7 @@ class FeatureParallelTreeLearner: public SerialTreeLearner {
public:
explicit FeatureParallelTreeLearner(const TreeConfig* tree_config);
~FeatureParallelTreeLearner();
virtual void Init(const Dataset* train_data);
void Init(const Dataset* train_data) override;
protected:
void BeforeTrain() override;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment