Commit ef221275 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix #991 (#992)

* refine categorical split

* a bug fix

* fix a bug
parent 4aa32967
...@@ -306,7 +306,11 @@ namespace LightGBM { ...@@ -306,7 +306,11 @@ namespace LightGBM {
// sort by counts // sort by counts
Common::SortForPair<int, int>(counts_int, distinct_values_int, 0, true); Common::SortForPair<int, int>(counts_int, distinct_values_int, 0, true);
// avoid first bin is zero // avoid first bin is zero
if (distinct_values_int[0] == 0 && counts_int.size() > 1) { if (distinct_values_int[0] == 0 || (counts_int.size() == 1 && na_cnt > 0)) {
if (counts_int.size() == 1) {
counts_int.push_back(0);
distinct_values_int.push_back(distinct_values_int[0] + 1);
}
std::swap(counts_int[0], counts_int[1]); std::swap(counts_int[0], counts_int[1]);
std::swap(distinct_values_int[0], distinct_values_int[1]); std::swap(distinct_values_int[0], distinct_values_int[1]);
} }
......
...@@ -206,11 +206,17 @@ public: ...@@ -206,11 +206,17 @@ public:
output->cat_threshold = std::vector<uint32_t>(output->num_cat_threshold); output->cat_threshold = std::vector<uint32_t>(output->num_cat_threshold);
if (best_dir == 1) { if (best_dir == 1) {
for (int i = 0; i < output->num_cat_threshold; ++i) { for (int i = 0; i < output->num_cat_threshold; ++i) {
output->cat_threshold[i] = sorted_idx[i]; auto t = sorted_idx[i];
if (data_[t].cnt > 0) {
output->cat_threshold[i] = t;
}
} }
} else { } else {
for (int i = 0; i < output->num_cat_threshold; ++i) { for (int i = 0; i < output->num_cat_threshold; ++i) {
output->cat_threshold[i] = sorted_idx[used_bin - 1 - i]; auto t = sorted_idx[used_bin - 1 - i];
if (data_[t].cnt > 0) {
output->cat_threshold[i] = t;
}
} }
} }
} }
......
...@@ -241,6 +241,68 @@ class TestEngine(unittest.TestCase): ...@@ -241,6 +241,68 @@ class TestEngine(unittest.TestCase):
pred = gbm.predict(X_train) pred = gbm.predict(X_train)
np.testing.assert_almost_equal(pred, y) np.testing.assert_almost_equal(pred, y)
def test_categorical_handle2(self):
x = [0, np.nan, 0, np.nan, 0, np.nan]
y = [0, 1, 0, 1, 0, 1]
X_train = np.array(x).reshape(len(x), 1)
y_train = np.array(y)
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_train, y_train)
params = {
'objective': 'regression',
'metric': 'auc',
'verbose': -1,
'boost_from_average': False,
'min_data': 1,
'num_leaves': 2,
'learning_rate': 1,
'min_data_in_bin': 1,
'min_data_per_group': 1,
'zero_as_missing': False,
'categorical_column': 0
}
evals_result = {}
gbm = lgb.train(params, lgb_train,
num_boost_round=1,
valid_sets=lgb_eval,
verbose_eval=True,
evals_result=evals_result)
pred = gbm.predict(X_train)
np.testing.assert_almost_equal(pred, y)
def test_categorical_handle3(self):
x = [11, np.nan, 11, np.nan, 11, np.nan]
y = [0, 1, 0, 1, 0, 1]
X_train = np.array(x).reshape(len(x), 1)
y_train = np.array(y)
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_train, y_train)
params = {
'objective': 'regression',
'metric': 'auc',
'verbose': -1,
'boost_from_average': False,
'min_data': 1,
'num_leaves': 2,
'learning_rate': 1,
'min_data_in_bin': 1,
'min_data_per_group': 1,
'zero_as_missing': False,
'categorical_column': 0
}
evals_result = {}
gbm = lgb.train(params, lgb_train,
num_boost_round=1,
valid_sets=lgb_eval,
verbose_eval=True,
evals_result=evals_result)
pred = gbm.predict(X_train)
np.testing.assert_almost_equal(pred, y)
def test_multiclass(self): def test_multiclass(self):
X, y = load_digits(10, True) X, y = load_digits(10, True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment