Unverified Commit 3da86585 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

port sample input smoke test (#7962)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent f33e387f
...@@ -420,7 +420,7 @@ def make_bounding_boxes( ...@@ -420,7 +420,7 @@ def make_bounding_boxes(
dtype = dtype or torch.float32 dtype = dtype or torch.float32
num_objects = 1 num_objects = 1
h, w = [torch.randint(1, c, (num_objects,)) for c in canvas_size] h, w = [torch.randint(1, s, (num_objects,)) for s in canvas_size]
y = sample_position(h, canvas_size[0]) y = sample_position(h, canvas_size[0])
x = sample_position(w, canvas_size[1]) x = sample_position(w, canvas_size[1])
......
import itertools import itertools
import pathlib
import pickle
import random import random
import numpy as np import numpy as np
...@@ -11,22 +9,11 @@ import torch ...@@ -11,22 +9,11 @@ import torch
import torchvision.transforms.v2 as transforms import torchvision.transforms.v2 as transforms
from common_utils import assert_equal, cpu_and_cuda from common_utils import assert_equal, cpu_and_cuda
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import tv_tensors from torchvision import tv_tensors
from torchvision.ops.boxes import box_iou from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import to_pil_image from torchvision.transforms.functional import to_pil_image
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2._utils import is_pure_tensor
from torchvision.transforms.v2._utils import check_type, is_pure_tensor, query_chw from transforms_v2_legacy_utils import make_bounding_boxes, make_detection_mask, make_image, make_images, make_videos
from transforms_v2_legacy_utils import (
make_bounding_boxes,
make_detection_mask,
make_image,
make_images,
make_multiple_bounding_boxes,
make_segmentation_mask,
make_video,
make_videos,
)
def make_vanilla_tensor_images(*args, **kwargs): def make_vanilla_tensor_images(*args, **kwargs):
...@@ -41,11 +28,6 @@ def make_pil_images(*args, **kwargs): ...@@ -41,11 +28,6 @@ def make_pil_images(*args, **kwargs):
yield to_pil_image(image) yield to_pil_image(image)
def make_vanilla_tensor_bounding_boxes(*args, **kwargs):
for bounding_boxes in make_multiple_bounding_boxes(*args, **kwargs):
yield bounding_boxes.data
def parametrize(transforms_with_inputs): def parametrize(transforms_with_inputs):
return pytest.mark.parametrize( return pytest.mark.parametrize(
("transform", "input"), ("transform", "input"),
...@@ -61,218 +43,6 @@ def parametrize(transforms_with_inputs): ...@@ -61,218 +43,6 @@ def parametrize(transforms_with_inputs):
) )
def auto_augment_adapter(transform, input, device):
adapted_input = {}
image_or_video_found = False
for key, value in input.items():
if isinstance(value, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
# AA transforms don't support bounding boxes or masks
continue
elif check_type(value, (tv_tensors.Image, tv_tensors.Video, is_pure_tensor, PIL.Image.Image)):
if image_or_video_found:
# AA transforms only support a single image or video
continue
image_or_video_found = True
adapted_input[key] = value
return adapted_input
def linear_transformation_adapter(transform, input, device):
flat_inputs = list(input.values())
c, h, w = query_chw(
[
item
for item, needs_transform in zip(flat_inputs, transforms.Transform()._needs_transform_list(flat_inputs))
if needs_transform
]
)
num_elements = c * h * w
transform.transformation_matrix = torch.randn((num_elements, num_elements), device=device)
transform.mean_vector = torch.randn((num_elements,), device=device)
return {key: value for key, value in input.items() if not isinstance(value, PIL.Image.Image)}
def normalize_adapter(transform, input, device):
adapted_input = {}
for key, value in input.items():
if isinstance(value, PIL.Image.Image):
# normalize doesn't support PIL images
continue
elif check_type(value, (tv_tensors.Image, tv_tensors.Video, is_pure_tensor)):
# normalize doesn't support integer images
value = F.to_dtype(value, torch.float32, scale=True)
adapted_input[key] = value
return adapted_input
class TestSmoke:
@pytest.mark.parametrize(
("transform", "adapter"),
[
(transforms.RandomErasing(p=1.0), None),
(transforms.AugMix(), auto_augment_adapter),
(transforms.AutoAugment(), auto_augment_adapter),
(transforms.RandAugment(), auto_augment_adapter),
(transforms.TrivialAugmentWide(), auto_augment_adapter),
(transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.3, hue=0.15), None),
(transforms.RandomAdjustSharpness(sharpness_factor=0.5, p=1.0), None),
(transforms.RandomAutocontrast(p=1.0), None),
(transforms.RandomEqualize(p=1.0), None),
(transforms.RandomInvert(p=1.0), None),
(transforms.RandomChannelPermutation(), None),
(transforms.RandomPosterize(bits=4, p=1.0), None),
(transforms.RandomSolarize(threshold=0.5, p=1.0), None),
(transforms.CenterCrop([16, 16]), None),
(transforms.ElasticTransform(sigma=1.0), None),
(transforms.Pad(4), None),
(transforms.RandomAffine(degrees=30.0), None),
(transforms.RandomCrop([16, 16], pad_if_needed=True), None),
(transforms.RandomHorizontalFlip(p=1.0), None),
(transforms.RandomPerspective(p=1.0), None),
(transforms.RandomResize(min_size=10, max_size=20, antialias=True), None),
(transforms.RandomResizedCrop([16, 16], antialias=True), None),
(transforms.RandomRotation(degrees=30), None),
(transforms.RandomShortestSize(min_size=10, antialias=True), None),
(transforms.RandomVerticalFlip(p=1.0), None),
(transforms.Resize([16, 16], antialias=True), None),
(transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2), antialias=True), None),
(transforms.ClampBoundingBoxes(), None),
(transforms.ConvertBoundingBoxFormat(tv_tensors.BoundingBoxFormat.CXCYWH), None),
(transforms.ConvertImageDtype(), None),
(transforms.GaussianBlur(kernel_size=3), None),
(
transforms.LinearTransformation(
# These are just dummy values that will be filled by the adapter. We can't define them upfront,
# because for we neither know the spatial size nor the device at this point
transformation_matrix=torch.empty((1, 1)),
mean_vector=torch.empty((1,)),
),
linear_transformation_adapter,
),
(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), normalize_adapter),
(transforms.ToDtype(torch.float64), None),
(transforms.UniformTemporalSubsample(num_samples=2), None),
],
ids=lambda transform: type(transform).__name__,
)
@pytest.mark.parametrize("container_type", [dict, list, tuple])
@pytest.mark.parametrize(
"image_or_video",
[
make_image(),
make_video(),
next(make_pil_images(color_spaces=["RGB"])),
next(make_vanilla_tensor_images()),
],
)
@pytest.mark.parametrize("de_serialize", [lambda t: t, lambda t: pickle.loads(pickle.dumps(t))])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_common(self, transform, adapter, container_type, image_or_video, de_serialize, device):
transform = de_serialize(transform)
canvas_size = F.get_size(image_or_video)
input = dict(
image_or_video=image_or_video,
image_tv_tensor=make_image(size=canvas_size),
video_tv_tensor=make_video(size=canvas_size),
image_pil=next(make_pil_images(sizes=[canvas_size], color_spaces=["RGB"])),
bounding_boxes_xyxy=make_bounding_boxes(
format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(3,)
),
bounding_boxes_xywh=make_bounding_boxes(
format=tv_tensors.BoundingBoxFormat.XYWH, canvas_size=canvas_size, batch_dims=(4,)
),
bounding_boxes_cxcywh=make_bounding_boxes(
format=tv_tensors.BoundingBoxFormat.CXCYWH, canvas_size=canvas_size, batch_dims=(5,)
),
bounding_boxes_degenerate_xyxy=tv_tensors.BoundingBoxes(
[
[0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height
[0, 0, 1, 0], # no width
[2, 0, 1, 1], # x1 > x2, y1 < y2
[0, 2, 1, 1], # x1 < x2, y1 > y2
[2, 2, 1, 1], # x1 > x2, y1 > y2
],
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=canvas_size,
),
bounding_boxes_degenerate_xywh=tv_tensors.BoundingBoxes(
[
[0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height
[0, 0, 1, 0], # no width
[0, 0, 1, -1], # negative height
[0, 0, -1, 1], # negative width
[0, 0, -1, -1], # negative height and width
],
format=tv_tensors.BoundingBoxFormat.XYWH,
canvas_size=canvas_size,
),
bounding_boxes_degenerate_cxcywh=tv_tensors.BoundingBoxes(
[
[0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height
[0, 0, 1, 0], # no width
[0, 0, 1, -1], # negative height
[0, 0, -1, 1], # negative width
[0, 0, -1, -1], # negative height and width
],
format=tv_tensors.BoundingBoxFormat.CXCYWH,
canvas_size=canvas_size,
),
detection_mask=make_detection_mask(size=canvas_size),
segmentation_mask=make_segmentation_mask(size=canvas_size),
int=0,
float=0.0,
bool=True,
none=None,
str="str",
path=pathlib.Path.cwd(),
object=object(),
tensor=torch.empty(5),
array=np.empty(5),
)
if adapter is not None:
input = adapter(transform, input, device)
if container_type in {tuple, list}:
input = container_type(input.values())
input_flat, input_spec = tree_flatten(input)
input_flat = [item.to(device) if isinstance(item, torch.Tensor) else item for item in input_flat]
input = tree_unflatten(input_flat, input_spec)
torch.manual_seed(0)
output = transform(input)
output_flat, output_spec = tree_flatten(output)
assert output_spec == input_spec
for output_item, input_item, should_be_transformed in zip(
output_flat, input_flat, transforms.Transform()._needs_transform_list(input_flat)
):
if should_be_transformed:
assert type(output_item) is type(input_item)
else:
assert output_item is input_item
if isinstance(input_item, tv_tensors.BoundingBoxes) and not isinstance(
transform, transforms.ConvertBoundingBoxFormat
):
assert output_item.format == input_item.format
# Enforce that the transform does not turn a degenerate box marked by RandomIoUCrop (or any other future
# transform that does this), back into a valid one.
# TODO: we should test that against all degenerate boxes above
for format in list(tv_tensors.BoundingBoxFormat):
sample = dict(
boxes=tv_tensors.BoundingBoxes([[0, 0, 0, 0]], format=format, canvas_size=(224, 244)),
labels=torch.tensor([3]),
)
assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"flat_inputs", "flat_inputs",
itertools.permutations( itertools.permutations(
...@@ -543,39 +313,6 @@ class TestRandomShortestSize: ...@@ -543,39 +313,6 @@ class TestRandomShortestSize:
assert shorter in min_size assert shorter in min_size
class TestLinearTransformation:
def test_assertions(self):
with pytest.raises(ValueError, match="transformation_matrix should be square"):
transforms.LinearTransformation(torch.rand(2, 3), torch.rand(5))
with pytest.raises(ValueError, match="mean_vector should have the same length"):
transforms.LinearTransformation(torch.rand(3, 3), torch.rand(5))
@pytest.mark.parametrize(
"inpt",
[
122 * torch.ones(1, 3, 8, 8),
122.0 * torch.ones(1, 3, 8, 8),
tv_tensors.Image(122 * torch.ones(1, 3, 8, 8)),
PIL.Image.new("RGB", (8, 8), (122, 122, 122)),
],
)
def test__transform(self, inpt):
v = 121 * torch.ones(3 * 8 * 8)
m = torch.ones(3 * 8 * 8, 3 * 8 * 8)
transform = transforms.LinearTransformation(m, v)
if isinstance(inpt, PIL.Image.Image):
with pytest.raises(TypeError, match="does not support PIL images"):
transform(inpt)
else:
output = transform(inpt)
assert isinstance(output, torch.Tensor)
assert output.unique() == 3 * 8 * 8
assert output.dtype == inpt.dtype
class TestRandomResize: class TestRandomResize:
def test__get_params(self): def test__get_params(self):
min_size = 3 min_size = 3
......
...@@ -72,28 +72,6 @@ LINEAR_TRANSFORMATION_MEAN = torch.rand(36) ...@@ -72,28 +72,6 @@ LINEAR_TRANSFORMATION_MEAN = torch.rand(36)
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2) LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)
CONSISTENCY_CONFIGS = [ CONSISTENCY_CONFIGS = [
*[
ConsistencyConfig(
v2_transforms.LinearTransformation,
legacy_transforms.LinearTransformation,
[
ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX.to(matrix_dtype), LINEAR_TRANSFORMATION_MEAN.to(matrix_dtype)),
],
# Make sure that the product of the height, width and number of channels matches the number of elements in
# `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
make_images_kwargs=dict(
DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"], dtypes=[image_dtype]
),
supports_pil=False,
)
for matrix_dtype, image_dtype in [
(torch.float32, torch.float32),
(torch.float64, torch.float64),
(torch.float32, torch.uint8),
(torch.float64, torch.float32),
(torch.float32, torch.float64),
]
],
ConsistencyConfig( ConsistencyConfig(
v2_transforms.ToPILImage, v2_transforms.ToPILImage,
legacy_transforms.ToPILImage, legacy_transforms.ToPILImage,
......
...@@ -6,6 +6,7 @@ import itertools ...@@ -6,6 +6,7 @@ import itertools
import math import math
import pickle import pickle
import re import re
from copy import deepcopy
from pathlib import Path from pathlib import Path
from unittest import mock from unittest import mock
...@@ -37,13 +38,14 @@ from common_utils import ( ...@@ -37,13 +38,14 @@ from common_utils import (
from torch import nn from torch import nn
from torch.testing import assert_close from torch.testing import assert_close
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_flatten, tree_map
from torch.utils.data import DataLoader, default_collate from torch.utils.data import DataLoader, default_collate
from torchvision import tv_tensors from torchvision import tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.functional import pil_modes_mapping from torchvision.transforms.functional import pil_modes_mapping
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2._utils import check_type, is_pure_tensor
from torchvision.transforms.v2.functional._geometry import _get_perspective_coeffs from torchvision.transforms.v2.functional._geometry import _get_perspective_coeffs
from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal
...@@ -261,7 +263,123 @@ def _check_transform_v1_compatibility(transform, input, *, rtol, atol): ...@@ -261,7 +263,123 @@ def _check_transform_v1_compatibility(transform, input, *, rtol, atol):
_script(v1_transform)(input) _script(v1_transform)(input)
def check_transform(transform, input, check_v1_compatibility=True): def _make_transform_sample(transform, *, image_or_video, adapter):
device = image_or_video.device if isinstance(image_or_video, torch.Tensor) else "cpu"
size = F.get_size(image_or_video)
input = dict(
image_or_video=image_or_video,
image_tv_tensor=make_image(size, device=device),
video_tv_tensor=make_video(size, device=device),
image_pil=make_image_pil(size),
bounding_boxes_xyxy=make_bounding_boxes(size, format=tv_tensors.BoundingBoxFormat.XYXY, device=device),
bounding_boxes_xywh=make_bounding_boxes(size, format=tv_tensors.BoundingBoxFormat.XYWH, device=device),
bounding_boxes_cxcywh=make_bounding_boxes(size, format=tv_tensors.BoundingBoxFormat.CXCYWH, device=device),
bounding_boxes_degenerate_xyxy=tv_tensors.BoundingBoxes(
[
[0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height
[0, 0, 1, 0], # no width
[2, 0, 1, 1], # x1 > x2, y1 < y2
[0, 2, 1, 1], # x1 < x2, y1 > y2
[2, 2, 1, 1], # x1 > x2, y1 > y2
],
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=size,
device=device,
),
bounding_boxes_degenerate_xywh=tv_tensors.BoundingBoxes(
[
[0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height
[0, 0, 1, 0], # no width
[0, 0, 1, -1], # negative height
[0, 0, -1, 1], # negative width
[0, 0, -1, -1], # negative height and width
],
format=tv_tensors.BoundingBoxFormat.XYWH,
canvas_size=size,
device=device,
),
bounding_boxes_degenerate_cxcywh=tv_tensors.BoundingBoxes(
[
[0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height
[0, 0, 1, 0], # no width
[0, 0, 1, -1], # negative height
[0, 0, -1, 1], # negative width
[0, 0, -1, -1], # negative height and width
],
format=tv_tensors.BoundingBoxFormat.CXCYWH,
canvas_size=size,
device=device,
),
detection_mask=make_detection_mask(size, device=device),
segmentation_mask=make_segmentation_mask(size, device=device),
int=0,
float=0.0,
bool=True,
none=None,
str="str",
path=Path.cwd(),
object=object(),
tensor=torch.empty(5),
array=np.empty(5),
)
if adapter is not None:
input = adapter(transform, input, device)
return input
def _check_transform_sample_input_smoke(transform, input, *, adapter):
# This is a bunch of input / output convention checks, using a big sample with different parts as input.
if not check_type(input, (is_pure_tensor, PIL.Image.Image, tv_tensors.Image, tv_tensors.Video)):
return
sample = _make_transform_sample(
# adapter might change transform inplace
transform=transform if adapter is None else deepcopy(transform),
image_or_video=input,
adapter=adapter,
)
for container_type in [dict, list, tuple]:
if container_type is dict:
input = sample
else:
input = container_type(sample.values())
input_flat, input_spec = tree_flatten(input)
with freeze_rng_state():
torch.manual_seed(0)
output = transform(input)
output_flat, output_spec = tree_flatten(output)
assert output_spec == input_spec
for output_item, input_item, should_be_transformed in zip(
output_flat, input_flat, transforms.Transform()._needs_transform_list(input_flat)
):
if should_be_transformed:
assert type(output_item) is type(input_item)
else:
assert output_item is input_item
# Enforce that the transform does not turn a degenerate bounding box, e.g. marked by RandomIoUCrop (or any other
# future transform that does this), back into a valid one.
for degenerate_bounding_boxes in (
bounding_box
for name, bounding_box in sample.items()
if "degenerate" in name and isinstance(bounding_box, tv_tensors.BoundingBoxes)
):
sample = dict(
boxes=degenerate_bounding_boxes,
labels=torch.randint(10, (degenerate_bounding_boxes.shape[0],), device=degenerate_bounding_boxes.device),
)
assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4)
def check_transform(transform, input, check_v1_compatibility=True, check_sample_input=True):
pickle.loads(pickle.dumps(transform)) pickle.loads(pickle.dumps(transform))
output = transform(input) output = transform(input)
...@@ -270,6 +388,11 @@ def check_transform(transform, input, check_v1_compatibility=True): ...@@ -270,6 +388,11 @@ def check_transform(transform, input, check_v1_compatibility=True):
if isinstance(input, tv_tensors.BoundingBoxes) and not isinstance(transform, transforms.ConvertBoundingBoxFormat): if isinstance(input, tv_tensors.BoundingBoxes) and not isinstance(transform, transforms.ConvertBoundingBoxFormat):
assert output.format == input.format assert output.format == input.format
if check_sample_input:
_check_transform_sample_input_smoke(
transform, input, adapter=check_sample_input if callable(check_sample_input) else None
)
if check_v1_compatibility: if check_v1_compatibility:
_check_transform_v1_compatibility(transform, input, **_to_tolerances(check_v1_compatibility)) _check_transform_v1_compatibility(transform, input, **_to_tolerances(check_v1_compatibility))
...@@ -1758,7 +1881,7 @@ class TestToDtype: ...@@ -1758,7 +1881,7 @@ class TestToDtype:
input = make_input(dtype=input_dtype, device=device) input = make_input(dtype=input_dtype, device=device)
if as_dict: if as_dict:
output_dtype = {type(input): output_dtype} output_dtype = {type(input): output_dtype}
check_transform(transforms.ToDtype(dtype=output_dtype, scale=scale), input) check_transform(transforms.ToDtype(dtype=output_dtype, scale=scale), input, check_sample_input=not as_dict)
def reference_convert_dtype_image_tensor(self, image, dtype=torch.float, scale=False): def reference_convert_dtype_image_tensor(self, image, dtype=torch.float, scale=False):
input_dtype = image.dtype input_dtype = image.dtype
...@@ -2559,9 +2682,13 @@ class TestCrop: ...@@ -2559,9 +2682,13 @@ class TestCrop:
def test_transform(self, param, value, make_input): def test_transform(self, param, value, make_input):
input = make_input(self.INPUT_SIZE) input = make_input(self.INPUT_SIZE)
check_sample_input = True
if param == "fill": if param == "fill":
if isinstance(input, tv_tensors.Mask) and isinstance(value, (tuple, list)): if isinstance(value, (tuple, list)):
if isinstance(input, tv_tensors.Mask):
pytest.skip("F.pad_mask doesn't support non-scalar fill.") pytest.skip("F.pad_mask doesn't support non-scalar fill.")
else:
check_sample_input = False
kwargs = dict( kwargs = dict(
# 1. size is required # 1. size is required
...@@ -2576,6 +2703,7 @@ class TestCrop: ...@@ -2576,6 +2703,7 @@ class TestCrop:
transforms.RandomCrop(**kwargs, pad_if_needed=True), transforms.RandomCrop(**kwargs, pad_if_needed=True),
input, input,
check_v1_compatibility=param != "fill" or isinstance(value, (int, float)), check_v1_compatibility=param != "fill" or isinstance(value, (int, float)),
check_sample_input=check_sample_input,
) )
@pytest.mark.parametrize("padding", [1, (1, 1), (1, 1, 1, 1)]) @pytest.mark.parametrize("padding", [1, (1, 1), (1, 1, 1, 1)])
...@@ -2761,8 +2889,12 @@ class TestErase: ...@@ -2761,8 +2889,12 @@ class TestErase:
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, make_input, device): def test_transform(self, make_input, device):
input = make_input(device=device) input = make_input(device=device)
with pytest.warns(UserWarning, match="currently passing through inputs of type"):
check_transform( check_transform(
transforms.RandomErasing(p=1), input, check_v1_compatibility=not isinstance(input, PIL.Image.Image) transforms.RandomErasing(p=1),
input,
check_v1_compatibility=not isinstance(input, PIL.Image.Image),
) )
def _reference_erase_image(self, image, *, i, j, h, w, v): def _reference_erase_image(self, image, *, i, j, h, w, v):
...@@ -2835,18 +2967,6 @@ class TestErase: ...@@ -2835,18 +2967,6 @@ class TestErase:
with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value"): with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value"):
transform._get_params([make_image()]) transform._get_params([make_image()])
@pytest.mark.parametrize("make_input", [make_bounding_boxes, make_detection_mask])
def test_transform_passthrough(self, make_input):
transform = transforms.RandomErasing(p=1)
input = make_input(self.INPUT_SIZE)
with pytest.warns(UserWarning, match="currently passing through inputs of type"):
# RandomErasing requires an image or video to be present
_, output = transform(make_image(self.INPUT_SIZE), input)
assert output is input
class TestGaussianBlur: class TestGaussianBlur:
@pytest.mark.parametrize("kernel_size", [1, 3, (3, 1), [3, 5]]) @pytest.mark.parametrize("kernel_size", [1, 3, (3, 1), [3, 5]])
...@@ -3063,6 +3183,21 @@ class TestAutoAugmentTransforms: ...@@ -3063,6 +3183,21 @@ class TestAutoAugmentTransforms:
else: else:
assert_close(actual, expected, rtol=0, atol=1) assert_close(actual, expected, rtol=0, atol=1)
def _sample_input_adapter(self, transform, input, device):
adapted_input = {}
image_or_video_found = False
for key, value in input.items():
if isinstance(value, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
# AA transforms don't support bounding boxes or masks
continue
elif check_type(value, (tv_tensors.Image, tv_tensors.Video, is_pure_tensor, PIL.Image.Image)):
if image_or_video_found:
# AA transforms only support a single image or video
continue
image_or_video_found = True
adapted_input[key] = value
return adapted_input
@pytest.mark.parametrize( @pytest.mark.parametrize(
"transform", "transform",
[transforms.AutoAugment(), transforms.RandAugment(), transforms.TrivialAugmentWide(), transforms.AugMix()], [transforms.AutoAugment(), transforms.RandAugment(), transforms.TrivialAugmentWide(), transforms.AugMix()],
...@@ -3087,7 +3222,9 @@ class TestAutoAugmentTransforms: ...@@ -3087,7 +3222,9 @@ class TestAutoAugmentTransforms:
# For v2, we changed the random sampling of the AA transforms. This makes it impossible to compare the v1 # For v2, we changed the random sampling of the AA transforms. This makes it impossible to compare the v1
# and v2 outputs without complicated mocking and monkeypatching. Thus, we skip the v1 compatibility checks # and v2 outputs without complicated mocking and monkeypatching. Thus, we skip the v1 compatibility checks
# here and only check if we can script the v2 transform and subsequently call the result. # here and only check if we can script the v2 transform and subsequently call the result.
check_transform(transform, input, check_v1_compatibility=False) check_transform(
transform, input, check_v1_compatibility=False, check_sample_input=self._sample_input_adapter
)
if type(input) is torch.Tensor and dtype is torch.uint8: if type(input) is torch.Tensor and dtype is torch.uint8:
_script(transform)(input) _script(transform)(input)
...@@ -4014,9 +4151,25 @@ class TestNormalize: ...@@ -4014,9 +4151,25 @@ class TestNormalize:
with pytest.raises(ValueError, match="std evaluated to zero, leading to division by zero"): with pytest.raises(ValueError, match="std evaluated to zero, leading to division by zero"):
F.normalize_image(make_image(dtype=torch.float32), mean=self.MEAN, std=std) F.normalize_image(make_image(dtype=torch.float32), mean=self.MEAN, std=std)
def _sample_input_adapter(self, transform, input, device):
adapted_input = {}
for key, value in input.items():
if isinstance(value, PIL.Image.Image):
# normalize doesn't support PIL images
continue
elif check_type(value, (is_pure_tensor, tv_tensors.Image, tv_tensors.Video)):
# normalize doesn't support integer images
value = F.to_dtype(value, torch.float32, scale=True)
adapted_input[key] = value
return adapted_input
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video]) @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
def test_transform(self, make_input): def test_transform(self, make_input):
check_transform(transforms.Normalize(mean=self.MEAN, std=self.STD), make_input(dtype=torch.float32)) check_transform(
transforms.Normalize(mean=self.MEAN, std=self.STD),
make_input(dtype=torch.float32),
check_sample_input=self._sample_input_adapter,
)
def _reference_normalize_image(self, image, *, mean, std): def _reference_normalize_image(self, image, *, mean, std):
image = image.numpy() image = image.numpy()
...@@ -4543,7 +4696,11 @@ class TestFiveTenCrop: ...@@ -4543,7 +4696,11 @@ class TestFiveTenCrop:
) )
@pytest.mark.parametrize("transform_cls", [transforms.FiveCrop, transforms.TenCrop]) @pytest.mark.parametrize("transform_cls", [transforms.FiveCrop, transforms.TenCrop])
def test_transform(self, make_input, transform_cls): def test_transform(self, make_input, transform_cls):
check_transform(self._TransformWrapper(transform_cls(size=self.OUTPUT_SIZE)), make_input(self.INPUT_SIZE)) check_transform(
self._TransformWrapper(transform_cls(size=self.OUTPUT_SIZE)),
make_input(self.INPUT_SIZE),
check_sample_input=False,
)
@pytest.mark.parametrize("make_input", [make_bounding_boxes, make_detection_mask]) @pytest.mark.parametrize("make_input", [make_bounding_boxes, make_detection_mask])
@pytest.mark.parametrize("transform_cls", [transforms.FiveCrop, transforms.TenCrop]) @pytest.mark.parametrize("transform_cls", [transforms.FiveCrop, transforms.TenCrop])
...@@ -4826,3 +4983,66 @@ class TestScaleJitter: ...@@ -4826,3 +4983,66 @@ class TestScaleJitter:
assert int(input_size[0] * r_min) <= height <= int(input_size[0] * r_max) assert int(input_size[0] * r_min) <= height <= int(input_size[0] * r_max)
assert int(input_size[1] * r_min) <= width <= int(input_size[1] * r_max) assert int(input_size[1] * r_min) <= width <= int(input_size[1] * r_max)
class TestLinearTransform:
def _make_matrix_and_vector(self, input, *, device=None):
device = device or input.device
numel = math.prod(F.get_dimensions(input))
transformation_matrix = torch.randn((numel, numel), device=device)
mean_vector = torch.randn((numel,), device=device)
return transformation_matrix, mean_vector
def _sample_input_adapter(self, transform, input, device):
return {key: value for key, value in input.items() if not isinstance(value, PIL.Image.Image)}
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, make_input, dtype, device):
input = make_input(dtype=dtype, device=device)
check_transform(
transforms.LinearTransformation(*self._make_matrix_and_vector(input)),
input,
check_sample_input=self._sample_input_adapter,
)
def test_transform_error(self):
with pytest.raises(ValueError, match="transformation_matrix should be square"):
transforms.LinearTransformation(transformation_matrix=torch.rand(2, 3), mean_vector=torch.rand(2))
with pytest.raises(ValueError, match="mean_vector should have the same length"):
transforms.LinearTransformation(transformation_matrix=torch.rand(2, 2), mean_vector=torch.rand(1))
for matrix_dtype, vector_dtype in [(torch.float32, torch.float64), (torch.float64, torch.float32)]:
with pytest.raises(ValueError, match="Input tensors should have the same dtype"):
transforms.LinearTransformation(
transformation_matrix=torch.rand(2, 2, dtype=matrix_dtype),
mean_vector=torch.rand(2, dtype=vector_dtype),
)
image = make_image()
transform = transforms.LinearTransformation(transformation_matrix=torch.rand(2, 2), mean_vector=torch.rand(2))
with pytest.raises(ValueError, match="Input tensor and transformation matrix have incompatible shape"):
transform(image)
transform = transforms.LinearTransformation(*self._make_matrix_and_vector(image))
with pytest.raises(TypeError, match="does not support PIL images"):
transform(F.to_pil_image(image))
@needs_cuda
def test_transform_error_cuda(self):
for matrix_device, vector_device in [("cuda", "cpu"), ("cpu", "cuda")]:
with pytest.raises(ValueError, match="Input tensors should be on the same device"):
transforms.LinearTransformation(
transformation_matrix=torch.rand(2, 2, device=matrix_device),
mean_vector=torch.rand(2, device=vector_device),
)
for input_device, param_device in [("cuda", "cpu"), ("cpu", "cuda")]:
input = make_image(device=input_device)
transform = transforms.LinearTransformation(*self._make_matrix_and_vector(input, device=param_device))
with pytest.raises(
ValueError, match="Input tensor should be on the same device as transformation matrix and mean vector"
):
transform(input)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment