Unverified Commit d8a34df9 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix subset bug (#2748)



* fix subset bug

* typo

* add fixme tag

* bin mapper

* fix test
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 1c1a2765
...@@ -46,19 +46,21 @@ class FeatureGroup { ...@@ -46,19 +46,21 @@ class FeatureGroup {
num_total_bin_ += num_bin; num_total_bin_ += num_bin;
bin_offsets_.emplace_back(num_total_bin_); bin_offsets_.emplace_back(num_total_bin_);
} }
if (is_multi_val_) { CreateBinData(num_data, is_multi_val_, true, false);
multi_bin_data_.clear(); }
for (int i = 0; i < num_feature_; ++i) {
int addi = bin_mappers_[i]->GetMostFreqBin() == 0 ? 0 : 1; FeatureGroup(const FeatureGroup& other, int num_data) {
if (bin_mappers_[i]->sparse_rate() >= kSparseThreshold) { num_feature_ = other.num_feature_;
multi_bin_data_.emplace_back(Bin::CreateSparseBin(num_data, bin_mappers_[i]->num_bin() + addi)); is_multi_val_ = other.is_multi_val_;
} else { is_sparse_ = other.is_sparse_;
multi_bin_data_.emplace_back(Bin::CreateDenseBin(num_data, bin_mappers_[i]->num_bin() + addi)); num_total_bin_ = other.num_total_bin_;
} bin_offsets_ = other.bin_offsets_;
}
} else { bin_mappers_.reserve(other.bin_mappers_.size());
bin_data_.reset(Bin::CreateDenseBin(num_data, num_total_bin_)); for (auto& bin_mapper : other.bin_mappers_) {
bin_mappers_.emplace_back(new BinMapper(*bin_mapper));
} }
CreateBinData(num_data, is_multi_val_, !is_sparse_, is_sparse_);
} }
FeatureGroup(std::vector<std::unique_ptr<BinMapper>>* bin_mappers, FeatureGroup(std::vector<std::unique_ptr<BinMapper>>* bin_mappers,
...@@ -76,13 +78,7 @@ class FeatureGroup { ...@@ -76,13 +78,7 @@ class FeatureGroup {
num_total_bin_ += num_bin; num_total_bin_ += num_bin;
bin_offsets_.emplace_back(num_total_bin_); bin_offsets_.emplace_back(num_total_bin_);
} }
if (bin_mappers_[0]->sparse_rate() >= kSparseThreshold) { CreateBinData(num_data, false, false, false);
is_sparse_ = true;
bin_data_.reset(Bin::CreateSparseBin(num_data, num_total_bin_));
} else {
is_sparse_ = false;
bin_data_.reset(Bin::CreateDenseBin(num_data, num_total_bin_));
}
} }
/*! /*!
...@@ -167,6 +163,16 @@ class FeatureGroup { ...@@ -167,6 +163,16 @@ class FeatureGroup {
} }
} }
void ReSize(int num_data) {
if (!is_multi_val_) {
bin_data_->ReSize(num_data);
} else {
for (int i = 0; i < num_feature_; ++i) {
multi_bin_data_[i]->ReSize(num_data);
}
}
}
inline void CopySubset(const FeatureGroup* full_feature, const data_size_t* used_indices, data_size_t num_used_indices) { inline void CopySubset(const FeatureGroup* full_feature, const data_size_t* used_indices, data_size_t num_used_indices) {
if (!is_multi_val_) { if (!is_multi_val_) {
bin_data_->CopySubset(full_feature->bin_data_.get(), used_indices, num_used_indices); bin_data_->CopySubset(full_feature->bin_data_.get(), used_indices, num_used_indices);
...@@ -327,6 +333,34 @@ class FeatureGroup { ...@@ -327,6 +333,34 @@ class FeatureGroup {
} }
private: private:
void CreateBinData(int num_data, bool is_multi_val, bool force_dense, bool force_sparse) {
if (is_multi_val) {
multi_bin_data_.clear();
for (int i = 0; i < num_feature_; ++i) {
int addi = bin_mappers_[i]->GetMostFreqBin() == 0 ? 0 : 1;
if (bin_mappers_[i]->sparse_rate() >= kSparseThreshold) {
multi_bin_data_.emplace_back(Bin::CreateSparseBin(
num_data, bin_mappers_[i]->num_bin() + addi));
} else {
multi_bin_data_.emplace_back(
Bin::CreateDenseBin(num_data, bin_mappers_[i]->num_bin() + addi));
}
}
is_multi_val_ = true;
} else {
if (force_sparse || (!force_dense && num_feature_ == 1 &&
bin_mappers_[0]->sparse_rate() >= kSparseThreshold)) {
is_sparse_ = true;
bin_data_.reset(Bin::CreateSparseBin(num_data, num_total_bin_));
} else {
is_sparse_ = false;
bin_data_.reset(Bin::CreateDenseBin(num_data, num_total_bin_));
}
is_multi_val_ = false;
}
}
/*! \brief Number of features */ /*! \brief Number of features */
int num_feature_; int num_feature_;
/*! \brief Bin mapper for sub features */ /*! \brief Bin mapper for sub features */
......
...@@ -720,15 +720,7 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset) { ...@@ -720,15 +720,7 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset) {
num_groups_ = dataset->num_groups_; num_groups_ = dataset->num_groups_;
// copy feature bin mapper data // copy feature bin mapper data
for (int i = 0; i < num_groups_; ++i) { for (int i = 0; i < num_groups_; ++i) {
std::vector<std::unique_ptr<BinMapper>> bin_mappers; feature_groups_.emplace_back(new FeatureGroup(*dataset->feature_groups_[i], num_data_));
for (int j = 0; j < dataset->feature_groups_[i]->num_feature_; ++j) {
bin_mappers.emplace_back(new BinMapper(*(dataset->feature_groups_[i]->bin_mappers_[j])));
}
feature_groups_.emplace_back(new FeatureGroup(
dataset->feature_groups_[i]->num_feature_,
dataset->feature_groups_[i]->is_multi_val_,
&bin_mappers,
num_data_));
} }
feature_groups_.shrink_to_fit(); feature_groups_.shrink_to_fit();
used_feature_map_ = dataset->used_feature_map_; used_feature_map_ = dataset->used_feature_map_;
...@@ -806,7 +798,7 @@ void Dataset::ReSize(data_size_t num_data) { ...@@ -806,7 +798,7 @@ void Dataset::ReSize(data_size_t num_data) {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int group = 0; group < num_groups_; ++group) { for (int group = 0; group < num_groups_; ++group) {
OMP_LOOP_EX_BEGIN(); OMP_LOOP_EX_BEGIN();
feature_groups_[group]->bin_data_->ReSize(num_data_); feature_groups_[group]->ReSize(num_data_);
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
OMP_THROW_EX(); OMP_THROW_EX();
...@@ -1399,6 +1391,8 @@ void Dataset::AddFeaturesFrom(Dataset* other) { ...@@ -1399,6 +1391,8 @@ void Dataset::AddFeaturesFrom(Dataset* other) {
PushVector(&group_feature_cnt_, other->group_feature_cnt_); PushVector(&group_feature_cnt_, other->group_feature_cnt_);
PushVector(&forced_bin_bounds_, other->forced_bin_bounds_); PushVector(&forced_bin_bounds_, other->forced_bin_bounds_);
feature_groups_.reserve(other->feature_groups_.size()); feature_groups_.reserve(other->feature_groups_.size());
// FIXME: fix the multiple multi-val feature groups, they need to be merged
// into one multi-val group
for (auto& fg : other->feature_groups_) { for (auto& fg : other->feature_groups_) {
feature_groups_.emplace_back(new FeatureGroup(*fg)); feature_groups_.emplace_back(new FeatureGroup(*fg));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment