Unverified Commit 8acf1ca2 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

extract common utils for prototype transform tests (#6552)

parent b5c961d4
import functools
import itertools
import PIL.Image
import pytest
import torch
import torch.testing
from torch.nn.functional import one_hot
from torch.testing._comparison import assert_equal as _assert_equal, TensorLikePair
from torchvision.prototype import features
from torchvision.prototype.transforms.functional import to_image_tensor
from torchvision.transforms.functional_tensor import _max_value as get_max_value
class ImagePair(TensorLikePair):
def _process_inputs(self, actual, expected, *, id, allow_subclasses):
return super()._process_inputs(
*[to_image_tensor(input) if isinstance(input, PIL.Image.Image) else input for input in [actual, expected]],
id=id,
allow_subclasses=allow_subclasses,
)
assert_equal = functools.partial(_assert_equal, pair_types=[ImagePair], rtol=0, atol=0)
class ArgsKwargs:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
def __iter__(self):
yield self.args
yield self.kwargs
def __str__(self):
def short_repr(obj, max=20):
repr_ = repr(obj)
if len(repr_) <= max:
return repr_
return f"{repr_[:max//2]}...{repr_[-(max//2-3):]}"
return ", ".join(
itertools.chain(
[short_repr(arg) for arg in self.args],
[f"{param}={short_repr(kwarg)}" for param, kwarg in self.kwargs.items()],
)
)
make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32, constant_alpha=True):
size = size or torch.randint(16, 33, (2,)).tolist()
try:
num_channels = {
features.ColorSpace.GRAY: 1,
features.ColorSpace.GRAY_ALPHA: 2,
features.ColorSpace.RGB: 3,
features.ColorSpace.RGB_ALPHA: 4,
}[color_space]
except KeyError as error:
raise pytest.UsageError() from error
shape = (*extra_dims, num_channels, *size)
max_value = get_max_value(dtype)
data = make_tensor(shape, low=0, high=max_value, dtype=dtype)
if color_space in {features.ColorSpace.GRAY_ALPHA, features.ColorSpace.RGB_ALPHA} and constant_alpha:
data[..., -1, :, :] = max_value
return features.Image(data, color_space=color_space)
make_grayscale_image = functools.partial(make_image, color_space=features.ColorSpace.GRAY)
make_rgb_image = functools.partial(make_image, color_space=features.ColorSpace.RGB)
def make_images(
sizes=((16, 16), (7, 33), (31, 9)),
color_spaces=(
features.ColorSpace.GRAY,
features.ColorSpace.GRAY_ALPHA,
features.ColorSpace.RGB,
features.ColorSpace.RGB_ALPHA,
),
dtypes=(torch.float32, torch.uint8),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
):
for size, color_space, dtype in itertools.product(sizes, color_spaces, dtypes):
yield make_image(size, color_space=color_space, dtype=dtype)
for color_space, dtype, extra_dims_ in itertools.product(color_spaces, dtypes, extra_dims):
yield make_image(size=sizes[0], color_space=color_space, extra_dims=extra_dims_, dtype=dtype)
def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
low, high = torch.broadcast_tensors(
*[torch.as_tensor(arg) for arg in ((0, arg1) if arg2 is None else (arg1, arg2))]
)
return torch.stack(
[
torch.randint(low_scalar, high_scalar, (), **kwargs)
for low_scalar, high_scalar in zip(low.flatten().tolist(), high.flatten().tolist())
]
).reshape(low.shape)
def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch.int64):
if isinstance(format, str):
format = features.BoundingBoxFormat[format]
if any(dim == 0 for dim in extra_dims):
return features.BoundingBox(torch.empty(*extra_dims, 4), format=format, image_size=image_size)
height, width = image_size
if format == features.BoundingBoxFormat.XYXY:
x1 = torch.randint(0, width // 2, extra_dims)
y1 = torch.randint(0, height // 2, extra_dims)
x2 = randint_with_tensor_bounds(x1 + 1, width - x1) + x1
y2 = randint_with_tensor_bounds(y1 + 1, height - y1) + y1
parts = (x1, y1, x2, y2)
elif format == features.BoundingBoxFormat.XYWH:
x = torch.randint(0, width // 2, extra_dims)
y = torch.randint(0, height // 2, extra_dims)
w = randint_with_tensor_bounds(1, width - x)
h = randint_with_tensor_bounds(1, height - y)
parts = (x, y, w, h)
elif format == features.BoundingBoxFormat.CXCYWH:
cx = torch.randint(1, width - 1, ())
cy = torch.randint(1, height - 1, ())
w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1)
h = randint_with_tensor_bounds(1, torch.minimum(cy, height - cy) + 1)
parts = (cx, cy, w, h)
else:
raise pytest.UsageError()
return features.BoundingBox(torch.stack(parts, dim=-1).to(dtype), format=format, image_size=image_size)
make_xyxy_bounding_box = functools.partial(make_bounding_box, format=features.BoundingBoxFormat.XYXY)
def make_bounding_boxes(
formats=(features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH),
image_sizes=((32, 32),),
dtypes=(torch.int64, torch.float32),
extra_dims=((0,), (), (4,), (2, 3), (5, 0), (0, 5)),
):
for format, image_size, dtype in itertools.product(formats, image_sizes, dtypes):
yield make_bounding_box(format=format, image_size=image_size, dtype=dtype)
for format, extra_dims_ in itertools.product(formats, extra_dims):
yield make_bounding_box(format=format, extra_dims=extra_dims_)
def make_label(size=(), *, categories=("category0", "category1")):
return features.Label(torch.randint(0, len(categories) if categories else 10, size), categories=categories)
def make_one_hot_label(*args, **kwargs):
label = make_label(*args, **kwargs)
return features.OneHotLabel(one_hot(label, num_classes=len(label.categories)), categories=label.categories)
def make_one_hot_labels(
*,
num_categories=(1, 2, 10),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
):
for num_categories_ in num_categories:
yield make_one_hot_label(categories=[f"category{idx}" for idx in range(num_categories_)])
for extra_dims_ in extra_dims:
yield make_one_hot_label(extra_dims_)
def make_segmentation_mask(size=None, *, num_objects=None, extra_dims=(), dtype=torch.uint8):
size = size if size is not None else torch.randint(16, 33, (2,)).tolist()
num_objects = num_objects if num_objects is not None else int(torch.randint(1, 11, ()))
shape = (*extra_dims, num_objects, *size)
data = make_tensor(shape, low=0, high=2, dtype=dtype)
return features.SegmentationMask(data)
def make_segmentation_masks(
sizes=((16, 16), (7, 33), (31, 9)),
dtypes=(torch.uint8,),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
num_objects=(1, 0, 10),
):
for size, dtype, extra_dims_ in itertools.product(sizes, dtypes, extra_dims):
yield make_segmentation_mask(size=size, dtype=dtype, extra_dims=extra_dims_)
for dtype, extra_dims_, num_objects_ in itertools.product(dtypes, extra_dims, num_objects):
yield make_segmentation_mask(size=sizes[0], num_objects=num_objects_, dtype=dtype, extra_dims=extra_dims_)
...@@ -7,7 +7,7 @@ import PIL.Image ...@@ -7,7 +7,7 @@ import PIL.Image
import pytest import pytest
import torch import torch
from common_utils import assert_equal, cpu_and_gpu from common_utils import assert_equal, cpu_and_gpu
from test_prototype_transforms_functional import ( from prototype_common_utils import (
make_bounding_box, make_bounding_box,
make_bounding_boxes, make_bounding_boxes,
make_image, make_image,
......
import enum import enum
import functools
import inspect import inspect
import itertools
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import pytest import pytest
import torch import torch
from test_prototype_transforms_functional import make_images from prototype_common_utils import ArgsKwargs, assert_equal, make_images
from torch.testing._comparison import assert_equal as _assert_equal, TensorLikePair
from torchvision import transforms as legacy_transforms from torchvision import transforms as legacy_transforms
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype import features, transforms as prototype_transforms from torchvision.prototype import features, transforms as prototype_transforms
from torchvision.prototype.transforms.functional import to_image_pil, to_image_tensor from torchvision.prototype.transforms.functional import to_image_pil
class ImagePair(TensorLikePair):
def _process_inputs(self, actual, expected, *, id, allow_subclasses):
return super()._process_inputs(
*[to_image_tensor(input) if isinstance(input, PIL.Image.Image) else input for input in [actual, expected]],
id=id,
allow_subclasses=allow_subclasses,
)
assert_equal = functools.partial(_assert_equal, pair_types=[ImagePair], rtol=0, atol=0)
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)]) DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)])
class ArgsKwargs:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
def __iter__(self):
yield self.args
yield self.kwargs
def __str__(self):
def short_repr(obj, max=20):
repr_ = repr(obj)
if len(repr_) <= max:
return repr_
return f"{repr_[:max//2]}...{repr_[-(max//2-3):]}"
return ", ".join(
itertools.chain(
[short_repr(arg) for arg in self.args],
[f"{param}={short_repr(kwarg)}" for param, kwarg in self.kwargs.items()],
)
)
class ConsistencyConfig: class ConsistencyConfig:
def __init__( def __init__(
self, self,
......
...@@ -3,7 +3,7 @@ import pytest ...@@ -3,7 +3,7 @@ import pytest
import torch import torch
from test_prototype_transforms_functional import make_bounding_box, make_image, make_segmentation_mask from prototype_common_utils import make_bounding_box, make_image, make_segmentation_mask
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms._utils import has_all, has_any from torchvision.prototype.transforms._utils import has_all, has_any
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment