Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
af7a2544
Commit
af7a2544
authored
Sep 02, 2018
by
Nikita Titov
Committed by
Tsukasa OMOTO
Sep 02, 2018
Browse files
[python] use kwargs in tree plotting functions (#1630)
* use kwargs in tree plotting functions * relaxed version
parent
b0087754
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
45 additions
and
65 deletions
+45
-65
python-package/lightgbm/plotting.py
python-package/lightgbm/plotting.py
+45
-65
No files found.
python-package/lightgbm/plotting.py
View file @
af7a2544
...
...
@@ -10,7 +10,7 @@ from io import BytesIO
import
numpy
as
np
from
.basic
import
Booster
from
.compat
import
MATPLOTLIB_INSTALLED
,
GRAPHVIZ_INSTALLED
,
range_
,
string_type
from
.compat
import
MATPLOTLIB_INSTALLED
,
GRAPHVIZ_INSTALLED
,
LGBMDeprecationWarning
,
range_
,
string_type
from
.sklearn
import
LGBMModel
...
...
@@ -253,14 +253,11 @@ def plot_metric(booster, metric=None, dataset_names=None,
return
ax
def
_to_graphviz
(
tree_info
,
show_info
,
feature_names
,
precision
=
None
,
name
=
None
,
comment
=
None
,
filename
=
None
,
directory
=
None
,
format
=
None
,
engine
=
None
,
encoding
=
None
,
graph_attr
=
None
,
node_attr
=
None
,
edge_attr
=
None
,
body
=
None
,
strict
=
False
):
def
_to_graphviz
(
tree_info
,
show_info
,
feature_names
,
precision
=
None
,
**
kwargs
):
"""Convert specified tree to graphviz instance.
See:
- http://graphviz.readthedocs.io/en/stable/api.html#digraph
- http
s
://graphviz.readthedocs.io/en/stable/api.html#digraph
"""
if
GRAPHVIZ_INSTALLED
:
from
graphviz
import
Digraph
...
...
@@ -304,24 +301,22 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=None,
if
parent
is
not
None
:
graph
.
edge
(
parent
,
name
,
decision
)
graph
=
Digraph
(
name
=
name
,
comment
=
comment
,
filename
=
filename
,
directory
=
directory
,
format
=
format
,
engine
=
engine
,
encoding
=
encoding
,
graph_attr
=
graph_attr
,
node_attr
=
node_attr
,
edge_attr
=
edge_attr
,
body
=
body
,
strict
=
strict
)
graph
=
Digraph
(
**
kwargs
)
add
(
tree_info
[
'tree_structure'
])
return
graph
def
create_tree_digraph
(
booster
,
tree_index
=
0
,
show_info
=
None
,
precision
=
None
,
name
=
None
,
comment
=
None
,
filename
=
None
,
directory
=
None
,
format
=
None
,
engine
=
None
,
encoding
=
None
,
graph_attr
=
None
,
node_attr
=
None
,
edge_attr
=
None
,
body
=
None
,
strict
=
False
):
old_
name
=
None
,
old_
comment
=
None
,
old_
filename
=
None
,
old_
directory
=
None
,
old_
format
=
None
,
old_
engine
=
None
,
old_
encoding
=
None
,
old_
graph_attr
=
None
,
old_
node_attr
=
None
,
old_
edge_attr
=
None
,
old_
body
=
None
,
old_
strict
=
False
,
**
kwargs
):
"""Create a digraph representation of specified tree.
Note
----
For more information please visit
http://graphviz.readthedocs.io/en/stable/api.html#digraph.
http
s
://graphviz.readthedocs.io/en/stable/api.html#digraph.
Parameters
----------
...
...
@@ -334,34 +329,9 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None,
Possible values of list items: 'split_gain', 'internal_value', 'internal_count', 'leaf_count'.
precision : int or None, optional (default=None)
Used to restrict the display of floating point values to a certain precision.
name : string or None, optional (default=None)
Graph name used in the source code.
comment : string or None, optional (default=None)
Comment added to the first line of the source.
filename : string or None, optional (default=None)
Filename for saving the source.
If None, ``name`` + '.gv' is used.
directory : string or None, optional (default=None)
(Sub)directory for source saving and rendering.
format : string or None, optional (default=None)
Rendering output format ('pdf', 'png', ...).
engine : string or None, optional (default=None)
Layout command used ('dot', 'neato', ...).
encoding : string or None, optional (default=None)
Encoding for saving the source.
graph_attr : dict, list of tuples or None, optional (default=None)
Mapping of (attribute, value) pairs set for the graph.
All attributes and values must be strings or bytes-like objects.
node_attr : dict, list of tuples or None, optional (default=None)
Mapping of (attribute, value) pairs set for all nodes.
All attributes and values must be strings or bytes-like objects.
edge_attr : dict, list of tuples or None, optional (default=None)
Mapping of (attribute, value) pairs set for all edges.
All attributes and values must be strings or bytes-like objects.
body : list of strings or None, optional (default=None)
Lines to add to the graph body.
strict : bool, optional (default=False)
Whether rendering should merge multi-edges.
**kwargs : other parameters
Other parameters passed to ``Digraph`` constructor.
Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters.
Returns
-------
...
...
@@ -373,6 +343,23 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None,
elif
not
isinstance
(
booster
,
Booster
):
raise
TypeError
(
'booster must be Booster or LGBMModel.'
)
for
param_name
in
[
'old_name'
,
'old_comment'
,
'old_filename'
,
'old_directory'
,
'old_format'
,
'old_engine'
,
'old_encoding'
,
'old_graph_attr'
,
'old_node_attr'
,
'old_edge_attr'
,
'old_body'
]:
param
=
locals
().
get
(
param_name
)
if
param
is
not
None
:
warnings
.
warn
(
'{0} parameter is deprecated and will be removed in 2.3 version.
\n
'
'Please use **kwargs to pass {1} parameter.'
.
format
(
param_name
,
param_name
[
4
:]),
LGBMDeprecationWarning
)
if
param_name
[
4
:]
not
in
kwargs
:
kwargs
[
param_name
[
4
:]]
=
param
if
locals
().
get
(
'strict'
):
warnings
.
warn
(
'old_strict parameter is deprecated and will be removed in 2.3 version.
\n
'
'Please use **kwargs to pass strict parameter.'
,
LGBMDeprecationWarning
)
if
'strict'
not
in
kwargs
:
kwargs
[
'strict'
]
=
True
model
=
booster
.
dump_model
()
tree_infos
=
model
[
'tree_info'
]
if
'feature_names'
in
model
:
...
...
@@ -388,17 +375,14 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None,
if
show_info
is
None
:
show_info
=
[]
graph
=
_to_graphviz
(
tree_info
,
show_info
,
feature_names
,
precision
,
name
=
name
,
comment
=
comment
,
filename
=
filename
,
directory
=
directory
,
format
=
format
,
engine
=
engine
,
encoding
=
encoding
,
graph_attr
=
graph_attr
,
node_attr
=
node_attr
,
edge_attr
=
edge_attr
,
body
=
body
,
strict
=
strict
)
graph
=
_to_graphviz
(
tree_info
,
show_info
,
feature_names
,
precision
,
**
kwargs
)
return
graph
def
plot_tree
(
booster
,
ax
=
None
,
tree_index
=
0
,
figsize
=
None
,
graph_attr
=
None
,
node_attr
=
None
,
edge_attr
=
None
,
show_info
=
None
,
precision
=
None
):
old_
graph_attr
=
None
,
old_
node_attr
=
None
,
old_
edge_attr
=
None
,
show_info
=
None
,
precision
=
None
,
**
kwargs
):
"""Plot specified tree.
Note
...
...
@@ -417,20 +401,14 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None,
The index of a target tree to plot.
figsize : tuple of 2 elements or None, optional (default=None)
Figure size.
graph_attr : dict, list of tuples or None, optional (default=None)
Mapping of (attribute, value) pairs set for the graph.
All attributes and values must be strings or bytes-like objects.
node_attr : dict, list of tuples or None, optional (default=None)
Mapping of (attribute, value) pairs set for all nodes.
All attributes and values must be strings or bytes-like objects.
edge_attr : dict, list of tuples or None, optional (default=None)
Mapping of (attribute, value) pairs set for all edges.
All attributes and values must be strings or bytes-like objects.
show_info : list of strings or None, optional (default=None)
What information should be shown in nodes.
Possible values of list items: 'split_gain', 'internal_value', 'internal_count', 'leaf_count'.
precision : int or None, optional (default=None)
Used to restrict the display of floating point values to a certain precision.
**kwargs : other parameters
Other parameters passed to ``Digraph`` constructor.
Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters.
Returns
-------
...
...
@@ -443,20 +421,22 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None,
else
:
raise
ImportError
(
'You must install matplotlib to plot tree.'
)
for
param_name
in
[
'old_graph_attr'
,
'old_node_attr'
,
'old_edge_attr'
]:
param
=
locals
().
get
(
param_name
)
if
param
is
not
None
:
warnings
.
warn
(
'{0} parameter is deprecated and will be removed in 2.3 version.
\n
'
'Please use **kwargs to pass {1} parameter.'
.
format
(
param_name
,
param_name
[
4
:]),
LGBMDeprecationWarning
)
if
param_name
[
4
:]
not
in
kwargs
:
kwargs
[
param_name
[
4
:]]
=
param
if
ax
is
None
:
if
figsize
is
not
None
:
check_not_tuple_of_2_elements
(
figsize
,
'figsize'
)
_
,
ax
=
plt
.
subplots
(
1
,
1
,
figsize
=
figsize
)
graph
=
create_tree_digraph
(
booster
=
booster
,
tree_index
=
tree_index
,
show_info
=
show_info
,
precision
=
precision
,
graph_attr
=
graph_attr
,
node_attr
=
node_attr
,
edge_attr
=
edge_attr
)
graph
=
create_tree_digraph
(
booster
=
booster
,
tree_index
=
tree_index
,
show_info
=
show_info
,
precision
=
precision
,
**
kwargs
)
s
=
BytesIO
()
s
.
write
(
graph
.
pipe
(
format
=
'png'
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment