Unverified Commit 15e3aecc authored by Mohamed Ziada's avatar Mohamed Ziada Committed by GitHub
Browse files

[python-package] adding max_category_values parameter to create_tree_digraph...

[python-package] adding max_category_values parameter to create_tree_digraph method (fixes #5687) (#5818)
parent 1c8a7abd
...@@ -455,6 +455,7 @@ def _to_graphviz( ...@@ -455,6 +455,7 @@ def _to_graphviz(
orientation: str, orientation: str,
constraints: Optional[List[int]], constraints: Optional[List[int]],
example_case: Optional[Union[np.ndarray, pd_DataFrame]], example_case: Optional[Union[np.ndarray, pd_DataFrame]],
max_category_values: int,
**kwargs: Any **kwargs: Any
) -> Any: ) -> Any:
"""Convert specified tree to graphviz instance. """Convert specified tree to graphviz instance.
...@@ -477,6 +478,7 @@ def _to_graphviz( ...@@ -477,6 +478,7 @@ def _to_graphviz(
"""Recursively add node or edge.""" """Recursively add node or edge."""
fillcolor = 'white' fillcolor = 'white'
style = '' style = ''
tooltip = None
if highlight: if highlight:
color = 'blue' color = 'blue'
penwidth = '3' penwidth = '3'
...@@ -487,6 +489,7 @@ def _to_graphviz( ...@@ -487,6 +489,7 @@ def _to_graphviz(
shape = "rectangle" shape = "rectangle"
l_dec = 'yes' l_dec = 'yes'
r_dec = 'no' r_dec = 'no'
threshold = root['threshold']
if root['decision_type'] == '<=': if root['decision_type'] == '<=':
operator = "&#8804;" operator = "&#8804;"
elif root['decision_type'] == '==': elif root['decision_type'] == '==':
...@@ -513,7 +516,13 @@ def _to_graphviz( ...@@ -513,7 +516,13 @@ def _to_graphviz(
missing_type_str=root['missing_type'], missing_type_str=root['missing_type'],
default_left=root['default_left'] default_left=root['default_left']
) )
label += f"<B>{_float2str(root['threshold'], precision)}</B>" if root['decision_type'] == '==':
category_values = root['threshold'].split('||')
if len(category_values) > max_category_values:
tooltip = root['threshold']
threshold = '||'.join(category_values[:2]) + '||...||' + category_values[-1]
label += f"<B>{_float2str(threshold, precision)}</B>"
for info in ['split_gain', 'internal_value', 'internal_weight', "internal_count", "data_percentage"]: for info in ['split_gain', 'internal_value', 'internal_weight', "internal_count", "data_percentage"]:
if info in show_info: if info in show_info:
output = info.split('_')[-1] output = info.split('_')[-1]
...@@ -557,7 +566,7 @@ def _to_graphviz( ...@@ -557,7 +566,7 @@ def _to_graphviz(
if "data_percentage" in show_info: if "data_percentage" in show_info:
label += f"<br/>{_float2str(root['leaf_count'] / total_count * 100, 2)}% of data" label += f"<br/>{_float2str(root['leaf_count'] / total_count * 100, 2)}% of data"
label = f"<{label}>" label = f"<{label}>"
graph.node(name, label=label, shape=shape, style=style, fillcolor=fillcolor, color=color, penwidth=penwidth) graph.node(name, label=label, shape=shape, style=style, fillcolor=fillcolor, color=color, penwidth=penwidth, tooltip=tooltip)
if parent is not None: if parent is not None:
graph.edge(parent, name, decision, color=color, penwidth=penwidth) graph.edge(parent, name, decision, color=color, penwidth=penwidth)
...@@ -603,6 +612,7 @@ def create_tree_digraph( ...@@ -603,6 +612,7 @@ def create_tree_digraph(
precision: Optional[int] = 3, precision: Optional[int] = 3,
orientation: str = 'horizontal', orientation: str = 'horizontal',
example_case: Optional[Union[np.ndarray, pd_DataFrame]] = None, example_case: Optional[Union[np.ndarray, pd_DataFrame]] = None,
max_category_values: int = 10,
**kwargs: Any **kwargs: Any
) -> Any: ) -> Any:
"""Create a digraph representation of specified tree. """Create a digraph representation of specified tree.
...@@ -646,6 +656,22 @@ def create_tree_digraph( ...@@ -646,6 +656,22 @@ def create_tree_digraph(
example_case : numpy 2-D array, pandas DataFrame or None, optional (default=None) example_case : numpy 2-D array, pandas DataFrame or None, optional (default=None)
Single row with the same structure as the training data. Single row with the same structure as the training data.
If not None, the plot will highlight the path that sample takes through the tree. If not None, the plot will highlight the path that sample takes through the tree.
max_category_values : int, optional (default=10)
The maximum number of category values to display in tree nodes, if the number of thresholds is greater than this value, thresholds will be collapsed and displayed on the label tooltip instead.
.. warning::
Consider wrapping the SVG string of the tree graph with ``IPython.display.HTML`` when running on JupyterLab to get the `tooltip <https://graphviz.org/docs/attrs/tooltip>`_ working right.
Example:
.. code-block:: python
from IPython.display import HTML
graph = lgb.create_tree_digraph(clf, max_category_values=5)
HTML(graph._repr_image_svg_xml())
**kwargs **kwargs
Other parameters passed to ``Digraph`` constructor. Other parameters passed to ``Digraph`` constructor.
Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters. Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters.
...@@ -699,6 +725,7 @@ def create_tree_digraph( ...@@ -699,6 +725,7 @@ def create_tree_digraph(
orientation=orientation, orientation=orientation,
constraints=monotone_constraints, constraints=monotone_constraints,
example_case=example_case, example_case=example_case,
max_category_values=max_category_values,
**kwargs **kwargs
) )
......
# coding: utf-8 # coding: utf-8
import numpy as np import numpy as np
import pandas as pd
import pytest import pytest
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
...@@ -21,6 +22,17 @@ def breast_cancer_split(): ...@@ -21,6 +22,17 @@ def breast_cancer_split():
test_size=0.1, random_state=1) test_size=0.1, random_state=1)
def _categorical_data(category_values_lower_bound, category_values_upper_bound):
X, y = load_breast_cancer(return_X_y=True)
X_df = pd.DataFrame()
rnd = np.random.RandomState(0)
n_cat_values = rnd.randint(category_values_lower_bound, category_values_upper_bound, size=X.shape[1])
for i in range(X.shape[1]):
bins = np.linspace(0, 1, num=n_cat_values[i] + 1)
X_df[f"cat_col_{i}"] = pd.qcut(X[:, i], q=bins, labels=range(n_cat_values[i])).as_unordered()
return X_df, y
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def train_data(breast_cancer_split): def train_data(breast_cancer_split):
X_train, _, y_train, _ = breast_cancer_split X_train, _, y_train, _ = breast_cancer_split
...@@ -188,6 +200,86 @@ def test_create_tree_digraph(breast_cancer_split): ...@@ -188,6 +200,86 @@ def test_create_tree_digraph(breast_cancer_split):
assert 'count' not in graph_body assert 'count' not in graph_body
@pytest.mark.skipif(not GRAPHVIZ_INSTALLED, reason='graphviz is not installed')
def test_tree_with_categories_below_max_category_values():
X_train, y_train = _categorical_data(2, 10)
params = {
"n_estimators": 10,
"num_leaves": 3,
"min_data_in_bin": 1,
"force_col_wise": True,
"deterministic": True,
"num_threads": 1,
"seed": 708,
"verbose": -1
}
gbm = lgb.LGBMClassifier(**params)
gbm.fit(X_train, y_train)
with pytest.raises(IndexError):
lgb.create_tree_digraph(gbm, tree_index=83)
graph = lgb.create_tree_digraph(gbm, tree_index=3,
show_info=['split_gain', 'internal_value', 'internal_weight'],
name='Tree4', node_attr={'color': 'red'},
max_category_values=10)
graph.render(view=False)
assert isinstance(graph, graphviz.Digraph)
assert graph.name == 'Tree4'
assert len(graph.node_attr) == 1
assert graph.node_attr['color'] == 'red'
assert len(graph.graph_attr) == 0
assert len(graph.edge_attr) == 0
graph_body = ''.join(graph.body)
assert 'leaf' in graph_body
assert 'gain' in graph_body
assert 'value' in graph_body
assert 'weight' in graph_body
assert 'data' not in graph_body
assert 'count' not in graph_body
assert '||...||' not in graph_body
@pytest.mark.skipif(not GRAPHVIZ_INSTALLED, reason='graphviz is not installed')
def test_tree_with_categories_above_max_category_values():
X_train, y_train = _categorical_data(20, 30)
params = {
"n_estimators": 10,
"num_leaves": 3,
"min_data_in_bin": 1,
"force_col_wise": True,
"deterministic": True,
"num_threads": 1,
"seed": 708,
"verbose": -1
}
gbm = lgb.LGBMClassifier(**params)
gbm.fit(X_train, y_train)
with pytest.raises(IndexError):
lgb.create_tree_digraph(gbm, tree_index=83)
graph = lgb.create_tree_digraph(gbm, tree_index=9,
show_info=['split_gain', 'internal_value', 'internal_weight'],
name='Tree4', node_attr={'color': 'red'},
max_category_values=4)
graph.render(view=False)
assert isinstance(graph, graphviz.Digraph)
assert graph.name == 'Tree4'
assert len(graph.node_attr) == 1
assert graph.node_attr['color'] == 'red'
assert len(graph.graph_attr) == 0
assert len(graph.edge_attr) == 0
graph_body = ''.join(graph.body)
assert 'leaf' in graph_body
assert 'gain' in graph_body
assert 'value' in graph_body
assert 'weight' in graph_body
assert 'data' not in graph_body
assert 'count' not in graph_body
assert '||...||' in graph_body
@pytest.mark.parametrize('use_missing', [True, False]) @pytest.mark.parametrize('use_missing', [True, False])
@pytest.mark.parametrize('zero_as_missing', [True, False]) @pytest.mark.parametrize('zero_as_missing', [True, False])
def test_numeric_split_direction(use_missing, zero_as_missing): def test_numeric_split_direction(use_missing, zero_as_missing):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment