Commit af7a2544 authored by Nikita Titov's avatar Nikita Titov Committed by Tsukasa OMOTO
Browse files

[python] use kwargs in tree plotting functions (#1630)

* use kwargs in tree plotting functions

* relaxed version
parent b0087754
...@@ -10,7 +10,7 @@ from io import BytesIO ...@@ -10,7 +10,7 @@ from io import BytesIO
import numpy as np import numpy as np
from .basic import Booster from .basic import Booster
from .compat import MATPLOTLIB_INSTALLED, GRAPHVIZ_INSTALLED, range_, string_type from .compat import MATPLOTLIB_INSTALLED, GRAPHVIZ_INSTALLED, LGBMDeprecationWarning, range_, string_type
from .sklearn import LGBMModel from .sklearn import LGBMModel
...@@ -253,14 +253,11 @@ def plot_metric(booster, metric=None, dataset_names=None, ...@@ -253,14 +253,11 @@ def plot_metric(booster, metric=None, dataset_names=None,
return ax return ax
def _to_graphviz(tree_info, show_info, feature_names, precision=None, def _to_graphviz(tree_info, show_info, feature_names, precision=None, **kwargs):
name=None, comment=None, filename=None, directory=None,
format=None, engine=None, encoding=None, graph_attr=None,
node_attr=None, edge_attr=None, body=None, strict=False):
"""Convert specified tree to graphviz instance. """Convert specified tree to graphviz instance.
See: See:
- http://graphviz.readthedocs.io/en/stable/api.html#digraph - https://graphviz.readthedocs.io/en/stable/api.html#digraph
""" """
if GRAPHVIZ_INSTALLED: if GRAPHVIZ_INSTALLED:
from graphviz import Digraph from graphviz import Digraph
...@@ -304,24 +301,22 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=None, ...@@ -304,24 +301,22 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=None,
if parent is not None: if parent is not None:
graph.edge(parent, name, decision) graph.edge(parent, name, decision)
graph = Digraph(name=name, comment=comment, filename=filename, directory=directory, graph = Digraph(**kwargs)
format=format, engine=engine, encoding=encoding, graph_attr=graph_attr,
node_attr=node_attr, edge_attr=edge_attr, body=body, strict=strict)
add(tree_info['tree_structure']) add(tree_info['tree_structure'])
return graph return graph
def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None, def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None,
name=None, comment=None, filename=None, directory=None, old_name=None, old_comment=None, old_filename=None, old_directory=None,
format=None, engine=None, encoding=None, graph_attr=None, old_format=None, old_engine=None, old_encoding=None, old_graph_attr=None,
node_attr=None, edge_attr=None, body=None, strict=False): old_node_attr=None, old_edge_attr=None, old_body=None, old_strict=False, **kwargs):
"""Create a digraph representation of specified tree. """Create a digraph representation of specified tree.
Note Note
---- ----
For more information please visit For more information please visit
http://graphviz.readthedocs.io/en/stable/api.html#digraph. https://graphviz.readthedocs.io/en/stable/api.html#digraph.
Parameters Parameters
---------- ----------
...@@ -334,34 +329,9 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None, ...@@ -334,34 +329,9 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None,
Possible values of list items: 'split_gain', 'internal_value', 'internal_count', 'leaf_count'. Possible values of list items: 'split_gain', 'internal_value', 'internal_count', 'leaf_count'.
precision : int or None, optional (default=None) precision : int or None, optional (default=None)
Used to restrict the display of floating point values to a certain precision. Used to restrict the display of floating point values to a certain precision.
name : string or None, optional (default=None) **kwargs : other parameters
Graph name used in the source code. Other parameters passed to ``Digraph`` constructor.
comment : string or None, optional (default=None) Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters.
Comment added to the first line of the source.
filename : string or None, optional (default=None)
Filename for saving the source.
If None, ``name`` + '.gv' is used.
directory : string or None, optional (default=None)
(Sub)directory for source saving and rendering.
format : string or None, optional (default=None)
Rendering output format ('pdf', 'png', ...).
engine : string or None, optional (default=None)
Layout command used ('dot', 'neato', ...).
encoding : string or None, optional (default=None)
Encoding for saving the source.
graph_attr : dict, list of tuples or None, optional (default=None)
Mapping of (attribute, value) pairs set for the graph.
All attributes and values must be strings or bytes-like objects.
node_attr : dict, list of tuples or None, optional (default=None)
Mapping of (attribute, value) pairs set for all nodes.
All attributes and values must be strings or bytes-like objects.
edge_attr : dict, list of tuples or None, optional (default=None)
Mapping of (attribute, value) pairs set for all edges.
All attributes and values must be strings or bytes-like objects.
body : list of strings or None, optional (default=None)
Lines to add to the graph body.
strict : bool, optional (default=False)
Whether rendering should merge multi-edges.
Returns Returns
------- -------
...@@ -373,6 +343,23 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None, ...@@ -373,6 +343,23 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None,
elif not isinstance(booster, Booster): elif not isinstance(booster, Booster):
raise TypeError('booster must be Booster or LGBMModel.') raise TypeError('booster must be Booster or LGBMModel.')
for param_name in ['old_name', 'old_comment', 'old_filename', 'old_directory',
'old_format', 'old_engine', 'old_encoding', 'old_graph_attr',
'old_node_attr', 'old_edge_attr', 'old_body']:
param = locals().get(param_name)
if param is not None:
warnings.warn('{0} parameter is deprecated and will be removed in 2.3 version.\n'
'Please use **kwargs to pass {1} parameter.'.format(param_name, param_name[4:]),
LGBMDeprecationWarning)
if param_name[4:] not in kwargs:
kwargs[param_name[4:]] = param
if locals().get('strict'):
warnings.warn('old_strict parameter is deprecated and will be removed in 2.3 version.\n'
'Please use **kwargs to pass strict parameter.',
LGBMDeprecationWarning)
if 'strict' not in kwargs:
kwargs['strict'] = True
model = booster.dump_model() model = booster.dump_model()
tree_infos = model['tree_info'] tree_infos = model['tree_info']
if 'feature_names' in model: if 'feature_names' in model:
...@@ -388,17 +375,14 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None, ...@@ -388,17 +375,14 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None,
if show_info is None: if show_info is None:
show_info = [] show_info = []
graph = _to_graphviz(tree_info, show_info, feature_names, precision, graph = _to_graphviz(tree_info, show_info, feature_names, precision, **kwargs)
name=name, comment=comment, filename=filename, directory=directory,
format=format, engine=engine, encoding=encoding, graph_attr=graph_attr,
node_attr=node_attr, edge_attr=edge_attr, body=body, strict=strict)
return graph return graph
def plot_tree(booster, ax=None, tree_index=0, figsize=None, def plot_tree(booster, ax=None, tree_index=0, figsize=None,
graph_attr=None, node_attr=None, edge_attr=None, old_graph_attr=None, old_node_attr=None, old_edge_attr=None,
show_info=None, precision=None): show_info=None, precision=None, **kwargs):
"""Plot specified tree. """Plot specified tree.
Note Note
...@@ -417,20 +401,14 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None, ...@@ -417,20 +401,14 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None,
The index of a target tree to plot. The index of a target tree to plot.
figsize : tuple of 2 elements or None, optional (default=None) figsize : tuple of 2 elements or None, optional (default=None)
Figure size. Figure size.
graph_attr : dict, list of tuples or None, optional (default=None)
Mapping of (attribute, value) pairs set for the graph.
All attributes and values must be strings or bytes-like objects.
node_attr : dict, list of tuples or None, optional (default=None)
Mapping of (attribute, value) pairs set for all nodes.
All attributes and values must be strings or bytes-like objects.
edge_attr : dict, list of tuples or None, optional (default=None)
Mapping of (attribute, value) pairs set for all edges.
All attributes and values must be strings or bytes-like objects.
show_info : list of strings or None, optional (default=None) show_info : list of strings or None, optional (default=None)
What information should be shown in nodes. What information should be shown in nodes.
Possible values of list items: 'split_gain', 'internal_value', 'internal_count', 'leaf_count'. Possible values of list items: 'split_gain', 'internal_value', 'internal_count', 'leaf_count'.
precision : int or None, optional (default=None) precision : int or None, optional (default=None)
Used to restrict the display of floating point values to a certain precision. Used to restrict the display of floating point values to a certain precision.
**kwargs : other parameters
Other parameters passed to ``Digraph`` constructor.
Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters.
Returns Returns
------- -------
...@@ -443,20 +421,22 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None, ...@@ -443,20 +421,22 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None,
else: else:
raise ImportError('You must install matplotlib to plot tree.') raise ImportError('You must install matplotlib to plot tree.')
for param_name in ['old_graph_attr', 'old_node_attr', 'old_edge_attr']:
param = locals().get(param_name)
if param is not None:
warnings.warn('{0} parameter is deprecated and will be removed in 2.3 version.\n'
'Please use **kwargs to pass {1} parameter.'.format(param_name, param_name[4:]),
LGBMDeprecationWarning)
if param_name[4:] not in kwargs:
kwargs[param_name[4:]] = param
if ax is None: if ax is None:
if figsize is not None: if figsize is not None:
check_not_tuple_of_2_elements(figsize, 'figsize') check_not_tuple_of_2_elements(figsize, 'figsize')
_, ax = plt.subplots(1, 1, figsize=figsize) _, ax = plt.subplots(1, 1, figsize=figsize)
graph = create_tree_digraph( graph = create_tree_digraph(booster=booster, tree_index=tree_index,
booster=booster, show_info=show_info, precision=precision, **kwargs)
tree_index=tree_index,
show_info=show_info,
precision=precision,
graph_attr=graph_attr,
node_attr=node_attr,
edge_attr=edge_attr
)
s = BytesIO() s = BytesIO()
s.write(graph.pipe(format='png')) s.write(graph.pipe(format='png'))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment