"tests/vscode:/vscode.git/clone" did not exist on "f185695617d29e446e021fed1261ba84f09ff992"
Unverified Commit 03910760 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix zero bin in categorical split (#3305)

* fix zero bin

* some fix

* fix bin mapping

* fix

* fix bug

* use stable sort

* fix cat forced split

* Apply suggestions from code review

* Apply suggestions from code review

* Apply suggestions from code review
parent 27c9aa88
...@@ -456,7 +456,9 @@ class MultiValBin { ...@@ -456,7 +456,9 @@ class MultiValBin {
inline uint32_t BinMapper::ValueToBin(double value) const { inline uint32_t BinMapper::ValueToBin(double value) const {
if (std::isnan(value)) { if (std::isnan(value)) {
if (missing_type_ == MissingType::NaN) { if (bin_type_ == BinType::CategoricalBin) {
return 0;
} else if (missing_type_ == MissingType::NaN) {
return num_bin_ - 1; return num_bin_ - 1;
} else { } else {
value = 0.0f; value = 0.0f;
...@@ -482,12 +484,12 @@ inline uint32_t BinMapper::ValueToBin(double value) const { ...@@ -482,12 +484,12 @@ inline uint32_t BinMapper::ValueToBin(double value) const {
int int_value = static_cast<int>(value); int int_value = static_cast<int>(value);
// convert negative value to NaN bin // convert negative value to NaN bin
if (int_value < 0) { if (int_value < 0) {
return num_bin_ - 1; return 0;
} }
if (categorical_2_bin_.count(int_value)) { if (categorical_2_bin_.count(int_value)) {
return categorical_2_bin_.at(int_value); return categorical_2_bin_.at(int_value);
} else { } else {
return num_bin_ - 1; return 0;
} }
} }
} }
......
...@@ -439,7 +439,6 @@ namespace LightGBM { ...@@ -439,7 +439,6 @@ namespace LightGBM {
} }
} }
} }
num_bin_ = 0;
int rest_cnt = static_cast<int>(total_sample_cnt - na_cnt); int rest_cnt = static_cast<int>(total_sample_cnt - na_cnt);
if (rest_cnt > 0) { if (rest_cnt > 0) {
const int SPARSE_RATIO = 100; const int SPARSE_RATIO = 100;
...@@ -449,23 +448,25 @@ namespace LightGBM { ...@@ -449,23 +448,25 @@ namespace LightGBM {
} }
// sort by counts // sort by counts
Common::SortForPair<int, int>(&counts_int, &distinct_values_int, 0, true); Common::SortForPair<int, int>(&counts_int, &distinct_values_int, 0, true);
// avoid first bin is zero
if (distinct_values_int[0] == 0) {
if (counts_int.size() == 1) {
counts_int.push_back(0);
distinct_values_int.push_back(distinct_values_int[0] + 1);
}
std::swap(counts_int[0], counts_int[1]);
std::swap(distinct_values_int[0], distinct_values_int[1]);
}
// will ignore the categorical of small counts // will ignore the categorical of small counts
int cut_cnt = static_cast<int>((total_sample_cnt - na_cnt) * 0.99f); int cut_cnt = static_cast<int>(
Common::RoundInt((total_sample_cnt - na_cnt) * 0.99f));
size_t cur_cat = 0; size_t cur_cat = 0;
categorical_2_bin_.clear(); categorical_2_bin_.clear();
bin_2_categorical_.clear(); bin_2_categorical_.clear();
int used_cnt = 0; int used_cnt = 0;
max_bin = std::min(static_cast<int>(distinct_values_int.size()), max_bin); int distinct_cnt = static_cast<int>(distinct_values_int.size());
if (na_cnt > 0) {
++distinct_cnt;
}
max_bin = std::min(distinct_cnt, max_bin);
cnt_in_bin.clear(); cnt_in_bin.clear();
// Push the dummy bin for NaN
bin_2_categorical_.push_back(-1);
categorical_2_bin_[-1] = 0;
cnt_in_bin.push_back(0);
num_bin_ = 1;
while (cur_cat < distinct_values_int.size() while (cur_cat < distinct_values_int.size()
&& (used_cnt < cut_cnt || num_bin_ < max_bin)) { && (used_cnt < cut_cnt || num_bin_ < max_bin)) {
if (counts_int[cur_cat] < min_data_in_bin && cur_cat > 1) { if (counts_int[cur_cat] < min_data_in_bin && cur_cat > 1) {
...@@ -478,21 +479,14 @@ namespace LightGBM { ...@@ -478,21 +479,14 @@ namespace LightGBM {
++num_bin_; ++num_bin_;
++cur_cat; ++cur_cat;
} }
// need an additional bin for NaN
if (cur_cat == distinct_values_int.size() && na_cnt > 0) {
// use -1 to represent NaN
bin_2_categorical_.push_back(-1);
categorical_2_bin_[-1] = num_bin_;
cnt_in_bin.push_back(0);
++num_bin_;
}
// Use MissingType::None to represent this bin contains all categoricals // Use MissingType::None to represent this bin contains all categoricals
if (cur_cat == distinct_values_int.size() && na_cnt == 0) { if (cur_cat == distinct_values_int.size() && na_cnt == 0) {
missing_type_ = MissingType::None; missing_type_ = MissingType::None;
} else { } else {
missing_type_ = MissingType::NaN; missing_type_ = MissingType::NaN;
} }
cnt_in_bin.back() += static_cast<int>(total_sample_cnt - used_cnt); // fix count of NaN bin
cnt_in_bin[0] = static_cast<int>(total_sample_cnt - used_cnt);
} }
} }
...@@ -511,13 +505,6 @@ namespace LightGBM { ...@@ -511,13 +505,6 @@ namespace LightGBM {
default_bin_ = ValueToBin(0); default_bin_ = ValueToBin(0);
most_freq_bin_ = most_freq_bin_ =
static_cast<uint32_t>(ArrayArgs<int>::ArgMax(cnt_in_bin)); static_cast<uint32_t>(ArrayArgs<int>::ArgMax(cnt_in_bin));
if (bin_type_ == BinType::CategoricalBin) {
if (most_freq_bin_ == 0) {
CHECK_GT(num_bin_, 1);
// FIXME: how to enable `most_freq_bin_ = 0` for categorical features
most_freq_bin_ = 1;
}
}
const double max_sparse_rate = const double max_sparse_rate =
static_cast<double>(cnt_in_bin[most_freq_bin_]) / total_sample_cnt; static_cast<double>(cnt_in_bin[most_freq_bin_]) / total_sample_cnt;
// When most_freq_bin_ != default_bin_, there are some additional data loading costs. // When most_freq_bin_ != default_bin_, there are some additional data loading costs.
......
...@@ -318,7 +318,9 @@ class DenseBin : public Bin { ...@@ -318,7 +318,9 @@ class DenseBin : public Bin {
data_size_t gt_count = 0; data_size_t gt_count = 0;
data_size_t* default_indices = gt_indices; data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count; data_size_t* default_count = &gt_count;
if (Common::FindInBitset(threshold, num_threahold, most_freq_bin)) { int8_t offset = most_freq_bin == 0 ? 1 : 0;
if (most_freq_bin > 0 &&
Common::FindInBitset(threshold, num_threahold, most_freq_bin)) {
default_indices = lte_indices; default_indices = lte_indices;
default_count = &lte_count; default_count = &lte_count;
} }
...@@ -330,7 +332,7 @@ class DenseBin : public Bin { ...@@ -330,7 +332,7 @@ class DenseBin : public Bin {
} else if (!USE_MIN_BIN && bin == 0) { } else if (!USE_MIN_BIN && bin == 0) {
default_indices[(*default_count)++] = idx; default_indices[(*default_count)++] = idx;
} else if (Common::FindInBitset(threshold, num_threahold, } else if (Common::FindInBitset(threshold, num_threahold,
bin - min_bin)) { bin - min_bin + offset)) {
lte_indices[lte_count++] = idx; lte_indices[lte_count++] = idx;
} else { } else {
gt_indices[gt_count++] = idx; gt_indices[gt_count++] = idx;
......
...@@ -364,7 +364,8 @@ class SparseBin : public Bin { ...@@ -364,7 +364,8 @@ class SparseBin : public Bin {
data_size_t* default_indices = gt_indices; data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count; data_size_t* default_count = &gt_count;
SparseBinIterator<VAL_T> iterator(this, data_indices[0]); SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
if (Common::FindInBitset(threshold, num_threahold, most_freq_bin)) { int8_t offset = most_freq_bin == 0 ? 1 : 0;
if (most_freq_bin > 0 && Common::FindInBitset(threshold, num_threahold, most_freq_bin)) {
default_indices = lte_indices; default_indices = lte_indices;
default_count = &lte_count; default_count = &lte_count;
} }
...@@ -376,7 +377,7 @@ class SparseBin : public Bin { ...@@ -376,7 +377,7 @@ class SparseBin : public Bin {
} else if (!USE_MIN_BIN && bin == 0) { } else if (!USE_MIN_BIN && bin == 0) {
default_indices[(*default_count)++] = idx; default_indices[(*default_count)++] = idx;
} else if (Common::FindInBitset(threshold, num_threahold, } else if (Common::FindInBitset(threshold, num_threahold,
bin - min_bin)) { bin - min_bin + offset)) {
lte_indices[lte_count++] = idx; lte_indices[lte_count++] = idx;
} else { } else {
gt_indices[gt_count++] = idx; gt_indices[gt_count++] = idx;
......
...@@ -300,8 +300,10 @@ class FeatureHistogram { ...@@ -300,8 +300,10 @@ class FeatureHistogram {
} }
double min_gain_shift = gain_shift + meta_->config->min_gain_to_split; double min_gain_shift = gain_shift + meta_->config->min_gain_to_split;
bool is_full_categorical = meta_->missing_type == MissingType::None; const int8_t offset = meta_->offset;
int used_bin = meta_->num_bin - 1 + is_full_categorical; const int bin_start = 1 - offset;
const int bin_end = meta_->num_bin - offset;
int used_bin = -1;
std::vector<int> sorted_idx; std::vector<int> sorted_idx;
double l2 = meta_->config->lambda_l2; double l2 = meta_->config->lambda_l2;
...@@ -312,11 +314,11 @@ class FeatureHistogram { ...@@ -312,11 +314,11 @@ class FeatureHistogram {
int rand_threshold = 0; int rand_threshold = 0;
if (use_onehot) { if (use_onehot) {
if (USE_RAND) { if (USE_RAND) {
if (used_bin > 0) { if (bin_end - bin_start > 0) {
rand_threshold = meta_->rand.NextInt(0, used_bin); rand_threshold = meta_->rand.NextInt(bin_start, bin_end);
} }
} }
for (int t = 0; t < used_bin; ++t) { for (int t = bin_start; t < bin_end; ++t) {
const auto grad = GET_GRAD(data_, t); const auto grad = GET_GRAD(data_, t);
const auto hess = GET_HESS(data_, t); const auto hess = GET_HESS(data_, t);
data_size_t cnt = data_size_t cnt =
...@@ -366,7 +368,7 @@ class FeatureHistogram { ...@@ -366,7 +368,7 @@ class FeatureHistogram {
} }
} }
} else { } else {
for (int i = 0; i < used_bin; ++i) { for (int i = bin_start; i < bin_end; ++i) {
if (Common::RoundInt(GET_HESS(data_, i) * cnt_factor) >= if (Common::RoundInt(GET_HESS(data_, i) * cnt_factor) >=
meta_->config->cat_smooth) { meta_->config->cat_smooth) {
sorted_idx.push_back(i); sorted_idx.push_back(i);
...@@ -379,8 +381,8 @@ class FeatureHistogram { ...@@ -379,8 +381,8 @@ class FeatureHistogram {
auto ctr_fun = [this](double sum_grad, double sum_hess) { auto ctr_fun = [this](double sum_grad, double sum_hess) {
return (sum_grad) / (sum_hess + meta_->config->cat_smooth); return (sum_grad) / (sum_hess + meta_->config->cat_smooth);
}; };
std::sort(sorted_idx.begin(), sorted_idx.end(), std::stable_sort(
[this, &ctr_fun](int i, int j) { sorted_idx.begin(), sorted_idx.end(), [this, &ctr_fun](int i, int j) {
return ctr_fun(GET_GRAD(data_, i), GET_HESS(data_, i)) < return ctr_fun(GET_GRAD(data_, i), GET_HESS(data_, i)) <
ctr_fun(GET_GRAD(data_, j), GET_HESS(data_, j)); ctr_fun(GET_GRAD(data_, j), GET_HESS(data_, j));
}); });
...@@ -489,19 +491,19 @@ class FeatureHistogram { ...@@ -489,19 +491,19 @@ class FeatureHistogram {
if (use_onehot) { if (use_onehot) {
output->num_cat_threshold = 1; output->num_cat_threshold = 1;
output->cat_threshold = output->cat_threshold =
std::vector<uint32_t>(1, static_cast<uint32_t>(best_threshold)); std::vector<uint32_t>(1, static_cast<uint32_t>(best_threshold + offset));
} else { } else {
output->num_cat_threshold = best_threshold + 1; output->num_cat_threshold = best_threshold + 1;
output->cat_threshold = output->cat_threshold =
std::vector<uint32_t>(output->num_cat_threshold); std::vector<uint32_t>(output->num_cat_threshold);
if (best_dir == 1) { if (best_dir == 1) {
for (int i = 0; i < output->num_cat_threshold; ++i) { for (int i = 0; i < output->num_cat_threshold; ++i) {
auto t = sorted_idx[i]; auto t = sorted_idx[i] + offset;
output->cat_threshold[i] = t; output->cat_threshold[i] = t;
} }
} else { } else {
for (int i = 0; i < output->num_cat_threshold; ++i) { for (int i = 0; i < output->num_cat_threshold; ++i) {
auto t = sorted_idx[used_bin - 1 - i]; auto t = sorted_idx[used_bin - 1 - i] + offset;
output->cat_threshold[i] = t; output->cat_threshold[i] = t;
} }
} }
...@@ -649,16 +651,14 @@ class FeatureHistogram { ...@@ -649,16 +651,14 @@ class FeatureHistogram {
double gain_shift = GetLeafGainGivenOutput<true>( double gain_shift = GetLeafGainGivenOutput<true>(
sum_gradient, sum_hessian, meta_->config->lambda_l1, meta_->config->lambda_l2, parent_output); sum_gradient, sum_hessian, meta_->config->lambda_l1, meta_->config->lambda_l2, parent_output);
double min_gain_shift = gain_shift + meta_->config->min_gain_to_split; double min_gain_shift = gain_shift + meta_->config->min_gain_to_split;
bool is_full_categorical = meta_->missing_type == MissingType::None; if (threshold >= static_cast<uint32_t>(meta_->num_bin) || threshold == 0) {
int used_bin = meta_->num_bin - 1 + is_full_categorical;
if (threshold >= static_cast<uint32_t>(used_bin)) {
output->gain = kMinScore; output->gain = kMinScore;
Log::Warning("Invalid categorical threshold split"); Log::Warning("Invalid categorical threshold split");
return; return;
} }
const double cnt_factor = num_data / sum_hessian; const double cnt_factor = num_data / sum_hessian;
const auto grad = GET_GRAD(data_, threshold); const auto grad = GET_GRAD(data_, threshold - meta_->offset);
const auto hess = GET_HESS(data_, threshold); const auto hess = GET_HESS(data_, threshold - meta_->offset);
data_size_t cnt = data_size_t cnt =
static_cast<data_size_t>(Common::RoundInt(hess * cnt_factor)); static_cast<data_size_t>(Common::RoundInt(hess * cnt_factor));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment