Commit cc26cd81 authored by panning's avatar panning
Browse files

merge v0.16.0

parents f78f29f5 fbb4cc54
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
import collections.abc
import dataclasses
from typing import Optional, Sequence
import pytest
import torch
from torch.nn.functional import one_hot
from torchvision.prototype import tv_tensors
from transforms_v2_legacy_utils import combinations_grid, DEFAULT_EXTRA_DIMS, from_loader, from_loaders, TensorLoader
@dataclasses.dataclass
class LabelLoader(TensorLoader):
categories: Optional[Sequence[str]]
def _parse_categories(categories):
if categories is None:
num_categories = int(torch.randint(1, 11, ()))
elif isinstance(categories, int):
num_categories = categories
categories = [f"category{idx}" for idx in range(num_categories)]
elif isinstance(categories, collections.abc.Sequence) and all(isinstance(category, str) for category in categories):
categories = list(categories)
num_categories = len(categories)
else:
raise pytest.UsageError(
f"`categories` can either be `None` (default), an integer, or a sequence of strings, "
f"but got '{categories}' instead."
)
return categories, num_categories
def make_label_loader(*, extra_dims=(), categories=None, dtype=torch.int64):
categories, num_categories = _parse_categories(categories)
def fn(shape, dtype, device):
# The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values,
# regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123
data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype)
return tv_tensors.Label(data, categories=categories)
return LabelLoader(fn, shape=extra_dims, dtype=dtype, categories=categories)
make_label = from_loader(make_label_loader)
@dataclasses.dataclass
class OneHotLabelLoader(TensorLoader):
categories: Optional[Sequence[str]]
def make_one_hot_label_loader(*, categories=None, extra_dims=(), dtype=torch.int64):
categories, num_categories = _parse_categories(categories)
def fn(shape, dtype, device):
if num_categories == 0:
data = torch.empty(shape, dtype=dtype, device=device)
else:
# The idiom `make_label_loader(..., dtype=torch.int64); ...; one_hot(...).to(dtype)` is intentional
# since `one_hot` only supports int64
label = make_label_loader(extra_dims=extra_dims, categories=num_categories, dtype=torch.int64).load(device)
data = one_hot(label, num_classes=num_categories).to(dtype)
return tv_tensors.OneHotLabel(data, categories=categories)
return OneHotLabelLoader(fn, shape=(*extra_dims, num_categories), dtype=dtype, categories=categories)
def make_one_hot_label_loaders(
*,
categories=(1, 0, None),
extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.int64, torch.float32),
):
for params in combinations_grid(categories=categories, extra_dims=extra_dims, dtype=dtypes):
yield make_one_hot_label_loader(**params)
make_one_hot_labels = from_loaders(make_one_hot_label_loaders)
"""Run smoke tests"""
import sys
from pathlib import Path
import torch
import torchvision
import torchvision.datasets as dset
import torchvision.transforms
from torchvision.io import decode_jpeg, read_file, read_image
from torchvision.models import resnet50, ResNet50_Weights
SCRIPT_DIR = Path(__file__).parent
def smoke_test_torchvision() -> None:
print(
"Is torchvision usable?",
all(x is not None for x in [torch.ops.image.decode_png, torch.ops.torchvision.roi_align]),
)
def smoke_test_torchvision_read_decode() -> None:
img_jpg = read_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
if img_jpg.shape != (3, 606, 517):
raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")
img_png = read_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png"))
if img_png.shape != (4, 471, 354):
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
def smoke_test_torchvision_decode_jpeg(device: str = "cpu"):
img_jpg_data = read_file(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
img_jpg = decode_jpeg(img_jpg_data, device=device)
if img_jpg.shape != (3, 606, 517):
raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")
def smoke_test_compile() -> None:
try:
model = resnet50().cuda()
model = torch.compile(model)
x = torch.randn(1, 3, 224, 224, device="cuda")
out = model(x)
print(f"torch.compile model output: {out.shape}")
except RuntimeError:
if sys.platform == "win32":
print("Successfully caught torch.compile RuntimeError on win")
elif sys.version_info >= (3, 11, 0):
print("Successfully caught torch.compile RuntimeError on Python 3.11")
else:
raise
def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device)
# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights).to(device)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
expected_category = "German shepherd"
print(f"{category_name} ({device}): {100 * score:.1f}%")
if category_name != expected_category:
raise RuntimeError(f"Failed ResNet50 classify {category_name} Expected: {expected_category}")
def main() -> None:
print(f"torchvision: {torchvision.__version__}")
print(f"torch.cuda.is_available: {torch.cuda.is_available()}")
# Turn 1.11.0aHASH into 1.11 (major.minor only)
version = ".".join(torchvision.__version__.split(".")[:2])
if version >= "0.16":
print(f"{torch.ops.image._jpeg_version() = }")
assert torch.ops.image._is_compiled_against_turbo()
smoke_test_torchvision()
smoke_test_torchvision_read_decode()
smoke_test_torchvision_resnet50_classify()
smoke_test_torchvision_decode_jpeg()
if torch.cuda.is_available():
smoke_test_torchvision_decode_jpeg("cuda")
smoke_test_torchvision_resnet50_classify("cuda")
smoke_test_compile()
if torch.backends.mps.is_available():
smoke_test_torchvision_resnet50_classify("mps")
if __name__ == "__main__":
main()
......@@ -20,7 +20,7 @@ class MaxvitTester(unittest.TestCase):
x_hat = partition(x, partition_size)
x_hat = departition(x_hat, partition_size, n_partitions, n_partitions)
assert torch.allclose(x, x_hat)
torch.testing.assert_close(x, x_hat)
def test_maxvit_grid_partition(self):
input_shape = (1, 3, 224, 224)
......@@ -39,7 +39,7 @@ class MaxvitTester(unittest.TestCase):
x_hat = post_swap(x_hat)
x_hat = departition(x_hat, n_partitions, partition_size, partition_size)
assert torch.allclose(x, x_hat)
torch.testing.assert_close(x, x_hat)
if __name__ == "__main__":
......
......@@ -194,7 +194,7 @@ class TestFxFeatureExtraction:
assert n1 == n2
assert p1.equal(p2)
# And that ouputs match
# And that outputs match
with torch.no_grad():
ilg_out = ilg_model(self.inp)
fgn_out = fx_model(self.inp)
......
......@@ -8,6 +8,7 @@ import os
import pathlib
import pickle
import random
import re
import shutil
import string
import unittest
......@@ -21,12 +22,13 @@ import PIL
import pytest
import torch
import torch.nn.functional as F
from common_utils import combinations_grid
from torchvision import datasets
class STL10TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.STL10
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test", "unlabeled", "train+unlabeled"))
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test", "unlabeled", "train+unlabeled"))
@staticmethod
def _make_binary_file(num_elements, root, name):
......@@ -112,9 +114,7 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Caltech101
FEATURE_TYPES = (PIL.Image.Image, (int, np.ndarray, tuple))
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
target_type=("category", "annotation", ["category", "annotation"])
)
ADDITIONAL_CONFIGS = combinations_grid(target_type=("category", "annotation", ["category", "annotation"]))
REQUIRED_PACKAGES = ("scipy",)
def inject_fake_data(self, tmpdir, config):
......@@ -183,6 +183,10 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
), "Type of the combined target does not match the type of the corresponding individual target: "
f"{actual} is not {expected}",
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(target_type="category") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Caltech256
......@@ -190,7 +194,7 @@ class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir) / "caltech256" / "256_ObjectCategories"
categories = ((1, "ak47"), (127, "laptop-101"), (257, "clutter"))
categories = ((1, "ak47"), (2, "american-flag"), (3, "backpack"))
num_images_per_category = 2
for idx, category in categories:
......@@ -207,7 +211,7 @@ class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.WIDERFace
FEATURE_TYPES = (PIL.Image.Image, (dict, type(None))) # test split returns None as target
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test"))
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
def inject_fake_data(self, tmpdir, config):
widerface_dir = pathlib.Path(tmpdir) / "widerface"
......@@ -258,6 +262,10 @@ class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase):
return split_to_num_examples[config["split"]]
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Cityscapes
......@@ -268,8 +276,8 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
"color",
)
ADDITIONAL_CONFIGS = (
*datasets_utils.combinations_grid(mode=("fine",), split=("train", "test", "val"), target_type=TARGET_TYPES),
*datasets_utils.combinations_grid(
*combinations_grid(mode=("fine",), split=("train", "test", "val"), target_type=TARGET_TYPES),
*combinations_grid(
mode=("coarse",),
split=("train", "train_extra", "val"),
target_type=TARGET_TYPES,
......@@ -382,11 +390,16 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
assert isinstance(polygon_img, PIL.Image.Image)
(polygon_target, info["expected_polygon_target"])
def test_transforms_v2_wrapper_spawn(self):
for target_type in ["instance", "semantic", ["instance", "semantic"]]:
with self.create_dataset(target_type=target_type) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.ImageNet
REQUIRED_PACKAGES = ("scipy",)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val"))
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val"))
def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir)
......@@ -413,10 +426,14 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
torch.save((wnid_to_classes, None), tmpdir / "meta.bin")
return num_examples
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CIFAR10
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))
ADDITIONAL_CONFIGS = combinations_grid(train=(True, False))
_VERSION_CONFIG = dict(
base_folder="cifar-10-batches-py",
......@@ -489,7 +506,7 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CelebA
FEATURE_TYPES = (PIL.Image.Image, (torch.Tensor, int, tuple, type(None)))
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
ADDITIONAL_CONFIGS = combinations_grid(
split=("train", "valid", "test", "all"),
target_type=("attr", "identity", "bbox", "landmarks", ["attr", "identity"]),
)
......@@ -607,15 +624,18 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
assert merged_imgs_names == all_imgs_names
def test_transforms_v2_wrapper_spawn(self):
for target_type in ["identity", "bbox", ["identity", "bbox"]]:
with self.create_dataset(target_type=target_type) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.VOCSegmentation
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image)
ADDITIONAL_CONFIGS = (
*datasets_utils.combinations_grid(
year=[f"20{year:02d}" for year in range(7, 13)], image_set=("train", "val", "trainval")
),
*combinations_grid(year=[f"20{year:02d}" for year in range(7, 13)], image_set=("train", "val", "trainval")),
dict(year="2007", image_set="test"),
)
......@@ -696,6 +716,10 @@ class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
return data
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class VOCDetectionTestCase(VOCSegmentationTestCase):
DATASET_CLASS = datasets.VOCDetection
......@@ -716,6 +740,10 @@ class VOCDetectionTestCase(VOCSegmentationTestCase):
assert object == info["annotation"]
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CocoDetection
......@@ -763,11 +791,21 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
return info
def _create_annotations(self, image_ids, num_annotations_per_image):
annotations = datasets_utils.combinations_grid(
image_id=image_ids, bbox=([1.0, 2.0, 3.0, 4.0],) * num_annotations_per_image
)
for id, annotation in enumerate(annotations):
annotation["id"] = id
annotations = []
annotion_id = 0
for image_id in itertools.islice(itertools.cycle(image_ids), len(image_ids) * num_annotations_per_image):
annotations.append(
dict(
image_id=image_id,
id=annotion_id,
bbox=torch.rand(4).tolist(),
segmentation=[torch.rand(8).tolist()],
category_id=int(torch.randint(91, ())),
area=float(torch.rand(1)),
iscrowd=int(torch.randint(2, size=(1,))),
)
)
annotion_id += 1
return annotations, dict()
def _create_json(self, root, name, content):
......@@ -776,13 +814,17 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
json.dump(content, fh)
return file
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class CocoCaptionsTestCase(CocoDetectionTestCase):
DATASET_CLASS = datasets.CocoCaptions
def _create_annotations(self, image_ids, num_annotations_per_image):
captions = [str(idx) for idx in range(num_annotations_per_image)]
annotations = datasets_utils.combinations_grid(image_id=image_ids, caption=captions)
annotations = combinations_grid(image_id=image_ids, caption=captions)
for id, annotation in enumerate(annotations):
annotation["id"] = id
return annotations, dict(captions=captions)
......@@ -792,11 +834,16 @@ class CocoCaptionsTestCase(CocoDetectionTestCase):
_, captions = dataset[0]
assert tuple(captions) == tuple(info["captions"])
def test_transforms_v2_wrapper_spawn(self):
# We need to define this method, because otherwise the test from the super class will
# be run
pytest.skip("CocoCaptions is currently not supported by the v2 wrapper.")
class UCF101TestCase(datasets_utils.VideoDatasetTestCase):
DATASET_CLASS = datasets.UCF101
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(fold=(1, 2, 3), train=(True, False))
ADDITIONAL_CONFIGS = combinations_grid(fold=(1, 2, 3), train=(True, False))
_VIDEO_FOLDER = "videos"
_ANNOTATIONS_FOLDER = "annotations"
......@@ -857,9 +904,7 @@ class LSUNTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.LSUN
REQUIRED_PACKAGES = ("lmdb",)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
classes=("train", "test", "val", ["bedroom_train", "church_outdoor_train"])
)
ADDITIONAL_CONFIGS = combinations_grid(classes=("train", "test", "val", ["bedroom_train", "church_outdoor_train"]))
_CATEGORIES = (
"bedroom",
......@@ -944,7 +989,7 @@ class LSUNTestCase(datasets_utils.ImageDatasetTestCase):
class KineticsTestCase(datasets_utils.VideoDatasetTestCase):
DATASET_CLASS = datasets.Kinetics
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val"), num_classes=("400", "600", "700"))
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val"), num_classes=("400", "600", "700"))
def inject_fake_data(self, tmpdir, config):
classes = ("Abseiling", "Zumba")
......@@ -960,11 +1005,15 @@ class KineticsTestCase(datasets_utils.VideoDatasetTestCase):
)
return num_videos_per_class * len(classes)
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(output_format="TCHW") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
DATASET_CLASS = datasets.HMDB51
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(fold=(1, 2, 3), train=(True, False))
ADDITIONAL_CONFIGS = combinations_grid(fold=(1, 2, 3), train=(True, False))
_VIDEO_FOLDER = "videos"
_SPLITS_FOLDER = "splits"
......@@ -1024,7 +1073,7 @@ class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
class OmniglotTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Omniglot
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(background=(True, False))
ADDITIONAL_CONFIGS = combinations_grid(background=(True, False))
def inject_fake_data(self, tmpdir, config):
target_folder = (
......@@ -1104,7 +1153,7 @@ class SEMEIONTestCase(datasets_utils.ImageDatasetTestCase):
class USPSTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.USPS
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))
ADDITIONAL_CONFIGS = combinations_grid(train=(True, False))
def inject_fake_data(self, tmpdir, config):
num_images = 2 if config["train"] else 1
......@@ -1126,7 +1175,7 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase):
REQUIRED_PACKAGES = ("scipy.io", "scipy.sparse")
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
ADDITIONAL_CONFIGS = combinations_grid(
image_set=("train", "val", "train_noval"), mode=("boundaries", "segmentation")
)
......@@ -1187,6 +1236,10 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase):
def _file_stem(self, idx):
return f"2008_{idx:06d}"
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(mode="segmentation") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class FakeDataTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FakeData
......@@ -1212,7 +1265,7 @@ class PhotoTourTestCase(datasets_utils.ImageDatasetTestCase):
_TRAIN_FEATURE_TYPES = (torch.Tensor,)
_TEST_FEATURE_TYPES = (torch.Tensor, torch.Tensor, torch.Tensor)
datasets_utils.combinations_grid(train=(True, False))
combinations_grid(train=(True, False))
_NAME = "liberty"
......@@ -1371,7 +1424,7 @@ class Flickr30kTestCase(Flickr8kTestCase):
class MNISTTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.MNIST
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))
ADDITIONAL_CONFIGS = combinations_grid(train=(True, False))
_MAGIC_DTYPES = {
torch.uint8: 8,
......@@ -1441,7 +1494,7 @@ class EMNISTTestCase(MNISTTestCase):
DATASET_CLASS = datasets.EMNIST
DEFAULT_CONFIG = dict(split="byclass")
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
ADDITIONAL_CONFIGS = combinations_grid(
split=("byclass", "bymerge", "balanced", "letters", "digits", "mnist"), train=(True, False)
)
......@@ -1452,7 +1505,7 @@ class EMNISTTestCase(MNISTTestCase):
class QMNISTTestCase(MNISTTestCase):
DATASET_CLASS = datasets.QMNIST
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(what=("train", "test", "test10k", "nist"))
ADDITIONAL_CONFIGS = combinations_grid(what=("train", "test", "test10k", "nist"))
_LABELS_SIZE = (8,)
_LABELS_DTYPE = torch.int32
......@@ -1494,30 +1547,51 @@ class QMNISTTestCase(MNISTTestCase):
assert len(dataset) == info["num_examples"] - 10000
class MovingMNISTTestCase(datasets_utils.DatasetTestCase):
DATASET_CLASS = datasets.MovingMNIST
FEATURE_TYPES = (torch.Tensor,)
ADDITIONAL_CONFIGS = combinations_grid(split=(None, "train", "test"), split_ratio=(10, 1, 19))
_NUM_FRAMES = 20
def inject_fake_data(self, tmpdir, config):
base_folder = os.path.join(tmpdir, self.DATASET_CLASS.__name__)
os.makedirs(base_folder, exist_ok=True)
num_samples = 5
data = np.concatenate(
[
np.zeros((config["split_ratio"], num_samples, 64, 64)),
np.ones((self._NUM_FRAMES - config["split_ratio"], num_samples, 64, 64)),
]
)
np.save(os.path.join(base_folder, "mnist_test_seq.npy"), data)
return num_samples
@datasets_utils.test_all_configs
def test_split(self, config):
with self.create_dataset(config) as (dataset, _):
if config["split"] == "train":
assert (dataset.data == 0).all()
elif config["split"] == "test":
assert (dataset.data == 1).all()
else:
assert dataset.data.size()[1] == self._NUM_FRAMES
class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.DatasetFolder
# The dataset has no fixed return type since it is defined by the loader parameter. For testing, we use a loader
# that simply returns the path as type 'str' instead of loading anything. See the 'dataset_args()' method.
FEATURE_TYPES = (str, int)
_IMAGE_EXTENSIONS = ("jpg", "png")
_VIDEO_EXTENSIONS = ("avi", "mp4")
_EXTENSIONS = (*_IMAGE_EXTENSIONS, *_VIDEO_EXTENSIONS)
_EXTENSIONS = ("jpg", "png")
# DatasetFolder has two mutually exclusive parameters: 'extensions' and 'is_valid_file'. One of both is required.
# We only iterate over different 'extensions' here and handle the tests for 'is_valid_file' in the
# 'test_is_valid_file()' method.
DEFAULT_CONFIG = dict(extensions=_EXTENSIONS)
ADDITIONAL_CONFIGS = (
*datasets_utils.combinations_grid(extensions=[(ext,) for ext in _IMAGE_EXTENSIONS]),
dict(extensions=_IMAGE_EXTENSIONS),
*datasets_utils.combinations_grid(extensions=[(ext,) for ext in _VIDEO_EXTENSIONS]),
dict(extensions=_VIDEO_EXTENSIONS),
)
ADDITIONAL_CONFIGS = combinations_grid(extensions=[(ext,) for ext in _EXTENSIONS])
def dataset_args(self, tmpdir, config):
return tmpdir, lambda x: x
return tmpdir, datasets.folder.pil_loader
def inject_fake_data(self, tmpdir, config):
extensions = config["extensions"] or self._is_valid_file_to_extensions(config["is_valid_file"])
......@@ -1528,14 +1602,8 @@ class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
if ext not in extensions:
continue
create_example_folder = (
datasets_utils.create_image_folder
if ext in self._IMAGE_EXTENSIONS
else datasets_utils.create_video_folder
)
num_examples = torch.randint(1, 3, size=()).item()
create_example_folder(tmpdir, cls, lambda idx: self._file_name_fn(cls, ext, idx), num_examples)
datasets_utils.create_image_folder(tmpdir, cls, lambda idx: self._file_name_fn(cls, ext, idx), num_examples)
num_examples_total += num_examples
classes.append(cls)
......@@ -1589,7 +1657,7 @@ class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase):
class KittiTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Kitti
FEATURE_TYPES = (PIL.Image.Image, (list, type(None))) # test split returns None as target
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))
ADDITIONAL_CONFIGS = combinations_grid(train=(True, False))
def inject_fake_data(self, tmpdir, config):
kitti_dir = os.path.join(tmpdir, "Kitti", "raw")
......@@ -1621,11 +1689,15 @@ class KittiTestCase(datasets_utils.ImageDatasetTestCase):
return split_to_num_examples[config["train"]]
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class SvhnTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SVHN
REQUIRED_PACKAGES = ("scipy",)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test", "extra"))
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test", "extra"))
def inject_fake_data(self, tmpdir, config):
import scipy.io as sio
......@@ -1646,7 +1718,7 @@ class SvhnTestCase(datasets_utils.ImageDatasetTestCase):
class Places365TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Places365
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
ADDITIONAL_CONFIGS = combinations_grid(
split=("train-standard", "train-challenge", "val"),
small=(False, True),
)
......@@ -1738,7 +1810,7 @@ class INaturalistTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.INaturalist
FEATURE_TYPES = (PIL.Image.Image, (int, tuple))
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
ADDITIONAL_CONFIGS = combinations_grid(
target_type=("kingdom", "full", "genus", ["kingdom", "phylum", "class", "order", "family", "genus", "full"]),
version=("2021_train",),
)
......@@ -1775,7 +1847,7 @@ class INaturalistTestCase(datasets_utils.ImageDatasetTestCase):
class LFWPeopleTestCase(datasets_utils.DatasetTestCase):
DATASET_CLASS = datasets.LFWPeople
FEATURE_TYPES = (PIL.Image.Image, int)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
ADDITIONAL_CONFIGS = combinations_grid(
split=("10fold", "train", "test"), image_set=("original", "funneled", "deepfunneled")
)
_IMAGES_DIR = {"original": "lfw", "funneled": "lfw_funneled", "deepfunneled": "lfw-deepfunneled"}
......@@ -1851,7 +1923,7 @@ class LFWPairsTestCase(LFWPeopleTestCase):
class SintelTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Sintel
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"), pass_name=("clean", "final", "both"))
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"), pass_name=("clean", "final", "both"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
FLOW_H, FLOW_W = 3, 4
......@@ -1919,7 +1991,7 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
class KittiFlowTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.KittiFlow
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))
def inject_fake_data(self, tmpdir, config):
......@@ -1979,7 +2051,7 @@ class KittiFlowTestCase(datasets_utils.ImageDatasetTestCase):
class FlyingChairsTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FlyingChairs
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val"))
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
FLOW_H, FLOW_W = 3, 4
......@@ -2034,7 +2106,7 @@ class FlyingChairsTestCase(datasets_utils.ImageDatasetTestCase):
class FlyingThings3DTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FlyingThings3D
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
ADDITIONAL_CONFIGS = combinations_grid(
split=("train", "test"), pass_name=("clean", "final", "both"), camera=("left", "right", "both")
)
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
......@@ -2171,7 +2243,7 @@ class Food101TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Food101
FEATURE_TYPES = (PIL.Image.Image, int)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
def inject_fake_data(self, tmpdir: str, config):
root_folder = pathlib.Path(tmpdir) / "food-101"
......@@ -2206,7 +2278,7 @@ class Food101TestCase(datasets_utils.ImageDatasetTestCase):
class FGVCAircraftTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FGVCAircraft
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
ADDITIONAL_CONFIGS = combinations_grid(
split=("train", "val", "trainval", "test"), annotation_level=("variant", "family", "manufacturer")
)
......@@ -2289,7 +2361,7 @@ class DTDTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.DTD
FEATURE_TYPES = (PIL.Image.Image, int)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
ADDITIONAL_CONFIGS = combinations_grid(
split=("train", "test", "val"),
# There is no need to test the whole matrix here, since each fold is treated exactly the same
partition=(1, 5, 10),
......@@ -2323,7 +2395,7 @@ class DTDTestCase(datasets_utils.ImageDatasetTestCase):
class FER2013TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FER2013
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
FEATURE_TYPES = (PIL.Image.Image, (int, type(None)))
......@@ -2358,7 +2430,7 @@ class GTSRBTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.GTSRB
FEATURE_TYPES = (PIL.Image.Image, int)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
def inject_fake_data(self, tmpdir: str, config):
root_folder = os.path.join(tmpdir, "gtsrb")
......@@ -2408,7 +2480,7 @@ class CLEVRClassificationTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CLEVRClassification
FEATURE_TYPES = (PIL.Image.Image, (int, type(None)))
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test"))
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
def inject_fake_data(self, tmpdir, config):
data_folder = pathlib.Path(tmpdir) / "clevr" / "CLEVR_v1.0"
......@@ -2440,7 +2512,7 @@ class OxfordIIITPetTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.OxfordIIITPet
FEATURE_TYPES = (PIL.Image.Image, (int, PIL.Image.Image, tuple, type(None)))
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
ADDITIONAL_CONFIGS = combinations_grid(
split=("trainval", "test"),
target_types=("category", "segmentation", ["category", "segmentation"], []),
)
......@@ -2495,11 +2567,15 @@ class OxfordIIITPetTestCase(datasets_utils.ImageDatasetTestCase):
breed_id = "-1"
return (image_id, class_id, species, breed_id)
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.StanfordCars
REQUIRED_PACKAGES = ("scipy",)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
def inject_fake_data(self, tmpdir, config):
import scipy.io as io
......@@ -2543,7 +2619,7 @@ class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):
class Country211TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Country211
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "valid", "test"))
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "valid", "test"))
def inject_fake_data(self, tmpdir: str, config):
split_folder = pathlib.Path(tmpdir) / "country211" / config["split"]
......@@ -2570,7 +2646,7 @@ class Country211TestCase(datasets_utils.ImageDatasetTestCase):
class Flowers102TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Flowers102
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test"))
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
REQUIRED_PACKAGES = ("scipy",)
def inject_fake_data(self, tmpdir: str, config):
......@@ -2606,7 +2682,7 @@ class Flowers102TestCase(datasets_utils.ImageDatasetTestCase):
class PCAMTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.PCAM
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test"))
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
REQUIRED_PACKAGES = ("h5py",)
def inject_fake_data(self, tmpdir: str, config):
......@@ -2628,7 +2704,7 @@ class PCAMTestCase(datasets_utils.ImageDatasetTestCase):
class RenderedSST2TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.RenderedSST2
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test"))
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
SPLIT_TO_FOLDER = {"train": "train", "val": "valid", "test": "test"}
def inject_fake_data(self, tmpdir: str, config):
......@@ -2650,7 +2726,7 @@ class RenderedSST2TestCase(datasets_utils.ImageDatasetTestCase):
class Kitti2012StereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Kitti2012Stereo
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))
def inject_fake_data(self, tmpdir, config):
......@@ -2712,7 +2788,7 @@ class Kitti2012StereoTestCase(datasets_utils.ImageDatasetTestCase):
class Kitti2015StereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Kitti2015Stereo
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))
def inject_fake_data(self, tmpdir, config):
......@@ -2850,7 +2926,7 @@ class CREStereoTestCase(datasets_utils.ImageDatasetTestCase):
class FallingThingsStereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FallingThingsStereo
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(variant=("single", "mixed", "both"))
ADDITIONAL_CONFIGS = combinations_grid(variant=("single", "mixed", "both"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
@staticmethod
......@@ -2924,7 +3000,7 @@ class FallingThingsStereoTestCase(datasets_utils.ImageDatasetTestCase):
class SceneFlowStereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SceneFlowStereo
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
ADDITIONAL_CONFIGS = combinations_grid(
variant=("FlyingThings3D", "Driving", "Monkaa"), pass_name=("clean", "final", "both")
)
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
......@@ -3011,7 +3087,7 @@ class SceneFlowStereoTestCase(datasets_utils.ImageDatasetTestCase):
class InStereo2k(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.InStereo2k
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
@staticmethod
def _make_scene_folder(root: str, name: str, size: Tuple[int, int]):
......@@ -3053,7 +3129,7 @@ class InStereo2k(datasets_utils.ImageDatasetTestCase):
class SintelStereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SintelStereo
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(pass_name=("final", "clean", "both"))
ADDITIONAL_CONFIGS = combinations_grid(pass_name=("final", "clean", "both"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))
def inject_fake_data(self, tmpdir, config):
......@@ -3129,7 +3205,7 @@ class SintelStereoTestCase(datasets_utils.ImageDatasetTestCase):
class ETH3DStereoestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.ETH3DStereo
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))
@staticmethod
......@@ -3196,7 +3272,7 @@ class ETH3DStereoestCase(datasets_utils.ImageDatasetTestCase):
class Middlebury2014StereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Middlebury2014Stereo
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
ADDITIONAL_CONFIGS = combinations_grid(
split=("train", "additional"),
calibration=("perfect", "imperfect", "both"),
use_ambient_views=(True, False),
......@@ -3287,5 +3363,47 @@ class Middlebury2014StereoTestCase(datasets_utils.ImageDatasetTestCase):
pass
class TestDatasetWrapper:
def test_unknown_type(self):
unknown_object = object()
with pytest.raises(
TypeError, match=re.escape("is meant for subclasses of `torchvision.datasets.VisionDataset`")
):
datasets.wrap_dataset_for_transforms_v2(unknown_object)
def test_unknown_dataset(self):
class MyVisionDataset(datasets.VisionDataset):
pass
dataset = MyVisionDataset("root")
with pytest.raises(TypeError, match="No wrapper exist"):
datasets.wrap_dataset_for_transforms_v2(dataset)
def test_missing_wrapper(self):
dataset = datasets.FakeData()
with pytest.raises(TypeError, match="please open an issue"):
datasets.wrap_dataset_for_transforms_v2(dataset)
def test_subclass(self, mocker):
from torchvision import tv_tensors
sentinel = object()
mocker.patch.dict(
tv_tensors._dataset_wrapper.WRAPPER_FACTORIES,
clear=False,
values={datasets.FakeData: lambda dataset, target_keys: lambda idx, sample: sentinel},
)
class MyFakeData(datasets.FakeData):
pass
dataset = MyFakeData()
wrapped_dataset = datasets.wrap_dataset_for_transforms_v2(dataset)
assert wrapped_dataset[0] is sentinel
if __name__ == "__main__":
unittest.main()
......@@ -2,6 +2,7 @@ import contextlib
import itertools
import tempfile
import time
import traceback
import unittest.mock
import warnings
from datetime import datetime
......@@ -13,13 +14,7 @@ from urllib.request import Request, urlopen
import pytest
from torchvision import datasets
from torchvision.datasets.utils import (
_get_redirect_url,
check_integrity,
download_file_from_google_drive,
download_url,
USER_AGENT,
)
from torchvision.datasets.utils import _get_redirect_url, USER_AGENT
def limit_requests_per_time(min_secs_between_requests=2.0):
......@@ -83,63 +78,65 @@ urlopen = resolve_redirects()(urlopen)
@contextlib.contextmanager
def log_download_attempts(
urls_and_md5s=None,
file="utils",
patch=True,
mock_auxiliaries=None,
urls,
*,
dataset_module,
):
def add_mock(stack, name, file, **kwargs):
def maybe_add_mock(*, module, name, stack, lst=None):
patcher = unittest.mock.patch(f"torchvision.datasets.{module}.{name}")
try:
return stack.enter_context(unittest.mock.patch(f"torchvision.datasets.{file}.{name}", **kwargs))
except AttributeError as error:
if file != "utils":
return add_mock(stack, name, "utils", **kwargs)
else:
raise pytest.UsageError from error
if urls_and_md5s is None:
urls_and_md5s = set()
if mock_auxiliaries is None:
mock_auxiliaries = patch
mock = stack.enter_context(patcher)
except AttributeError:
return
with contextlib.ExitStack() as stack:
url_mock = add_mock(stack, "download_url", file, wraps=None if patch else download_url)
google_drive_mock = add_mock(
stack, "download_file_from_google_drive", file, wraps=None if patch else download_file_from_google_drive
)
if lst is not None:
lst.append(mock)
if mock_auxiliaries:
add_mock(stack, "extract_archive", file)
with contextlib.ExitStack() as stack:
download_url_mocks = []
download_file_from_google_drive_mocks = []
for module in [dataset_module, "utils"]:
maybe_add_mock(module=module, name="download_url", stack=stack, lst=download_url_mocks)
maybe_add_mock(
module=module,
name="download_file_from_google_drive",
stack=stack,
lst=download_file_from_google_drive_mocks,
)
maybe_add_mock(module=module, name="extract_archive", stack=stack)
try:
yield urls_and_md5s
yield
finally:
for args, kwargs in url_mock.call_args_list:
url = args[0]
md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
urls_and_md5s.add((url, md5))
for download_url_mock in download_url_mocks:
for args, kwargs in download_url_mock.call_args_list:
urls.append(args[0] if args else kwargs["url"])
for args, kwargs in google_drive_mock.call_args_list:
id = args[0]
url = f"https://drive.google.com/file/d/{id}"
md5 = args[3] if len(args) == 4 else kwargs.get("md5")
urls_and_md5s.add((url, md5))
for download_file_from_google_drive_mock in download_file_from_google_drive_mocks:
for args, kwargs in download_file_from_google_drive_mock.call_args_list:
file_id = args[0] if args else kwargs["file_id"]
urls.append(f"https://drive.google.com/file/d/{file_id}")
def retry(fn, times=1, wait=5.0):
msgs = []
tbs = []
for _ in range(times + 1):
try:
return fn()
except AssertionError as error:
msgs.append(str(error))
tbs.append("".join(traceback.format_exception(type(error), error, error.__traceback__)))
time.sleep(wait)
else:
raise AssertionError(
"\n".join(
(
f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n",
*(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)),
"\n",
*[f"{'_' * 40} {idx:2d} {'_' * 40}\n\n{tb}" for idx, tb in enumerate(tbs, 1)],
(
f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time. "
f"You can find the the full tracebacks above."
),
)
)
)
......@@ -149,10 +146,12 @@ def retry(fn, times=1, wait=5.0):
def assert_server_response_ok():
try:
yield
except URLError as error:
raise AssertionError("The request timed out.") from error
except HTTPError as error:
raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error
except URLError as error:
raise AssertionError(
"Connection not possible due to SSL." if "SSL" in str(error) else "The request timed out."
) from error
except RecursionError as error:
raise AssertionError(str(error)) from error
......@@ -163,45 +162,14 @@ def assert_url_is_accessible(url, timeout=5.0):
urlopen(request, timeout=timeout)
def assert_file_downloads_correctly(url, md5, tmpdir, timeout=5.0):
file = path.join(tmpdir, path.basename(url))
with assert_server_response_ok():
with open(file, "wb") as fh:
request = Request(url, headers={"User-Agent": USER_AGENT})
response = urlopen(request, timeout=timeout)
fh.write(response.read())
assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"
class DownloadConfig:
def __init__(self, url, md5=None, id=None):
self.url = url
self.md5 = md5
self.id = id or url
def collect_urls(dataset_cls, *args, **kwargs):
urls = []
with contextlib.suppress(Exception), log_download_attempts(
urls, dataset_module=dataset_cls.__module__.split(".")[-1]
):
dataset_cls(*args, **kwargs)
def __repr__(self) -> str:
return self.id
def make_download_configs(urls_and_md5s, name=None):
return [
DownloadConfig(url, md5=md5, id=f"{name}, {url}" if name is not None else None) for url, md5 in urls_and_md5s
]
def collect_download_configs(dataset_loader, name=None, **kwargs):
urls_and_md5s = set()
try:
with log_download_attempts(urls_and_md5s=urls_and_md5s, **kwargs):
dataset = dataset_loader()
except Exception:
dataset = None
if name is None and dataset is not None:
name = type(dataset).__name__
return make_download_configs(urls_and_md5s, name)
return [(url, f"{dataset_cls.__name__}, {url}") for url in urls]
# This is a workaround since fixtures, such as the built-in tmp_dir, can only be used within a test but not within a
......@@ -216,12 +184,14 @@ def root():
def places365():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.Places365(ROOT, split=split, small=small, download=True),
name=f"Places365, {split}, {'small' if small else 'large'}",
file="places365",
return itertools.chain.from_iterable(
[
collect_urls(
datasets.Places365,
ROOT,
split=split,
small=small,
download=True,
)
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True))
]
......@@ -229,30 +199,26 @@ def places365():
def caltech101():
return collect_download_configs(lambda: datasets.Caltech101(ROOT, download=True), name="Caltech101")
return collect_urls(datasets.Caltech101, ROOT, download=True)
def caltech256():
return collect_download_configs(lambda: datasets.Caltech256(ROOT, download=True), name="Caltech256")
return collect_urls(datasets.Caltech256, ROOT, download=True)
def cifar10():
return collect_download_configs(lambda: datasets.CIFAR10(ROOT, download=True), name="CIFAR10")
return collect_urls(datasets.CIFAR10, ROOT, download=True)
def cifar100():
return collect_download_configs(lambda: datasets.CIFAR100(ROOT, download=True), name="CIFAR100")
return collect_urls(datasets.CIFAR100, ROOT, download=True)
def voc():
# TODO: Also test the "2007-test" key
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.VOCSegmentation(ROOT, year=year, download=True),
name=f"VOC, {year}",
file="voc",
)
return itertools.chain.from_iterable(
[
collect_urls(datasets.VOCSegmentation, ROOT, year=year, download=True)
for year in ("2007", "2008", "2009", "2010", "2011", "2012")
]
)
......@@ -260,55 +226,42 @@ def voc():
def mnist():
with unittest.mock.patch.object(datasets.MNIST, "mirrors", datasets.MNIST.mirrors[-1:]):
return collect_download_configs(lambda: datasets.MNIST(ROOT, download=True), name="MNIST")
return collect_urls(datasets.MNIST, ROOT, download=True)
def fashion_mnist():
return collect_download_configs(lambda: datasets.FashionMNIST(ROOT, download=True), name="FashionMNIST")
return collect_urls(datasets.FashionMNIST, ROOT, download=True)
def kmnist():
return collect_download_configs(lambda: datasets.KMNIST(ROOT, download=True), name="KMNIST")
return collect_urls(datasets.KMNIST, ROOT, download=True)
def emnist():
# the 'split' argument can be any valid one, since everything is downloaded anyway
return collect_download_configs(lambda: datasets.EMNIST(ROOT, split="byclass", download=True), name="EMNIST")
return collect_urls(datasets.EMNIST, ROOT, split="byclass", download=True)
def qmnist():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.QMNIST(ROOT, what=what, download=True),
name=f"QMNIST, {what}",
file="mnist",
)
for what in ("train", "test", "nist")
]
return itertools.chain.from_iterable(
[collect_urls(datasets.QMNIST, ROOT, what=what, download=True) for what in ("train", "test", "nist")]
)
def moving_mnist():
return collect_urls(datasets.MovingMNIST, ROOT, download=True)
def omniglot():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.Omniglot(ROOT, background=background, download=True),
name=f"Omniglot, {'background' if background else 'evaluation'}",
)
for background in (True, False)
]
return itertools.chain.from_iterable(
[collect_urls(datasets.Omniglot, ROOT, background=background, download=True) for background in (True, False)]
)
def phototour():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.PhotoTour(ROOT, name=name, download=True),
name=f"PhotoTour, {name}",
file="phototour",
)
return itertools.chain.from_iterable(
[
collect_urls(datasets.PhotoTour, ROOT, name=name, download=True)
# The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all
# requests timeout from within CI. They are disabled until this is resolved.
for name in ("notredame", "yosemite", "liberty") # "notredame_harris", "yosemite_harris", "liberty_harris"
......@@ -317,91 +270,51 @@ def phototour():
def sbdataset():
return collect_download_configs(
lambda: datasets.SBDataset(ROOT, download=True),
name="SBDataset",
file="voc",
)
return collect_urls(datasets.SBDataset, ROOT, download=True)
def sbu():
return collect_download_configs(
lambda: datasets.SBU(ROOT, download=True),
name="SBU",
file="sbu",
)
return collect_urls(datasets.SBU, ROOT, download=True)
def semeion():
return collect_download_configs(
lambda: datasets.SEMEION(ROOT, download=True),
name="SEMEION",
file="semeion",
)
return collect_urls(datasets.SEMEION, ROOT, download=True)
def stl10():
return collect_download_configs(
lambda: datasets.STL10(ROOT, download=True),
name="STL10",
)
return collect_urls(datasets.STL10, ROOT, download=True)
def svhn():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.SVHN(ROOT, split=split, download=True),
name=f"SVHN, {split}",
file="svhn",
)
for split in ("train", "test", "extra")
]
return itertools.chain.from_iterable(
[collect_urls(datasets.SVHN, ROOT, split=split, download=True) for split in ("train", "test", "extra")]
)
def usps():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.USPS(ROOT, train=train, download=True),
name=f"USPS, {'train' if train else 'test'}",
file="usps",
)
for train in (True, False)
]
return itertools.chain.from_iterable(
[collect_urls(datasets.USPS, ROOT, train=train, download=True) for train in (True, False)]
)
def celeba():
return collect_download_configs(
lambda: datasets.CelebA(ROOT, download=True),
name="CelebA",
file="celeba",
)
return collect_urls(datasets.CelebA, ROOT, download=True)
def widerface():
return collect_download_configs(
lambda: datasets.WIDERFace(ROOT, download=True),
name="WIDERFace",
file="widerface",
)
return collect_urls(datasets.WIDERFace, ROOT, download=True)
def kinetics():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.Kinetics(
path.join(ROOT, f"Kinetics{num_classes}"),
frames_per_clip=1,
num_classes=num_classes,
split=split,
download=True,
),
name=f"Kinetics, {num_classes}, {split}",
file="kinetics",
return itertools.chain.from_iterable(
[
collect_urls(
datasets.Kinetics,
path.join(ROOT, f"Kinetics{num_classes}"),
frames_per_clip=1,
num_classes=num_classes,
split=split,
download=True,
)
for num_classes, split in itertools.product(("400", "600", "700"), ("train", "val"))
]
......@@ -409,58 +322,55 @@ def kinetics():
def kitti():
return itertools.chain(
*[
collect_download_configs(
lambda train=train: datasets.Kitti(ROOT, train=train, download=True),
name=f"Kitti, {'train' if train else 'test'}",
file="kitti",
)
for train in (True, False)
]
return itertools.chain.from_iterable(
[collect_urls(datasets.Kitti, ROOT, train=train, download=True) for train in (True, False)]
)
def make_parametrize_kwargs(download_configs):
argvalues = []
ids = []
for config in download_configs:
argvalues.append((config.url, config.md5))
ids.append(config.id)
return dict(argnames=("url", "md5"), argvalues=argvalues, ids=ids)
@pytest.mark.parametrize(
**make_parametrize_kwargs(
itertools.chain(
caltech101(),
caltech256(),
cifar10(),
cifar100(),
# The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details.
# voc(),
mnist(),
fashion_mnist(),
kmnist(),
emnist(),
qmnist(),
omniglot(),
phototour(),
sbdataset(),
sbu(),
semeion(),
stl10(),
svhn(),
usps(),
celeba(),
widerface(),
kinetics(),
kitti(),
)
def stanford_cars():
return itertools.chain.from_iterable(
[collect_urls(datasets.StanfordCars, ROOT, split=split, download=True) for split in ["train", "test"]]
)
def url_parametrization(*dataset_urls_and_ids_fns):
return pytest.mark.parametrize(
"url",
[
pytest.param(url, id=id)
for dataset_urls_and_ids_fn in dataset_urls_and_ids_fns
for url, id in sorted(set(dataset_urls_and_ids_fn()))
],
)
@url_parametrization(
caltech101,
caltech256,
cifar10,
cifar100,
# The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details.
# voc,
mnist,
fashion_mnist,
kmnist,
emnist,
qmnist,
omniglot,
phototour,
sbdataset,
semeion,
stl10,
svhn,
usps,
celeba,
widerface,
kinetics,
kitti,
places365,
sbu,
)
def test_url_is_accessible(url, md5):
def test_url_is_accessible(url):
"""
If you see this test failing, find the offending dataset in the parametrization and move it to
``test_url_is_not_accessible`` and link an issue detailing the problem.
......@@ -468,15 +378,11 @@ def test_url_is_accessible(url, md5):
retry(lambda: assert_url_is_accessible(url))
@pytest.mark.parametrize(
**make_parametrize_kwargs(
itertools.chain(
places365(), # https://github.com/pytorch/vision/issues/6268
)
)
@url_parametrization(
stanford_cars, # https://github.com/pytorch/vision/issues/7545
)
@pytest.mark.xfail
def test_url_is_not_accessible(url, md5):
def test_url_is_not_accessible(url):
"""
As the name implies, this test is the 'inverse' of ``test_url_is_accessible``. Since the download servers are
beyond our control, some files might not be accessible for longer stretches of time. Still, we want to know if they
......@@ -486,8 +392,3 @@ def test_url_is_not_accessible(url, md5):
``test_url_is_accessible``.
"""
retry(lambda: assert_url_is_accessible(url))
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain()))
def test_file_downloads_correctly(url, md5):
retry(lambda: assert_file_downloads_correctly(url, md5))
......@@ -7,7 +7,9 @@ import tarfile
import zipfile
import pytest
import torch
import torchvision.datasets.utils as utils
from common_utils import assert_equal
from torch._utils_internal import get_file_path_2
from torchvision.datasets.folder import make_dataset
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS
......@@ -215,6 +217,24 @@ class TestDatasetsUtils:
pytest.raises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")
@pytest.mark.parametrize(
("dtype", "actual_hex", "expected_hex"),
[
(torch.uint8, "01 23 45 67 89 AB CD EF", "01 23 45 67 89 AB CD EF"),
(torch.float16, "01 23 45 67 89 AB CD EF", "23 01 67 45 AB 89 EF CD"),
(torch.int32, "01 23 45 67 89 AB CD EF", "67 45 23 01 EF CD AB 89"),
(torch.float64, "01 23 45 67 89 AB CD EF", "EF CD AB 89 67 45 23 01"),
],
)
def test_flip_byte_order(self, dtype, actual_hex, expected_hex):
def to_tensor(hex):
return torch.frombuffer(bytes.fromhex(hex), dtype=dtype)
assert_equal(
utils._flip_byte_order(to_tensor(actual_hex)),
to_tensor(expected_hex),
)
@pytest.mark.parametrize(
("kwargs", "expected_error_msg"),
......
import copy
import os
import pickle
import pytest
import test_models as TM
import torch
from common_extended_utils import get_file_size_mb, get_ops
from torchvision import models
from torchvision.models._api import get_model_weights, Weights, WeightsEnum
from torchvision.models import get_model_weights, Weights, WeightsEnum
from torchvision.models._utils import handle_legacy_interface
from torchvision.models.detection.backbone_utils import mobilenet_backbone, resnet_fpn_backbone
run_if_test_with_extended = pytest.mark.skipif(
os.getenv("PYTORCH_TEST_WITH_EXTENDED", "0") != "1",
......@@ -59,17 +62,59 @@ def test_get_model_weights(name, weight):
assert models.get_model_weights(name) == weight
@pytest.mark.parametrize("copy_fn", [copy.copy, copy.deepcopy])
@pytest.mark.parametrize(
"name",
[
"resnet50",
"retinanet_resnet50_fpn_v2",
"raft_large",
"quantized_resnet50",
"lraspp_mobilenet_v3_large",
"mvit_v1_b",
],
)
def test_weights_copyable(copy_fn, name):
for weights in list(models.get_model_weights(name)):
# It is somewhat surprising that (deep-)copying is an identity operation here, but this is the default behavior
# of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances
# Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop
# support for the identity operation in the future.
assert copy_fn(weights) is weights
@pytest.mark.parametrize(
"name",
[
"resnet50",
"retinanet_resnet50_fpn_v2",
"raft_large",
"quantized_resnet50",
"lraspp_mobilenet_v3_large",
"mvit_v1_b",
],
)
def test_weights_deserializable(name):
for weights in list(models.get_model_weights(name)):
# It is somewhat surprising that deserialization is an identity operation here, but this is the default behavior
# of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances
# Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop
# support for the identity operation in the future.
assert pickle.loads(pickle.dumps(weights)) is weights
def get_models_from_module(module):
return [
v.__name__
for k, v in module.__dict__.items()
if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__
]
@pytest.mark.parametrize(
"module", [models, models.detection, models.quantization, models.segmentation, models.video, models.optical_flow]
)
def test_list_models(module):
def get_models_from_module(module):
return [
v.__name__
for k, v in module.__dict__.items()
if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__
]
a = set(get_models_from_module(module))
b = set(x.replace("quantized_", "") for x in models.list_models(module))
......@@ -77,6 +122,65 @@ def test_list_models(module):
assert a == b
@pytest.mark.parametrize(
"include_filters",
[
None,
[],
(),
"",
"*resnet*",
["*alexnet*"],
"*not-existing-model-for-test?",
["*resnet*", "*alexnet*"],
["*resnet*", "*alexnet*", "*not-existing-model-for-test?"],
("*resnet*", "*alexnet*"),
set(["*resnet*", "*alexnet*"]),
],
)
@pytest.mark.parametrize(
"exclude_filters",
[
None,
[],
(),
"",
"*resnet*",
["*alexnet*"],
["*not-existing-model-for-test?"],
["resnet34", "*not-existing-model-for-test?"],
["resnet34", "*resnet1*"],
("resnet34", "*resnet1*"),
set(["resnet34", "*resnet1*"]),
],
)
def test_list_models_filters(include_filters, exclude_filters):
actual = set(models.list_models(models, include=include_filters, exclude=exclude_filters))
classification_models = set(get_models_from_module(models))
if isinstance(include_filters, str):
include_filters = [include_filters]
if isinstance(exclude_filters, str):
exclude_filters = [exclude_filters]
if include_filters:
expected = set()
for include_f in include_filters:
include_f = include_f.strip("*?")
expected = expected | set(x for x in classification_models if include_f in x)
else:
expected = classification_models
if exclude_filters:
for exclude_f in exclude_filters:
exclude_f = exclude_f.strip("*?")
if exclude_f != "":
a_exclude = set(x for x in classification_models if exclude_f in x)
expected = expected - a_exclude
assert expected == actual
@pytest.mark.parametrize(
"name, weight",
[
......@@ -111,6 +215,22 @@ def test_naming_conventions(model_fn):
assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT")
detection_models_input_dims = {
"fasterrcnn_mobilenet_v3_large_320_fpn": (320, 320),
"fasterrcnn_mobilenet_v3_large_fpn": (800, 800),
"fasterrcnn_resnet50_fpn": (800, 800),
"fasterrcnn_resnet50_fpn_v2": (800, 800),
"fcos_resnet50_fpn": (800, 800),
"keypointrcnn_resnet50_fpn": (1333, 1333),
"maskrcnn_resnet50_fpn": (800, 800),
"maskrcnn_resnet50_fpn_v2": (800, 800),
"retinanet_resnet50_fpn": (800, 800),
"retinanet_resnet50_fpn_v2": (800, 800),
"ssd300_vgg16": (300, 300),
"ssdlite320_mobilenet_v3_large": (320, 320),
}
@pytest.mark.parametrize(
"model_fn",
TM.list_model_fns(models)
......@@ -122,6 +242,9 @@ def test_naming_conventions(model_fn):
)
@run_if_test_with_extended
def test_schema_meta_validation(model_fn):
if model_fn.__name__ == "maskrcnn_resnet50_fpn_v2":
pytest.skip(reason="FIXME https://github.com/pytorch/vision/issues/7349")
# list of all possible supported high-level fields for weights meta-data
permitted_fields = {
"backend",
......@@ -135,11 +258,13 @@ def test_schema_meta_validation(model_fn):
"recipe",
"unquantized",
"_docs",
"_ops",
"_file_size",
}
# mandatory fields for each computer vision task
classification_fields = {"categories", ("_metrics", "ImageNet-1K", "acc@1"), ("_metrics", "ImageNet-1K", "acc@5")}
defaults = {
"all": {"_metrics", "min_size", "num_params", "recipe", "_docs"},
"all": {"_metrics", "min_size", "num_params", "recipe", "_docs", "_file_size", "_ops"},
"models": classification_fields,
"detection": {"categories", ("_metrics", "COCO-val2017", "box_map")},
"quantization": classification_fields | {"backend", "unquantized"},
......@@ -160,7 +285,7 @@ def test_schema_meta_validation(model_fn):
pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")
problematic_weights = {}
incorrect_params = []
incorrect_meta = []
bad_names = []
for w in weights_enum:
actual_fields = set(w.meta.keys())
......@@ -173,24 +298,47 @@ def test_schema_meta_validation(model_fn):
unsupported_fields = set(w.meta.keys()) - permitted_fields
if missing_fields or unsupported_fields:
problematic_weights[w] = {"missing": missing_fields, "unsupported": unsupported_fields}
if w == weights_enum.DEFAULT:
if w == weights_enum.DEFAULT or any(w.meta[k] != weights_enum.DEFAULT.meta[k] for k in ["num_params", "_ops"]):
if module_name == "quantization":
# parameters() count doesn't work well with quantization, so we check against the non-quantized
unquantized_w = w.meta.get("unquantized")
if unquantized_w is not None and w.meta.get("num_params") != unquantized_w.meta.get("num_params"):
incorrect_params.append(w)
if unquantized_w is not None:
if w.meta.get("num_params") != unquantized_w.meta.get("num_params"):
incorrect_meta.append((w, "num_params"))
# the methodology for quantized ops count doesn't work as well, so we take unquantized FLOPs
# instead
if w.meta["_ops"] != unquantized_w.meta.get("_ops"):
incorrect_meta.append((w, "_ops"))
else:
if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()):
incorrect_params.append(w)
else:
if w.meta.get("num_params") != weights_enum.DEFAULT.meta.get("num_params"):
if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()):
incorrect_params.append(w)
# loading the model and using it for parameter and ops verification
model = model_fn(weights=w)
if w.meta.get("num_params") != sum(p.numel() for p in model.parameters()):
incorrect_meta.append((w, "num_params"))
kwargs = {}
if model_name in detection_models_input_dims:
# detection models have non default height and width
height, width = detection_models_input_dims[model_name]
kwargs = {"height": height, "width": width}
if not model_fn.__name__.startswith("vit"):
# FIXME: https://github.com/pytorch/vision/issues/7871
calculated_ops = get_ops(model=model, weight=w, **kwargs)
if calculated_ops != w.meta["_ops"]:
incorrect_meta.append((w, "_ops"))
if not w.name.isupper():
bad_names.append(w)
if get_file_size_mb(w) != w.meta.get("_file_size"):
incorrect_meta.append((w, "_file_size"))
assert not problematic_weights
assert not incorrect_params
assert not incorrect_meta
assert not bad_names
......@@ -343,7 +491,11 @@ class TestHandleLegacyInterface:
+ TM.list_model_fns(models.quantization)
+ TM.list_model_fns(models.segmentation)
+ TM.list_model_fns(models.video)
+ TM.list_model_fns(models.optical_flow),
+ TM.list_model_fns(models.optical_flow)
+ [
lambda pretrained: resnet_fpn_backbone(backbone_name="resnet50", pretrained=pretrained),
lambda pretrained: mobilenet_backbone(backbone_name="mobilenet_v2", fpn=False, pretrained=pretrained),
],
)
@run_if_test_with_extended
def test_pretrained_deprecation(self, model_fn):
......
......@@ -2,17 +2,18 @@ import colorsys
import itertools
import math
import os
import re
import warnings
from functools import partial
from typing import Sequence
import numpy as np
import PIL.Image
import pytest
import torch
import torchvision.transforms as T
import torchvision.transforms._functional_pil as F_pil
import torchvision.transforms._functional_tensor as F_t
import torchvision.transforms.functional as F
import torchvision.transforms.functional_pil as F_pil
import torchvision.transforms.functional_tensor as F_t
from common_utils import (
_assert_approx_equal_tensor_to_pil,
_assert_equal_tensor_to_pil,
......@@ -20,15 +21,20 @@ from common_utils import (
_create_data_batch,
_test_fn_on_batch,
assert_equal,
cpu_and_gpu,
cpu_and_cuda,
needs_cuda,
)
from torchvision.transforms import InterpolationMode
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
NEAREST, NEAREST_EXACT, BILINEAR, BICUBIC = (
InterpolationMode.NEAREST,
InterpolationMode.NEAREST_EXACT,
InterpolationMode.BILINEAR,
InterpolationMode.BICUBIC,
)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("fn", [F.get_image_size, F.get_image_num_channels, F.get_dimensions])
def test_image_sizes(device, fn):
script_F = torch.jit.script(fn)
......@@ -66,7 +72,7 @@ class TestRotate:
scripted_rotate = torch.jit.script(F.rotate)
IMG_W = 26
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("height, width", [(7, 33), (26, IMG_W), (32, IMG_W)])
@pytest.mark.parametrize(
"center",
......@@ -125,7 +131,7 @@ class TestRotate:
f"{out_pil_tensor[0, :7, :7]}"
)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dt", ALL_DTYPES)
def test_rotate_batch(self, device, dt):
if dt == torch.float16 and device == "cpu":
......@@ -141,17 +147,9 @@ class TestRotate:
def test_rotate_interpolation_type(self):
tensor, _ = _create_data(26, 26)
# assert changed type warning
with pytest.warns(
UserWarning,
match=re.escape(
"Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
"Please use InterpolationMode enum."
),
):
res1 = F.rotate(tensor, 45, interpolation=2)
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
assert_equal(res1, res2)
res1 = F.rotate(tensor, 45, interpolation=PIL.Image.BILINEAR)
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
assert_equal(res1, res2)
class TestAffine:
......@@ -159,7 +157,7 @@ class TestAffine:
ALL_DTYPES = [None, torch.float32, torch.float64, torch.float16]
scripted_affine = torch.jit.script(F.affine)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("height, width", [(26, 26), (32, 26)])
@pytest.mark.parametrize("dt", ALL_DTYPES)
def test_identity_map(self, device, height, width, dt):
......@@ -182,7 +180,7 @@ class TestAffine:
)
assert_equal(tensor, out_tensor, msg=f"{out_tensor[0, :5, :5]} vs {tensor[0, :5, :5]}")
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("height, width", [(26, 26)])
@pytest.mark.parametrize("dt", ALL_DTYPES)
@pytest.mark.parametrize(
......@@ -226,7 +224,7 @@ class TestAffine:
# Tolerance : less than 6% of different pixels
assert ratio_diff_pixels < 0.06
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("height, width", [(32, 26)])
@pytest.mark.parametrize("dt", ALL_DTYPES)
@pytest.mark.parametrize("angle", [90, 45, 15, -30, -60, -120])
......@@ -260,7 +258,7 @@ class TestAffine:
# Tolerance : less than 3% of different pixels
assert ratio_diff_pixels < 0.03
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("height, width", [(26, 26), (32, 26)])
@pytest.mark.parametrize("dt", ALL_DTYPES)
@pytest.mark.parametrize("t", [[10, 12], (-12, -13)])
......@@ -285,7 +283,7 @@ class TestAffine:
_assert_equal_tensor_to_pil(out_tensor, out_pil_img)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("height, width", [(26, 26), (32, 26)])
@pytest.mark.parametrize("dt", ALL_DTYPES)
@pytest.mark.parametrize(
......@@ -295,24 +293,8 @@ class TestAffine:
(33, (5, -4), 1.0, [0.0, 0.0], [0, 0, 0]),
(45, [-5, 4], 1.2, [0.0, 0.0], (1, 2, 3)),
(33, (-4, -8), 2.0, [0.0, 0.0], [255, 255, 255]),
(
85,
(10, -10),
0.7,
[0.0, 0.0],
[
1,
],
),
(
0,
[0, 0],
1.0,
[
35.0,
],
(2.0,),
),
(85, (10, -10), 0.7, [0.0, 0.0], [1]),
(0, [0, 0], 1.0, [35.0], (2.0,)),
(-25, [0, 0], 1.2, [0.0, 15.0], None),
(-45, [-10, 0], 0.7, [2.0, 5.0], None),
(-45, [-10, -10], 1.2, [4.0, 5.0], None),
......@@ -346,7 +328,7 @@ class TestAffine:
tol = 0.06 if device == "cuda" else 0.05
assert ratio_diff_pixels < tol
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dt", ALL_DTYPES)
def test_batches(self, device, dt):
if dt == torch.float16 and device == "cpu":
......@@ -359,21 +341,13 @@ class TestAffine:
_test_fn_on_batch(batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0])
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_warnings(self, device):
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_interpolation_type(self, device):
tensor, pil_img = _create_data(26, 26, device=device)
# assert changed type warning
with pytest.warns(
UserWarning,
match=re.escape(
"Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
"Please use InterpolationMode enum."
),
):
res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=2)
res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
assert_equal(res1, res2)
res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=PIL.Image.BILINEAR)
res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
assert_equal(res1, res2)
def _get_data_dims_and_points_for_perspective():
......@@ -399,22 +373,10 @@ def _get_data_dims_and_points_for_perspective():
return dims_and_points
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dims_and_points", _get_data_dims_and_points_for_perspective())
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize(
"fill",
(
None,
[0, 0, 0],
[1, 2, 3],
[255, 255, 255],
[
1,
],
(2.0,),
),
)
@pytest.mark.parametrize("fill", (None, [0, 0, 0], [1, 2, 3], [255, 255, 255], [1], (2.0,)))
@pytest.mark.parametrize("fn", [F.perspective, torch.jit.script(F.perspective)])
def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn):
......@@ -445,7 +407,7 @@ def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn):
assert ratio_diff_pixels < 0.05
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dims_and_points", _get_data_dims_and_points_for_perspective())
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
def test_perspective_batch(device, dims_and_points, dt):
......@@ -473,40 +435,21 @@ def test_perspective_batch(device, dims_and_points, dt):
)
def test_perspective_interpolation_warning():
# assert changed type warning
def test_perspective_interpolation_type():
spoints = [[0, 0], [33, 0], [33, 25], [0, 25]]
epoints = [[3, 2], [32, 3], [30, 24], [2, 25]]
tensor = torch.randint(0, 256, (3, 26, 26))
with pytest.warns(
UserWarning,
match=re.escape(
"Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
"Please use InterpolationMode enum."
),
):
res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=2)
res2 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=BILINEAR)
assert_equal(res1, res2)
res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=PIL.Image.BILINEAR)
res2 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=BILINEAR)
assert_equal(res1, res2)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize(
"size",
[
32,
26,
[
32,
],
[32, 32],
(32, 32),
[26, 35],
],
)
@pytest.mark.parametrize("size", [32, 26, [32], [32, 32], (32, 32), [26, 35]])
@pytest.mark.parametrize("max_size", [None, 34, 40, 1000])
@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC, NEAREST])
@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC, NEAREST, NEAREST_EXACT])
def test_resize(device, dt, size, max_size, interpolation):
if dt == torch.float16 and device == "cpu":
......@@ -526,14 +469,12 @@ def test_resize(device, dt, size, max_size, interpolation):
tensor = tensor.to(dt)
batch_tensors = batch_tensors.to(dt)
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size)
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size, antialias=True)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size, antialias=True)
assert resized_tensor.size()[1:] == resized_pil_img.size[::-1]
if interpolation not in [
NEAREST,
]:
if interpolation != NEAREST:
# We can not check values if mode = NEAREST, as results are different
# E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]]
# E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
......@@ -543,36 +484,27 @@ def test_resize(device, dt, size, max_size, interpolation):
resized_tensor_f = resized_tensor_f.to(torch.float)
# Pay attention to high tolerance for MAE
_assert_approx_equal_tensor_to_pil(resized_tensor_f, resized_pil_img, tol=8.0)
_assert_approx_equal_tensor_to_pil(resized_tensor_f, resized_pil_img, tol=3.0)
if isinstance(size, int):
script_size = [
size,
]
script_size = [size]
else:
script_size = size
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, max_size=max_size)
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, max_size=max_size, antialias=True)
assert_equal(resized_tensor, resize_result)
_test_fn_on_batch(batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size)
_test_fn_on_batch(
batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size, antialias=True
)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_resize_asserts(device):
tensor, pil_img = _create_data(26, 36, device=device)
# assert changed type warning
with pytest.warns(
UserWarning,
match=re.escape(
"Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
"Please use InterpolationMode enum."
),
):
res1 = F.resize(tensor, size=32, interpolation=2)
res1 = F.resize(tensor, size=32, interpolation=PIL.Image.BILINEAR)
res2 = F.resize(tensor, size=32, interpolation=BILINEAR)
assert_equal(res1, res2)
......@@ -584,7 +516,7 @@ def test_resize_asserts(device):
F.resize(img, size=32, max_size=32)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize("size", [[96, 72], [96, 420], [420, 72]])
@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC])
......@@ -603,7 +535,7 @@ def test_resize_antialias(device, dt, size, interpolation):
tensor = tensor.to(dt)
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, antialias=True)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, antialias=True)
assert resized_tensor.size()[1:] == resized_pil_img.size[::-1]
......@@ -637,38 +569,21 @@ def test_resize_antialias(device, dt, size, interpolation):
assert_equal(resized_tensor, resize_result)
@needs_cuda
@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC])
def test_assert_resize_antialias(interpolation):
# Checks implementation on very large scales
# and catch TORCH_CHECK inside PyTorch implementation
torch.manual_seed(12)
tensor, _ = _create_data(1000, 1000, device="cuda")
# Error message is not yet updated in pytorch nightly
# with pytest.raises(RuntimeError, match=r"Provided interpolation parameters can not be handled"):
with pytest.raises(RuntimeError, match=r"Too much shared memory required"):
F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dt", [torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize("size", [[10, 7], [10, 42], [42, 7]])
@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC])
def test_interpolate_antialias_backward(device, dt, size, interpolation):
def test_resize_antialias_default_warning():
if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case
return
img = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8)
torch.manual_seed(12)
x = (torch.rand(1, 32, 29, 3, dtype=torch.double, device=device).permute(0, 3, 1, 2).requires_grad_(True),)
resize = partial(F.resize, size=size, interpolation=interpolation, antialias=True)
assert torch.autograd.gradcheck(resize, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False)
match = "The default value of the antialias"
with pytest.warns(UserWarning, match=match):
F.resize(img, size=(20, 20))
with pytest.warns(UserWarning, match=match):
F.resized_crop(img, 0, 0, 10, 10, size=(20, 20))
x = (torch.rand(1, 3, 32, 29, dtype=torch.double, device=device, requires_grad=True),)
assert torch.autograd.gradcheck(resize, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False)
# For modes that aren't bicubic or bilinear, don't throw a warning
with warnings.catch_warnings():
warnings.simplefilter("error")
F.resize(img, size=(20, 20), interpolation=NEAREST)
F.resized_crop(img, 0, 0, 10, 10, size=(20, 20), interpolation=NEAREST)
def check_functional_vs_PIL_vs_scripted(
......@@ -708,7 +623,7 @@ def check_functional_vs_PIL_vs_scripted(
_test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=atol, **config)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize("config", [{"brightness_factor": f} for f in (0.1, 0.5, 1.0, 1.34, 2.5)])
@pytest.mark.parametrize("channels", [1, 3])
......@@ -724,7 +639,7 @@ def test_adjust_brightness(device, dtype, config, channels):
)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize("channels", [1, 3])
def test_invert(device, dtype, channels):
......@@ -733,7 +648,7 @@ def test_invert(device, dtype, channels):
)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("config", [{"bits": bits} for bits in range(0, 8)])
@pytest.mark.parametrize("channels", [1, 3])
def test_posterize(device, config, channels):
......@@ -750,7 +665,7 @@ def test_posterize(device, config, channels):
)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("config", [{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]])
@pytest.mark.parametrize("channels", [1, 3])
def test_solarize1(device, config, channels):
......@@ -767,7 +682,7 @@ def test_solarize1(device, config, channels):
)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dtype", (torch.float32, torch.float64))
@pytest.mark.parametrize("config", [{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]])
@pytest.mark.parametrize("channels", [1, 3])
......@@ -785,37 +700,45 @@ def test_solarize2(device, dtype, config, channels):
)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("threshold", [0.0, 0.25, 0.5, 0.75, 1.0])
def test_solarize_threshold1_bound(threshold, device):
img = torch.rand((3, 12, 23)).to(device)
F_t.solarize(img, threshold)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("threshold", [1.5])
def test_solarize_threshold1_upper_bound(threshold, device):
img = torch.rand((3, 12, 23)).to(device)
with pytest.raises(TypeError, match="Threshold should be less than bound of img."):
F_t.solarize(img, threshold)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("threshold", [0, 64, 128, 192, 255])
def test_solarize_threshold2_bound(threshold, device):
img = torch.randint(0, 256, (3, 12, 23)).to(device)
@pytest.mark.parametrize(
("dtype", "threshold"),
[
*[
(dtype, threshold)
for dtype, threshold in itertools.product(
[torch.float32, torch.float16],
[0.0, 0.25, 0.5, 0.75, 1.0],
)
],
*[(torch.uint8, threshold) for threshold in [0, 64, 128, 192, 255]],
*[(torch.int64, threshold) for threshold in [0, 2**32, 2**63 - 1]],
],
)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_solarize_threshold_within_bound(threshold, dtype, device):
make_img = torch.rand if dtype.is_floating_point else partial(torch.randint, 0, torch.iinfo(dtype).max)
img = make_img((3, 12, 23), dtype=dtype, device=device)
F_t.solarize(img, threshold)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("threshold", [260])
def test_solarize_threshold2_upper_bound(threshold, device):
img = torch.randint(0, 256, (3, 12, 23)).to(device)
@pytest.mark.parametrize(
("dtype", "threshold"),
[
(torch.float32, 1.5),
(torch.float16, 1.5),
(torch.uint8, 260),
(torch.int64, 2**64),
],
)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_solarize_threshold_above_bound(threshold, dtype, device):
make_img = torch.rand if dtype.is_floating_point else partial(torch.randint, 0, torch.iinfo(dtype).max)
img = make_img((3, 12, 23), dtype=dtype, device=device)
with pytest.raises(TypeError, match="Threshold should be less than bound of img."):
F_t.solarize(img, threshold)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize("config", [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]])
@pytest.mark.parametrize("channels", [1, 3])
......@@ -831,7 +754,7 @@ def test_adjust_sharpness(device, dtype, config, channels):
)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize("channels", [1, 3])
def test_autocontrast(device, dtype, channels):
......@@ -840,7 +763,7 @@ def test_autocontrast(device, dtype, channels):
)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize("channels", [1, 3])
def test_autocontrast_equal_minmax(device, dtype, channels):
......@@ -852,7 +775,7 @@ def test_autocontrast_equal_minmax(device, dtype, channels):
assert (F.autocontrast(a)[0] == F.autocontrast(a[0])).all()
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("channels", [1, 3])
def test_equalize(device, channels):
torch.use_deterministic_algorithms(False)
......@@ -869,7 +792,7 @@ def test_equalize(device, channels):
)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize("config", [{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]])
@pytest.mark.parametrize("channels", [1, 3])
......@@ -879,7 +802,7 @@ def test_adjust_contrast(device, dtype, config, channels):
)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize("config", [{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]])
@pytest.mark.parametrize("channels", [1, 3])
......@@ -889,7 +812,7 @@ def test_adjust_saturation(device, dtype, config, channels):
)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize("config", [{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]])
@pytest.mark.parametrize("channels", [1, 3])
......@@ -899,7 +822,7 @@ def test_adjust_hue(device, dtype, config, channels):
)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize("config", [{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])])
@pytest.mark.parametrize("channels", [1, 3])
......@@ -915,7 +838,7 @@ def test_adjust_gamma(device, dtype, config, channels):
)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize("pad", [2, [3], [0, 3], (3, 3), [4, 2, 4, 3]])
@pytest.mark.parametrize(
......@@ -965,14 +888,16 @@ def test_pad(device, dt, pad, config):
_test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **config)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("mode", [NEAREST, BILINEAR, BICUBIC])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("mode", [NEAREST, NEAREST_EXACT, BILINEAR, BICUBIC])
def test_resized_crop(device, mode):
# test values of F.resized_crop in several cases:
# 1) resize to the same size, crop to the same size => should be identity
tensor, _ = _create_data(26, 36, device=device)
out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode)
out_tensor = F.resized_crop(
tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode, antialias=True
)
assert_equal(tensor, out_tensor, msg=f"{out_tensor[0, :5, :5]} vs {tensor[0, :5, :5]}")
# 2) resize by half and crop a TL corner
......@@ -987,11 +912,18 @@ def test_resized_crop(device, mode):
batch_tensors = _create_data_batch(26, 36, num_samples=4, device=device)
_test_fn_on_batch(
batch_tensors, F.resized_crop, top=1, left=2, height=20, width=30, size=[10, 15], interpolation=NEAREST
batch_tensors,
F.resized_crop,
top=1,
left=2,
height=20,
width=30,
size=[10, 15],
interpolation=NEAREST,
)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize(
"func, args",
[
......@@ -1024,7 +956,7 @@ def test_assert_image_tensor(device, func, args):
func(tensor, *args)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_vflip(device):
script_vflip = torch.jit.script(F.vflip)
......@@ -1041,7 +973,7 @@ def test_vflip(device):
_test_fn_on_batch(batch_tensors, F.vflip)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_hflip(device):
script_hflip = torch.jit.script(F.hflip)
......@@ -1058,7 +990,7 @@ def test_hflip(device):
_test_fn_on_batch(batch_tensors, F.hflip)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize(
"top, left, height, width",
[
......@@ -1087,7 +1019,7 @@ def test_crop(device, top, left, height, width):
_test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("image_size", ("small", "large"))
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)])
......@@ -1141,7 +1073,7 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn):
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_hsv2rgb(device):
scripted_fn = torch.jit.script(F_t._hsv2rgb)
shape = (3, 100, 150)
......@@ -1172,7 +1104,7 @@ def test_hsv2rgb(device):
_test_fn_on_batch(batch_tensors, F_t._hsv2rgb)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_rgb2hsv(device):
scripted_fn = torch.jit.script(F_t._rgb2hsv)
shape = (3, 150, 100)
......@@ -1211,7 +1143,7 @@ def test_rgb2hsv(device):
_test_fn_on_batch(batch_tensors, F_t._rgb2hsv)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("num_output_channels", (3, 1))
def test_rgb_to_grayscale(device, num_output_channels):
script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)
......@@ -1230,7 +1162,7 @@ def test_rgb_to_grayscale(device, num_output_channels):
_test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_center_crop(device):
script_center_crop = torch.jit.script(F.center_crop)
......@@ -1248,7 +1180,7 @@ def test_center_crop(device):
_test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11])
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_five_crop(device):
script_five_crop = torch.jit.script(F.five_crop)
......@@ -1282,7 +1214,7 @@ def test_five_crop(device):
assert_equal(transformed_batch, s_transformed_batch)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_ten_crop(device):
script_ten_crop = torch.jit.script(F.ten_crop)
......@@ -1328,7 +1260,7 @@ def test_elastic_transform_asserts():
_ = F.elastic_transform(img_tensor, displacement=torch.rand(1, 2))
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC])
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment