Commit ffb134cc authored by Nikita Titov's avatar Nikita Titov Committed by Guolin Ke
Browse files

[python] disabled split value histogram for categorical features (#2045)

* disabled split value histogram for categorical features

* updated test for cat. feature

* updated docs
parent 7ab94e6b
...@@ -2437,6 +2437,11 @@ class Booster(object): ...@@ -2437,6 +2437,11 @@ class Booster(object):
The feature name or index the histogram is calculated for. The feature name or index the histogram is calculated for.
If int, interpreted as index. If int, interpreted as index.
If string, interpreted as name. If string, interpreted as name.
Note
----
Categorical features are not supported.
bins : int, string or None, optional (default=None) bins : int, string or None, optional (default=None)
The maximum number of bins. The maximum number of bins.
If None, or int and > number of unique split values and ``xgboost_style=True``, If None, or int and > number of unique split values and ``xgboost_style=True``,
...@@ -2464,6 +2469,9 @@ class Booster(object): ...@@ -2464,6 +2469,9 @@ class Booster(object):
else: else:
split_feature = root['split_feature'] split_feature = root['split_feature']
if split_feature == feature: if split_feature == feature:
if isinstance(root['threshold'], string_type):
raise LightGBMError('Cannot compute split value histogram for the categorical feature')
else:
values.append(root['threshold']) values.append(root['threshold'])
add(root['left_child']) add(root['left_child'])
add(root['right_child']) add(root['right_child'])
......
...@@ -1245,17 +1245,17 @@ class TestEngine(unittest.TestCase): ...@@ -1245,17 +1245,17 @@ class TestEngine(unittest.TestCase):
def test_get_split_value_histogram(self): def test_get_split_value_histogram(self):
X, y = load_boston(True) X, y = load_boston(True)
lgb_train = lgb.Dataset(X, y) lgb_train = lgb.Dataset(X, y, categorical_feature=[2])
gbm = lgb.train({'verbose': -1}, lgb_train, num_boost_round=20) gbm = lgb.train({'verbose': -1}, lgb_train, num_boost_round=20)
# test XGBoost-style return value # test XGBoost-style return value
params = {'feature': 0, 'xgboost_style': True} params = {'feature': 0, 'xgboost_style': True}
self.assertTupleEqual(gbm.get_split_value_histogram(**params).shape, (10, 2)) self.assertTupleEqual(gbm.get_split_value_histogram(**params).shape, (9, 2))
self.assertTupleEqual(gbm.get_split_value_histogram(bins=999, **params).shape, (10, 2)) self.assertTupleEqual(gbm.get_split_value_histogram(bins=999, **params).shape, (9, 2))
self.assertTupleEqual(gbm.get_split_value_histogram(bins=-1, **params).shape, (1, 2)) self.assertTupleEqual(gbm.get_split_value_histogram(bins=-1, **params).shape, (1, 2))
self.assertTupleEqual(gbm.get_split_value_histogram(bins=0, **params).shape, (1, 2)) self.assertTupleEqual(gbm.get_split_value_histogram(bins=0, **params).shape, (1, 2))
self.assertTupleEqual(gbm.get_split_value_histogram(bins=1, **params).shape, (1, 2)) self.assertTupleEqual(gbm.get_split_value_histogram(bins=1, **params).shape, (1, 2))
self.assertTupleEqual(gbm.get_split_value_histogram(bins=2, **params).shape, (2, 2)) self.assertTupleEqual(gbm.get_split_value_histogram(bins=2, **params).shape, (2, 2))
self.assertTupleEqual(gbm.get_split_value_histogram(bins=6, **params).shape, (6, 2)) self.assertTupleEqual(gbm.get_split_value_histogram(bins=6, **params).shape, (5, 2))
self.assertTupleEqual(gbm.get_split_value_histogram(bins=7, **params).shape, (6, 2)) self.assertTupleEqual(gbm.get_split_value_histogram(bins=7, **params).shape, (6, 2))
if lgb.compat.PANDAS_INSTALLED: if lgb.compat.PANDAS_INSTALLED:
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
...@@ -1277,8 +1277,8 @@ class TestEngine(unittest.TestCase): ...@@ -1277,8 +1277,8 @@ class TestEngine(unittest.TestCase):
) )
# test numpy-style return value # test numpy-style return value
hist, bins = gbm.get_split_value_histogram(0) hist, bins = gbm.get_split_value_histogram(0)
self.assertEqual(len(hist), 22) self.assertEqual(len(hist), 23)
self.assertEqual(len(bins), 23) self.assertEqual(len(bins), 24)
hist, bins = gbm.get_split_value_histogram(0, bins=999) hist, bins = gbm.get_split_value_histogram(0, bins=999)
self.assertEqual(len(hist), 999) self.assertEqual(len(hist), 999)
self.assertEqual(len(bins), 1000) self.assertEqual(len(bins), 1000)
...@@ -1316,3 +1316,5 @@ class TestEngine(unittest.TestCase): ...@@ -1316,3 +1316,5 @@ class TestEngine(unittest.TestCase):
mask = hist_vals > 0 mask = hist_vals > 0
np.testing.assert_array_equal(hist_vals[mask], hist[:, 1]) np.testing.assert_array_equal(hist_vals[mask], hist[:, 1])
np.testing.assert_almost_equal(bin_edges[1:][mask], hist[:, 0]) np.testing.assert_almost_equal(bin_edges[1:][mask], hist[:, 0])
# test histogram is disabled for categorical features
self.assertRaises(lgb.basic.LightGBMError, gbm.get_split_value_histogram, 2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment