Commit cc26cd81 authored by panning's avatar panning
Browse files

merge v0.16.0

parents f78f29f5 fbb4cc54
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
import collections.abc
import dataclasses
from typing import Optional, Sequence
import pytest
import torch
from torch.nn.functional import one_hot
from torchvision.prototype import tv_tensors
from transforms_v2_legacy_utils import combinations_grid, DEFAULT_EXTRA_DIMS, from_loader, from_loaders, TensorLoader
@dataclasses.dataclass
class LabelLoader(TensorLoader):
categories: Optional[Sequence[str]]
def _parse_categories(categories):
if categories is None:
num_categories = int(torch.randint(1, 11, ()))
elif isinstance(categories, int):
num_categories = categories
categories = [f"category{idx}" for idx in range(num_categories)]
elif isinstance(categories, collections.abc.Sequence) and all(isinstance(category, str) for category in categories):
categories = list(categories)
num_categories = len(categories)
else:
raise pytest.UsageError(
f"`categories` can either be `None` (default), an integer, or a sequence of strings, "
f"but got '{categories}' instead."
)
return categories, num_categories
def make_label_loader(*, extra_dims=(), categories=None, dtype=torch.int64):
categories, num_categories = _parse_categories(categories)
def fn(shape, dtype, device):
# The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values,
# regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123
data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype)
return tv_tensors.Label(data, categories=categories)
return LabelLoader(fn, shape=extra_dims, dtype=dtype, categories=categories)
make_label = from_loader(make_label_loader)
@dataclasses.dataclass
class OneHotLabelLoader(TensorLoader):
categories: Optional[Sequence[str]]
def make_one_hot_label_loader(*, categories=None, extra_dims=(), dtype=torch.int64):
categories, num_categories = _parse_categories(categories)
def fn(shape, dtype, device):
if num_categories == 0:
data = torch.empty(shape, dtype=dtype, device=device)
else:
# The idiom `make_label_loader(..., dtype=torch.int64); ...; one_hot(...).to(dtype)` is intentional
# since `one_hot` only supports int64
label = make_label_loader(extra_dims=extra_dims, categories=num_categories, dtype=torch.int64).load(device)
data = one_hot(label, num_classes=num_categories).to(dtype)
return tv_tensors.OneHotLabel(data, categories=categories)
return OneHotLabelLoader(fn, shape=(*extra_dims, num_categories), dtype=dtype, categories=categories)
def make_one_hot_label_loaders(
*,
categories=(1, 0, None),
extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.int64, torch.float32),
):
for params in combinations_grid(categories=categories, extra_dims=extra_dims, dtype=dtypes):
yield make_one_hot_label_loader(**params)
make_one_hot_labels = from_loaders(make_one_hot_label_loaders)
"""Run smoke tests"""
import sys
from pathlib import Path
import torch import torch
import torchvision import torchvision
import torchvision.datasets as dset from torchvision.io import decode_jpeg, read_file, read_image
import torchvision.transforms from torchvision.models import resnet50, ResNet50_Weights
SCRIPT_DIR = Path(__file__).parent
def smoke_test_torchvision() -> None:
print(
"Is torchvision usable?",
all(x is not None for x in [torch.ops.image.decode_png, torch.ops.torchvision.roi_align]),
)
def smoke_test_torchvision_read_decode() -> None:
img_jpg = read_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
if img_jpg.shape != (3, 606, 517):
raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")
img_png = read_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png"))
if img_png.shape != (4, 471, 354):
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
def smoke_test_torchvision_decode_jpeg(device: str = "cpu"):
img_jpg_data = read_file(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
img_jpg = decode_jpeg(img_jpg_data, device=device)
if img_jpg.shape != (3, 606, 517):
raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")
def smoke_test_compile() -> None:
try:
model = resnet50().cuda()
model = torch.compile(model)
x = torch.randn(1, 3, 224, 224, device="cuda")
out = model(x)
print(f"torch.compile model output: {out.shape}")
except RuntimeError:
if sys.platform == "win32":
print("Successfully caught torch.compile RuntimeError on win")
elif sys.version_info >= (3, 11, 0):
print("Successfully caught torch.compile RuntimeError on Python 3.11")
else:
raise
def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device)
# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights).to(device)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
expected_category = "German shepherd"
print(f"{category_name} ({device}): {100 * score:.1f}%")
if category_name != expected_category:
raise RuntimeError(f"Failed ResNet50 classify {category_name} Expected: {expected_category}")
def main() -> None:
print(f"torchvision: {torchvision.__version__}")
print(f"torch.cuda.is_available: {torch.cuda.is_available()}")
# Turn 1.11.0aHASH into 1.11 (major.minor only)
version = ".".join(torchvision.__version__.split(".")[:2])
if version >= "0.16":
print(f"{torch.ops.image._jpeg_version() = }")
assert torch.ops.image._is_compiled_against_turbo()
smoke_test_torchvision()
smoke_test_torchvision_read_decode()
smoke_test_torchvision_resnet50_classify()
smoke_test_torchvision_decode_jpeg()
if torch.cuda.is_available():
smoke_test_torchvision_decode_jpeg("cuda")
smoke_test_torchvision_resnet50_classify("cuda")
smoke_test_compile()
if torch.backends.mps.is_available():
smoke_test_torchvision_resnet50_classify("mps")
if __name__ == "__main__":
main()
...@@ -20,7 +20,7 @@ class MaxvitTester(unittest.TestCase): ...@@ -20,7 +20,7 @@ class MaxvitTester(unittest.TestCase):
x_hat = partition(x, partition_size) x_hat = partition(x, partition_size)
x_hat = departition(x_hat, partition_size, n_partitions, n_partitions) x_hat = departition(x_hat, partition_size, n_partitions, n_partitions)
assert torch.allclose(x, x_hat) torch.testing.assert_close(x, x_hat)
def test_maxvit_grid_partition(self): def test_maxvit_grid_partition(self):
input_shape = (1, 3, 224, 224) input_shape = (1, 3, 224, 224)
...@@ -39,7 +39,7 @@ class MaxvitTester(unittest.TestCase): ...@@ -39,7 +39,7 @@ class MaxvitTester(unittest.TestCase):
x_hat = post_swap(x_hat) x_hat = post_swap(x_hat)
x_hat = departition(x_hat, n_partitions, partition_size, partition_size) x_hat = departition(x_hat, n_partitions, partition_size, partition_size)
assert torch.allclose(x, x_hat) torch.testing.assert_close(x, x_hat)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -194,7 +194,7 @@ class TestFxFeatureExtraction: ...@@ -194,7 +194,7 @@ class TestFxFeatureExtraction:
assert n1 == n2 assert n1 == n2
assert p1.equal(p2) assert p1.equal(p2)
# And that ouputs match # And that outputs match
with torch.no_grad(): with torch.no_grad():
ilg_out = ilg_model(self.inp) ilg_out = ilg_model(self.inp)
fgn_out = fx_model(self.inp) fgn_out = fx_model(self.inp)
......
...@@ -8,6 +8,7 @@ import os ...@@ -8,6 +8,7 @@ import os
import pathlib import pathlib
import pickle import pickle
import random import random
import re
import shutil import shutil
import string import string
import unittest import unittest
...@@ -21,12 +22,13 @@ import PIL ...@@ -21,12 +22,13 @@ import PIL
import pytest import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from common_utils import combinations_grid
from torchvision import datasets from torchvision import datasets
class STL10TestCase(datasets_utils.ImageDatasetTestCase): class STL10TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.STL10 DATASET_CLASS = datasets.STL10
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test", "unlabeled", "train+unlabeled")) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test", "unlabeled", "train+unlabeled"))
@staticmethod @staticmethod
def _make_binary_file(num_elements, root, name): def _make_binary_file(num_elements, root, name):
...@@ -112,9 +114,7 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -112,9 +114,7 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Caltech101 DATASET_CLASS = datasets.Caltech101
FEATURE_TYPES = (PIL.Image.Image, (int, np.ndarray, tuple)) FEATURE_TYPES = (PIL.Image.Image, (int, np.ndarray, tuple))
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( ADDITIONAL_CONFIGS = combinations_grid(target_type=("category", "annotation", ["category", "annotation"]))
target_type=("category", "annotation", ["category", "annotation"])
)
REQUIRED_PACKAGES = ("scipy",) REQUIRED_PACKAGES = ("scipy",)
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
...@@ -183,6 +183,10 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -183,6 +183,10 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
), "Type of the combined target does not match the type of the corresponding individual target: " ), "Type of the combined target does not match the type of the corresponding individual target: "
f"{actual} is not {expected}", f"{actual} is not {expected}",
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(target_type="category") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class Caltech256TestCase(datasets_utils.ImageDatasetTestCase): class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Caltech256 DATASET_CLASS = datasets.Caltech256
...@@ -190,7 +194,7 @@ class Caltech256TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -190,7 +194,7 @@ class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir) / "caltech256" / "256_ObjectCategories" tmpdir = pathlib.Path(tmpdir) / "caltech256" / "256_ObjectCategories"
categories = ((1, "ak47"), (127, "laptop-101"), (257, "clutter")) categories = ((1, "ak47"), (2, "american-flag"), (3, "backpack"))
num_images_per_category = 2 num_images_per_category = 2
for idx, category in categories: for idx, category in categories:
...@@ -207,7 +211,7 @@ class Caltech256TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -207,7 +211,7 @@ class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase): class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.WIDERFace DATASET_CLASS = datasets.WIDERFace
FEATURE_TYPES = (PIL.Image.Image, (dict, type(None))) # test split returns None as target FEATURE_TYPES = (PIL.Image.Image, (dict, type(None))) # test split returns None as target
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test")) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
widerface_dir = pathlib.Path(tmpdir) / "widerface" widerface_dir = pathlib.Path(tmpdir) / "widerface"
...@@ -258,6 +262,10 @@ class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -258,6 +262,10 @@ class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase):
return split_to_num_examples[config["split"]] return split_to_num_examples[config["split"]]
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class CityScapesTestCase(datasets_utils.ImageDatasetTestCase): class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Cityscapes DATASET_CLASS = datasets.Cityscapes
...@@ -268,8 +276,8 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -268,8 +276,8 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
"color", "color",
) )
ADDITIONAL_CONFIGS = ( ADDITIONAL_CONFIGS = (
*datasets_utils.combinations_grid(mode=("fine",), split=("train", "test", "val"), target_type=TARGET_TYPES), *combinations_grid(mode=("fine",), split=("train", "test", "val"), target_type=TARGET_TYPES),
*datasets_utils.combinations_grid( *combinations_grid(
mode=("coarse",), mode=("coarse",),
split=("train", "train_extra", "val"), split=("train", "train_extra", "val"),
target_type=TARGET_TYPES, target_type=TARGET_TYPES,
...@@ -382,11 +390,16 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -382,11 +390,16 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
assert isinstance(polygon_img, PIL.Image.Image) assert isinstance(polygon_img, PIL.Image.Image)
(polygon_target, info["expected_polygon_target"]) (polygon_target, info["expected_polygon_target"])
def test_transforms_v2_wrapper_spawn(self):
for target_type in ["instance", "semantic", ["instance", "semantic"]]:
with self.create_dataset(target_type=target_type) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class ImageNetTestCase(datasets_utils.ImageDatasetTestCase): class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.ImageNet DATASET_CLASS = datasets.ImageNet
REQUIRED_PACKAGES = ("scipy",) REQUIRED_PACKAGES = ("scipy",)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val")) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val"))
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir) tmpdir = pathlib.Path(tmpdir)
...@@ -413,10 +426,14 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -413,10 +426,14 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
torch.save((wnid_to_classes, None), tmpdir / "meta.bin") torch.save((wnid_to_classes, None), tmpdir / "meta.bin")
return num_examples return num_examples
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase): class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CIFAR10 DATASET_CLASS = datasets.CIFAR10
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False)) ADDITIONAL_CONFIGS = combinations_grid(train=(True, False))
_VERSION_CONFIG = dict( _VERSION_CONFIG = dict(
base_folder="cifar-10-batches-py", base_folder="cifar-10-batches-py",
...@@ -489,7 +506,7 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase): ...@@ -489,7 +506,7 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CelebA DATASET_CLASS = datasets.CelebA
FEATURE_TYPES = (PIL.Image.Image, (torch.Tensor, int, tuple, type(None))) FEATURE_TYPES = (PIL.Image.Image, (torch.Tensor, int, tuple, type(None)))
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( ADDITIONAL_CONFIGS = combinations_grid(
split=("train", "valid", "test", "all"), split=("train", "valid", "test", "all"),
target_type=("attr", "identity", "bbox", "landmarks", ["attr", "identity"]), target_type=("attr", "identity", "bbox", "landmarks", ["attr", "identity"]),
) )
...@@ -607,15 +624,18 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase): ...@@ -607,15 +624,18 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
assert merged_imgs_names == all_imgs_names assert merged_imgs_names == all_imgs_names
def test_transforms_v2_wrapper_spawn(self):
for target_type in ["identity", "bbox", ["identity", "bbox"]]:
with self.create_dataset(target_type=target_type) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase): class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.VOCSegmentation DATASET_CLASS = datasets.VOCSegmentation
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image) FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image)
ADDITIONAL_CONFIGS = ( ADDITIONAL_CONFIGS = (
*datasets_utils.combinations_grid( *combinations_grid(year=[f"20{year:02d}" for year in range(7, 13)], image_set=("train", "val", "trainval")),
year=[f"20{year:02d}" for year in range(7, 13)], image_set=("train", "val", "trainval")
),
dict(year="2007", image_set="test"), dict(year="2007", image_set="test"),
) )
...@@ -696,6 +716,10 @@ class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -696,6 +716,10 @@ class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
return data return data
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class VOCDetectionTestCase(VOCSegmentationTestCase): class VOCDetectionTestCase(VOCSegmentationTestCase):
DATASET_CLASS = datasets.VOCDetection DATASET_CLASS = datasets.VOCDetection
...@@ -716,6 +740,10 @@ class VOCDetectionTestCase(VOCSegmentationTestCase): ...@@ -716,6 +740,10 @@ class VOCDetectionTestCase(VOCSegmentationTestCase):
assert object == info["annotation"] assert object == info["annotation"]
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase): class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CocoDetection DATASET_CLASS = datasets.CocoDetection
...@@ -763,11 +791,21 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -763,11 +791,21 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
return info return info
def _create_annotations(self, image_ids, num_annotations_per_image): def _create_annotations(self, image_ids, num_annotations_per_image):
annotations = datasets_utils.combinations_grid( annotations = []
image_id=image_ids, bbox=([1.0, 2.0, 3.0, 4.0],) * num_annotations_per_image annotion_id = 0
) for image_id in itertools.islice(itertools.cycle(image_ids), len(image_ids) * num_annotations_per_image):
for id, annotation in enumerate(annotations): annotations.append(
annotation["id"] = id dict(
image_id=image_id,
id=annotion_id,
bbox=torch.rand(4).tolist(),
segmentation=[torch.rand(8).tolist()],
category_id=int(torch.randint(91, ())),
area=float(torch.rand(1)),
iscrowd=int(torch.randint(2, size=(1,))),
)
)
annotion_id += 1
return annotations, dict() return annotations, dict()
def _create_json(self, root, name, content): def _create_json(self, root, name, content):
...@@ -776,13 +814,17 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -776,13 +814,17 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
json.dump(content, fh) json.dump(content, fh)
return file return file
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class CocoCaptionsTestCase(CocoDetectionTestCase): class CocoCaptionsTestCase(CocoDetectionTestCase):
DATASET_CLASS = datasets.CocoCaptions DATASET_CLASS = datasets.CocoCaptions
def _create_annotations(self, image_ids, num_annotations_per_image): def _create_annotations(self, image_ids, num_annotations_per_image):
captions = [str(idx) for idx in range(num_annotations_per_image)] captions = [str(idx) for idx in range(num_annotations_per_image)]
annotations = datasets_utils.combinations_grid(image_id=image_ids, caption=captions) annotations = combinations_grid(image_id=image_ids, caption=captions)
for id, annotation in enumerate(annotations): for id, annotation in enumerate(annotations):
annotation["id"] = id annotation["id"] = id
return annotations, dict(captions=captions) return annotations, dict(captions=captions)
...@@ -792,11 +834,16 @@ class CocoCaptionsTestCase(CocoDetectionTestCase): ...@@ -792,11 +834,16 @@ class CocoCaptionsTestCase(CocoDetectionTestCase):
_, captions = dataset[0] _, captions = dataset[0]
assert tuple(captions) == tuple(info["captions"]) assert tuple(captions) == tuple(info["captions"])
def test_transforms_v2_wrapper_spawn(self):
# We need to define this method, because otherwise the test from the super class will
# be run
pytest.skip("CocoCaptions is currently not supported by the v2 wrapper.")
class UCF101TestCase(datasets_utils.VideoDatasetTestCase): class UCF101TestCase(datasets_utils.VideoDatasetTestCase):
DATASET_CLASS = datasets.UCF101 DATASET_CLASS = datasets.UCF101
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(fold=(1, 2, 3), train=(True, False)) ADDITIONAL_CONFIGS = combinations_grid(fold=(1, 2, 3), train=(True, False))
_VIDEO_FOLDER = "videos" _VIDEO_FOLDER = "videos"
_ANNOTATIONS_FOLDER = "annotations" _ANNOTATIONS_FOLDER = "annotations"
...@@ -857,9 +904,7 @@ class LSUNTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -857,9 +904,7 @@ class LSUNTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.LSUN DATASET_CLASS = datasets.LSUN
REQUIRED_PACKAGES = ("lmdb",) REQUIRED_PACKAGES = ("lmdb",)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( ADDITIONAL_CONFIGS = combinations_grid(classes=("train", "test", "val", ["bedroom_train", "church_outdoor_train"]))
classes=("train", "test", "val", ["bedroom_train", "church_outdoor_train"])
)
_CATEGORIES = ( _CATEGORIES = (
"bedroom", "bedroom",
...@@ -944,7 +989,7 @@ class LSUNTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -944,7 +989,7 @@ class LSUNTestCase(datasets_utils.ImageDatasetTestCase):
class KineticsTestCase(datasets_utils.VideoDatasetTestCase): class KineticsTestCase(datasets_utils.VideoDatasetTestCase):
DATASET_CLASS = datasets.Kinetics DATASET_CLASS = datasets.Kinetics
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val"), num_classes=("400", "600", "700")) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val"), num_classes=("400", "600", "700"))
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
classes = ("Abseiling", "Zumba") classes = ("Abseiling", "Zumba")
...@@ -960,11 +1005,15 @@ class KineticsTestCase(datasets_utils.VideoDatasetTestCase): ...@@ -960,11 +1005,15 @@ class KineticsTestCase(datasets_utils.VideoDatasetTestCase):
) )
return num_videos_per_class * len(classes) return num_videos_per_class * len(classes)
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(output_format="TCHW") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class HMDB51TestCase(datasets_utils.VideoDatasetTestCase): class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
DATASET_CLASS = datasets.HMDB51 DATASET_CLASS = datasets.HMDB51
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(fold=(1, 2, 3), train=(True, False)) ADDITIONAL_CONFIGS = combinations_grid(fold=(1, 2, 3), train=(True, False))
_VIDEO_FOLDER = "videos" _VIDEO_FOLDER = "videos"
_SPLITS_FOLDER = "splits" _SPLITS_FOLDER = "splits"
...@@ -1024,7 +1073,7 @@ class HMDB51TestCase(datasets_utils.VideoDatasetTestCase): ...@@ -1024,7 +1073,7 @@ class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
class OmniglotTestCase(datasets_utils.ImageDatasetTestCase): class OmniglotTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Omniglot DATASET_CLASS = datasets.Omniglot
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(background=(True, False)) ADDITIONAL_CONFIGS = combinations_grid(background=(True, False))
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
target_folder = ( target_folder = (
...@@ -1104,7 +1153,7 @@ class SEMEIONTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1104,7 +1153,7 @@ class SEMEIONTestCase(datasets_utils.ImageDatasetTestCase):
class USPSTestCase(datasets_utils.ImageDatasetTestCase): class USPSTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.USPS DATASET_CLASS = datasets.USPS
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False)) ADDITIONAL_CONFIGS = combinations_grid(train=(True, False))
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
num_images = 2 if config["train"] else 1 num_images = 2 if config["train"] else 1
...@@ -1126,7 +1175,7 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1126,7 +1175,7 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase):
REQUIRED_PACKAGES = ("scipy.io", "scipy.sparse") REQUIRED_PACKAGES = ("scipy.io", "scipy.sparse")
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( ADDITIONAL_CONFIGS = combinations_grid(
image_set=("train", "val", "train_noval"), mode=("boundaries", "segmentation") image_set=("train", "val", "train_noval"), mode=("boundaries", "segmentation")
) )
...@@ -1187,6 +1236,10 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1187,6 +1236,10 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase):
def _file_stem(self, idx): def _file_stem(self, idx):
return f"2008_{idx:06d}" return f"2008_{idx:06d}"
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(mode="segmentation") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class FakeDataTestCase(datasets_utils.ImageDatasetTestCase): class FakeDataTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FakeData DATASET_CLASS = datasets.FakeData
...@@ -1212,7 +1265,7 @@ class PhotoTourTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1212,7 +1265,7 @@ class PhotoTourTestCase(datasets_utils.ImageDatasetTestCase):
_TRAIN_FEATURE_TYPES = (torch.Tensor,) _TRAIN_FEATURE_TYPES = (torch.Tensor,)
_TEST_FEATURE_TYPES = (torch.Tensor, torch.Tensor, torch.Tensor) _TEST_FEATURE_TYPES = (torch.Tensor, torch.Tensor, torch.Tensor)
datasets_utils.combinations_grid(train=(True, False)) combinations_grid(train=(True, False))
_NAME = "liberty" _NAME = "liberty"
...@@ -1371,7 +1424,7 @@ class Flickr30kTestCase(Flickr8kTestCase): ...@@ -1371,7 +1424,7 @@ class Flickr30kTestCase(Flickr8kTestCase):
class MNISTTestCase(datasets_utils.ImageDatasetTestCase): class MNISTTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.MNIST DATASET_CLASS = datasets.MNIST
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False)) ADDITIONAL_CONFIGS = combinations_grid(train=(True, False))
_MAGIC_DTYPES = { _MAGIC_DTYPES = {
torch.uint8: 8, torch.uint8: 8,
...@@ -1441,7 +1494,7 @@ class EMNISTTestCase(MNISTTestCase): ...@@ -1441,7 +1494,7 @@ class EMNISTTestCase(MNISTTestCase):
DATASET_CLASS = datasets.EMNIST DATASET_CLASS = datasets.EMNIST
DEFAULT_CONFIG = dict(split="byclass") DEFAULT_CONFIG = dict(split="byclass")
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( ADDITIONAL_CONFIGS = combinations_grid(
split=("byclass", "bymerge", "balanced", "letters", "digits", "mnist"), train=(True, False) split=("byclass", "bymerge", "balanced", "letters", "digits", "mnist"), train=(True, False)
) )
...@@ -1452,7 +1505,7 @@ class EMNISTTestCase(MNISTTestCase): ...@@ -1452,7 +1505,7 @@ class EMNISTTestCase(MNISTTestCase):
class QMNISTTestCase(MNISTTestCase): class QMNISTTestCase(MNISTTestCase):
DATASET_CLASS = datasets.QMNIST DATASET_CLASS = datasets.QMNIST
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(what=("train", "test", "test10k", "nist")) ADDITIONAL_CONFIGS = combinations_grid(what=("train", "test", "test10k", "nist"))
_LABELS_SIZE = (8,) _LABELS_SIZE = (8,)
_LABELS_DTYPE = torch.int32 _LABELS_DTYPE = torch.int32
...@@ -1494,30 +1547,51 @@ class QMNISTTestCase(MNISTTestCase): ...@@ -1494,30 +1547,51 @@ class QMNISTTestCase(MNISTTestCase):
assert len(dataset) == info["num_examples"] - 10000 assert len(dataset) == info["num_examples"] - 10000
class MovingMNISTTestCase(datasets_utils.DatasetTestCase):
DATASET_CLASS = datasets.MovingMNIST
FEATURE_TYPES = (torch.Tensor,)
ADDITIONAL_CONFIGS = combinations_grid(split=(None, "train", "test"), split_ratio=(10, 1, 19))
_NUM_FRAMES = 20
def inject_fake_data(self, tmpdir, config):
base_folder = os.path.join(tmpdir, self.DATASET_CLASS.__name__)
os.makedirs(base_folder, exist_ok=True)
num_samples = 5
data = np.concatenate(
[
np.zeros((config["split_ratio"], num_samples, 64, 64)),
np.ones((self._NUM_FRAMES - config["split_ratio"], num_samples, 64, 64)),
]
)
np.save(os.path.join(base_folder, "mnist_test_seq.npy"), data)
return num_samples
@datasets_utils.test_all_configs
def test_split(self, config):
with self.create_dataset(config) as (dataset, _):
if config["split"] == "train":
assert (dataset.data == 0).all()
elif config["split"] == "test":
assert (dataset.data == 1).all()
else:
assert dataset.data.size()[1] == self._NUM_FRAMES
class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase): class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.DatasetFolder DATASET_CLASS = datasets.DatasetFolder
# The dataset has no fixed return type since it is defined by the loader parameter. For testing, we use a loader _EXTENSIONS = ("jpg", "png")
# that simply returns the path as type 'str' instead of loading anything. See the 'dataset_args()' method.
FEATURE_TYPES = (str, int)
_IMAGE_EXTENSIONS = ("jpg", "png")
_VIDEO_EXTENSIONS = ("avi", "mp4")
_EXTENSIONS = (*_IMAGE_EXTENSIONS, *_VIDEO_EXTENSIONS)
# DatasetFolder has two mutually exclusive parameters: 'extensions' and 'is_valid_file'. One of both is required. # DatasetFolder has two mutually exclusive parameters: 'extensions' and 'is_valid_file'. One of both is required.
# We only iterate over different 'extensions' here and handle the tests for 'is_valid_file' in the # We only iterate over different 'extensions' here and handle the tests for 'is_valid_file' in the
# 'test_is_valid_file()' method. # 'test_is_valid_file()' method.
DEFAULT_CONFIG = dict(extensions=_EXTENSIONS) DEFAULT_CONFIG = dict(extensions=_EXTENSIONS)
ADDITIONAL_CONFIGS = ( ADDITIONAL_CONFIGS = combinations_grid(extensions=[(ext,) for ext in _EXTENSIONS])
*datasets_utils.combinations_grid(extensions=[(ext,) for ext in _IMAGE_EXTENSIONS]),
dict(extensions=_IMAGE_EXTENSIONS),
*datasets_utils.combinations_grid(extensions=[(ext,) for ext in _VIDEO_EXTENSIONS]),
dict(extensions=_VIDEO_EXTENSIONS),
)
def dataset_args(self, tmpdir, config): def dataset_args(self, tmpdir, config):
return tmpdir, lambda x: x return tmpdir, datasets.folder.pil_loader
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
extensions = config["extensions"] or self._is_valid_file_to_extensions(config["is_valid_file"]) extensions = config["extensions"] or self._is_valid_file_to_extensions(config["is_valid_file"])
...@@ -1528,14 +1602,8 @@ class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1528,14 +1602,8 @@ class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
if ext not in extensions: if ext not in extensions:
continue continue
create_example_folder = (
datasets_utils.create_image_folder
if ext in self._IMAGE_EXTENSIONS
else datasets_utils.create_video_folder
)
num_examples = torch.randint(1, 3, size=()).item() num_examples = torch.randint(1, 3, size=()).item()
create_example_folder(tmpdir, cls, lambda idx: self._file_name_fn(cls, ext, idx), num_examples) datasets_utils.create_image_folder(tmpdir, cls, lambda idx: self._file_name_fn(cls, ext, idx), num_examples)
num_examples_total += num_examples num_examples_total += num_examples
classes.append(cls) classes.append(cls)
...@@ -1589,7 +1657,7 @@ class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1589,7 +1657,7 @@ class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase):
class KittiTestCase(datasets_utils.ImageDatasetTestCase): class KittiTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Kitti DATASET_CLASS = datasets.Kitti
FEATURE_TYPES = (PIL.Image.Image, (list, type(None))) # test split returns None as target FEATURE_TYPES = (PIL.Image.Image, (list, type(None))) # test split returns None as target
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False)) ADDITIONAL_CONFIGS = combinations_grid(train=(True, False))
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
kitti_dir = os.path.join(tmpdir, "Kitti", "raw") kitti_dir = os.path.join(tmpdir, "Kitti", "raw")
...@@ -1621,11 +1689,15 @@ class KittiTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1621,11 +1689,15 @@ class KittiTestCase(datasets_utils.ImageDatasetTestCase):
return split_to_num_examples[config["train"]] return split_to_num_examples[config["train"]]
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class SvhnTestCase(datasets_utils.ImageDatasetTestCase): class SvhnTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SVHN DATASET_CLASS = datasets.SVHN
REQUIRED_PACKAGES = ("scipy",) REQUIRED_PACKAGES = ("scipy",)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test", "extra")) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test", "extra"))
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
import scipy.io as sio import scipy.io as sio
...@@ -1646,7 +1718,7 @@ class SvhnTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1646,7 +1718,7 @@ class SvhnTestCase(datasets_utils.ImageDatasetTestCase):
class Places365TestCase(datasets_utils.ImageDatasetTestCase): class Places365TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Places365 DATASET_CLASS = datasets.Places365
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( ADDITIONAL_CONFIGS = combinations_grid(
split=("train-standard", "train-challenge", "val"), split=("train-standard", "train-challenge", "val"),
small=(False, True), small=(False, True),
) )
...@@ -1738,7 +1810,7 @@ class INaturalistTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1738,7 +1810,7 @@ class INaturalistTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.INaturalist DATASET_CLASS = datasets.INaturalist
FEATURE_TYPES = (PIL.Image.Image, (int, tuple)) FEATURE_TYPES = (PIL.Image.Image, (int, tuple))
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( ADDITIONAL_CONFIGS = combinations_grid(
target_type=("kingdom", "full", "genus", ["kingdom", "phylum", "class", "order", "family", "genus", "full"]), target_type=("kingdom", "full", "genus", ["kingdom", "phylum", "class", "order", "family", "genus", "full"]),
version=("2021_train",), version=("2021_train",),
) )
...@@ -1775,7 +1847,7 @@ class INaturalistTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1775,7 +1847,7 @@ class INaturalistTestCase(datasets_utils.ImageDatasetTestCase):
class LFWPeopleTestCase(datasets_utils.DatasetTestCase): class LFWPeopleTestCase(datasets_utils.DatasetTestCase):
DATASET_CLASS = datasets.LFWPeople DATASET_CLASS = datasets.LFWPeople
FEATURE_TYPES = (PIL.Image.Image, int) FEATURE_TYPES = (PIL.Image.Image, int)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( ADDITIONAL_CONFIGS = combinations_grid(
split=("10fold", "train", "test"), image_set=("original", "funneled", "deepfunneled") split=("10fold", "train", "test"), image_set=("original", "funneled", "deepfunneled")
) )
_IMAGES_DIR = {"original": "lfw", "funneled": "lfw_funneled", "deepfunneled": "lfw-deepfunneled"} _IMAGES_DIR = {"original": "lfw", "funneled": "lfw_funneled", "deepfunneled": "lfw-deepfunneled"}
...@@ -1851,7 +1923,7 @@ class LFWPairsTestCase(LFWPeopleTestCase): ...@@ -1851,7 +1923,7 @@ class LFWPairsTestCase(LFWPeopleTestCase):
class SintelTestCase(datasets_utils.ImageDatasetTestCase): class SintelTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Sintel DATASET_CLASS = datasets.Sintel
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"), pass_name=("clean", "final", "both")) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"), pass_name=("clean", "final", "both"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None))) FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
FLOW_H, FLOW_W = 3, 4 FLOW_H, FLOW_W = 3, 4
...@@ -1919,7 +1991,7 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1919,7 +1991,7 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
class KittiFlowTestCase(datasets_utils.ImageDatasetTestCase): class KittiFlowTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.KittiFlow DATASET_CLASS = datasets.KittiFlow
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None))) FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
...@@ -1979,7 +2051,7 @@ class KittiFlowTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1979,7 +2051,7 @@ class KittiFlowTestCase(datasets_utils.ImageDatasetTestCase):
class FlyingChairsTestCase(datasets_utils.ImageDatasetTestCase): class FlyingChairsTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FlyingChairs DATASET_CLASS = datasets.FlyingChairs
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val")) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None))) FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
FLOW_H, FLOW_W = 3, 4 FLOW_H, FLOW_W = 3, 4
...@@ -2034,7 +2106,7 @@ class FlyingChairsTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2034,7 +2106,7 @@ class FlyingChairsTestCase(datasets_utils.ImageDatasetTestCase):
class FlyingThings3DTestCase(datasets_utils.ImageDatasetTestCase): class FlyingThings3DTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FlyingThings3D DATASET_CLASS = datasets.FlyingThings3D
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( ADDITIONAL_CONFIGS = combinations_grid(
split=("train", "test"), pass_name=("clean", "final", "both"), camera=("left", "right", "both") split=("train", "test"), pass_name=("clean", "final", "both"), camera=("left", "right", "both")
) )
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None))) FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
...@@ -2171,7 +2243,7 @@ class Food101TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2171,7 +2243,7 @@ class Food101TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Food101 DATASET_CLASS = datasets.Food101
FEATURE_TYPES = (PIL.Image.Image, int) FEATURE_TYPES = (PIL.Image.Image, int)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
def inject_fake_data(self, tmpdir: str, config): def inject_fake_data(self, tmpdir: str, config):
root_folder = pathlib.Path(tmpdir) / "food-101" root_folder = pathlib.Path(tmpdir) / "food-101"
...@@ -2206,7 +2278,7 @@ class Food101TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2206,7 +2278,7 @@ class Food101TestCase(datasets_utils.ImageDatasetTestCase):
class FGVCAircraftTestCase(datasets_utils.ImageDatasetTestCase): class FGVCAircraftTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FGVCAircraft DATASET_CLASS = datasets.FGVCAircraft
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( ADDITIONAL_CONFIGS = combinations_grid(
split=("train", "val", "trainval", "test"), annotation_level=("variant", "family", "manufacturer") split=("train", "val", "trainval", "test"), annotation_level=("variant", "family", "manufacturer")
) )
...@@ -2289,7 +2361,7 @@ class DTDTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2289,7 +2361,7 @@ class DTDTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.DTD DATASET_CLASS = datasets.DTD
FEATURE_TYPES = (PIL.Image.Image, int) FEATURE_TYPES = (PIL.Image.Image, int)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( ADDITIONAL_CONFIGS = combinations_grid(
split=("train", "test", "val"), split=("train", "test", "val"),
# There is no need to test the whole matrix here, since each fold is treated exactly the same # There is no need to test the whole matrix here, since each fold is treated exactly the same
partition=(1, 5, 10), partition=(1, 5, 10),
...@@ -2323,7 +2395,7 @@ class DTDTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2323,7 +2395,7 @@ class DTDTestCase(datasets_utils.ImageDatasetTestCase):
class FER2013TestCase(datasets_utils.ImageDatasetTestCase): class FER2013TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FER2013 DATASET_CLASS = datasets.FER2013
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
FEATURE_TYPES = (PIL.Image.Image, (int, type(None))) FEATURE_TYPES = (PIL.Image.Image, (int, type(None)))
...@@ -2358,7 +2430,7 @@ class GTSRBTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2358,7 +2430,7 @@ class GTSRBTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.GTSRB DATASET_CLASS = datasets.GTSRB
FEATURE_TYPES = (PIL.Image.Image, int) FEATURE_TYPES = (PIL.Image.Image, int)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
def inject_fake_data(self, tmpdir: str, config): def inject_fake_data(self, tmpdir: str, config):
root_folder = os.path.join(tmpdir, "gtsrb") root_folder = os.path.join(tmpdir, "gtsrb")
...@@ -2408,7 +2480,7 @@ class CLEVRClassificationTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2408,7 +2480,7 @@ class CLEVRClassificationTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CLEVRClassification DATASET_CLASS = datasets.CLEVRClassification
FEATURE_TYPES = (PIL.Image.Image, (int, type(None))) FEATURE_TYPES = (PIL.Image.Image, (int, type(None)))
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test")) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
data_folder = pathlib.Path(tmpdir) / "clevr" / "CLEVR_v1.0" data_folder = pathlib.Path(tmpdir) / "clevr" / "CLEVR_v1.0"
...@@ -2440,7 +2512,7 @@ class OxfordIIITPetTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2440,7 +2512,7 @@ class OxfordIIITPetTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.OxfordIIITPet DATASET_CLASS = datasets.OxfordIIITPet
FEATURE_TYPES = (PIL.Image.Image, (int, PIL.Image.Image, tuple, type(None))) FEATURE_TYPES = (PIL.Image.Image, (int, PIL.Image.Image, tuple, type(None)))
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( ADDITIONAL_CONFIGS = combinations_grid(
split=("trainval", "test"), split=("trainval", "test"),
target_types=("category", "segmentation", ["category", "segmentation"], []), target_types=("category", "segmentation", ["category", "segmentation"], []),
) )
...@@ -2495,11 +2567,15 @@ class OxfordIIITPetTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2495,11 +2567,15 @@ class OxfordIIITPetTestCase(datasets_utils.ImageDatasetTestCase):
breed_id = "-1" breed_id = "-1"
return (image_id, class_id, species, breed_id) return (image_id, class_id, species, breed_id)
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase): class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.StanfordCars DATASET_CLASS = datasets.StanfordCars
REQUIRED_PACKAGES = ("scipy",) REQUIRED_PACKAGES = ("scipy",)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
import scipy.io as io import scipy.io as io
...@@ -2543,7 +2619,7 @@ class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2543,7 +2619,7 @@ class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):
class Country211TestCase(datasets_utils.ImageDatasetTestCase): class Country211TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Country211 DATASET_CLASS = datasets.Country211
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "valid", "test")) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "valid", "test"))
def inject_fake_data(self, tmpdir: str, config): def inject_fake_data(self, tmpdir: str, config):
split_folder = pathlib.Path(tmpdir) / "country211" / config["split"] split_folder = pathlib.Path(tmpdir) / "country211" / config["split"]
...@@ -2570,7 +2646,7 @@ class Country211TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2570,7 +2646,7 @@ class Country211TestCase(datasets_utils.ImageDatasetTestCase):
class Flowers102TestCase(datasets_utils.ImageDatasetTestCase): class Flowers102TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Flowers102 DATASET_CLASS = datasets.Flowers102
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test")) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
REQUIRED_PACKAGES = ("scipy",) REQUIRED_PACKAGES = ("scipy",)
def inject_fake_data(self, tmpdir: str, config): def inject_fake_data(self, tmpdir: str, config):
...@@ -2606,7 +2682,7 @@ class Flowers102TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2606,7 +2682,7 @@ class Flowers102TestCase(datasets_utils.ImageDatasetTestCase):
class PCAMTestCase(datasets_utils.ImageDatasetTestCase): class PCAMTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.PCAM DATASET_CLASS = datasets.PCAM
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test")) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
REQUIRED_PACKAGES = ("h5py",) REQUIRED_PACKAGES = ("h5py",)
def inject_fake_data(self, tmpdir: str, config): def inject_fake_data(self, tmpdir: str, config):
...@@ -2628,7 +2704,7 @@ class PCAMTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2628,7 +2704,7 @@ class PCAMTestCase(datasets_utils.ImageDatasetTestCase):
class RenderedSST2TestCase(datasets_utils.ImageDatasetTestCase): class RenderedSST2TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.RenderedSST2 DATASET_CLASS = datasets.RenderedSST2
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test")) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
SPLIT_TO_FOLDER = {"train": "train", "val": "valid", "test": "test"} SPLIT_TO_FOLDER = {"train": "train", "val": "valid", "test": "test"}
def inject_fake_data(self, tmpdir: str, config): def inject_fake_data(self, tmpdir: str, config):
...@@ -2650,7 +2726,7 @@ class RenderedSST2TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2650,7 +2726,7 @@ class RenderedSST2TestCase(datasets_utils.ImageDatasetTestCase):
class Kitti2012StereoTestCase(datasets_utils.ImageDatasetTestCase): class Kitti2012StereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Kitti2012Stereo DATASET_CLASS = datasets.Kitti2012Stereo
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None))) FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
...@@ -2712,7 +2788,7 @@ class Kitti2012StereoTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2712,7 +2788,7 @@ class Kitti2012StereoTestCase(datasets_utils.ImageDatasetTestCase):
class Kitti2015StereoTestCase(datasets_utils.ImageDatasetTestCase): class Kitti2015StereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Kitti2015Stereo DATASET_CLASS = datasets.Kitti2015Stereo
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None))) FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
...@@ -2850,7 +2926,7 @@ class CREStereoTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2850,7 +2926,7 @@ class CREStereoTestCase(datasets_utils.ImageDatasetTestCase):
class FallingThingsStereoTestCase(datasets_utils.ImageDatasetTestCase): class FallingThingsStereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FallingThingsStereo DATASET_CLASS = datasets.FallingThingsStereo
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(variant=("single", "mixed", "both")) ADDITIONAL_CONFIGS = combinations_grid(variant=("single", "mixed", "both"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None))) FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
@staticmethod @staticmethod
...@@ -2924,7 +3000,7 @@ class FallingThingsStereoTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2924,7 +3000,7 @@ class FallingThingsStereoTestCase(datasets_utils.ImageDatasetTestCase):
class SceneFlowStereoTestCase(datasets_utils.ImageDatasetTestCase): class SceneFlowStereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SceneFlowStereo DATASET_CLASS = datasets.SceneFlowStereo
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( ADDITIONAL_CONFIGS = combinations_grid(
variant=("FlyingThings3D", "Driving", "Monkaa"), pass_name=("clean", "final", "both") variant=("FlyingThings3D", "Driving", "Monkaa"), pass_name=("clean", "final", "both")
) )
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None))) FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
...@@ -3011,7 +3087,7 @@ class SceneFlowStereoTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -3011,7 +3087,7 @@ class SceneFlowStereoTestCase(datasets_utils.ImageDatasetTestCase):
class InStereo2k(datasets_utils.ImageDatasetTestCase): class InStereo2k(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.InStereo2k DATASET_CLASS = datasets.InStereo2k
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None))) FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
@staticmethod @staticmethod
def _make_scene_folder(root: str, name: str, size: Tuple[int, int]): def _make_scene_folder(root: str, name: str, size: Tuple[int, int]):
...@@ -3053,7 +3129,7 @@ class InStereo2k(datasets_utils.ImageDatasetTestCase): ...@@ -3053,7 +3129,7 @@ class InStereo2k(datasets_utils.ImageDatasetTestCase):
class SintelStereoTestCase(datasets_utils.ImageDatasetTestCase): class SintelStereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SintelStereo DATASET_CLASS = datasets.SintelStereo
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(pass_name=("final", "clean", "both")) ADDITIONAL_CONFIGS = combinations_grid(pass_name=("final", "clean", "both"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None))) FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
...@@ -3129,7 +3205,7 @@ class SintelStereoTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -3129,7 +3205,7 @@ class SintelStereoTestCase(datasets_utils.ImageDatasetTestCase):
class ETH3DStereoestCase(datasets_utils.ImageDatasetTestCase): class ETH3DStereoestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.ETH3DStereo DATASET_CLASS = datasets.ETH3DStereo
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None))) FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))
@staticmethod @staticmethod
...@@ -3196,7 +3272,7 @@ class ETH3DStereoestCase(datasets_utils.ImageDatasetTestCase): ...@@ -3196,7 +3272,7 @@ class ETH3DStereoestCase(datasets_utils.ImageDatasetTestCase):
class Middlebury2014StereoTestCase(datasets_utils.ImageDatasetTestCase): class Middlebury2014StereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Middlebury2014Stereo DATASET_CLASS = datasets.Middlebury2014Stereo
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( ADDITIONAL_CONFIGS = combinations_grid(
split=("train", "additional"), split=("train", "additional"),
calibration=("perfect", "imperfect", "both"), calibration=("perfect", "imperfect", "both"),
use_ambient_views=(True, False), use_ambient_views=(True, False),
...@@ -3287,5 +3363,47 @@ class Middlebury2014StereoTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -3287,5 +3363,47 @@ class Middlebury2014StereoTestCase(datasets_utils.ImageDatasetTestCase):
pass pass
class TestDatasetWrapper:
def test_unknown_type(self):
unknown_object = object()
with pytest.raises(
TypeError, match=re.escape("is meant for subclasses of `torchvision.datasets.VisionDataset`")
):
datasets.wrap_dataset_for_transforms_v2(unknown_object)
def test_unknown_dataset(self):
class MyVisionDataset(datasets.VisionDataset):
pass
dataset = MyVisionDataset("root")
with pytest.raises(TypeError, match="No wrapper exist"):
datasets.wrap_dataset_for_transforms_v2(dataset)
def test_missing_wrapper(self):
dataset = datasets.FakeData()
with pytest.raises(TypeError, match="please open an issue"):
datasets.wrap_dataset_for_transforms_v2(dataset)
def test_subclass(self, mocker):
from torchvision import tv_tensors
sentinel = object()
mocker.patch.dict(
tv_tensors._dataset_wrapper.WRAPPER_FACTORIES,
clear=False,
values={datasets.FakeData: lambda dataset, target_keys: lambda idx, sample: sentinel},
)
class MyFakeData(datasets.FakeData):
pass
dataset = MyFakeData()
wrapped_dataset = datasets.wrap_dataset_for_transforms_v2(dataset)
assert wrapped_dataset[0] is sentinel
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -2,6 +2,7 @@ import contextlib ...@@ -2,6 +2,7 @@ import contextlib
import itertools import itertools
import tempfile import tempfile
import time import time
import traceback
import unittest.mock import unittest.mock
import warnings import warnings
from datetime import datetime from datetime import datetime
...@@ -13,13 +14,7 @@ from urllib.request import Request, urlopen ...@@ -13,13 +14,7 @@ from urllib.request import Request, urlopen
import pytest import pytest
from torchvision import datasets from torchvision import datasets
from torchvision.datasets.utils import ( from torchvision.datasets.utils import _get_redirect_url, USER_AGENT
_get_redirect_url,
check_integrity,
download_file_from_google_drive,
download_url,
USER_AGENT,
)
def limit_requests_per_time(min_secs_between_requests=2.0): def limit_requests_per_time(min_secs_between_requests=2.0):
...@@ -83,63 +78,65 @@ urlopen = resolve_redirects()(urlopen) ...@@ -83,63 +78,65 @@ urlopen = resolve_redirects()(urlopen)
@contextlib.contextmanager @contextlib.contextmanager
def log_download_attempts( def log_download_attempts(
urls_and_md5s=None, urls,
file="utils", *,
patch=True, dataset_module,
mock_auxiliaries=None,
): ):
def add_mock(stack, name, file, **kwargs): def maybe_add_mock(*, module, name, stack, lst=None):
patcher = unittest.mock.patch(f"torchvision.datasets.{module}.{name}")
try: try:
return stack.enter_context(unittest.mock.patch(f"torchvision.datasets.{file}.{name}", **kwargs)) mock = stack.enter_context(patcher)
except AttributeError as error: except AttributeError:
if file != "utils": return
return add_mock(stack, name, "utils", **kwargs)
else:
raise pytest.UsageError from error
if urls_and_md5s is None:
urls_and_md5s = set()
if mock_auxiliaries is None:
mock_auxiliaries = patch
with contextlib.ExitStack() as stack: if lst is not None:
url_mock = add_mock(stack, "download_url", file, wraps=None if patch else download_url) lst.append(mock)
google_drive_mock = add_mock(
stack, "download_file_from_google_drive", file, wraps=None if patch else download_file_from_google_drive
)
if mock_auxiliaries: with contextlib.ExitStack() as stack:
add_mock(stack, "extract_archive", file) download_url_mocks = []
download_file_from_google_drive_mocks = []
for module in [dataset_module, "utils"]:
maybe_add_mock(module=module, name="download_url", stack=stack, lst=download_url_mocks)
maybe_add_mock(
module=module,
name="download_file_from_google_drive",
stack=stack,
lst=download_file_from_google_drive_mocks,
)
maybe_add_mock(module=module, name="extract_archive", stack=stack)
try: try:
yield urls_and_md5s yield
finally: finally:
for args, kwargs in url_mock.call_args_list: for download_url_mock in download_url_mocks:
url = args[0] for args, kwargs in download_url_mock.call_args_list:
md5 = args[-1] if len(args) == 4 else kwargs.get("md5") urls.append(args[0] if args else kwargs["url"])
urls_and_md5s.add((url, md5))
for args, kwargs in google_drive_mock.call_args_list: for download_file_from_google_drive_mock in download_file_from_google_drive_mocks:
id = args[0] for args, kwargs in download_file_from_google_drive_mock.call_args_list:
url = f"https://drive.google.com/file/d/{id}" file_id = args[0] if args else kwargs["file_id"]
md5 = args[3] if len(args) == 4 else kwargs.get("md5") urls.append(f"https://drive.google.com/file/d/{file_id}")
urls_and_md5s.add((url, md5))
def retry(fn, times=1, wait=5.0): def retry(fn, times=1, wait=5.0):
msgs = [] tbs = []
for _ in range(times + 1): for _ in range(times + 1):
try: try:
return fn() return fn()
except AssertionError as error: except AssertionError as error:
msgs.append(str(error)) tbs.append("".join(traceback.format_exception(type(error), error, error.__traceback__)))
time.sleep(wait) time.sleep(wait)
else: else:
raise AssertionError( raise AssertionError(
"\n".join( "\n".join(
( (
f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n", "\n",
*(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)), *[f"{'_' * 40} {idx:2d} {'_' * 40}\n\n{tb}" for idx, tb in enumerate(tbs, 1)],
(
f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time. "
f"You can find the the full tracebacks above."
),
) )
) )
) )
...@@ -149,10 +146,12 @@ def retry(fn, times=1, wait=5.0): ...@@ -149,10 +146,12 @@ def retry(fn, times=1, wait=5.0):
def assert_server_response_ok(): def assert_server_response_ok():
try: try:
yield yield
except URLError as error:
raise AssertionError("The request timed out.") from error
except HTTPError as error: except HTTPError as error:
raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error
except URLError as error:
raise AssertionError(
"Connection not possible due to SSL." if "SSL" in str(error) else "The request timed out."
) from error
except RecursionError as error: except RecursionError as error:
raise AssertionError(str(error)) from error raise AssertionError(str(error)) from error
...@@ -163,45 +162,14 @@ def assert_url_is_accessible(url, timeout=5.0): ...@@ -163,45 +162,14 @@ def assert_url_is_accessible(url, timeout=5.0):
urlopen(request, timeout=timeout) urlopen(request, timeout=timeout)
def assert_file_downloads_correctly(url, md5, tmpdir, timeout=5.0): def collect_urls(dataset_cls, *args, **kwargs):
file = path.join(tmpdir, path.basename(url)) urls = []
with assert_server_response_ok(): with contextlib.suppress(Exception), log_download_attempts(
with open(file, "wb") as fh: urls, dataset_module=dataset_cls.__module__.split(".")[-1]
request = Request(url, headers={"User-Agent": USER_AGENT}) ):
response = urlopen(request, timeout=timeout) dataset_cls(*args, **kwargs)
fh.write(response.read())
assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"
class DownloadConfig:
def __init__(self, url, md5=None, id=None):
self.url = url
self.md5 = md5
self.id = id or url
def __repr__(self) -> str: return [(url, f"{dataset_cls.__name__}, {url}") for url in urls]
return self.id
def make_download_configs(urls_and_md5s, name=None):
return [
DownloadConfig(url, md5=md5, id=f"{name}, {url}" if name is not None else None) for url, md5 in urls_and_md5s
]
def collect_download_configs(dataset_loader, name=None, **kwargs):
urls_and_md5s = set()
try:
with log_download_attempts(urls_and_md5s=urls_and_md5s, **kwargs):
dataset = dataset_loader()
except Exception:
dataset = None
if name is None and dataset is not None:
name = type(dataset).__name__
return make_download_configs(urls_and_md5s, name)
# This is a workaround since fixtures, such as the built-in tmp_dir, can only be used within a test but not within a # This is a workaround since fixtures, such as the built-in tmp_dir, can only be used within a test but not within a
...@@ -216,12 +184,14 @@ def root(): ...@@ -216,12 +184,14 @@ def root():
def places365(): def places365():
return itertools.chain( return itertools.chain.from_iterable(
*[ [
collect_download_configs( collect_urls(
lambda: datasets.Places365(ROOT, split=split, small=small, download=True), datasets.Places365,
name=f"Places365, {split}, {'small' if small else 'large'}", ROOT,
file="places365", split=split,
small=small,
download=True,
) )
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)) for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True))
] ]
...@@ -229,30 +199,26 @@ def places365(): ...@@ -229,30 +199,26 @@ def places365():
def caltech101(): def caltech101():
return collect_download_configs(lambda: datasets.Caltech101(ROOT, download=True), name="Caltech101") return collect_urls(datasets.Caltech101, ROOT, download=True)
def caltech256(): def caltech256():
return collect_download_configs(lambda: datasets.Caltech256(ROOT, download=True), name="Caltech256") return collect_urls(datasets.Caltech256, ROOT, download=True)
def cifar10(): def cifar10():
return collect_download_configs(lambda: datasets.CIFAR10(ROOT, download=True), name="CIFAR10") return collect_urls(datasets.CIFAR10, ROOT, download=True)
def cifar100(): def cifar100():
return collect_download_configs(lambda: datasets.CIFAR100(ROOT, download=True), name="CIFAR100") return collect_urls(datasets.CIFAR100, ROOT, download=True)
def voc(): def voc():
# TODO: Also test the "2007-test" key # TODO: Also test the "2007-test" key
return itertools.chain( return itertools.chain.from_iterable(
*[ [
collect_download_configs( collect_urls(datasets.VOCSegmentation, ROOT, year=year, download=True)
lambda: datasets.VOCSegmentation(ROOT, year=year, download=True),
name=f"VOC, {year}",
file="voc",
)
for year in ("2007", "2008", "2009", "2010", "2011", "2012") for year in ("2007", "2008", "2009", "2010", "2011", "2012")
] ]
) )
...@@ -260,55 +226,42 @@ def voc(): ...@@ -260,55 +226,42 @@ def voc():
def mnist(): def mnist():
with unittest.mock.patch.object(datasets.MNIST, "mirrors", datasets.MNIST.mirrors[-1:]): with unittest.mock.patch.object(datasets.MNIST, "mirrors", datasets.MNIST.mirrors[-1:]):
return collect_download_configs(lambda: datasets.MNIST(ROOT, download=True), name="MNIST") return collect_urls(datasets.MNIST, ROOT, download=True)
def fashion_mnist(): def fashion_mnist():
return collect_download_configs(lambda: datasets.FashionMNIST(ROOT, download=True), name="FashionMNIST") return collect_urls(datasets.FashionMNIST, ROOT, download=True)
def kmnist(): def kmnist():
return collect_download_configs(lambda: datasets.KMNIST(ROOT, download=True), name="KMNIST") return collect_urls(datasets.KMNIST, ROOT, download=True)
def emnist(): def emnist():
# the 'split' argument can be any valid one, since everything is downloaded anyway # the 'split' argument can be any valid one, since everything is downloaded anyway
return collect_download_configs(lambda: datasets.EMNIST(ROOT, split="byclass", download=True), name="EMNIST") return collect_urls(datasets.EMNIST, ROOT, split="byclass", download=True)
def qmnist(): def qmnist():
return itertools.chain( return itertools.chain.from_iterable(
*[ [collect_urls(datasets.QMNIST, ROOT, what=what, download=True) for what in ("train", "test", "nist")]
collect_download_configs(
lambda: datasets.QMNIST(ROOT, what=what, download=True),
name=f"QMNIST, {what}",
file="mnist",
)
for what in ("train", "test", "nist")
]
) )
def moving_mnist():
return collect_urls(datasets.MovingMNIST, ROOT, download=True)
def omniglot(): def omniglot():
return itertools.chain( return itertools.chain.from_iterable(
*[ [collect_urls(datasets.Omniglot, ROOT, background=background, download=True) for background in (True, False)]
collect_download_configs(
lambda: datasets.Omniglot(ROOT, background=background, download=True),
name=f"Omniglot, {'background' if background else 'evaluation'}",
)
for background in (True, False)
]
) )
def phototour(): def phototour():
return itertools.chain( return itertools.chain.from_iterable(
*[ [
collect_download_configs( collect_urls(datasets.PhotoTour, ROOT, name=name, download=True)
lambda: datasets.PhotoTour(ROOT, name=name, download=True),
name=f"PhotoTour, {name}",
file="phototour",
)
# The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all # The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all
# requests timeout from within CI. They are disabled until this is resolved. # requests timeout from within CI. They are disabled until this is resolved.
for name in ("notredame", "yosemite", "liberty") # "notredame_harris", "yosemite_harris", "liberty_harris" for name in ("notredame", "yosemite", "liberty") # "notredame_harris", "yosemite_harris", "liberty_harris"
...@@ -317,91 +270,51 @@ def phototour(): ...@@ -317,91 +270,51 @@ def phototour():
def sbdataset(): def sbdataset():
return collect_download_configs( return collect_urls(datasets.SBDataset, ROOT, download=True)
lambda: datasets.SBDataset(ROOT, download=True),
name="SBDataset",
file="voc",
)
def sbu(): def sbu():
return collect_download_configs( return collect_urls(datasets.SBU, ROOT, download=True)
lambda: datasets.SBU(ROOT, download=True),
name="SBU",
file="sbu",
)
def semeion(): def semeion():
return collect_download_configs( return collect_urls(datasets.SEMEION, ROOT, download=True)
lambda: datasets.SEMEION(ROOT, download=True),
name="SEMEION",
file="semeion",
)
def stl10(): def stl10():
return collect_download_configs( return collect_urls(datasets.STL10, ROOT, download=True)
lambda: datasets.STL10(ROOT, download=True),
name="STL10",
)
def svhn(): def svhn():
return itertools.chain( return itertools.chain.from_iterable(
*[ [collect_urls(datasets.SVHN, ROOT, split=split, download=True) for split in ("train", "test", "extra")]
collect_download_configs(
lambda: datasets.SVHN(ROOT, split=split, download=True),
name=f"SVHN, {split}",
file="svhn",
)
for split in ("train", "test", "extra")
]
) )
def usps(): def usps():
return itertools.chain( return itertools.chain.from_iterable(
*[ [collect_urls(datasets.USPS, ROOT, train=train, download=True) for train in (True, False)]
collect_download_configs(
lambda: datasets.USPS(ROOT, train=train, download=True),
name=f"USPS, {'train' if train else 'test'}",
file="usps",
)
for train in (True, False)
]
) )
def celeba(): def celeba():
return collect_download_configs( return collect_urls(datasets.CelebA, ROOT, download=True)
lambda: datasets.CelebA(ROOT, download=True),
name="CelebA",
file="celeba",
)
def widerface(): def widerface():
return collect_download_configs( return collect_urls(datasets.WIDERFace, ROOT, download=True)
lambda: datasets.WIDERFace(ROOT, download=True),
name="WIDERFace",
file="widerface",
)
def kinetics(): def kinetics():
return itertools.chain( return itertools.chain.from_iterable(
*[ [
collect_download_configs( collect_urls(
lambda: datasets.Kinetics( datasets.Kinetics,
path.join(ROOT, f"Kinetics{num_classes}"), path.join(ROOT, f"Kinetics{num_classes}"),
frames_per_clip=1, frames_per_clip=1,
num_classes=num_classes, num_classes=num_classes,
split=split, split=split,
download=True, download=True,
),
name=f"Kinetics, {num_classes}, {split}",
file="kinetics",
) )
for num_classes, split in itertools.product(("400", "600", "700"), ("train", "val")) for num_classes, split in itertools.product(("400", "600", "700"), ("train", "val"))
] ]
...@@ -409,58 +322,55 @@ def kinetics(): ...@@ -409,58 +322,55 @@ def kinetics():
def kitti(): def kitti():
return itertools.chain( return itertools.chain.from_iterable(
*[ [collect_urls(datasets.Kitti, ROOT, train=train, download=True) for train in (True, False)]
collect_download_configs(
lambda train=train: datasets.Kitti(ROOT, train=train, download=True),
name=f"Kitti, {'train' if train else 'test'}",
file="kitti",
)
for train in (True, False)
]
) )
def make_parametrize_kwargs(download_configs): def stanford_cars():
argvalues = [] return itertools.chain.from_iterable(
ids = [] [collect_urls(datasets.StanfordCars, ROOT, split=split, download=True) for split in ["train", "test"]]
for config in download_configs: )
argvalues.append((config.url, config.md5))
ids.append(config.id)
def url_parametrization(*dataset_urls_and_ids_fns):
return dict(argnames=("url", "md5"), argvalues=argvalues, ids=ids) return pytest.mark.parametrize(
"url",
[
@pytest.mark.parametrize( pytest.param(url, id=id)
**make_parametrize_kwargs( for dataset_urls_and_ids_fn in dataset_urls_and_ids_fns
itertools.chain( for url, id in sorted(set(dataset_urls_and_ids_fn()))
caltech101(), ],
caltech256(),
cifar10(),
cifar100(),
# The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details.
# voc(),
mnist(),
fashion_mnist(),
kmnist(),
emnist(),
qmnist(),
omniglot(),
phototour(),
sbdataset(),
sbu(),
semeion(),
stl10(),
svhn(),
usps(),
celeba(),
widerface(),
kinetics(),
kitti(),
)
) )
@url_parametrization(
caltech101,
caltech256,
cifar10,
cifar100,
# The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details.
# voc,
mnist,
fashion_mnist,
kmnist,
emnist,
qmnist,
omniglot,
phototour,
sbdataset,
semeion,
stl10,
svhn,
usps,
celeba,
widerface,
kinetics,
kitti,
places365,
sbu,
) )
def test_url_is_accessible(url, md5): def test_url_is_accessible(url):
""" """
If you see this test failing, find the offending dataset in the parametrization and move it to If you see this test failing, find the offending dataset in the parametrization and move it to
``test_url_is_not_accessible`` and link an issue detailing the problem. ``test_url_is_not_accessible`` and link an issue detailing the problem.
...@@ -468,15 +378,11 @@ def test_url_is_accessible(url, md5): ...@@ -468,15 +378,11 @@ def test_url_is_accessible(url, md5):
retry(lambda: assert_url_is_accessible(url)) retry(lambda: assert_url_is_accessible(url))
@pytest.mark.parametrize( @url_parametrization(
**make_parametrize_kwargs( stanford_cars, # https://github.com/pytorch/vision/issues/7545
itertools.chain(
places365(), # https://github.com/pytorch/vision/issues/6268
)
)
) )
@pytest.mark.xfail @pytest.mark.xfail
def test_url_is_not_accessible(url, md5): def test_url_is_not_accessible(url):
""" """
As the name implies, this test is the 'inverse' of ``test_url_is_accessible``. Since the download servers are As the name implies, this test is the 'inverse' of ``test_url_is_accessible``. Since the download servers are
beyond our control, some files might not be accessible for longer stretches of time. Still, we want to know if they beyond our control, some files might not be accessible for longer stretches of time. Still, we want to know if they
...@@ -486,8 +392,3 @@ def test_url_is_not_accessible(url, md5): ...@@ -486,8 +392,3 @@ def test_url_is_not_accessible(url, md5):
``test_url_is_accessible``. ``test_url_is_accessible``.
""" """
retry(lambda: assert_url_is_accessible(url)) retry(lambda: assert_url_is_accessible(url))
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain()))
def test_file_downloads_correctly(url, md5):
retry(lambda: assert_file_downloads_correctly(url, md5))
...@@ -7,7 +7,9 @@ import tarfile ...@@ -7,7 +7,9 @@ import tarfile
import zipfile import zipfile
import pytest import pytest
import torch
import torchvision.datasets.utils as utils import torchvision.datasets.utils as utils
from common_utils import assert_equal
from torch._utils_internal import get_file_path_2 from torch._utils_internal import get_file_path_2
from torchvision.datasets.folder import make_dataset from torchvision.datasets.folder import make_dataset
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS
...@@ -215,6 +217,24 @@ class TestDatasetsUtils: ...@@ -215,6 +217,24 @@ class TestDatasetsUtils:
pytest.raises(ValueError, utils.verify_str_arg, 0, ("a",), "arg") pytest.raises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg") pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")
@pytest.mark.parametrize(
("dtype", "actual_hex", "expected_hex"),
[
(torch.uint8, "01 23 45 67 89 AB CD EF", "01 23 45 67 89 AB CD EF"),
(torch.float16, "01 23 45 67 89 AB CD EF", "23 01 67 45 AB 89 EF CD"),
(torch.int32, "01 23 45 67 89 AB CD EF", "67 45 23 01 EF CD AB 89"),
(torch.float64, "01 23 45 67 89 AB CD EF", "EF CD AB 89 67 45 23 01"),
],
)
def test_flip_byte_order(self, dtype, actual_hex, expected_hex):
def to_tensor(hex):
return torch.frombuffer(bytes.fromhex(hex), dtype=dtype)
assert_equal(
utils._flip_byte_order(to_tensor(actual_hex)),
to_tensor(expected_hex),
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kwargs", "expected_error_msg"), ("kwargs", "expected_error_msg"),
......
import copy
import os import os
import pickle
import pytest import pytest
import test_models as TM import test_models as TM
import torch import torch
from common_extended_utils import get_file_size_mb, get_ops
from torchvision import models from torchvision import models
from torchvision.models._api import get_model_weights, Weights, WeightsEnum from torchvision.models import get_model_weights, Weights, WeightsEnum
from torchvision.models._utils import handle_legacy_interface from torchvision.models._utils import handle_legacy_interface
from torchvision.models.detection.backbone_utils import mobilenet_backbone, resnet_fpn_backbone
run_if_test_with_extended = pytest.mark.skipif( run_if_test_with_extended = pytest.mark.skipif(
os.getenv("PYTORCH_TEST_WITH_EXTENDED", "0") != "1", os.getenv("PYTORCH_TEST_WITH_EXTENDED", "0") != "1",
...@@ -59,17 +62,59 @@ def test_get_model_weights(name, weight): ...@@ -59,17 +62,59 @@ def test_get_model_weights(name, weight):
assert models.get_model_weights(name) == weight assert models.get_model_weights(name) == weight
@pytest.mark.parametrize("copy_fn", [copy.copy, copy.deepcopy])
@pytest.mark.parametrize(
"name",
[
"resnet50",
"retinanet_resnet50_fpn_v2",
"raft_large",
"quantized_resnet50",
"lraspp_mobilenet_v3_large",
"mvit_v1_b",
],
)
def test_weights_copyable(copy_fn, name):
for weights in list(models.get_model_weights(name)):
# It is somewhat surprising that (deep-)copying is an identity operation here, but this is the default behavior
# of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances
# Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop
# support for the identity operation in the future.
assert copy_fn(weights) is weights
@pytest.mark.parametrize(
"name",
[
"resnet50",
"retinanet_resnet50_fpn_v2",
"raft_large",
"quantized_resnet50",
"lraspp_mobilenet_v3_large",
"mvit_v1_b",
],
)
def test_weights_deserializable(name):
for weights in list(models.get_model_weights(name)):
# It is somewhat surprising that deserialization is an identity operation here, but this is the default behavior
# of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances
# Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop
# support for the identity operation in the future.
assert pickle.loads(pickle.dumps(weights)) is weights
def get_models_from_module(module):
return [
v.__name__
for k, v in module.__dict__.items()
if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__
]
@pytest.mark.parametrize( @pytest.mark.parametrize(
"module", [models, models.detection, models.quantization, models.segmentation, models.video, models.optical_flow] "module", [models, models.detection, models.quantization, models.segmentation, models.video, models.optical_flow]
) )
def test_list_models(module): def test_list_models(module):
def get_models_from_module(module):
return [
v.__name__
for k, v in module.__dict__.items()
if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__
]
a = set(get_models_from_module(module)) a = set(get_models_from_module(module))
b = set(x.replace("quantized_", "") for x in models.list_models(module)) b = set(x.replace("quantized_", "") for x in models.list_models(module))
...@@ -77,6 +122,65 @@ def test_list_models(module): ...@@ -77,6 +122,65 @@ def test_list_models(module):
assert a == b assert a == b
@pytest.mark.parametrize(
"include_filters",
[
None,
[],
(),
"",
"*resnet*",
["*alexnet*"],
"*not-existing-model-for-test?",
["*resnet*", "*alexnet*"],
["*resnet*", "*alexnet*", "*not-existing-model-for-test?"],
("*resnet*", "*alexnet*"),
set(["*resnet*", "*alexnet*"]),
],
)
@pytest.mark.parametrize(
"exclude_filters",
[
None,
[],
(),
"",
"*resnet*",
["*alexnet*"],
["*not-existing-model-for-test?"],
["resnet34", "*not-existing-model-for-test?"],
["resnet34", "*resnet1*"],
("resnet34", "*resnet1*"),
set(["resnet34", "*resnet1*"]),
],
)
def test_list_models_filters(include_filters, exclude_filters):
actual = set(models.list_models(models, include=include_filters, exclude=exclude_filters))
classification_models = set(get_models_from_module(models))
if isinstance(include_filters, str):
include_filters = [include_filters]
if isinstance(exclude_filters, str):
exclude_filters = [exclude_filters]
if include_filters:
expected = set()
for include_f in include_filters:
include_f = include_f.strip("*?")
expected = expected | set(x for x in classification_models if include_f in x)
else:
expected = classification_models
if exclude_filters:
for exclude_f in exclude_filters:
exclude_f = exclude_f.strip("*?")
if exclude_f != "":
a_exclude = set(x for x in classification_models if exclude_f in x)
expected = expected - a_exclude
assert expected == actual
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name, weight", "name, weight",
[ [
...@@ -111,6 +215,22 @@ def test_naming_conventions(model_fn): ...@@ -111,6 +215,22 @@ def test_naming_conventions(model_fn):
assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT") assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT")
detection_models_input_dims = {
"fasterrcnn_mobilenet_v3_large_320_fpn": (320, 320),
"fasterrcnn_mobilenet_v3_large_fpn": (800, 800),
"fasterrcnn_resnet50_fpn": (800, 800),
"fasterrcnn_resnet50_fpn_v2": (800, 800),
"fcos_resnet50_fpn": (800, 800),
"keypointrcnn_resnet50_fpn": (1333, 1333),
"maskrcnn_resnet50_fpn": (800, 800),
"maskrcnn_resnet50_fpn_v2": (800, 800),
"retinanet_resnet50_fpn": (800, 800),
"retinanet_resnet50_fpn_v2": (800, 800),
"ssd300_vgg16": (300, 300),
"ssdlite320_mobilenet_v3_large": (320, 320),
}
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_fn", "model_fn",
TM.list_model_fns(models) TM.list_model_fns(models)
...@@ -122,6 +242,9 @@ def test_naming_conventions(model_fn): ...@@ -122,6 +242,9 @@ def test_naming_conventions(model_fn):
) )
@run_if_test_with_extended @run_if_test_with_extended
def test_schema_meta_validation(model_fn): def test_schema_meta_validation(model_fn):
if model_fn.__name__ == "maskrcnn_resnet50_fpn_v2":
pytest.skip(reason="FIXME https://github.com/pytorch/vision/issues/7349")
# list of all possible supported high-level fields for weights meta-data # list of all possible supported high-level fields for weights meta-data
permitted_fields = { permitted_fields = {
"backend", "backend",
...@@ -135,11 +258,13 @@ def test_schema_meta_validation(model_fn): ...@@ -135,11 +258,13 @@ def test_schema_meta_validation(model_fn):
"recipe", "recipe",
"unquantized", "unquantized",
"_docs", "_docs",
"_ops",
"_file_size",
} }
# mandatory fields for each computer vision task # mandatory fields for each computer vision task
classification_fields = {"categories", ("_metrics", "ImageNet-1K", "acc@1"), ("_metrics", "ImageNet-1K", "acc@5")} classification_fields = {"categories", ("_metrics", "ImageNet-1K", "acc@1"), ("_metrics", "ImageNet-1K", "acc@5")}
defaults = { defaults = {
"all": {"_metrics", "min_size", "num_params", "recipe", "_docs"}, "all": {"_metrics", "min_size", "num_params", "recipe", "_docs", "_file_size", "_ops"},
"models": classification_fields, "models": classification_fields,
"detection": {"categories", ("_metrics", "COCO-val2017", "box_map")}, "detection": {"categories", ("_metrics", "COCO-val2017", "box_map")},
"quantization": classification_fields | {"backend", "unquantized"}, "quantization": classification_fields | {"backend", "unquantized"},
...@@ -160,7 +285,7 @@ def test_schema_meta_validation(model_fn): ...@@ -160,7 +285,7 @@ def test_schema_meta_validation(model_fn):
pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.") pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")
problematic_weights = {} problematic_weights = {}
incorrect_params = [] incorrect_meta = []
bad_names = [] bad_names = []
for w in weights_enum: for w in weights_enum:
actual_fields = set(w.meta.keys()) actual_fields = set(w.meta.keys())
...@@ -173,24 +298,47 @@ def test_schema_meta_validation(model_fn): ...@@ -173,24 +298,47 @@ def test_schema_meta_validation(model_fn):
unsupported_fields = set(w.meta.keys()) - permitted_fields unsupported_fields = set(w.meta.keys()) - permitted_fields
if missing_fields or unsupported_fields: if missing_fields or unsupported_fields:
problematic_weights[w] = {"missing": missing_fields, "unsupported": unsupported_fields} problematic_weights[w] = {"missing": missing_fields, "unsupported": unsupported_fields}
if w == weights_enum.DEFAULT:
if w == weights_enum.DEFAULT or any(w.meta[k] != weights_enum.DEFAULT.meta[k] for k in ["num_params", "_ops"]):
if module_name == "quantization": if module_name == "quantization":
# parameters() count doesn't work well with quantization, so we check against the non-quantized # parameters() count doesn't work well with quantization, so we check against the non-quantized
unquantized_w = w.meta.get("unquantized") unquantized_w = w.meta.get("unquantized")
if unquantized_w is not None and w.meta.get("num_params") != unquantized_w.meta.get("num_params"): if unquantized_w is not None:
incorrect_params.append(w) if w.meta.get("num_params") != unquantized_w.meta.get("num_params"):
incorrect_meta.append((w, "num_params"))
# the methodology for quantized ops count doesn't work as well, so we take unquantized FLOPs
# instead
if w.meta["_ops"] != unquantized_w.meta.get("_ops"):
incorrect_meta.append((w, "_ops"))
else: else:
if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()): # loading the model and using it for parameter and ops verification
incorrect_params.append(w) model = model_fn(weights=w)
else:
if w.meta.get("num_params") != weights_enum.DEFAULT.meta.get("num_params"): if w.meta.get("num_params") != sum(p.numel() for p in model.parameters()):
if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()): incorrect_meta.append((w, "num_params"))
incorrect_params.append(w)
kwargs = {}
if model_name in detection_models_input_dims:
# detection models have non default height and width
height, width = detection_models_input_dims[model_name]
kwargs = {"height": height, "width": width}
if not model_fn.__name__.startswith("vit"):
# FIXME: https://github.com/pytorch/vision/issues/7871
calculated_ops = get_ops(model=model, weight=w, **kwargs)
if calculated_ops != w.meta["_ops"]:
incorrect_meta.append((w, "_ops"))
if not w.name.isupper(): if not w.name.isupper():
bad_names.append(w) bad_names.append(w)
if get_file_size_mb(w) != w.meta.get("_file_size"):
incorrect_meta.append((w, "_file_size"))
assert not problematic_weights assert not problematic_weights
assert not incorrect_params assert not incorrect_meta
assert not bad_names assert not bad_names
...@@ -343,7 +491,11 @@ class TestHandleLegacyInterface: ...@@ -343,7 +491,11 @@ class TestHandleLegacyInterface:
+ TM.list_model_fns(models.quantization) + TM.list_model_fns(models.quantization)
+ TM.list_model_fns(models.segmentation) + TM.list_model_fns(models.segmentation)
+ TM.list_model_fns(models.video) + TM.list_model_fns(models.video)
+ TM.list_model_fns(models.optical_flow), + TM.list_model_fns(models.optical_flow)
+ [
lambda pretrained: resnet_fpn_backbone(backbone_name="resnet50", pretrained=pretrained),
lambda pretrained: mobilenet_backbone(backbone_name="mobilenet_v2", fpn=False, pretrained=pretrained),
],
) )
@run_if_test_with_extended @run_if_test_with_extended
def test_pretrained_deprecation(self, model_fn): def test_pretrained_deprecation(self, model_fn):
......
...@@ -2,17 +2,18 @@ import colorsys ...@@ -2,17 +2,18 @@ import colorsys
import itertools import itertools
import math import math
import os import os
import re import warnings
from functools import partial from functools import partial
from typing import Sequence from typing import Sequence
import numpy as np import numpy as np
import PIL.Image
import pytest import pytest
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
import torchvision.transforms._functional_pil as F_pil
import torchvision.transforms._functional_tensor as F_t
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
import torchvision.transforms.functional_pil as F_pil
import torchvision.transforms.functional_tensor as F_t
from common_utils import ( from common_utils import (
_assert_approx_equal_tensor_to_pil, _assert_approx_equal_tensor_to_pil,
_assert_equal_tensor_to_pil, _assert_equal_tensor_to_pil,
...@@ -20,15 +21,20 @@ from common_utils import ( ...@@ -20,15 +21,20 @@ from common_utils import (
_create_data_batch, _create_data_batch,
_test_fn_on_batch, _test_fn_on_batch,
assert_equal, assert_equal,
cpu_and_gpu, cpu_and_cuda,
needs_cuda, needs_cuda,
) )
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC NEAREST, NEAREST_EXACT, BILINEAR, BICUBIC = (
InterpolationMode.NEAREST,
InterpolationMode.NEAREST_EXACT,
InterpolationMode.BILINEAR,
InterpolationMode.BICUBIC,
)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("fn", [F.get_image_size, F.get_image_num_channels, F.get_dimensions]) @pytest.mark.parametrize("fn", [F.get_image_size, F.get_image_num_channels, F.get_dimensions])
def test_image_sizes(device, fn): def test_image_sizes(device, fn):
script_F = torch.jit.script(fn) script_F = torch.jit.script(fn)
...@@ -66,7 +72,7 @@ class TestRotate: ...@@ -66,7 +72,7 @@ class TestRotate:
scripted_rotate = torch.jit.script(F.rotate) scripted_rotate = torch.jit.script(F.rotate)
IMG_W = 26 IMG_W = 26
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("height, width", [(7, 33), (26, IMG_W), (32, IMG_W)]) @pytest.mark.parametrize("height, width", [(7, 33), (26, IMG_W), (32, IMG_W)])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"center", "center",
...@@ -125,7 +131,7 @@ class TestRotate: ...@@ -125,7 +131,7 @@ class TestRotate:
f"{out_pil_tensor[0, :7, :7]}" f"{out_pil_tensor[0, :7, :7]}"
) )
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dt", ALL_DTYPES) @pytest.mark.parametrize("dt", ALL_DTYPES)
def test_rotate_batch(self, device, dt): def test_rotate_batch(self, device, dt):
if dt == torch.float16 and device == "cpu": if dt == torch.float16 and device == "cpu":
...@@ -141,17 +147,9 @@ class TestRotate: ...@@ -141,17 +147,9 @@ class TestRotate:
def test_rotate_interpolation_type(self): def test_rotate_interpolation_type(self):
tensor, _ = _create_data(26, 26) tensor, _ = _create_data(26, 26)
# assert changed type warning res1 = F.rotate(tensor, 45, interpolation=PIL.Image.BILINEAR)
with pytest.warns( res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
UserWarning, assert_equal(res1, res2)
match=re.escape(
"Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
"Please use InterpolationMode enum."
),
):
res1 = F.rotate(tensor, 45, interpolation=2)
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
assert_equal(res1, res2)
class TestAffine: class TestAffine:
...@@ -159,7 +157,7 @@ class TestAffine: ...@@ -159,7 +157,7 @@ class TestAffine:
ALL_DTYPES = [None, torch.float32, torch.float64, torch.float16] ALL_DTYPES = [None, torch.float32, torch.float64, torch.float16]
scripted_affine = torch.jit.script(F.affine) scripted_affine = torch.jit.script(F.affine)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("height, width", [(26, 26), (32, 26)]) @pytest.mark.parametrize("height, width", [(26, 26), (32, 26)])
@pytest.mark.parametrize("dt", ALL_DTYPES) @pytest.mark.parametrize("dt", ALL_DTYPES)
def test_identity_map(self, device, height, width, dt): def test_identity_map(self, device, height, width, dt):
...@@ -182,7 +180,7 @@ class TestAffine: ...@@ -182,7 +180,7 @@ class TestAffine:
) )
assert_equal(tensor, out_tensor, msg=f"{out_tensor[0, :5, :5]} vs {tensor[0, :5, :5]}") assert_equal(tensor, out_tensor, msg=f"{out_tensor[0, :5, :5]} vs {tensor[0, :5, :5]}")
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("height, width", [(26, 26)]) @pytest.mark.parametrize("height, width", [(26, 26)])
@pytest.mark.parametrize("dt", ALL_DTYPES) @pytest.mark.parametrize("dt", ALL_DTYPES)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -226,7 +224,7 @@ class TestAffine: ...@@ -226,7 +224,7 @@ class TestAffine:
# Tolerance : less than 6% of different pixels # Tolerance : less than 6% of different pixels
assert ratio_diff_pixels < 0.06 assert ratio_diff_pixels < 0.06
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("height, width", [(32, 26)]) @pytest.mark.parametrize("height, width", [(32, 26)])
@pytest.mark.parametrize("dt", ALL_DTYPES) @pytest.mark.parametrize("dt", ALL_DTYPES)
@pytest.mark.parametrize("angle", [90, 45, 15, -30, -60, -120]) @pytest.mark.parametrize("angle", [90, 45, 15, -30, -60, -120])
...@@ -260,7 +258,7 @@ class TestAffine: ...@@ -260,7 +258,7 @@ class TestAffine:
# Tolerance : less than 3% of different pixels # Tolerance : less than 3% of different pixels
assert ratio_diff_pixels < 0.03 assert ratio_diff_pixels < 0.03
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("height, width", [(26, 26), (32, 26)]) @pytest.mark.parametrize("height, width", [(26, 26), (32, 26)])
@pytest.mark.parametrize("dt", ALL_DTYPES) @pytest.mark.parametrize("dt", ALL_DTYPES)
@pytest.mark.parametrize("t", [[10, 12], (-12, -13)]) @pytest.mark.parametrize("t", [[10, 12], (-12, -13)])
...@@ -285,7 +283,7 @@ class TestAffine: ...@@ -285,7 +283,7 @@ class TestAffine:
_assert_equal_tensor_to_pil(out_tensor, out_pil_img) _assert_equal_tensor_to_pil(out_tensor, out_pil_img)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("height, width", [(26, 26), (32, 26)]) @pytest.mark.parametrize("height, width", [(26, 26), (32, 26)])
@pytest.mark.parametrize("dt", ALL_DTYPES) @pytest.mark.parametrize("dt", ALL_DTYPES)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -295,24 +293,8 @@ class TestAffine: ...@@ -295,24 +293,8 @@ class TestAffine:
(33, (5, -4), 1.0, [0.0, 0.0], [0, 0, 0]), (33, (5, -4), 1.0, [0.0, 0.0], [0, 0, 0]),
(45, [-5, 4], 1.2, [0.0, 0.0], (1, 2, 3)), (45, [-5, 4], 1.2, [0.0, 0.0], (1, 2, 3)),
(33, (-4, -8), 2.0, [0.0, 0.0], [255, 255, 255]), (33, (-4, -8), 2.0, [0.0, 0.0], [255, 255, 255]),
( (85, (10, -10), 0.7, [0.0, 0.0], [1]),
85, (0, [0, 0], 1.0, [35.0], (2.0,)),
(10, -10),
0.7,
[0.0, 0.0],
[
1,
],
),
(
0,
[0, 0],
1.0,
[
35.0,
],
(2.0,),
),
(-25, [0, 0], 1.2, [0.0, 15.0], None), (-25, [0, 0], 1.2, [0.0, 15.0], None),
(-45, [-10, 0], 0.7, [2.0, 5.0], None), (-45, [-10, 0], 0.7, [2.0, 5.0], None),
(-45, [-10, -10], 1.2, [4.0, 5.0], None), (-45, [-10, -10], 1.2, [4.0, 5.0], None),
...@@ -346,7 +328,7 @@ class TestAffine: ...@@ -346,7 +328,7 @@ class TestAffine:
tol = 0.06 if device == "cuda" else 0.05 tol = 0.06 if device == "cuda" else 0.05
assert ratio_diff_pixels < tol assert ratio_diff_pixels < tol
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dt", ALL_DTYPES) @pytest.mark.parametrize("dt", ALL_DTYPES)
def test_batches(self, device, dt): def test_batches(self, device, dt):
if dt == torch.float16 and device == "cpu": if dt == torch.float16 and device == "cpu":
...@@ -359,21 +341,13 @@ class TestAffine: ...@@ -359,21 +341,13 @@ class TestAffine:
_test_fn_on_batch(batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0]) _test_fn_on_batch(batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0])
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
def test_warnings(self, device): def test_interpolation_type(self, device):
tensor, pil_img = _create_data(26, 26, device=device) tensor, pil_img = _create_data(26, 26, device=device)
# assert changed type warning res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=PIL.Image.BILINEAR)
with pytest.warns( res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
UserWarning, assert_equal(res1, res2)
match=re.escape(
"Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
"Please use InterpolationMode enum."
),
):
res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=2)
res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
assert_equal(res1, res2)
def _get_data_dims_and_points_for_perspective(): def _get_data_dims_and_points_for_perspective():
...@@ -399,22 +373,10 @@ def _get_data_dims_and_points_for_perspective(): ...@@ -399,22 +373,10 @@ def _get_data_dims_and_points_for_perspective():
return dims_and_points return dims_and_points
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dims_and_points", _get_data_dims_and_points_for_perspective()) @pytest.mark.parametrize("dims_and_points", _get_data_dims_and_points_for_perspective())
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize( @pytest.mark.parametrize("fill", (None, [0, 0, 0], [1, 2, 3], [255, 255, 255], [1], (2.0,)))
"fill",
(
None,
[0, 0, 0],
[1, 2, 3],
[255, 255, 255],
[
1,
],
(2.0,),
),
)
@pytest.mark.parametrize("fn", [F.perspective, torch.jit.script(F.perspective)]) @pytest.mark.parametrize("fn", [F.perspective, torch.jit.script(F.perspective)])
def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn): def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn):
...@@ -445,7 +407,7 @@ def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn): ...@@ -445,7 +407,7 @@ def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn):
assert ratio_diff_pixels < 0.05 assert ratio_diff_pixels < 0.05
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dims_and_points", _get_data_dims_and_points_for_perspective()) @pytest.mark.parametrize("dims_and_points", _get_data_dims_and_points_for_perspective())
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
def test_perspective_batch(device, dims_and_points, dt): def test_perspective_batch(device, dims_and_points, dt):
...@@ -473,40 +435,21 @@ def test_perspective_batch(device, dims_and_points, dt): ...@@ -473,40 +435,21 @@ def test_perspective_batch(device, dims_and_points, dt):
) )
def test_perspective_interpolation_warning(): def test_perspective_interpolation_type():
# assert changed type warning
spoints = [[0, 0], [33, 0], [33, 25], [0, 25]] spoints = [[0, 0], [33, 0], [33, 25], [0, 25]]
epoints = [[3, 2], [32, 3], [30, 24], [2, 25]] epoints = [[3, 2], [32, 3], [30, 24], [2, 25]]
tensor = torch.randint(0, 256, (3, 26, 26)) tensor = torch.randint(0, 256, (3, 26, 26))
with pytest.warns(
UserWarning, res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=PIL.Image.BILINEAR)
match=re.escape( res2 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=BILINEAR)
"Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. " assert_equal(res1, res2)
"Please use InterpolationMode enum."
),
):
res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=2)
res2 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=BILINEAR)
assert_equal(res1, res2)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize( @pytest.mark.parametrize("size", [32, 26, [32], [32, 32], (32, 32), [26, 35]])
"size",
[
32,
26,
[
32,
],
[32, 32],
(32, 32),
[26, 35],
],
)
@pytest.mark.parametrize("max_size", [None, 34, 40, 1000]) @pytest.mark.parametrize("max_size", [None, 34, 40, 1000])
@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC, NEAREST]) @pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC, NEAREST, NEAREST_EXACT])
def test_resize(device, dt, size, max_size, interpolation): def test_resize(device, dt, size, max_size, interpolation):
if dt == torch.float16 and device == "cpu": if dt == torch.float16 and device == "cpu":
...@@ -526,14 +469,12 @@ def test_resize(device, dt, size, max_size, interpolation): ...@@ -526,14 +469,12 @@ def test_resize(device, dt, size, max_size, interpolation):
tensor = tensor.to(dt) tensor = tensor.to(dt)
batch_tensors = batch_tensors.to(dt) batch_tensors = batch_tensors.to(dt)
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size) resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size, antialias=True)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size) resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size, antialias=True)
assert resized_tensor.size()[1:] == resized_pil_img.size[::-1] assert resized_tensor.size()[1:] == resized_pil_img.size[::-1]
if interpolation not in [ if interpolation != NEAREST:
NEAREST,
]:
# We can not check values if mode = NEAREST, as results are different # We can not check values if mode = NEAREST, as results are different
# E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]] # E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]]
# E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]] # E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
...@@ -543,36 +484,27 @@ def test_resize(device, dt, size, max_size, interpolation): ...@@ -543,36 +484,27 @@ def test_resize(device, dt, size, max_size, interpolation):
resized_tensor_f = resized_tensor_f.to(torch.float) resized_tensor_f = resized_tensor_f.to(torch.float)
# Pay attention to high tolerance for MAE # Pay attention to high tolerance for MAE
_assert_approx_equal_tensor_to_pil(resized_tensor_f, resized_pil_img, tol=8.0) _assert_approx_equal_tensor_to_pil(resized_tensor_f, resized_pil_img, tol=3.0)
if isinstance(size, int): if isinstance(size, int):
script_size = [ script_size = [size]
size,
]
else: else:
script_size = size script_size = size
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, max_size=max_size) resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, max_size=max_size, antialias=True)
assert_equal(resized_tensor, resize_result) assert_equal(resized_tensor, resize_result)
_test_fn_on_batch(batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size) _test_fn_on_batch(
batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size, antialias=True
)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
def test_resize_asserts(device): def test_resize_asserts(device):
tensor, pil_img = _create_data(26, 36, device=device) tensor, pil_img = _create_data(26, 36, device=device)
# assert changed type warning res1 = F.resize(tensor, size=32, interpolation=PIL.Image.BILINEAR)
with pytest.warns(
UserWarning,
match=re.escape(
"Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
"Please use InterpolationMode enum."
),
):
res1 = F.resize(tensor, size=32, interpolation=2)
res2 = F.resize(tensor, size=32, interpolation=BILINEAR) res2 = F.resize(tensor, size=32, interpolation=BILINEAR)
assert_equal(res1, res2) assert_equal(res1, res2)
...@@ -584,7 +516,7 @@ def test_resize_asserts(device): ...@@ -584,7 +516,7 @@ def test_resize_asserts(device):
F.resize(img, size=32, max_size=32) F.resize(img, size=32, max_size=32)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize("size", [[96, 72], [96, 420], [420, 72]]) @pytest.mark.parametrize("size", [[96, 72], [96, 420], [420, 72]])
@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC]) @pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC])
...@@ -603,7 +535,7 @@ def test_resize_antialias(device, dt, size, interpolation): ...@@ -603,7 +535,7 @@ def test_resize_antialias(device, dt, size, interpolation):
tensor = tensor.to(dt) tensor = tensor.to(dt)
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, antialias=True) resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, antialias=True)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation) resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, antialias=True)
assert resized_tensor.size()[1:] == resized_pil_img.size[::-1] assert resized_tensor.size()[1:] == resized_pil_img.size[::-1]
...@@ -637,38 +569,21 @@ def test_resize_antialias(device, dt, size, interpolation): ...@@ -637,38 +569,21 @@ def test_resize_antialias(device, dt, size, interpolation):
assert_equal(resized_tensor, resize_result) assert_equal(resized_tensor, resize_result)
@needs_cuda def test_resize_antialias_default_warning():
@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC])
def test_assert_resize_antialias(interpolation):
# Checks implementation on very large scales
# and catch TORCH_CHECK inside PyTorch implementation
torch.manual_seed(12)
tensor, _ = _create_data(1000, 1000, device="cuda")
# Error message is not yet updated in pytorch nightly
# with pytest.raises(RuntimeError, match=r"Provided interpolation parameters can not be handled"):
with pytest.raises(RuntimeError, match=r"Too much shared memory required"):
F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dt", [torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize("size", [[10, 7], [10, 42], [42, 7]])
@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC])
def test_interpolate_antialias_backward(device, dt, size, interpolation):
if dt == torch.float16 and device == "cpu": img = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8)
# skip float16 on CPU case
return
torch.manual_seed(12) match = "The default value of the antialias"
x = (torch.rand(1, 32, 29, 3, dtype=torch.double, device=device).permute(0, 3, 1, 2).requires_grad_(True),) with pytest.warns(UserWarning, match=match):
resize = partial(F.resize, size=size, interpolation=interpolation, antialias=True) F.resize(img, size=(20, 20))
assert torch.autograd.gradcheck(resize, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False) with pytest.warns(UserWarning, match=match):
F.resized_crop(img, 0, 0, 10, 10, size=(20, 20))
x = (torch.rand(1, 3, 32, 29, dtype=torch.double, device=device, requires_grad=True),) # For modes that aren't bicubic or bilinear, don't throw a warning
assert torch.autograd.gradcheck(resize, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False) with warnings.catch_warnings():
warnings.simplefilter("error")
F.resize(img, size=(20, 20), interpolation=NEAREST)
F.resized_crop(img, 0, 0, 10, 10, size=(20, 20), interpolation=NEAREST)
def check_functional_vs_PIL_vs_scripted( def check_functional_vs_PIL_vs_scripted(
...@@ -708,7 +623,7 @@ def check_functional_vs_PIL_vs_scripted( ...@@ -708,7 +623,7 @@ def check_functional_vs_PIL_vs_scripted(
_test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=atol, **config) _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=atol, **config)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize("config", [{"brightness_factor": f} for f in (0.1, 0.5, 1.0, 1.34, 2.5)]) @pytest.mark.parametrize("config", [{"brightness_factor": f} for f in (0.1, 0.5, 1.0, 1.34, 2.5)])
@pytest.mark.parametrize("channels", [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
...@@ -724,7 +639,7 @@ def test_adjust_brightness(device, dtype, config, channels): ...@@ -724,7 +639,7 @@ def test_adjust_brightness(device, dtype, config, channels):
) )
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize("channels", [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
def test_invert(device, dtype, channels): def test_invert(device, dtype, channels):
...@@ -733,7 +648,7 @@ def test_invert(device, dtype, channels): ...@@ -733,7 +648,7 @@ def test_invert(device, dtype, channels):
) )
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("config", [{"bits": bits} for bits in range(0, 8)]) @pytest.mark.parametrize("config", [{"bits": bits} for bits in range(0, 8)])
@pytest.mark.parametrize("channels", [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
def test_posterize(device, config, channels): def test_posterize(device, config, channels):
...@@ -750,7 +665,7 @@ def test_posterize(device, config, channels): ...@@ -750,7 +665,7 @@ def test_posterize(device, config, channels):
) )
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("config", [{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]]) @pytest.mark.parametrize("config", [{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]])
@pytest.mark.parametrize("channels", [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
def test_solarize1(device, config, channels): def test_solarize1(device, config, channels):
...@@ -767,7 +682,7 @@ def test_solarize1(device, config, channels): ...@@ -767,7 +682,7 @@ def test_solarize1(device, config, channels):
) )
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dtype", (torch.float32, torch.float64)) @pytest.mark.parametrize("dtype", (torch.float32, torch.float64))
@pytest.mark.parametrize("config", [{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]]) @pytest.mark.parametrize("config", [{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]])
@pytest.mark.parametrize("channels", [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
...@@ -785,37 +700,45 @@ def test_solarize2(device, dtype, config, channels): ...@@ -785,37 +700,45 @@ def test_solarize2(device, dtype, config, channels):
) )
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize(
@pytest.mark.parametrize("threshold", [0.0, 0.25, 0.5, 0.75, 1.0]) ("dtype", "threshold"),
def test_solarize_threshold1_bound(threshold, device): [
img = torch.rand((3, 12, 23)).to(device) *[
F_t.solarize(img, threshold) (dtype, threshold)
for dtype, threshold in itertools.product(
[torch.float32, torch.float16],
@pytest.mark.parametrize("device", cpu_and_gpu()) [0.0, 0.25, 0.5, 0.75, 1.0],
@pytest.mark.parametrize("threshold", [1.5]) )
def test_solarize_threshold1_upper_bound(threshold, device): ],
img = torch.rand((3, 12, 23)).to(device) *[(torch.uint8, threshold) for threshold in [0, 64, 128, 192, 255]],
with pytest.raises(TypeError, match="Threshold should be less than bound of img."): *[(torch.int64, threshold) for threshold in [0, 2**32, 2**63 - 1]],
F_t.solarize(img, threshold) ],
)
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("device", cpu_and_gpu()) def test_solarize_threshold_within_bound(threshold, dtype, device):
@pytest.mark.parametrize("threshold", [0, 64, 128, 192, 255]) make_img = torch.rand if dtype.is_floating_point else partial(torch.randint, 0, torch.iinfo(dtype).max)
def test_solarize_threshold2_bound(threshold, device): img = make_img((3, 12, 23), dtype=dtype, device=device)
img = torch.randint(0, 256, (3, 12, 23)).to(device)
F_t.solarize(img, threshold) F_t.solarize(img, threshold)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize(
@pytest.mark.parametrize("threshold", [260]) ("dtype", "threshold"),
def test_solarize_threshold2_upper_bound(threshold, device): [
img = torch.randint(0, 256, (3, 12, 23)).to(device) (torch.float32, 1.5),
(torch.float16, 1.5),
(torch.uint8, 260),
(torch.int64, 2**64),
],
)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_solarize_threshold_above_bound(threshold, dtype, device):
make_img = torch.rand if dtype.is_floating_point else partial(torch.randint, 0, torch.iinfo(dtype).max)
img = make_img((3, 12, 23), dtype=dtype, device=device)
with pytest.raises(TypeError, match="Threshold should be less than bound of img."): with pytest.raises(TypeError, match="Threshold should be less than bound of img."):
F_t.solarize(img, threshold) F_t.solarize(img, threshold)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize("config", [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]) @pytest.mark.parametrize("config", [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]])
@pytest.mark.parametrize("channels", [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
...@@ -831,7 +754,7 @@ def test_adjust_sharpness(device, dtype, config, channels): ...@@ -831,7 +754,7 @@ def test_adjust_sharpness(device, dtype, config, channels):
) )
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize("channels", [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
def test_autocontrast(device, dtype, channels): def test_autocontrast(device, dtype, channels):
...@@ -840,7 +763,7 @@ def test_autocontrast(device, dtype, channels): ...@@ -840,7 +763,7 @@ def test_autocontrast(device, dtype, channels):
) )
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize("channels", [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
def test_autocontrast_equal_minmax(device, dtype, channels): def test_autocontrast_equal_minmax(device, dtype, channels):
...@@ -852,7 +775,7 @@ def test_autocontrast_equal_minmax(device, dtype, channels): ...@@ -852,7 +775,7 @@ def test_autocontrast_equal_minmax(device, dtype, channels):
assert (F.autocontrast(a)[0] == F.autocontrast(a[0])).all() assert (F.autocontrast(a)[0] == F.autocontrast(a[0])).all()
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("channels", [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
def test_equalize(device, channels): def test_equalize(device, channels):
torch.use_deterministic_algorithms(False) torch.use_deterministic_algorithms(False)
...@@ -869,7 +792,7 @@ def test_equalize(device, channels): ...@@ -869,7 +792,7 @@ def test_equalize(device, channels):
) )
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize("config", [{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]) @pytest.mark.parametrize("config", [{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]])
@pytest.mark.parametrize("channels", [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
...@@ -879,7 +802,7 @@ def test_adjust_contrast(device, dtype, config, channels): ...@@ -879,7 +802,7 @@ def test_adjust_contrast(device, dtype, config, channels):
) )
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize("config", [{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]]) @pytest.mark.parametrize("config", [{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]])
@pytest.mark.parametrize("channels", [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
...@@ -889,7 +812,7 @@ def test_adjust_saturation(device, dtype, config, channels): ...@@ -889,7 +812,7 @@ def test_adjust_saturation(device, dtype, config, channels):
) )
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize("config", [{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]]) @pytest.mark.parametrize("config", [{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]])
@pytest.mark.parametrize("channels", [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
...@@ -899,7 +822,7 @@ def test_adjust_hue(device, dtype, config, channels): ...@@ -899,7 +822,7 @@ def test_adjust_hue(device, dtype, config, channels):
) )
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize("config", [{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])]) @pytest.mark.parametrize("config", [{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])])
@pytest.mark.parametrize("channels", [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
...@@ -915,7 +838,7 @@ def test_adjust_gamma(device, dtype, config, channels): ...@@ -915,7 +838,7 @@ def test_adjust_gamma(device, dtype, config, channels):
) )
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize("pad", [2, [3], [0, 3], (3, 3), [4, 2, 4, 3]]) @pytest.mark.parametrize("pad", [2, [3], [0, 3], (3, 3), [4, 2, 4, 3]])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -965,14 +888,16 @@ def test_pad(device, dt, pad, config): ...@@ -965,14 +888,16 @@ def test_pad(device, dt, pad, config):
_test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **config) _test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **config)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("mode", [NEAREST, BILINEAR, BICUBIC]) @pytest.mark.parametrize("mode", [NEAREST, NEAREST_EXACT, BILINEAR, BICUBIC])
def test_resized_crop(device, mode): def test_resized_crop(device, mode):
# test values of F.resized_crop in several cases: # test values of F.resized_crop in several cases:
# 1) resize to the same size, crop to the same size => should be identity # 1) resize to the same size, crop to the same size => should be identity
tensor, _ = _create_data(26, 36, device=device) tensor, _ = _create_data(26, 36, device=device)
out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode) out_tensor = F.resized_crop(
tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode, antialias=True
)
assert_equal(tensor, out_tensor, msg=f"{out_tensor[0, :5, :5]} vs {tensor[0, :5, :5]}") assert_equal(tensor, out_tensor, msg=f"{out_tensor[0, :5, :5]} vs {tensor[0, :5, :5]}")
# 2) resize by half and crop a TL corner # 2) resize by half and crop a TL corner
...@@ -987,11 +912,18 @@ def test_resized_crop(device, mode): ...@@ -987,11 +912,18 @@ def test_resized_crop(device, mode):
batch_tensors = _create_data_batch(26, 36, num_samples=4, device=device) batch_tensors = _create_data_batch(26, 36, num_samples=4, device=device)
_test_fn_on_batch( _test_fn_on_batch(
batch_tensors, F.resized_crop, top=1, left=2, height=20, width=30, size=[10, 15], interpolation=NEAREST batch_tensors,
F.resized_crop,
top=1,
left=2,
height=20,
width=30,
size=[10, 15],
interpolation=NEAREST,
) )
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize( @pytest.mark.parametrize(
"func, args", "func, args",
[ [
...@@ -1024,7 +956,7 @@ def test_assert_image_tensor(device, func, args): ...@@ -1024,7 +956,7 @@ def test_assert_image_tensor(device, func, args):
func(tensor, *args) func(tensor, *args)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
def test_vflip(device): def test_vflip(device):
script_vflip = torch.jit.script(F.vflip) script_vflip = torch.jit.script(F.vflip)
...@@ -1041,7 +973,7 @@ def test_vflip(device): ...@@ -1041,7 +973,7 @@ def test_vflip(device):
_test_fn_on_batch(batch_tensors, F.vflip) _test_fn_on_batch(batch_tensors, F.vflip)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
def test_hflip(device): def test_hflip(device):
script_hflip = torch.jit.script(F.hflip) script_hflip = torch.jit.script(F.hflip)
...@@ -1058,7 +990,7 @@ def test_hflip(device): ...@@ -1058,7 +990,7 @@ def test_hflip(device):
_test_fn_on_batch(batch_tensors, F.hflip) _test_fn_on_batch(batch_tensors, F.hflip)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize( @pytest.mark.parametrize(
"top, left, height, width", "top, left, height, width",
[ [
...@@ -1087,7 +1019,7 @@ def test_crop(device, top, left, height, width): ...@@ -1087,7 +1019,7 @@ def test_crop(device, top, left, height, width):
_test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width) _test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("image_size", ("small", "large")) @pytest.mark.parametrize("image_size", ("small", "large"))
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)]) @pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)])
...@@ -1141,7 +1073,7 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn): ...@@ -1141,7 +1073,7 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn):
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}") torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
def test_hsv2rgb(device): def test_hsv2rgb(device):
scripted_fn = torch.jit.script(F_t._hsv2rgb) scripted_fn = torch.jit.script(F_t._hsv2rgb)
shape = (3, 100, 150) shape = (3, 100, 150)
...@@ -1172,7 +1104,7 @@ def test_hsv2rgb(device): ...@@ -1172,7 +1104,7 @@ def test_hsv2rgb(device):
_test_fn_on_batch(batch_tensors, F_t._hsv2rgb) _test_fn_on_batch(batch_tensors, F_t._hsv2rgb)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
def test_rgb2hsv(device): def test_rgb2hsv(device):
scripted_fn = torch.jit.script(F_t._rgb2hsv) scripted_fn = torch.jit.script(F_t._rgb2hsv)
shape = (3, 150, 100) shape = (3, 150, 100)
...@@ -1211,7 +1143,7 @@ def test_rgb2hsv(device): ...@@ -1211,7 +1143,7 @@ def test_rgb2hsv(device):
_test_fn_on_batch(batch_tensors, F_t._rgb2hsv) _test_fn_on_batch(batch_tensors, F_t._rgb2hsv)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("num_output_channels", (3, 1)) @pytest.mark.parametrize("num_output_channels", (3, 1))
def test_rgb_to_grayscale(device, num_output_channels): def test_rgb_to_grayscale(device, num_output_channels):
script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale) script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)
...@@ -1230,7 +1162,7 @@ def test_rgb_to_grayscale(device, num_output_channels): ...@@ -1230,7 +1162,7 @@ def test_rgb_to_grayscale(device, num_output_channels):
_test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels) _test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
def test_center_crop(device): def test_center_crop(device):
script_center_crop = torch.jit.script(F.center_crop) script_center_crop = torch.jit.script(F.center_crop)
...@@ -1248,7 +1180,7 @@ def test_center_crop(device): ...@@ -1248,7 +1180,7 @@ def test_center_crop(device):
_test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11]) _test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11])
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
def test_five_crop(device): def test_five_crop(device):
script_five_crop = torch.jit.script(F.five_crop) script_five_crop = torch.jit.script(F.five_crop)
...@@ -1282,7 +1214,7 @@ def test_five_crop(device): ...@@ -1282,7 +1214,7 @@ def test_five_crop(device):
assert_equal(transformed_batch, s_transformed_batch) assert_equal(transformed_batch, s_transformed_batch)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
def test_ten_crop(device): def test_ten_crop(device):
script_ten_crop = torch.jit.script(F.ten_crop) script_ten_crop = torch.jit.script(F.ten_crop)
...@@ -1328,7 +1260,7 @@ def test_elastic_transform_asserts(): ...@@ -1328,7 +1260,7 @@ def test_elastic_transform_asserts():
_ = F.elastic_transform(img_tensor, displacement=torch.rand(1, 2)) _ = F.elastic_transform(img_tensor, displacement=torch.rand(1, 2))
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC]) @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC])
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment