Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
39421265
Unverified
Commit
39421265
authored
Sep 01, 2021
by
Nikita Titov
Committed by
GitHub
Aug 31, 2021
Browse files
add 'auto' value for `importance_type` param in plotting (#4570)
parent
32445aba
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
5 deletions
+29
-5
python-package/lightgbm/plotting.py
python-package/lightgbm/plotting.py
+9
-3
tests/python_package_test/test_plotting.py
tests/python_package_test/test_plotting.py
+20
-2
No files found.
python-package/lightgbm/plotting.py
View file @
39421265
...
...
@@ -32,7 +32,7 @@ def plot_importance(
title
:
Optional
[
str
]
=
'Feature importance'
,
xlabel
:
Optional
[
str
]
=
'Feature importance'
,
ylabel
:
Optional
[
str
]
=
'Features'
,
importance_type
:
str
=
'
split
'
,
importance_type
:
str
=
'
auto
'
,
max_num_features
:
Optional
[
int
]
=
None
,
ignore_zero
:
bool
=
True
,
figsize
:
Optional
[
Tuple
[
float
,
float
]]
=
None
,
...
...
@@ -65,8 +65,9 @@ def plot_importance(
ylabel : str or None, optional (default="Features")
Y-axis title label.
If None, title is disabled.
importance_type : str, optional (default="
split
")
importance_type : str, optional (default="
auto
")
How the importance is calculated.
If "auto", if ``booster`` parameter is LGBMModel, ``booster.importance_type`` attribute is used; "split" otherwise.
If "split", result contains numbers of times the feature is used in a model.
If "gain", result contains total gains of splits which use the feature.
max_num_features : int or None, optional (default=None)
...
...
@@ -96,8 +97,13 @@ def plot_importance(
raise
ImportError
(
'You must install matplotlib and restart your session to plot importance.'
)
if
isinstance
(
booster
,
LGBMModel
):
if
importance_type
==
"auto"
:
importance_type
=
booster
.
importance_type
booster
=
booster
.
booster_
elif
not
isinstance
(
booster
,
Booster
):
elif
isinstance
(
booster
,
Booster
):
if
importance_type
==
"auto"
:
importance_type
=
"split"
else
:
raise
TypeError
(
'booster must be Booster or LGBMModel.'
)
importance
=
booster
.
feature_importance
(
importance_type
=
importance_type
)
...
...
tests/python_package_test/test_plotting.py
View file @
39421265
...
...
@@ -57,8 +57,7 @@ def test_plot_importance(params, breast_cancer_split, train_data):
for
patch
in
ax1
.
patches
:
assert
patch
.
get_facecolor
()
==
(
1.
,
0
,
0
,
1.
)
# red
ax2
=
lgb
.
plot_importance
(
gbm0
,
color
=
[
'r'
,
'y'
,
'g'
,
'b'
],
title
=
None
,
xlabel
=
None
,
ylabel
=
None
)
ax2
=
lgb
.
plot_importance
(
gbm0
,
color
=
[
'r'
,
'y'
,
'g'
,
'b'
],
title
=
None
,
xlabel
=
None
,
ylabel
=
None
)
assert
isinstance
(
ax2
,
matplotlib
.
axes
.
Axes
)
assert
ax2
.
get_title
()
==
''
assert
ax2
.
get_xlabel
()
==
''
...
...
@@ -69,6 +68,25 @@ def test_plot_importance(params, breast_cancer_split, train_data):
assert
ax2
.
patches
[
2
].
get_facecolor
()
==
(
0
,
.
5
,
0
,
1.
)
# g
assert
ax2
.
patches
[
3
].
get_facecolor
()
==
(
0
,
0
,
1.
,
1.
)
# b
gbm2
=
lgb
.
LGBMClassifier
(
n_estimators
=
10
,
num_leaves
=
3
,
silent
=
True
,
importance_type
=
"gain"
)
gbm2
.
fit
(
X_train
,
y_train
)
def
get_bounds_of_first_patch
(
axes
):
return
axes
.
patches
[
0
].
get_extents
().
bounds
first_bar1
=
get_bounds_of_first_patch
(
lgb
.
plot_importance
(
gbm1
))
first_bar2
=
get_bounds_of_first_patch
(
lgb
.
plot_importance
(
gbm1
,
importance_type
=
"split"
))
first_bar3
=
get_bounds_of_first_patch
(
lgb
.
plot_importance
(
gbm1
,
importance_type
=
"gain"
))
first_bar4
=
get_bounds_of_first_patch
(
lgb
.
plot_importance
(
gbm2
))
first_bar5
=
get_bounds_of_first_patch
(
lgb
.
plot_importance
(
gbm2
,
importance_type
=
"split"
))
first_bar6
=
get_bounds_of_first_patch
(
lgb
.
plot_importance
(
gbm2
,
importance_type
=
"gain"
))
assert
first_bar1
==
first_bar2
assert
first_bar1
==
first_bar5
assert
first_bar3
==
first_bar4
assert
first_bar3
==
first_bar6
assert
first_bar1
!=
first_bar3
@
pytest
.
mark
.
skipif
(
not
MATPLOTLIB_INSTALLED
,
reason
=
'matplotlib is not installed'
)
def
test_plot_split_value_histogram
(
params
,
breast_cancer_split
,
train_data
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment