"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "f975d3fafcdbb3739dbe4eac40dc2b7e1e3244d7"
Unverified Commit 5fe2bdd7 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python] tests for plot tree functions and module_INSTALLED variables (#1438)

* removed excess import

* added tests for plotting trees in Python

* refined module_INSTALLED mechanism

* added note about that create_tree_digraph is better than plot_tree
parent c184852b
...@@ -48,7 +48,7 @@ if [[ $TASK == "if-else" ]]; then ...@@ -48,7 +48,7 @@ if [[ $TASK == "if-else" ]]; then
exit 0 exit 0
fi fi
conda install numpy nose scipy scikit-learn pandas matplotlib pytest conda install numpy nose scipy scikit-learn pandas matplotlib python-graphviz pytest
if [[ $TASK == "sdist" ]]; then if [[ $TASK == "sdist" ]]; then
cd $TRAVIS_BUILD_DIR/python-package && python setup.py sdist || exit -1 cd $TRAVIS_BUILD_DIR/python-package && python setup.py sdist || exit -1
...@@ -98,7 +98,6 @@ cd $TRAVIS_BUILD_DIR/python-package && python setup.py install --precompile || e ...@@ -98,7 +98,6 @@ cd $TRAVIS_BUILD_DIR/python-package && python setup.py install --precompile || e
pytest $TRAVIS_BUILD_DIR || exit -1 pytest $TRAVIS_BUILD_DIR || exit -1
if [[ $TASK == "regular" ]]; then if [[ $TASK == "regular" ]]; then
conda install python-graphviz
cd $TRAVIS_BUILD_DIR/examples/python-guide cd $TRAVIS_BUILD_DIR/examples/python-guide
sed -i'.bak' '/import lightgbm as lgb/a\ sed -i'.bak' '/import lightgbm as lgb/a\
import matplotlib\ import matplotlib\
......
...@@ -3,9 +3,9 @@ ...@@ -3,9 +3,9 @@
import lightgbm as lgb import lightgbm as lgb
import pandas as pd import pandas as pd
try: if lgb.compat.MATPLOTLIB_INSTALLED:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
except ImportError: else:
raise ImportError('You need to install matplotlib for plot_example.py.') raise ImportError('You need to install matplotlib for plot_example.py.')
# load or create your dataset # load or create your dataset
......
...@@ -57,13 +57,30 @@ def json_default_with_numpy(obj): ...@@ -57,13 +57,30 @@ def json_default_with_numpy(obj):
"""pandas""" """pandas"""
try: try:
from pandas import Series, DataFrame from pandas import Series, DataFrame
PANDAS_INSTALLED = True
except ImportError: except ImportError:
PANDAS_INSTALLED = False
class Series(object): class Series(object):
pass pass
class DataFrame(object): class DataFrame(object):
pass pass
"""matplotlib"""
try:
import matplotlib
MATPLOTLIB_INSTALLED = True
except ImportError:
MATPLOTLIB_INSTALLED = False
"""graphviz"""
try:
import graphviz
GRAPHVIZ_INSTALLED = True
except ImportError:
GRAPHVIZ_INSTALLED = False
"""sklearn""" """sklearn"""
try: try:
from sklearn.base import BaseEstimator from sklearn.base import BaseEstimator
......
...@@ -10,6 +10,7 @@ from io import BytesIO ...@@ -10,6 +10,7 @@ from io import BytesIO
import numpy as np import numpy as np
from .basic import Booster from .basic import Booster
from .compat import MATPLOTLIB_INSTALLED, GRAPHVIZ_INSTALLED
from .sklearn import LGBMModel from .sklearn import LGBMModel
...@@ -69,9 +70,9 @@ def plot_importance(booster, ax=None, height=0.2, ...@@ -69,9 +70,9 @@ def plot_importance(booster, ax=None, height=0.2,
ax : matplotlib.axes.Axes ax : matplotlib.axes.Axes
The plot with model's feature importances. The plot with model's feature importances.
""" """
try: if MATPLOTLIB_INSTALLED:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
except ImportError: else:
raise ImportError('You must install matplotlib to plot importance.') raise ImportError('You must install matplotlib to plot importance.')
if isinstance(booster, LGBMModel): if isinstance(booster, LGBMModel):
...@@ -173,9 +174,9 @@ def plot_metric(booster, metric=None, dataset_names=None, ...@@ -173,9 +174,9 @@ def plot_metric(booster, metric=None, dataset_names=None,
ax : matplotlib.axes.Axes ax : matplotlib.axes.Axes
The plot with metric's history over the training. The plot with metric's history over the training.
""" """
try: if MATPLOTLIB_INSTALLED:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
except ImportError: else:
raise ImportError('You must install matplotlib to plot metric.') raise ImportError('You must install matplotlib to plot metric.')
if isinstance(booster, LGBMModel): if isinstance(booster, LGBMModel):
...@@ -261,9 +262,9 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=None, ...@@ -261,9 +262,9 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=None,
See: See:
- http://graphviz.readthedocs.io/en/stable/api.html#digraph - http://graphviz.readthedocs.io/en/stable/api.html#digraph
""" """
try: if GRAPHVIZ_INSTALLED:
from graphviz import Digraph from graphviz import Digraph
except ImportError: else:
raise ImportError('You must install graphviz to plot tree.') raise ImportError('You must install graphviz to plot tree.')
def float2str(value, precision=None): def float2str(value, precision=None):
...@@ -399,6 +400,11 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None, ...@@ -399,6 +400,11 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None,
show_info=None, precision=None): show_info=None, precision=None):
"""Plot specified tree. """Plot specified tree.
Note
----
It is preferable to use ``create_tree_digraph()`` because of its lossless quality
and returned objects can be also rendered and displayed directly inside a Jupyter notebook.
Parameters Parameters
---------- ----------
booster : Booster or LGBMModel booster : Booster or LGBMModel
...@@ -430,10 +436,10 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None, ...@@ -430,10 +436,10 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None,
ax : matplotlib.axes.Axes ax : matplotlib.axes.Axes
The plot with single tree. The plot with single tree.
""" """
try: if MATPLOTLIB_INSTALLED:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.image as image import matplotlib.image as image
except ImportError: else:
raise ImportError('You must install matplotlib to plot tree.') raise ImportError('You must install matplotlib to plot tree.')
if ax is None: if ax is None:
......
# coding: utf-8 # coding: utf-8
# pylint: skip-file # pylint: skip-file
import os import os
import subprocess
import tempfile import tempfile
import unittest import unittest
......
...@@ -14,12 +14,6 @@ from sklearn.metrics import log_loss, mean_absolute_error, mean_squared_error ...@@ -14,12 +14,6 @@ from sklearn.metrics import log_loss, mean_absolute_error, mean_squared_error
from sklearn.model_selection import train_test_split, TimeSeriesSplit from sklearn.model_selection import train_test_split, TimeSeriesSplit
from scipy.sparse import csr_matrix from scipy.sparse import csr_matrix
try:
import pandas as pd
IS_PANDAS_INSTALLED = True
except ImportError:
IS_PANDAS_INSTALLED = False
try: try:
import cPickle as pickle import cPickle as pickle
except ImportError: except ImportError:
...@@ -478,8 +472,9 @@ class TestEngine(unittest.TestCase): ...@@ -478,8 +472,9 @@ class TestEngine(unittest.TestCase):
for ret in other_ret: for ret in other_ret:
self.assertAlmostEqual(ret_origin, ret, places=5) self.assertAlmostEqual(ret_origin, ret, places=5)
@unittest.skipIf(not IS_PANDAS_INSTALLED, 'pandas is not installed') @unittest.skipIf(not lgb.compat.PANDAS_INSTALLED, 'pandas is not installed')
def test_pandas_categorical(self): def test_pandas_categorical(self):
import pandas as pd
X = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'c', 'd'] * 75), # str X = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'c', 'd'] * 75), # str
"B": np.random.permutation([1, 2, 3] * 100), # int "B": np.random.permutation([1, 2, 3] * 100), # int
"C": np.random.permutation([0.1, 0.2, -0.1, -0.1, 0.2] * 60), # float "C": np.random.permutation([0.1, 0.2, -0.1, -0.1, 0.2] * 60), # float
......
...@@ -3,30 +3,31 @@ ...@@ -3,30 +3,31 @@
import unittest import unittest
import lightgbm as lgb import lightgbm as lgb
from lightgbm.compat import MATPLOTLIB_INSTALLED, GRAPHVIZ_INSTALLED
from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
try: if MATPLOTLIB_INSTALLED:
import matplotlib import matplotlib
matplotlib.use('Agg') matplotlib.use('Agg')
matplotlib_installed = True if GRAPHVIZ_INSTALLED:
except ImportError: import graphviz
matplotlib_installed = False
class TestBasic(unittest.TestCase): class TestBasic(unittest.TestCase):
@unittest.skipIf(not matplotlib_installed, 'matplotlib is not installed') def setUp(self):
def test_plot_importance(self): self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1)
X_train, _, y_train, _ = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1) self.train_data = lgb.Dataset(self.X_train, self.y_train)
train_data = lgb.Dataset(X_train, y_train) self.params = {
params = {
"objective": "binary", "objective": "binary",
"verbose": -1, "verbose": -1,
"num_leaves": 3 "num_leaves": 3
} }
gbm0 = lgb.train(params, train_data, num_boost_round=10)
@unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib is not installed')
def test_plot_importance(self):
gbm0 = lgb.train(self.params, self.train_data, num_boost_round=10)
ax0 = lgb.plot_importance(gbm0) ax0 = lgb.plot_importance(gbm0)
self.assertIsInstance(ax0, matplotlib.axes.Axes) self.assertIsInstance(ax0, matplotlib.axes.Axes)
self.assertEqual(ax0.get_title(), 'Feature importance') self.assertEqual(ax0.get_title(), 'Feature importance')
...@@ -35,7 +36,7 @@ class TestBasic(unittest.TestCase): ...@@ -35,7 +36,7 @@ class TestBasic(unittest.TestCase):
self.assertLessEqual(len(ax0.patches), 30) self.assertLessEqual(len(ax0.patches), 30)
gbm1 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True) gbm1 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
gbm1.fit(X_train, y_train) gbm1.fit(self.X_train, self.y_train)
ax1 = lgb.plot_importance(gbm1, color='r', title='t', xlabel='x', ylabel='y') ax1 = lgb.plot_importance(gbm1, color='r', title='t', xlabel='x', ylabel='y')
self.assertIsInstance(ax1, matplotlib.axes.Axes) self.assertIsInstance(ax1, matplotlib.axes.Axes)
...@@ -58,26 +59,55 @@ class TestBasic(unittest.TestCase): ...@@ -58,26 +59,55 @@ class TestBasic(unittest.TestCase):
self.assertTupleEqual(ax2.patches[2].get_facecolor(), (0, .5, 0, 1.)) # g self.assertTupleEqual(ax2.patches[2].get_facecolor(), (0, .5, 0, 1.)) # g
self.assertTupleEqual(ax2.patches[3].get_facecolor(), (0, 0, 1., 1.)) # b self.assertTupleEqual(ax2.patches[3].get_facecolor(), (0, 0, 1., 1.)) # b
@unittest.skip('Graphviz are not executables on Travis') @unittest.skipIf(not MATPLOTLIB_INSTALLED or not GRAPHVIZ_INSTALLED, 'matplotlib or graphviz is not installed')
def test_plot_tree(self): def test_plot_tree(self):
pass gbm = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
gbm.fit(self.X_train, self.y_train, verbose=False)
@unittest.skipIf(not matplotlib_installed, 'matplotlib is not installed') self.assertRaises(IndexError, lgb.plot_tree, gbm, tree_index=83)
def test_plot_metrics(self):
X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1)
train_data = lgb.Dataset(X_train, y_train)
test_data = lgb.Dataset(X_test, y_test, reference=train_data)
params = { ax = lgb.plot_tree(gbm, tree_index=3, figsize=(15, 8), show_info=['split_gain'])
"objective": "binary", self.assertIsInstance(ax, matplotlib.axes.Axes)
"metric": {"binary_logloss", "binary_error"}, w, h = ax.axes.get_figure().get_size_inches()
"verbose": -1, self.assertEqual(int(w), 15)
"num_leaves": 3 self.assertEqual(int(h), 8)
}
@unittest.skipIf(not GRAPHVIZ_INSTALLED, 'graphviz is not installed')
def test_create_tree_digraph(self):
gbm = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
gbm.fit(self.X_train, self.y_train, verbose=False)
self.assertRaises(IndexError, lgb.create_tree_digraph, gbm, tree_index=83)
graph = lgb.create_tree_digraph(gbm, tree_index=3,
show_info=['split_gain', 'internal_value'],
name='Tree4', node_attr={'color': 'red'})
graph.render(view=False)
self.assertIsInstance(graph, graphviz.Digraph)
self.assertEqual(graph.name, 'Tree4')
self.assertEqual(graph.filename, 'Tree4.gv')
self.assertEqual(len(graph.node_attr), 1)
self.assertEqual(graph.node_attr['color'], 'red')
self.assertEqual(len(graph.graph_attr), 0)
self.assertEqual(len(graph.edge_attr), 0)
graph_body = ''.join(graph.body)
self.assertIn('threshold', graph_body)
self.assertIn('split_feature_name', graph_body)
self.assertNotIn('split_feature_index', graph_body)
self.assertIn('leaf_index', graph_body)
self.assertIn('split_gain', graph_body)
self.assertIn('internal_value', graph_body)
self.assertNotIn('internal_count', graph_body)
self.assertNotIn('leaf_count', graph_body)
@unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib is not installed')
def test_plot_metrics(self):
test_data = lgb.Dataset(self.X_test, self.y_test, reference=self.train_data)
self.params.update({"metric": {"binary_logloss", "binary_error"}})
evals_result0 = {} evals_result0 = {}
gbm0 = lgb.train(params, train_data, gbm0 = lgb.train(self.params, self.train_data,
valid_sets=[train_data, test_data], valid_sets=[self.train_data, test_data],
valid_names=['v1', 'v2'], valid_names=['v1', 'v2'],
num_boost_round=10, num_boost_round=10,
evals_result=evals_result0, evals_result=evals_result0,
...@@ -91,14 +121,14 @@ class TestBasic(unittest.TestCase): ...@@ -91,14 +121,14 @@ class TestBasic(unittest.TestCase):
ax0 = lgb.plot_metric(evals_result0, metric='binary_logloss', dataset_names=['v2']) ax0 = lgb.plot_metric(evals_result0, metric='binary_logloss', dataset_names=['v2'])
evals_result1 = {} evals_result1 = {}
gbm1 = lgb.train(params, train_data, gbm1 = lgb.train(self.params, self.train_data,
num_boost_round=10, num_boost_round=10,
evals_result=evals_result1, evals_result=evals_result1,
verbose_eval=False) verbose_eval=False)
self.assertRaises(ValueError, lgb.plot_metric, evals_result1) self.assertRaises(ValueError, lgb.plot_metric, evals_result1)
gbm2 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True) gbm2 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
gbm2.fit(X_train, y_train, eval_set=[(X_test, y_test)], verbose=False) gbm2.fit(self.X_train, self.y_train, eval_set=[(self.X_test, self.y_test)], verbose=False)
ax2 = lgb.plot_metric(gbm2, title=None, xlabel=None, ylabel=None) ax2 = lgb.plot_metric(gbm2, title=None, xlabel=None, ylabel=None)
self.assertIsInstance(ax2, matplotlib.axes.Axes) self.assertIsInstance(ax2, matplotlib.axes.Axes)
self.assertEqual(ax2.get_title(), '') self.assertEqual(ax2.get_title(), '')
......
...@@ -19,11 +19,6 @@ try: ...@@ -19,11 +19,6 @@ try:
sklearn_at_least_019 = True sklearn_at_least_019 = True
except ImportError: except ImportError:
sklearn_at_least_019 = False sklearn_at_least_019 = False
try:
import pandas as pd
IS_PANDAS_INSTALLED = True
except ImportError:
IS_PANDAS_INSTALLED = False
def multi_error(y_true, y_pred): def multi_error(y_true, y_pred):
...@@ -182,26 +177,27 @@ class TestSklearn(unittest.TestCase): ...@@ -182,26 +177,27 @@ class TestSklearn(unittest.TestCase):
y_pred_2 = clf_2.fit(X_train, y_train).predict_proba(X_test) y_pred_2 = clf_2.fit(X_train, y_train).predict_proba(X_test)
np.testing.assert_allclose(y_pred_1, y_pred_2) np.testing.assert_allclose(y_pred_1, y_pred_2)
# sklearn <0.19 cannot accept instance, but many tests could be passed only with min_data=1 and min_data_in_bin=1
@unittest.skipIf(not sklearn_at_least_019, 'scikit-learn version is less than 0.19')
def test_sklearn_integration(self): def test_sklearn_integration(self):
# sklearn <0.19 cannot accept instance, but many tests could be passed only with min_data=1 and min_data_in_bin=1 # we cannot use `check_estimator` directly since there is no skip test mechanism
if sklearn_at_least_019: for name, estimator in ((lgb.sklearn.LGBMClassifier.__name__, lgb.sklearn.LGBMClassifier),
# we cannot use `check_estimator` directly since there is no skip test mechanism (lgb.sklearn.LGBMRegressor.__name__, lgb.sklearn.LGBMRegressor)):
for name, estimator in ((lgb.sklearn.LGBMClassifier.__name__, lgb.sklearn.LGBMClassifier), check_parameters_default_constructible(name, estimator)
(lgb.sklearn.LGBMRegressor.__name__, lgb.sklearn.LGBMRegressor)): check_no_fit_attributes_set_in_init(name, estimator)
check_parameters_default_constructible(name, estimator) # we cannot leave default params (see https://github.com/Microsoft/LightGBM/issues/833)
check_no_fit_attributes_set_in_init(name, estimator) estimator = estimator(min_child_samples=1, min_data_in_bin=1)
# we cannot leave default params (see https://github.com/Microsoft/LightGBM/issues/833) for check in _yield_all_checks(name, estimator):
estimator = estimator(min_child_samples=1, min_data_in_bin=1) if check.__name__ == 'check_estimators_nan_inf':
for check in _yield_all_checks(name, estimator): continue # skip test because LightGBM deals with nan
if check.__name__ == 'check_estimators_nan_inf': try:
continue # skip test because LightGBM deals with nan check(name, estimator)
try: except SkipTest as message:
check(name, estimator) warnings.warn(message, SkipTestWarning)
except SkipTest as message:
warnings.warn(message, SkipTestWarning) @unittest.skipIf(not lgb.compat.PANDAS_INSTALLED, 'pandas is not installed')
@unittest.skipIf(not IS_PANDAS_INSTALLED, 'pandas is not installed')
def test_pandas_categorical(self): def test_pandas_categorical(self):
import pandas as pd
X = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'c', 'd'] * 75), # str X = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'c', 'd'] * 75), # str
"B": np.random.permutation([1, 2, 3] * 100), # int "B": np.random.permutation([1, 2, 3] * 100), # int
"C": np.random.permutation([0.1, 0.2, -0.1, -0.1, 0.2] * 60), # float "C": np.random.permutation([0.1, 0.2, -0.1, -0.1, 0.2] * 60), # float
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment