Unverified Commit 680f4b08 authored by José Morales's avatar José Morales Committed by GitHub
Browse files

[python-package] highlight the path a sample takes through a tree in...


[python-package] highlight the path a sample takes through a tree in `plot_tree` and `create_tree_digraph` (fixes #4784) (#5119)

* highlight path in plot_tree

* lint

* rename x to example_case. support categorical features. add test

* lint

* check for exactly one row. test empty example_case

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* handle missing values in numeric splits

* remove literal. add categorical split function

* make categorical feature more important. lint

* add enum. update categorical split. apply suggestions

* update numeric split decision

* lint

* Update python-package/lightgbm/plotting.py
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 6be79860
...@@ -6,6 +6,7 @@ import json ...@@ -6,6 +6,7 @@ import json
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
from enum import Enum
from functools import wraps from functools import wraps
from os import SEEK_END, environ from os import SEEK_END, environ
from os.path import getsize from os.path import getsize
...@@ -22,6 +23,10 @@ from .libpath import find_lib_path ...@@ -22,6 +23,10 @@ from .libpath import find_lib_path
ZERO_THRESHOLD = 1e-35 ZERO_THRESHOLD = 1e-35
def _is_zero(x: float) -> bool:
return -ZERO_THRESHOLD <= x <= ZERO_THRESHOLD
def _get_sample_count(total_nrow: int, params: str) -> int: def _get_sample_count(total_nrow: int, params: str) -> int:
sample_cnt = ctypes.c_int(0) sample_cnt = ctypes.c_int(0)
_safe_call(_LIB.LGBM_GetSampleCount( _safe_call(_LIB.LGBM_GetSampleCount(
...@@ -32,6 +37,12 @@ def _get_sample_count(total_nrow: int, params: str) -> int: ...@@ -32,6 +37,12 @@ def _get_sample_count(total_nrow: int, params: str) -> int:
return sample_cnt.value return sample_cnt.value
class _MissingType(Enum):
NONE = 'None'
NAN = 'NaN'
ZERO = 'Zero'
class _DummyLogger: class _DummyLogger:
def info(self, msg: str) -> None: def info(self, msg: str) -> None:
print(msg) print(msg)
......
# coding: utf-8 # coding: utf-8
"""Plotting library.""" """Plotting library."""
import math
from copy import deepcopy from copy import deepcopy
from io import BytesIO from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
from .basic import Booster, _log_warning from .basic import ZERO_THRESHOLD, Booster, _data_from_pandas, _is_zero, _log_warning, _MissingType
from .compat import GRAPHVIZ_INSTALLED, MATPLOTLIB_INSTALLED from .compat import GRAPHVIZ_INSTALLED, MATPLOTLIB_INSTALLED, pd_DataFrame
from .sklearn import LGBMModel from .sklearn import LGBMModel
...@@ -414,6 +415,30 @@ def plot_metric( ...@@ -414,6 +415,30 @@ def plot_metric(
return ax return ax
def _determine_direction_for_numeric_split(
fval: float,
threshold: float,
missing_type_str: str,
default_left: bool,
) -> str:
missing_type = _MissingType(missing_type_str)
if math.isnan(fval) and missing_type != _MissingType.NAN:
fval = 0.0
if ((missing_type == _MissingType.ZERO and _is_zero(fval))
or (missing_type == _MissingType.NAN and math.isnan(fval))):
direction = 'left' if default_left else 'right'
else:
direction = 'left' if fval <= threshold else 'right'
return direction
def _determine_direction_for_categorical_split(fval: float, thresholds: str) -> str:
if math.isnan(fval) or int(fval) < 0:
return 'right'
int_thresholds = {int(t) for t in thresholds.split('||')}
return 'left' if int(fval) in int_thresholds else 'right'
def _to_graphviz( def _to_graphviz(
tree_info: Dict[str, Any], tree_info: Dict[str, Any],
show_info: List[str], show_info: List[str],
...@@ -421,6 +446,7 @@ def _to_graphviz( ...@@ -421,6 +446,7 @@ def _to_graphviz(
precision: Optional[int] = 3, precision: Optional[int] = 3,
orientation: str = 'horizontal', orientation: str = 'horizontal',
constraints: Optional[List[int]] = None, constraints: Optional[List[int]] = None,
example_case: Optional[Union[np.ndarray, pd_DataFrame]] = None,
**kwargs: Any **kwargs: Any
) -> Any: ) -> Any:
"""Convert specified tree to graphviz instance. """Convert specified tree to graphviz instance.
...@@ -433,23 +459,40 @@ def _to_graphviz( ...@@ -433,23 +459,40 @@ def _to_graphviz(
else: else:
raise ImportError('You must install graphviz and restart your session to plot tree.') raise ImportError('You must install graphviz and restart your session to plot tree.')
def add(root, total_count, parent=None, decision=None): def add(root, total_count, parent=None, decision=None, highlight=False):
"""Recursively add node or edge.""" """Recursively add node or edge."""
fillcolor = 'white'
style = ''
if highlight:
color = 'blue'
penwidth = '3'
else:
color = 'black'
penwidth = '1'
if 'split_index' in root: # non-leaf if 'split_index' in root: # non-leaf
shape = "rectangle"
l_dec = 'yes' l_dec = 'yes'
r_dec = 'no' r_dec = 'no'
if root['decision_type'] == '<=': if root['decision_type'] == '<=':
lte_symbol = "&#8804;" operator = "&#8804;"
operator = lte_symbol
elif root['decision_type'] == '==': elif root['decision_type'] == '==':
operator = "=" operator = "="
else: else:
raise ValueError('Invalid decision type in tree model.') raise ValueError('Invalid decision type in tree model.')
name = f"split{root['split_index']}" name = f"split{root['split_index']}"
split_feature = root['split_feature']
if feature_names is not None: if feature_names is not None:
label = f"<B>{feature_names[root['split_feature']]}</B> {operator}" label = f"<B>{feature_names[split_feature]}</B> {operator}"
else: else:
label = f"feature <B>{root['split_feature']}</B> {operator} " label = f"feature <B>{split_feature}</B> {operator} "
direction = None
if example_case is not None:
if root['decision_type'] == '==':
direction = _determine_direction_for_categorical_split(example_case[split_feature], root['threshold'])
else:
direction = _determine_direction_for_numeric_split(
example_case[split_feature], root['threshold'], root['missing_type'], root['default_left']
)
label += f"<B>{_float2str(root['threshold'], precision)}</B>" label += f"<B>{_float2str(root['threshold'], precision)}</B>"
for info in ['split_gain', 'internal_value', 'internal_weight', "internal_count", "data_percentage"]: for info in ['split_gain', 'internal_value', 'internal_weight', "internal_count", "data_percentage"]:
if info in show_info: if info in show_info:
...@@ -461,8 +504,6 @@ def _to_graphviz( ...@@ -461,8 +504,6 @@ def _to_graphviz(
elif info == "data_percentage": elif info == "data_percentage":
label += f"<br/>{_float2str(root['internal_count'] / total_count * 100, 2)}% of data" label += f"<br/>{_float2str(root['internal_count'] / total_count * 100, 2)}% of data"
fillcolor = "white"
style = ""
if constraints: if constraints:
if constraints[root['split_feature']] == 1: if constraints[root['split_feature']] == 1:
fillcolor = "#ddffdd" # light green fillcolor = "#ddffdd" # light green
...@@ -470,10 +511,10 @@ def _to_graphviz( ...@@ -470,10 +511,10 @@ def _to_graphviz(
fillcolor = "#ffdddd" # light red fillcolor = "#ffdddd" # light red
style = "filled" style = "filled"
label = f"<{label}>" label = f"<{label}>"
graph.node(name, label=label, shape="rectangle", style=style, fillcolor=fillcolor) add(root['left_child'], total_count, name, l_dec, highlight and direction == "left")
add(root['left_child'], total_count, name, l_dec) add(root['right_child'], total_count, name, r_dec, highlight and direction == "right")
add(root['right_child'], total_count, name, r_dec)
else: # leaf else: # leaf
shape = "ellipse"
name = f"leaf{root['leaf_index']}" name = f"leaf{root['leaf_index']}"
label = f"leaf {root['leaf_index']}: " label = f"leaf {root['leaf_index']}: "
label += f"<B>{_float2str(root['leaf_value'], precision)}</B>" label += f"<B>{_float2str(root['leaf_value'], precision)}</B>"
...@@ -484,15 +525,15 @@ def _to_graphviz( ...@@ -484,15 +525,15 @@ def _to_graphviz(
if "data_percentage" in show_info: if "data_percentage" in show_info:
label += f"<br/>{_float2str(root['leaf_count'] / total_count * 100, 2)}% of data" label += f"<br/>{_float2str(root['leaf_count'] / total_count * 100, 2)}% of data"
label = f"<{label}>" label = f"<{label}>"
graph.node(name, label=label) graph.node(name, label=label, shape=shape, style=style, fillcolor=fillcolor, color=color, penwidth=penwidth)
if parent is not None: if parent is not None:
graph.edge(parent, name, decision) graph.edge(parent, name, decision, color=color, penwidth=penwidth)
graph = Digraph(**kwargs) graph = Digraph(**kwargs)
rankdir = "LR" if orientation == "horizontal" else "TB" rankdir = "LR" if orientation == "horizontal" else "TB"
graph.attr("graph", nodesep="0.05", ranksep="0.3", rankdir=rankdir) graph.attr("graph", nodesep="0.05", ranksep="0.3", rankdir=rankdir)
if "internal_count" in tree_info['tree_structure']: if "internal_count" in tree_info['tree_structure']:
add(tree_info['tree_structure'], tree_info['tree_structure']["internal_count"]) add(tree_info['tree_structure'], tree_info['tree_structure']["internal_count"], highlight=example_case is not None)
else: else:
raise Exception("Cannot plot trees with no split") raise Exception("Cannot plot trees with no split")
...@@ -523,6 +564,7 @@ def create_tree_digraph( ...@@ -523,6 +564,7 @@ def create_tree_digraph(
show_info: Optional[List[str]] = None, show_info: Optional[List[str]] = None,
precision: Optional[int] = 3, precision: Optional[int] = 3,
orientation: str = 'horizontal', orientation: str = 'horizontal',
example_case: Optional[Union[np.ndarray, pd_DataFrame]] = None,
**kwargs: Any **kwargs: Any
) -> Any: ) -> Any:
"""Create a digraph representation of specified tree. """Create a digraph representation of specified tree.
...@@ -563,6 +605,9 @@ def create_tree_digraph( ...@@ -563,6 +605,9 @@ def create_tree_digraph(
orientation : str, optional (default='horizontal') orientation : str, optional (default='horizontal')
Orientation of the tree. Orientation of the tree.
Can be 'horizontal' or 'vertical'. Can be 'horizontal' or 'vertical'.
example_case : numpy 2-D array, pandas DataFrame or None, optional (default=None)
Single row with the same structure as the training data.
If not None, the plot will highlight the path that sample takes through the tree.
**kwargs **kwargs
Other parameters passed to ``Digraph`` constructor. Other parameters passed to ``Digraph`` constructor.
Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters. Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters.
...@@ -594,8 +639,17 @@ def create_tree_digraph( ...@@ -594,8 +639,17 @@ def create_tree_digraph(
if show_info is None: if show_info is None:
show_info = [] show_info = []
if example_case is not None:
if not isinstance(example_case, (np.ndarray, pd_DataFrame)) or example_case.ndim != 2:
raise ValueError('example_case must be a numpy 2-D array or a pandas DataFrame')
if example_case.shape[0] != 1:
raise ValueError('example_case must have a single row.')
if isinstance(example_case, pd_DataFrame):
example_case = _data_from_pandas(example_case, None, None, booster.pandas_categorical)[0]
example_case = example_case[0]
graph = _to_graphviz(tree_info, show_info, feature_names, precision, graph = _to_graphviz(tree_info, show_info, feature_names, precision,
orientation, monotone_constraints, **kwargs) orientation, monotone_constraints, example_case=example_case, **kwargs)
return graph return graph
...@@ -609,6 +663,7 @@ def plot_tree( ...@@ -609,6 +663,7 @@ def plot_tree(
show_info: Optional[List[str]] = None, show_info: Optional[List[str]] = None,
precision: Optional[int] = 3, precision: Optional[int] = 3,
orientation: str = 'horizontal', orientation: str = 'horizontal',
example_case: Optional[Union[np.ndarray, pd_DataFrame]] = None,
**kwargs: Any **kwargs: Any
) -> Any: ) -> Any:
"""Plot specified tree. """Plot specified tree.
...@@ -656,6 +711,9 @@ def plot_tree( ...@@ -656,6 +711,9 @@ def plot_tree(
orientation : str, optional (default='horizontal') orientation : str, optional (default='horizontal')
Orientation of the tree. Orientation of the tree.
Can be 'horizontal' or 'vertical'. Can be 'horizontal' or 'vertical'.
example_case : numpy 2-D array, pandas DataFrame or None, optional (default=None)
Single row with the same structure as the training data.
If not None, the plot will highlight the path that sample takes through the tree.
**kwargs **kwargs
Other parameters passed to ``Digraph`` constructor. Other parameters passed to ``Digraph`` constructor.
Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters. Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters.
...@@ -678,7 +736,7 @@ def plot_tree( ...@@ -678,7 +736,7 @@ def plot_tree(
graph = create_tree_digraph(booster=booster, tree_index=tree_index, graph = create_tree_digraph(booster=booster, tree_index=tree_index,
show_info=show_info, precision=precision, show_info=show_info, precision=precision,
orientation=orientation, **kwargs) orientation=orientation, example_case=example_case, **kwargs)
s = BytesIO() s = BytesIO()
s.write(graph.pipe(format='png')) s.write(graph.pipe(format='png'))
......
# coding: utf-8 # coding: utf-8
import numpy as np
import pytest import pytest
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
import lightgbm as lgb import lightgbm as lgb
from lightgbm.compat import GRAPHVIZ_INSTALLED, MATPLOTLIB_INSTALLED from lightgbm.compat import GRAPHVIZ_INSTALLED, MATPLOTLIB_INSTALLED, PANDAS_INSTALLED, pd_DataFrame
if MATPLOTLIB_INSTALLED: if MATPLOTLIB_INSTALLED:
import matplotlib import matplotlib
...@@ -11,7 +12,7 @@ if MATPLOTLIB_INSTALLED: ...@@ -11,7 +12,7 @@ if MATPLOTLIB_INSTALLED:
if GRAPHVIZ_INSTALLED: if GRAPHVIZ_INSTALLED:
import graphviz import graphviz
from .utils import load_breast_cancer from .utils import load_breast_cancer, make_synthetic_regression
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
...@@ -187,6 +188,129 @@ def test_create_tree_digraph(breast_cancer_split): ...@@ -187,6 +188,129 @@ def test_create_tree_digraph(breast_cancer_split):
assert 'count' not in graph_body assert 'count' not in graph_body
@pytest.mark.parametrize('use_missing', [True, False])
@pytest.mark.parametrize('zero_as_missing', [True, False])
def test_numeric_split_direction(use_missing, zero_as_missing):
if use_missing and zero_as_missing:
pytest.skip('use_missing and zero_as_missing both set to True')
X, y = make_synthetic_regression()
rng = np.random.RandomState(0)
zero_mask = rng.rand(X.shape[0]) < 0.05
X[zero_mask, :] = 0
if use_missing:
nan_mask = ~zero_mask & (rng.rand(X.shape[0]) < 0.1)
X[nan_mask, :] = np.nan
ds = lgb.Dataset(X, y)
params = {
'num_leaves': 127,
'min_child_samples': 1,
'use_missing': use_missing,
'zero_as_missing': zero_as_missing,
}
bst = lgb.train(params, ds, num_boost_round=1)
case_with_zero = X[zero_mask][[0]]
expected_leaf_zero = bst.predict(case_with_zero, pred_leaf=True)[0]
node = bst.dump_model()['tree_info'][0]['tree_structure']
while 'decision_type' in node:
direction = lgb.plotting._determine_direction_for_numeric_split(
case_with_zero[0][node['split_feature']], node['threshold'], node['missing_type'], node['default_left']
)
node = node['left_child'] if direction == 'left' else node['right_child']
assert node['leaf_index'] == expected_leaf_zero
if use_missing:
case_with_nan = X[nan_mask][[0]]
expected_leaf_nan = bst.predict(case_with_nan, pred_leaf=True)[0]
node = bst.dump_model()['tree_info'][0]['tree_structure']
while 'decision_type' in node:
direction = lgb.plotting._determine_direction_for_numeric_split(
case_with_nan[0][node['split_feature']], node['threshold'], node['missing_type'], node['default_left']
)
node = node['left_child'] if direction == 'left' else node['right_child']
assert node['leaf_index'] == expected_leaf_nan
assert expected_leaf_zero != expected_leaf_nan
@pytest.mark.skipif(not GRAPHVIZ_INSTALLED, reason='graphviz is not installed')
def test_example_case_in_tree_digraph():
rng = np.random.RandomState(0)
x1 = rng.rand(100)
cat = rng.randint(1, 3, size=x1.size)
X = np.vstack([x1, cat]).T
y = x1 + 2 * cat
feature_name = ['x1', 'cat']
ds = lgb.Dataset(X, y, feature_name=feature_name, categorical_feature=['cat'])
num_round = 3
bst = lgb.train({'num_leaves': 7}, ds, num_boost_round=num_round)
mod = bst.dump_model()
example_case = X[[0]]
makes_categorical_splits = False
seen_indices = set()
for i in range(num_round):
graph = lgb.create_tree_digraph(bst, example_case=example_case, tree_index=i)
gbody = graph.body
node = mod['tree_info'][i]['tree_structure']
while 'decision_type' in node: # iterate through the splits
split_index = node['split_index']
node_in_graph = [n for n in gbody if f'split{split_index}' in n and '->' not in n]
assert len(node_in_graph) == 1
seen_indices.add(gbody.index(node_in_graph[0]))
edge_to_node = [e for e in gbody if f'-> split{split_index}' in e]
if node['decision_type'] == '<=':
direction = lgb.plotting._determine_direction_for_numeric_split(
example_case[0][node['split_feature']], node['threshold'], node['missing_type'], node['default_left'])
else:
makes_categorical_splits = True
direction = lgb.plotting._determine_direction_for_categorical_split(
example_case[0][node['split_feature']], node['threshold']
)
node = node['left_child'] if direction == 'left' else node['right_child']
assert 'color=blue' in node_in_graph[0]
if edge_to_node:
assert len(edge_to_node) == 1
assert 'color=blue' in edge_to_node[0]
seen_indices.add(gbody.index(edge_to_node[0]))
# we're in a leaf now
leaf_index = node['leaf_index']
leaf_in_graph = [n for n in gbody if f'leaf{leaf_index}' in n and '->' not in n]
edge_to_leaf = [e for e in gbody if f'-> leaf{leaf_index}' in e]
assert len(leaf_in_graph) == 1
assert 'color=blue' in leaf_in_graph[0]
assert len(edge_to_leaf) == 1
assert 'color=blue' in edge_to_leaf[0]
seen_indices.update([gbody.index(leaf_in_graph[0]), gbody.index(edge_to_leaf[0])])
# check that the rest of the elements have black color
remaining_elements = [e for i, e in enumerate(graph.body) if i not in seen_indices and 'graph' not in e]
assert all('color=black' in e for e in remaining_elements)
# check that we got to the expected leaf
expected_leaf = bst.predict(example_case, start_iteration=i, num_iteration=1, pred_leaf=True)[0]
assert leaf_index == expected_leaf
assert makes_categorical_splits
@pytest.mark.skipif(not GRAPHVIZ_INSTALLED, reason='graphviz is not installed')
@pytest.mark.parametrize('input_type', ['array', 'dataframe'])
def test_empty_example_case_on_tree_digraph_raises_error(input_type):
X, y = make_synthetic_regression()
if input_type == 'dataframe':
if not PANDAS_INSTALLED:
pytest.skip(reason='pandas is not installed')
X = pd_DataFrame(X)
ds = lgb.Dataset(X, y)
bst = lgb.train({'num_leaves': 3}, ds, num_boost_round=1)
example_case = X[:0]
if input_type == 'dataframe':
example_case = pd_DataFrame(example_case)
with pytest.raises(ValueError, match='example_case must have a single row.'):
lgb.create_tree_digraph(bst, tree_index=0, example_case=example_case)
@pytest.mark.skipif(not MATPLOTLIB_INSTALLED, reason='matplotlib is not installed') @pytest.mark.skipif(not MATPLOTLIB_INSTALLED, reason='matplotlib is not installed')
def test_plot_metrics(params, breast_cancer_split, train_data): def test_plot_metrics(params, breast_cancer_split, train_data):
X_train, X_test, y_train, y_test = breast_cancer_split X_train, X_test, y_train, y_test = breast_cancer_split
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment