Unverified Commit 03c4d455 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[tests][python] reduce unnecessary data loading in tests (#3486)



* [ci] [python] reduce unnecessary data loading in tests

* add profiling files to gitignore

* just use cache()

* default on cache size

* patch lru_cache on Python 2.7

* linting

* reduce duplicated code

* missing warnings

* fix imports

* fix lru_cache backport

* missing kwargs

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* reduce duplicated code

* cache in test_plotting
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 5cc9e671
...@@ -318,6 +318,8 @@ htmlcov/ ...@@ -318,6 +318,8 @@ htmlcov/
.coverage.* .coverage.*
.cache .cache
nosetests.xml nosetests.xml
prof/
*.prof
coverage.xml coverage.xml
*,cover *,cover
.hypothesis/ .hypothesis/
......
...@@ -7,9 +7,11 @@ import lightgbm as lgb ...@@ -7,9 +7,11 @@ import lightgbm as lgb
import numpy as np import numpy as np
from scipy import sparse from scipy import sparse
from sklearn.datasets import load_breast_cancer, dump_svmlight_file, load_svmlight_file from sklearn.datasets import dump_svmlight_file, load_svmlight_file
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from .utils import load_breast_cancer
class TestBasic(unittest.TestCase): class TestBasic(unittest.TestCase):
......
...@@ -10,8 +10,7 @@ import unittest ...@@ -10,8 +10,7 @@ import unittest
import lightgbm as lgb import lightgbm as lgb
import numpy as np import numpy as np
from scipy.sparse import csr_matrix, isspmatrix_csr, isspmatrix_csc from scipy.sparse import csr_matrix, isspmatrix_csr, isspmatrix_csc
from sklearn.datasets import (load_boston, load_breast_cancer, load_digits, from sklearn.datasets import load_svmlight_file, make_multilabel_classification
load_iris, load_svmlight_file, make_multilabel_classification)
from sklearn.metrics import log_loss, mean_absolute_error, mean_squared_error, roc_auc_score, average_precision_score from sklearn.metrics import log_loss, mean_absolute_error, mean_squared_error, roc_auc_score, average_precision_score
from sklearn.model_selection import train_test_split, TimeSeriesSplit, GroupKFold from sklearn.model_selection import train_test_split, TimeSeriesSplit, GroupKFold
...@@ -20,6 +19,8 @@ try: ...@@ -20,6 +19,8 @@ try:
except ImportError: except ImportError:
import pickle import pickle
from .utils import load_boston, load_breast_cancer, load_digits, load_iris
decreasing_generator = itertools.count(0, -1) decreasing_generator = itertools.count(0, -1)
...@@ -2524,6 +2525,7 @@ class TestEngine(unittest.TestCase): ...@@ -2524,6 +2525,7 @@ class TestEngine(unittest.TestCase):
sklearn_ap = average_precision_score(y, pred) sklearn_ap = average_precision_score(y, pred)
self.assertAlmostEqual(ap, sklearn_ap) self.assertAlmostEqual(ap, sklearn_ap)
# test that average precision is 1 where model predicts perfectly # test that average precision is 1 where model predicts perfectly
y = y.copy()
y[:] = 1 y[:] = 1
lgb_X = lgb.Dataset(X, label=y) lgb_X = lgb.Dataset(X, label=y)
lgb.train(params, lgb_X, num_boost_round=1, valid_sets=[lgb_X], evals_result=res) lgb.train(params, lgb_X, num_boost_round=1, valid_sets=[lgb_X], evals_result=res)
......
...@@ -3,7 +3,6 @@ import unittest ...@@ -3,7 +3,6 @@ import unittest
import lightgbm as lgb import lightgbm as lgb
from lightgbm.compat import MATPLOTLIB_INSTALLED, GRAPHVIZ_INSTALLED from lightgbm.compat import MATPLOTLIB_INSTALLED, GRAPHVIZ_INSTALLED
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
if MATPLOTLIB_INSTALLED: if MATPLOTLIB_INSTALLED:
...@@ -12,6 +11,8 @@ if MATPLOTLIB_INSTALLED: ...@@ -12,6 +11,8 @@ if MATPLOTLIB_INSTALLED:
if GRAPHVIZ_INSTALLED: if GRAPHVIZ_INSTALLED:
import graphviz import graphviz
from .utils import load_breast_cancer
class TestBasic(unittest.TestCase): class TestBasic(unittest.TestCase):
......
...@@ -10,9 +10,7 @@ import lightgbm as lgb ...@@ -10,9 +10,7 @@ import lightgbm as lgb
import numpy as np import numpy as np
from sklearn import __version__ as sk_version from sklearn import __version__ as sk_version
from sklearn.base import clone from sklearn.base import clone
from sklearn.datasets import (load_boston, load_breast_cancer, load_digits, from sklearn.datasets import load_svmlight_file, make_multilabel_classification
load_iris, load_linnerud, load_svmlight_file,
make_multilabel_classification)
from sklearn.exceptions import SkipTestWarning from sklearn.exceptions import SkipTestWarning
from sklearn.metrics import log_loss, mean_squared_error from sklearn.metrics import log_loss, mean_squared_error
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, train_test_split from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, train_test_split
...@@ -22,6 +20,8 @@ from sklearn.utils.estimator_checks import (_yield_all_checks, SkipTest, ...@@ -22,6 +20,8 @@ from sklearn.utils.estimator_checks import (_yield_all_checks, SkipTest,
check_parameters_default_constructible) check_parameters_default_constructible)
from sklearn.utils.validation import check_is_fitted from sklearn.utils.validation import check_is_fitted
from .utils import load_boston, load_breast_cancer, load_digits, load_iris, load_linnerud
decreasing_generator = itertools.count(0, -1) decreasing_generator = itertools.count(0, -1)
......
# coding: utf-8
import sklearn.datasets
try:
from functools import lru_cache
except ImportError:
import warnings
warnings.warn("Could not import functools.lru_cache", RuntimeWarning)
def lru_cache(maxsize=None):
cache = {}
def _lru_wrapper(user_function):
def wrapper(*args, **kwargs):
arg_key = (args, tuple(kwargs.items()))
if arg_key not in cache:
cache[arg_key] = user_function(*args, **kwargs)
return cache[arg_key]
return wrapper
return _lru_wrapper
@lru_cache(maxsize=None)
def load_boston(**kwargs):
return sklearn.datasets.load_boston(**kwargs)
@lru_cache(maxsize=None)
def load_breast_cancer(**kwargs):
return sklearn.datasets.load_breast_cancer(**kwargs)
@lru_cache(maxsize=None)
def load_digits(**kwargs):
return sklearn.datasets.load_digits(**kwargs)
@lru_cache(maxsize=None)
def load_iris(**kwargs):
return sklearn.datasets.load_iris(**kwargs)
@lru_cache(maxsize=None)
def load_linnerud(**kwargs):
return sklearn.datasets.load_linnerud(**kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment