Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
9713ff40
Unverified
Commit
9713ff40
authored
Feb 13, 2023
by
James Lamb
Committed by
GitHub
Feb 13, 2023
Browse files
[python-package] add type hints on plotting code (#5708)
parent
885ea3ad
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
56 additions
and
13 deletions
+56
-13
python-package/lightgbm/plotting.py
python-package/lightgbm/plotting.py
+56
-13
No files found.
python-package/lightgbm/plotting.py
View file @
9713ff40
...
@@ -451,10 +451,10 @@ def _to_graphviz(
...
@@ -451,10 +451,10 @@ def _to_graphviz(
tree_info
:
Dict
[
str
,
Any
],
tree_info
:
Dict
[
str
,
Any
],
show_info
:
List
[
str
],
show_info
:
List
[
str
],
feature_names
:
Union
[
List
[
str
],
None
],
feature_names
:
Union
[
List
[
str
],
None
],
precision
:
Optional
[
int
]
=
3
,
precision
:
Optional
[
int
],
orientation
:
str
=
'horizontal'
,
orientation
:
str
,
constraints
:
Optional
[
List
[
int
]]
=
None
,
constraints
:
Optional
[
List
[
int
]],
example_case
:
Optional
[
Union
[
np
.
ndarray
,
pd_DataFrame
]]
=
None
,
example_case
:
Optional
[
Union
[
np
.
ndarray
,
pd_DataFrame
]],
**
kwargs
:
Any
**
kwargs
:
Any
)
->
Any
:
)
->
Any
:
"""Convert specified tree to graphviz instance.
"""Convert specified tree to graphviz instance.
...
@@ -467,7 +467,13 @@ def _to_graphviz(
...
@@ -467,7 +467,13 @@ def _to_graphviz(
else
:
else
:
raise
ImportError
(
'You must install graphviz and restart your session to plot tree.'
)
raise
ImportError
(
'You must install graphviz and restart your session to plot tree.'
)
def
add
(
root
,
total_count
,
parent
=
None
,
decision
=
None
,
highlight
=
False
):
def
add
(
root
:
Dict
[
str
,
Any
],
total_count
:
int
,
parent
:
Optional
[
str
],
decision
:
Optional
[
str
],
highlight
:
bool
)
->
None
:
"""Recursively add node or edge."""
"""Recursively add node or edge."""
fillcolor
=
'white'
fillcolor
=
'white'
style
=
''
style
=
''
...
@@ -496,10 +502,16 @@ def _to_graphviz(
...
@@ -496,10 +502,16 @@ def _to_graphviz(
direction
=
None
direction
=
None
if
example_case
is
not
None
:
if
example_case
is
not
None
:
if
root
[
'decision_type'
]
==
'=='
:
if
root
[
'decision_type'
]
==
'=='
:
direction
=
_determine_direction_for_categorical_split
(
example_case
[
split_feature
],
root
[
'threshold'
])
direction
=
_determine_direction_for_categorical_split
(
fval
=
example_case
[
split_feature
],
thresholds
=
root
[
'threshold'
]
)
else
:
else
:
direction
=
_determine_direction_for_numeric_split
(
direction
=
_determine_direction_for_numeric_split
(
example_case
[
split_feature
],
root
[
'threshold'
],
root
[
'missing_type'
],
root
[
'default_left'
]
fval
=
example_case
[
split_feature
],
threshold
=
root
[
'threshold'
],
missing_type_str
=
root
[
'missing_type'
],
default_left
=
root
[
'default_left'
]
)
)
label
+=
f
"<B>
{
_float2str
(
root
[
'threshold'
],
precision
)
}
</B>"
label
+=
f
"<B>
{
_float2str
(
root
[
'threshold'
],
precision
)
}
</B>"
for
info
in
[
'split_gain'
,
'internal_value'
,
'internal_weight'
,
"internal_count"
,
"data_percentage"
]:
for
info
in
[
'split_gain'
,
'internal_value'
,
'internal_weight'
,
"internal_count"
,
"data_percentage"
]:
...
@@ -519,8 +531,20 @@ def _to_graphviz(
...
@@ -519,8 +531,20 @@ def _to_graphviz(
fillcolor
=
"#ffdddd"
# light red
fillcolor
=
"#ffdddd"
# light red
style
=
"filled"
style
=
"filled"
label
=
f
"<
{
label
}
>"
label
=
f
"<
{
label
}
>"
add
(
root
[
'left_child'
],
total_count
,
name
,
l_dec
,
highlight
and
direction
==
"left"
)
add
(
add
(
root
[
'right_child'
],
total_count
,
name
,
r_dec
,
highlight
and
direction
==
"right"
)
root
=
root
[
'left_child'
],
total_count
=
total_count
,
parent
=
name
,
decision
=
l_dec
,
highlight
=
highlight
and
direction
==
"left"
)
add
(
root
=
root
[
'right_child'
],
total_count
=
total_count
,
parent
=
name
,
decision
=
r_dec
,
highlight
=
highlight
and
direction
==
"right"
)
else
:
# leaf
else
:
# leaf
shape
=
"ellipse"
shape
=
"ellipse"
name
=
f
"leaf
{
root
[
'leaf_index'
]
}
"
name
=
f
"leaf
{
root
[
'leaf_index'
]
}
"
...
@@ -541,7 +565,13 @@ def _to_graphviz(
...
@@ -541,7 +565,13 @@ def _to_graphviz(
rankdir
=
"LR"
if
orientation
==
"horizontal"
else
"TB"
rankdir
=
"LR"
if
orientation
==
"horizontal"
else
"TB"
graph
.
attr
(
"graph"
,
nodesep
=
"0.05"
,
ranksep
=
"0.3"
,
rankdir
=
rankdir
)
graph
.
attr
(
"graph"
,
nodesep
=
"0.05"
,
ranksep
=
"0.3"
,
rankdir
=
rankdir
)
if
"internal_count"
in
tree_info
[
'tree_structure'
]:
if
"internal_count"
in
tree_info
[
'tree_structure'
]:
add
(
tree_info
[
'tree_structure'
],
tree_info
[
'tree_structure'
][
"internal_count"
],
highlight
=
example_case
is
not
None
)
add
(
root
=
tree_info
[
'tree_structure'
],
total_count
=
tree_info
[
'tree_structure'
][
"internal_count"
],
parent
=
None
,
decision
=
None
,
highlight
=
example_case
is
not
None
)
else
:
else
:
raise
Exception
(
"Cannot plot trees with no split"
)
raise
Exception
(
"Cannot plot trees with no split"
)
...
@@ -653,11 +683,24 @@ def create_tree_digraph(
...
@@ -653,11 +683,24 @@ def create_tree_digraph(
if
example_case
.
shape
[
0
]
!=
1
:
if
example_case
.
shape
[
0
]
!=
1
:
raise
ValueError
(
'example_case must have a single row.'
)
raise
ValueError
(
'example_case must have a single row.'
)
if
isinstance
(
example_case
,
pd_DataFrame
):
if
isinstance
(
example_case
,
pd_DataFrame
):
example_case
=
_data_from_pandas
(
example_case
,
None
,
None
,
booster
.
pandas_categorical
)[
0
]
example_case
=
_data_from_pandas
(
data
=
example_case
,
feature_name
=
None
,
categorical_feature
=
None
,
pandas_categorical
=
booster
.
pandas_categorical
)[
0
]
example_case
=
example_case
[
0
]
example_case
=
example_case
[
0
]
graph
=
_to_graphviz
(
tree_info
,
show_info
,
feature_names
,
precision
,
graph
=
_to_graphviz
(
orientation
,
monotone_constraints
,
example_case
=
example_case
,
**
kwargs
)
tree_info
=
tree_info
,
show_info
=
show_info
,
feature_names
=
feature_names
,
precision
=
precision
,
orientation
=
orientation
,
constraints
=
monotone_constraints
,
example_case
=
example_case
,
**
kwargs
)
return
graph
return
graph
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment