"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "2c4e68e49b4f4bcb331558f6bcdd6ed4c21e05d1"
Unverified Commit b29ed34f authored by Yiwen Song's avatar Yiwen Song Committed by GitHub
Browse files

Port test/test_datasets.py to use pytest (#4215)

* Port test_datasets.py to use pytest

* A better replacement of self.assertSequenceEqual

* Migrate from equality check to identity check
parent e2dbadbf
...@@ -14,6 +14,7 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tupl ...@@ -14,6 +14,7 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tupl
import PIL import PIL
import PIL.Image import PIL.Image
import pytest
import torch import torch
import torchvision.datasets import torchvision.datasets
import torchvision.io import torchvision.io
...@@ -519,18 +520,18 @@ class DatasetTestCase(unittest.TestCase): ...@@ -519,18 +520,18 @@ class DatasetTestCase(unittest.TestCase):
yield mocks yield mocks
def test_not_found_or_corrupted(self): def test_not_found_or_corrupted(self):
with self.assertRaises((FileNotFoundError, RuntimeError)): with pytest.raises((FileNotFoundError, RuntimeError)):
with self.create_dataset(inject_fake_data=False): with self.create_dataset(inject_fake_data=False):
pass pass
def test_smoke(self): def test_smoke(self):
with self.create_dataset() as (dataset, _): with self.create_dataset() as (dataset, _):
self.assertIsInstance(dataset, torchvision.datasets.VisionDataset) assert isinstance(dataset, torchvision.datasets.VisionDataset)
@test_all_configs @test_all_configs
def test_str_smoke(self, config): def test_str_smoke(self, config):
with self.create_dataset(config) as (dataset, _): with self.create_dataset(config) as (dataset, _):
self.assertIsInstance(str(dataset), str) assert isinstance(str(dataset), str)
@test_all_configs @test_all_configs
def test_feature_types(self, config): def test_feature_types(self, config):
...@@ -540,23 +541,21 @@ class DatasetTestCase(unittest.TestCase): ...@@ -540,23 +541,21 @@ class DatasetTestCase(unittest.TestCase):
if len(self.FEATURE_TYPES) > 1: if len(self.FEATURE_TYPES) > 1:
actual = len(example) actual = len(example)
expected = len(self.FEATURE_TYPES) expected = len(self.FEATURE_TYPES)
self.assertEqual( assert (
actual, actual == expected
expected, ), "The number of the returned features does not match the the number of elements in FEATURE_TYPES: "
f"The number of the returned features does not match the the number of elements in FEATURE_TYPES: " f"{actual} != {expected}"
f"{actual} != {expected}",
)
else: else:
example = (example,) example = (example,)
for idx, (feature, expected_feature_type) in enumerate(zip(example, self.FEATURE_TYPES)): for idx, (feature, expected_feature_type) in enumerate(zip(example, self.FEATURE_TYPES)):
with self.subTest(idx=idx): with self.subTest(idx=idx):
self.assertIsInstance(feature, expected_feature_type) assert isinstance(feature, expected_feature_type)
@test_all_configs @test_all_configs
def test_num_examples(self, config): def test_num_examples(self, config):
with self.create_dataset(config) as (dataset, info): with self.create_dataset(config) as (dataset, info):
self.assertEqual(len(dataset), info["num_examples"]) assert len(dataset) == info["num_examples"]
@test_all_configs @test_all_configs
def test_transforms(self, config): def test_transforms(self, config):
......
...@@ -16,6 +16,7 @@ import zipfile ...@@ -16,6 +16,7 @@ import zipfile
import PIL import PIL
import datasets_utils import datasets_utils
import numpy as np import numpy as np
import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torchvision import datasets from torchvision import datasets
...@@ -88,20 +89,20 @@ class STL10TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -88,20 +89,20 @@ class STL10TestCase(datasets_utils.ImageDatasetTestCase):
def test_folds(self): def test_folds(self):
for fold in range(10): for fold in range(10):
with self.create_dataset(split="train", folds=fold) as (dataset, _): with self.create_dataset(split="train", folds=fold) as (dataset, _):
self.assertEqual(len(dataset), fold + 1) assert len(dataset) == fold + 1
def test_unlabeled(self): def test_unlabeled(self):
with self.create_dataset(split="unlabeled") as (dataset, _): with self.create_dataset(split="unlabeled") as (dataset, _):
labels = [dataset[idx][1] for idx in range(len(dataset))] labels = [dataset[idx][1] for idx in range(len(dataset))]
self.assertTrue(all(label == -1 for label in labels)) assert all(label == -1 for label in labels)
def test_invalid_folds1(self): def test_invalid_folds1(self):
with self.assertRaises(ValueError): with pytest.raises(ValueError):
with self.create_dataset(folds=10): with self.create_dataset(folds=10):
pass pass
def test_invalid_folds2(self): def test_invalid_folds2(self):
with self.assertRaises(ValueError): with pytest.raises(ValueError):
with self.create_dataset(folds="0"): with self.create_dataset(folds="0"):
pass pass
...@@ -167,23 +168,19 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -167,23 +168,19 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
actual = len(individual_targets) actual = len(individual_targets)
expected = len(combined_targets) expected = len(combined_targets)
self.assertEqual( assert (
actual, actual == expected
expected, ), "The number of the returned combined targets does not match the the number targets if requested "
f"The number of the returned combined targets does not match the the number targets if requested " f"individually: {actual} != {expected}",
f"individually: {actual} != {expected}",
)
for target_type, combined_target, individual_target in zip(target_types, combined_targets, individual_targets): for target_type, combined_target, individual_target in zip(target_types, combined_targets, individual_targets):
with self.subTest(target_type=target_type): with self.subTest(target_type=target_type):
actual = type(combined_target) actual = type(combined_target)
expected = type(individual_target) expected = type(individual_target)
self.assertIs( assert (
actual, actual is expected
expected, ), "Type of the combined target does not match the type of the corresponding individual target: "
f"Type of the combined target does not match the type of the corresponding individual target: " f"{actual} is not {expected}",
f"{actual} is not {expected}",
)
class Caltech256TestCase(datasets_utils.ImageDatasetTestCase): class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
...@@ -363,26 +360,26 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -363,26 +360,26 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
with self.create_dataset(target_type=target_types) as (dataset, _): with self.create_dataset(target_type=target_types) as (dataset, _):
output = dataset[0] output = dataset[0]
self.assertTrue(isinstance(output, tuple)) assert isinstance(output, tuple)
self.assertTrue(len(output) == 2) assert len(output) == 2
self.assertTrue(isinstance(output[0], PIL.Image.Image)) assert isinstance(output[0], PIL.Image.Image)
self.assertTrue(isinstance(output[1], tuple)) assert isinstance(output[1], tuple)
self.assertTrue(len(output[1]) == 3) assert len(output[1]) == 3
self.assertTrue(isinstance(output[1][0], PIL.Image.Image)) # semantic assert isinstance(output[1][0], PIL.Image.Image) # semantic
self.assertTrue(isinstance(output[1][1], dict)) # polygon assert isinstance(output[1][1], dict) # polygon
self.assertTrue(isinstance(output[1][2], PIL.Image.Image)) # color assert isinstance(output[1][2], PIL.Image.Image) # color
def test_feature_types_target_color(self): def test_feature_types_target_color(self):
with self.create_dataset(target_type='color') as (dataset, _): with self.create_dataset(target_type='color') as (dataset, _):
color_img, color_target = dataset[0] color_img, color_target = dataset[0]
self.assertTrue(isinstance(color_img, PIL.Image.Image)) assert isinstance(color_img, PIL.Image.Image)
self.assertTrue(np.array(color_target).shape[2] == 4) assert np.array(color_target).shape[2] == 4
def test_feature_types_target_polygon(self): def test_feature_types_target_polygon(self):
with self.create_dataset(target_type='polygon') as (dataset, info): with self.create_dataset(target_type='polygon') as (dataset, info):
polygon_img, polygon_target = dataset[0] polygon_img, polygon_target = dataset[0]
self.assertTrue(isinstance(polygon_img, PIL.Image.Image)) assert isinstance(polygon_img, PIL.Image.Image)
self.assertEqual(polygon_target, info['expected_polygon_target']) (polygon_target, info['expected_polygon_target'])
class ImageNetTestCase(datasets_utils.ImageDatasetTestCase): class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
...@@ -469,7 +466,7 @@ class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -469,7 +466,7 @@ class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
with self.create_dataset() as (dataset, info): with self.create_dataset() as (dataset, info):
expected = {category: label for label, category in enumerate(info["categories"])} expected = {category: label for label, category in enumerate(info["categories"])}
actual = dataset.class_to_idx actual = dataset.class_to_idx
self.assertEqual(actual, expected) assert actual == expected
class CIFAR100(CIFAR10TestCase): class CIFAR100(CIFAR10TestCase):
...@@ -573,33 +570,29 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase): ...@@ -573,33 +570,29 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
actual = len(individual_targets) actual = len(individual_targets)
expected = len(combined_targets) expected = len(combined_targets)
self.assertEqual( assert (
actual, actual == expected
expected, ), "The number of the returned combined targets does not match the the number targets if requested "
f"The number of the returned combined targets does not match the the number targets if requested " f"individually: {actual} != {expected}",
f"individually: {actual} != {expected}",
)
for target_type, combined_target, individual_target in zip(target_types, combined_targets, individual_targets): for target_type, combined_target, individual_target in zip(target_types, combined_targets, individual_targets):
with self.subTest(target_type=target_type): with self.subTest(target_type=target_type):
actual = type(combined_target) actual = type(combined_target)
expected = type(individual_target) expected = type(individual_target)
self.assertIs( assert (
actual, actual is expected
expected, ), "Type of the combined target does not match the type of the corresponding individual target: "
f"Type of the combined target does not match the type of the corresponding individual target: " f"{actual} is not {expected}",
f"{actual} is not {expected}",
)
def test_no_target(self): def test_no_target(self):
with self.create_dataset(target_type=[]) as (dataset, _): with self.create_dataset(target_type=[]) as (dataset, _):
_, target = dataset[0] _, target = dataset[0]
self.assertIsNone(target) assert target is None
def test_attr_names(self): def test_attr_names(self):
with self.create_dataset() as (dataset, info): with self.create_dataset() as (dataset, info):
self.assertEqual(tuple(dataset.attr_names), info["attr_names"]) assert tuple(dataset.attr_names) == info["attr_names"]
class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase): class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
...@@ -704,16 +697,16 @@ class VOCDetectionTestCase(VOCSegmentationTestCase): ...@@ -704,16 +697,16 @@ class VOCDetectionTestCase(VOCSegmentationTestCase):
with self.create_dataset() as (dataset, info): with self.create_dataset() as (dataset, info):
_, target = dataset[0] _, target = dataset[0]
self.assertIn("annotation", target) assert "annotation" in target
annotation = target["annotation"] annotation = target["annotation"]
self.assertIn("object", annotation) assert "object" in annotation
objects = annotation["object"] objects = annotation["object"]
self.assertEqual(len(objects), 1) assert len(objects) == 1
object = objects[0] object = objects[0]
self.assertEqual(object, info["annotation"]) assert object == info["annotation"]
class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase): class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
...@@ -789,7 +782,7 @@ class CocoCaptionsTestCase(CocoDetectionTestCase): ...@@ -789,7 +782,7 @@ class CocoCaptionsTestCase(CocoDetectionTestCase):
def test_captions(self): def test_captions(self):
with self.create_dataset() as (dataset, info): with self.create_dataset() as (dataset, info):
_, captions = dataset[0] _, captions = dataset[0]
self.assertEqual(tuple(captions), tuple(info["captions"])) assert tuple(captions) == tuple(info["captions"])
class UCF101TestCase(datasets_utils.VideoDatasetTestCase): class UCF101TestCase(datasets_utils.VideoDatasetTestCase):
...@@ -940,7 +933,7 @@ class LSUNTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -940,7 +933,7 @@ class LSUNTestCase(datasets_utils.ImageDatasetTestCase):
def test_not_found_or_corrupted(self): def test_not_found_or_corrupted(self):
# LSUN does not raise built-in exception, but a custom one. It is expressive enough to not 'cast' it to # LSUN does not raise built-in exception, but a custom one. It is expressive enough to not 'cast' it to
# RuntimeError or FileNotFoundError that are normally checked by this test. # RuntimeError or FileNotFoundError that are normally checked by this test.
with self.assertRaises(datasets_utils.lazy_importer.lmdb.Error): with pytest.raises(datasets_utils.lazy_importer.lmdb.Error):
super().test_not_found_or_corrupted() super().test_not_found_or_corrupted()
...@@ -1369,7 +1362,8 @@ class Flickr8kTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1369,7 +1362,8 @@ class Flickr8kTestCase(datasets_utils.ImageDatasetTestCase):
def test_captions(self): def test_captions(self):
with self.create_dataset() as (dataset, info): with self.create_dataset() as (dataset, info):
_, captions = dataset[0] _, captions = dataset[0]
self.assertSequenceEqual(captions, info["captions"]) assert len(captions) == len(info["captions"])
assert all([a == b for a, b in zip(captions, info["captions"])])
class Flickr30kTestCase(Flickr8kTestCase): class Flickr30kTestCase(Flickr8kTestCase):
...@@ -1513,7 +1507,7 @@ class QMNISTTestCase(MNISTTestCase): ...@@ -1513,7 +1507,7 @@ class QMNISTTestCase(MNISTTestCase):
with self.create_dataset(what="test50k") as (dataset, info): with self.create_dataset(what="test50k") as (dataset, info):
# Since the split 'test50k' selects all images beginning from the index 10000, we subtract the number of # Since the split 'test50k' selects all images beginning from the index 10000, we subtract the number of
# created examples by this. # created examples by this.
self.assertEqual(len(dataset), info["num_examples"] - 10000) assert len(dataset) == info["num_examples"] - 10000
class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase): class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
...@@ -1578,12 +1572,13 @@ class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1578,12 +1572,13 @@ class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
with self.create_dataset( with self.create_dataset(
config, extensions=None, is_valid_file=lambda file: pathlib.Path(file).suffix[1:] in extensions config, extensions=None, is_valid_file=lambda file: pathlib.Path(file).suffix[1:] in extensions
) as (dataset, info): ) as (dataset, info):
self.assertEqual(len(dataset), info["num_examples"]) assert len(dataset) == info["num_examples"]
@datasets_utils.test_all_configs @datasets_utils.test_all_configs
def test_classes(self, config): def test_classes(self, config):
with self.create_dataset(config) as (dataset, info): with self.create_dataset(config) as (dataset, info):
self.assertSequenceEqual(dataset.classes, info["classes"]) assert len(dataset.classes) == len(info["classes"])
assert all([a == b for a, b in zip(dataset.classes, info["classes"])])
class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase): class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase):
...@@ -1603,7 +1598,8 @@ class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1603,7 +1598,8 @@ class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase):
@datasets_utils.test_all_configs @datasets_utils.test_all_configs
def test_classes(self, config): def test_classes(self, config):
with self.create_dataset(config) as (dataset, info): with self.create_dataset(config) as (dataset, info):
self.assertSequenceEqual(dataset.classes, info["classes"]) assert len(dataset.classes) == len(info["classes"])
assert all([a == b for a, b in zip(dataset.classes, info["classes"])])
class KittiTestCase(datasets_utils.ImageDatasetTestCase): class KittiTestCase(datasets_utils.ImageDatasetTestCase):
...@@ -1742,15 +1738,15 @@ class Places365TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1742,15 +1738,15 @@ class Places365TestCase(datasets_utils.ImageDatasetTestCase):
def test_classes(self): def test_classes(self):
classes = list(map(lambda x: x[0], self._CATEGORIES_CONTENT)) classes = list(map(lambda x: x[0], self._CATEGORIES_CONTENT))
with self.create_dataset() as (dataset, _): with self.create_dataset() as (dataset, _):
self.assertEqual(dataset.classes, classes) assert dataset.classes == classes
def test_class_to_idx(self): def test_class_to_idx(self):
class_to_idx = dict(self._CATEGORIES_CONTENT) class_to_idx = dict(self._CATEGORIES_CONTENT)
with self.create_dataset() as (dataset, _): with self.create_dataset() as (dataset, _):
self.assertEqual(dataset.class_to_idx, class_to_idx) assert dataset.class_to_idx == class_to_idx
def test_images_download_preexisting(self): def test_images_download_preexisting(self):
with self.assertRaises(RuntimeError): with pytest.raises(RuntimeError):
with self.create_dataset({'download': True}): with self.create_dataset({'download': True}):
pass pass
...@@ -1788,9 +1784,9 @@ class INaturalistTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1788,9 +1784,9 @@ class INaturalistTestCase(datasets_utils.ImageDatasetTestCase):
with self.create_dataset(target_type=target_types, version="2021_valid") as (dataset, _): with self.create_dataset(target_type=target_types, version="2021_valid") as (dataset, _):
items = [d[1] for d in dataset] items = [d[1] for d in dataset]
for i, item in enumerate(items): for i, item in enumerate(items):
self.assertEqual(dataset.category_name("kingdom", item[0]), "Akingdom") assert dataset.category_name("kingdom", item[0]) == "Akingdom"
self.assertEqual(dataset.category_name("phylum", item[1]), f"{i // 3}phylum") assert dataset.category_name("phylum", item[1]) == f"{i // 3}phylum"
self.assertEqual(item[6], i // 3) assert item[6] == i // 3
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment