Commit eef4d2d0 authored by wxchan's avatar wxchan Committed by Qiwei Ye
Browse files

refine plotting library (#282)

* refine plot

* use warnings

* refine  logic

* revert 'move to compat.py'
parent 9a3781fb
......@@ -946,8 +946,8 @@ The methods of each Class is in alphabetical order.
Parameters
----------
booster : Booster, LGBMModel or array
Booster or LGBMModel instance, or array of feature importances.
booster : Booster or LGBMModel
Booster or LGBMModel instance.
ax : matplotlib Axes
Target axes instance. If None, new figure and axes will be created.
height : float
......
......@@ -9,5 +9,6 @@ Documents
* [Parameters Tuning](./Parameters-tuning.md)
* [Python API Reference](./Python-API.md)
* [Parallel Learning Guide](https://github.com/Microsoft/LightGBM/wiki/Parallel-Learning-Guide)
* [FAQ](./FAQ.md)
* [Development Guide](./development.md)
......@@ -6,6 +6,7 @@ from __future__ import absolute_import
import ctypes
import os
import warnings
from tempfile import NamedTemporaryFile
import numpy as np
......@@ -223,6 +224,10 @@ PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int',
def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorical):
if isinstance(data, DataFrame):
if feature_name == 'auto' or feature_name is None:
if all([isinstance(name, integer_types + (np.integer, )) for name in data.columns]):
msg = """Using Pandas (default) integer column names, not column indexes. You can use indexes with DataFrame.values."""
warnings.filterwarnings('once')
warnings.warn(msg, stacklevel=5)
data = data.rename(columns=str)
cat_cols = data.select_dtypes(include=['category']).columns
if pandas_categorical is None: # train dataset
......
......@@ -14,7 +14,7 @@ is_py3 = (sys.version_info[0] == 3)
if is_py3:
string_type = str
numeric_types = (int, float, bool)
integer_types = int
integer_types = (int, )
range_ = range
def argc_(func):
......
......@@ -3,12 +3,13 @@
"""Plotting Library."""
from __future__ import absolute_import
import warnings
from copy import deepcopy
from io import BytesIO
import numpy as np
from .basic import Booster, is_numpy_1d_array
from .basic import Booster
from .sklearn import LGBMModel
......@@ -27,8 +28,8 @@ def plot_importance(booster, ax=None, height=0.2,
Parameters
----------
booster : Booster, LGBMModel or array
Booster or LGBMModel instance, or array of feature importances
booster : Booster or LGBMModel
Booster or LGBMModel instance
ax : matplotlib Axes
Target axes instance. If None, new figure and axes will be created.
height : float
......@@ -69,18 +70,17 @@ def plot_importance(booster, ax=None, height=0.2,
raise ImportError('You must install matplotlib to plot importance.')
if isinstance(booster, LGBMModel):
importance = booster.booster_.feature_importance(importance_type=importance_type)
elif isinstance(booster, Booster):
importance = booster.feature_importance(importance_type=importance_type)
elif is_numpy_1d_array(booster) or isinstance(booster, list):
importance = booster
else:
raise TypeError('booster must be Booster, LGBMModel or array instance.')
booster = booster.booster_
elif not isinstance(booster, Booster):
raise TypeError('booster must be Booster or LGBMModel.')
importance = booster.feature_importance(importance_type=importance_type)
feature_name = booster.feature_name()
if not len(importance):
raise ValueError('Booster feature_importances are empty.')
tuples = sorted(enumerate(importance), key=lambda x: x[1])
tuples = sorted(zip(feature_name, importance), key=lambda x: x[1])
if ignore_zero:
tuples = [x for x in tuples if x[1] > 0]
if max_num_features is not None and max_num_features > 0:
......@@ -196,7 +196,8 @@ def plot_metric(booster, metric=None, dataset_names=None,
num_metric = len(metrics_for_one)
if metric is None:
if num_metric > 1:
print('Warning: more than one metric available, picking one to plot.')
msg = """more than one metric available, picking one to plot."""
warnings.warn(msg, stacklevel=2)
metric, results = metrics_for_one.popitem()
else:
if metric not in metrics_for_one:
......
......@@ -46,8 +46,7 @@ class TestBasic(unittest.TestCase):
for patch in ax1.patches:
self.assertTupleEqual(patch.get_facecolor(), (1., 0, 0, 1.)) # red
ax2 = lgb.plot_importance(gbm0.feature_importance(),
color=['r', 'y', 'g', 'b'],
ax2 = lgb.plot_importance(gbm0, color=['r', 'y', 'g', 'b'],
title=None, xlabel=None, ylabel=None)
self.assertIsInstance(ax2, matplotlib.axes.Axes)
self.assertEqual(ax2.get_title(), '')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment