"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "6fc80528f15b92921ecffaaa14b6bddaa0de3404"
Unverified Commit 9713ff40 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] add type hints on plotting code (#5708)

parent 885ea3ad
...@@ -451,10 +451,10 @@ def _to_graphviz( ...@@ -451,10 +451,10 @@ def _to_graphviz(
tree_info: Dict[str, Any], tree_info: Dict[str, Any],
show_info: List[str], show_info: List[str],
feature_names: Union[List[str], None], feature_names: Union[List[str], None],
precision: Optional[int] = 3, precision: Optional[int],
orientation: str = 'horizontal', orientation: str,
constraints: Optional[List[int]] = None, constraints: Optional[List[int]],
example_case: Optional[Union[np.ndarray, pd_DataFrame]] = None, example_case: Optional[Union[np.ndarray, pd_DataFrame]],
**kwargs: Any **kwargs: Any
) -> Any: ) -> Any:
"""Convert specified tree to graphviz instance. """Convert specified tree to graphviz instance.
...@@ -467,7 +467,13 @@ def _to_graphviz( ...@@ -467,7 +467,13 @@ def _to_graphviz(
else: else:
raise ImportError('You must install graphviz and restart your session to plot tree.') raise ImportError('You must install graphviz and restart your session to plot tree.')
def add(root, total_count, parent=None, decision=None, highlight=False): def add(
root: Dict[str, Any],
total_count: int,
parent: Optional[str],
decision: Optional[str],
highlight: bool
) -> None:
"""Recursively add node or edge.""" """Recursively add node or edge."""
fillcolor = 'white' fillcolor = 'white'
style = '' style = ''
...@@ -496,10 +502,16 @@ def _to_graphviz( ...@@ -496,10 +502,16 @@ def _to_graphviz(
direction = None direction = None
if example_case is not None: if example_case is not None:
if root['decision_type'] == '==': if root['decision_type'] == '==':
direction = _determine_direction_for_categorical_split(example_case[split_feature], root['threshold']) direction = _determine_direction_for_categorical_split(
fval=example_case[split_feature],
thresholds=root['threshold']
)
else: else:
direction = _determine_direction_for_numeric_split( direction = _determine_direction_for_numeric_split(
example_case[split_feature], root['threshold'], root['missing_type'], root['default_left'] fval=example_case[split_feature],
threshold=root['threshold'],
missing_type_str=root['missing_type'],
default_left=root['default_left']
) )
label += f"<B>{_float2str(root['threshold'], precision)}</B>" label += f"<B>{_float2str(root['threshold'], precision)}</B>"
for info in ['split_gain', 'internal_value', 'internal_weight', "internal_count", "data_percentage"]: for info in ['split_gain', 'internal_value', 'internal_weight', "internal_count", "data_percentage"]:
...@@ -519,8 +531,20 @@ def _to_graphviz( ...@@ -519,8 +531,20 @@ def _to_graphviz(
fillcolor = "#ffdddd" # light red fillcolor = "#ffdddd" # light red
style = "filled" style = "filled"
label = f"<{label}>" label = f"<{label}>"
add(root['left_child'], total_count, name, l_dec, highlight and direction == "left") add(
add(root['right_child'], total_count, name, r_dec, highlight and direction == "right") root=root['left_child'],
total_count=total_count,
parent=name,
decision=l_dec,
highlight=highlight and direction == "left"
)
add(
root=root['right_child'],
total_count=total_count,
parent=name,
decision=r_dec,
highlight=highlight and direction == "right"
)
else: # leaf else: # leaf
shape = "ellipse" shape = "ellipse"
name = f"leaf{root['leaf_index']}" name = f"leaf{root['leaf_index']}"
...@@ -541,7 +565,13 @@ def _to_graphviz( ...@@ -541,7 +565,13 @@ def _to_graphviz(
rankdir = "LR" if orientation == "horizontal" else "TB" rankdir = "LR" if orientation == "horizontal" else "TB"
graph.attr("graph", nodesep="0.05", ranksep="0.3", rankdir=rankdir) graph.attr("graph", nodesep="0.05", ranksep="0.3", rankdir=rankdir)
if "internal_count" in tree_info['tree_structure']: if "internal_count" in tree_info['tree_structure']:
add(tree_info['tree_structure'], tree_info['tree_structure']["internal_count"], highlight=example_case is not None) add(
root=tree_info['tree_structure'],
total_count=tree_info['tree_structure']["internal_count"],
parent=None,
decision=None,
highlight=example_case is not None
)
else: else:
raise Exception("Cannot plot trees with no split") raise Exception("Cannot plot trees with no split")
...@@ -653,11 +683,24 @@ def create_tree_digraph( ...@@ -653,11 +683,24 @@ def create_tree_digraph(
if example_case.shape[0] != 1: if example_case.shape[0] != 1:
raise ValueError('example_case must have a single row.') raise ValueError('example_case must have a single row.')
if isinstance(example_case, pd_DataFrame): if isinstance(example_case, pd_DataFrame):
example_case = _data_from_pandas(example_case, None, None, booster.pandas_categorical)[0] example_case = _data_from_pandas(
data=example_case,
feature_name=None,
categorical_feature=None,
pandas_categorical=booster.pandas_categorical
)[0]
example_case = example_case[0] example_case = example_case[0]
graph = _to_graphviz(tree_info, show_info, feature_names, precision, graph = _to_graphviz(
orientation, monotone_constraints, example_case=example_case, **kwargs) tree_info=tree_info,
show_info=show_info,
feature_names=feature_names,
precision=precision,
orientation=orientation,
constraints=monotone_constraints,
example_case=example_case,
**kwargs
)
return graph return graph
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment