Unverified Commit 8acf1ca2 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

extract common utils for prototype transform tests (#6552)

parent b5c961d4
import functools
import itertools
import PIL.Image
import pytest
import torch
import torch.testing
from torch.nn.functional import one_hot
from torch.testing._comparison import assert_equal as _assert_equal, TensorLikePair
from torchvision.prototype import features
from torchvision.prototype.transforms.functional import to_image_tensor
from torchvision.transforms.functional_tensor import _max_value as get_max_value
class ImagePair(TensorLikePair):
def _process_inputs(self, actual, expected, *, id, allow_subclasses):
return super()._process_inputs(
*[to_image_tensor(input) if isinstance(input, PIL.Image.Image) else input for input in [actual, expected]],
id=id,
allow_subclasses=allow_subclasses,
)
assert_equal = functools.partial(_assert_equal, pair_types=[ImagePair], rtol=0, atol=0)
class ArgsKwargs:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
def __iter__(self):
yield self.args
yield self.kwargs
def __str__(self):
def short_repr(obj, max=20):
repr_ = repr(obj)
if len(repr_) <= max:
return repr_
return f"{repr_[:max//2]}...{repr_[-(max//2-3):]}"
return ", ".join(
itertools.chain(
[short_repr(arg) for arg in self.args],
[f"{param}={short_repr(kwarg)}" for param, kwarg in self.kwargs.items()],
)
)
make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32, constant_alpha=True):
size = size or torch.randint(16, 33, (2,)).tolist()
try:
num_channels = {
features.ColorSpace.GRAY: 1,
features.ColorSpace.GRAY_ALPHA: 2,
features.ColorSpace.RGB: 3,
features.ColorSpace.RGB_ALPHA: 4,
}[color_space]
except KeyError as error:
raise pytest.UsageError() from error
shape = (*extra_dims, num_channels, *size)
max_value = get_max_value(dtype)
data = make_tensor(shape, low=0, high=max_value, dtype=dtype)
if color_space in {features.ColorSpace.GRAY_ALPHA, features.ColorSpace.RGB_ALPHA} and constant_alpha:
data[..., -1, :, :] = max_value
return features.Image(data, color_space=color_space)
make_grayscale_image = functools.partial(make_image, color_space=features.ColorSpace.GRAY)
make_rgb_image = functools.partial(make_image, color_space=features.ColorSpace.RGB)
def make_images(
sizes=((16, 16), (7, 33), (31, 9)),
color_spaces=(
features.ColorSpace.GRAY,
features.ColorSpace.GRAY_ALPHA,
features.ColorSpace.RGB,
features.ColorSpace.RGB_ALPHA,
),
dtypes=(torch.float32, torch.uint8),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
):
for size, color_space, dtype in itertools.product(sizes, color_spaces, dtypes):
yield make_image(size, color_space=color_space, dtype=dtype)
for color_space, dtype, extra_dims_ in itertools.product(color_spaces, dtypes, extra_dims):
yield make_image(size=sizes[0], color_space=color_space, extra_dims=extra_dims_, dtype=dtype)
def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
low, high = torch.broadcast_tensors(
*[torch.as_tensor(arg) for arg in ((0, arg1) if arg2 is None else (arg1, arg2))]
)
return torch.stack(
[
torch.randint(low_scalar, high_scalar, (), **kwargs)
for low_scalar, high_scalar in zip(low.flatten().tolist(), high.flatten().tolist())
]
).reshape(low.shape)
def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch.int64):
if isinstance(format, str):
format = features.BoundingBoxFormat[format]
if any(dim == 0 for dim in extra_dims):
return features.BoundingBox(torch.empty(*extra_dims, 4), format=format, image_size=image_size)
height, width = image_size
if format == features.BoundingBoxFormat.XYXY:
x1 = torch.randint(0, width // 2, extra_dims)
y1 = torch.randint(0, height // 2, extra_dims)
x2 = randint_with_tensor_bounds(x1 + 1, width - x1) + x1
y2 = randint_with_tensor_bounds(y1 + 1, height - y1) + y1
parts = (x1, y1, x2, y2)
elif format == features.BoundingBoxFormat.XYWH:
x = torch.randint(0, width // 2, extra_dims)
y = torch.randint(0, height // 2, extra_dims)
w = randint_with_tensor_bounds(1, width - x)
h = randint_with_tensor_bounds(1, height - y)
parts = (x, y, w, h)
elif format == features.BoundingBoxFormat.CXCYWH:
cx = torch.randint(1, width - 1, ())
cy = torch.randint(1, height - 1, ())
w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1)
h = randint_with_tensor_bounds(1, torch.minimum(cy, height - cy) + 1)
parts = (cx, cy, w, h)
else:
raise pytest.UsageError()
return features.BoundingBox(torch.stack(parts, dim=-1).to(dtype), format=format, image_size=image_size)
make_xyxy_bounding_box = functools.partial(make_bounding_box, format=features.BoundingBoxFormat.XYXY)
def make_bounding_boxes(
formats=(features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH),
image_sizes=((32, 32),),
dtypes=(torch.int64, torch.float32),
extra_dims=((0,), (), (4,), (2, 3), (5, 0), (0, 5)),
):
for format, image_size, dtype in itertools.product(formats, image_sizes, dtypes):
yield make_bounding_box(format=format, image_size=image_size, dtype=dtype)
for format, extra_dims_ in itertools.product(formats, extra_dims):
yield make_bounding_box(format=format, extra_dims=extra_dims_)
def make_label(size=(), *, categories=("category0", "category1")):
return features.Label(torch.randint(0, len(categories) if categories else 10, size), categories=categories)
def make_one_hot_label(*args, **kwargs):
label = make_label(*args, **kwargs)
return features.OneHotLabel(one_hot(label, num_classes=len(label.categories)), categories=label.categories)
def make_one_hot_labels(
*,
num_categories=(1, 2, 10),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
):
for num_categories_ in num_categories:
yield make_one_hot_label(categories=[f"category{idx}" for idx in range(num_categories_)])
for extra_dims_ in extra_dims:
yield make_one_hot_label(extra_dims_)
def make_segmentation_mask(size=None, *, num_objects=None, extra_dims=(), dtype=torch.uint8):
size = size if size is not None else torch.randint(16, 33, (2,)).tolist()
num_objects = num_objects if num_objects is not None else int(torch.randint(1, 11, ()))
shape = (*extra_dims, num_objects, *size)
data = make_tensor(shape, low=0, high=2, dtype=dtype)
return features.SegmentationMask(data)
def make_segmentation_masks(
sizes=((16, 16), (7, 33), (31, 9)),
dtypes=(torch.uint8,),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
num_objects=(1, 0, 10),
):
for size, dtype, extra_dims_ in itertools.product(sizes, dtypes, extra_dims):
yield make_segmentation_mask(size=size, dtype=dtype, extra_dims=extra_dims_)
for dtype, extra_dims_, num_objects_ in itertools.product(dtypes, extra_dims, num_objects):
yield make_segmentation_mask(size=sizes[0], num_objects=num_objects_, dtype=dtype, extra_dims=extra_dims_)
...@@ -7,7 +7,7 @@ import PIL.Image ...@@ -7,7 +7,7 @@ import PIL.Image
import pytest import pytest
import torch import torch
from common_utils import assert_equal, cpu_and_gpu from common_utils import assert_equal, cpu_and_gpu
from test_prototype_transforms_functional import ( from prototype_common_utils import (
make_bounding_box, make_bounding_box,
make_bounding_boxes, make_bounding_boxes,
make_image, make_image,
......
import enum import enum
import functools
import inspect import inspect
import itertools
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import pytest import pytest
import torch import torch
from test_prototype_transforms_functional import make_images from prototype_common_utils import ArgsKwargs, assert_equal, make_images
from torch.testing._comparison import assert_equal as _assert_equal, TensorLikePair
from torchvision import transforms as legacy_transforms from torchvision import transforms as legacy_transforms
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype import features, transforms as prototype_transforms from torchvision.prototype import features, transforms as prototype_transforms
from torchvision.prototype.transforms.functional import to_image_pil, to_image_tensor from torchvision.prototype.transforms.functional import to_image_pil
class ImagePair(TensorLikePair):
def _process_inputs(self, actual, expected, *, id, allow_subclasses):
return super()._process_inputs(
*[to_image_tensor(input) if isinstance(input, PIL.Image.Image) else input for input in [actual, expected]],
id=id,
allow_subclasses=allow_subclasses,
)
assert_equal = functools.partial(_assert_equal, pair_types=[ImagePair], rtol=0, atol=0)
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)]) DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)])
class ArgsKwargs:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
def __iter__(self):
yield self.args
yield self.kwargs
def __str__(self):
def short_repr(obj, max=20):
repr_ = repr(obj)
if len(repr_) <= max:
return repr_
return f"{repr_[:max//2]}...{repr_[-(max//2-3):]}"
return ", ".join(
itertools.chain(
[short_repr(arg) for arg in self.args],
[f"{param}={short_repr(kwarg)}" for param, kwarg in self.kwargs.items()],
)
)
class ConsistencyConfig: class ConsistencyConfig:
def __init__( def __init__(
self, self,
......
import functools
import itertools import itertools
import math import math
import os import os
...@@ -9,167 +8,12 @@ import pytest ...@@ -9,167 +8,12 @@ import pytest
import torch.testing import torch.testing
import torchvision.prototype.transforms.functional as F import torchvision.prototype.transforms.functional as F
from common_utils import cpu_and_gpu from common_utils import cpu_and_gpu
from prototype_common_utils import ArgsKwargs, make_bounding_boxes, make_image, make_images, make_segmentation_masks
from torch import jit from torch import jit
from torch.nn.functional import one_hot
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding
from torchvision.prototype.transforms.functional._meta import convert_bounding_box_format from torchvision.prototype.transforms.functional._meta import convert_bounding_box_format
from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.functional_tensor import _max_value as get_max_value
make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32, constant_alpha=True):
size = size or torch.randint(16, 33, (2,)).tolist()
try:
num_channels = {
features.ColorSpace.GRAY: 1,
features.ColorSpace.GRAY_ALPHA: 2,
features.ColorSpace.RGB: 3,
features.ColorSpace.RGB_ALPHA: 4,
}[color_space]
except KeyError as error:
raise pytest.UsageError() from error
shape = (*extra_dims, num_channels, *size)
max_value = get_max_value(dtype)
data = make_tensor(shape, low=0, high=max_value, dtype=dtype)
if color_space in {features.ColorSpace.GRAY_ALPHA, features.ColorSpace.RGB_ALPHA} and constant_alpha:
data[..., -1, :, :] = max_value
return features.Image(data, color_space=color_space)
make_grayscale_image = functools.partial(make_image, color_space=features.ColorSpace.GRAY)
make_rgb_image = functools.partial(make_image, color_space=features.ColorSpace.RGB)
def make_images(
sizes=((16, 16), (7, 33), (31, 9)),
color_spaces=(
features.ColorSpace.GRAY,
features.ColorSpace.GRAY_ALPHA,
features.ColorSpace.RGB,
features.ColorSpace.RGB_ALPHA,
),
dtypes=(torch.float32, torch.uint8),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
):
for size, color_space, dtype in itertools.product(sizes, color_spaces, dtypes):
yield make_image(size, color_space=color_space, dtype=dtype)
for color_space, dtype, extra_dims_ in itertools.product(color_spaces, dtypes, extra_dims):
yield make_image(size=sizes[0], color_space=color_space, extra_dims=extra_dims_, dtype=dtype)
def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
low, high = torch.broadcast_tensors(
*[torch.as_tensor(arg) for arg in ((0, arg1) if arg2 is None else (arg1, arg2))]
)
return torch.stack(
[
torch.randint(low_scalar, high_scalar, (), **kwargs)
for low_scalar, high_scalar in zip(low.flatten().tolist(), high.flatten().tolist())
]
).reshape(low.shape)
def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch.int64):
if isinstance(format, str):
format = features.BoundingBoxFormat[format]
if any(dim == 0 for dim in extra_dims):
return features.BoundingBox(torch.empty(*extra_dims, 4), format=format, image_size=image_size)
height, width = image_size
if format == features.BoundingBoxFormat.XYXY:
x1 = torch.randint(0, width // 2, extra_dims)
y1 = torch.randint(0, height // 2, extra_dims)
x2 = randint_with_tensor_bounds(x1 + 1, width - x1) + x1
y2 = randint_with_tensor_bounds(y1 + 1, height - y1) + y1
parts = (x1, y1, x2, y2)
elif format == features.BoundingBoxFormat.XYWH:
x = torch.randint(0, width // 2, extra_dims)
y = torch.randint(0, height // 2, extra_dims)
w = randint_with_tensor_bounds(1, width - x)
h = randint_with_tensor_bounds(1, height - y)
parts = (x, y, w, h)
elif format == features.BoundingBoxFormat.CXCYWH:
cx = torch.randint(1, width - 1, ())
cy = torch.randint(1, height - 1, ())
w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1)
h = randint_with_tensor_bounds(1, torch.minimum(cy, height - cy) + 1)
parts = (cx, cy, w, h)
else:
raise pytest.UsageError()
return features.BoundingBox(torch.stack(parts, dim=-1).to(dtype), format=format, image_size=image_size)
make_xyxy_bounding_box = functools.partial(make_bounding_box, format=features.BoundingBoxFormat.XYXY)
def make_bounding_boxes(
formats=(features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH),
image_sizes=((32, 32),),
dtypes=(torch.int64, torch.float32),
extra_dims=((0,), (), (4,), (2, 3), (5, 0), (0, 5)),
):
for format, image_size, dtype in itertools.product(formats, image_sizes, dtypes):
yield make_bounding_box(format=format, image_size=image_size, dtype=dtype)
for format, extra_dims_ in itertools.product(formats, extra_dims):
yield make_bounding_box(format=format, extra_dims=extra_dims_)
def make_label(size=(), *, categories=("category0", "category1")):
return features.Label(torch.randint(0, len(categories) if categories else 10, size), categories=categories)
def make_one_hot_label(*args, **kwargs):
label = make_label(*args, **kwargs)
return features.OneHotLabel(one_hot(label, num_classes=len(label.categories)), categories=label.categories)
def make_one_hot_labels(
*,
num_categories=(1, 2, 10),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
):
for num_categories_ in num_categories:
yield make_one_hot_label(categories=[f"category{idx}" for idx in range(num_categories_)])
for extra_dims_ in extra_dims:
yield make_one_hot_label(extra_dims_)
def make_segmentation_mask(size=None, *, num_objects=None, extra_dims=(), dtype=torch.uint8):
size = size if size is not None else torch.randint(16, 33, (2,)).tolist()
num_objects = num_objects if num_objects is not None else int(torch.randint(1, 11, ()))
shape = (*extra_dims, num_objects, *size)
data = make_tensor(shape, low=0, high=2, dtype=dtype)
return features.SegmentationMask(data)
def make_segmentation_masks(
sizes=((16, 16), (7, 33), (31, 9)),
dtypes=(torch.uint8,),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
num_objects=(1, 0, 10),
):
for size, dtype, extra_dims_ in itertools.product(sizes, dtypes, extra_dims):
yield make_segmentation_mask(size=size, dtype=dtype, extra_dims=extra_dims_)
for dtype, extra_dims_, num_objects_ in itertools.product(dtypes, extra_dims, num_objects):
yield make_segmentation_mask(size=sizes[0], num_objects=num_objects_, dtype=dtype, extra_dims=extra_dims_)
class SampleInput:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
class FunctionalInfo: class FunctionalInfo:
...@@ -182,7 +26,7 @@ class FunctionalInfo: ...@@ -182,7 +26,7 @@ class FunctionalInfo:
yield from self._sample_inputs_fn() yield from self._sample_inputs_fn()
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
if len(args) == 1 and not kwargs and isinstance(args[0], SampleInput): if len(args) == 1 and not kwargs and isinstance(args[0], ArgsKwargs):
sample_input = args[0] sample_input = args[0]
return self.functional(*sample_input.args, **sample_input.kwargs) return self.functional(*sample_input.args, **sample_input.kwargs)
...@@ -200,37 +44,37 @@ def register_kernel_info_from_sample_inputs_fn(sample_inputs_fn): ...@@ -200,37 +44,37 @@ def register_kernel_info_from_sample_inputs_fn(sample_inputs_fn):
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def horizontal_flip_image_tensor(): def horizontal_flip_image_tensor():
for image in make_images(): for image in make_images():
yield SampleInput(image) yield ArgsKwargs(image)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def horizontal_flip_bounding_box(): def horizontal_flip_bounding_box():
for bounding_box in make_bounding_boxes(formats=[features.BoundingBoxFormat.XYXY]): for bounding_box in make_bounding_boxes(formats=[features.BoundingBoxFormat.XYXY]):
yield SampleInput(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size) yield ArgsKwargs(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def horizontal_flip_segmentation_mask(): def horizontal_flip_segmentation_mask():
for mask in make_segmentation_masks(): for mask in make_segmentation_masks():
yield SampleInput(mask) yield ArgsKwargs(mask)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def vertical_flip_image_tensor(): def vertical_flip_image_tensor():
for image in make_images(): for image in make_images():
yield SampleInput(image) yield ArgsKwargs(image)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def vertical_flip_bounding_box(): def vertical_flip_bounding_box():
for bounding_box in make_bounding_boxes(formats=[features.BoundingBoxFormat.XYXY]): for bounding_box in make_bounding_boxes(formats=[features.BoundingBoxFormat.XYXY]):
yield SampleInput(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size) yield ArgsKwargs(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def vertical_flip_segmentation_mask(): def vertical_flip_segmentation_mask():
for mask in make_segmentation_masks(): for mask in make_segmentation_masks():
yield SampleInput(mask) yield ArgsKwargs(mask)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
...@@ -252,7 +96,7 @@ def resize_image_tensor(): ...@@ -252,7 +96,7 @@ def resize_image_tensor():
]: ]:
if max_size is not None: if max_size is not None:
size = [size[0]] size = [size[0]]
yield SampleInput(image, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) yield ArgsKwargs(image, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
...@@ -268,7 +112,7 @@ def resize_bounding_box(): ...@@ -268,7 +112,7 @@ def resize_bounding_box():
]: ]:
if max_size is not None: if max_size is not None:
size = [size[0]] size = [size[0]]
yield SampleInput(bounding_box, size=size, image_size=bounding_box.image_size) yield ArgsKwargs(bounding_box, size=size, image_size=bounding_box.image_size)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
...@@ -284,7 +128,7 @@ def resize_segmentation_mask(): ...@@ -284,7 +128,7 @@ def resize_segmentation_mask():
]: ]:
if max_size is not None: if max_size is not None:
size = [size[0]] size = [size[0]]
yield SampleInput(mask, size=size, max_size=max_size) yield ArgsKwargs(mask, size=size, max_size=max_size)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
...@@ -296,7 +140,7 @@ def affine_image_tensor(): ...@@ -296,7 +140,7 @@ def affine_image_tensor():
[0.77, 1.27], # scale [0.77, 1.27], # scale
[0, 12], # shear [0, 12], # shear
): ):
yield SampleInput( yield ArgsKwargs(
image, image,
angle=angle, angle=angle,
translate=(translate, translate), translate=(translate, translate),
...@@ -315,7 +159,7 @@ def affine_bounding_box(): ...@@ -315,7 +159,7 @@ def affine_bounding_box():
[0.77, 1.27], # scale [0.77, 1.27], # scale
[0, 12], # shear [0, 12], # shear
): ):
yield SampleInput( yield ArgsKwargs(
bounding_box, bounding_box,
format=bounding_box.format, format=bounding_box.format,
image_size=bounding_box.image_size, image_size=bounding_box.image_size,
...@@ -335,7 +179,7 @@ def affine_segmentation_mask(): ...@@ -335,7 +179,7 @@ def affine_segmentation_mask():
[0.77, 1.27], # scale [0.77, 1.27], # scale
[0, 12], # shear [0, 12], # shear
): ):
yield SampleInput( yield ArgsKwargs(
mask, mask,
angle=angle, angle=angle,
translate=(translate, translate), translate=(translate, translate),
...@@ -357,7 +201,7 @@ def rotate_image_tensor(): ...@@ -357,7 +201,7 @@ def rotate_image_tensor():
# Skip warning: The provided center argument is ignored if expand is True # Skip warning: The provided center argument is ignored if expand is True
continue continue
yield SampleInput(image, angle=angle, expand=expand, center=center, fill=fill) yield ArgsKwargs(image, angle=angle, expand=expand, center=center, fill=fill)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
...@@ -369,7 +213,7 @@ def rotate_bounding_box(): ...@@ -369,7 +213,7 @@ def rotate_bounding_box():
# Skip warning: The provided center argument is ignored if expand is True # Skip warning: The provided center argument is ignored if expand is True
continue continue
yield SampleInput( yield ArgsKwargs(
bounding_box, bounding_box,
format=bounding_box.format, format=bounding_box.format,
image_size=bounding_box.image_size, image_size=bounding_box.image_size,
...@@ -391,7 +235,7 @@ def rotate_segmentation_mask(): ...@@ -391,7 +235,7 @@ def rotate_segmentation_mask():
# Skip warning: The provided center argument is ignored if expand is True # Skip warning: The provided center argument is ignored if expand is True
continue continue
yield SampleInput( yield ArgsKwargs(
mask, mask,
angle=angle, angle=angle,
expand=expand, expand=expand,
...@@ -402,7 +246,7 @@ def rotate_segmentation_mask(): ...@@ -402,7 +246,7 @@ def rotate_segmentation_mask():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def crop_image_tensor(): def crop_image_tensor():
for image, top, left, height, width in itertools.product(make_images(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20]): for image, top, left, height, width in itertools.product(make_images(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20]):
yield SampleInput( yield ArgsKwargs(
image, image,
top=top, top=top,
left=left, left=left,
...@@ -414,7 +258,7 @@ def crop_image_tensor(): ...@@ -414,7 +258,7 @@ def crop_image_tensor():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def crop_bounding_box(): def crop_bounding_box():
for bounding_box, top, left in itertools.product(make_bounding_boxes(), [-8, 0, 9], [-8, 0, 9]): for bounding_box, top, left in itertools.product(make_bounding_boxes(), [-8, 0, 9], [-8, 0, 9]):
yield SampleInput( yield ArgsKwargs(
bounding_box, bounding_box,
format=bounding_box.format, format=bounding_box.format,
top=top, top=top,
...@@ -427,7 +271,7 @@ def crop_segmentation_mask(): ...@@ -427,7 +271,7 @@ def crop_segmentation_mask():
for mask, top, left, height, width in itertools.product( for mask, top, left, height, width in itertools.product(
make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20] make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20]
): ):
yield SampleInput( yield ArgsKwargs(
mask, mask,
top=top, top=top,
left=left, left=left,
...@@ -447,7 +291,7 @@ def resized_crop_image_tensor(): ...@@ -447,7 +291,7 @@ def resized_crop_image_tensor():
[(16, 18)], [(16, 18)],
[True, False], [True, False],
): ):
yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size, antialias=antialias) yield ArgsKwargs(mask, top=top, left=left, height=height, width=width, size=size, antialias=antialias)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
...@@ -455,7 +299,7 @@ def resized_crop_bounding_box(): ...@@ -455,7 +299,7 @@ def resized_crop_bounding_box():
for bounding_box, top, left, height, width, size in itertools.product( for bounding_box, top, left, height, width, size in itertools.product(
make_bounding_boxes(), [-8, 9], [-8, 9], [32, 22], [34, 20], [(32, 32), (16, 18)] make_bounding_boxes(), [-8, 9], [-8, 9], [32, 22], [34, 20], [(32, 32), (16, 18)]
): ):
yield SampleInput( yield ArgsKwargs(
bounding_box, format=bounding_box.format, top=top, left=left, height=height, width=width, size=size bounding_box, format=bounding_box.format, top=top, left=left, height=height, width=width, size=size
) )
...@@ -465,7 +309,7 @@ def resized_crop_segmentation_mask(): ...@@ -465,7 +309,7 @@ def resized_crop_segmentation_mask():
for mask, top, left, height, width, size in itertools.product( for mask, top, left, height, width, size in itertools.product(
make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20], [(32, 32), (16, 18)] make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20], [(32, 32), (16, 18)]
): ):
yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size) yield ArgsKwargs(mask, top=top, left=left, height=height, width=width, size=size)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
...@@ -476,7 +320,7 @@ def pad_image_tensor(): ...@@ -476,7 +320,7 @@ def pad_image_tensor():
[None, 12, 12.0], # fill [None, 12, 12.0], # fill
["constant", "symmetric", "edge", "reflect"], # padding mode, ["constant", "symmetric", "edge", "reflect"], # padding mode,
): ):
yield SampleInput(image, padding=padding, fill=fill, padding_mode=padding_mode) yield ArgsKwargs(image, padding=padding, fill=fill, padding_mode=padding_mode)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
...@@ -486,7 +330,7 @@ def pad_segmentation_mask(): ...@@ -486,7 +330,7 @@ def pad_segmentation_mask():
[[1], [1, 1], [1, 1, 2, 2]], # padding [[1], [1, 1], [1, 1, 2, 2]], # padding
["constant", "symmetric", "edge", "reflect"], # padding mode, ["constant", "symmetric", "edge", "reflect"], # padding mode,
): ):
yield SampleInput(mask, padding=padding, padding_mode=padding_mode) yield ArgsKwargs(mask, padding=padding, padding_mode=padding_mode)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
...@@ -495,7 +339,7 @@ def pad_bounding_box(): ...@@ -495,7 +339,7 @@ def pad_bounding_box():
make_bounding_boxes(), make_bounding_boxes(),
[[1], [1, 1], [1, 1, 2, 2]], [[1], [1, 1], [1, 1, 2, 2]],
): ):
yield SampleInput(bounding_box, padding=padding, format=bounding_box.format) yield ArgsKwargs(bounding_box, padding=padding, format=bounding_box.format)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
...@@ -508,7 +352,7 @@ def perspective_image_tensor(): ...@@ -508,7 +352,7 @@ def perspective_image_tensor():
], ],
[None, [128], [12.0]], # fill [None, [128], [12.0]], # fill
): ):
yield SampleInput(image, perspective_coeffs=perspective_coeffs, fill=fill) yield ArgsKwargs(image, perspective_coeffs=perspective_coeffs, fill=fill)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
...@@ -520,7 +364,7 @@ def perspective_bounding_box(): ...@@ -520,7 +364,7 @@ def perspective_bounding_box():
[0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063], [0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
], ],
): ):
yield SampleInput( yield ArgsKwargs(
bounding_box, bounding_box,
format=bounding_box.format, format=bounding_box.format,
perspective_coeffs=perspective_coeffs, perspective_coeffs=perspective_coeffs,
...@@ -536,7 +380,7 @@ def perspective_segmentation_mask(): ...@@ -536,7 +380,7 @@ def perspective_segmentation_mask():
[0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063], [0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
], ],
): ):
yield SampleInput( yield ArgsKwargs(
mask, mask,
perspective_coeffs=perspective_coeffs, perspective_coeffs=perspective_coeffs,
) )
...@@ -550,7 +394,7 @@ def elastic_image_tensor(): ...@@ -550,7 +394,7 @@ def elastic_image_tensor():
): ):
h, w = image.shape[-2:] h, w = image.shape[-2:]
displacement = torch.rand(1, h, w, 2) displacement = torch.rand(1, h, w, 2)
yield SampleInput(image, displacement=displacement, fill=fill) yield ArgsKwargs(image, displacement=displacement, fill=fill)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
...@@ -558,7 +402,7 @@ def elastic_bounding_box(): ...@@ -558,7 +402,7 @@ def elastic_bounding_box():
for bounding_box in make_bounding_boxes(): for bounding_box in make_bounding_boxes():
h, w = bounding_box.image_size h, w = bounding_box.image_size
displacement = torch.rand(1, h, w, 2) displacement = torch.rand(1, h, w, 2)
yield SampleInput( yield ArgsKwargs(
bounding_box, bounding_box,
format=bounding_box.format, format=bounding_box.format,
displacement=displacement, displacement=displacement,
...@@ -570,7 +414,7 @@ def elastic_segmentation_mask(): ...@@ -570,7 +414,7 @@ def elastic_segmentation_mask():
for mask in make_segmentation_masks(extra_dims=((), (4,))): for mask in make_segmentation_masks(extra_dims=((), (4,))):
h, w = mask.shape[-2:] h, w = mask.shape[-2:]
displacement = torch.rand(1, h, w, 2) displacement = torch.rand(1, h, w, 2)
yield SampleInput( yield ArgsKwargs(
mask, mask,
displacement=displacement, displacement=displacement,
) )
...@@ -582,13 +426,13 @@ def center_crop_image_tensor(): ...@@ -582,13 +426,13 @@ def center_crop_image_tensor():
make_images(sizes=((16, 16), (7, 33), (31, 9))), make_images(sizes=((16, 16), (7, 33), (31, 9))),
[[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size [[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
): ):
yield SampleInput(mask, output_size) yield ArgsKwargs(mask, output_size)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def center_crop_bounding_box(): def center_crop_bounding_box():
for bounding_box, output_size in itertools.product(make_bounding_boxes(), [(24, 12), [16, 18], [46, 48], [12]]): for bounding_box, output_size in itertools.product(make_bounding_boxes(), [(24, 12), [16, 18], [46, 48], [12]]):
yield SampleInput( yield ArgsKwargs(
bounding_box, format=bounding_box.format, output_size=output_size, image_size=bounding_box.image_size bounding_box, format=bounding_box.format, output_size=output_size, image_size=bounding_box.image_size
) )
...@@ -599,7 +443,7 @@ def center_crop_segmentation_mask(): ...@@ -599,7 +443,7 @@ def center_crop_segmentation_mask():
make_segmentation_masks(sizes=((16, 16), (7, 33), (31, 9))), make_segmentation_masks(sizes=((16, 16), (7, 33), (31, 9))),
[[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size [[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
): ):
yield SampleInput(mask, output_size) yield ArgsKwargs(mask, output_size)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
...@@ -609,7 +453,7 @@ def gaussian_blur_image_tensor(): ...@@ -609,7 +453,7 @@ def gaussian_blur_image_tensor():
[[3, 3]], [[3, 3]],
[None, [3.0, 3.0]], [None, [3.0, 3.0]],
): ):
yield SampleInput(image, kernel_size=kernel_size, sigma=sigma) yield ArgsKwargs(image, kernel_size=kernel_size, sigma=sigma)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
...@@ -617,13 +461,13 @@ def equalize_image_tensor(): ...@@ -617,13 +461,13 @@ def equalize_image_tensor():
for image in make_images(extra_dims=(), color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)): for image in make_images(extra_dims=(), color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)):
if image.dtype != torch.uint8: if image.dtype != torch.uint8:
continue continue
yield SampleInput(image) yield ArgsKwargs(image)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def invert_image_tensor(): def invert_image_tensor():
for image in make_images(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)): for image in make_images(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)):
yield SampleInput(image) yield ArgsKwargs(image)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
...@@ -634,7 +478,7 @@ def posterize_image_tensor(): ...@@ -634,7 +478,7 @@ def posterize_image_tensor():
): ):
if image.dtype != torch.uint8: if image.dtype != torch.uint8:
continue continue
yield SampleInput(image, bits=bits) yield ArgsKwargs(image, bits=bits)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
...@@ -645,13 +489,13 @@ def solarize_image_tensor(): ...@@ -645,13 +489,13 @@ def solarize_image_tensor():
): ):
if image.is_floating_point() and threshold > 1.0: if image.is_floating_point() and threshold > 1.0:
continue continue
yield SampleInput(image, threshold=threshold) yield ArgsKwargs(image, threshold=threshold)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def autocontrast_image_tensor(): def autocontrast_image_tensor():
for image in make_images(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)): for image in make_images(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)):
yield SampleInput(image) yield ArgsKwargs(image)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
...@@ -660,14 +504,14 @@ def adjust_sharpness_image_tensor(): ...@@ -660,14 +504,14 @@ def adjust_sharpness_image_tensor():
make_images(extra_dims=((4,),), color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)), make_images(extra_dims=((4,),), color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)),
[0.1, 0.5], [0.1, 0.5],
): ):
yield SampleInput(image, sharpness_factor=sharpness_factor) yield ArgsKwargs(image, sharpness_factor=sharpness_factor)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def erase_image_tensor(): def erase_image_tensor():
for image in make_images(): for image in make_images():
c = image.shape[-3] c = image.shape[-3]
yield SampleInput(image, i=1, j=2, h=6, w=7, v=torch.rand(c, 6, 7)) yield ArgsKwargs(image, i=1, j=2, h=6, w=7, v=torch.rand(c, 6, 7))
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -3,7 +3,7 @@ import pytest ...@@ -3,7 +3,7 @@ import pytest
import torch import torch
from test_prototype_transforms_functional import make_bounding_box, make_image, make_segmentation_mask from prototype_common_utils import make_bounding_box, make_image, make_segmentation_mask
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms._utils import has_all, has_any from torchvision.prototype.transforms._utils import has_all, has_any
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment