Unverified Commit 3da86585 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

port sample input smoke test (#7962)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent f33e387f
......@@ -420,7 +420,7 @@ def make_bounding_boxes(
dtype = dtype or torch.float32
num_objects = 1
h, w = [torch.randint(1, c, (num_objects,)) for c in canvas_size]
h, w = [torch.randint(1, s, (num_objects,)) for s in canvas_size]
y = sample_position(h, canvas_size[0])
x = sample_position(w, canvas_size[1])
......
import itertools
import pathlib
import pickle
import random
import numpy as np
......@@ -11,22 +9,11 @@ import torch
import torchvision.transforms.v2 as transforms
from common_utils import assert_equal, cpu_and_cuda
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import tv_tensors
from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import to_pil_image
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2._utils import check_type, is_pure_tensor, query_chw
from transforms_v2_legacy_utils import (
make_bounding_boxes,
make_detection_mask,
make_image,
make_images,
make_multiple_bounding_boxes,
make_segmentation_mask,
make_video,
make_videos,
)
from torchvision.transforms.v2._utils import is_pure_tensor
from transforms_v2_legacy_utils import make_bounding_boxes, make_detection_mask, make_image, make_images, make_videos
def make_vanilla_tensor_images(*args, **kwargs):
......@@ -41,11 +28,6 @@ def make_pil_images(*args, **kwargs):
yield to_pil_image(image)
def make_vanilla_tensor_bounding_boxes(*args, **kwargs):
for bounding_boxes in make_multiple_bounding_boxes(*args, **kwargs):
yield bounding_boxes.data
def parametrize(transforms_with_inputs):
return pytest.mark.parametrize(
("transform", "input"),
......@@ -61,218 +43,6 @@ def parametrize(transforms_with_inputs):
)
def auto_augment_adapter(transform, input, device):
adapted_input = {}
image_or_video_found = False
for key, value in input.items():
if isinstance(value, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
# AA transforms don't support bounding boxes or masks
continue
elif check_type(value, (tv_tensors.Image, tv_tensors.Video, is_pure_tensor, PIL.Image.Image)):
if image_or_video_found:
# AA transforms only support a single image or video
continue
image_or_video_found = True
adapted_input[key] = value
return adapted_input
def linear_transformation_adapter(transform, input, device):
flat_inputs = list(input.values())
c, h, w = query_chw(
[
item
for item, needs_transform in zip(flat_inputs, transforms.Transform()._needs_transform_list(flat_inputs))
if needs_transform
]
)
num_elements = c * h * w
transform.transformation_matrix = torch.randn((num_elements, num_elements), device=device)
transform.mean_vector = torch.randn((num_elements,), device=device)
return {key: value for key, value in input.items() if not isinstance(value, PIL.Image.Image)}
def normalize_adapter(transform, input, device):
adapted_input = {}
for key, value in input.items():
if isinstance(value, PIL.Image.Image):
# normalize doesn't support PIL images
continue
elif check_type(value, (tv_tensors.Image, tv_tensors.Video, is_pure_tensor)):
# normalize doesn't support integer images
value = F.to_dtype(value, torch.float32, scale=True)
adapted_input[key] = value
return adapted_input
class TestSmoke:
@pytest.mark.parametrize(
("transform", "adapter"),
[
(transforms.RandomErasing(p=1.0), None),
(transforms.AugMix(), auto_augment_adapter),
(transforms.AutoAugment(), auto_augment_adapter),
(transforms.RandAugment(), auto_augment_adapter),
(transforms.TrivialAugmentWide(), auto_augment_adapter),
(transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.3, hue=0.15), None),
(transforms.RandomAdjustSharpness(sharpness_factor=0.5, p=1.0), None),
(transforms.RandomAutocontrast(p=1.0), None),
(transforms.RandomEqualize(p=1.0), None),
(transforms.RandomInvert(p=1.0), None),
(transforms.RandomChannelPermutation(), None),
(transforms.RandomPosterize(bits=4, p=1.0), None),
(transforms.RandomSolarize(threshold=0.5, p=1.0), None),
(transforms.CenterCrop([16, 16]), None),
(transforms.ElasticTransform(sigma=1.0), None),
(transforms.Pad(4), None),
(transforms.RandomAffine(degrees=30.0), None),
(transforms.RandomCrop([16, 16], pad_if_needed=True), None),
(transforms.RandomHorizontalFlip(p=1.0), None),
(transforms.RandomPerspective(p=1.0), None),
(transforms.RandomResize(min_size=10, max_size=20, antialias=True), None),
(transforms.RandomResizedCrop([16, 16], antialias=True), None),
(transforms.RandomRotation(degrees=30), None),
(transforms.RandomShortestSize(min_size=10, antialias=True), None),
(transforms.RandomVerticalFlip(p=1.0), None),
(transforms.Resize([16, 16], antialias=True), None),
(transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2), antialias=True), None),
(transforms.ClampBoundingBoxes(), None),
(transforms.ConvertBoundingBoxFormat(tv_tensors.BoundingBoxFormat.CXCYWH), None),
(transforms.ConvertImageDtype(), None),
(transforms.GaussianBlur(kernel_size=3), None),
(
transforms.LinearTransformation(
# These are just dummy values that will be filled by the adapter. We can't define them upfront,
# because for we neither know the spatial size nor the device at this point
transformation_matrix=torch.empty((1, 1)),
mean_vector=torch.empty((1,)),
),
linear_transformation_adapter,
),
(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), normalize_adapter),
(transforms.ToDtype(torch.float64), None),
(transforms.UniformTemporalSubsample(num_samples=2), None),
],
ids=lambda transform: type(transform).__name__,
)
@pytest.mark.parametrize("container_type", [dict, list, tuple])
@pytest.mark.parametrize(
"image_or_video",
[
make_image(),
make_video(),
next(make_pil_images(color_spaces=["RGB"])),
next(make_vanilla_tensor_images()),
],
)
@pytest.mark.parametrize("de_serialize", [lambda t: t, lambda t: pickle.loads(pickle.dumps(t))])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_common(self, transform, adapter, container_type, image_or_video, de_serialize, device):
transform = de_serialize(transform)
canvas_size = F.get_size(image_or_video)
input = dict(
image_or_video=image_or_video,
image_tv_tensor=make_image(size=canvas_size),
video_tv_tensor=make_video(size=canvas_size),
image_pil=next(make_pil_images(sizes=[canvas_size], color_spaces=["RGB"])),
bounding_boxes_xyxy=make_bounding_boxes(
format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(3,)
),
bounding_boxes_xywh=make_bounding_boxes(
format=tv_tensors.BoundingBoxFormat.XYWH, canvas_size=canvas_size, batch_dims=(4,)
),
bounding_boxes_cxcywh=make_bounding_boxes(
format=tv_tensors.BoundingBoxFormat.CXCYWH, canvas_size=canvas_size, batch_dims=(5,)
),
bounding_boxes_degenerate_xyxy=tv_tensors.BoundingBoxes(
[
[0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height
[0, 0, 1, 0], # no width
[2, 0, 1, 1], # x1 > x2, y1 < y2
[0, 2, 1, 1], # x1 < x2, y1 > y2
[2, 2, 1, 1], # x1 > x2, y1 > y2
],
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=canvas_size,
),
bounding_boxes_degenerate_xywh=tv_tensors.BoundingBoxes(
[
[0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height
[0, 0, 1, 0], # no width
[0, 0, 1, -1], # negative height
[0, 0, -1, 1], # negative width
[0, 0, -1, -1], # negative height and width
],
format=tv_tensors.BoundingBoxFormat.XYWH,
canvas_size=canvas_size,
),
bounding_boxes_degenerate_cxcywh=tv_tensors.BoundingBoxes(
[
[0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height
[0, 0, 1, 0], # no width
[0, 0, 1, -1], # negative height
[0, 0, -1, 1], # negative width
[0, 0, -1, -1], # negative height and width
],
format=tv_tensors.BoundingBoxFormat.CXCYWH,
canvas_size=canvas_size,
),
detection_mask=make_detection_mask(size=canvas_size),
segmentation_mask=make_segmentation_mask(size=canvas_size),
int=0,
float=0.0,
bool=True,
none=None,
str="str",
path=pathlib.Path.cwd(),
object=object(),
tensor=torch.empty(5),
array=np.empty(5),
)
if adapter is not None:
input = adapter(transform, input, device)
if container_type in {tuple, list}:
input = container_type(input.values())
input_flat, input_spec = tree_flatten(input)
input_flat = [item.to(device) if isinstance(item, torch.Tensor) else item for item in input_flat]
input = tree_unflatten(input_flat, input_spec)
torch.manual_seed(0)
output = transform(input)
output_flat, output_spec = tree_flatten(output)
assert output_spec == input_spec
for output_item, input_item, should_be_transformed in zip(
output_flat, input_flat, transforms.Transform()._needs_transform_list(input_flat)
):
if should_be_transformed:
assert type(output_item) is type(input_item)
else:
assert output_item is input_item
if isinstance(input_item, tv_tensors.BoundingBoxes) and not isinstance(
transform, transforms.ConvertBoundingBoxFormat
):
assert output_item.format == input_item.format
# Enforce that the transform does not turn a degenerate box marked by RandomIoUCrop (or any other future
# transform that does this), back into a valid one.
# TODO: we should test that against all degenerate boxes above
for format in list(tv_tensors.BoundingBoxFormat):
sample = dict(
boxes=tv_tensors.BoundingBoxes([[0, 0, 0, 0]], format=format, canvas_size=(224, 244)),
labels=torch.tensor([3]),
)
assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4)
@pytest.mark.parametrize(
"flat_inputs",
itertools.permutations(
......@@ -543,39 +313,6 @@ class TestRandomShortestSize:
assert shorter in min_size
class TestLinearTransformation:
def test_assertions(self):
with pytest.raises(ValueError, match="transformation_matrix should be square"):
transforms.LinearTransformation(torch.rand(2, 3), torch.rand(5))
with pytest.raises(ValueError, match="mean_vector should have the same length"):
transforms.LinearTransformation(torch.rand(3, 3), torch.rand(5))
@pytest.mark.parametrize(
"inpt",
[
122 * torch.ones(1, 3, 8, 8),
122.0 * torch.ones(1, 3, 8, 8),
tv_tensors.Image(122 * torch.ones(1, 3, 8, 8)),
PIL.Image.new("RGB", (8, 8), (122, 122, 122)),
],
)
def test__transform(self, inpt):
v = 121 * torch.ones(3 * 8 * 8)
m = torch.ones(3 * 8 * 8, 3 * 8 * 8)
transform = transforms.LinearTransformation(m, v)
if isinstance(inpt, PIL.Image.Image):
with pytest.raises(TypeError, match="does not support PIL images"):
transform(inpt)
else:
output = transform(inpt)
assert isinstance(output, torch.Tensor)
assert output.unique() == 3 * 8 * 8
assert output.dtype == inpt.dtype
class TestRandomResize:
def test__get_params(self):
min_size = 3
......
......@@ -72,28 +72,6 @@ LINEAR_TRANSFORMATION_MEAN = torch.rand(36)
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)
CONSISTENCY_CONFIGS = [
*[
ConsistencyConfig(
v2_transforms.LinearTransformation,
legacy_transforms.LinearTransformation,
[
ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX.to(matrix_dtype), LINEAR_TRANSFORMATION_MEAN.to(matrix_dtype)),
],
# Make sure that the product of the height, width and number of channels matches the number of elements in
# `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
make_images_kwargs=dict(
DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"], dtypes=[image_dtype]
),
supports_pil=False,
)
for matrix_dtype, image_dtype in [
(torch.float32, torch.float32),
(torch.float64, torch.float64),
(torch.float32, torch.uint8),
(torch.float64, torch.float32),
(torch.float32, torch.float64),
]
],
ConsistencyConfig(
v2_transforms.ToPILImage,
legacy_transforms.ToPILImage,
......
......@@ -6,6 +6,7 @@ import itertools
import math
import pickle
import re
from copy import deepcopy
from pathlib import Path
from unittest import mock
......@@ -37,13 +38,14 @@ from common_utils import (
from torch import nn
from torch.testing import assert_close
from torch.utils._pytree import tree_map
from torch.utils._pytree import tree_flatten, tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision import tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.functional import pil_modes_mapping
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2._utils import check_type, is_pure_tensor
from torchvision.transforms.v2.functional._geometry import _get_perspective_coeffs
from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal
......@@ -261,7 +263,123 @@ def _check_transform_v1_compatibility(transform, input, *, rtol, atol):
_script(v1_transform)(input)
def check_transform(transform, input, check_v1_compatibility=True):
def _make_transform_sample(transform, *, image_or_video, adapter):
device = image_or_video.device if isinstance(image_or_video, torch.Tensor) else "cpu"
size = F.get_size(image_or_video)
input = dict(
image_or_video=image_or_video,
image_tv_tensor=make_image(size, device=device),
video_tv_tensor=make_video(size, device=device),
image_pil=make_image_pil(size),
bounding_boxes_xyxy=make_bounding_boxes(size, format=tv_tensors.BoundingBoxFormat.XYXY, device=device),
bounding_boxes_xywh=make_bounding_boxes(size, format=tv_tensors.BoundingBoxFormat.XYWH, device=device),
bounding_boxes_cxcywh=make_bounding_boxes(size, format=tv_tensors.BoundingBoxFormat.CXCYWH, device=device),
bounding_boxes_degenerate_xyxy=tv_tensors.BoundingBoxes(
[
[0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height
[0, 0, 1, 0], # no width
[2, 0, 1, 1], # x1 > x2, y1 < y2
[0, 2, 1, 1], # x1 < x2, y1 > y2
[2, 2, 1, 1], # x1 > x2, y1 > y2
],
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=size,
device=device,
),
bounding_boxes_degenerate_xywh=tv_tensors.BoundingBoxes(
[
[0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height
[0, 0, 1, 0], # no width
[0, 0, 1, -1], # negative height
[0, 0, -1, 1], # negative width
[0, 0, -1, -1], # negative height and width
],
format=tv_tensors.BoundingBoxFormat.XYWH,
canvas_size=size,
device=device,
),
bounding_boxes_degenerate_cxcywh=tv_tensors.BoundingBoxes(
[
[0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height
[0, 0, 1, 0], # no width
[0, 0, 1, -1], # negative height
[0, 0, -1, 1], # negative width
[0, 0, -1, -1], # negative height and width
],
format=tv_tensors.BoundingBoxFormat.CXCYWH,
canvas_size=size,
device=device,
),
detection_mask=make_detection_mask(size, device=device),
segmentation_mask=make_segmentation_mask(size, device=device),
int=0,
float=0.0,
bool=True,
none=None,
str="str",
path=Path.cwd(),
object=object(),
tensor=torch.empty(5),
array=np.empty(5),
)
if adapter is not None:
input = adapter(transform, input, device)
return input
def _check_transform_sample_input_smoke(transform, input, *, adapter):
# This is a bunch of input / output convention checks, using a big sample with different parts as input.
if not check_type(input, (is_pure_tensor, PIL.Image.Image, tv_tensors.Image, tv_tensors.Video)):
return
sample = _make_transform_sample(
# adapter might change transform inplace
transform=transform if adapter is None else deepcopy(transform),
image_or_video=input,
adapter=adapter,
)
for container_type in [dict, list, tuple]:
if container_type is dict:
input = sample
else:
input = container_type(sample.values())
input_flat, input_spec = tree_flatten(input)
with freeze_rng_state():
torch.manual_seed(0)
output = transform(input)
output_flat, output_spec = tree_flatten(output)
assert output_spec == input_spec
for output_item, input_item, should_be_transformed in zip(
output_flat, input_flat, transforms.Transform()._needs_transform_list(input_flat)
):
if should_be_transformed:
assert type(output_item) is type(input_item)
else:
assert output_item is input_item
# Enforce that the transform does not turn a degenerate bounding box, e.g. marked by RandomIoUCrop (or any other
# future transform that does this), back into a valid one.
for degenerate_bounding_boxes in (
bounding_box
for name, bounding_box in sample.items()
if "degenerate" in name and isinstance(bounding_box, tv_tensors.BoundingBoxes)
):
sample = dict(
boxes=degenerate_bounding_boxes,
labels=torch.randint(10, (degenerate_bounding_boxes.shape[0],), device=degenerate_bounding_boxes.device),
)
assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4)
def check_transform(transform, input, check_v1_compatibility=True, check_sample_input=True):
pickle.loads(pickle.dumps(transform))
output = transform(input)
......@@ -270,6 +388,11 @@ def check_transform(transform, input, check_v1_compatibility=True):
if isinstance(input, tv_tensors.BoundingBoxes) and not isinstance(transform, transforms.ConvertBoundingBoxFormat):
assert output.format == input.format
if check_sample_input:
_check_transform_sample_input_smoke(
transform, input, adapter=check_sample_input if callable(check_sample_input) else None
)
if check_v1_compatibility:
_check_transform_v1_compatibility(transform, input, **_to_tolerances(check_v1_compatibility))
......@@ -1758,7 +1881,7 @@ class TestToDtype:
input = make_input(dtype=input_dtype, device=device)
if as_dict:
output_dtype = {type(input): output_dtype}
check_transform(transforms.ToDtype(dtype=output_dtype, scale=scale), input)
check_transform(transforms.ToDtype(dtype=output_dtype, scale=scale), input, check_sample_input=not as_dict)
def reference_convert_dtype_image_tensor(self, image, dtype=torch.float, scale=False):
input_dtype = image.dtype
......@@ -2559,9 +2682,13 @@ class TestCrop:
def test_transform(self, param, value, make_input):
input = make_input(self.INPUT_SIZE)
check_sample_input = True
if param == "fill":
if isinstance(input, tv_tensors.Mask) and isinstance(value, (tuple, list)):
pytest.skip("F.pad_mask doesn't support non-scalar fill.")
if isinstance(value, (tuple, list)):
if isinstance(input, tv_tensors.Mask):
pytest.skip("F.pad_mask doesn't support non-scalar fill.")
else:
check_sample_input = False
kwargs = dict(
# 1. size is required
......@@ -2576,6 +2703,7 @@ class TestCrop:
transforms.RandomCrop(**kwargs, pad_if_needed=True),
input,
check_v1_compatibility=param != "fill" or isinstance(value, (int, float)),
check_sample_input=check_sample_input,
)
@pytest.mark.parametrize("padding", [1, (1, 1), (1, 1, 1, 1)])
......@@ -2761,9 +2889,13 @@ class TestErase:
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, make_input, device):
input = make_input(device=device)
check_transform(
transforms.RandomErasing(p=1), input, check_v1_compatibility=not isinstance(input, PIL.Image.Image)
)
with pytest.warns(UserWarning, match="currently passing through inputs of type"):
check_transform(
transforms.RandomErasing(p=1),
input,
check_v1_compatibility=not isinstance(input, PIL.Image.Image),
)
def _reference_erase_image(self, image, *, i, j, h, w, v):
mask = torch.zeros_like(image, dtype=torch.bool)
......@@ -2835,18 +2967,6 @@ class TestErase:
with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value"):
transform._get_params([make_image()])
@pytest.mark.parametrize("make_input", [make_bounding_boxes, make_detection_mask])
def test_transform_passthrough(self, make_input):
transform = transforms.RandomErasing(p=1)
input = make_input(self.INPUT_SIZE)
with pytest.warns(UserWarning, match="currently passing through inputs of type"):
# RandomErasing requires an image or video to be present
_, output = transform(make_image(self.INPUT_SIZE), input)
assert output is input
class TestGaussianBlur:
@pytest.mark.parametrize("kernel_size", [1, 3, (3, 1), [3, 5]])
......@@ -3063,6 +3183,21 @@ class TestAutoAugmentTransforms:
else:
assert_close(actual, expected, rtol=0, atol=1)
def _sample_input_adapter(self, transform, input, device):
adapted_input = {}
image_or_video_found = False
for key, value in input.items():
if isinstance(value, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
# AA transforms don't support bounding boxes or masks
continue
elif check_type(value, (tv_tensors.Image, tv_tensors.Video, is_pure_tensor, PIL.Image.Image)):
if image_or_video_found:
# AA transforms only support a single image or video
continue
image_or_video_found = True
adapted_input[key] = value
return adapted_input
@pytest.mark.parametrize(
"transform",
[transforms.AutoAugment(), transforms.RandAugment(), transforms.TrivialAugmentWide(), transforms.AugMix()],
......@@ -3087,7 +3222,9 @@ class TestAutoAugmentTransforms:
# For v2, we changed the random sampling of the AA transforms. This makes it impossible to compare the v1
# and v2 outputs without complicated mocking and monkeypatching. Thus, we skip the v1 compatibility checks
# here and only check if we can script the v2 transform and subsequently call the result.
check_transform(transform, input, check_v1_compatibility=False)
check_transform(
transform, input, check_v1_compatibility=False, check_sample_input=self._sample_input_adapter
)
if type(input) is torch.Tensor and dtype is torch.uint8:
_script(transform)(input)
......@@ -4014,9 +4151,25 @@ class TestNormalize:
with pytest.raises(ValueError, match="std evaluated to zero, leading to division by zero"):
F.normalize_image(make_image(dtype=torch.float32), mean=self.MEAN, std=std)
def _sample_input_adapter(self, transform, input, device):
adapted_input = {}
for key, value in input.items():
if isinstance(value, PIL.Image.Image):
# normalize doesn't support PIL images
continue
elif check_type(value, (is_pure_tensor, tv_tensors.Image, tv_tensors.Video)):
# normalize doesn't support integer images
value = F.to_dtype(value, torch.float32, scale=True)
adapted_input[key] = value
return adapted_input
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
def test_transform(self, make_input):
check_transform(transforms.Normalize(mean=self.MEAN, std=self.STD), make_input(dtype=torch.float32))
check_transform(
transforms.Normalize(mean=self.MEAN, std=self.STD),
make_input(dtype=torch.float32),
check_sample_input=self._sample_input_adapter,
)
def _reference_normalize_image(self, image, *, mean, std):
image = image.numpy()
......@@ -4543,7 +4696,11 @@ class TestFiveTenCrop:
)
@pytest.mark.parametrize("transform_cls", [transforms.FiveCrop, transforms.TenCrop])
def test_transform(self, make_input, transform_cls):
check_transform(self._TransformWrapper(transform_cls(size=self.OUTPUT_SIZE)), make_input(self.INPUT_SIZE))
check_transform(
self._TransformWrapper(transform_cls(size=self.OUTPUT_SIZE)),
make_input(self.INPUT_SIZE),
check_sample_input=False,
)
@pytest.mark.parametrize("make_input", [make_bounding_boxes, make_detection_mask])
@pytest.mark.parametrize("transform_cls", [transforms.FiveCrop, transforms.TenCrop])
......@@ -4826,3 +4983,66 @@ class TestScaleJitter:
assert int(input_size[0] * r_min) <= height <= int(input_size[0] * r_max)
assert int(input_size[1] * r_min) <= width <= int(input_size[1] * r_max)
class TestLinearTransform:
def _make_matrix_and_vector(self, input, *, device=None):
device = device or input.device
numel = math.prod(F.get_dimensions(input))
transformation_matrix = torch.randn((numel, numel), device=device)
mean_vector = torch.randn((numel,), device=device)
return transformation_matrix, mean_vector
def _sample_input_adapter(self, transform, input, device):
return {key: value for key, value in input.items() if not isinstance(value, PIL.Image.Image)}
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, make_input, dtype, device):
input = make_input(dtype=dtype, device=device)
check_transform(
transforms.LinearTransformation(*self._make_matrix_and_vector(input)),
input,
check_sample_input=self._sample_input_adapter,
)
def test_transform_error(self):
with pytest.raises(ValueError, match="transformation_matrix should be square"):
transforms.LinearTransformation(transformation_matrix=torch.rand(2, 3), mean_vector=torch.rand(2))
with pytest.raises(ValueError, match="mean_vector should have the same length"):
transforms.LinearTransformation(transformation_matrix=torch.rand(2, 2), mean_vector=torch.rand(1))
for matrix_dtype, vector_dtype in [(torch.float32, torch.float64), (torch.float64, torch.float32)]:
with pytest.raises(ValueError, match="Input tensors should have the same dtype"):
transforms.LinearTransformation(
transformation_matrix=torch.rand(2, 2, dtype=matrix_dtype),
mean_vector=torch.rand(2, dtype=vector_dtype),
)
image = make_image()
transform = transforms.LinearTransformation(transformation_matrix=torch.rand(2, 2), mean_vector=torch.rand(2))
with pytest.raises(ValueError, match="Input tensor and transformation matrix have incompatible shape"):
transform(image)
transform = transforms.LinearTransformation(*self._make_matrix_and_vector(image))
with pytest.raises(TypeError, match="does not support PIL images"):
transform(F.to_pil_image(image))
@needs_cuda
def test_transform_error_cuda(self):
for matrix_device, vector_device in [("cuda", "cpu"), ("cpu", "cuda")]:
with pytest.raises(ValueError, match="Input tensors should be on the same device"):
transforms.LinearTransformation(
transformation_matrix=torch.rand(2, 2, device=matrix_device),
mean_vector=torch.rand(2, device=vector_device),
)
for input_device, param_device in [("cuda", "cpu"), ("cpu", "cuda")]:
input = make_image(device=input_device)
transform = transforms.LinearTransformation(*self._make_matrix_and_vector(input, device=param_device))
with pytest.raises(
ValueError, match="Input tensor should be on the same device as transformation matrix and mean vector"
):
transform(input)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment