Commit ef221275 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix #991 (#992)

* refine categorical split

* a bug fix

* fix a bug
parent 4aa32967
......@@ -306,7 +306,11 @@ namespace LightGBM {
// sort by counts
Common::SortForPair<int, int>(counts_int, distinct_values_int, 0, true);
// avoid first bin is zero
if (distinct_values_int[0] == 0 && counts_int.size() > 1) {
if (distinct_values_int[0] == 0 || (counts_int.size() == 1 && na_cnt > 0)) {
if (counts_int.size() == 1) {
counts_int.push_back(0);
distinct_values_int.push_back(distinct_values_int[0] + 1);
}
std::swap(counts_int[0], counts_int[1]);
std::swap(distinct_values_int[0], distinct_values_int[1]);
}
......
......@@ -206,11 +206,17 @@ public:
output->cat_threshold = std::vector<uint32_t>(output->num_cat_threshold);
if (best_dir == 1) {
for (int i = 0; i < output->num_cat_threshold; ++i) {
output->cat_threshold[i] = sorted_idx[i];
auto t = sorted_idx[i];
if (data_[t].cnt > 0) {
output->cat_threshold[i] = t;
}
}
} else {
for (int i = 0; i < output->num_cat_threshold; ++i) {
output->cat_threshold[i] = sorted_idx[used_bin - 1 - i];
auto t = sorted_idx[used_bin - 1 - i];
if (data_[t].cnt > 0) {
output->cat_threshold[i] = t;
}
}
}
}
......
......@@ -241,6 +241,68 @@ class TestEngine(unittest.TestCase):
pred = gbm.predict(X_train)
np.testing.assert_almost_equal(pred, y)
def test_categorical_handle2(self):
x = [0, np.nan, 0, np.nan, 0, np.nan]
y = [0, 1, 0, 1, 0, 1]
X_train = np.array(x).reshape(len(x), 1)
y_train = np.array(y)
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_train, y_train)
params = {
'objective': 'regression',
'metric': 'auc',
'verbose': -1,
'boost_from_average': False,
'min_data': 1,
'num_leaves': 2,
'learning_rate': 1,
'min_data_in_bin': 1,
'min_data_per_group': 1,
'zero_as_missing': False,
'categorical_column': 0
}
evals_result = {}
gbm = lgb.train(params, lgb_train,
num_boost_round=1,
valid_sets=lgb_eval,
verbose_eval=True,
evals_result=evals_result)
pred = gbm.predict(X_train)
np.testing.assert_almost_equal(pred, y)
def test_categorical_handle3(self):
x = [11, np.nan, 11, np.nan, 11, np.nan]
y = [0, 1, 0, 1, 0, 1]
X_train = np.array(x).reshape(len(x), 1)
y_train = np.array(y)
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_train, y_train)
params = {
'objective': 'regression',
'metric': 'auc',
'verbose': -1,
'boost_from_average': False,
'min_data': 1,
'num_leaves': 2,
'learning_rate': 1,
'min_data_in_bin': 1,
'min_data_per_group': 1,
'zero_as_missing': False,
'categorical_column': 0
}
evals_result = {}
gbm = lgb.train(params, lgb_train,
num_boost_round=1,
valid_sets=lgb_eval,
verbose_eval=True,
evals_result=evals_result)
pred = gbm.predict(X_train)
np.testing.assert_almost_equal(pred, y)
def test_multiclass(self):
X, y = load_digits(10, True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment