Unverified Commit 23b0938f authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

extract make_* functions out of make_*_loader (#7717)

parent cbc36eb4
...@@ -27,7 +27,7 @@ from PIL import Image ...@@ -27,7 +27,7 @@ from PIL import Image
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision import datapoints, io from torchvision import datapoints, io
from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import convert_dtype_image_tensor, to_image_tensor from torchvision.transforms.v2.functional import convert_dtype_image_tensor, to_image_pil, to_image_tensor
IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"]) IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
...@@ -399,6 +399,9 @@ class ArgsKwargs: ...@@ -399,6 +399,9 @@ class ArgsKwargs:
) )
# new v2 default
DEFAULT_SIZE = (17, 11)
# old v2 defaults
DEFAULT_SQUARE_SPATIAL_SIZE = 15 DEFAULT_SQUARE_SPATIAL_SIZE = 15
DEFAULT_LANDSCAPE_SPATIAL_SIZE = (7, 33) DEFAULT_LANDSCAPE_SPATIAL_SIZE = (7, 33)
DEFAULT_PORTRAIT_SPATIAL_SIZE = (31, 9) DEFAULT_PORTRAIT_SPATIAL_SIZE = (31, 9)
...@@ -406,13 +409,12 @@ DEFAULT_SPATIAL_SIZES = ( ...@@ -406,13 +409,12 @@ DEFAULT_SPATIAL_SIZES = (
DEFAULT_LANDSCAPE_SPATIAL_SIZE, DEFAULT_LANDSCAPE_SPATIAL_SIZE,
DEFAULT_PORTRAIT_SPATIAL_SIZE, DEFAULT_PORTRAIT_SPATIAL_SIZE,
DEFAULT_SQUARE_SPATIAL_SIZE, DEFAULT_SQUARE_SPATIAL_SIZE,
"random",
) )
def _parse_spatial_size(size, *, name="size"): def _parse_spatial_size(size, *, name="size"):
if size == "random": if size == "random":
return tuple(torch.randint(15, 33, (2,)).tolist()) raise ValueError("This should never happen")
elif isinstance(size, int) and size > 0: elif isinstance(size, int) and size > 0:
return (size, size) return (size, size)
elif ( elif (
...@@ -492,8 +494,40 @@ def get_num_channels(color_space): ...@@ -492,8 +494,40 @@ def get_num_channels(color_space):
return num_channels return num_channels
def make_image(
size=DEFAULT_SIZE,
*,
color_space="RGB",
batch_dims=(),
dtype=None,
device="cpu",
memory_format=torch.contiguous_format,
):
max_value = get_max_value(dtype)
data = torch.testing.make_tensor(
(*batch_dims, get_num_channels(color_space), *size),
low=0,
high=max_value,
dtype=dtype or torch.uint8,
device=device,
memory_format=memory_format,
)
if color_space in {"GRAY_ALPHA", "RGBA"}:
data[..., -1, :, :] = max_value
return datapoints.Image(data)
def make_image_tensor(*args, **kwargs):
return make_image(*args, **kwargs).as_subclass(torch.Tensor)
def make_image_pil(*args, **kwargs):
return to_image_pil(make_image(*args, **kwargs))
def make_image_loader( def make_image_loader(
size="random", size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
*, *,
color_space="RGB", color_space="RGB",
extra_dims=(), extra_dims=(),
...@@ -501,24 +535,25 @@ def make_image_loader( ...@@ -501,24 +535,25 @@ def make_image_loader(
constant_alpha=True, constant_alpha=True,
memory_format=torch.contiguous_format, memory_format=torch.contiguous_format,
): ):
if not constant_alpha:
raise ValueError("This should never happen")
size = _parse_spatial_size(size) size = _parse_spatial_size(size)
num_channels = get_num_channels(color_space) num_channels = get_num_channels(color_space)
def fn(shape, dtype, device, memory_format): def fn(shape, dtype, device, memory_format):
max_value = get_max_value(dtype) *batch_dims, _, height, width = shape
data = torch.testing.make_tensor( return make_image(
shape, low=0, high=max_value, dtype=dtype, device=device, memory_format=memory_format (height, width),
color_space=color_space,
batch_dims=batch_dims,
dtype=dtype,
device=device,
memory_format=memory_format,
) )
if color_space in {"GRAY_ALPHA", "RGBA"} and constant_alpha:
data[..., -1, :, :] = max_value
return datapoints.Image(data)
return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, memory_format=memory_format) return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, memory_format=memory_format)
make_image = from_loader(make_image_loader)
def make_image_loaders( def make_image_loaders(
*, *,
sizes=DEFAULT_SPATIAL_SIZES, sizes=DEFAULT_SPATIAL_SIZES,
...@@ -540,7 +575,7 @@ make_images = from_loaders(make_image_loaders) ...@@ -540,7 +575,7 @@ make_images = from_loaders(make_image_loaders)
def make_image_loader_for_interpolation( def make_image_loader_for_interpolation(
size="random", *, color_space="RGB", dtype=torch.uint8, memory_format=torch.contiguous_format size=(233, 147), *, color_space="RGB", dtype=torch.uint8, memory_format=torch.contiguous_format
): ):
size = _parse_spatial_size(size) size = _parse_spatial_size(size)
num_channels = get_num_channels(color_space) num_channels = get_num_channels(color_space)
...@@ -589,76 +624,114 @@ class BoundingBoxLoader(TensorLoader): ...@@ -589,76 +624,114 @@ class BoundingBoxLoader(TensorLoader):
spatial_size: Tuple[int, int] spatial_size: Tuple[int, int]
def randint_with_tensor_bounds(arg1, arg2=None, **kwargs): def make_bounding_box(
low, high = torch.broadcast_tensors( size=None,
*[torch.as_tensor(arg) for arg in ((0, arg1) if arg2 is None else (arg1, arg2))] *,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=None,
batch_dims=(),
dtype=None,
device="cpu",
):
"""
size: Size of the actual bounding box, i.e.
- (box[3] - box[1], box[2] - box[0]) for XYXY
- (H, W) for XYWH and CXCYWH
spatial_size: Size of the reference object, e.g. an image. Corresponds to the .spatial_size attribute on
returned datapoints.BoundingBox
To generate a valid joint sample, you need to set spatial_size here to the same value as size on the other maker
functions, e.g.
.. code::
image = make_image=(size=size)
bounding_box = make_bounding_box(spatial_size=size)
assert F.get_spatial_size(bounding_box) == F.get_spatial_size(image)
For convenience, if both size and spatial_size are omitted, spatial_size defaults to the same value as size for all
other maker functions, e.g.
.. code::
image = make_image=()
bounding_box = make_bounding_box()
assert F.get_spatial_size(bounding_box) == F.get_spatial_size(image)
"""
def sample_position(values, max_value):
# We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high.
# However, if we have batch_dims, we need tensors as limits.
return torch.stack([torch.randint(max_value - v, ()) for v in values.flatten().tolist()]).reshape(values.shape)
if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format]
if spatial_size is None:
if size is None:
spatial_size = DEFAULT_SIZE
else:
height, width = size
height_margin, width_margin = torch.randint(10, (2,)).tolist()
spatial_size = (height + height_margin, width + width_margin)
dtype = dtype or torch.float32
if any(dim == 0 for dim in batch_dims):
return datapoints.BoundingBox(
torch.empty(*batch_dims, 4, dtype=dtype, device=device), format=format, spatial_size=spatial_size
)
if size is None:
h, w = [torch.randint(1, s, batch_dims) for s in spatial_size]
else:
h, w = [torch.full(batch_dims, s, dtype=torch.int) for s in size]
y = sample_position(h, spatial_size[0])
x = sample_position(w, spatial_size[1])
if format is datapoints.BoundingBoxFormat.XYWH:
parts = (x, y, w, h)
elif format is datapoints.BoundingBoxFormat.XYXY:
x1, y1 = x, y
x2 = x1 + w
y2 = y1 + h
parts = (x1, y1, x2, y2)
elif format is datapoints.BoundingBoxFormat.CXCYWH:
cx = x + w / 2
cy = y + h / 2
parts = (cx, cy, w, h)
else:
raise ValueError(f"Format {format} is not supported")
return datapoints.BoundingBox(
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, spatial_size=spatial_size
) )
return torch.stack(
[
torch.randint(low_scalar, high_scalar, (), **kwargs)
for low_scalar, high_scalar in zip(low.flatten().tolist(), high.flatten().tolist())
]
).reshape(low.shape)
def make_bounding_box_loader(*, extra_dims=(), format, spatial_size="random", dtype=torch.float32): def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtype=torch.float32):
if isinstance(format, str): if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format] format = datapoints.BoundingBoxFormat[format]
if format not in {
datapoints.BoundingBoxFormat.XYXY,
datapoints.BoundingBoxFormat.XYWH,
datapoints.BoundingBoxFormat.CXCYWH,
}:
raise pytest.UsageError(f"Can't make bounding box in format {format}")
spatial_size = _parse_spatial_size(spatial_size, name="spatial_size") spatial_size = _parse_spatial_size(spatial_size, name="spatial_size")
def fn(shape, dtype, device): def fn(shape, dtype, device):
*extra_dims, num_coordinates = shape *batch_dims, num_coordinates = shape
if num_coordinates != 4: if num_coordinates != 4:
raise pytest.UsageError() raise pytest.UsageError()
if any(dim == 0 for dim in extra_dims): return make_bounding_box(
return datapoints.BoundingBox( format=format, spatial_size=spatial_size, batch_dims=batch_dims, dtype=dtype, device=device
torch.empty(*extra_dims, 4, dtype=dtype, device=device), format=format, spatial_size=spatial_size
)
height, width = spatial_size
if format == datapoints.BoundingBoxFormat.XYXY:
x1 = torch.randint(0, width // 2, extra_dims)
y1 = torch.randint(0, height // 2, extra_dims)
x2 = randint_with_tensor_bounds(x1 + 1, width - x1) + x1
y2 = randint_with_tensor_bounds(y1 + 1, height - y1) + y1
parts = (x1, y1, x2, y2)
elif format == datapoints.BoundingBoxFormat.XYWH:
x = torch.randint(0, width // 2, extra_dims)
y = torch.randint(0, height // 2, extra_dims)
w = randint_with_tensor_bounds(1, width - x)
h = randint_with_tensor_bounds(1, height - y)
parts = (x, y, w, h)
else: # format == features.BoundingBoxFormat.CXCYWH:
cx = torch.randint(1, width - 1, extra_dims)
cy = torch.randint(1, height - 1, extra_dims)
w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1)
h = randint_with_tensor_bounds(1, torch.minimum(cy, height - cy) + 1)
parts = (cx, cy, w, h)
return datapoints.BoundingBox(
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, spatial_size=spatial_size
) )
return BoundingBoxLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=spatial_size) return BoundingBoxLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=spatial_size)
make_bounding_box = from_loader(make_bounding_box_loader)
def make_bounding_box_loaders( def make_bounding_box_loaders(
*, *,
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
formats=tuple(datapoints.BoundingBoxFormat), formats=tuple(datapoints.BoundingBoxFormat),
spatial_size="random", spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
dtypes=(torch.float32, torch.float64, torch.int64), dtypes=(torch.float32, torch.float64, torch.int64),
): ):
for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes): for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes):
...@@ -672,24 +745,35 @@ class MaskLoader(TensorLoader): ...@@ -672,24 +745,35 @@ class MaskLoader(TensorLoader):
pass pass
def make_detection_mask_loader(size="random", *, num_objects="random", extra_dims=(), dtype=torch.uint8): def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtype=None, device="cpu"):
"""Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks"""
return datapoints.Mask(
torch.testing.make_tensor(
(*batch_dims, num_objects, *size),
low=0,
high=2,
dtype=dtype or torch.bool,
device=device,
)
)
def make_detection_mask_loader(size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_objects=5, extra_dims=(), dtype=torch.uint8):
# This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects # This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects
size = _parse_spatial_size(size) size = _parse_spatial_size(size)
num_objects = int(torch.randint(1, 11, ())) if num_objects == "random" else num_objects
def fn(shape, dtype, device): def fn(shape, dtype, device):
data = torch.testing.make_tensor(shape, low=0, high=2, dtype=dtype, device=device) *batch_dims, num_objects, height, width = shape
return datapoints.Mask(data) return make_detection_mask(
(height, width), num_objects=num_objects, batch_dims=batch_dims, dtype=dtype, device=device
)
return MaskLoader(fn, shape=(*extra_dims, num_objects, *size), dtype=dtype) return MaskLoader(fn, shape=(*extra_dims, num_objects, *size), dtype=dtype)
make_detection_mask = from_loader(make_detection_mask_loader)
def make_detection_mask_loaders( def make_detection_mask_loaders(
sizes=DEFAULT_SPATIAL_SIZES, sizes=DEFAULT_SPATIAL_SIZES,
num_objects=(1, 0, "random"), num_objects=(1, 0, 5),
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.uint8,), dtypes=(torch.uint8,),
): ):
...@@ -700,25 +784,38 @@ def make_detection_mask_loaders( ...@@ -700,25 +784,38 @@ def make_detection_mask_loaders(
make_detection_masks = from_loaders(make_detection_mask_loaders) make_detection_masks = from_loaders(make_detection_mask_loaders)
def make_segmentation_mask_loader(size="random", *, num_categories="random", extra_dims=(), dtype=torch.uint8): def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=(), dtype=None, device="cpu"):
# This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values """Make a "segmentation" mask, i.e. (*, H, W), where the category is encoded as pixel value"""
size = _parse_spatial_size(size) return datapoints.Mask(
num_categories = int(torch.randint(1, 11, ())) if num_categories == "random" else num_categories torch.testing.make_tensor(
(*batch_dims, *size),
low=0,
high=num_categories,
dtype=dtype or torch.uint8,
device=device,
)
)
def fn(shape, dtype, device):
data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=dtype, device=device)
return datapoints.Mask(data)
return MaskLoader(fn, shape=(*extra_dims, *size), dtype=dtype) def make_segmentation_mask_loader(
size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_categories=10, extra_dims=(), dtype=torch.uint8
):
# This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values
spatial_size = _parse_spatial_size(size)
def fn(shape, dtype, device):
*batch_dims, height, width = shape
return make_segmentation_mask(
(height, width), num_categories=num_categories, batch_dims=batch_dims, dtype=dtype, device=device
)
make_segmentation_mask = from_loader(make_segmentation_mask_loader) return MaskLoader(fn, shape=(*extra_dims, *spatial_size), dtype=dtype)
def make_segmentation_mask_loaders( def make_segmentation_mask_loaders(
*, *,
sizes=DEFAULT_SPATIAL_SIZES, sizes=DEFAULT_SPATIAL_SIZES,
num_categories=(1, 2, "random"), num_categories=(1, 2, 10),
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.uint8,), dtypes=(torch.uint8,),
): ):
...@@ -732,8 +829,8 @@ make_segmentation_masks = from_loaders(make_segmentation_mask_loaders) ...@@ -732,8 +829,8 @@ make_segmentation_masks = from_loaders(make_segmentation_mask_loaders)
def make_mask_loaders( def make_mask_loaders(
*, *,
sizes=DEFAULT_SPATIAL_SIZES, sizes=DEFAULT_SPATIAL_SIZES,
num_objects=(1, 0, "random"), num_objects=(1, 0, 5),
num_categories=(1, 2, "random"), num_categories=(1, 2, 10),
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.uint8,), dtypes=(torch.uint8,),
): ):
...@@ -750,29 +847,35 @@ class VideoLoader(ImageLoader): ...@@ -750,29 +847,35 @@ class VideoLoader(ImageLoader):
pass pass
def make_video(size=DEFAULT_SIZE, *, num_frames=3, batch_dims=(), **kwargs):
return datapoints.Video(make_image(size, batch_dims=(*batch_dims, num_frames), **kwargs))
def make_video_loader( def make_video_loader(
size="random", size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
*, *,
color_space="RGB", color_space="RGB",
num_frames="random", num_frames=3,
extra_dims=(), extra_dims=(),
dtype=torch.uint8, dtype=torch.uint8,
): ):
size = _parse_spatial_size(size) size = _parse_spatial_size(size)
num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames
def fn(shape, dtype, device, memory_format): def fn(shape, dtype, device, memory_format):
video = make_image( *batch_dims, num_frames, _, height, width = shape
size=shape[-2:], extra_dims=shape[:-3], dtype=dtype, device=device, memory_format=memory_format return make_video(
(height, width),
num_frames=num_frames,
batch_dims=batch_dims,
color_space=color_space,
dtype=dtype,
device=device,
memory_format=memory_format,
) )
return datapoints.Video(video)
return VideoLoader(fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype) return VideoLoader(fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype)
make_video = from_loader(make_video_loader)
def make_video_loaders( def make_video_loaders(
*, *,
sizes=DEFAULT_SPATIAL_SIZES, sizes=DEFAULT_SPATIAL_SIZES,
...@@ -780,7 +883,7 @@ def make_video_loaders( ...@@ -780,7 +883,7 @@ def make_video_loaders(
"GRAY", "GRAY",
"RGB", "RGB",
), ),
num_frames=(1, 0, "random"), num_frames=(1, 0, 3),
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.uint8, torch.float32, torch.float64), dtypes=(torch.uint8, torch.float32, torch.float64),
): ):
......
...@@ -216,7 +216,7 @@ class TestFixedSizeCrop: ...@@ -216,7 +216,7 @@ class TestFixedSizeCrop:
flat_inputs = [ flat_inputs = [
make_image(size=spatial_size, color_space="RGB"), make_image(size=spatial_size, color_space="RGB"),
make_bounding_box(format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=batch_shape), make_bounding_box(format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, batch_dims=batch_shape),
] ]
params = transform._get_params(flat_inputs) params = transform._get_params(flat_inputs)
...@@ -312,9 +312,9 @@ class TestFixedSizeCrop: ...@@ -312,9 +312,9 @@ class TestFixedSizeCrop:
) )
bounding_boxes = make_bounding_box( bounding_boxes = make_bounding_box(
format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(batch_size,) format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, batch_dims=(batch_size,)
) )
masks = make_detection_mask(size=spatial_size, extra_dims=(batch_size,)) masks = make_detection_mask(size=spatial_size, batch_dims=(batch_size,))
labels = make_label(extra_dims=(batch_size,)) labels = make_label(extra_dims=(batch_size,))
transform = transforms.FixedSizeCrop((-1, -1)) transform = transforms.FixedSizeCrop((-1, -1))
...@@ -350,7 +350,7 @@ class TestFixedSizeCrop: ...@@ -350,7 +350,7 @@ class TestFixedSizeCrop:
) )
bounding_box = make_bounding_box( bounding_box = make_bounding_box(
format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(batch_size,) format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, batch_dims=(batch_size,)
) )
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_box") mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_box")
...@@ -496,7 +496,7 @@ def test_fixed_sized_crop_against_detection_reference(): ...@@ -496,7 +496,7 @@ def test_fixed_sized_crop_against_detection_reference():
pil_image = to_image_pil(make_image(size=size, color_space="RGB")) pil_image = to_image_pil(make_image(size=size, color_space="RGB"))
target = { target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(spatial_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long), "masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
} }
...@@ -505,7 +505,7 @@ def test_fixed_sized_crop_against_detection_reference(): ...@@ -505,7 +505,7 @@ def test_fixed_sized_crop_against_detection_reference():
tensor_image = torch.Tensor(make_image(size=size, color_space="RGB")) tensor_image = torch.Tensor(make_image(size=size, color_space="RGB"))
target = { target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(spatial_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long), "masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
} }
...@@ -514,7 +514,7 @@ def test_fixed_sized_crop_against_detection_reference(): ...@@ -514,7 +514,7 @@ def test_fixed_sized_crop_against_detection_reference():
datapoint_image = make_image(size=size, color_space="RGB") datapoint_image = make_image(size=size, color_space="RGB")
target = { target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(spatial_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long), "masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
} }
......
...@@ -182,13 +182,13 @@ class TestSmoke: ...@@ -182,13 +182,13 @@ class TestSmoke:
video_datapoint=make_video(size=spatial_size), video_datapoint=make_video(size=spatial_size),
image_pil=next(make_pil_images(sizes=[spatial_size], color_spaces=["RGB"])), image_pil=next(make_pil_images(sizes=[spatial_size], color_spaces=["RGB"])),
bounding_box_xyxy=make_bounding_box( bounding_box_xyxy=make_bounding_box(
format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(3,) format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, batch_dims=(3,)
), ),
bounding_box_xywh=make_bounding_box( bounding_box_xywh=make_bounding_box(
format=datapoints.BoundingBoxFormat.XYWH, spatial_size=spatial_size, extra_dims=(4,) format=datapoints.BoundingBoxFormat.XYWH, spatial_size=spatial_size, batch_dims=(4,)
), ),
bounding_box_cxcywh=make_bounding_box( bounding_box_cxcywh=make_bounding_box(
format=datapoints.BoundingBoxFormat.CXCYWH, spatial_size=spatial_size, extra_dims=(5,) format=datapoints.BoundingBoxFormat.CXCYWH, spatial_size=spatial_size, batch_dims=(5,)
), ),
bounding_box_degenerate_xyxy=datapoints.BoundingBox( bounding_box_degenerate_xyxy=datapoints.BoundingBox(
[ [
...@@ -289,7 +289,7 @@ class TestSmoke: ...@@ -289,7 +289,7 @@ class TestSmoke:
], ],
dtypes=[torch.uint8], dtypes=[torch.uint8],
extra_dims=[(), (4,)], extra_dims=[(), (4,)],
**(dict(num_frames=["random"]) if fn is make_videos else dict()), **(dict(num_frames=[3]) if fn is make_videos else dict()),
) )
for fn in [ for fn in [
make_images, make_images,
...@@ -1124,7 +1124,7 @@ class TestRandomIoUCrop: ...@@ -1124,7 +1124,7 @@ class TestRandomIoUCrop:
transform = transforms.RandomIoUCrop() transform = transforms.RandomIoUCrop()
image = datapoints.Image(torch.rand(3, 32, 24)) image = datapoints.Image(torch.rand(3, 32, 24))
bboxes = make_bounding_box(format="XYXY", spatial_size=(32, 24), extra_dims=(6,)) bboxes = make_bounding_box(format="XYXY", spatial_size=(32, 24), batch_dims=(6,))
masks = make_detection_mask((32, 24), num_objects=6) masks = make_detection_mask((32, 24), num_objects=6)
sample = [image, bboxes, masks] sample = [image, bboxes, masks]
......
...@@ -1090,7 +1090,7 @@ class TestRefDetTransforms: ...@@ -1090,7 +1090,7 @@ class TestRefDetTransforms:
pil_image = to_image_pil(make_image(size=size, color_space="RGB")) pil_image = to_image_pil(make_image(size=size, color_space="RGB"))
target = { target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(spatial_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
} }
if with_mask: if with_mask:
...@@ -1098,9 +1098,9 @@ class TestRefDetTransforms: ...@@ -1098,9 +1098,9 @@ class TestRefDetTransforms:
yield (pil_image, target) yield (pil_image, target)
tensor_image = torch.Tensor(make_image(size=size, color_space="RGB")) tensor_image = torch.Tensor(make_image(size=size, color_space="RGB", dtype=torch.float32))
target = { target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(spatial_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
} }
if with_mask: if with_mask:
...@@ -1108,9 +1108,9 @@ class TestRefDetTransforms: ...@@ -1108,9 +1108,9 @@ class TestRefDetTransforms:
yield (tensor_image, target) yield (tensor_image, target)
datapoint_image = make_image(size=size, color_space="RGB") datapoint_image = make_image(size=size, color_space="RGB", dtype=torch.float32)
target = { target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(spatial_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
} }
if with_mask: if with_mask:
......
...@@ -665,163 +665,6 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_): ...@@ -665,163 +665,6 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
return true_matrix return true_matrix
@pytest.mark.parametrize("angle", range(-90, 90, 56))
@pytest.mark.parametrize("expand, center", [(True, None), (False, None), (False, (12, 14))])
def test_correctness_rotate_bounding_box(angle, expand, center):
def _compute_expected_bbox(bbox, angle_, expand_, center_):
affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_)
affine_matrix = affine_matrix[:2, :]
height, width = bbox.spatial_size
bbox_xyxy = convert_format_bounding_box(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
points = np.array(
[
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
[bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
[bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
[bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
# image frame
[0.0, 0.0, 1.0],
[0.0, height, 1.0],
[width, height, 1.0],
[width, 0.0, 1.0],
]
)
transformed_points = np.matmul(points, affine_matrix.T)
out_bbox = [
float(np.min(transformed_points[:4, 0])),
float(np.min(transformed_points[:4, 1])),
float(np.max(transformed_points[:4, 0])),
float(np.max(transformed_points[:4, 1])),
]
if expand_:
tr_x = np.min(transformed_points[4:, 0])
tr_y = np.min(transformed_points[4:, 1])
out_bbox[0] -= tr_x
out_bbox[1] -= tr_y
out_bbox[2] -= tr_x
out_bbox[3] -= tr_y
height = int(height - 2 * tr_y)
width = int(width - 2 * tr_x)
out_bbox = datapoints.BoundingBox(
out_bbox,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=(height, width),
dtype=bbox.dtype,
device=bbox.device,
)
out_bbox = clamp_bounding_box(convert_format_bounding_box(out_bbox, new_format=bbox.format))
return out_bbox, (height, width)
spatial_size = (32, 38)
for bboxes in make_bounding_boxes(spatial_size=spatial_size, extra_dims=((4,),)):
bboxes_format = bboxes.format
bboxes_spatial_size = bboxes.spatial_size
output_bboxes, output_spatial_size = F.rotate_bounding_box(
bboxes.as_subclass(torch.Tensor),
format=bboxes_format,
spatial_size=bboxes_spatial_size,
angle=angle,
expand=expand,
center=center,
)
center_ = center
if center_ is None:
center_ = [s * 0.5 for s in bboxes_spatial_size[::-1]]
if bboxes.ndim < 2:
bboxes = [bboxes]
expected_bboxes = []
for bbox in bboxes:
bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size)
expected_bbox, expected_spatial_size = _compute_expected_bbox(bbox, -angle, expand, center_)
expected_bboxes.append(expected_bbox)
if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes)
else:
expected_bboxes = expected_bboxes[0]
torch.testing.assert_close(output_bboxes, expected_bboxes, atol=1, rtol=0)
torch.testing.assert_close(output_spatial_size, expected_spatial_size, atol=1, rtol=0)
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("expand", [False]) # expand=True does not match D2
def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
# Check transformation against known expected output
format = datapoints.BoundingBoxFormat.XYXY
spatial_size = (64, 64)
# xyxy format
in_boxes = [
[1, 1, 5, 5],
[1, spatial_size[0] - 6, 5, spatial_size[0] - 2],
[spatial_size[1] - 6, spatial_size[0] - 6, spatial_size[1] - 2, spatial_size[0] - 2],
[spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10],
]
in_boxes = torch.tensor(in_boxes, dtype=torch.float64, device=device)
# Tested parameters
angle = 45
center = None if expand else [12, 23]
# # Expected bboxes computed using Detectron2:
# from detectron2.data.transforms import RotationTransform, AugmentationList
# from detectron2.data.transforms import AugInput
# import cv2
# inpt = AugInput(im1, boxes=np.array(in_boxes, dtype="float32"))
# augs = AugmentationList([RotationTransform(*size, angle, expand=expand, center=center, interp=cv2.INTER_NEAREST), ])
# out = augs(inpt)
# print(inpt.boxes)
if expand:
expected_bboxes = [
[1.65937957, 42.67157288, 7.31623382, 48.32842712],
[41.96446609, 82.9766594, 47.62132034, 88.63351365],
[82.26955262, 42.67157288, 87.92640687, 48.32842712],
[31.35786438, 31.35786438, 59.64213562, 59.64213562],
]
else:
expected_bboxes = [
[-11.33452378, 12.39339828, -5.67766953, 18.05025253],
[28.97056275, 52.69848481, 34.627417, 58.35533906],
[69.27564928, 12.39339828, 74.93250353, 18.05025253],
[18.36396103, 1.07968978, 46.64823228, 29.36396103],
]
expected_bboxes = clamp_bounding_box(
datapoints.BoundingBox(expected_bboxes, format="XYXY", spatial_size=spatial_size)
).tolist()
output_boxes, _ = F.rotate_bounding_box(
in_boxes,
format=format,
spatial_size=spatial_size,
angle=angle,
expand=expand,
center=center,
)
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_correctness_rotate_segmentation_mask_on_fixed_input(device):
# Check transformation against known expected output and CPU/CUDA devices
# Create a fixed input segmentation mask with 2 square masks
# in top-left, bottom-left corners
mask = torch.zeros(1, 32, 32, dtype=torch.long, device=device)
mask[0, 2:10, 2:10] = 1
mask[0, 32 - 9 : 32 - 3, 3:9] = 2
# Rotate 90 degrees
expected_mask = torch.rot90(mask, k=1, dims=(-2, -1))
out_mask = F.rotate_mask(mask, 90, expand=False)
torch.testing.assert_close(out_mask, expected_mask)
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize( @pytest.mark.parametrize(
"format", "format",
......
This diff is collapsed.
...@@ -11,6 +11,7 @@ import torchvision.transforms.v2.functional as F ...@@ -11,6 +11,7 @@ import torchvision.transforms.v2.functional as F
from common_utils import ( from common_utils import (
ArgsKwargs, ArgsKwargs,
combinations_grid, combinations_grid,
DEFAULT_PORTRAIT_SPATIAL_SIZE,
get_num_channels, get_num_channels,
ImageLoader, ImageLoader,
InfoBase, InfoBase,
...@@ -260,6 +261,9 @@ KERNEL_INFOS.append( ...@@ -260,6 +261,9 @@ KERNEL_INFOS.append(
reference_fn=reference_convert_format_bounding_box, reference_fn=reference_convert_format_bounding_box,
reference_inputs_fn=reference_inputs_convert_format_bounding_box, reference_inputs_fn=reference_inputs_convert_format_bounding_box,
logs_usage=True, logs_usage=True,
closeness_kwargs={
(("TestKernels", "test_against_reference"), torch.int64, "cpu"): dict(atol=1, rtol=0),
},
), ),
) )
...@@ -296,7 +300,7 @@ def sample_inputs_crop_bounding_box(): ...@@ -296,7 +300,7 @@ def sample_inputs_crop_bounding_box():
def sample_inputs_crop_mask(): def sample_inputs_crop_mask():
for mask_loader in make_mask_loaders(sizes=[(16, 17)], num_categories=["random"], num_objects=["random"]): for mask_loader in make_mask_loaders(sizes=[(16, 17)], num_categories=[10], num_objects=[5]):
yield ArgsKwargs(mask_loader, top=4, left=3, height=7, width=8) yield ArgsKwargs(mask_loader, top=4, left=3, height=7, width=8)
...@@ -306,7 +310,7 @@ def reference_inputs_crop_mask(): ...@@ -306,7 +310,7 @@ def reference_inputs_crop_mask():
def sample_inputs_crop_video(): def sample_inputs_crop_video():
for video_loader in make_video_loaders(sizes=[(16, 17)], num_frames=["random"]): for video_loader in make_video_loaders(sizes=[(16, 17)], num_frames=[3]):
yield ArgsKwargs(video_loader, top=4, left=3, height=7, width=8) yield ArgsKwargs(video_loader, top=4, left=3, height=7, width=8)
...@@ -415,7 +419,7 @@ def sample_inputs_resized_crop_mask(): ...@@ -415,7 +419,7 @@ def sample_inputs_resized_crop_mask():
def sample_inputs_resized_crop_video(): def sample_inputs_resized_crop_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader, **_RESIZED_CROP_PARAMS[0]) yield ArgsKwargs(video_loader, **_RESIZED_CROP_PARAMS[0])
...@@ -457,7 +461,7 @@ _PAD_PARAMS = combinations_grid( ...@@ -457,7 +461,7 @@ _PAD_PARAMS = combinations_grid(
def sample_inputs_pad_image_tensor(): def sample_inputs_pad_image_tensor():
make_pad_image_loaders = functools.partial( make_pad_image_loaders = functools.partial(
make_image_loaders, sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32] make_image_loaders, sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=["RGB"], dtypes=[torch.float32]
) )
for image_loader, padding in itertools.product( for image_loader, padding in itertools.product(
...@@ -512,7 +516,7 @@ def sample_inputs_pad_bounding_box(): ...@@ -512,7 +516,7 @@ def sample_inputs_pad_bounding_box():
def sample_inputs_pad_mask(): def sample_inputs_pad_mask():
for mask_loader in make_mask_loaders(sizes=["random"], num_categories=["random"], num_objects=["random"]): for mask_loader in make_mask_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_categories=[10], num_objects=[5]):
yield ArgsKwargs(mask_loader, padding=[1]) yield ArgsKwargs(mask_loader, padding=[1])
...@@ -524,7 +528,7 @@ def reference_inputs_pad_mask(): ...@@ -524,7 +528,7 @@ def reference_inputs_pad_mask():
def sample_inputs_pad_video(): def sample_inputs_pad_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader, padding=[1]) yield ArgsKwargs(video_loader, padding=[1])
...@@ -620,7 +624,7 @@ _ENDPOINTS = [[9, 8], [7, 6], [5, 4], [3, 2]] ...@@ -620,7 +624,7 @@ _ENDPOINTS = [[9, 8], [7, 6], [5, 4], [3, 2]]
def sample_inputs_perspective_image_tensor(): def sample_inputs_perspective_image_tensor():
for image_loader in make_image_loaders(sizes=["random"]): for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]):
for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype): for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs( yield ArgsKwargs(
image_loader, startpoints=None, endpoints=None, fill=fill, coefficients=_PERSPECTIVE_COEFFS[0] image_loader, startpoints=None, endpoints=None, fill=fill, coefficients=_PERSPECTIVE_COEFFS[0]
...@@ -672,7 +676,7 @@ def sample_inputs_perspective_bounding_box(): ...@@ -672,7 +676,7 @@ def sample_inputs_perspective_bounding_box():
def sample_inputs_perspective_mask(): def sample_inputs_perspective_mask():
for mask_loader in make_mask_loaders(sizes=["random"]): for mask_loader in make_mask_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]):
yield ArgsKwargs(mask_loader, startpoints=None, endpoints=None, coefficients=_PERSPECTIVE_COEFFS[0]) yield ArgsKwargs(mask_loader, startpoints=None, endpoints=None, coefficients=_PERSPECTIVE_COEFFS[0])
yield ArgsKwargs(make_detection_mask_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS) yield ArgsKwargs(make_detection_mask_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS)
...@@ -686,7 +690,7 @@ def reference_inputs_perspective_mask(): ...@@ -686,7 +690,7 @@ def reference_inputs_perspective_mask():
def sample_inputs_perspective_video(): def sample_inputs_perspective_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader, startpoints=None, endpoints=None, coefficients=_PERSPECTIVE_COEFFS[0]) yield ArgsKwargs(video_loader, startpoints=None, endpoints=None, coefficients=_PERSPECTIVE_COEFFS[0])
yield ArgsKwargs(make_video_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS) yield ArgsKwargs(make_video_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS)
...@@ -745,7 +749,7 @@ def _get_elastic_displacement(spatial_size): ...@@ -745,7 +749,7 @@ def _get_elastic_displacement(spatial_size):
def sample_inputs_elastic_image_tensor(): def sample_inputs_elastic_image_tensor():
for image_loader in make_image_loaders(sizes=["random"]): for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]):
displacement = _get_elastic_displacement(image_loader.spatial_size) displacement = _get_elastic_displacement(image_loader.spatial_size)
for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype): for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, displacement=displacement, fill=fill) yield ArgsKwargs(image_loader, displacement=displacement, fill=fill)
...@@ -777,13 +781,13 @@ def sample_inputs_elastic_bounding_box(): ...@@ -777,13 +781,13 @@ def sample_inputs_elastic_bounding_box():
def sample_inputs_elastic_mask(): def sample_inputs_elastic_mask():
for mask_loader in make_mask_loaders(sizes=["random"]): for mask_loader in make_mask_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]):
displacement = _get_elastic_displacement(mask_loader.shape[-2:]) displacement = _get_elastic_displacement(mask_loader.shape[-2:])
yield ArgsKwargs(mask_loader, displacement=displacement) yield ArgsKwargs(mask_loader, displacement=displacement)
def sample_inputs_elastic_video(): def sample_inputs_elastic_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
displacement = _get_elastic_displacement(video_loader.shape[-2:]) displacement = _get_elastic_displacement(video_loader.shape[-2:])
yield ArgsKwargs(video_loader, displacement=displacement) yield ArgsKwargs(video_loader, displacement=displacement)
...@@ -854,7 +858,7 @@ def sample_inputs_center_crop_bounding_box(): ...@@ -854,7 +858,7 @@ def sample_inputs_center_crop_bounding_box():
def sample_inputs_center_crop_mask(): def sample_inputs_center_crop_mask():
for mask_loader in make_mask_loaders(sizes=["random"], num_categories=["random"], num_objects=["random"]): for mask_loader in make_mask_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_categories=[10], num_objects=[5]):
height, width = mask_loader.shape[-2:] height, width = mask_loader.shape[-2:]
yield ArgsKwargs(mask_loader, output_size=(height // 2, width // 2)) yield ArgsKwargs(mask_loader, output_size=(height // 2, width // 2))
...@@ -867,7 +871,7 @@ def reference_inputs_center_crop_mask(): ...@@ -867,7 +871,7 @@ def reference_inputs_center_crop_mask():
def sample_inputs_center_crop_video(): def sample_inputs_center_crop_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
height, width = video_loader.shape[-2:] height, width = video_loader.shape[-2:]
yield ArgsKwargs(video_loader, output_size=(height // 2, width // 2)) yield ArgsKwargs(video_loader, output_size=(height // 2, width // 2))
...@@ -947,7 +951,7 @@ KERNEL_INFOS.extend( ...@@ -947,7 +951,7 @@ KERNEL_INFOS.extend(
def sample_inputs_equalize_image_tensor(): def sample_inputs_equalize_image_tensor():
for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
...@@ -1008,7 +1012,7 @@ def reference_inputs_equalize_image_tensor(): ...@@ -1008,7 +1012,7 @@ def reference_inputs_equalize_image_tensor():
def sample_inputs_equalize_video(): def sample_inputs_equalize_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader) yield ArgsKwargs(video_loader)
...@@ -1031,7 +1035,7 @@ KERNEL_INFOS.extend( ...@@ -1031,7 +1035,7 @@ KERNEL_INFOS.extend(
def sample_inputs_invert_image_tensor(): def sample_inputs_invert_image_tensor():
for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
...@@ -1041,7 +1045,7 @@ def reference_inputs_invert_image_tensor(): ...@@ -1041,7 +1045,7 @@ def reference_inputs_invert_image_tensor():
def sample_inputs_invert_video(): def sample_inputs_invert_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader) yield ArgsKwargs(video_loader)
...@@ -1067,7 +1071,7 @@ _POSTERIZE_BITS = [1, 4, 8] ...@@ -1067,7 +1071,7 @@ _POSTERIZE_BITS = [1, 4, 8]
def sample_inputs_posterize_image_tensor(): def sample_inputs_posterize_image_tensor():
for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader, bits=_POSTERIZE_BITS[0]) yield ArgsKwargs(image_loader, bits=_POSTERIZE_BITS[0])
...@@ -1080,7 +1084,7 @@ def reference_inputs_posterize_image_tensor(): ...@@ -1080,7 +1084,7 @@ def reference_inputs_posterize_image_tensor():
def sample_inputs_posterize_video(): def sample_inputs_posterize_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader, bits=_POSTERIZE_BITS[0]) yield ArgsKwargs(video_loader, bits=_POSTERIZE_BITS[0])
...@@ -1110,7 +1114,7 @@ def _get_solarize_thresholds(dtype): ...@@ -1110,7 +1114,7 @@ def _get_solarize_thresholds(dtype):
def sample_inputs_solarize_image_tensor(): def sample_inputs_solarize_image_tensor():
for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader, threshold=next(_get_solarize_thresholds(image_loader.dtype))) yield ArgsKwargs(image_loader, threshold=next(_get_solarize_thresholds(image_loader.dtype)))
...@@ -1125,7 +1129,7 @@ def uint8_to_float32_threshold_adapter(other_args, kwargs): ...@@ -1125,7 +1129,7 @@ def uint8_to_float32_threshold_adapter(other_args, kwargs):
def sample_inputs_solarize_video(): def sample_inputs_solarize_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader, threshold=next(_get_solarize_thresholds(video_loader.dtype))) yield ArgsKwargs(video_loader, threshold=next(_get_solarize_thresholds(video_loader.dtype)))
...@@ -1149,7 +1153,7 @@ KERNEL_INFOS.extend( ...@@ -1149,7 +1153,7 @@ KERNEL_INFOS.extend(
def sample_inputs_autocontrast_image_tensor(): def sample_inputs_autocontrast_image_tensor():
for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
...@@ -1159,7 +1163,7 @@ def reference_inputs_autocontrast_image_tensor(): ...@@ -1159,7 +1163,7 @@ def reference_inputs_autocontrast_image_tensor():
def sample_inputs_autocontrast_video(): def sample_inputs_autocontrast_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader) yield ArgsKwargs(video_loader)
...@@ -1189,7 +1193,7 @@ _ADJUST_SHARPNESS_FACTORS = [0.1, 0.5] ...@@ -1189,7 +1193,7 @@ _ADJUST_SHARPNESS_FACTORS = [0.1, 0.5]
def sample_inputs_adjust_sharpness_image_tensor(): def sample_inputs_adjust_sharpness_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=["random", (2, 2)], sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE, (2, 2)],
color_spaces=("GRAY", "RGB"), color_spaces=("GRAY", "RGB"),
): ):
yield ArgsKwargs(image_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0]) yield ArgsKwargs(image_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0])
...@@ -1204,7 +1208,7 @@ def reference_inputs_adjust_sharpness_image_tensor(): ...@@ -1204,7 +1208,7 @@ def reference_inputs_adjust_sharpness_image_tensor():
def sample_inputs_adjust_sharpness_video(): def sample_inputs_adjust_sharpness_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0]) yield ArgsKwargs(video_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0])
...@@ -1228,7 +1232,7 @@ KERNEL_INFOS.extend( ...@@ -1228,7 +1232,7 @@ KERNEL_INFOS.extend(
def sample_inputs_erase_image_tensor(): def sample_inputs_erase_image_tensor():
for image_loader in make_image_loaders(sizes=["random"]): for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]):
# FIXME: make the parameters more diverse # FIXME: make the parameters more diverse
h, w = 6, 7 h, w = 6, 7
v = torch.rand(image_loader.num_channels, h, w) v = torch.rand(image_loader.num_channels, h, w)
...@@ -1236,7 +1240,7 @@ def sample_inputs_erase_image_tensor(): ...@@ -1236,7 +1240,7 @@ def sample_inputs_erase_image_tensor():
def sample_inputs_erase_video(): def sample_inputs_erase_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
# FIXME: make the parameters more diverse # FIXME: make the parameters more diverse
h, w = 6, 7 h, w = 6, 7
v = torch.rand(video_loader.num_channels, h, w) v = torch.rand(video_loader.num_channels, h, w)
...@@ -1261,7 +1265,7 @@ _ADJUST_BRIGHTNESS_FACTORS = [0.1, 0.5] ...@@ -1261,7 +1265,7 @@ _ADJUST_BRIGHTNESS_FACTORS = [0.1, 0.5]
def sample_inputs_adjust_brightness_image_tensor(): def sample_inputs_adjust_brightness_image_tensor():
for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0]) yield ArgsKwargs(image_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0])
...@@ -1274,7 +1278,7 @@ def reference_inputs_adjust_brightness_image_tensor(): ...@@ -1274,7 +1278,7 @@ def reference_inputs_adjust_brightness_image_tensor():
def sample_inputs_adjust_brightness_video(): def sample_inputs_adjust_brightness_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0]) yield ArgsKwargs(video_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0])
...@@ -1301,7 +1305,7 @@ _ADJUST_CONTRAST_FACTORS = [0.1, 0.5] ...@@ -1301,7 +1305,7 @@ _ADJUST_CONTRAST_FACTORS = [0.1, 0.5]
def sample_inputs_adjust_contrast_image_tensor(): def sample_inputs_adjust_contrast_image_tensor():
for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0]) yield ArgsKwargs(image_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0])
...@@ -1314,7 +1318,7 @@ def reference_inputs_adjust_contrast_image_tensor(): ...@@ -1314,7 +1318,7 @@ def reference_inputs_adjust_contrast_image_tensor():
def sample_inputs_adjust_contrast_video(): def sample_inputs_adjust_contrast_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0]) yield ArgsKwargs(video_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0])
...@@ -1353,7 +1357,7 @@ _ADJUST_GAMMA_GAMMAS_GAINS = [ ...@@ -1353,7 +1357,7 @@ _ADJUST_GAMMA_GAMMAS_GAINS = [
def sample_inputs_adjust_gamma_image_tensor(): def sample_inputs_adjust_gamma_image_tensor():
gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0] gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0]
for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader, gamma=gamma, gain=gain) yield ArgsKwargs(image_loader, gamma=gamma, gain=gain)
...@@ -1367,7 +1371,7 @@ def reference_inputs_adjust_gamma_image_tensor(): ...@@ -1367,7 +1371,7 @@ def reference_inputs_adjust_gamma_image_tensor():
def sample_inputs_adjust_gamma_video(): def sample_inputs_adjust_gamma_video():
gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0] gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0]
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader, gamma=gamma, gain=gain) yield ArgsKwargs(video_loader, gamma=gamma, gain=gain)
...@@ -1397,7 +1401,7 @@ _ADJUST_HUE_FACTORS = [-0.1, 0.5] ...@@ -1397,7 +1401,7 @@ _ADJUST_HUE_FACTORS = [-0.1, 0.5]
def sample_inputs_adjust_hue_image_tensor(): def sample_inputs_adjust_hue_image_tensor():
for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader, hue_factor=_ADJUST_HUE_FACTORS[0]) yield ArgsKwargs(image_loader, hue_factor=_ADJUST_HUE_FACTORS[0])
...@@ -1410,7 +1414,7 @@ def reference_inputs_adjust_hue_image_tensor(): ...@@ -1410,7 +1414,7 @@ def reference_inputs_adjust_hue_image_tensor():
def sample_inputs_adjust_hue_video(): def sample_inputs_adjust_hue_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader, hue_factor=_ADJUST_HUE_FACTORS[0]) yield ArgsKwargs(video_loader, hue_factor=_ADJUST_HUE_FACTORS[0])
...@@ -1439,7 +1443,7 @@ _ADJUST_SATURATION_FACTORS = [0.1, 0.5] ...@@ -1439,7 +1443,7 @@ _ADJUST_SATURATION_FACTORS = [0.1, 0.5]
def sample_inputs_adjust_saturation_image_tensor(): def sample_inputs_adjust_saturation_image_tensor():
for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0]) yield ArgsKwargs(image_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0])
...@@ -1452,7 +1456,7 @@ def reference_inputs_adjust_saturation_image_tensor(): ...@@ -1452,7 +1456,7 @@ def reference_inputs_adjust_saturation_image_tensor():
def sample_inputs_adjust_saturation_video(): def sample_inputs_adjust_saturation_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0]) yield ArgsKwargs(video_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0])
...@@ -1612,7 +1616,7 @@ _NORMALIZE_MEANS_STDS = [ ...@@ -1612,7 +1616,7 @@ _NORMALIZE_MEANS_STDS = [
def sample_inputs_normalize_image_tensor(): def sample_inputs_normalize_image_tensor():
for image_loader, (mean, std) in itertools.product( for image_loader, (mean, std) in itertools.product(
make_image_loaders(sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32]), make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=["RGB"], dtypes=[torch.float32]),
_NORMALIZE_MEANS_STDS, _NORMALIZE_MEANS_STDS,
): ):
yield ArgsKwargs(image_loader, mean=mean, std=std) yield ArgsKwargs(image_loader, mean=mean, std=std)
...@@ -1637,7 +1641,7 @@ def reference_inputs_normalize_image_tensor(): ...@@ -1637,7 +1641,7 @@ def reference_inputs_normalize_image_tensor():
def sample_inputs_normalize_video(): def sample_inputs_normalize_video():
mean, std = _NORMALIZE_MEANS_STDS[0] mean, std = _NORMALIZE_MEANS_STDS[0]
for video_loader in make_video_loaders( for video_loader in make_video_loaders(
sizes=["random"], color_spaces=["RGB"], num_frames=["random"], dtypes=[torch.float32] sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=["RGB"], num_frames=[3], dtypes=[torch.float32]
): ):
yield ArgsKwargs(video_loader, mean=mean, std=std) yield ArgsKwargs(video_loader, mean=mean, std=std)
...@@ -1671,7 +1675,9 @@ def sample_inputs_convert_dtype_image_tensor(): ...@@ -1671,7 +1675,9 @@ def sample_inputs_convert_dtype_image_tensor():
# conversion cannot be performed safely # conversion cannot be performed safely
continue continue
for image_loader in make_image_loaders(sizes=["random"], color_spaces=["RGB"], dtypes=[input_dtype]): for image_loader in make_image_loaders(
sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=["RGB"], dtypes=[input_dtype]
):
yield ArgsKwargs(image_loader, dtype=output_dtype) yield ArgsKwargs(image_loader, dtype=output_dtype)
...@@ -1736,7 +1742,7 @@ def reference_inputs_convert_dtype_image_tensor(): ...@@ -1736,7 +1742,7 @@ def reference_inputs_convert_dtype_image_tensor():
def sample_inputs_convert_dtype_video(): def sample_inputs_convert_dtype_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader) yield ArgsKwargs(video_loader)
...@@ -1781,7 +1787,7 @@ KERNEL_INFOS.extend( ...@@ -1781,7 +1787,7 @@ KERNEL_INFOS.extend(
def sample_inputs_uniform_temporal_subsample_video(): def sample_inputs_uniform_temporal_subsample_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=[4]): for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[4]):
yield ArgsKwargs(video_loader, num_samples=2) yield ArgsKwargs(video_loader, num_samples=2)
...@@ -1797,7 +1803,9 @@ def reference_uniform_temporal_subsample_video(x, num_samples): ...@@ -1797,7 +1803,9 @@ def reference_uniform_temporal_subsample_video(x, num_samples):
def reference_inputs_uniform_temporal_subsample_video(): def reference_inputs_uniform_temporal_subsample_video():
for video_loader in make_video_loaders(sizes=["random"], color_spaces=["RGB"], num_frames=[10]): for video_loader in make_video_loaders(
sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=["RGB"], num_frames=[10]
):
for num_samples in range(1, video_loader.shape[-4] + 1): for num_samples in range(1, video_loader.shape[-4] + 1):
yield ArgsKwargs(video_loader, num_samples) yield ArgsKwargs(video_loader, num_samples)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment