Unverified Commit a86a211b authored by shiyu1994's avatar shiyu1994 Committed by GitHub
Browse files

fix bug in corner case of hist bin mismatch (#3694)

parent c7c4e084
...@@ -556,18 +556,11 @@ MultiValBin* Dataset::GetMultiBinFromAllFeatures(const std::vector<uint32_t>& of ...@@ -556,18 +556,11 @@ MultiValBin* Dataset::GetMultiBinFromAllFeatures(const std::vector<uint32_t>& of
} }
} }
sum_dense_ratio /= ncol; sum_dense_ratio /= ncol;
const int offset = (1.0f - sum_dense_ratio) >=
MultiValBin::multi_val_bin_sparse_threshold ? 1 : 0;
int num_total_bin = offset;
for (int gid = 0; gid < num_groups_; ++gid) { for (int gid = 0; gid < num_groups_; ++gid) {
if (feature_groups_[gid]->is_multi_val_) { if (feature_groups_[gid]->is_multi_val_) {
for (int fid = 0; fid < feature_groups_[gid]->num_feature_; ++fid) { for (int fid = 0; fid < feature_groups_[gid]->num_feature_; ++fid) {
const auto& bin_mapper = feature_groups_[gid]->bin_mappers_[fid]; const auto& bin_mapper = feature_groups_[gid]->bin_mappers_[fid];
most_freq_bins.push_back(bin_mapper->GetMostFreqBin()); most_freq_bins.push_back(bin_mapper->GetMostFreqBin());
num_total_bin += bin_mapper->num_bin();
if (most_freq_bins.back() == 0) {
num_total_bin -= offset;
}
#pragma omp parallel for schedule(static, 1) #pragma omp parallel for schedule(static, 1)
for (int tid = 0; tid < num_threads; ++tid) { for (int tid = 0; tid < num_threads; ++tid) {
iters[tid].emplace_back( iters[tid].emplace_back(
...@@ -576,7 +569,6 @@ MultiValBin* Dataset::GetMultiBinFromAllFeatures(const std::vector<uint32_t>& of ...@@ -576,7 +569,6 @@ MultiValBin* Dataset::GetMultiBinFromAllFeatures(const std::vector<uint32_t>& of
} }
} else { } else {
most_freq_bins.push_back(0); most_freq_bins.push_back(0);
num_total_bin += feature_groups_[gid]->bin_offsets_.back() - offset;
for (int tid = 0; tid < num_threads; ++tid) { for (int tid = 0; tid < num_threads; ++tid) {
iters[tid].emplace_back(feature_groups_[gid]->FeatureGroupIterator()); iters[tid].emplace_back(feature_groups_[gid]->FeatureGroupIterator());
} }
...@@ -586,7 +578,7 @@ MultiValBin* Dataset::GetMultiBinFromAllFeatures(const std::vector<uint32_t>& of ...@@ -586,7 +578,7 @@ MultiValBin* Dataset::GetMultiBinFromAllFeatures(const std::vector<uint32_t>& of
Log::Debug("Dataset::GetMultiBinFromAllFeatures: sparse rate %f", Log::Debug("Dataset::GetMultiBinFromAllFeatures: sparse rate %f",
1.0 - sum_dense_ratio); 1.0 - sum_dense_ratio);
ret.reset(MultiValBin::CreateMultiValBin( ret.reset(MultiValBin::CreateMultiValBin(
num_data_, num_total_bin, static_cast<int>(most_freq_bins.size()), num_data_, offsets.back(), static_cast<int>(most_freq_bins.size()),
1.0 - sum_dense_ratio, offsets)); 1.0 - sum_dense_ratio, offsets));
PushDataToMultiValBin(num_data_, most_freq_bins, offsets, &iters, ret.get()); PushDataToMultiValBin(num_data_, most_freq_bins, offsets, &iters, ret.get());
ret->FinishLoad(); ret->FinishLoad();
......
...@@ -176,6 +176,9 @@ void MultiValBinWrapper::CopyMultiValBinSubset( ...@@ -176,6 +176,9 @@ void MultiValBinWrapper::CopyMultiValBinSubset(
if (feature_groups[i]->is_multi_val_) { if (feature_groups[i]->is_multi_val_) {
for (int j = 0; j < feature_groups[i]->num_feature_; ++j) { for (int j = 0; j < feature_groups[i]->num_feature_; ++j) {
const auto& bin_mapper = feature_groups[i]->bin_mappers_[j]; const auto& bin_mapper = feature_groups[i]->bin_mappers_[j];
if (i == 0 && j == 0 && bin_mapper->GetMostFreqBin() > 0) {
num_total_bin = 1;
}
int cur_num_bin = bin_mapper->num_bin(); int cur_num_bin = bin_mapper->num_bin();
if (bin_mapper->GetMostFreqBin() == 0) { if (bin_mapper->GetMostFreqBin() == 0) {
cur_num_bin -= offset; cur_num_bin -= offset;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment