Unverified Commit 5e36167d authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python] allow to specify orientation of a tree (#2605)

parent 4eb0745d
......@@ -374,7 +374,8 @@ def plot_metric(booster, metric=None, dataset_names=None,
return ax
def _to_graphviz(tree_info, show_info, feature_names, precision=3, constraints=None, **kwargs):
def _to_graphviz(tree_info, show_info, feature_names, precision=3,
orientation='horizontal', constraints=None, **kwargs):
"""Convert specified tree to graphviz instance.
See:
......@@ -441,7 +442,8 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=3, constraints=N
graph.edge(parent, name, decision)
graph = Digraph(**kwargs)
graph.attr("graph", nodesep="0.05", ranksep="0.3", rankdir="LR")
rankdir = "LR" if orientation == "horizontal" else "TB"
graph.attr("graph", nodesep="0.05", ranksep="0.3", rankdir=rankdir)
if "internal_count" in tree_info['tree_structure']:
add(tree_info['tree_structure'], tree_info['tree_structure']["internal_count"])
else:
......@@ -471,7 +473,8 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=3, constraints=N
def create_tree_digraph(booster, tree_index=0, show_info=None, precision=3,
old_name=None, old_comment=None, old_filename=None, old_directory=None,
old_format=None, old_engine=None, old_encoding=None, old_graph_attr=None,
old_node_attr=None, old_edge_attr=None, old_body=None, old_strict=False, **kwargs):
old_node_attr=None, old_edge_attr=None, old_body=None, old_strict=False,
orientation='horizontal', **kwargs):
"""Create a digraph representation of specified tree.
.. note::
......@@ -492,6 +495,9 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=3,
'leaf_count', 'leaf_weight', 'data_percentage'.
precision : int or None, optional (default=3)
Used to restrict the display of floating point values to a certain precision.
orientation : string, optional (default='horizontal')
Orientation of the tree.
Can be 'horizontal' or 'vertical'.
**kwargs
Other parameters passed to ``Digraph`` constructor.
Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters.
......@@ -540,14 +546,15 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=3,
if show_info is None:
show_info = []
graph = _to_graphviz(tree_info, show_info, feature_names, precision, monotone_constraints, **kwargs)
graph = _to_graphviz(tree_info, show_info, feature_names, precision,
orientation, monotone_constraints, **kwargs)
return graph
def plot_tree(booster, ax=None, tree_index=0, figsize=None, dpi=None,
old_graph_attr=None, old_node_attr=None, old_edge_attr=None,
show_info=None, precision=3, **kwargs):
show_info=None, precision=3, orientation='horizontal', **kwargs):
"""Plot specified tree.
.. note::
......@@ -575,6 +582,9 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None, dpi=None,
'leaf_count', 'leaf_weight', 'data_percentage'.
precision : int or None, optional (default=3)
Used to restrict the display of floating point values to a certain precision.
orientation : string, optional (default='horizontal')
Orientation of the tree.
Can be 'horizontal' or 'vertical'.
**kwargs
Other parameters passed to ``Digraph`` constructor.
Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters.
......@@ -605,7 +615,8 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None, dpi=None,
_, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
graph = create_tree_digraph(booster=booster, tree_index=tree_index,
show_info=show_info, precision=precision, **kwargs)
show_info=show_info, precision=precision,
orientation=orientation, **kwargs)
s = BytesIO()
s.write(graph.pipe(format='png'))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment