Unverified Commit b29ed34f authored by Yiwen Song's avatar Yiwen Song Committed by GitHub
Browse files

Port test/test_datasets.py to use pytest (#4215)

* Port test_datasets.py to use pytest

* A better replacement of self.assertSequenceEqual

* Migrate from equality check to identity check
parent e2dbadbf
......@@ -14,6 +14,7 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tupl
import PIL
import PIL.Image
import pytest
import torch
import torchvision.datasets
import torchvision.io
......@@ -519,18 +520,18 @@ class DatasetTestCase(unittest.TestCase):
yield mocks
def test_not_found_or_corrupted(self):
with self.assertRaises((FileNotFoundError, RuntimeError)):
with pytest.raises((FileNotFoundError, RuntimeError)):
with self.create_dataset(inject_fake_data=False):
pass
def test_smoke(self):
with self.create_dataset() as (dataset, _):
self.assertIsInstance(dataset, torchvision.datasets.VisionDataset)
assert isinstance(dataset, torchvision.datasets.VisionDataset)
@test_all_configs
def test_str_smoke(self, config):
with self.create_dataset(config) as (dataset, _):
self.assertIsInstance(str(dataset), str)
assert isinstance(str(dataset), str)
@test_all_configs
def test_feature_types(self, config):
......@@ -540,23 +541,21 @@ class DatasetTestCase(unittest.TestCase):
if len(self.FEATURE_TYPES) > 1:
actual = len(example)
expected = len(self.FEATURE_TYPES)
self.assertEqual(
actual,
expected,
f"The number of the returned features does not match the the number of elements in FEATURE_TYPES: "
f"{actual} != {expected}",
)
assert (
actual == expected
), "The number of the returned features does not match the the number of elements in FEATURE_TYPES: "
f"{actual} != {expected}"
else:
example = (example,)
for idx, (feature, expected_feature_type) in enumerate(zip(example, self.FEATURE_TYPES)):
with self.subTest(idx=idx):
self.assertIsInstance(feature, expected_feature_type)
assert isinstance(feature, expected_feature_type)
@test_all_configs
def test_num_examples(self, config):
with self.create_dataset(config) as (dataset, info):
self.assertEqual(len(dataset), info["num_examples"])
assert len(dataset) == info["num_examples"]
@test_all_configs
def test_transforms(self, config):
......
......@@ -16,6 +16,7 @@ import zipfile
import PIL
import datasets_utils
import numpy as np
import pytest
import torch
import torch.nn.functional as F
from torchvision import datasets
......@@ -88,20 +89,20 @@ class STL10TestCase(datasets_utils.ImageDatasetTestCase):
def test_folds(self):
for fold in range(10):
with self.create_dataset(split="train", folds=fold) as (dataset, _):
self.assertEqual(len(dataset), fold + 1)
assert len(dataset) == fold + 1
def test_unlabeled(self):
with self.create_dataset(split="unlabeled") as (dataset, _):
labels = [dataset[idx][1] for idx in range(len(dataset))]
self.assertTrue(all(label == -1 for label in labels))
assert all(label == -1 for label in labels)
def test_invalid_folds1(self):
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
with self.create_dataset(folds=10):
pass
def test_invalid_folds2(self):
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
with self.create_dataset(folds="0"):
pass
......@@ -167,23 +168,19 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
actual = len(individual_targets)
expected = len(combined_targets)
self.assertEqual(
actual,
expected,
f"The number of the returned combined targets does not match the the number targets if requested "
assert (
actual == expected
), "The number of the returned combined targets does not match the the number targets if requested "
f"individually: {actual} != {expected}",
)
for target_type, combined_target, individual_target in zip(target_types, combined_targets, individual_targets):
with self.subTest(target_type=target_type):
actual = type(combined_target)
expected = type(individual_target)
self.assertIs(
actual,
expected,
f"Type of the combined target does not match the type of the corresponding individual target: "
assert (
actual is expected
), "Type of the combined target does not match the type of the corresponding individual target: "
f"{actual} is not {expected}",
)
class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
......@@ -363,26 +360,26 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
with self.create_dataset(target_type=target_types) as (dataset, _):
output = dataset[0]
self.assertTrue(isinstance(output, tuple))
self.assertTrue(len(output) == 2)
self.assertTrue(isinstance(output[0], PIL.Image.Image))
self.assertTrue(isinstance(output[1], tuple))
self.assertTrue(len(output[1]) == 3)
self.assertTrue(isinstance(output[1][0], PIL.Image.Image)) # semantic
self.assertTrue(isinstance(output[1][1], dict)) # polygon
self.assertTrue(isinstance(output[1][2], PIL.Image.Image)) # color
assert isinstance(output, tuple)
assert len(output) == 2
assert isinstance(output[0], PIL.Image.Image)
assert isinstance(output[1], tuple)
assert len(output[1]) == 3
assert isinstance(output[1][0], PIL.Image.Image) # semantic
assert isinstance(output[1][1], dict) # polygon
assert isinstance(output[1][2], PIL.Image.Image) # color
def test_feature_types_target_color(self):
with self.create_dataset(target_type='color') as (dataset, _):
color_img, color_target = dataset[0]
self.assertTrue(isinstance(color_img, PIL.Image.Image))
self.assertTrue(np.array(color_target).shape[2] == 4)
assert isinstance(color_img, PIL.Image.Image)
assert np.array(color_target).shape[2] == 4
def test_feature_types_target_polygon(self):
with self.create_dataset(target_type='polygon') as (dataset, info):
polygon_img, polygon_target = dataset[0]
self.assertTrue(isinstance(polygon_img, PIL.Image.Image))
self.assertEqual(polygon_target, info['expected_polygon_target'])
assert isinstance(polygon_img, PIL.Image.Image)
(polygon_target, info['expected_polygon_target'])
class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
......@@ -469,7 +466,7 @@ class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
with self.create_dataset() as (dataset, info):
expected = {category: label for label, category in enumerate(info["categories"])}
actual = dataset.class_to_idx
self.assertEqual(actual, expected)
assert actual == expected
class CIFAR100(CIFAR10TestCase):
......@@ -573,33 +570,29 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
actual = len(individual_targets)
expected = len(combined_targets)
self.assertEqual(
actual,
expected,
f"The number of the returned combined targets does not match the the number targets if requested "
assert (
actual == expected
), "The number of the returned combined targets does not match the the number targets if requested "
f"individually: {actual} != {expected}",
)
for target_type, combined_target, individual_target in zip(target_types, combined_targets, individual_targets):
with self.subTest(target_type=target_type):
actual = type(combined_target)
expected = type(individual_target)
self.assertIs(
actual,
expected,
f"Type of the combined target does not match the type of the corresponding individual target: "
assert (
actual is expected
), "Type of the combined target does not match the type of the corresponding individual target: "
f"{actual} is not {expected}",
)
def test_no_target(self):
with self.create_dataset(target_type=[]) as (dataset, _):
_, target = dataset[0]
self.assertIsNone(target)
assert target is None
def test_attr_names(self):
with self.create_dataset() as (dataset, info):
self.assertEqual(tuple(dataset.attr_names), info["attr_names"])
assert tuple(dataset.attr_names) == info["attr_names"]
class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
......@@ -704,16 +697,16 @@ class VOCDetectionTestCase(VOCSegmentationTestCase):
with self.create_dataset() as (dataset, info):
_, target = dataset[0]
self.assertIn("annotation", target)
assert "annotation" in target
annotation = target["annotation"]
self.assertIn("object", annotation)
assert "object" in annotation
objects = annotation["object"]
self.assertEqual(len(objects), 1)
assert len(objects) == 1
object = objects[0]
self.assertEqual(object, info["annotation"])
assert object == info["annotation"]
class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
......@@ -789,7 +782,7 @@ class CocoCaptionsTestCase(CocoDetectionTestCase):
def test_captions(self):
with self.create_dataset() as (dataset, info):
_, captions = dataset[0]
self.assertEqual(tuple(captions), tuple(info["captions"]))
assert tuple(captions) == tuple(info["captions"])
class UCF101TestCase(datasets_utils.VideoDatasetTestCase):
......@@ -940,7 +933,7 @@ class LSUNTestCase(datasets_utils.ImageDatasetTestCase):
def test_not_found_or_corrupted(self):
# LSUN does not raise built-in exception, but a custom one. It is expressive enough to not 'cast' it to
# RuntimeError or FileNotFoundError that are normally checked by this test.
with self.assertRaises(datasets_utils.lazy_importer.lmdb.Error):
with pytest.raises(datasets_utils.lazy_importer.lmdb.Error):
super().test_not_found_or_corrupted()
......@@ -1369,7 +1362,8 @@ class Flickr8kTestCase(datasets_utils.ImageDatasetTestCase):
def test_captions(self):
with self.create_dataset() as (dataset, info):
_, captions = dataset[0]
self.assertSequenceEqual(captions, info["captions"])
assert len(captions) == len(info["captions"])
assert all([a == b for a, b in zip(captions, info["captions"])])
class Flickr30kTestCase(Flickr8kTestCase):
......@@ -1513,7 +1507,7 @@ class QMNISTTestCase(MNISTTestCase):
with self.create_dataset(what="test50k") as (dataset, info):
# Since the split 'test50k' selects all images beginning from the index 10000, we subtract the number of
# created examples by this.
self.assertEqual(len(dataset), info["num_examples"] - 10000)
assert len(dataset) == info["num_examples"] - 10000
class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
......@@ -1578,12 +1572,13 @@ class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
with self.create_dataset(
config, extensions=None, is_valid_file=lambda file: pathlib.Path(file).suffix[1:] in extensions
) as (dataset, info):
self.assertEqual(len(dataset), info["num_examples"])
assert len(dataset) == info["num_examples"]
@datasets_utils.test_all_configs
def test_classes(self, config):
with self.create_dataset(config) as (dataset, info):
self.assertSequenceEqual(dataset.classes, info["classes"])
assert len(dataset.classes) == len(info["classes"])
assert all([a == b for a, b in zip(dataset.classes, info["classes"])])
class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase):
......@@ -1603,7 +1598,8 @@ class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase):
@datasets_utils.test_all_configs
def test_classes(self, config):
with self.create_dataset(config) as (dataset, info):
self.assertSequenceEqual(dataset.classes, info["classes"])
assert len(dataset.classes) == len(info["classes"])
assert all([a == b for a, b in zip(dataset.classes, info["classes"])])
class KittiTestCase(datasets_utils.ImageDatasetTestCase):
......@@ -1742,15 +1738,15 @@ class Places365TestCase(datasets_utils.ImageDatasetTestCase):
def test_classes(self):
classes = list(map(lambda x: x[0], self._CATEGORIES_CONTENT))
with self.create_dataset() as (dataset, _):
self.assertEqual(dataset.classes, classes)
assert dataset.classes == classes
def test_class_to_idx(self):
class_to_idx = dict(self._CATEGORIES_CONTENT)
with self.create_dataset() as (dataset, _):
self.assertEqual(dataset.class_to_idx, class_to_idx)
assert dataset.class_to_idx == class_to_idx
def test_images_download_preexisting(self):
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
with self.create_dataset({'download': True}):
pass
......@@ -1788,9 +1784,9 @@ class INaturalistTestCase(datasets_utils.ImageDatasetTestCase):
with self.create_dataset(target_type=target_types, version="2021_valid") as (dataset, _):
items = [d[1] for d in dataset]
for i, item in enumerate(items):
self.assertEqual(dataset.category_name("kingdom", item[0]), "Akingdom")
self.assertEqual(dataset.category_name("phylum", item[1]), f"{i // 3}phylum")
self.assertEqual(item[6], i // 3)
assert dataset.category_name("kingdom", item[0]) == "Akingdom"
assert dataset.category_name("phylum", item[1]) == f"{i // 3}phylum"
assert item[6] == i // 3
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment