Commit eef4d2d0 authored by wxchan's avatar wxchan Committed by Qiwei Ye
Browse files

refine plotting library (#282)

* refine plot

* use warnings

* refine  logic

* revert 'move to compat.py'
parent 9a3781fb
...@@ -946,8 +946,8 @@ The methods of each Class is in alphabetical order. ...@@ -946,8 +946,8 @@ The methods of each Class is in alphabetical order.
Parameters Parameters
---------- ----------
booster : Booster, LGBMModel or array booster : Booster or LGBMModel
Booster or LGBMModel instance, or array of feature importances. Booster or LGBMModel instance.
ax : matplotlib Axes ax : matplotlib Axes
Target axes instance. If None, new figure and axes will be created. Target axes instance. If None, new figure and axes will be created.
height : float height : float
......
...@@ -9,5 +9,6 @@ Documents ...@@ -9,5 +9,6 @@ Documents
* [Parameters Tuning](./Parameters-tuning.md) * [Parameters Tuning](./Parameters-tuning.md)
* [Python API Reference](./Python-API.md) * [Python API Reference](./Python-API.md)
* [Parallel Learning Guide](https://github.com/Microsoft/LightGBM/wiki/Parallel-Learning-Guide) * [Parallel Learning Guide](https://github.com/Microsoft/LightGBM/wiki/Parallel-Learning-Guide)
* [FAQ](./FAQ.md)
* [Development Guide](./development.md) * [Development Guide](./development.md)
...@@ -6,6 +6,7 @@ from __future__ import absolute_import ...@@ -6,6 +6,7 @@ from __future__ import absolute_import
import ctypes import ctypes
import os import os
import warnings
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
import numpy as np import numpy as np
...@@ -223,6 +224,10 @@ PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', ...@@ -223,6 +224,10 @@ PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int',
def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorical): def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorical):
if isinstance(data, DataFrame): if isinstance(data, DataFrame):
if feature_name == 'auto' or feature_name is None: if feature_name == 'auto' or feature_name is None:
if all([isinstance(name, integer_types + (np.integer, )) for name in data.columns]):
msg = """Using Pandas (default) integer column names, not column indexes. You can use indexes with DataFrame.values."""
warnings.filterwarnings('once')
warnings.warn(msg, stacklevel=5)
data = data.rename(columns=str) data = data.rename(columns=str)
cat_cols = data.select_dtypes(include=['category']).columns cat_cols = data.select_dtypes(include=['category']).columns
if pandas_categorical is None: # train dataset if pandas_categorical is None: # train dataset
......
...@@ -14,7 +14,7 @@ is_py3 = (sys.version_info[0] == 3) ...@@ -14,7 +14,7 @@ is_py3 = (sys.version_info[0] == 3)
if is_py3: if is_py3:
string_type = str string_type = str
numeric_types = (int, float, bool) numeric_types = (int, float, bool)
integer_types = int integer_types = (int, )
range_ = range range_ = range
def argc_(func): def argc_(func):
......
...@@ -3,12 +3,13 @@ ...@@ -3,12 +3,13 @@
"""Plotting Library.""" """Plotting Library."""
from __future__ import absolute_import from __future__ import absolute_import
import warnings
from copy import deepcopy from copy import deepcopy
from io import BytesIO from io import BytesIO
import numpy as np import numpy as np
from .basic import Booster, is_numpy_1d_array from .basic import Booster
from .sklearn import LGBMModel from .sklearn import LGBMModel
...@@ -27,8 +28,8 @@ def plot_importance(booster, ax=None, height=0.2, ...@@ -27,8 +28,8 @@ def plot_importance(booster, ax=None, height=0.2,
Parameters Parameters
---------- ----------
booster : Booster, LGBMModel or array booster : Booster or LGBMModel
Booster or LGBMModel instance, or array of feature importances Booster or LGBMModel instance
ax : matplotlib Axes ax : matplotlib Axes
Target axes instance. If None, new figure and axes will be created. Target axes instance. If None, new figure and axes will be created.
height : float height : float
...@@ -69,18 +70,17 @@ def plot_importance(booster, ax=None, height=0.2, ...@@ -69,18 +70,17 @@ def plot_importance(booster, ax=None, height=0.2,
raise ImportError('You must install matplotlib to plot importance.') raise ImportError('You must install matplotlib to plot importance.')
if isinstance(booster, LGBMModel): if isinstance(booster, LGBMModel):
importance = booster.booster_.feature_importance(importance_type=importance_type) booster = booster.booster_
elif isinstance(booster, Booster): elif not isinstance(booster, Booster):
raise TypeError('booster must be Booster or LGBMModel.')
importance = booster.feature_importance(importance_type=importance_type) importance = booster.feature_importance(importance_type=importance_type)
elif is_numpy_1d_array(booster) or isinstance(booster, list): feature_name = booster.feature_name()
importance = booster
else:
raise TypeError('booster must be Booster, LGBMModel or array instance.')
if not len(importance): if not len(importance):
raise ValueError('Booster feature_importances are empty.') raise ValueError('Booster feature_importances are empty.')
tuples = sorted(enumerate(importance), key=lambda x: x[1]) tuples = sorted(zip(feature_name, importance), key=lambda x: x[1])
if ignore_zero: if ignore_zero:
tuples = [x for x in tuples if x[1] > 0] tuples = [x for x in tuples if x[1] > 0]
if max_num_features is not None and max_num_features > 0: if max_num_features is not None and max_num_features > 0:
...@@ -196,7 +196,8 @@ def plot_metric(booster, metric=None, dataset_names=None, ...@@ -196,7 +196,8 @@ def plot_metric(booster, metric=None, dataset_names=None,
num_metric = len(metrics_for_one) num_metric = len(metrics_for_one)
if metric is None: if metric is None:
if num_metric > 1: if num_metric > 1:
print('Warning: more than one metric available, picking one to plot.') msg = """more than one metric available, picking one to plot."""
warnings.warn(msg, stacklevel=2)
metric, results = metrics_for_one.popitem() metric, results = metrics_for_one.popitem()
else: else:
if metric not in metrics_for_one: if metric not in metrics_for_one:
......
...@@ -46,8 +46,7 @@ class TestBasic(unittest.TestCase): ...@@ -46,8 +46,7 @@ class TestBasic(unittest.TestCase):
for patch in ax1.patches: for patch in ax1.patches:
self.assertTupleEqual(patch.get_facecolor(), (1., 0, 0, 1.)) # red self.assertTupleEqual(patch.get_facecolor(), (1., 0, 0, 1.)) # red
ax2 = lgb.plot_importance(gbm0.feature_importance(), ax2 = lgb.plot_importance(gbm0, color=['r', 'y', 'g', 'b'],
color=['r', 'y', 'g', 'b'],
title=None, xlabel=None, ylabel=None) title=None, xlabel=None, ylabel=None)
self.assertIsInstance(ax2, matplotlib.axes.Axes) self.assertIsInstance(ax2, matplotlib.axes.Axes)
self.assertEqual(ax2.get_title(), '') self.assertEqual(ax2.get_title(), '')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment