Unverified Commit 8acf1ca2 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

extract common utils for prototype transform tests (#6552)

parent b5c961d4
import functools
import itertools
import PIL.Image
import pytest
import torch
import torch.testing
from torch.nn.functional import one_hot
from torch.testing._comparison import assert_equal as _assert_equal, TensorLikePair
from torchvision.prototype import features
from torchvision.prototype.transforms.functional import to_image_tensor
from torchvision.transforms.functional_tensor import _max_value as get_max_value
class ImagePair(TensorLikePair):
def _process_inputs(self, actual, expected, *, id, allow_subclasses):
return super()._process_inputs(
*[to_image_tensor(input) if isinstance(input, PIL.Image.Image) else input for input in [actual, expected]],
id=id,
allow_subclasses=allow_subclasses,
)
assert_equal = functools.partial(_assert_equal, pair_types=[ImagePair], rtol=0, atol=0)
class ArgsKwargs:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
def __iter__(self):
yield self.args
yield self.kwargs
def __str__(self):
def short_repr(obj, max=20):
repr_ = repr(obj)
if len(repr_) <= max:
return repr_
return f"{repr_[:max//2]}...{repr_[-(max//2-3):]}"
return ", ".join(
itertools.chain(
[short_repr(arg) for arg in self.args],
[f"{param}={short_repr(kwarg)}" for param, kwarg in self.kwargs.items()],
)
)
make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32, constant_alpha=True):
size = size or torch.randint(16, 33, (2,)).tolist()
try:
num_channels = {
features.ColorSpace.GRAY: 1,
features.ColorSpace.GRAY_ALPHA: 2,
features.ColorSpace.RGB: 3,
features.ColorSpace.RGB_ALPHA: 4,
}[color_space]
except KeyError as error:
raise pytest.UsageError() from error
shape = (*extra_dims, num_channels, *size)
max_value = get_max_value(dtype)
data = make_tensor(shape, low=0, high=max_value, dtype=dtype)
if color_space in {features.ColorSpace.GRAY_ALPHA, features.ColorSpace.RGB_ALPHA} and constant_alpha:
data[..., -1, :, :] = max_value
return features.Image(data, color_space=color_space)
make_grayscale_image = functools.partial(make_image, color_space=features.ColorSpace.GRAY)
make_rgb_image = functools.partial(make_image, color_space=features.ColorSpace.RGB)
def make_images(
sizes=((16, 16), (7, 33), (31, 9)),
color_spaces=(
features.ColorSpace.GRAY,
features.ColorSpace.GRAY_ALPHA,
features.ColorSpace.RGB,
features.ColorSpace.RGB_ALPHA,
),
dtypes=(torch.float32, torch.uint8),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
):
for size, color_space, dtype in itertools.product(sizes, color_spaces, dtypes):
yield make_image(size, color_space=color_space, dtype=dtype)
for color_space, dtype, extra_dims_ in itertools.product(color_spaces, dtypes, extra_dims):
yield make_image(size=sizes[0], color_space=color_space, extra_dims=extra_dims_, dtype=dtype)
def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
low, high = torch.broadcast_tensors(
*[torch.as_tensor(arg) for arg in ((0, arg1) if arg2 is None else (arg1, arg2))]
)
return torch.stack(
[
torch.randint(low_scalar, high_scalar, (), **kwargs)
for low_scalar, high_scalar in zip(low.flatten().tolist(), high.flatten().tolist())
]
).reshape(low.shape)
def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch.int64):
if isinstance(format, str):
format = features.BoundingBoxFormat[format]
if any(dim == 0 for dim in extra_dims):
return features.BoundingBox(torch.empty(*extra_dims, 4), format=format, image_size=image_size)
height, width = image_size
if format == features.BoundingBoxFormat.XYXY:
x1 = torch.randint(0, width // 2, extra_dims)
y1 = torch.randint(0, height // 2, extra_dims)
x2 = randint_with_tensor_bounds(x1 + 1, width - x1) + x1
y2 = randint_with_tensor_bounds(y1 + 1, height - y1) + y1
parts = (x1, y1, x2, y2)
elif format == features.BoundingBoxFormat.XYWH:
x = torch.randint(0, width // 2, extra_dims)
y = torch.randint(0, height // 2, extra_dims)
w = randint_with_tensor_bounds(1, width - x)
h = randint_with_tensor_bounds(1, height - y)
parts = (x, y, w, h)
elif format == features.BoundingBoxFormat.CXCYWH:
cx = torch.randint(1, width - 1, ())
cy = torch.randint(1, height - 1, ())
w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1)
h = randint_with_tensor_bounds(1, torch.minimum(cy, height - cy) + 1)
parts = (cx, cy, w, h)
else:
raise pytest.UsageError()
return features.BoundingBox(torch.stack(parts, dim=-1).to(dtype), format=format, image_size=image_size)
make_xyxy_bounding_box = functools.partial(make_bounding_box, format=features.BoundingBoxFormat.XYXY)
def make_bounding_boxes(
formats=(features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH),
image_sizes=((32, 32),),
dtypes=(torch.int64, torch.float32),
extra_dims=((0,), (), (4,), (2, 3), (5, 0), (0, 5)),
):
for format, image_size, dtype in itertools.product(formats, image_sizes, dtypes):
yield make_bounding_box(format=format, image_size=image_size, dtype=dtype)
for format, extra_dims_ in itertools.product(formats, extra_dims):
yield make_bounding_box(format=format, extra_dims=extra_dims_)
def make_label(size=(), *, categories=("category0", "category1")):
return features.Label(torch.randint(0, len(categories) if categories else 10, size), categories=categories)
def make_one_hot_label(*args, **kwargs):
label = make_label(*args, **kwargs)
return features.OneHotLabel(one_hot(label, num_classes=len(label.categories)), categories=label.categories)
def make_one_hot_labels(
*,
num_categories=(1, 2, 10),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
):
for num_categories_ in num_categories:
yield make_one_hot_label(categories=[f"category{idx}" for idx in range(num_categories_)])
for extra_dims_ in extra_dims:
yield make_one_hot_label(extra_dims_)
def make_segmentation_mask(size=None, *, num_objects=None, extra_dims=(), dtype=torch.uint8):
size = size if size is not None else torch.randint(16, 33, (2,)).tolist()
num_objects = num_objects if num_objects is not None else int(torch.randint(1, 11, ()))
shape = (*extra_dims, num_objects, *size)
data = make_tensor(shape, low=0, high=2, dtype=dtype)
return features.SegmentationMask(data)
def make_segmentation_masks(
sizes=((16, 16), (7, 33), (31, 9)),
dtypes=(torch.uint8,),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
num_objects=(1, 0, 10),
):
for size, dtype, extra_dims_ in itertools.product(sizes, dtypes, extra_dims):
yield make_segmentation_mask(size=size, dtype=dtype, extra_dims=extra_dims_)
for dtype, extra_dims_, num_objects_ in itertools.product(dtypes, extra_dims, num_objects):
yield make_segmentation_mask(size=sizes[0], num_objects=num_objects_, dtype=dtype, extra_dims=extra_dims_)
......@@ -7,7 +7,7 @@ import PIL.Image
import pytest
import torch
from common_utils import assert_equal, cpu_and_gpu
from test_prototype_transforms_functional import (
from prototype_common_utils import (
make_bounding_box,
make_bounding_boxes,
make_image,
......
import enum
import functools
import inspect
import itertools
import numpy as np
import PIL.Image
import pytest
import torch
from test_prototype_transforms_functional import make_images
from torch.testing._comparison import assert_equal as _assert_equal, TensorLikePair
from prototype_common_utils import ArgsKwargs, assert_equal, make_images
from torchvision import transforms as legacy_transforms
from torchvision._utils import sequence_to_str
from torchvision.prototype import features, transforms as prototype_transforms
from torchvision.prototype.transforms.functional import to_image_pil, to_image_tensor
class ImagePair(TensorLikePair):
def _process_inputs(self, actual, expected, *, id, allow_subclasses):
return super()._process_inputs(
*[to_image_tensor(input) if isinstance(input, PIL.Image.Image) else input for input in [actual, expected]],
id=id,
allow_subclasses=allow_subclasses,
)
assert_equal = functools.partial(_assert_equal, pair_types=[ImagePair], rtol=0, atol=0)
from torchvision.prototype.transforms.functional import to_image_pil
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)])
class ArgsKwargs:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
def __iter__(self):
yield self.args
yield self.kwargs
def __str__(self):
def short_repr(obj, max=20):
repr_ = repr(obj)
if len(repr_) <= max:
return repr_
return f"{repr_[:max//2]}...{repr_[-(max//2-3):]}"
return ", ".join(
itertools.chain(
[short_repr(arg) for arg in self.args],
[f"{param}={short_repr(kwarg)}" for param, kwarg in self.kwargs.items()],
)
)
class ConsistencyConfig:
def __init__(
self,
......
import functools
import itertools
import math
import os
......@@ -9,167 +8,12 @@ import pytest
import torch.testing
import torchvision.prototype.transforms.functional as F
from common_utils import cpu_and_gpu
from prototype_common_utils import ArgsKwargs, make_bounding_boxes, make_image, make_images, make_segmentation_masks
from torch import jit
from torch.nn.functional import one_hot
from torchvision.prototype import features
from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding
from torchvision.prototype.transforms.functional._meta import convert_bounding_box_format
from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.functional_tensor import _max_value as get_max_value
make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32, constant_alpha=True):
size = size or torch.randint(16, 33, (2,)).tolist()
try:
num_channels = {
features.ColorSpace.GRAY: 1,
features.ColorSpace.GRAY_ALPHA: 2,
features.ColorSpace.RGB: 3,
features.ColorSpace.RGB_ALPHA: 4,
}[color_space]
except KeyError as error:
raise pytest.UsageError() from error
shape = (*extra_dims, num_channels, *size)
max_value = get_max_value(dtype)
data = make_tensor(shape, low=0, high=max_value, dtype=dtype)
if color_space in {features.ColorSpace.GRAY_ALPHA, features.ColorSpace.RGB_ALPHA} and constant_alpha:
data[..., -1, :, :] = max_value
return features.Image(data, color_space=color_space)
make_grayscale_image = functools.partial(make_image, color_space=features.ColorSpace.GRAY)
make_rgb_image = functools.partial(make_image, color_space=features.ColorSpace.RGB)
def make_images(
sizes=((16, 16), (7, 33), (31, 9)),
color_spaces=(
features.ColorSpace.GRAY,
features.ColorSpace.GRAY_ALPHA,
features.ColorSpace.RGB,
features.ColorSpace.RGB_ALPHA,
),
dtypes=(torch.float32, torch.uint8),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
):
for size, color_space, dtype in itertools.product(sizes, color_spaces, dtypes):
yield make_image(size, color_space=color_space, dtype=dtype)
for color_space, dtype, extra_dims_ in itertools.product(color_spaces, dtypes, extra_dims):
yield make_image(size=sizes[0], color_space=color_space, extra_dims=extra_dims_, dtype=dtype)
def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
low, high = torch.broadcast_tensors(
*[torch.as_tensor(arg) for arg in ((0, arg1) if arg2 is None else (arg1, arg2))]
)
return torch.stack(
[
torch.randint(low_scalar, high_scalar, (), **kwargs)
for low_scalar, high_scalar in zip(low.flatten().tolist(), high.flatten().tolist())
]
).reshape(low.shape)
def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch.int64):
if isinstance(format, str):
format = features.BoundingBoxFormat[format]
if any(dim == 0 for dim in extra_dims):
return features.BoundingBox(torch.empty(*extra_dims, 4), format=format, image_size=image_size)
height, width = image_size
if format == features.BoundingBoxFormat.XYXY:
x1 = torch.randint(0, width // 2, extra_dims)
y1 = torch.randint(0, height // 2, extra_dims)
x2 = randint_with_tensor_bounds(x1 + 1, width - x1) + x1
y2 = randint_with_tensor_bounds(y1 + 1, height - y1) + y1
parts = (x1, y1, x2, y2)
elif format == features.BoundingBoxFormat.XYWH:
x = torch.randint(0, width // 2, extra_dims)
y = torch.randint(0, height // 2, extra_dims)
w = randint_with_tensor_bounds(1, width - x)
h = randint_with_tensor_bounds(1, height - y)
parts = (x, y, w, h)
elif format == features.BoundingBoxFormat.CXCYWH:
cx = torch.randint(1, width - 1, ())
cy = torch.randint(1, height - 1, ())
w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1)
h = randint_with_tensor_bounds(1, torch.minimum(cy, height - cy) + 1)
parts = (cx, cy, w, h)
else:
raise pytest.UsageError()
return features.BoundingBox(torch.stack(parts, dim=-1).to(dtype), format=format, image_size=image_size)
make_xyxy_bounding_box = functools.partial(make_bounding_box, format=features.BoundingBoxFormat.XYXY)
def make_bounding_boxes(
formats=(features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH),
image_sizes=((32, 32),),
dtypes=(torch.int64, torch.float32),
extra_dims=((0,), (), (4,), (2, 3), (5, 0), (0, 5)),
):
for format, image_size, dtype in itertools.product(formats, image_sizes, dtypes):
yield make_bounding_box(format=format, image_size=image_size, dtype=dtype)
for format, extra_dims_ in itertools.product(formats, extra_dims):
yield make_bounding_box(format=format, extra_dims=extra_dims_)
def make_label(size=(), *, categories=("category0", "category1")):
return features.Label(torch.randint(0, len(categories) if categories else 10, size), categories=categories)
def make_one_hot_label(*args, **kwargs):
label = make_label(*args, **kwargs)
return features.OneHotLabel(one_hot(label, num_classes=len(label.categories)), categories=label.categories)
def make_one_hot_labels(
*,
num_categories=(1, 2, 10),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
):
for num_categories_ in num_categories:
yield make_one_hot_label(categories=[f"category{idx}" for idx in range(num_categories_)])
for extra_dims_ in extra_dims:
yield make_one_hot_label(extra_dims_)
def make_segmentation_mask(size=None, *, num_objects=None, extra_dims=(), dtype=torch.uint8):
size = size if size is not None else torch.randint(16, 33, (2,)).tolist()
num_objects = num_objects if num_objects is not None else int(torch.randint(1, 11, ()))
shape = (*extra_dims, num_objects, *size)
data = make_tensor(shape, low=0, high=2, dtype=dtype)
return features.SegmentationMask(data)
def make_segmentation_masks(
sizes=((16, 16), (7, 33), (31, 9)),
dtypes=(torch.uint8,),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
num_objects=(1, 0, 10),
):
for size, dtype, extra_dims_ in itertools.product(sizes, dtypes, extra_dims):
yield make_segmentation_mask(size=size, dtype=dtype, extra_dims=extra_dims_)
for dtype, extra_dims_, num_objects_ in itertools.product(dtypes, extra_dims, num_objects):
yield make_segmentation_mask(size=sizes[0], num_objects=num_objects_, dtype=dtype, extra_dims=extra_dims_)
class SampleInput:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
class FunctionalInfo:
......@@ -182,7 +26,7 @@ class FunctionalInfo:
yield from self._sample_inputs_fn()
def __call__(self, *args, **kwargs):
if len(args) == 1 and not kwargs and isinstance(args[0], SampleInput):
if len(args) == 1 and not kwargs and isinstance(args[0], ArgsKwargs):
sample_input = args[0]
return self.functional(*sample_input.args, **sample_input.kwargs)
......@@ -200,37 +44,37 @@ def register_kernel_info_from_sample_inputs_fn(sample_inputs_fn):
@register_kernel_info_from_sample_inputs_fn
def horizontal_flip_image_tensor():
for image in make_images():
yield SampleInput(image)
yield ArgsKwargs(image)
@register_kernel_info_from_sample_inputs_fn
def horizontal_flip_bounding_box():
for bounding_box in make_bounding_boxes(formats=[features.BoundingBoxFormat.XYXY]):
yield SampleInput(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size)
yield ArgsKwargs(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size)
@register_kernel_info_from_sample_inputs_fn
def horizontal_flip_segmentation_mask():
for mask in make_segmentation_masks():
yield SampleInput(mask)
yield ArgsKwargs(mask)
@register_kernel_info_from_sample_inputs_fn
def vertical_flip_image_tensor():
for image in make_images():
yield SampleInput(image)
yield ArgsKwargs(image)
@register_kernel_info_from_sample_inputs_fn
def vertical_flip_bounding_box():
for bounding_box in make_bounding_boxes(formats=[features.BoundingBoxFormat.XYXY]):
yield SampleInput(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size)
yield ArgsKwargs(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size)
@register_kernel_info_from_sample_inputs_fn
def vertical_flip_segmentation_mask():
for mask in make_segmentation_masks():
yield SampleInput(mask)
yield ArgsKwargs(mask)
@register_kernel_info_from_sample_inputs_fn
......@@ -252,7 +96,7 @@ def resize_image_tensor():
]:
if max_size is not None:
size = [size[0]]
yield SampleInput(image, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
yield ArgsKwargs(image, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
@register_kernel_info_from_sample_inputs_fn
......@@ -268,7 +112,7 @@ def resize_bounding_box():
]:
if max_size is not None:
size = [size[0]]
yield SampleInput(bounding_box, size=size, image_size=bounding_box.image_size)
yield ArgsKwargs(bounding_box, size=size, image_size=bounding_box.image_size)
@register_kernel_info_from_sample_inputs_fn
......@@ -284,7 +128,7 @@ def resize_segmentation_mask():
]:
if max_size is not None:
size = [size[0]]
yield SampleInput(mask, size=size, max_size=max_size)
yield ArgsKwargs(mask, size=size, max_size=max_size)
@register_kernel_info_from_sample_inputs_fn
......@@ -296,7 +140,7 @@ def affine_image_tensor():
[0.77, 1.27], # scale
[0, 12], # shear
):
yield SampleInput(
yield ArgsKwargs(
image,
angle=angle,
translate=(translate, translate),
......@@ -315,7 +159,7 @@ def affine_bounding_box():
[0.77, 1.27], # scale
[0, 12], # shear
):
yield SampleInput(
yield ArgsKwargs(
bounding_box,
format=bounding_box.format,
image_size=bounding_box.image_size,
......@@ -335,7 +179,7 @@ def affine_segmentation_mask():
[0.77, 1.27], # scale
[0, 12], # shear
):
yield SampleInput(
yield ArgsKwargs(
mask,
angle=angle,
translate=(translate, translate),
......@@ -357,7 +201,7 @@ def rotate_image_tensor():
# Skip warning: The provided center argument is ignored if expand is True
continue
yield SampleInput(image, angle=angle, expand=expand, center=center, fill=fill)
yield ArgsKwargs(image, angle=angle, expand=expand, center=center, fill=fill)
@register_kernel_info_from_sample_inputs_fn
......@@ -369,7 +213,7 @@ def rotate_bounding_box():
# Skip warning: The provided center argument is ignored if expand is True
continue
yield SampleInput(
yield ArgsKwargs(
bounding_box,
format=bounding_box.format,
image_size=bounding_box.image_size,
......@@ -391,7 +235,7 @@ def rotate_segmentation_mask():
# Skip warning: The provided center argument is ignored if expand is True
continue
yield SampleInput(
yield ArgsKwargs(
mask,
angle=angle,
expand=expand,
......@@ -402,7 +246,7 @@ def rotate_segmentation_mask():
@register_kernel_info_from_sample_inputs_fn
def crop_image_tensor():
for image, top, left, height, width in itertools.product(make_images(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20]):
yield SampleInput(
yield ArgsKwargs(
image,
top=top,
left=left,
......@@ -414,7 +258,7 @@ def crop_image_tensor():
@register_kernel_info_from_sample_inputs_fn
def crop_bounding_box():
for bounding_box, top, left in itertools.product(make_bounding_boxes(), [-8, 0, 9], [-8, 0, 9]):
yield SampleInput(
yield ArgsKwargs(
bounding_box,
format=bounding_box.format,
top=top,
......@@ -427,7 +271,7 @@ def crop_segmentation_mask():
for mask, top, left, height, width in itertools.product(
make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20]
):
yield SampleInput(
yield ArgsKwargs(
mask,
top=top,
left=left,
......@@ -447,7 +291,7 @@ def resized_crop_image_tensor():
[(16, 18)],
[True, False],
):
yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size, antialias=antialias)
yield ArgsKwargs(mask, top=top, left=left, height=height, width=width, size=size, antialias=antialias)
@register_kernel_info_from_sample_inputs_fn
......@@ -455,7 +299,7 @@ def resized_crop_bounding_box():
for bounding_box, top, left, height, width, size in itertools.product(
make_bounding_boxes(), [-8, 9], [-8, 9], [32, 22], [34, 20], [(32, 32), (16, 18)]
):
yield SampleInput(
yield ArgsKwargs(
bounding_box, format=bounding_box.format, top=top, left=left, height=height, width=width, size=size
)
......@@ -465,7 +309,7 @@ def resized_crop_segmentation_mask():
for mask, top, left, height, width, size in itertools.product(
make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20], [(32, 32), (16, 18)]
):
yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size)
yield ArgsKwargs(mask, top=top, left=left, height=height, width=width, size=size)
@register_kernel_info_from_sample_inputs_fn
......@@ -476,7 +320,7 @@ def pad_image_tensor():
[None, 12, 12.0], # fill
["constant", "symmetric", "edge", "reflect"], # padding mode,
):
yield SampleInput(image, padding=padding, fill=fill, padding_mode=padding_mode)
yield ArgsKwargs(image, padding=padding, fill=fill, padding_mode=padding_mode)
@register_kernel_info_from_sample_inputs_fn
......@@ -486,7 +330,7 @@ def pad_segmentation_mask():
[[1], [1, 1], [1, 1, 2, 2]], # padding
["constant", "symmetric", "edge", "reflect"], # padding mode,
):
yield SampleInput(mask, padding=padding, padding_mode=padding_mode)
yield ArgsKwargs(mask, padding=padding, padding_mode=padding_mode)
@register_kernel_info_from_sample_inputs_fn
......@@ -495,7 +339,7 @@ def pad_bounding_box():
make_bounding_boxes(),
[[1], [1, 1], [1, 1, 2, 2]],
):
yield SampleInput(bounding_box, padding=padding, format=bounding_box.format)
yield ArgsKwargs(bounding_box, padding=padding, format=bounding_box.format)
@register_kernel_info_from_sample_inputs_fn
......@@ -508,7 +352,7 @@ def perspective_image_tensor():
],
[None, [128], [12.0]], # fill
):
yield SampleInput(image, perspective_coeffs=perspective_coeffs, fill=fill)
yield ArgsKwargs(image, perspective_coeffs=perspective_coeffs, fill=fill)
@register_kernel_info_from_sample_inputs_fn
......@@ -520,7 +364,7 @@ def perspective_bounding_box():
[0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
],
):
yield SampleInput(
yield ArgsKwargs(
bounding_box,
format=bounding_box.format,
perspective_coeffs=perspective_coeffs,
......@@ -536,7 +380,7 @@ def perspective_segmentation_mask():
[0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
],
):
yield SampleInput(
yield ArgsKwargs(
mask,
perspective_coeffs=perspective_coeffs,
)
......@@ -550,7 +394,7 @@ def elastic_image_tensor():
):
h, w = image.shape[-2:]
displacement = torch.rand(1, h, w, 2)
yield SampleInput(image, displacement=displacement, fill=fill)
yield ArgsKwargs(image, displacement=displacement, fill=fill)
@register_kernel_info_from_sample_inputs_fn
......@@ -558,7 +402,7 @@ def elastic_bounding_box():
for bounding_box in make_bounding_boxes():
h, w = bounding_box.image_size
displacement = torch.rand(1, h, w, 2)
yield SampleInput(
yield ArgsKwargs(
bounding_box,
format=bounding_box.format,
displacement=displacement,
......@@ -570,7 +414,7 @@ def elastic_segmentation_mask():
for mask in make_segmentation_masks(extra_dims=((), (4,))):
h, w = mask.shape[-2:]
displacement = torch.rand(1, h, w, 2)
yield SampleInput(
yield ArgsKwargs(
mask,
displacement=displacement,
)
......@@ -582,13 +426,13 @@ def center_crop_image_tensor():
make_images(sizes=((16, 16), (7, 33), (31, 9))),
[[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
):
yield SampleInput(mask, output_size)
yield ArgsKwargs(mask, output_size)
@register_kernel_info_from_sample_inputs_fn
def center_crop_bounding_box():
for bounding_box, output_size in itertools.product(make_bounding_boxes(), [(24, 12), [16, 18], [46, 48], [12]]):
yield SampleInput(
yield ArgsKwargs(
bounding_box, format=bounding_box.format, output_size=output_size, image_size=bounding_box.image_size
)
......@@ -599,7 +443,7 @@ def center_crop_segmentation_mask():
make_segmentation_masks(sizes=((16, 16), (7, 33), (31, 9))),
[[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
):
yield SampleInput(mask, output_size)
yield ArgsKwargs(mask, output_size)
@register_kernel_info_from_sample_inputs_fn
......@@ -609,7 +453,7 @@ def gaussian_blur_image_tensor():
[[3, 3]],
[None, [3.0, 3.0]],
):
yield SampleInput(image, kernel_size=kernel_size, sigma=sigma)
yield ArgsKwargs(image, kernel_size=kernel_size, sigma=sigma)
@register_kernel_info_from_sample_inputs_fn
......@@ -617,13 +461,13 @@ def equalize_image_tensor():
for image in make_images(extra_dims=(), color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)):
if image.dtype != torch.uint8:
continue
yield SampleInput(image)
yield ArgsKwargs(image)
@register_kernel_info_from_sample_inputs_fn
def invert_image_tensor():
for image in make_images(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)):
yield SampleInput(image)
yield ArgsKwargs(image)
@register_kernel_info_from_sample_inputs_fn
......@@ -634,7 +478,7 @@ def posterize_image_tensor():
):
if image.dtype != torch.uint8:
continue
yield SampleInput(image, bits=bits)
yield ArgsKwargs(image, bits=bits)
@register_kernel_info_from_sample_inputs_fn
......@@ -645,13 +489,13 @@ def solarize_image_tensor():
):
if image.is_floating_point() and threshold > 1.0:
continue
yield SampleInput(image, threshold=threshold)
yield ArgsKwargs(image, threshold=threshold)
@register_kernel_info_from_sample_inputs_fn
def autocontrast_image_tensor():
for image in make_images(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)):
yield SampleInput(image)
yield ArgsKwargs(image)
@register_kernel_info_from_sample_inputs_fn
......@@ -660,14 +504,14 @@ def adjust_sharpness_image_tensor():
make_images(extra_dims=((4,),), color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)),
[0.1, 0.5],
):
yield SampleInput(image, sharpness_factor=sharpness_factor)
yield ArgsKwargs(image, sharpness_factor=sharpness_factor)
@register_kernel_info_from_sample_inputs_fn
def erase_image_tensor():
for image in make_images():
c = image.shape[-3]
yield SampleInput(image, i=1, j=2, h=6, w=7, v=torch.rand(c, 6, 7))
yield ArgsKwargs(image, i=1, j=2, h=6, w=7, v=torch.rand(c, 6, 7))
@pytest.mark.parametrize(
......
......@@ -3,7 +3,7 @@ import pytest
import torch
from test_prototype_transforms_functional import make_bounding_box, make_image, make_segmentation_mask
from prototype_common_utils import make_bounding_box, make_image, make_segmentation_mask
from torchvision.prototype import features
from torchvision.prototype.transforms._utils import has_all, has_any
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment