Unverified Commit e502ed01 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

avoid most_freq_bin to be 0 in categorical features (#2824)

* avoid most_freq_bin to be 0 in categorical features

* Apply suggestions from code review

* add tests

* update test

* Apply suggestions from code review

* Apply suggestions from code review
parent b305a432
...@@ -510,20 +510,24 @@ namespace LightGBM { ...@@ -510,20 +510,24 @@ namespace LightGBM {
if (!is_trivial_) { if (!is_trivial_) {
default_bin_ = ValueToBin(0); default_bin_ = ValueToBin(0);
most_freq_bin_ =
static_cast<uint32_t>(ArrayArgs<int>::ArgMax(cnt_in_bin));
if (bin_type_ == BinType::CategoricalBin) { if (bin_type_ == BinType::CategoricalBin) {
CHECK(default_bin_ > 0); if (most_freq_bin_ == 0) {
CHECK(num_bin_ > 1);
// FIXME: how to enable `most_freq_bin_ = 0` for categorical features
most_freq_bin_ = 1;
}
} }
} const double max_sparse_rate =
if (!is_trivial_) { static_cast<double>(cnt_in_bin[most_freq_bin_]) / total_sample_cnt;
most_freq_bin_ = static_cast<uint32_t>(ArrayArgs<int>::ArgMax(cnt_in_bin)); // When most_freq_bin_ != default_bin_, there are some additional data loading costs.
// calculate sparse rate // so use most_freq_bin_ = default_bin_ when there is not so sparse
sparse_rate_ = static_cast<double>(cnt_in_bin[default_bin_]) / total_sample_cnt; if (most_freq_bin_ != default_bin_ && max_sparse_rate < kSparseThreshold) {
const double max_sparse_rate = static_cast<double>(cnt_in_bin[most_freq_bin_]) / total_sample_cnt;
if (most_freq_bin_ != default_bin_ && max_sparse_rate > 0.7f) {
sparse_rate_ = max_sparse_rate;
} else {
most_freq_bin_ = default_bin_; most_freq_bin_ = default_bin_;
} }
sparse_rate_ =
static_cast<double>(cnt_in_bin[most_freq_bin_]) / total_sample_cnt;
} else { } else {
sparse_rate_ = 1.0f; sparse_rate_ = 1.0f;
} }
......
...@@ -335,6 +335,43 @@ class TestEngine(unittest.TestCase): ...@@ -335,6 +335,43 @@ class TestEngine(unittest.TestCase):
self.assertGreater(ret, 0.999) self.assertGreater(ret, 0.999)
self.assertAlmostEqual(evals_result['valid_0']['auc'][-1], ret, places=5) self.assertAlmostEqual(evals_result['valid_0']['auc'][-1], ret, places=5)
def test_categorical_non_zero_inputs(self):
x = [1, 1, 1, 1, 1, 1, 2, 2]
y = [1, 1, 1, 1, 1, 1, 0, 0]
X_train = np.array(x).reshape(len(x), 1)
y_train = np.array(y)
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_train, y_train)
params = {
'objective': 'regression',
'metric': 'auc',
'verbose': -1,
'boost_from_average': False,
'min_data': 1,
'num_leaves': 2,
'learning_rate': 1,
'min_data_in_bin': 1,
'min_data_per_group': 1,
'cat_smooth': 1,
'cat_l2': 0,
'max_cat_to_onehot': 1,
'zero_as_missing': False,
'categorical_column': 0
}
evals_result = {}
gbm = lgb.train(params, lgb_train,
num_boost_round=1,
valid_sets=lgb_eval,
verbose_eval=False,
evals_result=evals_result)
pred = gbm.predict(X_train)
np.testing.assert_allclose(pred, y)
ret = roc_auc_score(y_train, pred)
self.assertGreater(ret, 0.999)
self.assertAlmostEqual(evals_result['valid_0']['auc'][-1], ret, places=5)
def test_multiclass(self): def test_multiclass(self):
X, y = load_digits(10, True) X, y = load_digits(10, True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment