Unverified Commit 9bacf03c authored by Thomas J. Fan's avatar Thomas J. Fan Committed by GitHub
Browse files

[python][tests] Migrates test_basic.py to use pytest (#3764)

* TST Migrates test_basic.py to use pytest

* STY Linting

* CI Force CI to run
parent 53639f4a
# coding: utf-8
import os
import tempfile
import unittest
import lightgbm as lgb
import numpy as np
import pytest
from scipy import sparse
from sklearn.datasets import dump_svmlight_file, load_svmlight_file
......@@ -13,9 +13,7 @@ from sklearn.model_selection import train_test_split
from .utils import load_breast_cancer
class TestBasic(unittest.TestCase):
def test(self):
def test_basic(tmp_path):
X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(return_X_y=True),
test_size=0.1, random_state=2)
train_data = lgb.Dataset(X_train, label=y_train)
......@@ -39,25 +37,24 @@ class TestBasic(unittest.TestCase):
if i % 10 == 0:
print(bst.eval_train(), bst.eval_valid())
self.assertEqual(bst.current_iteration(), 20)
self.assertEqual(bst.num_trees(), 20)
self.assertEqual(bst.num_model_per_iteration(), 1)
self.assertAlmostEqual(bst.lower_bound(), -2.9040190126976606)
self.assertAlmostEqual(bst.upper_bound(), 3.3182142872462883)
assert bst.current_iteration() == 20
assert bst.num_trees() == 20
assert bst.num_model_per_iteration() == 1
assert bst.lower_bound() == pytest.approx(-2.9040190126976606)
assert bst.upper_bound() == pytest.approx(3.3182142872462883)
tname = str(tmp_path / "svm_light.dat")
model_file = str(tmp_path / "model.txt")
bst.save_model("model.txt")
bst.save_model(model_file)
pred_from_matr = bst.predict(X_test)
with tempfile.NamedTemporaryFile() as f:
tname = f.name
with open(tname, "w+b") as f:
dump_svmlight_file(X_test, y_test, f)
pred_from_file = bst.predict(tname)
os.remove(tname)
np.testing.assert_allclose(pred_from_matr, pred_from_file)
# check saved model persistence
bst = lgb.Booster(params, model_file="model.txt")
os.remove("model.txt")
bst = lgb.Booster(params, model_file=model_file)
pred_from_model_file = bst.predict(X_test)
# we need to check the consistency of model file here, so test for exact equal
np.testing.assert_array_equal(pred_from_matr, pred_from_model_file)
......@@ -85,10 +82,11 @@ class TestBasic(unittest.TestCase):
dump_svmlight_file(X_test, y_test, f, zero_based=False)
np.testing.assert_raises_regex(lgb.basic.LightGBMError, bad_shape_error_msg,
bst.predict, tname)
os.remove(tname)
def test_chunked_dataset(self):
X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(return_X_y=True), test_size=0.1, random_state=2)
def test_chunked_dataset():
X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(return_X_y=True), test_size=0.1,
random_state=2)
chunk_size = X_train.shape[0] // 10 + 1
X_train = [X_train[i * chunk_size:(i + 1) * chunk_size, :] for i in range(X_train.shape[0] // chunk_size + 1)]
......@@ -99,7 +97,8 @@ class TestBasic(unittest.TestCase):
train_data.construct()
valid_data.construct()
def test_chunked_dataset_linear(self):
def test_chunked_dataset_linear():
X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(return_X_y=True), test_size=0.1,
random_state=2)
chunk_size = X_train.shape[0] // 10 + 1
......@@ -111,7 +110,8 @@ class TestBasic(unittest.TestCase):
train_data.construct()
valid_data.construct()
def test_save_and_load_linear(self):
def test_save_and_load_linear(tmp_path):
X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(return_X_y=True), test_size=0.1,
random_state=2)
X_train = np.concatenate([np.ones((X_train.shape[0], 1)), X_train], 1)
......@@ -121,55 +121,62 @@ class TestBasic(unittest.TestCase):
train_data_1 = lgb.Dataset(X_train, label=y_train, params=params)
est_1 = lgb.train(params, train_data_1, num_boost_round=10, categorical_feature=[0])
pred_1 = est_1.predict(X_train)
train_data_1.save_binary('temp_dataset.bin')
train_data_2 = lgb.Dataset('temp_dataset.bin')
tmp_dataset = str(tmp_path / 'temp_dataset.bin')
train_data_1.save_binary(tmp_dataset)
train_data_2 = lgb.Dataset(tmp_dataset)
est_2 = lgb.train(params, train_data_2, num_boost_round=10)
pred_2 = est_2.predict(X_train)
np.testing.assert_allclose(pred_1, pred_2)
est_2.save_model('model.txt')
est_3 = lgb.Booster(model_file='model.txt')
model_file = str(tmp_path / 'model.txt')
est_2.save_model(model_file)
est_3 = lgb.Booster(model_file=model_file)
pred_3 = est_3.predict(X_train)
np.testing.assert_allclose(pred_2, pred_3)
def test_subset_group(self):
def test_subset_group():
X_train, y_train = load_svmlight_file(os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/lambdarank/rank.train'))
q_train = np.loadtxt(os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/lambdarank/rank.train.query'))
lgb_train = lgb.Dataset(X_train, y_train, group=q_train)
self.assertEqual(len(lgb_train.get_group()), 201)
assert len(lgb_train.get_group()) == 201
subset = lgb_train.subset(list(range(10))).construct()
subset_group = subset.get_group()
self.assertEqual(len(subset_group), 2)
self.assertEqual(subset_group[0], 1)
self.assertEqual(subset_group[1], 9)
assert len(subset_group) == 2
assert subset_group[0] == 1
assert subset_group[1] == 9
def test_add_features_throws_if_num_data_unequal(self):
def test_add_features_throws_if_num_data_unequal():
X1 = np.random.random((100, 1))
X2 = np.random.random((10, 1))
d1 = lgb.Dataset(X1).construct()
d2 = lgb.Dataset(X2).construct()
with self.assertRaises(lgb.basic.LightGBMError):
with pytest.raises(lgb.basic.LightGBMError):
d1.add_features_from(d2)
def test_add_features_throws_if_datasets_unconstructed(self):
def test_add_features_throws_if_datasets_unconstructed():
X1 = np.random.random((100, 1))
X2 = np.random.random((100, 1))
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
d1 = lgb.Dataset(X1)
d2 = lgb.Dataset(X2)
d1.add_features_from(d2)
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
d1 = lgb.Dataset(X1).construct()
d2 = lgb.Dataset(X2)
d1.add_features_from(d2)
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
d1 = lgb.Dataset(X1)
d2 = lgb.Dataset(X2).construct()
d1.add_features_from(d2)
def test_add_features_equal_data_on_alternating_used_unused(self):
self.maxDiff = None
def test_add_features_equal_data_on_alternating_used_unused(tmp_path):
X = np.random.random((100, 5))
X[:, [1, 3]] = 0
names = ['col_%d' % i for i in range(5)]
......@@ -177,23 +184,19 @@ class TestBasic(unittest.TestCase):
d1 = lgb.Dataset(X[:, :j], feature_name=names[:j]).construct()
d2 = lgb.Dataset(X[:, j:], feature_name=names[j:]).construct()
d1.add_features_from(d2)
with tempfile.NamedTemporaryFile() as f:
d1name = f.name
d1name = str(tmp_path / "d1.txt")
d1._dump_text(d1name)
d = lgb.Dataset(X, feature_name=names).construct()
with tempfile.NamedTemporaryFile() as f:
dname = f.name
dname = str(tmp_path / "d.txt")
d._dump_text(dname)
with open(d1name, 'rt') as d1f:
d1txt = d1f.read()
with open(dname, 'rt') as df:
dtxt = df.read()
os.remove(dname)
os.remove(d1name)
self.assertEqual(dtxt, d1txt)
assert dtxt == d1txt
def test_add_features_same_booster_behaviour(self):
self.maxDiff = None
def test_add_features_same_booster_behaviour(tmp_path):
X = np.random.random((100, 5))
X[:, [1, 3]] = 0
names = ['col_%d' % i for i in range(5)]
......@@ -210,21 +213,19 @@ class TestBasic(unittest.TestCase):
for k in range(10):
b.update()
b1.update()
with tempfile.NamedTemporaryFile() as df:
dname = df.name
with tempfile.NamedTemporaryFile() as d1f:
d1name = d1f.name
dname = str(tmp_path / "d.txt")
d1name = str(tmp_path / "d1.txt")
b1.save_model(d1name)
b.save_model(dname)
with open(dname, 'rt') as df:
dtxt = df.read()
with open(d1name, 'rt') as d1f:
d1txt = d1f.read()
self.assertEqual(dtxt, d1txt)
assert dtxt == d1txt
@unittest.skipIf(not lgb.compat.PANDAS_INSTALLED, 'pandas is not installed')
def test_add_features_from_different_sources(self):
import pandas as pd
def test_add_features_from_different_sources():
pd = pytest.importorskip("pandas")
n_row = 100
n_col = 5
X = np.random.random((n_row, n_col))
......@@ -235,14 +236,14 @@ class TestBasic(unittest.TestCase):
d1 = lgb.Dataset(x_1, feature_name=names, free_raw_data=True).construct()
d2 = lgb.Dataset(x_1, feature_name=names, free_raw_data=True).construct()
d1.add_features_from(d2)
self.assertIsNone(d1.data)
assert d1.data is None
# test that method works but sets raw data to None in case of immergeable data types
d1 = lgb.Dataset(x_1, feature_name=names, free_raw_data=False).construct()
d2 = lgb.Dataset([X[:n_row // 2, :], X[n_row // 2:, :]],
feature_name=names, free_raw_data=False).construct()
d1.add_features_from(d2)
self.assertIsNone(d1.data)
assert d1.data is None
# test that method works for different data types
d1 = lgb.Dataset(x_1, feature_name=names, free_raw_data=False).construct()
......@@ -251,12 +252,13 @@ class TestBasic(unittest.TestCase):
original_type = type(d1.get_data())
d2 = lgb.Dataset(x_2, feature_name=names, free_raw_data=False).construct()
d1.add_features_from(d2)
self.assertIsInstance(d1.get_data(), original_type)
self.assertTupleEqual(d1.get_data().shape, (n_row, n_col * idx))
assert isinstance(d1.get_data(), original_type)
assert d1.get_data().shape == (n_row, n_col * idx)
res_feature_names += ['D{}_{}'.format(idx, name) for name in names]
self.assertListEqual(d1.feature_name, res_feature_names)
assert d1.feature_name == res_feature_names
def test_cegb_affects_behavior(self):
def test_cegb_affects_behavior(tmp_path):
X = np.random.random((100, 5))
X[:, [1, 3]] = 0
y = np.random.random(100)
......@@ -279,14 +281,14 @@ class TestBasic(unittest.TestCase):
booster = lgb.Booster(train_set=ds, params=case)
for k in range(10):
booster.update()
with tempfile.NamedTemporaryFile() as f:
casename = f.name
casename = str(tmp_path / "casename.txt")
booster.save_model(casename)
with open(casename, 'rt') as f:
casetxt = f.read()
self.assertNotEqual(basetxt, casetxt)
assert basetxt != casetxt
def test_cegb_scaling_equalities(self):
def test_cegb_scaling_equalities(tmp_path):
X = np.random.random((100, 5))
X[:, [1, 3]] = 0
y = np.random.random(100)
......@@ -306,40 +308,38 @@ class TestBasic(unittest.TestCase):
for k in range(10):
booster1.update()
booster2.update()
with tempfile.NamedTemporaryFile() as f:
p1name = f.name
p1name = str(tmp_path / "p1.txt")
# Reset booster1's parameters to p2, so the parameter section of the file matches.
booster1.reset_parameter(p2)
booster1.save_model(p1name)
with open(p1name, 'rt') as f:
p1txt = f.read()
with tempfile.NamedTemporaryFile() as f:
p2name = f.name
p2name = str(tmp_path / "p2.txt")
booster2.save_model(p2name)
with open(p2name, 'rt') as f:
p2txt = f.read()
self.maxDiff = None
self.assertEqual(p1txt, p2txt)
assert p1txt == p2txt
def test_consistent_state_for_dataset_fields(self):
def test_consistent_state_for_dataset_fields():
def check_asserts(data):
np.testing.assert_allclose(data.label, data.get_label())
np.testing.assert_allclose(data.label, data.get_field('label'))
self.assertFalse(np.isnan(data.label[0]))
self.assertFalse(np.isinf(data.label[1]))
assert not np.isnan(data.label[0])
assert not np.isinf(data.label[1])
np.testing.assert_allclose(data.weight, data.get_weight())
np.testing.assert_allclose(data.weight, data.get_field('weight'))
self.assertFalse(np.isnan(data.weight[0]))
self.assertFalse(np.isinf(data.weight[1]))
assert not np.isnan(data.weight[0])
assert not np.isinf(data.weight[1])
np.testing.assert_allclose(data.init_score, data.get_init_score())
np.testing.assert_allclose(data.init_score, data.get_field('init_score'))
self.assertFalse(np.isnan(data.init_score[0]))
self.assertFalse(np.isinf(data.init_score[1]))
self.assertTrue(np.all(np.isclose([data.label[0], data.weight[0], data.init_score[0]],
data.label[0])))
self.assertAlmostEqual(data.label[1], data.weight[1])
self.assertListEqual(data.feature_name, data.get_feature_name())
assert not np.isnan(data.init_score[0])
assert not np.isinf(data.init_score[1])
assert np.all(np.isclose([data.label[0], data.weight[0], data.init_score[0]],
data.label[0]))
assert data.label[1] == pytest.approx(data.weight[1])
assert data.feature_name == data.get_feature_name()
X, y = load_breast_cancer(return_X_y=True)
sequence = np.ones(y.shape[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment