Commit da031707 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix size of feature_name error when tail features are all zeros

parent 9db054cf
...@@ -332,14 +332,14 @@ DllExport int LGBM_DatasetCreateFromCSR(const void* indptr, ...@@ -332,14 +332,14 @@ DllExport int LGBM_DatasetCreateFromCSR(const void* indptr,
auto idx = sample_indices[i]; auto idx = sample_indices[i];
auto row = get_row_fun(static_cast<int>(idx)); auto row = get_row_fun(static_cast<int>(idx));
for (std::pair<int, double>& inner_data : row) { for (std::pair<int, double>& inner_data : row) {
if (std::fabs(inner_data.second) > 1e-15) { if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
if (static_cast<size_t>(inner_data.first) >= sample_values.size()) { // if need expand feature set
// if need expand feature set size_t need_size = inner_data.first - sample_values.size() + 1;
size_t need_size = inner_data.first - sample_values.size() + 1; for (size_t j = 0; j < need_size; ++j) {
for (size_t j = 0; j < need_size; ++j) { sample_values.emplace_back();
sample_values.emplace_back();
}
} }
}
if (std::fabs(inner_data.second) > 1e-15) {
// edit the feature value // edit the feature value
sample_values[inner_data.first].push_back(inner_data.second); sample_values[inner_data.first].push_back(inner_data.second);
} }
......
...@@ -496,6 +496,10 @@ void DatasetLoader::CheckDataset(const Dataset* dataset) { ...@@ -496,6 +496,10 @@ void DatasetLoader::CheckDataset(const Dataset* dataset) {
if (dataset->features_.empty()) { if (dataset->features_.empty()) {
Log::Fatal("No usable features in data file %s", dataset->data_filename_); Log::Fatal("No usable features in data file %s", dataset->data_filename_);
} }
if (dataset->feature_names_.size() != static_cast<size_t>(dataset->num_total_features_)) {
Log::Fatal("Size of feature name error, should be %d, got %d", dataset->num_total_features_,
static_cast<int>(dataset->feature_names_.size()));
}
} }
std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filename, const Metadata& metadata, std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filename, const Metadata& metadata,
...@@ -616,14 +620,14 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -616,14 +620,14 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
// parse features // parse features
parser->ParseOneLine(sample_data[i].c_str(), &oneline_features, &label); parser->ParseOneLine(sample_data[i].c_str(), &oneline_features, &label);
for (std::pair<int, double>& inner_data : oneline_features) { for (std::pair<int, double>& inner_data : oneline_features) {
if (std::fabs(inner_data.second) > 1e-15) { if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
if (static_cast<size_t>(inner_data.first) >= sample_values.size()) { // if need expand feature set
// if need expand feature set size_t need_size = inner_data.first - sample_values.size() + 1;
size_t need_size = inner_data.first - sample_values.size() + 1; for (size_t j = 0; j < need_size; ++j) {
for (size_t j = 0; j < need_size; ++j) { sample_values.emplace_back();
sample_values.emplace_back();
}
} }
}
if (std::fabs(inner_data.second) > 1e-15) {
sample_values[inner_data.first].push_back(inner_data.second); sample_values[inner_data.first].push_back(inner_data.second);
} }
} }
...@@ -631,9 +635,14 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -631,9 +635,14 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
dataset->features_.clear(); dataset->features_.clear();
// -1 means doesn't use this feature if (feature_names_.empty()) {
dataset->used_feature_map_ = std::vector<int>(sample_values.size(), -1); // -1 means doesn't use this feature
dataset->num_total_features_ = static_cast<int>(sample_values.size()); dataset->used_feature_map_ = std::vector<int>(sample_values.size(), -1);
dataset->num_total_features_ = static_cast<int>(sample_values.size());
} else {
dataset->used_feature_map_ = std::vector<int>(feature_names_.size(), -1);
dataset->num_total_features_ = static_cast<int>(feature_names_.size());
}
// check the range of label_idx, weight_idx and group_idx // check the range of label_idx, weight_idx and group_idx
CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_); CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment