Commit 7339ed64 authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

replace whitespaces with underlines in feature name (#426)

* change whitespace to underline in feature names

* add test

* fix bug

* fix bug

* warning -> fatal
parent d8bb5784
...@@ -495,10 +495,20 @@ public: ...@@ -495,10 +495,20 @@ public:
inline void set_feature_names(const std::vector<std::string>& feature_names) { inline void set_feature_names(const std::vector<std::string>& feature_names) {
if (feature_names.size() != static_cast<size_t>(num_total_features_)) { if (feature_names.size() != static_cast<size_t>(num_total_features_)) {
Log::Warning("size of feature_names error, should equal with total number of features"); Log::Fatal("Size of feature_names error, should equal with total number of features");
return;
} }
feature_names_ = std::vector<std::string>(feature_names); feature_names_ = std::vector<std::string>(feature_names);
// replace ' ' in feature_names with '_'
bool spaceInFeatureName = false;
for (auto& feature_name: feature_names_){
if (feature_name.find(' ') != std::string::npos){
spaceInFeatureName = true;
std::replace(feature_name.begin(), feature_name.end(), ' ', '_');
}
}
if (spaceInFeatureName){
Log::Warning("Find whitespaces in feature_names, replace with underlines");
}
} }
inline std::vector<std::string> feature_infos() const { inline std::vector<std::string> feature_infos() const {
......
...@@ -770,7 +770,7 @@ class Dataset(object): ...@@ -770,7 +770,7 @@ class Dataset(object):
"""create valid""" """create valid"""
self._lazy_init(self.data, label=self.label, max_bin=self.max_bin, reference=self.reference, self._lazy_init(self.data, label=self.label, max_bin=self.max_bin, reference=self.reference,
weight=self.weight, group=self.group, predictor=self._predictor, weight=self.weight, group=self.group, predictor=self._predictor,
silent=self.silent, params=self.params) silent=self.silent, feature_name=self.feature_name, params=self.params)
else: else:
"""construct subset""" """construct subset"""
used_indices = list_to_1d_numpy(self.used_indices, np.int32, name='used_indices') used_indices = list_to_1d_numpy(self.used_indices, np.int32, name='used_indices')
...@@ -1004,6 +1004,7 @@ class Dataset(object): ...@@ -1004,6 +1004,7 @@ class Dataset(object):
feature_name : list of str feature_name : list of str
Feature names Feature names
""" """
if feature_name != 'auto':
self.feature_name = feature_name self.feature_name = feature_name
if self.handle is not None and feature_name is not None and feature_name != 'auto': if self.handle is not None and feature_name is not None and feature_name != 'auto':
if len(feature_name) != self.num_feature(): if len(feature_name) != self.num_feature():
......
...@@ -28,7 +28,7 @@ void DatasetLoader::SetHeader(const char* filename) { ...@@ -28,7 +28,7 @@ void DatasetLoader::SetHeader(const char* filename) {
// get column names // get column names
if (io_config_.has_header) { if (io_config_.has_header) {
std::string first_line = text_reader.first_line(); std::string first_line = text_reader.first_line();
feature_names_ = Common::Split(first_line.c_str(), "\t ,"); feature_names_ = Common::Split(first_line.c_str(), "\t,");
} }
// load label idx first // load label idx first
...@@ -509,8 +509,8 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values, ...@@ -509,8 +509,8 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
} }
OMP_THROW_EX(); OMP_THROW_EX();
auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data)); auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));
dataset->feature_names_ = feature_names_;
dataset->Construct(bin_mappers, sample_indices, num_per_col, total_sample_size, io_config_); dataset->Construct(bin_mappers, sample_indices, num_per_col, total_sample_size, io_config_);
dataset->set_feature_names(feature_names_);
return dataset.release(); return dataset.release();
} }
...@@ -704,7 +704,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -704,7 +704,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
feature_names_.push_back(str_buf.str()); feature_names_.push_back(str_buf.str());
} }
} }
dataset->feature_names_ = feature_names_; dataset->set_feature_names(feature_names_);
std::vector<std::unique_ptr<BinMapper>> bin_mappers(sample_values.size()); std::vector<std::unique_ptr<BinMapper>> bin_mappers(sample_values.size());
const data_size_t filter_cnt = static_cast<data_size_t>( const data_size_t filter_cnt = static_cast<data_size_t>(
static_cast<double>(io_config_.min_data_in_leaf* sample_data.size()) / dataset->num_data_); static_cast<double>(io_config_.min_data_in_leaf* sample_data.size()) / dataset->num_data_);
......
...@@ -164,8 +164,12 @@ class TestEngine(unittest.TestCase): ...@@ -164,8 +164,12 @@ class TestEngine(unittest.TestCase):
def test_feature_name(self): def test_feature_name(self):
lgb_train, _ = template.test_template(return_data=True) lgb_train, _ = template.test_template(return_data=True)
feature_names = ['f' + str(i) for i in range(13)] feature_names = ['f_' + str(i) for i in range(13)]
gbm = lgb.train({'verbose': -1}, lgb_train, num_boost_round=10, feature_name=feature_names) gbm = lgb.train({'verbose': -1}, lgb_train, num_boost_round=5, feature_name=feature_names)
self.assertListEqual(feature_names, gbm.feature_name())
# test feature_names with whitespaces
feature_names_with_space = ['f ' + str(i) for i in range(13)]
gbm = lgb.train({'verbose': -1}, lgb_train, num_boost_round=5, feature_name=feature_names_with_space)
self.assertListEqual(feature_names, gbm.feature_name()) self.assertListEqual(feature_names, gbm.feature_name())
def test_save_load_copy_pickle(self): def test_save_load_copy_pickle(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment