Unverified Commit 680f4b08 authored by José Morales's avatar José Morales Committed by GitHub
Browse files

[python-package] highlight the path a sample takes through a tree in...


[python-package] highlight the path a sample takes through a tree in `plot_tree` and `create_tree_digraph` (fixes #4784) (#5119)

* highlight path in plot_tree

* lint

* rename x to example_case. support categorical features. add test

* lint

* check for exactly one row. test empty example_case

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* handle missing values in numeric splits

* remove literal. add categorical split function

* make categorical feature more important. lint

* add enum. update categorical split. apply suggestions

* update numeric split decision

* lint

* Update python-package/lightgbm/plotting.py
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 6be79860
......@@ -6,6 +6,7 @@ import json
import warnings
from collections import OrderedDict
from copy import deepcopy
from enum import Enum
from functools import wraps
from os import SEEK_END, environ
from os.path import getsize
......@@ -22,6 +23,10 @@ from .libpath import find_lib_path
ZERO_THRESHOLD = 1e-35
def _is_zero(x: float) -> bool:
return -ZERO_THRESHOLD <= x <= ZERO_THRESHOLD
def _get_sample_count(total_nrow: int, params: str) -> int:
sample_cnt = ctypes.c_int(0)
_safe_call(_LIB.LGBM_GetSampleCount(
......@@ -32,6 +37,12 @@ def _get_sample_count(total_nrow: int, params: str) -> int:
return sample_cnt.value
class _MissingType(Enum):
NONE = 'None'
NAN = 'NaN'
ZERO = 'Zero'
class _DummyLogger:
def info(self, msg: str) -> None:
print(msg)
......
# coding: utf-8
"""Plotting library."""
import math
from copy import deepcopy
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from .basic import Booster, _log_warning
from .compat import GRAPHVIZ_INSTALLED, MATPLOTLIB_INSTALLED
from .basic import ZERO_THRESHOLD, Booster, _data_from_pandas, _is_zero, _log_warning, _MissingType
from .compat import GRAPHVIZ_INSTALLED, MATPLOTLIB_INSTALLED, pd_DataFrame
from .sklearn import LGBMModel
......@@ -414,6 +415,30 @@ def plot_metric(
return ax
def _determine_direction_for_numeric_split(
fval: float,
threshold: float,
missing_type_str: str,
default_left: bool,
) -> str:
missing_type = _MissingType(missing_type_str)
if math.isnan(fval) and missing_type != _MissingType.NAN:
fval = 0.0
if ((missing_type == _MissingType.ZERO and _is_zero(fval))
or (missing_type == _MissingType.NAN and math.isnan(fval))):
direction = 'left' if default_left else 'right'
else:
direction = 'left' if fval <= threshold else 'right'
return direction
def _determine_direction_for_categorical_split(fval: float, thresholds: str) -> str:
if math.isnan(fval) or int(fval) < 0:
return 'right'
int_thresholds = {int(t) for t in thresholds.split('||')}
return 'left' if int(fval) in int_thresholds else 'right'
def _to_graphviz(
tree_info: Dict[str, Any],
show_info: List[str],
......@@ -421,6 +446,7 @@ def _to_graphviz(
precision: Optional[int] = 3,
orientation: str = 'horizontal',
constraints: Optional[List[int]] = None,
example_case: Optional[Union[np.ndarray, pd_DataFrame]] = None,
**kwargs: Any
) -> Any:
"""Convert specified tree to graphviz instance.
......@@ -433,23 +459,40 @@ def _to_graphviz(
else:
raise ImportError('You must install graphviz and restart your session to plot tree.')
def add(root, total_count, parent=None, decision=None):
def add(root, total_count, parent=None, decision=None, highlight=False):
"""Recursively add node or edge."""
fillcolor = 'white'
style = ''
if highlight:
color = 'blue'
penwidth = '3'
else:
color = 'black'
penwidth = '1'
if 'split_index' in root: # non-leaf
shape = "rectangle"
l_dec = 'yes'
r_dec = 'no'
if root['decision_type'] == '<=':
lte_symbol = "&#8804;"
operator = lte_symbol
operator = "&#8804;"
elif root['decision_type'] == '==':
operator = "="
else:
raise ValueError('Invalid decision type in tree model.')
name = f"split{root['split_index']}"
split_feature = root['split_feature']
if feature_names is not None:
label = f"<B>{feature_names[root['split_feature']]}</B> {operator}"
label = f"<B>{feature_names[split_feature]}</B> {operator}"
else:
label = f"feature <B>{root['split_feature']}</B> {operator} "
label = f"feature <B>{split_feature}</B> {operator} "
direction = None
if example_case is not None:
if root['decision_type'] == '==':
direction = _determine_direction_for_categorical_split(example_case[split_feature], root['threshold'])
else:
direction = _determine_direction_for_numeric_split(
example_case[split_feature], root['threshold'], root['missing_type'], root['default_left']
)
label += f"<B>{_float2str(root['threshold'], precision)}</B>"
for info in ['split_gain', 'internal_value', 'internal_weight', "internal_count", "data_percentage"]:
if info in show_info:
......@@ -461,8 +504,6 @@ def _to_graphviz(
elif info == "data_percentage":
label += f"<br/>{_float2str(root['internal_count'] / total_count * 100, 2)}% of data"
fillcolor = "white"
style = ""
if constraints:
if constraints[root['split_feature']] == 1:
fillcolor = "#ddffdd" # light green
......@@ -470,10 +511,10 @@ def _to_graphviz(
fillcolor = "#ffdddd" # light red
style = "filled"
label = f"<{label}>"
graph.node(name, label=label, shape="rectangle", style=style, fillcolor=fillcolor)
add(root['left_child'], total_count, name, l_dec)
add(root['right_child'], total_count, name, r_dec)
add(root['left_child'], total_count, name, l_dec, highlight and direction == "left")
add(root['right_child'], total_count, name, r_dec, highlight and direction == "right")
else: # leaf
shape = "ellipse"
name = f"leaf{root['leaf_index']}"
label = f"leaf {root['leaf_index']}: "
label += f"<B>{_float2str(root['leaf_value'], precision)}</B>"
......@@ -484,15 +525,15 @@ def _to_graphviz(
if "data_percentage" in show_info:
label += f"<br/>{_float2str(root['leaf_count'] / total_count * 100, 2)}% of data"
label = f"<{label}>"
graph.node(name, label=label)
graph.node(name, label=label, shape=shape, style=style, fillcolor=fillcolor, color=color, penwidth=penwidth)
if parent is not None:
graph.edge(parent, name, decision)
graph.edge(parent, name, decision, color=color, penwidth=penwidth)
graph = Digraph(**kwargs)
rankdir = "LR" if orientation == "horizontal" else "TB"
graph.attr("graph", nodesep="0.05", ranksep="0.3", rankdir=rankdir)
if "internal_count" in tree_info['tree_structure']:
add(tree_info['tree_structure'], tree_info['tree_structure']["internal_count"])
add(tree_info['tree_structure'], tree_info['tree_structure']["internal_count"], highlight=example_case is not None)
else:
raise Exception("Cannot plot trees with no split")
......@@ -523,6 +564,7 @@ def create_tree_digraph(
show_info: Optional[List[str]] = None,
precision: Optional[int] = 3,
orientation: str = 'horizontal',
example_case: Optional[Union[np.ndarray, pd_DataFrame]] = None,
**kwargs: Any
) -> Any:
"""Create a digraph representation of specified tree.
......@@ -563,6 +605,9 @@ def create_tree_digraph(
orientation : str, optional (default='horizontal')
Orientation of the tree.
Can be 'horizontal' or 'vertical'.
example_case : numpy 2-D array, pandas DataFrame or None, optional (default=None)
Single row with the same structure as the training data.
If not None, the plot will highlight the path that sample takes through the tree.
**kwargs
Other parameters passed to ``Digraph`` constructor.
Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters.
......@@ -594,8 +639,17 @@ def create_tree_digraph(
if show_info is None:
show_info = []
if example_case is not None:
if not isinstance(example_case, (np.ndarray, pd_DataFrame)) or example_case.ndim != 2:
raise ValueError('example_case must be a numpy 2-D array or a pandas DataFrame')
if example_case.shape[0] != 1:
raise ValueError('example_case must have a single row.')
if isinstance(example_case, pd_DataFrame):
example_case = _data_from_pandas(example_case, None, None, booster.pandas_categorical)[0]
example_case = example_case[0]
graph = _to_graphviz(tree_info, show_info, feature_names, precision,
orientation, monotone_constraints, **kwargs)
orientation, monotone_constraints, example_case=example_case, **kwargs)
return graph
......@@ -609,6 +663,7 @@ def plot_tree(
show_info: Optional[List[str]] = None,
precision: Optional[int] = 3,
orientation: str = 'horizontal',
example_case: Optional[Union[np.ndarray, pd_DataFrame]] = None,
**kwargs: Any
) -> Any:
"""Plot specified tree.
......@@ -656,6 +711,9 @@ def plot_tree(
orientation : str, optional (default='horizontal')
Orientation of the tree.
Can be 'horizontal' or 'vertical'.
example_case : numpy 2-D array, pandas DataFrame or None, optional (default=None)
Single row with the same structure as the training data.
If not None, the plot will highlight the path that sample takes through the tree.
**kwargs
Other parameters passed to ``Digraph`` constructor.
Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters.
......@@ -678,7 +736,7 @@ def plot_tree(
graph = create_tree_digraph(booster=booster, tree_index=tree_index,
show_info=show_info, precision=precision,
orientation=orientation, **kwargs)
orientation=orientation, example_case=example_case, **kwargs)
s = BytesIO()
s.write(graph.pipe(format='png'))
......
# coding: utf-8
import numpy as np
import pytest
from sklearn.model_selection import train_test_split
import lightgbm as lgb
from lightgbm.compat import GRAPHVIZ_INSTALLED, MATPLOTLIB_INSTALLED
from lightgbm.compat import GRAPHVIZ_INSTALLED, MATPLOTLIB_INSTALLED, PANDAS_INSTALLED, pd_DataFrame
if MATPLOTLIB_INSTALLED:
import matplotlib
......@@ -11,7 +12,7 @@ if MATPLOTLIB_INSTALLED:
if GRAPHVIZ_INSTALLED:
import graphviz
from .utils import load_breast_cancer
from .utils import load_breast_cancer, make_synthetic_regression
@pytest.fixture(scope="module")
......@@ -187,6 +188,129 @@ def test_create_tree_digraph(breast_cancer_split):
assert 'count' not in graph_body
@pytest.mark.parametrize('use_missing', [True, False])
@pytest.mark.parametrize('zero_as_missing', [True, False])
def test_numeric_split_direction(use_missing, zero_as_missing):
if use_missing and zero_as_missing:
pytest.skip('use_missing and zero_as_missing both set to True')
X, y = make_synthetic_regression()
rng = np.random.RandomState(0)
zero_mask = rng.rand(X.shape[0]) < 0.05
X[zero_mask, :] = 0
if use_missing:
nan_mask = ~zero_mask & (rng.rand(X.shape[0]) < 0.1)
X[nan_mask, :] = np.nan
ds = lgb.Dataset(X, y)
params = {
'num_leaves': 127,
'min_child_samples': 1,
'use_missing': use_missing,
'zero_as_missing': zero_as_missing,
}
bst = lgb.train(params, ds, num_boost_round=1)
case_with_zero = X[zero_mask][[0]]
expected_leaf_zero = bst.predict(case_with_zero, pred_leaf=True)[0]
node = bst.dump_model()['tree_info'][0]['tree_structure']
while 'decision_type' in node:
direction = lgb.plotting._determine_direction_for_numeric_split(
case_with_zero[0][node['split_feature']], node['threshold'], node['missing_type'], node['default_left']
)
node = node['left_child'] if direction == 'left' else node['right_child']
assert node['leaf_index'] == expected_leaf_zero
if use_missing:
case_with_nan = X[nan_mask][[0]]
expected_leaf_nan = bst.predict(case_with_nan, pred_leaf=True)[0]
node = bst.dump_model()['tree_info'][0]['tree_structure']
while 'decision_type' in node:
direction = lgb.plotting._determine_direction_for_numeric_split(
case_with_nan[0][node['split_feature']], node['threshold'], node['missing_type'], node['default_left']
)
node = node['left_child'] if direction == 'left' else node['right_child']
assert node['leaf_index'] == expected_leaf_nan
assert expected_leaf_zero != expected_leaf_nan
@pytest.mark.skipif(not GRAPHVIZ_INSTALLED, reason='graphviz is not installed')
def test_example_case_in_tree_digraph():
rng = np.random.RandomState(0)
x1 = rng.rand(100)
cat = rng.randint(1, 3, size=x1.size)
X = np.vstack([x1, cat]).T
y = x1 + 2 * cat
feature_name = ['x1', 'cat']
ds = lgb.Dataset(X, y, feature_name=feature_name, categorical_feature=['cat'])
num_round = 3
bst = lgb.train({'num_leaves': 7}, ds, num_boost_round=num_round)
mod = bst.dump_model()
example_case = X[[0]]
makes_categorical_splits = False
seen_indices = set()
for i in range(num_round):
graph = lgb.create_tree_digraph(bst, example_case=example_case, tree_index=i)
gbody = graph.body
node = mod['tree_info'][i]['tree_structure']
while 'decision_type' in node: # iterate through the splits
split_index = node['split_index']
node_in_graph = [n for n in gbody if f'split{split_index}' in n and '->' not in n]
assert len(node_in_graph) == 1
seen_indices.add(gbody.index(node_in_graph[0]))
edge_to_node = [e for e in gbody if f'-> split{split_index}' in e]
if node['decision_type'] == '<=':
direction = lgb.plotting._determine_direction_for_numeric_split(
example_case[0][node['split_feature']], node['threshold'], node['missing_type'], node['default_left'])
else:
makes_categorical_splits = True
direction = lgb.plotting._determine_direction_for_categorical_split(
example_case[0][node['split_feature']], node['threshold']
)
node = node['left_child'] if direction == 'left' else node['right_child']
assert 'color=blue' in node_in_graph[0]
if edge_to_node:
assert len(edge_to_node) == 1
assert 'color=blue' in edge_to_node[0]
seen_indices.add(gbody.index(edge_to_node[0]))
# we're in a leaf now
leaf_index = node['leaf_index']
leaf_in_graph = [n for n in gbody if f'leaf{leaf_index}' in n and '->' not in n]
edge_to_leaf = [e for e in gbody if f'-> leaf{leaf_index}' in e]
assert len(leaf_in_graph) == 1
assert 'color=blue' in leaf_in_graph[0]
assert len(edge_to_leaf) == 1
assert 'color=blue' in edge_to_leaf[0]
seen_indices.update([gbody.index(leaf_in_graph[0]), gbody.index(edge_to_leaf[0])])
# check that the rest of the elements have black color
remaining_elements = [e for i, e in enumerate(graph.body) if i not in seen_indices and 'graph' not in e]
assert all('color=black' in e for e in remaining_elements)
# check that we got to the expected leaf
expected_leaf = bst.predict(example_case, start_iteration=i, num_iteration=1, pred_leaf=True)[0]
assert leaf_index == expected_leaf
assert makes_categorical_splits
@pytest.mark.skipif(not GRAPHVIZ_INSTALLED, reason='graphviz is not installed')
@pytest.mark.parametrize('input_type', ['array', 'dataframe'])
def test_empty_example_case_on_tree_digraph_raises_error(input_type):
X, y = make_synthetic_regression()
if input_type == 'dataframe':
if not PANDAS_INSTALLED:
pytest.skip(reason='pandas is not installed')
X = pd_DataFrame(X)
ds = lgb.Dataset(X, y)
bst = lgb.train({'num_leaves': 3}, ds, num_boost_round=1)
example_case = X[:0]
if input_type == 'dataframe':
example_case = pd_DataFrame(example_case)
with pytest.raises(ValueError, match='example_case must have a single row.'):
lgb.create_tree_digraph(bst, tree_index=0, example_case=example_case)
@pytest.mark.skipif(not MATPLOTLIB_INSTALLED, reason='matplotlib is not installed')
def test_plot_metrics(params, breast_cancer_split, train_data):
X_train, X_test, y_train, y_test = breast_cancer_split
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment