Unverified Commit 1f4a9846 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add proper smoke test for prototype transforms (#7238)

parent 0bdd01a7
import itertools import itertools
import pathlib
import re import re
import warnings import warnings
from collections import defaultdict from collections import defaultdict
...@@ -20,15 +21,16 @@ from prototype_common_utils import ( ...@@ -20,15 +21,16 @@ from prototype_common_utils import (
make_image, make_image,
make_images, make_images,
make_label, make_label,
make_masks,
make_one_hot_labels, make_one_hot_labels,
make_segmentation_mask, make_segmentation_mask,
make_video, make_video,
make_videos, make_videos,
) )
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.ops.boxes import box_iou from torchvision.ops.boxes import box_iou
from torchvision.prototype import datapoints, transforms from torchvision.prototype import datapoints, transforms
from torchvision.prototype.transforms.utils import check_type, is_simple_tensor from torchvision.prototype.transforms import functional as F
from torchvision.prototype.transforms.utils import check_type, is_simple_tensor, query_chw
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims] BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
...@@ -66,53 +68,201 @@ def parametrize(transforms_with_inputs): ...@@ -66,53 +68,201 @@ def parametrize(transforms_with_inputs):
) )
def parametrize_from_transforms(*transforms): def auto_augment_adapter(transform, input, device):
transforms_with_inputs = [] adapted_input = {}
for transform in transforms: image_or_video_found = False
for creation_fn in [ for key, value in input.items():
make_images, if isinstance(value, (datapoints.BoundingBox, datapoints.Mask)):
make_bounding_boxes, # AA transforms don't support bounding boxes or masks
make_one_hot_labels,
make_vanilla_tensor_images,
make_pil_images,
make_masks,
make_videos,
]:
inputs = list(creation_fn())
try:
output = transform(inputs[0])
except Exception:
continue continue
else: elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor, PIL.Image.Image)):
if output is inputs[0]: if image_or_video_found:
# AA transforms only support a single image or video
continue continue
image_or_video_found = True
adapted_input[key] = value
return adapted_input
transforms_with_inputs.append((transform, inputs)) def linear_transformation_adapter(transform, input, device):
flat_inputs = list(input.values())
c, h, w = query_chw(
[
item
for item, needs_transform in zip(flat_inputs, transforms.Transform()._needs_transform_list(flat_inputs))
if needs_transform
]
)
num_elements = c * h * w
transform.transformation_matrix = torch.randn((num_elements, num_elements), device=device)
transform.mean_vector = torch.randn((num_elements,), device=device)
return {key: value for key, value in input.items() if not isinstance(value, PIL.Image.Image)}
return parametrize(transforms_with_inputs)
def normalize_adapter(transform, input, device):
adapted_input = {}
for key, value in input.items():
if isinstance(value, PIL.Image.Image):
# normalize doesn't support PIL images
continue
elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor)):
# normalize doesn't support integer images
value = F.convert_dtype(value, torch.float32)
adapted_input[key] = value
return adapted_input
class TestSmoke: class TestSmoke:
@parametrize_from_transforms( @pytest.mark.parametrize(
transforms.RandomErasing(p=1.0), ("transform", "adapter"),
transforms.Resize([16, 16], antialias=True), [
transforms.CenterCrop([16, 16]), (transforms.RandomErasing(p=1.0), None),
transforms.ConvertDtype(), (transforms.AugMix(), auto_augment_adapter),
transforms.RandomHorizontalFlip(), (transforms.AutoAugment(), auto_augment_adapter),
transforms.Pad(5), (transforms.RandAugment(), auto_augment_adapter),
transforms.RandomZoomOut(), (transforms.TrivialAugmentWide(), auto_augment_adapter),
transforms.RandomRotation(degrees=(-45, 45)), (transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.3, hue=0.15), None),
transforms.RandomAffine(degrees=(-45, 45)), (transforms.Grayscale(), None),
transforms.RandomCrop([16, 16], padding=1, pad_if_needed=True), (transforms.RandomAdjustSharpness(sharpness_factor=0.5, p=1.0), None),
# TODO: Something wrong with input data setup. Let's fix that (transforms.RandomAutocontrast(p=1.0), None),
# transforms.RandomEqualize(), (transforms.RandomEqualize(p=1.0), None),
# transforms.RandomInvert(), (transforms.RandomGrayscale(p=1.0), None),
# transforms.RandomPosterize(bits=4), (transforms.RandomInvert(p=1.0), None),
# transforms.RandomSolarize(threshold=0.5), (transforms.RandomPhotometricDistort(p=1.0), None),
# transforms.RandomAdjustSharpness(sharpness_factor=0.5), (transforms.RandomPosterize(bits=4, p=1.0), None),
(transforms.RandomSolarize(threshold=0.5, p=1.0), None),
(transforms.CenterCrop([16, 16]), None),
(transforms.ElasticTransform(sigma=1.0), None),
(transforms.Pad(4), None),
(transforms.RandomAffine(degrees=30.0), None),
(transforms.RandomCrop([16, 16], pad_if_needed=True), None),
(transforms.RandomHorizontalFlip(p=1.0), None),
(transforms.RandomPerspective(p=1.0), None),
(transforms.RandomResize(min_size=10, max_size=20), None),
(transforms.RandomResizedCrop([16, 16]), None),
(transforms.RandomRotation(degrees=30), None),
(transforms.RandomShortestSize(min_size=10), None),
(transforms.RandomVerticalFlip(p=1.0), None),
(transforms.RandomZoomOut(p=1.0), None),
(transforms.Resize([16, 16], antialias=True), None),
(transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2)), None),
(transforms.ClampBoundingBoxes(), None),
(transforms.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.CXCYWH), None),
(transforms.ConvertDtype(), None),
(transforms.GaussianBlur(kernel_size=3), None),
(
transforms.LinearTransformation(
# These are just dummy values that will be filled by the adapter. We can't define them upfront,
# because for we neither know the spatial size nor the device at this point
transformation_matrix=torch.empty((1, 1)),
mean_vector=torch.empty((1,)),
),
linear_transformation_adapter,
),
(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), normalize_adapter),
(transforms.ToDtype(torch.float64), None),
(transforms.UniformTemporalSubsample(num_samples=2), None),
],
ids=lambda transform: type(transform).__name__,
) )
def test_common(self, transform, input): @pytest.mark.parametrize("container_type", [dict, list, tuple])
transform(input) @pytest.mark.parametrize(
"image_or_video",
[
make_image(),
make_video(),
next(make_pil_images(color_spaces=["RGB"])),
next(make_vanilla_tensor_images()),
],
)
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_common(self, transform, adapter, container_type, image_or_video, device):
spatial_size = F.get_spatial_size(image_or_video)
input = dict(
image_or_video=image_or_video,
image_datapoint=make_image(size=spatial_size),
video_datapoint=make_video(size=spatial_size),
image_pil=next(make_pil_images(sizes=[spatial_size], color_spaces=["RGB"])),
bounding_box_xyxy=make_bounding_box(
format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(3,)
),
bounding_box_xywh=make_bounding_box(
format=datapoints.BoundingBoxFormat.XYWH, spatial_size=spatial_size, extra_dims=(4,)
),
bounding_box_cxcywh=make_bounding_box(
format=datapoints.BoundingBoxFormat.CXCYWH, spatial_size=spatial_size, extra_dims=(5,)
),
bounding_box_degenerate_xyxy=datapoints.BoundingBox(
[
[0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height
[0, 0, 1, 0], # no width
[2, 0, 1, 1], # x1 > x2, y1 < y2
[0, 2, 1, 1], # x1 < x2, y1 > y2
[2, 2, 1, 1], # x1 > x2, y1 > y2
],
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size,
),
bounding_box_degenerate_xywh=datapoints.BoundingBox(
[
[0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height
[0, 0, 1, 0], # no width
[0, 0, 1, -1], # negative height
[0, 0, -1, 1], # negative width
[0, 0, -1, -1], # negative height and width
],
format=datapoints.BoundingBoxFormat.XYWH,
spatial_size=spatial_size,
),
bounding_box_degenerate_cxcywh=datapoints.BoundingBox(
[
[0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height
[0, 0, 1, 0], # no width
[0, 0, 1, -1], # negative height
[0, 0, -1, 1], # negative width
[0, 0, -1, -1], # negative height and width
],
format=datapoints.BoundingBoxFormat.CXCYWH,
spatial_size=spatial_size,
),
detection_mask=make_detection_mask(size=spatial_size),
segmentation_mask=make_segmentation_mask(size=spatial_size),
int=0,
float=0.0,
bool=True,
none=None,
str="str",
path=pathlib.Path.cwd(),
object=object(),
tensor=torch.empty(5),
array=np.empty(5),
)
if adapter is not None:
input = adapter(transform, input, device)
if container_type in {tuple, list}:
input = container_type(input.values())
input_flat, input_spec = tree_flatten(input)
input_flat = [item.to(device) if isinstance(item, torch.Tensor) else item for item in input_flat]
input = tree_unflatten(input_flat, input_spec)
torch.manual_seed(0)
output = transform(input)
output_flat, output_spec = tree_flatten(output)
assert output_spec == input_spec
for output_item, input_item, should_be_transformed in zip(
output_flat, input_flat, transforms.Transform()._needs_transform_list(input_flat)
):
if should_be_transformed:
assert type(output_item) is type(input_item)
else:
assert output_item is input_item
@parametrize( @parametrize(
[ [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment