Unverified Commit c30ace21 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix all negative values in cat features (#1547)

* fix all negative values in cat features

* fix a bug
parent 00a125d5
...@@ -316,55 +316,58 @@ namespace LightGBM { ...@@ -316,55 +316,58 @@ namespace LightGBM {
} }
} }
} }
// sort by counts
Common::SortForPair<int, int>(counts_int, distinct_values_int, 0, true);
// avoid first bin is zero
if (distinct_values_int[0] == 0) {
if (counts_int.size() == 1) {
counts_int.push_back(0);
distinct_values_int.push_back(distinct_values_int[0] + 1);
}
std::swap(counts_int[0], counts_int[1]);
std::swap(distinct_values_int[0], distinct_values_int[1]);
}
// will ignore the categorical of small counts
int cut_cnt = static_cast<int>((total_sample_cnt - na_cnt) * 0.99f);
size_t cur_cat = 0;
categorical_2_bin_.clear();
bin_2_categorical_.clear();
num_bin_ = 0; num_bin_ = 0;
int used_cnt = 0; int rest_cnt = total_sample_cnt - na_cnt;
max_bin = std::min(static_cast<int>(distinct_values_int.size()), max_bin); if (rest_cnt > 0) {
cnt_in_bin.clear(); // sort by counts
while (cur_cat < distinct_values_int.size() Common::SortForPair<int, int>(counts_int, distinct_values_int, 0, true);
&& (used_cnt < cut_cnt || num_bin_ < max_bin)) { // avoid first bin is zero
if (counts_int[cur_cat] < min_data_in_bin && cur_cat > 1) { if (distinct_values_int[0] == 0) {
break; if (counts_int.size() == 1) {
counts_int.push_back(0);
distinct_values_int.push_back(distinct_values_int[0] + 1);
}
std::swap(counts_int[0], counts_int[1]);
std::swap(distinct_values_int[0], distinct_values_int[1]);
} }
bin_2_categorical_.push_back(distinct_values_int[cur_cat]); // will ignore the categorical of small counts
categorical_2_bin_[distinct_values_int[cur_cat]] = static_cast<unsigned int>(num_bin_); int cut_cnt = static_cast<int>((total_sample_cnt - na_cnt) * 0.99f);
used_cnt += counts_int[cur_cat]; size_t cur_cat = 0;
cnt_in_bin.push_back(counts_int[cur_cat]); categorical_2_bin_.clear();
++num_bin_; bin_2_categorical_.clear();
++cur_cat; int used_cnt = 0;
} max_bin = std::min(static_cast<int>(distinct_values_int.size()), max_bin);
// need an additional bin for NaN cnt_in_bin.clear();
if (cur_cat == distinct_values_int.size() && na_cnt > 0) { while (cur_cat < distinct_values_int.size()
// use -1 to represent NaN && (used_cnt < cut_cnt || num_bin_ < max_bin)) {
bin_2_categorical_.push_back(-1); if (counts_int[cur_cat] < min_data_in_bin && cur_cat > 1) {
categorical_2_bin_[-1] = num_bin_; break;
cnt_in_bin.push_back(0); }
++num_bin_; bin_2_categorical_.push_back(distinct_values_int[cur_cat]);
} categorical_2_bin_[distinct_values_int[cur_cat]] = static_cast<unsigned int>(num_bin_);
// Use MissingType::None to represent this bin contains all categoricals used_cnt += counts_int[cur_cat];
if (cur_cat == distinct_values_int.size() && na_cnt == 0) { cnt_in_bin.push_back(counts_int[cur_cat]);
missing_type_ = MissingType::None; ++num_bin_;
} else if (na_cnt == 0) { ++cur_cat;
missing_type_ = MissingType::Zero; }
} else { // need an additional bin for NaN
missing_type_ = MissingType::NaN; if (cur_cat == distinct_values_int.size() && na_cnt > 0) {
// use -1 to represent NaN
bin_2_categorical_.push_back(-1);
categorical_2_bin_[-1] = num_bin_;
cnt_in_bin.push_back(0);
++num_bin_;
}
// Use MissingType::None to represent this bin contains all categoricals
if (cur_cat == distinct_values_int.size() && na_cnt == 0) {
missing_type_ = MissingType::None;
} else if (na_cnt == 0) {
missing_type_ = MissingType::Zero;
} else {
missing_type_ = MissingType::NaN;
}
cnt_in_bin.back() += static_cast<int>(total_sample_cnt - used_cnt);
} }
cnt_in_bin.back() += static_cast<int>(total_sample_cnt - used_cnt);
} }
// check trival(num_bin_ == 1) feature // check trival(num_bin_ == 1) feature
...@@ -384,8 +387,12 @@ namespace LightGBM { ...@@ -384,8 +387,12 @@ namespace LightGBM {
CHECK(default_bin_ > 0); CHECK(default_bin_ > 0);
} }
} }
// calculate sparse rate if (!is_trival_) {
sparse_rate_ = static_cast<double>(cnt_in_bin[default_bin_]) / static_cast<double>(total_sample_cnt); // calculate sparse rate
sparse_rate_ = static_cast<double>(cnt_in_bin[default_bin_]) / static_cast<double>(total_sample_cnt);
} else {
sparse_rate_ = 1.0f;
}
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment