Unverified Commit 39421265 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

add 'auto' value for `importance_type` param in plotting (#4570)

parent 32445aba
...@@ -32,7 +32,7 @@ def plot_importance( ...@@ -32,7 +32,7 @@ def plot_importance(
title: Optional[str] = 'Feature importance', title: Optional[str] = 'Feature importance',
xlabel: Optional[str] = 'Feature importance', xlabel: Optional[str] = 'Feature importance',
ylabel: Optional[str] = 'Features', ylabel: Optional[str] = 'Features',
importance_type: str = 'split', importance_type: str = 'auto',
max_num_features: Optional[int] = None, max_num_features: Optional[int] = None,
ignore_zero: bool = True, ignore_zero: bool = True,
figsize: Optional[Tuple[float, float]] = None, figsize: Optional[Tuple[float, float]] = None,
...@@ -65,8 +65,9 @@ def plot_importance( ...@@ -65,8 +65,9 @@ def plot_importance(
ylabel : str or None, optional (default="Features") ylabel : str or None, optional (default="Features")
Y-axis title label. Y-axis title label.
If None, title is disabled. If None, title is disabled.
importance_type : str, optional (default="split") importance_type : str, optional (default="auto")
How the importance is calculated. How the importance is calculated.
If "auto", if ``booster`` parameter is LGBMModel, ``booster.importance_type`` attribute is used; "split" otherwise.
If "split", result contains numbers of times the feature is used in a model. If "split", result contains numbers of times the feature is used in a model.
If "gain", result contains total gains of splits which use the feature. If "gain", result contains total gains of splits which use the feature.
max_num_features : int or None, optional (default=None) max_num_features : int or None, optional (default=None)
...@@ -96,8 +97,13 @@ def plot_importance( ...@@ -96,8 +97,13 @@ def plot_importance(
raise ImportError('You must install matplotlib and restart your session to plot importance.') raise ImportError('You must install matplotlib and restart your session to plot importance.')
if isinstance(booster, LGBMModel): if isinstance(booster, LGBMModel):
if importance_type == "auto":
importance_type = booster.importance_type
booster = booster.booster_ booster = booster.booster_
elif not isinstance(booster, Booster): elif isinstance(booster, Booster):
if importance_type == "auto":
importance_type = "split"
else:
raise TypeError('booster must be Booster or LGBMModel.') raise TypeError('booster must be Booster or LGBMModel.')
importance = booster.feature_importance(importance_type=importance_type) importance = booster.feature_importance(importance_type=importance_type)
......
...@@ -57,8 +57,7 @@ def test_plot_importance(params, breast_cancer_split, train_data): ...@@ -57,8 +57,7 @@ def test_plot_importance(params, breast_cancer_split, train_data):
for patch in ax1.patches: for patch in ax1.patches:
assert patch.get_facecolor() == (1., 0, 0, 1.) # red assert patch.get_facecolor() == (1., 0, 0, 1.) # red
ax2 = lgb.plot_importance(gbm0, color=['r', 'y', 'g', 'b'], ax2 = lgb.plot_importance(gbm0, color=['r', 'y', 'g', 'b'], title=None, xlabel=None, ylabel=None)
title=None, xlabel=None, ylabel=None)
assert isinstance(ax2, matplotlib.axes.Axes) assert isinstance(ax2, matplotlib.axes.Axes)
assert ax2.get_title() == '' assert ax2.get_title() == ''
assert ax2.get_xlabel() == '' assert ax2.get_xlabel() == ''
...@@ -69,6 +68,25 @@ def test_plot_importance(params, breast_cancer_split, train_data): ...@@ -69,6 +68,25 @@ def test_plot_importance(params, breast_cancer_split, train_data):
assert ax2.patches[2].get_facecolor() == (0, .5, 0, 1.) # g assert ax2.patches[2].get_facecolor() == (0, .5, 0, 1.) # g
assert ax2.patches[3].get_facecolor() == (0, 0, 1., 1.) # b assert ax2.patches[3].get_facecolor() == (0, 0, 1., 1.) # b
gbm2 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True, importance_type="gain")
gbm2.fit(X_train, y_train)
def get_bounds_of_first_patch(axes):
return axes.patches[0].get_extents().bounds
first_bar1 = get_bounds_of_first_patch(lgb.plot_importance(gbm1))
first_bar2 = get_bounds_of_first_patch(lgb.plot_importance(gbm1, importance_type="split"))
first_bar3 = get_bounds_of_first_patch(lgb.plot_importance(gbm1, importance_type="gain"))
first_bar4 = get_bounds_of_first_patch(lgb.plot_importance(gbm2))
first_bar5 = get_bounds_of_first_patch(lgb.plot_importance(gbm2, importance_type="split"))
first_bar6 = get_bounds_of_first_patch(lgb.plot_importance(gbm2, importance_type="gain"))
assert first_bar1 == first_bar2
assert first_bar1 == first_bar5
assert first_bar3 == first_bar4
assert first_bar3 == first_bar6
assert first_bar1 != first_bar3
@pytest.mark.skipif(not MATPLOTLIB_INSTALLED, reason='matplotlib is not installed') @pytest.mark.skipif(not MATPLOTLIB_INSTALLED, reason='matplotlib is not installed')
def test_plot_split_value_histogram(params, breast_cancer_split, train_data): def test_plot_split_value_histogram(params, breast_cancer_split, train_data):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment