Unverified Commit 312c3d32 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

remove spatial_size (#7734)

parent bdf16222
...@@ -80,7 +80,7 @@ print(image.shape, image.dtype) ...@@ -80,7 +80,7 @@ print(image.shape, image.dtype)
# corresponding image alongside the actual values: # corresponding image alongside the actual values:
bounding_box = datapoints.BoundingBoxes( bounding_box = datapoints.BoundingBoxes(
[17, 16, 344, 495], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=image.shape[-2:] [17, 16, 344, 495], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=image.shape[-2:]
) )
print(bounding_box) print(bounding_box)
...@@ -108,7 +108,7 @@ class PennFudanDataset(torch.utils.data.Dataset): ...@@ -108,7 +108,7 @@ class PennFudanDataset(torch.utils.data.Dataset):
target["boxes"] = datapoints.BoundingBoxes( target["boxes"] = datapoints.BoundingBoxes(
boxes, boxes,
format=datapoints.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=F.get_spatial_size(img), canvas_size=F.get_size(img),
) )
target["labels"] = labels target["labels"] = labels
target["masks"] = datapoints.Mask(masks) target["masks"] = datapoints.Mask(masks)
...@@ -129,7 +129,7 @@ class WrapPennFudanDataset: ...@@ -129,7 +129,7 @@ class WrapPennFudanDataset:
target["boxes"] = datapoints.BoundingBoxes( target["boxes"] = datapoints.BoundingBoxes(
target["boxes"], target["boxes"],
format=datapoints.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=F.get_spatial_size(img), canvas_size=F.get_size(img),
) )
target["masks"] = datapoints.Mask(target["masks"]) target["masks"] = datapoints.Mask(target["masks"])
return img, target return img, target
......
...@@ -30,7 +30,7 @@ def load_data(): ...@@ -30,7 +30,7 @@ def load_data():
masks = datapoints.Mask(merged_masks == labels.view(-1, 1, 1)) masks = datapoints.Mask(merged_masks == labels.view(-1, 1, 1))
bounding_boxes = datapoints.BoundingBoxes( bounding_boxes = datapoints.BoundingBoxes(
masks_to_boxes(masks), format=datapoints.BoundingBoxFormat.XYXY, spatial_size=image.shape[-2:] masks_to_boxes(masks), format=datapoints.BoundingBoxFormat.XYXY, canvas_size=image.shape[-2:]
) )
return path, image, bounding_boxes, masks, labels return path, image, bounding_boxes, masks, labels
......
...@@ -412,7 +412,7 @@ DEFAULT_SPATIAL_SIZES = ( ...@@ -412,7 +412,7 @@ DEFAULT_SPATIAL_SIZES = (
) )
def _parse_spatial_size(size, *, name="size"): def _parse_canvas_size(size, *, name="size"):
if size == "random": if size == "random":
raise ValueError("This should never happen") raise ValueError("This should never happen")
elif isinstance(size, int) and size > 0: elif isinstance(size, int) and size > 0:
...@@ -467,12 +467,13 @@ class TensorLoader: ...@@ -467,12 +467,13 @@ class TensorLoader:
@dataclasses.dataclass @dataclasses.dataclass
class ImageLoader(TensorLoader): class ImageLoader(TensorLoader):
spatial_size: Tuple[int, int] = dataclasses.field(init=False) canvas_size: Tuple[int, int] = dataclasses.field(init=False)
num_channels: int = dataclasses.field(init=False) num_channels: int = dataclasses.field(init=False)
memory_format: torch.memory_format = torch.contiguous_format memory_format: torch.memory_format = torch.contiguous_format
canvas_size: Tuple[int, int] = dataclasses.field(init=False)
def __post_init__(self): def __post_init__(self):
self.spatial_size = self.shape[-2:] self.canvas_size = self.canvas_size = self.shape[-2:]
self.num_channels = self.shape[-3] self.num_channels = self.shape[-3]
def load(self, device): def load(self, device):
...@@ -538,7 +539,7 @@ def make_image_loader( ...@@ -538,7 +539,7 @@ def make_image_loader(
): ):
if not constant_alpha: if not constant_alpha:
raise ValueError("This should never happen") raise ValueError("This should never happen")
size = _parse_spatial_size(size) size = _parse_canvas_size(size)
num_channels = get_num_channels(color_space) num_channels = get_num_channels(color_space)
def fn(shape, dtype, device, memory_format): def fn(shape, dtype, device, memory_format):
...@@ -578,7 +579,7 @@ make_images = from_loaders(make_image_loaders) ...@@ -578,7 +579,7 @@ make_images = from_loaders(make_image_loaders)
def make_image_loader_for_interpolation( def make_image_loader_for_interpolation(
size=(233, 147), *, color_space="RGB", dtype=torch.uint8, memory_format=torch.contiguous_format size=(233, 147), *, color_space="RGB", dtype=torch.uint8, memory_format=torch.contiguous_format
): ):
size = _parse_spatial_size(size) size = _parse_canvas_size(size)
num_channels = get_num_channels(color_space) num_channels = get_num_channels(color_space)
def fn(shape, dtype, device, memory_format): def fn(shape, dtype, device, memory_format):
...@@ -623,43 +624,20 @@ def make_image_loaders_for_interpolation( ...@@ -623,43 +624,20 @@ def make_image_loaders_for_interpolation(
class BoundingBoxesLoader(TensorLoader): class BoundingBoxesLoader(TensorLoader):
format: datapoints.BoundingBoxFormat format: datapoints.BoundingBoxFormat
spatial_size: Tuple[int, int] spatial_size: Tuple[int, int]
canvas_size: Tuple[int, int] = dataclasses.field(init=False)
def __post_init__(self):
self.canvas_size = self.spatial_size
def make_bounding_box( def make_bounding_box(
size=None, canvas_size=DEFAULT_SIZE,
*, *,
format=datapoints.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=None,
batch_dims=(), batch_dims=(),
dtype=None, dtype=None,
device="cpu", device="cpu",
): ):
"""
size: Size of the actual bounding box, i.e.
- (box[3] - box[1], box[2] - box[0]) for XYXY
- (H, W) for XYWH and CXCYWH
spatial_size: Size of the reference object, e.g. an image. Corresponds to the .spatial_size attribute on
returned datapoints.BoundingBoxes
To generate a valid joint sample, you need to set spatial_size here to the same value as size on the other maker
functions, e.g.
.. code::
image = make_image=(size=size)
bounding_boxes = make_bounding_box(spatial_size=size)
assert F.get_spatial_size(bounding_boxes) == F.get_spatial_size(image)
For convenience, if both size and spatial_size are omitted, spatial_size defaults to the same value as size for all
other maker functions, e.g.
.. code::
image = make_image=()
bounding_boxes = make_bounding_box()
assert F.get_spatial_size(bounding_boxes) == F.get_spatial_size(image)
"""
def sample_position(values, max_value): def sample_position(values, max_value):
# We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high. # We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high.
# However, if we have batch_dims, we need tensors as limits. # However, if we have batch_dims, we need tensors as limits.
...@@ -668,28 +646,16 @@ def make_bounding_box( ...@@ -668,28 +646,16 @@ def make_bounding_box(
if isinstance(format, str): if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format] format = datapoints.BoundingBoxFormat[format]
if spatial_size is None:
if size is None:
spatial_size = DEFAULT_SIZE
else:
height, width = size
height_margin, width_margin = torch.randint(10, (2,)).tolist()
spatial_size = (height + height_margin, width + width_margin)
dtype = dtype or torch.float32 dtype = dtype or torch.float32
if any(dim == 0 for dim in batch_dims): if any(dim == 0 for dim in batch_dims):
return datapoints.BoundingBoxes( return datapoints.BoundingBoxes(
torch.empty(*batch_dims, 4, dtype=dtype, device=device), format=format, spatial_size=spatial_size torch.empty(*batch_dims, 4, dtype=dtype, device=device), format=format, canvas_size=canvas_size
) )
if size is None: h, w = [torch.randint(1, c, batch_dims) for c in canvas_size]
h, w = [torch.randint(1, s, batch_dims) for s in spatial_size] y = sample_position(h, canvas_size[0])
else: x = sample_position(w, canvas_size[1])
h, w = [torch.full(batch_dims, s, dtype=torch.int) for s in size]
y = sample_position(h, spatial_size[0])
x = sample_position(w, spatial_size[1])
if format is datapoints.BoundingBoxFormat.XYWH: if format is datapoints.BoundingBoxFormat.XYWH:
parts = (x, y, w, h) parts = (x, y, w, h)
...@@ -706,15 +672,15 @@ def make_bounding_box( ...@@ -706,15 +672,15 @@ def make_bounding_box(
raise ValueError(f"Format {format} is not supported") raise ValueError(f"Format {format} is not supported")
return datapoints.BoundingBoxes( return datapoints.BoundingBoxes(
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, spatial_size=spatial_size torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, canvas_size=canvas_size
) )
def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtype=torch.float32): def make_bounding_box_loader(*, extra_dims=(), format, canvas_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtype=torch.float32):
if isinstance(format, str): if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format] format = datapoints.BoundingBoxFormat[format]
spatial_size = _parse_spatial_size(spatial_size, name="spatial_size") canvas_size = _parse_canvas_size(canvas_size, name="canvas_size")
def fn(shape, dtype, device): def fn(shape, dtype, device):
*batch_dims, num_coordinates = shape *batch_dims, num_coordinates = shape
...@@ -722,21 +688,21 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORT ...@@ -722,21 +688,21 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORT
raise pytest.UsageError() raise pytest.UsageError()
return make_bounding_box( return make_bounding_box(
format=format, spatial_size=spatial_size, batch_dims=batch_dims, dtype=dtype, device=device format=format, canvas_size=canvas_size, batch_dims=batch_dims, dtype=dtype, device=device
) )
return BoundingBoxesLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=spatial_size) return BoundingBoxesLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=canvas_size)
def make_bounding_box_loaders( def make_bounding_box_loaders(
*, *,
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
formats=tuple(datapoints.BoundingBoxFormat), formats=tuple(datapoints.BoundingBoxFormat),
spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, canvas_size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
dtypes=(torch.float32, torch.float64, torch.int64), dtypes=(torch.float32, torch.float64, torch.int64),
): ):
for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes): for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes):
yield make_bounding_box_loader(**params, spatial_size=spatial_size) yield make_bounding_box_loader(**params, canvas_size=canvas_size)
make_bounding_boxes = from_loaders(make_bounding_box_loaders) make_bounding_boxes = from_loaders(make_bounding_box_loaders)
...@@ -761,7 +727,7 @@ def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtyp ...@@ -761,7 +727,7 @@ def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtyp
def make_detection_mask_loader(size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_objects=5, extra_dims=(), dtype=torch.uint8): def make_detection_mask_loader(size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_objects=5, extra_dims=(), dtype=torch.uint8):
# This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects # This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects
size = _parse_spatial_size(size) size = _parse_canvas_size(size)
def fn(shape, dtype, device): def fn(shape, dtype, device):
*batch_dims, num_objects, height, width = shape *batch_dims, num_objects, height, width = shape
...@@ -802,7 +768,7 @@ def make_segmentation_mask_loader( ...@@ -802,7 +768,7 @@ def make_segmentation_mask_loader(
size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_categories=10, extra_dims=(), dtype=torch.uint8 size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_categories=10, extra_dims=(), dtype=torch.uint8
): ):
# This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values # This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values
spatial_size = _parse_spatial_size(size) canvas_size = _parse_canvas_size(size)
def fn(shape, dtype, device): def fn(shape, dtype, device):
*batch_dims, height, width = shape *batch_dims, height, width = shape
...@@ -810,7 +776,7 @@ def make_segmentation_mask_loader( ...@@ -810,7 +776,7 @@ def make_segmentation_mask_loader(
(height, width), num_categories=num_categories, batch_dims=batch_dims, dtype=dtype, device=device (height, width), num_categories=num_categories, batch_dims=batch_dims, dtype=dtype, device=device
) )
return MaskLoader(fn, shape=(*extra_dims, *spatial_size), dtype=dtype) return MaskLoader(fn, shape=(*extra_dims, *canvas_size), dtype=dtype)
def make_segmentation_mask_loaders( def make_segmentation_mask_loaders(
...@@ -860,7 +826,7 @@ def make_video_loader( ...@@ -860,7 +826,7 @@ def make_video_loader(
extra_dims=(), extra_dims=(),
dtype=torch.uint8, dtype=torch.uint8,
): ):
size = _parse_spatial_size(size) size = _parse_canvas_size(size)
def fn(shape, dtype, device, memory_format): def fn(shape, dtype, device, memory_format):
*batch_dims, num_frames, _, height, width = shape *batch_dims, num_frames, _, height, width = shape
......
...@@ -27,7 +27,7 @@ def test_mask_instance(data): ...@@ -27,7 +27,7 @@ def test_mask_instance(data):
"format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH] "format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH]
) )
def test_bbox_instance(data, format): def test_bbox_instance(data, format):
bboxes = datapoints.BoundingBoxes(data, format=format, spatial_size=(32, 32)) bboxes = datapoints.BoundingBoxes(data, format=format, canvas_size=(32, 32))
assert isinstance(bboxes, torch.Tensor) assert isinstance(bboxes, torch.Tensor)
assert bboxes.ndim == 2 and bboxes.shape[1] == 4 assert bboxes.ndim == 2 and bboxes.shape[1] == 4
if isinstance(format, str): if isinstance(format, str):
...@@ -164,7 +164,7 @@ def test_wrap_like(): ...@@ -164,7 +164,7 @@ def test_wrap_like():
[ [
datapoints.Image(torch.rand(3, 16, 16)), datapoints.Image(torch.rand(3, 16, 16)),
datapoints.Video(torch.rand(2, 3, 16, 16)), datapoints.Video(torch.rand(2, 3, 16, 16)),
datapoints.BoundingBoxes([0.0, 1.0, 2.0, 3.0], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(10, 10)), datapoints.BoundingBoxes([0.0, 1.0, 2.0, 3.0], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(10, 10)),
datapoints.Mask(torch.randint(0, 256, (16, 16), dtype=torch.uint8)), datapoints.Mask(torch.randint(0, 256, (16, 16), dtype=torch.uint8)),
], ],
) )
......
...@@ -164,7 +164,7 @@ class TestSimpleCopyPaste: ...@@ -164,7 +164,7 @@ class TestSimpleCopyPaste:
labels = torch.nn.functional.one_hot(labels, num_classes=5) labels = torch.nn.functional.one_hot(labels, num_classes=5)
target = { target = {
"boxes": BoundingBoxes( "boxes": BoundingBoxes(
torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", spatial_size=(32, 32) torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", canvas_size=(32, 32)
), ),
"masks": Mask(masks), "masks": Mask(masks),
"labels": label_type(labels), "labels": label_type(labels),
...@@ -179,7 +179,7 @@ class TestSimpleCopyPaste: ...@@ -179,7 +179,7 @@ class TestSimpleCopyPaste:
paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5) paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5)
paste_target = { paste_target = {
"boxes": BoundingBoxes( "boxes": BoundingBoxes(
torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", spatial_size=(32, 32) torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", canvas_size=(32, 32)
), ),
"masks": Mask(paste_masks), "masks": Mask(paste_masks),
"labels": label_type(paste_labels), "labels": label_type(paste_labels),
...@@ -210,13 +210,13 @@ class TestFixedSizeCrop: ...@@ -210,13 +210,13 @@ class TestFixedSizeCrop:
def test__get_params(self, mocker): def test__get_params(self, mocker):
crop_size = (7, 7) crop_size = (7, 7)
batch_shape = (10,) batch_shape = (10,)
spatial_size = (11, 5) canvas_size = (11, 5)
transform = transforms.FixedSizeCrop(size=crop_size) transform = transforms.FixedSizeCrop(size=crop_size)
flat_inputs = [ flat_inputs = [
make_image(size=spatial_size, color_space="RGB"), make_image(size=canvas_size, color_space="RGB"),
make_bounding_box(format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, batch_dims=batch_shape), make_bounding_box(format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=batch_shape),
] ]
params = transform._get_params(flat_inputs) params = transform._get_params(flat_inputs)
...@@ -295,7 +295,7 @@ class TestFixedSizeCrop: ...@@ -295,7 +295,7 @@ class TestFixedSizeCrop:
def test__transform_culling(self, mocker): def test__transform_culling(self, mocker):
batch_size = 10 batch_size = 10
spatial_size = (10, 10) canvas_size = (10, 10)
is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool) is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool)
mocker.patch( mocker.patch(
...@@ -304,17 +304,17 @@ class TestFixedSizeCrop: ...@@ -304,17 +304,17 @@ class TestFixedSizeCrop:
needs_crop=True, needs_crop=True,
top=0, top=0,
left=0, left=0,
height=spatial_size[0], height=canvas_size[0],
width=spatial_size[1], width=canvas_size[1],
is_valid=is_valid, is_valid=is_valid,
needs_pad=False, needs_pad=False,
), ),
) )
bounding_boxes = make_bounding_box( bounding_boxes = make_bounding_box(
format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, batch_dims=(batch_size,) format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,)
) )
masks = make_detection_mask(size=spatial_size, batch_dims=(batch_size,)) masks = make_detection_mask(size=canvas_size, batch_dims=(batch_size,))
labels = make_label(extra_dims=(batch_size,)) labels = make_label(extra_dims=(batch_size,))
transform = transforms.FixedSizeCrop((-1, -1)) transform = transforms.FixedSizeCrop((-1, -1))
...@@ -334,7 +334,7 @@ class TestFixedSizeCrop: ...@@ -334,7 +334,7 @@ class TestFixedSizeCrop:
def test__transform_bounding_boxes_clamping(self, mocker): def test__transform_bounding_boxes_clamping(self, mocker):
batch_size = 3 batch_size = 3
spatial_size = (10, 10) canvas_size = (10, 10)
mocker.patch( mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
...@@ -342,15 +342,15 @@ class TestFixedSizeCrop: ...@@ -342,15 +342,15 @@ class TestFixedSizeCrop:
needs_crop=True, needs_crop=True,
top=0, top=0,
left=0, left=0,
height=spatial_size[0], height=canvas_size[0],
width=spatial_size[1], width=canvas_size[1],
is_valid=torch.full((batch_size,), fill_value=True), is_valid=torch.full((batch_size,), fill_value=True),
needs_pad=False, needs_pad=False,
), ),
) )
bounding_boxes = make_bounding_box( bounding_boxes = make_bounding_box(
format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, batch_dims=(batch_size,) format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,)
) )
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes") mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes")
...@@ -496,7 +496,7 @@ def test_fixed_sized_crop_against_detection_reference(): ...@@ -496,7 +496,7 @@ def test_fixed_sized_crop_against_detection_reference():
pil_image = to_image_pil(make_image(size=size, color_space="RGB")) pil_image = to_image_pil(make_image(size=size, color_space="RGB"))
target = { target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long), "masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
} }
...@@ -505,7 +505,7 @@ def test_fixed_sized_crop_against_detection_reference(): ...@@ -505,7 +505,7 @@ def test_fixed_sized_crop_against_detection_reference():
tensor_image = torch.Tensor(make_image(size=size, color_space="RGB")) tensor_image = torch.Tensor(make_image(size=size, color_space="RGB"))
target = { target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long), "masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
} }
...@@ -514,7 +514,7 @@ def test_fixed_sized_crop_against_detection_reference(): ...@@ -514,7 +514,7 @@ def test_fixed_sized_crop_against_detection_reference():
datapoint_image = make_image(size=size, color_space="RGB") datapoint_image = make_image(size=size, color_space="RGB")
target = { target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long), "masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
} }
......
...@@ -174,20 +174,20 @@ class TestSmoke: ...@@ -174,20 +174,20 @@ class TestSmoke:
) )
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
def test_common(self, transform, adapter, container_type, image_or_video, device): def test_common(self, transform, adapter, container_type, image_or_video, device):
spatial_size = F.get_spatial_size(image_or_video) canvas_size = F.get_size(image_or_video)
input = dict( input = dict(
image_or_video=image_or_video, image_or_video=image_or_video,
image_datapoint=make_image(size=spatial_size), image_datapoint=make_image(size=canvas_size),
video_datapoint=make_video(size=spatial_size), video_datapoint=make_video(size=canvas_size),
image_pil=next(make_pil_images(sizes=[spatial_size], color_spaces=["RGB"])), image_pil=next(make_pil_images(sizes=[canvas_size], color_spaces=["RGB"])),
bounding_boxes_xyxy=make_bounding_box( bounding_boxes_xyxy=make_bounding_box(
format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, batch_dims=(3,) format=datapoints.BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(3,)
), ),
bounding_boxes_xywh=make_bounding_box( bounding_boxes_xywh=make_bounding_box(
format=datapoints.BoundingBoxFormat.XYWH, spatial_size=spatial_size, batch_dims=(4,) format=datapoints.BoundingBoxFormat.XYWH, canvas_size=canvas_size, batch_dims=(4,)
), ),
bounding_boxes_cxcywh=make_bounding_box( bounding_boxes_cxcywh=make_bounding_box(
format=datapoints.BoundingBoxFormat.CXCYWH, spatial_size=spatial_size, batch_dims=(5,) format=datapoints.BoundingBoxFormat.CXCYWH, canvas_size=canvas_size, batch_dims=(5,)
), ),
bounding_boxes_degenerate_xyxy=datapoints.BoundingBoxes( bounding_boxes_degenerate_xyxy=datapoints.BoundingBoxes(
[ [
...@@ -199,7 +199,7 @@ class TestSmoke: ...@@ -199,7 +199,7 @@ class TestSmoke:
[2, 2, 1, 1], # x1 > x2, y1 > y2 [2, 2, 1, 1], # x1 > x2, y1 > y2
], ],
format=datapoints.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size, canvas_size=canvas_size,
), ),
bounding_boxes_degenerate_xywh=datapoints.BoundingBoxes( bounding_boxes_degenerate_xywh=datapoints.BoundingBoxes(
[ [
...@@ -211,7 +211,7 @@ class TestSmoke: ...@@ -211,7 +211,7 @@ class TestSmoke:
[0, 0, -1, -1], # negative height and width [0, 0, -1, -1], # negative height and width
], ],
format=datapoints.BoundingBoxFormat.XYWH, format=datapoints.BoundingBoxFormat.XYWH,
spatial_size=spatial_size, canvas_size=canvas_size,
), ),
bounding_boxes_degenerate_cxcywh=datapoints.BoundingBoxes( bounding_boxes_degenerate_cxcywh=datapoints.BoundingBoxes(
[ [
...@@ -223,10 +223,10 @@ class TestSmoke: ...@@ -223,10 +223,10 @@ class TestSmoke:
[0, 0, -1, -1], # negative height and width [0, 0, -1, -1], # negative height and width
], ],
format=datapoints.BoundingBoxFormat.CXCYWH, format=datapoints.BoundingBoxFormat.CXCYWH,
spatial_size=spatial_size, canvas_size=canvas_size,
), ),
detection_mask=make_detection_mask(size=spatial_size), detection_mask=make_detection_mask(size=canvas_size),
segmentation_mask=make_segmentation_mask(size=spatial_size), segmentation_mask=make_segmentation_mask(size=canvas_size),
int=0, int=0,
float=0.0, float=0.0,
bool=True, bool=True,
...@@ -271,7 +271,7 @@ class TestSmoke: ...@@ -271,7 +271,7 @@ class TestSmoke:
# TODO: we should test that against all degenerate boxes above # TODO: we should test that against all degenerate boxes above
for format in list(datapoints.BoundingBoxFormat): for format in list(datapoints.BoundingBoxFormat):
sample = dict( sample = dict(
boxes=datapoints.BoundingBoxes([[0, 0, 0, 0]], format=format, spatial_size=(224, 244)), boxes=datapoints.BoundingBoxes([[0, 0, 0, 0]], format=format, canvas_size=(224, 244)),
labels=torch.tensor([3]), labels=torch.tensor([3]),
) )
assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4) assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4)
...@@ -473,11 +473,11 @@ class TestRandomZoomOut: ...@@ -473,11 +473,11 @@ class TestRandomZoomOut:
@pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)])
@pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]]) @pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]])
def test__get_params(self, fill, side_range, mocker): def test__get_params(self, fill, side_range):
transform = transforms.RandomZoomOut(fill=fill, side_range=side_range) transform = transforms.RandomZoomOut(fill=fill, side_range=side_range)
image = mocker.MagicMock(spec=datapoints.Image) h, w = size = (24, 32)
h, w = image.spatial_size = (24, 32) image = make_image(size)
params = transform._get_params([image]) params = transform._get_params([image])
...@@ -490,9 +490,7 @@ class TestRandomZoomOut: ...@@ -490,9 +490,7 @@ class TestRandomZoomOut:
@pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)])
@pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]]) @pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]])
def test__transform(self, fill, side_range, mocker): def test__transform(self, fill, side_range, mocker):
inpt = mocker.MagicMock(spec=datapoints.Image) inpt = make_image((24, 32))
inpt.num_channels = 3
inpt.spatial_size = (24, 32)
transform = transforms.RandomZoomOut(fill=fill, side_range=side_range, p=1) transform = transforms.RandomZoomOut(fill=fill, side_range=side_range, p=1)
...@@ -559,11 +557,9 @@ class TestRandomCrop: ...@@ -559,11 +557,9 @@ class TestRandomCrop:
@pytest.mark.parametrize("padding", [None, 1, [2, 3], [1, 2, 3, 4]]) @pytest.mark.parametrize("padding", [None, 1, [2, 3], [1, 2, 3, 4]])
@pytest.mark.parametrize("size, pad_if_needed", [((10, 10), False), ((50, 25), True)]) @pytest.mark.parametrize("size, pad_if_needed", [((10, 10), False), ((50, 25), True)])
def test__get_params(self, padding, pad_if_needed, size, mocker): def test__get_params(self, padding, pad_if_needed, size):
image = mocker.MagicMock(spec=datapoints.Image) h, w = size = (24, 32)
image.num_channels = 3 image = make_image(size)
image.spatial_size = (24, 32)
h, w = image.spatial_size
transform = transforms.RandomCrop(size, padding=padding, pad_if_needed=pad_if_needed) transform = transforms.RandomCrop(size, padding=padding, pad_if_needed=pad_if_needed)
params = transform._get_params([image]) params = transform._get_params([image])
...@@ -613,21 +609,16 @@ class TestRandomCrop: ...@@ -613,21 +609,16 @@ class TestRandomCrop:
output_size, padding=padding, pad_if_needed=pad_if_needed, fill=fill, padding_mode=padding_mode output_size, padding=padding, pad_if_needed=pad_if_needed, fill=fill, padding_mode=padding_mode
) )
inpt = mocker.MagicMock(spec=datapoints.Image) h, w = size = (32, 32)
inpt.num_channels = 3 inpt = make_image(size)
inpt.spatial_size = (32, 32)
expected = mocker.MagicMock(spec=datapoints.Image)
expected.num_channels = 3
if isinstance(padding, int): if isinstance(padding, int):
expected.spatial_size = (inpt.spatial_size[0] + padding, inpt.spatial_size[1] + padding) new_size = (h + padding, w + padding)
elif isinstance(padding, list): elif isinstance(padding, list):
expected.spatial_size = ( new_size = (h + sum(padding[0::2]), w + sum(padding[1::2]))
inpt.spatial_size[0] + sum(padding[0::2]),
inpt.spatial_size[1] + sum(padding[1::2]),
)
else: else:
expected.spatial_size = inpt.spatial_size new_size = size
expected = make_image(new_size)
_ = mocker.patch("torchvision.transforms.v2.functional.pad", return_value=expected) _ = mocker.patch("torchvision.transforms.v2.functional.pad", return_value=expected)
fn_crop = mocker.patch("torchvision.transforms.v2.functional.crop") fn_crop = mocker.patch("torchvision.transforms.v2.functional.crop")
...@@ -703,7 +694,7 @@ class TestGaussianBlur: ...@@ -703,7 +694,7 @@ class TestGaussianBlur:
fn = mocker.patch("torchvision.transforms.v2.functional.gaussian_blur") fn = mocker.patch("torchvision.transforms.v2.functional.gaussian_blur")
inpt = mocker.MagicMock(spec=datapoints.Image) inpt = mocker.MagicMock(spec=datapoints.Image)
inpt.num_channels = 3 inpt.num_channels = 3
inpt.spatial_size = (24, 32) inpt.canvas_size = (24, 32)
# vfdev-5, Feature Request: let's store params as Transform attribute # vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users # This could be also helpful for users
...@@ -749,16 +740,14 @@ class TestRandomPerspective: ...@@ -749,16 +740,14 @@ class TestRandomPerspective:
with pytest.raises(TypeError, match="Got inappropriate fill arg"): with pytest.raises(TypeError, match="Got inappropriate fill arg"):
transforms.RandomPerspective(0.5, fill="abc") transforms.RandomPerspective(0.5, fill="abc")
def test__get_params(self, mocker): def test__get_params(self):
dscale = 0.5 dscale = 0.5
transform = transforms.RandomPerspective(dscale) transform = transforms.RandomPerspective(dscale)
image = mocker.MagicMock(spec=datapoints.Image)
image.num_channels = 3 image = make_image((24, 32))
image.spatial_size = (24, 32)
params = transform._get_params([image]) params = transform._get_params([image])
h, w = image.spatial_size
assert "coefficients" in params assert "coefficients" in params
assert len(params["coefficients"]) == 8 assert len(params["coefficients"]) == 8
...@@ -769,9 +758,9 @@ class TestRandomPerspective: ...@@ -769,9 +758,9 @@ class TestRandomPerspective:
transform = transforms.RandomPerspective(distortion_scale, fill=fill, interpolation=interpolation) transform = transforms.RandomPerspective(distortion_scale, fill=fill, interpolation=interpolation)
fn = mocker.patch("torchvision.transforms.v2.functional.perspective") fn = mocker.patch("torchvision.transforms.v2.functional.perspective")
inpt = mocker.MagicMock(spec=datapoints.Image)
inpt.num_channels = 3 inpt = make_image((24, 32))
inpt.spatial_size = (24, 32)
# vfdev-5, Feature Request: let's store params as Transform attribute # vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users # This could be also helpful for users
# Otherwise, we can mock transform._get_params # Otherwise, we can mock transform._get_params
...@@ -809,17 +798,16 @@ class TestElasticTransform: ...@@ -809,17 +798,16 @@ class TestElasticTransform:
with pytest.raises(TypeError, match="Got inappropriate fill arg"): with pytest.raises(TypeError, match="Got inappropriate fill arg"):
transforms.ElasticTransform(1.0, 2.0, fill="abc") transforms.ElasticTransform(1.0, 2.0, fill="abc")
def test__get_params(self, mocker): def test__get_params(self):
alpha = 2.0 alpha = 2.0
sigma = 3.0 sigma = 3.0
transform = transforms.ElasticTransform(alpha, sigma) transform = transforms.ElasticTransform(alpha, sigma)
image = mocker.MagicMock(spec=datapoints.Image)
image.num_channels = 3 h, w = size = (24, 32)
image.spatial_size = (24, 32) image = make_image(size)
params = transform._get_params([image]) params = transform._get_params([image])
h, w = image.spatial_size
displacement = params["displacement"] displacement = params["displacement"]
assert displacement.shape == (1, h, w, 2) assert displacement.shape == (1, h, w, 2)
assert (-alpha / w <= displacement[0, ..., 0]).all() and (displacement[0, ..., 0] <= alpha / w).all() assert (-alpha / w <= displacement[0, ..., 0]).all() and (displacement[0, ..., 0] <= alpha / w).all()
...@@ -845,7 +833,7 @@ class TestElasticTransform: ...@@ -845,7 +833,7 @@ class TestElasticTransform:
fn = mocker.patch("torchvision.transforms.v2.functional.elastic") fn = mocker.patch("torchvision.transforms.v2.functional.elastic")
inpt = mocker.MagicMock(spec=datapoints.Image) inpt = mocker.MagicMock(spec=datapoints.Image)
inpt.num_channels = 3 inpt.num_channels = 3
inpt.spatial_size = (24, 32) inpt.canvas_size = (24, 32)
# Let's mock transform._get_params to control the output: # Let's mock transform._get_params to control the output:
transform._get_params = mocker.MagicMock() transform._get_params = mocker.MagicMock()
...@@ -856,7 +844,7 @@ class TestElasticTransform: ...@@ -856,7 +844,7 @@ class TestElasticTransform:
class TestRandomErasing: class TestRandomErasing:
def test_assertions(self, mocker): def test_assertions(self):
with pytest.raises(TypeError, match="Argument value should be either a number or str or a sequence"): with pytest.raises(TypeError, match="Argument value should be either a number or str or a sequence"):
transforms.RandomErasing(value={}) transforms.RandomErasing(value={})
...@@ -872,9 +860,7 @@ class TestRandomErasing: ...@@ -872,9 +860,7 @@ class TestRandomErasing:
with pytest.raises(ValueError, match="Scale should be between 0 and 1"): with pytest.raises(ValueError, match="Scale should be between 0 and 1"):
transforms.RandomErasing(scale=[-1, 2]) transforms.RandomErasing(scale=[-1, 2])
image = mocker.MagicMock(spec=datapoints.Image) image = make_image((24, 32))
image.num_channels = 3
image.spatial_size = (24, 32)
transform = transforms.RandomErasing(value=[1, 2, 3, 4]) transform = transforms.RandomErasing(value=[1, 2, 3, 4])
...@@ -882,10 +868,9 @@ class TestRandomErasing: ...@@ -882,10 +868,9 @@ class TestRandomErasing:
transform._get_params([image]) transform._get_params([image])
@pytest.mark.parametrize("value", [5.0, [1, 2, 3], "random"]) @pytest.mark.parametrize("value", [5.0, [1, 2, 3], "random"])
def test__get_params(self, value, mocker): def test__get_params(self, value):
image = mocker.MagicMock(spec=datapoints.Image) image = make_image((24, 32))
image.num_channels = 3 num_channels, height, width = F.get_dimensions(image)
image.spatial_size = (24, 32)
transform = transforms.RandomErasing(value=value) transform = transforms.RandomErasing(value=value)
params = transform._get_params([image]) params = transform._get_params([image])
...@@ -895,14 +880,14 @@ class TestRandomErasing: ...@@ -895,14 +880,14 @@ class TestRandomErasing:
i, j = params["i"], params["j"] i, j = params["i"], params["j"]
assert isinstance(v, torch.Tensor) assert isinstance(v, torch.Tensor)
if value == "random": if value == "random":
assert v.shape == (image.num_channels, h, w) assert v.shape == (num_channels, h, w)
elif isinstance(value, (int, float)): elif isinstance(value, (int, float)):
assert v.shape == (1, 1, 1) assert v.shape == (1, 1, 1)
elif isinstance(value, (list, tuple)): elif isinstance(value, (list, tuple)):
assert v.shape == (image.num_channels, 1, 1) assert v.shape == (num_channels, 1, 1)
assert 0 <= i <= image.spatial_size[0] - h assert 0 <= i <= height - h
assert 0 <= j <= image.spatial_size[1] - w assert 0 <= j <= width - w
@pytest.mark.parametrize("p", [0, 1]) @pytest.mark.parametrize("p", [0, 1])
def test__transform(self, mocker, p): def test__transform(self, mocker, p):
...@@ -1061,14 +1046,13 @@ class TestRandomChoice: ...@@ -1061,14 +1046,13 @@ class TestRandomChoice:
class TestRandomIoUCrop: class TestRandomIoUCrop:
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]]) @pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]])
def test__get_params(self, device, options, mocker): def test__get_params(self, device, options):
image = mocker.MagicMock(spec=datapoints.Image) orig_h, orig_w = size = (24, 32)
image.num_channels = 3 image = make_image(size)
image.spatial_size = (24, 32)
bboxes = datapoints.BoundingBoxes( bboxes = datapoints.BoundingBoxes(
torch.tensor([[1, 1, 10, 10], [20, 20, 23, 23], [1, 20, 10, 23], [20, 1, 23, 10]]), torch.tensor([[1, 1, 10, 10], [20, 20, 23, 23], [1, 20, 10, 23], [20, 1, 23, 10]]),
format="XYXY", format="XYXY",
spatial_size=image.spatial_size, canvas_size=size,
device=device, device=device,
) )
sample = [image, bboxes] sample = [image, bboxes]
...@@ -1087,8 +1071,6 @@ class TestRandomIoUCrop: ...@@ -1087,8 +1071,6 @@ class TestRandomIoUCrop:
assert len(params["is_within_crop_area"]) > 0 assert len(params["is_within_crop_area"]) > 0
assert params["is_within_crop_area"].dtype == torch.bool assert params["is_within_crop_area"].dtype == torch.bool
orig_h = image.spatial_size[0]
orig_w = image.spatial_size[1]
assert int(transform.min_scale * orig_h) <= params["height"] <= int(transform.max_scale * orig_h) assert int(transform.min_scale * orig_h) <= params["height"] <= int(transform.max_scale * orig_h)
assert int(transform.min_scale * orig_w) <= params["width"] <= int(transform.max_scale * orig_w) assert int(transform.min_scale * orig_w) <= params["width"] <= int(transform.max_scale * orig_w)
...@@ -1103,7 +1085,7 @@ class TestRandomIoUCrop: ...@@ -1103,7 +1085,7 @@ class TestRandomIoUCrop:
def test__transform_empty_params(self, mocker): def test__transform_empty_params(self, mocker):
transform = transforms.RandomIoUCrop(sampler_options=[2.0]) transform = transforms.RandomIoUCrop(sampler_options=[2.0])
image = datapoints.Image(torch.rand(1, 3, 4, 4)) image = datapoints.Image(torch.rand(1, 3, 4, 4))
bboxes = datapoints.BoundingBoxes(torch.tensor([[1, 1, 2, 2]]), format="XYXY", spatial_size=(4, 4)) bboxes = datapoints.BoundingBoxes(torch.tensor([[1, 1, 2, 2]]), format="XYXY", canvas_size=(4, 4))
label = torch.tensor([1]) label = torch.tensor([1])
sample = [image, bboxes, label] sample = [image, bboxes, label]
# Let's mock transform._get_params to control the output: # Let's mock transform._get_params to control the output:
...@@ -1122,9 +1104,10 @@ class TestRandomIoUCrop: ...@@ -1122,9 +1104,10 @@ class TestRandomIoUCrop:
def test__transform(self, mocker): def test__transform(self, mocker):
transform = transforms.RandomIoUCrop() transform = transforms.RandomIoUCrop()
image = datapoints.Image(torch.rand(3, 32, 24)) size = (32, 24)
bboxes = make_bounding_box(format="XYXY", spatial_size=(32, 24), batch_dims=(6,)) image = make_image(size)
masks = make_detection_mask((32, 24), num_objects=6) bboxes = make_bounding_box(format="XYXY", canvas_size=size, batch_dims=(6,))
masks = make_detection_mask(size, num_objects=6)
sample = [image, bboxes, masks] sample = [image, bboxes, masks]
...@@ -1155,13 +1138,14 @@ class TestRandomIoUCrop: ...@@ -1155,13 +1138,14 @@ class TestRandomIoUCrop:
class TestScaleJitter: class TestScaleJitter:
def test__get_params(self, mocker): def test__get_params(self):
spatial_size = (24, 32) canvas_size = (24, 32)
target_size = (16, 12) target_size = (16, 12)
scale_range = (0.5, 1.5) scale_range = (0.5, 1.5)
transform = transforms.ScaleJitter(target_size=target_size, scale_range=scale_range) transform = transforms.ScaleJitter(target_size=target_size, scale_range=scale_range)
sample = mocker.MagicMock(spec=datapoints.Image, num_channels=3, spatial_size=spatial_size)
sample = make_image(canvas_size)
n_samples = 5 n_samples = 5
for _ in range(n_samples): for _ in range(n_samples):
...@@ -1174,11 +1158,11 @@ class TestScaleJitter: ...@@ -1174,11 +1158,11 @@ class TestScaleJitter:
assert isinstance(size, tuple) and len(size) == 2 assert isinstance(size, tuple) and len(size) == 2
height, width = size height, width = size
r_min = min(target_size[1] / spatial_size[0], target_size[0] / spatial_size[1]) * scale_range[0] r_min = min(target_size[1] / canvas_size[0], target_size[0] / canvas_size[1]) * scale_range[0]
r_max = min(target_size[1] / spatial_size[0], target_size[0] / spatial_size[1]) * scale_range[1] r_max = min(target_size[1] / canvas_size[0], target_size[0] / canvas_size[1]) * scale_range[1]
assert int(spatial_size[0] * r_min) <= height <= int(spatial_size[0] * r_max) assert int(canvas_size[0] * r_min) <= height <= int(canvas_size[0] * r_max)
assert int(spatial_size[1] * r_min) <= width <= int(spatial_size[1] * r_max) assert int(canvas_size[1] * r_min) <= width <= int(canvas_size[1] * r_max)
def test__transform(self, mocker): def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode) interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
...@@ -1206,12 +1190,12 @@ class TestScaleJitter: ...@@ -1206,12 +1190,12 @@ class TestScaleJitter:
class TestRandomShortestSize: class TestRandomShortestSize:
@pytest.mark.parametrize("min_size,max_size", [([5, 9], 20), ([5, 9], None)]) @pytest.mark.parametrize("min_size,max_size", [([5, 9], 20), ([5, 9], None)])
def test__get_params(self, min_size, max_size, mocker): def test__get_params(self, min_size, max_size):
spatial_size = (3, 10) canvas_size = (3, 10)
transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size, antialias=True) transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size, antialias=True)
sample = mocker.MagicMock(spec=datapoints.Image, num_channels=3, spatial_size=spatial_size) sample = make_image(canvas_size)
params = transform._get_params([sample]) params = transform._get_params([sample])
assert "size" in params assert "size" in params
...@@ -1523,7 +1507,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): ...@@ -1523,7 +1507,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
boxes = torch.randint(0, min(H, W) // 2, size=(num_boxes, 4)) boxes = torch.randint(0, min(H, W) // 2, size=(num_boxes, 4))
boxes[:, 2:] += boxes[:, :2] boxes[:, 2:] += boxes[:, :2]
boxes = boxes.clamp(min=0, max=min(H, W)) boxes = boxes.clamp(min=0, max=min(H, W))
boxes = datapoints.BoundingBoxes(boxes, format="XYXY", spatial_size=(H, W)) boxes = datapoints.BoundingBoxes(boxes, format="XYXY", canvas_size=(H, W))
masks = datapoints.Mask(torch.randint(0, 2, size=(num_boxes, H, W), dtype=torch.uint8)) masks = datapoints.Mask(torch.randint(0, 2, size=(num_boxes, H, W), dtype=torch.uint8))
...@@ -1597,7 +1581,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type): ...@@ -1597,7 +1581,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
boxes = datapoints.BoundingBoxes( boxes = datapoints.BoundingBoxes(
boxes, boxes,
format=datapoints.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=(H, W), canvas_size=(H, W),
) )
masks = datapoints.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W))) masks = datapoints.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W)))
...@@ -1651,7 +1635,7 @@ def test_sanitize_bounding_boxes_errors(): ...@@ -1651,7 +1635,7 @@ def test_sanitize_bounding_boxes_errors():
good_bbox = datapoints.BoundingBoxes( good_bbox = datapoints.BoundingBoxes(
[[0, 0, 10, 10]], [[0, 0, 10, 10]],
format=datapoints.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=(20, 20), canvas_size=(20, 20),
) )
with pytest.raises(ValueError, match="min_size must be >= 1"): with pytest.raises(ValueError, match="min_size must be >= 1"):
...@@ -1678,7 +1662,7 @@ def test_sanitize_bounding_boxes_errors(): ...@@ -1678,7 +1662,7 @@ def test_sanitize_bounding_boxes_errors():
[[0, 0, 10, 10]], [[0, 0, 10, 10]],
], ],
format=datapoints.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=(20, 20), canvas_size=(20, 20),
) )
different_sizes = {"bbox": bad_bbox, "labels": torch.arange(bad_bbox.shape[0])} different_sizes = {"bbox": bad_bbox, "labels": torch.arange(bad_bbox.shape[0])}
transforms.SanitizeBoundingBoxes()(different_sizes) transforms.SanitizeBoundingBoxes()(different_sizes)
......
...@@ -31,7 +31,7 @@ from torchvision._utils import sequence_to_str ...@@ -31,7 +31,7 @@ from torchvision._utils import sequence_to_str
from torchvision.transforms import functional as legacy_F from torchvision.transforms import functional as legacy_F
from torchvision.transforms.v2 import functional as prototype_F from torchvision.transforms.v2 import functional as prototype_F
from torchvision.transforms.v2.functional import to_image_pil from torchvision.transforms.v2.functional import to_image_pil
from torchvision.transforms.v2.utils import query_spatial_size from torchvision.transforms.v2.utils import query_size
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)]) DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)])
...@@ -1090,7 +1090,7 @@ class TestRefDetTransforms: ...@@ -1090,7 +1090,7 @@ class TestRefDetTransforms:
pil_image = to_image_pil(make_image(size=size, color_space="RGB")) pil_image = to_image_pil(make_image(size=size, color_space="RGB"))
target = { target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
} }
if with_mask: if with_mask:
...@@ -1100,7 +1100,7 @@ class TestRefDetTransforms: ...@@ -1100,7 +1100,7 @@ class TestRefDetTransforms:
tensor_image = torch.Tensor(make_image(size=size, color_space="RGB", dtype=torch.float32)) tensor_image = torch.Tensor(make_image(size=size, color_space="RGB", dtype=torch.float32))
target = { target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
} }
if with_mask: if with_mask:
...@@ -1110,7 +1110,7 @@ class TestRefDetTransforms: ...@@ -1110,7 +1110,7 @@ class TestRefDetTransforms:
datapoint_image = make_image(size=size, color_space="RGB", dtype=torch.float32) datapoint_image = make_image(size=size, color_space="RGB", dtype=torch.float32)
target = { target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
} }
if with_mask: if with_mask:
...@@ -1172,7 +1172,7 @@ class PadIfSmaller(v2_transforms.Transform): ...@@ -1172,7 +1172,7 @@ class PadIfSmaller(v2_transforms.Transform):
self.fill = v2_transforms._geometry._setup_fill_arg(fill) self.fill = v2_transforms._geometry._setup_fill_arg(fill)
def _get_params(self, sample): def _get_params(self, sample):
height, width = query_spatial_size(sample) height, width = query_size(sample)
padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)] padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
needs_padding = any(padding) needs_padding = any(padding)
return dict(padding=padding, needs_padding=needs_padding) return dict(padding=padding, needs_padding=needs_padding)
......
...@@ -351,7 +351,7 @@ class TestDispatchers: ...@@ -351,7 +351,7 @@ class TestDispatchers:
F.get_image_size, F.get_image_size,
F.get_num_channels, F.get_num_channels,
F.get_num_frames, F.get_num_frames,
F.get_spatial_size, F.get_size,
F.rgb_to_grayscale, F.rgb_to_grayscale,
F.uniform_temporal_subsample, F.uniform_temporal_subsample,
], ],
...@@ -568,27 +568,27 @@ class TestClampBoundingBoxes: ...@@ -568,27 +568,27 @@ class TestClampBoundingBoxes:
[ [
dict(), dict(),
dict(format=datapoints.BoundingBoxFormat.XYXY), dict(format=datapoints.BoundingBoxFormat.XYXY),
dict(spatial_size=(1, 1)), dict(canvas_size=(1, 1)),
], ],
) )
def test_simple_tensor_insufficient_metadata(self, metadata): def test_simple_tensor_insufficient_metadata(self, metadata):
simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor) simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)
with pytest.raises(ValueError, match=re.escape("`format` and `spatial_size` has to be passed")): with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` has to be passed")):
F.clamp_bounding_boxes(simple_tensor, **metadata) F.clamp_bounding_boxes(simple_tensor, **metadata)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"metadata", "metadata",
[ [
dict(format=datapoints.BoundingBoxFormat.XYXY), dict(format=datapoints.BoundingBoxFormat.XYXY),
dict(spatial_size=(1, 1)), dict(canvas_size=(1, 1)),
dict(format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(1, 1)), dict(format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(1, 1)),
], ],
) )
def test_datapoint_explicit_metadata(self, metadata): def test_datapoint_explicit_metadata(self, metadata):
datapoint = next(make_bounding_boxes()) datapoint = next(make_bounding_boxes())
with pytest.raises(ValueError, match=re.escape("`format` and `spatial_size` must not be passed")): with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` must not be passed")):
F.clamp_bounding_boxes(datapoint, **metadata) F.clamp_bounding_boxes(datapoint, **metadata)
...@@ -673,7 +673,7 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt ...@@ -673,7 +673,7 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
# expected_bboxes.append(out_box) # expected_bboxes.append(out_box)
format = datapoints.BoundingBoxFormat.XYXY format = datapoints.BoundingBoxFormat.XYXY
spatial_size = (64, 76) canvas_size = (64, 76)
in_boxes = [ in_boxes = [
[10.0, 15.0, 25.0, 35.0], [10.0, 15.0, 25.0, 35.0],
[50.0, 5.0, 70.0, 22.0], [50.0, 5.0, 70.0, 22.0],
...@@ -684,23 +684,23 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt ...@@ -684,23 +684,23 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format) in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
expected_bboxes = clamp_bounding_boxes( expected_bboxes = clamp_bounding_boxes(
datapoints.BoundingBoxes(expected_bboxes, format="XYXY", spatial_size=spatial_size) datapoints.BoundingBoxes(expected_bboxes, format="XYXY", canvas_size=canvas_size)
).tolist() ).tolist()
output_boxes, output_spatial_size = F.crop_bounding_boxes( output_boxes, output_canvas_size = F.crop_bounding_boxes(
in_boxes, in_boxes,
format, format,
top, top,
left, left,
spatial_size[0], canvas_size[0],
spatial_size[1], canvas_size[1],
) )
if format != datapoints.BoundingBoxFormat.XYXY: if format != datapoints.BoundingBoxFormat.XYXY:
output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY) output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
torch.testing.assert_close(output_spatial_size, spatial_size) torch.testing.assert_close(output_canvas_size, canvas_size)
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
...@@ -737,7 +737,7 @@ def test_correctness_resized_crop_bounding_boxes(device, format, top, left, heig ...@@ -737,7 +737,7 @@ def test_correctness_resized_crop_bounding_boxes(device, format, top, left, heig
return bbox return bbox
format = datapoints.BoundingBoxFormat.XYXY format = datapoints.BoundingBoxFormat.XYXY
spatial_size = (100, 100) canvas_size = (100, 100)
in_boxes = [ in_boxes = [
[10.0, 10.0, 20.0, 20.0], [10.0, 10.0, 20.0, 20.0],
[5.0, 10.0, 15.0, 20.0], [5.0, 10.0, 15.0, 20.0],
...@@ -748,18 +748,18 @@ def test_correctness_resized_crop_bounding_boxes(device, format, top, left, heig ...@@ -748,18 +748,18 @@ def test_correctness_resized_crop_bounding_boxes(device, format, top, left, heig
expected_bboxes = torch.tensor(expected_bboxes, device=device) expected_bboxes = torch.tensor(expected_bboxes, device=device)
in_boxes = datapoints.BoundingBoxes( in_boxes = datapoints.BoundingBoxes(
in_boxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, device=device in_boxes, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=canvas_size, device=device
) )
if format != datapoints.BoundingBoxFormat.XYXY: if format != datapoints.BoundingBoxFormat.XYXY:
in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format) in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
output_boxes, output_spatial_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size) output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size)
if format != datapoints.BoundingBoxFormat.XYXY: if format != datapoints.BoundingBoxFormat.XYXY:
output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY) output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes, expected_bboxes) torch.testing.assert_close(output_boxes, expected_bboxes)
torch.testing.assert_close(output_spatial_size, size) torch.testing.assert_close(output_canvas_size, size)
def _parse_padding(padding): def _parse_padding(padding):
...@@ -798,28 +798,28 @@ def test_correctness_pad_bounding_boxes(device, padding): ...@@ -798,28 +798,28 @@ def test_correctness_pad_bounding_boxes(device, padding):
bbox = bbox.to(dtype) bbox = bbox.to(dtype)
return bbox return bbox
def _compute_expected_spatial_size(bbox, padding_): def _compute_expected_canvas_size(bbox, padding_):
pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_) pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_)
height, width = bbox.spatial_size height, width = bbox.canvas_size
return height + pad_up + pad_down, width + pad_left + pad_right return height + pad_up + pad_down, width + pad_left + pad_right
for bboxes in make_bounding_boxes(): for bboxes in make_bounding_boxes():
bboxes = bboxes.to(device) bboxes = bboxes.to(device)
bboxes_format = bboxes.format bboxes_format = bboxes.format
bboxes_spatial_size = bboxes.spatial_size bboxes_canvas_size = bboxes.canvas_size
output_boxes, output_spatial_size = F.pad_bounding_boxes( output_boxes, output_canvas_size = F.pad_bounding_boxes(
bboxes, format=bboxes_format, spatial_size=bboxes_spatial_size, padding=padding bboxes, format=bboxes_format, canvas_size=bboxes_canvas_size, padding=padding
) )
torch.testing.assert_close(output_spatial_size, _compute_expected_spatial_size(bboxes, padding)) torch.testing.assert_close(output_canvas_size, _compute_expected_canvas_size(bboxes, padding))
if bboxes.ndim < 2 or bboxes.shape[0] == 0: if bboxes.ndim < 2 or bboxes.shape[0] == 0:
bboxes = [bboxes] bboxes = [bboxes]
expected_bboxes = [] expected_bboxes = []
for bbox in bboxes: for bbox in bboxes:
bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, canvas_size=bboxes_canvas_size)
expected_bboxes.append(_compute_expected_bbox(bbox, padding)) expected_bboxes.append(_compute_expected_bbox(bbox, padding))
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
...@@ -887,24 +887,24 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints): ...@@ -887,24 +887,24 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
out_bbox = datapoints.BoundingBoxes( out_bbox = datapoints.BoundingBoxes(
out_bbox, out_bbox,
format=datapoints.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=bbox.spatial_size, canvas_size=bbox.canvas_size,
dtype=bbox.dtype, dtype=bbox.dtype,
device=bbox.device, device=bbox.device,
) )
return clamp_bounding_boxes(convert_format_bounding_boxes(out_bbox, new_format=bbox.format)) return clamp_bounding_boxes(convert_format_bounding_boxes(out_bbox, new_format=bbox.format))
spatial_size = (32, 38) canvas_size = (32, 38)
pcoeffs = _get_perspective_coeffs(startpoints, endpoints) pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints) inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints)
for bboxes in make_bounding_boxes(spatial_size=spatial_size, extra_dims=((4,),)): for bboxes in make_bounding_boxes(canvas_size=canvas_size, extra_dims=((4,),)):
bboxes = bboxes.to(device) bboxes = bboxes.to(device)
output_bboxes = F.perspective_bounding_boxes( output_bboxes = F.perspective_bounding_boxes(
bboxes.as_subclass(torch.Tensor), bboxes.as_subclass(torch.Tensor),
format=bboxes.format, format=bboxes.format,
spatial_size=bboxes.spatial_size, canvas_size=bboxes.canvas_size,
startpoints=None, startpoints=None,
endpoints=None, endpoints=None,
coefficients=pcoeffs, coefficients=pcoeffs,
...@@ -915,7 +915,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints): ...@@ -915,7 +915,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
expected_bboxes = [] expected_bboxes = []
for bbox in bboxes: for bbox in bboxes:
bbox = datapoints.BoundingBoxes(bbox, format=bboxes.format, spatial_size=bboxes.spatial_size) bbox = datapoints.BoundingBoxes(bbox, format=bboxes.format, canvas_size=bboxes.canvas_size)
expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs)) expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs))
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes) expected_bboxes = torch.stack(expected_bboxes)
...@@ -932,15 +932,15 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints): ...@@ -932,15 +932,15 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
def test_correctness_center_crop_bounding_boxes(device, output_size): def test_correctness_center_crop_bounding_boxes(device, output_size):
def _compute_expected_bbox(bbox, output_size_): def _compute_expected_bbox(bbox, output_size_):
format_ = bbox.format format_ = bbox.format
spatial_size_ = bbox.spatial_size canvas_size_ = bbox.canvas_size
dtype = bbox.dtype dtype = bbox.dtype
bbox = convert_format_bounding_boxes(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH) bbox = convert_format_bounding_boxes(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH)
if len(output_size_) == 1: if len(output_size_) == 1:
output_size_.append(output_size_[-1]) output_size_.append(output_size_[-1])
cy = int(round((spatial_size_[0] - output_size_[0]) * 0.5)) cy = int(round((canvas_size_[0] - output_size_[0]) * 0.5))
cx = int(round((spatial_size_[1] - output_size_[1]) * 0.5)) cx = int(round((canvas_size_[1] - output_size_[1]) * 0.5))
out_bbox = [ out_bbox = [
bbox[0].item() - cx, bbox[0].item() - cx,
bbox[1].item() - cy, bbox[1].item() - cy,
...@@ -949,16 +949,16 @@ def test_correctness_center_crop_bounding_boxes(device, output_size): ...@@ -949,16 +949,16 @@ def test_correctness_center_crop_bounding_boxes(device, output_size):
] ]
out_bbox = torch.tensor(out_bbox) out_bbox = torch.tensor(out_bbox)
out_bbox = convert_format_bounding_boxes(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_) out_bbox = convert_format_bounding_boxes(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_)
out_bbox = clamp_bounding_boxes(out_bbox, format=format_, spatial_size=output_size) out_bbox = clamp_bounding_boxes(out_bbox, format=format_, canvas_size=output_size)
return out_bbox.to(dtype=dtype, device=bbox.device) return out_bbox.to(dtype=dtype, device=bbox.device)
for bboxes in make_bounding_boxes(extra_dims=((4,),)): for bboxes in make_bounding_boxes(extra_dims=((4,),)):
bboxes = bboxes.to(device) bboxes = bboxes.to(device)
bboxes_format = bboxes.format bboxes_format = bboxes.format
bboxes_spatial_size = bboxes.spatial_size bboxes_canvas_size = bboxes.canvas_size
output_boxes, output_spatial_size = F.center_crop_bounding_boxes( output_boxes, output_canvas_size = F.center_crop_bounding_boxes(
bboxes, bboxes_format, bboxes_spatial_size, output_size bboxes, bboxes_format, bboxes_canvas_size, output_size
) )
if bboxes.ndim < 2: if bboxes.ndim < 2:
...@@ -966,7 +966,7 @@ def test_correctness_center_crop_bounding_boxes(device, output_size): ...@@ -966,7 +966,7 @@ def test_correctness_center_crop_bounding_boxes(device, output_size):
expected_bboxes = [] expected_bboxes = []
for bbox in bboxes: for bbox in bboxes:
bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, canvas_size=bboxes_canvas_size)
expected_bboxes.append(_compute_expected_bbox(bbox, output_size)) expected_bboxes.append(_compute_expected_bbox(bbox, output_size))
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
...@@ -975,7 +975,7 @@ def test_correctness_center_crop_bounding_boxes(device, output_size): ...@@ -975,7 +975,7 @@ def test_correctness_center_crop_bounding_boxes(device, output_size):
expected_bboxes = expected_bboxes[0] expected_bboxes = expected_bboxes[0]
torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0) torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
torch.testing.assert_close(output_spatial_size, output_size) torch.testing.assert_close(output_canvas_size, output_size)
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
...@@ -1003,11 +1003,11 @@ def test_correctness_center_crop_mask(device, output_size): ...@@ -1003,11 +1003,11 @@ def test_correctness_center_crop_mask(device, output_size):
# Copied from test/test_functional_tensor.py # Copied from test/test_functional_tensor.py
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("spatial_size", ("small", "large")) @pytest.mark.parametrize("canvas_size", ("small", "large"))
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)]) @pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)])
@pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)]) @pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)])
def test_correctness_gaussian_blur_image_tensor(device, spatial_size, dt, ksize, sigma): def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize, sigma):
fn = F.gaussian_blur_image_tensor fn = F.gaussian_blur_image_tensor
# true_cv2_results = { # true_cv2_results = {
...@@ -1027,7 +1027,7 @@ def test_correctness_gaussian_blur_image_tensor(device, spatial_size, dt, ksize, ...@@ -1027,7 +1027,7 @@ def test_correctness_gaussian_blur_image_tensor(device, spatial_size, dt, ksize,
p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt") p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt")
true_cv2_results = torch.load(p) true_cv2_results = torch.load(p)
if spatial_size == "small": if canvas_size == "small":
tensor = ( tensor = (
torch.from_numpy(np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))).permute(2, 0, 1).to(device) torch.from_numpy(np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))).permute(2, 0, 1).to(device)
) )
......
...@@ -392,7 +392,7 @@ def assert_warns_antialias_default_value(): ...@@ -392,7 +392,7 @@ def assert_warns_antialias_default_value():
yield yield
def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, spatial_size, affine_matrix): def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_size, affine_matrix):
def transform(bbox): def transform(bbox):
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1 # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
in_dtype = bbox.dtype in_dtype = bbox.dtype
...@@ -426,7 +426,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, spatial_si ...@@ -426,7 +426,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, spatial_si
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
) )
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64 # It is important to clamp before casting, especially for CXCYWH format, dtype=int64
out_bbox = F.clamp_bounding_boxes(out_bbox, format=format, spatial_size=spatial_size) out_bbox = F.clamp_bounding_boxes(out_bbox, format=format, canvas_size=canvas_size)
out_bbox = out_bbox.to(dtype=in_dtype) out_bbox = out_bbox.to(dtype=in_dtype)
return out_bbox return out_bbox
...@@ -514,14 +514,14 @@ class TestResize: ...@@ -514,14 +514,14 @@ class TestResize:
bounding_boxes = make_bounding_box( bounding_boxes = make_bounding_box(
format=format, format=format,
spatial_size=self.INPUT_SIZE, canvas_size=self.INPUT_SIZE,
dtype=dtype, dtype=dtype,
device=device, device=device,
) )
check_kernel( check_kernel(
F.resize_bounding_boxes, F.resize_bounding_boxes,
bounding_boxes, bounding_boxes,
spatial_size=bounding_boxes.spatial_size, canvas_size=bounding_boxes.canvas_size,
size=size, size=size,
**max_size_kwarg, **max_size_kwarg,
check_scripted_vs_eager=not isinstance(size, int), check_scripted_vs_eager=not isinstance(size, int),
...@@ -588,8 +588,8 @@ class TestResize: ...@@ -588,8 +588,8 @@ class TestResize:
check_transform(transforms.Resize, make_input(self.INPUT_SIZE, device=device), size=size, antialias=True) check_transform(transforms.Resize, make_input(self.INPUT_SIZE, device=device), size=size, antialias=True)
def _check_output_size(self, input, output, *, size, max_size): def _check_output_size(self, input, output, *, size, max_size):
assert tuple(F.get_spatial_size(output)) == self._compute_output_size( assert tuple(F.get_size(output)) == self._compute_output_size(
input_size=F.get_spatial_size(input), size=size, max_size=max_size input_size=F.get_size(input), size=size, max_size=max_size
) )
@pytest.mark.parametrize("size", OUTPUT_SIZES) @pytest.mark.parametrize("size", OUTPUT_SIZES)
...@@ -613,9 +613,9 @@ class TestResize: ...@@ -613,9 +613,9 @@ class TestResize:
torch.testing.assert_close(actual, expected, atol=1, rtol=0) torch.testing.assert_close(actual, expected, atol=1, rtol=0)
def _reference_resize_bounding_boxes(self, bounding_boxes, *, size, max_size=None): def _reference_resize_bounding_boxes(self, bounding_boxes, *, size, max_size=None):
old_height, old_width = bounding_boxes.spatial_size old_height, old_width = bounding_boxes.canvas_size
new_height, new_width = self._compute_output_size( new_height, new_width = self._compute_output_size(
input_size=bounding_boxes.spatial_size, size=size, max_size=max_size input_size=bounding_boxes.canvas_size, size=size, max_size=max_size
) )
if (old_height, old_width) == (new_height, new_width): if (old_height, old_width) == (new_height, new_width):
...@@ -632,10 +632,10 @@ class TestResize: ...@@ -632,10 +632,10 @@ class TestResize:
expected_bboxes = reference_affine_bounding_boxes_helper( expected_bboxes = reference_affine_bounding_boxes_helper(
bounding_boxes, bounding_boxes,
format=bounding_boxes.format, format=bounding_boxes.format,
spatial_size=(new_height, new_width), canvas_size=(new_height, new_width),
affine_matrix=affine_matrix, affine_matrix=affine_matrix,
) )
return datapoints.BoundingBoxes.wrap_like(bounding_boxes, expected_bboxes, spatial_size=(new_height, new_width)) return datapoints.BoundingBoxes.wrap_like(bounding_boxes, expected_bboxes, canvas_size=(new_height, new_width))
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("size", OUTPUT_SIZES) @pytest.mark.parametrize("size", OUTPUT_SIZES)
...@@ -645,7 +645,7 @@ class TestResize: ...@@ -645,7 +645,7 @@ class TestResize:
if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)): if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)):
return return
bounding_boxes = make_bounding_box(format=format, spatial_size=self.INPUT_SIZE) bounding_boxes = make_bounding_box(format=format, canvas_size=self.INPUT_SIZE)
actual = fn(bounding_boxes, size=size, **max_size_kwarg) actual = fn(bounding_boxes, size=size, **max_size_kwarg)
expected = self._reference_resize_bounding_boxes(bounding_boxes, size=size, **max_size_kwarg) expected = self._reference_resize_bounding_boxes(bounding_boxes, size=size, **max_size_kwarg)
...@@ -762,7 +762,7 @@ class TestResize: ...@@ -762,7 +762,7 @@ class TestResize:
def test_noop(self, size, make_input): def test_noop(self, size, make_input):
input = make_input(self.INPUT_SIZE) input = make_input(self.INPUT_SIZE)
output = F.resize(input, size=F.get_spatial_size(input), antialias=True) output = F.resize(input, size=F.get_size(input), antialias=True)
# This identity check is not a requirement. It is here to avoid breaking the behavior by accident. If there # This identity check is not a requirement. It is here to avoid breaking the behavior by accident. If there
# is a good reason to break this, feel free to downgrade to an equality check. # is a good reason to break this, feel free to downgrade to an equality check.
...@@ -792,11 +792,11 @@ class TestResize: ...@@ -792,11 +792,11 @@ class TestResize:
input = make_input(self.INPUT_SIZE) input = make_input(self.INPUT_SIZE)
size = min(F.get_spatial_size(input)) size = min(F.get_size(input))
max_size = size + 1 max_size = size + 1
output = F.resize(input, size=size, max_size=max_size, antialias=True) output = F.resize(input, size=size, max_size=max_size, antialias=True)
assert max(F.get_spatial_size(output)) == max_size assert max(F.get_size(output)) == max_size
class TestHorizontalFlip: class TestHorizontalFlip:
...@@ -814,7 +814,7 @@ class TestHorizontalFlip: ...@@ -814,7 +814,7 @@ class TestHorizontalFlip:
F.horizontal_flip_bounding_boxes, F.horizontal_flip_bounding_boxes,
bounding_boxes, bounding_boxes,
format=format, format=format,
spatial_size=bounding_boxes.spatial_size, canvas_size=bounding_boxes.canvas_size,
) )
@pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_mask]) @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_mask])
...@@ -874,7 +874,7 @@ class TestHorizontalFlip: ...@@ -874,7 +874,7 @@ class TestHorizontalFlip:
def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes): def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes):
affine_matrix = np.array( affine_matrix = np.array(
[ [
[-1, 0, bounding_boxes.spatial_size[1]], [-1, 0, bounding_boxes.canvas_size[1]],
[0, 1, 0], [0, 1, 0],
], ],
dtype="float64" if bounding_boxes.dtype == torch.float64 else "float32", dtype="float64" if bounding_boxes.dtype == torch.float64 else "float32",
...@@ -883,7 +883,7 @@ class TestHorizontalFlip: ...@@ -883,7 +883,7 @@ class TestHorizontalFlip:
expected_bboxes = reference_affine_bounding_boxes_helper( expected_bboxes = reference_affine_bounding_boxes_helper(
bounding_boxes, bounding_boxes,
format=bounding_boxes.format, format=bounding_boxes.format,
spatial_size=bounding_boxes.spatial_size, canvas_size=bounding_boxes.canvas_size,
affine_matrix=affine_matrix, affine_matrix=affine_matrix,
) )
...@@ -995,7 +995,7 @@ class TestAffine: ...@@ -995,7 +995,7 @@ class TestAffine:
F.affine_bounding_boxes, F.affine_bounding_boxes,
bounding_boxes, bounding_boxes,
format=format, format=format,
spatial_size=bounding_boxes.spatial_size, canvas_size=bounding_boxes.canvas_size,
**{param: value}, **{param: value},
check_scripted_vs_eager=not (param == "shear" and isinstance(value, (int, float))), check_scripted_vs_eager=not (param == "shear" and isinstance(value, (int, float))),
) )
...@@ -1133,7 +1133,7 @@ class TestAffine: ...@@ -1133,7 +1133,7 @@ class TestAffine:
def _reference_affine_bounding_boxes(self, bounding_boxes, *, angle, translate, scale, shear, center): def _reference_affine_bounding_boxes(self, bounding_boxes, *, angle, translate, scale, shear, center):
if center is None: if center is None:
center = [s * 0.5 for s in bounding_boxes.spatial_size[::-1]] center = [s * 0.5 for s in bounding_boxes.canvas_size[::-1]]
affine_matrix = self._compute_affine_matrix( affine_matrix = self._compute_affine_matrix(
angle=angle, translate=translate, scale=scale, shear=shear, center=center angle=angle, translate=translate, scale=scale, shear=shear, center=center
...@@ -1143,7 +1143,7 @@ class TestAffine: ...@@ -1143,7 +1143,7 @@ class TestAffine:
expected_bboxes = reference_affine_bounding_boxes_helper( expected_bboxes = reference_affine_bounding_boxes_helper(
bounding_boxes, bounding_boxes,
format=bounding_boxes.format, format=bounding_boxes.format,
spatial_size=bounding_boxes.spatial_size, canvas_size=bounding_boxes.canvas_size,
affine_matrix=affine_matrix, affine_matrix=affine_matrix,
) )
...@@ -1202,7 +1202,7 @@ class TestAffine: ...@@ -1202,7 +1202,7 @@ class TestAffine:
@pytest.mark.parametrize("seed", list(range(10))) @pytest.mark.parametrize("seed", list(range(10)))
def test_transform_get_params_bounds(self, degrees, translate, scale, shear, seed): def test_transform_get_params_bounds(self, degrees, translate, scale, shear, seed):
image = make_image() image = make_image()
height, width = F.get_spatial_size(image) height, width = F.get_size(image)
transform = transforms.RandomAffine(degrees=degrees, translate=translate, scale=scale, shear=shear) transform = transforms.RandomAffine(degrees=degrees, translate=translate, scale=scale, shear=shear)
...@@ -1293,7 +1293,7 @@ class TestVerticalFlip: ...@@ -1293,7 +1293,7 @@ class TestVerticalFlip:
F.vertical_flip_bounding_boxes, F.vertical_flip_bounding_boxes,
bounding_boxes, bounding_boxes,
format=format, format=format,
spatial_size=bounding_boxes.spatial_size, canvas_size=bounding_boxes.canvas_size,
) )
@pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_mask]) @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_mask])
...@@ -1352,7 +1352,7 @@ class TestVerticalFlip: ...@@ -1352,7 +1352,7 @@ class TestVerticalFlip:
affine_matrix = np.array( affine_matrix = np.array(
[ [
[1, 0, 0], [1, 0, 0],
[0, -1, bounding_boxes.spatial_size[0]], [0, -1, bounding_boxes.canvas_size[0]],
], ],
dtype="float64" if bounding_boxes.dtype == torch.float64 else "float32", dtype="float64" if bounding_boxes.dtype == torch.float64 else "float32",
) )
...@@ -1360,7 +1360,7 @@ class TestVerticalFlip: ...@@ -1360,7 +1360,7 @@ class TestVerticalFlip:
expected_bboxes = reference_affine_bounding_boxes_helper( expected_bboxes = reference_affine_bounding_boxes_helper(
bounding_boxes, bounding_boxes,
format=bounding_boxes.format, format=bounding_boxes.format,
spatial_size=bounding_boxes.spatial_size, canvas_size=bounding_boxes.canvas_size,
affine_matrix=affine_matrix, affine_matrix=affine_matrix,
) )
...@@ -1449,7 +1449,7 @@ class TestRotate: ...@@ -1449,7 +1449,7 @@ class TestRotate:
F.rotate_bounding_boxes, F.rotate_bounding_boxes,
bounding_boxes, bounding_boxes,
format=format, format=format,
spatial_size=bounding_boxes.spatial_size, canvas_size=bounding_boxes.canvas_size,
**kwargs, **kwargs,
) )
...@@ -1555,7 +1555,7 @@ class TestRotate: ...@@ -1555,7 +1555,7 @@ class TestRotate:
raise ValueError("This reference currently does not support expand=True") raise ValueError("This reference currently does not support expand=True")
if center is None: if center is None:
center = [s * 0.5 for s in bounding_boxes.spatial_size[::-1]] center = [s * 0.5 for s in bounding_boxes.canvas_size[::-1]]
a = np.cos(angle * np.pi / 180.0) a = np.cos(angle * np.pi / 180.0)
b = np.sin(angle * np.pi / 180.0) b = np.sin(angle * np.pi / 180.0)
...@@ -1572,7 +1572,7 @@ class TestRotate: ...@@ -1572,7 +1572,7 @@ class TestRotate:
expected_bboxes = reference_affine_bounding_boxes_helper( expected_bboxes = reference_affine_bounding_boxes_helper(
bounding_boxes, bounding_boxes,
format=bounding_boxes.format, format=bounding_boxes.format,
spatial_size=bounding_boxes.spatial_size, canvas_size=bounding_boxes.canvas_size,
affine_matrix=affine_matrix, affine_matrix=affine_matrix,
) )
...@@ -1834,7 +1834,7 @@ class TestToDtype: ...@@ -1834,7 +1834,7 @@ class TestToDtype:
mask_dtype = torch.bool mask_dtype = torch.bool
sample = { sample = {
"inpt": make_input(size=(H, W), dtype=inpt_dtype), "inpt": make_input(size=(H, W), dtype=inpt_dtype),
"bbox": make_bounding_box(size=(H, W), dtype=bbox_dtype), "bbox": make_bounding_box(canvas_size=(H, W), dtype=bbox_dtype),
"mask": make_detection_mask(size=(H, W), dtype=mask_dtype), "mask": make_detection_mask(size=(H, W), dtype=mask_dtype),
} }
...@@ -1988,7 +1988,7 @@ class TestCutMixMixUp: ...@@ -1988,7 +1988,7 @@ class TestCutMixMixUp:
for input_with_bad_type in ( for input_with_bad_type in (
F.to_pil_image(imgs[0]), F.to_pil_image(imgs[0]),
datapoints.Mask(torch.rand(12, 12)), datapoints.Mask(torch.rand(12, 12)),
datapoints.BoundingBoxes(torch.rand(2, 4), format="XYXY", spatial_size=12), datapoints.BoundingBoxes(torch.rand(2, 4), format="XYXY", canvas_size=12),
): ):
with pytest.raises(ValueError, match="does not support PIL images, "): with pytest.raises(ValueError, match="does not support PIL images, "):
cutmix_mixup(input_with_bad_type) cutmix_mixup(input_with_bad_type)
......
...@@ -4,16 +4,16 @@ import pytest ...@@ -4,16 +4,16 @@ import pytest
import torch import torch
import torchvision.transforms.v2.utils import torchvision.transforms.v2.utils
from common_utils import make_bounding_box, make_detection_mask, make_image from common_utils import DEFAULT_SIZE, make_bounding_box, make_detection_mask, make_image
from torchvision import datapoints from torchvision import datapoints
from torchvision.transforms.v2.functional import to_image_pil from torchvision.transforms.v2.functional import to_image_pil
from torchvision.transforms.v2.utils import has_all, has_any from torchvision.transforms.v2.utils import has_all, has_any
IMAGE = make_image(color_space="RGB") IMAGE = make_image(DEFAULT_SIZE, color_space="RGB")
BOUNDING_BOX = make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, spatial_size=IMAGE.spatial_size) BOUNDING_BOX = make_bounding_box(DEFAULT_SIZE, format=datapoints.BoundingBoxFormat.XYXY)
MASK = make_detection_mask(size=IMAGE.spatial_size) MASK = make_detection_mask(DEFAULT_SIZE)
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -184,8 +184,8 @@ def float32_vs_uint8_fill_adapter(other_args, kwargs): ...@@ -184,8 +184,8 @@ def float32_vs_uint8_fill_adapter(other_args, kwargs):
return other_args, dict(kwargs, fill=fill) return other_args, dict(kwargs, fill=fill)
def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, spatial_size, affine_matrix): def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_size, affine_matrix):
def transform(bbox, affine_matrix_, format_, spatial_size_): def transform(bbox, affine_matrix_, format_, canvas_size_):
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1 # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
in_dtype = bbox.dtype in_dtype = bbox.dtype
if not torch.is_floating_point(bbox): if not torch.is_floating_point(bbox):
...@@ -218,14 +218,14 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, spatial_si ...@@ -218,14 +218,14 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, spatial_si
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True
) )
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64 # It is important to clamp before casting, especially for CXCYWH format, dtype=int64
out_bbox = F.clamp_bounding_boxes(out_bbox, format=format_, spatial_size=spatial_size_) out_bbox = F.clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_)
out_bbox = out_bbox.to(dtype=in_dtype) out_bbox = out_bbox.to(dtype=in_dtype)
return out_bbox return out_bbox
if bounding_boxes.ndim < 2: if bounding_boxes.ndim < 2:
bounding_boxes = [bounding_boxes] bounding_boxes = [bounding_boxes]
expected_bboxes = [transform(bbox, affine_matrix, format, spatial_size) for bbox in bounding_boxes] expected_bboxes = [transform(bbox, affine_matrix, format, canvas_size) for bbox in bounding_boxes]
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes) expected_bboxes = torch.stack(expected_bboxes)
else: else:
...@@ -321,11 +321,11 @@ def reference_crop_bounding_boxes(bounding_boxes, *, format, top, left, height, ...@@ -321,11 +321,11 @@ def reference_crop_bounding_boxes(bounding_boxes, *, format, top, left, height,
dtype="float64" if bounding_boxes.dtype == torch.float64 else "float32", dtype="float64" if bounding_boxes.dtype == torch.float64 else "float32",
) )
spatial_size = (height, width) canvas_size = (height, width)
expected_bboxes = reference_affine_bounding_boxes_helper( expected_bboxes = reference_affine_bounding_boxes_helper(
bounding_boxes, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix bounding_boxes, format=format, canvas_size=canvas_size, affine_matrix=affine_matrix
) )
return expected_bboxes, spatial_size return expected_bboxes, canvas_size
def reference_inputs_crop_bounding_boxes(): def reference_inputs_crop_bounding_boxes():
...@@ -507,7 +507,7 @@ def sample_inputs_pad_bounding_boxes(): ...@@ -507,7 +507,7 @@ def sample_inputs_pad_bounding_boxes():
yield ArgsKwargs( yield ArgsKwargs(
bounding_boxes_loader, bounding_boxes_loader,
format=bounding_boxes_loader.format, format=bounding_boxes_loader.format,
spatial_size=bounding_boxes_loader.spatial_size, canvas_size=bounding_boxes_loader.canvas_size,
padding=padding, padding=padding,
padding_mode="constant", padding_mode="constant",
) )
...@@ -530,7 +530,7 @@ def sample_inputs_pad_video(): ...@@ -530,7 +530,7 @@ def sample_inputs_pad_video():
yield ArgsKwargs(video_loader, padding=[1]) yield ArgsKwargs(video_loader, padding=[1])
def reference_pad_bounding_boxes(bounding_boxes, *, format, spatial_size, padding, padding_mode): def reference_pad_bounding_boxes(bounding_boxes, *, format, canvas_size, padding, padding_mode):
left, right, top, bottom = _parse_pad_padding(padding) left, right, top, bottom = _parse_pad_padding(padding)
...@@ -542,11 +542,11 @@ def reference_pad_bounding_boxes(bounding_boxes, *, format, spatial_size, paddin ...@@ -542,11 +542,11 @@ def reference_pad_bounding_boxes(bounding_boxes, *, format, spatial_size, paddin
dtype="float64" if bounding_boxes.dtype == torch.float64 else "float32", dtype="float64" if bounding_boxes.dtype == torch.float64 else "float32",
) )
height = spatial_size[0] + top + bottom height = canvas_size[0] + top + bottom
width = spatial_size[1] + left + right width = canvas_size[1] + left + right
expected_bboxes = reference_affine_bounding_boxes_helper( expected_bboxes = reference_affine_bounding_boxes_helper(
bounding_boxes, format=format, spatial_size=(height, width), affine_matrix=affine_matrix bounding_boxes, format=format, canvas_size=(height, width), affine_matrix=affine_matrix
) )
return expected_bboxes, (height, width) return expected_bboxes, (height, width)
...@@ -558,7 +558,7 @@ def reference_inputs_pad_bounding_boxes(): ...@@ -558,7 +558,7 @@ def reference_inputs_pad_bounding_boxes():
yield ArgsKwargs( yield ArgsKwargs(
bounding_boxes_loader, bounding_boxes_loader,
format=bounding_boxes_loader.format, format=bounding_boxes_loader.format,
spatial_size=bounding_boxes_loader.spatial_size, canvas_size=bounding_boxes_loader.canvas_size,
padding=padding, padding=padding,
padding_mode="constant", padding_mode="constant",
) )
...@@ -660,7 +660,7 @@ def sample_inputs_perspective_bounding_boxes(): ...@@ -660,7 +660,7 @@ def sample_inputs_perspective_bounding_boxes():
yield ArgsKwargs( yield ArgsKwargs(
bounding_boxes_loader, bounding_boxes_loader,
format=bounding_boxes_loader.format, format=bounding_boxes_loader.format,
spatial_size=bounding_boxes_loader.spatial_size, canvas_size=bounding_boxes_loader.canvas_size,
startpoints=None, startpoints=None,
endpoints=None, endpoints=None,
coefficients=_PERSPECTIVE_COEFFS[0], coefficients=_PERSPECTIVE_COEFFS[0],
...@@ -669,7 +669,7 @@ def sample_inputs_perspective_bounding_boxes(): ...@@ -669,7 +669,7 @@ def sample_inputs_perspective_bounding_boxes():
format = datapoints.BoundingBoxFormat.XYXY format = datapoints.BoundingBoxFormat.XYXY
loader = make_bounding_box_loader(format=format) loader = make_bounding_box_loader(format=format)
yield ArgsKwargs( yield ArgsKwargs(
loader, format=format, spatial_size=loader.spatial_size, startpoints=_STARTPOINTS, endpoints=_ENDPOINTS loader, format=format, canvas_size=loader.canvas_size, startpoints=_STARTPOINTS, endpoints=_ENDPOINTS
) )
...@@ -742,13 +742,13 @@ KERNEL_INFOS.extend( ...@@ -742,13 +742,13 @@ KERNEL_INFOS.extend(
) )
def _get_elastic_displacement(spatial_size): def _get_elastic_displacement(canvas_size):
return torch.rand(1, *spatial_size, 2) return torch.rand(1, *canvas_size, 2)
def sample_inputs_elastic_image_tensor(): def sample_inputs_elastic_image_tensor():
for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]): for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]):
displacement = _get_elastic_displacement(image_loader.spatial_size) displacement = _get_elastic_displacement(image_loader.canvas_size)
for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype): for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, displacement=displacement, fill=fill) yield ArgsKwargs(image_loader, displacement=displacement, fill=fill)
...@@ -762,18 +762,18 @@ def reference_inputs_elastic_image_tensor(): ...@@ -762,18 +762,18 @@ def reference_inputs_elastic_image_tensor():
F.InterpolationMode.BICUBIC, F.InterpolationMode.BICUBIC,
], ],
): ):
displacement = _get_elastic_displacement(image_loader.spatial_size) displacement = _get_elastic_displacement(image_loader.canvas_size)
for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype): for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, interpolation=interpolation, displacement=displacement, fill=fill) yield ArgsKwargs(image_loader, interpolation=interpolation, displacement=displacement, fill=fill)
def sample_inputs_elastic_bounding_boxes(): def sample_inputs_elastic_bounding_boxes():
for bounding_boxes_loader in make_bounding_box_loaders(): for bounding_boxes_loader in make_bounding_box_loaders():
displacement = _get_elastic_displacement(bounding_boxes_loader.spatial_size) displacement = _get_elastic_displacement(bounding_boxes_loader.canvas_size)
yield ArgsKwargs( yield ArgsKwargs(
bounding_boxes_loader, bounding_boxes_loader,
format=bounding_boxes_loader.format, format=bounding_boxes_loader.format,
spatial_size=bounding_boxes_loader.spatial_size, canvas_size=bounding_boxes_loader.canvas_size,
displacement=displacement, displacement=displacement,
) )
...@@ -850,7 +850,7 @@ def sample_inputs_center_crop_bounding_boxes(): ...@@ -850,7 +850,7 @@ def sample_inputs_center_crop_bounding_boxes():
yield ArgsKwargs( yield ArgsKwargs(
bounding_boxes_loader, bounding_boxes_loader,
format=bounding_boxes_loader.format, format=bounding_boxes_loader.format,
spatial_size=bounding_boxes_loader.spatial_size, canvas_size=bounding_boxes_loader.canvas_size,
output_size=output_size, output_size=output_size,
) )
...@@ -975,7 +975,7 @@ def reference_inputs_equalize_image_tensor(): ...@@ -975,7 +975,7 @@ def reference_inputs_equalize_image_tensor():
image.mul_(torch.iinfo(dtype).max).round_() image.mul_(torch.iinfo(dtype).max).round_()
return image.to(dtype=dtype, device=device, memory_format=memory_format, copy=True) return image.to(dtype=dtype, device=device, memory_format=memory_format, copy=True)
spatial_size = (256, 256) canvas_size = (256, 256)
for dtype, color_space, fn in itertools.product( for dtype, color_space, fn in itertools.product(
[torch.uint8], [torch.uint8],
["GRAY", "RGB"], ["GRAY", "RGB"],
...@@ -1005,7 +1005,7 @@ def reference_inputs_equalize_image_tensor(): ...@@ -1005,7 +1005,7 @@ def reference_inputs_equalize_image_tensor():
], ],
], ],
): ):
image_loader = ImageLoader(fn, shape=(get_num_channels(color_space), *spatial_size), dtype=dtype) image_loader = ImageLoader(fn, shape=(get_num_channels(color_space), *canvas_size), dtype=dtype)
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
...@@ -1487,7 +1487,7 @@ def sample_inputs_clamp_bounding_boxes(): ...@@ -1487,7 +1487,7 @@ def sample_inputs_clamp_bounding_boxes():
yield ArgsKwargs( yield ArgsKwargs(
bounding_boxes_loader, bounding_boxes_loader,
format=bounding_boxes_loader.format, format=bounding_boxes_loader.format,
spatial_size=bounding_boxes_loader.spatial_size, canvas_size=bounding_boxes_loader.canvas_size,
) )
...@@ -1502,7 +1502,7 @@ KERNEL_INFOS.append( ...@@ -1502,7 +1502,7 @@ KERNEL_INFOS.append(
_FIVE_TEN_CROP_SIZES = [7, (6,), [5], (6, 5), [7, 6]] _FIVE_TEN_CROP_SIZES = [7, (6,), [5], (6, 5), [7, 6]]
def _get_five_ten_crop_spatial_size(size): def _get_five_ten_crop_canvas_size(size):
if isinstance(size, int): if isinstance(size, int):
crop_height = crop_width = size crop_height = crop_width = size
elif len(size) == 1: elif len(size) == 1:
...@@ -1515,7 +1515,7 @@ def _get_five_ten_crop_spatial_size(size): ...@@ -1515,7 +1515,7 @@ def _get_five_ten_crop_spatial_size(size):
def sample_inputs_five_crop_image_tensor(): def sample_inputs_five_crop_image_tensor():
for size in _FIVE_TEN_CROP_SIZES: for size in _FIVE_TEN_CROP_SIZES:
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=[_get_five_ten_crop_spatial_size(size)], sizes=[_get_five_ten_crop_canvas_size(size)],
color_spaces=["RGB"], color_spaces=["RGB"],
dtypes=[torch.float32], dtypes=[torch.float32],
): ):
...@@ -1525,21 +1525,21 @@ def sample_inputs_five_crop_image_tensor(): ...@@ -1525,21 +1525,21 @@ def sample_inputs_five_crop_image_tensor():
def reference_inputs_five_crop_image_tensor(): def reference_inputs_five_crop_image_tensor():
for size in _FIVE_TEN_CROP_SIZES: for size in _FIVE_TEN_CROP_SIZES:
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=[_get_five_ten_crop_spatial_size(size)], extra_dims=[()], dtypes=[torch.uint8] sizes=[_get_five_ten_crop_canvas_size(size)], extra_dims=[()], dtypes=[torch.uint8]
): ):
yield ArgsKwargs(image_loader, size=size) yield ArgsKwargs(image_loader, size=size)
def sample_inputs_five_crop_video(): def sample_inputs_five_crop_video():
size = _FIVE_TEN_CROP_SIZES[0] size = _FIVE_TEN_CROP_SIZES[0]
for video_loader in make_video_loaders(sizes=[_get_five_ten_crop_spatial_size(size)]): for video_loader in make_video_loaders(sizes=[_get_five_ten_crop_canvas_size(size)]):
yield ArgsKwargs(video_loader, size=size) yield ArgsKwargs(video_loader, size=size)
def sample_inputs_ten_crop_image_tensor(): def sample_inputs_ten_crop_image_tensor():
for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]): for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=[_get_five_ten_crop_spatial_size(size)], sizes=[_get_five_ten_crop_canvas_size(size)],
color_spaces=["RGB"], color_spaces=["RGB"],
dtypes=[torch.float32], dtypes=[torch.float32],
): ):
...@@ -1549,14 +1549,14 @@ def sample_inputs_ten_crop_image_tensor(): ...@@ -1549,14 +1549,14 @@ def sample_inputs_ten_crop_image_tensor():
def reference_inputs_ten_crop_image_tensor(): def reference_inputs_ten_crop_image_tensor():
for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]): for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=[_get_five_ten_crop_spatial_size(size)], extra_dims=[()], dtypes=[torch.uint8] sizes=[_get_five_ten_crop_canvas_size(size)], extra_dims=[()], dtypes=[torch.uint8]
): ):
yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip) yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)
def sample_inputs_ten_crop_video(): def sample_inputs_ten_crop_video():
size = _FIVE_TEN_CROP_SIZES[0] size = _FIVE_TEN_CROP_SIZES[0]
for video_loader in make_video_loaders(sizes=[_get_five_ten_crop_spatial_size(size)]): for video_loader in make_video_loaders(sizes=[_get_five_ten_crop_canvas_size(size)]):
yield ArgsKwargs(video_loader, size=size) yield ArgsKwargs(video_loader, size=size)
......
...@@ -30,7 +30,7 @@ class BoundingBoxes(Datapoint): ...@@ -30,7 +30,7 @@ class BoundingBoxes(Datapoint):
Args: Args:
data: Any data that can be turned into a tensor with :func:`torch.as_tensor`. data: Any data that can be turned into a tensor with :func:`torch.as_tensor`.
format (BoundingBoxFormat, str): Format of the bounding box. format (BoundingBoxFormat, str): Format of the bounding box.
spatial_size (two-tuple of ints): Height and width of the corresponding image or video. canvas_size (two-tuple of ints): Height and width of the corresponding image or video.
dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from
``data``. ``data``.
device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a
...@@ -40,13 +40,13 @@ class BoundingBoxes(Datapoint): ...@@ -40,13 +40,13 @@ class BoundingBoxes(Datapoint):
""" """
format: BoundingBoxFormat format: BoundingBoxFormat
spatial_size: Tuple[int, int] canvas_size: Tuple[int, int]
@classmethod @classmethod
def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, spatial_size: Tuple[int, int]) -> BoundingBoxes: def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, canvas_size: Tuple[int, int]) -> BoundingBoxes:
bounding_boxes = tensor.as_subclass(cls) bounding_boxes = tensor.as_subclass(cls)
bounding_boxes.format = format bounding_boxes.format = format
bounding_boxes.spatial_size = spatial_size bounding_boxes.canvas_size = canvas_size
return bounding_boxes return bounding_boxes
def __new__( def __new__(
...@@ -54,7 +54,7 @@ class BoundingBoxes(Datapoint): ...@@ -54,7 +54,7 @@ class BoundingBoxes(Datapoint):
data: Any, data: Any,
*, *,
format: Union[BoundingBoxFormat, str], format: Union[BoundingBoxFormat, str],
spatial_size: Tuple[int, int], canvas_size: Tuple[int, int],
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None, requires_grad: Optional[bool] = None,
...@@ -64,7 +64,7 @@ class BoundingBoxes(Datapoint): ...@@ -64,7 +64,7 @@ class BoundingBoxes(Datapoint):
if isinstance(format, str): if isinstance(format, str):
format = BoundingBoxFormat[format.upper()] format = BoundingBoxFormat[format.upper()]
return cls._wrap(tensor, format=format, spatial_size=spatial_size) return cls._wrap(tensor, format=format, canvas_size=canvas_size)
@classmethod @classmethod
def wrap_like( def wrap_like(
...@@ -73,7 +73,7 @@ class BoundingBoxes(Datapoint): ...@@ -73,7 +73,7 @@ class BoundingBoxes(Datapoint):
tensor: torch.Tensor, tensor: torch.Tensor,
*, *,
format: Optional[BoundingBoxFormat] = None, format: Optional[BoundingBoxFormat] = None,
spatial_size: Optional[Tuple[int, int]] = None, canvas_size: Optional[Tuple[int, int]] = None,
) -> BoundingBoxes: ) -> BoundingBoxes:
"""Wrap a :class:`torch.Tensor` as :class:`BoundingBoxes` from a reference. """Wrap a :class:`torch.Tensor` as :class:`BoundingBoxes` from a reference.
...@@ -82,7 +82,7 @@ class BoundingBoxes(Datapoint): ...@@ -82,7 +82,7 @@ class BoundingBoxes(Datapoint):
tensor (Tensor): Tensor to be wrapped as :class:`BoundingBoxes` tensor (Tensor): Tensor to be wrapped as :class:`BoundingBoxes`
format (BoundingBoxFormat, str, optional): Format of the bounding box. If omitted, it is taken from the format (BoundingBoxFormat, str, optional): Format of the bounding box. If omitted, it is taken from the
reference. reference.
spatial_size (two-tuple of ints, optional): Height and width of the corresponding image or video. If canvas_size (two-tuple of ints, optional): Height and width of the corresponding image or video. If
omitted, it is taken from the reference. omitted, it is taken from the reference.
""" """
...@@ -92,21 +92,21 @@ class BoundingBoxes(Datapoint): ...@@ -92,21 +92,21 @@ class BoundingBoxes(Datapoint):
return cls._wrap( return cls._wrap(
tensor, tensor,
format=format if format is not None else other.format, format=format if format is not None else other.format,
spatial_size=spatial_size if spatial_size is not None else other.spatial_size, canvas_size=canvas_size if canvas_size is not None else other.canvas_size,
) )
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(format=self.format, spatial_size=self.spatial_size) return self._make_repr(format=self.format, canvas_size=self.canvas_size)
def horizontal_flip(self) -> BoundingBoxes: def horizontal_flip(self) -> BoundingBoxes:
output = self._F.horizontal_flip_bounding_boxes( output = self._F.horizontal_flip_bounding_boxes(
self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size self.as_subclass(torch.Tensor), format=self.format, canvas_size=self.canvas_size
) )
return BoundingBoxes.wrap_like(self, output) return BoundingBoxes.wrap_like(self, output)
def vertical_flip(self) -> BoundingBoxes: def vertical_flip(self) -> BoundingBoxes:
output = self._F.vertical_flip_bounding_boxes( output = self._F.vertical_flip_bounding_boxes(
self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size self.as_subclass(torch.Tensor), format=self.format, canvas_size=self.canvas_size
) )
return BoundingBoxes.wrap_like(self, output) return BoundingBoxes.wrap_like(self, output)
...@@ -117,25 +117,25 @@ class BoundingBoxes(Datapoint): ...@@ -117,25 +117,25 @@ class BoundingBoxes(Datapoint):
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> BoundingBoxes: ) -> BoundingBoxes:
output, spatial_size = self._F.resize_bounding_boxes( output, canvas_size = self._F.resize_bounding_boxes(
self.as_subclass(torch.Tensor), self.as_subclass(torch.Tensor),
spatial_size=self.spatial_size, canvas_size=self.canvas_size,
size=size, size=size,
max_size=max_size, max_size=max_size,
) )
return BoundingBoxes.wrap_like(self, output, spatial_size=spatial_size) return BoundingBoxes.wrap_like(self, output, canvas_size=canvas_size)
def crop(self, top: int, left: int, height: int, width: int) -> BoundingBoxes: def crop(self, top: int, left: int, height: int, width: int) -> BoundingBoxes:
output, spatial_size = self._F.crop_bounding_boxes( output, canvas_size = self._F.crop_bounding_boxes(
self.as_subclass(torch.Tensor), self.format, top=top, left=left, height=height, width=width self.as_subclass(torch.Tensor), self.format, top=top, left=left, height=height, width=width
) )
return BoundingBoxes.wrap_like(self, output, spatial_size=spatial_size) return BoundingBoxes.wrap_like(self, output, canvas_size=canvas_size)
def center_crop(self, output_size: List[int]) -> BoundingBoxes: def center_crop(self, output_size: List[int]) -> BoundingBoxes:
output, spatial_size = self._F.center_crop_bounding_boxes( output, canvas_size = self._F.center_crop_bounding_boxes(
self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size, output_size=output_size self.as_subclass(torch.Tensor), format=self.format, canvas_size=self.canvas_size, output_size=output_size
) )
return BoundingBoxes.wrap_like(self, output, spatial_size=spatial_size) return BoundingBoxes.wrap_like(self, output, canvas_size=canvas_size)
def resized_crop( def resized_crop(
self, self,
...@@ -147,10 +147,10 @@ class BoundingBoxes(Datapoint): ...@@ -147,10 +147,10 @@ class BoundingBoxes(Datapoint):
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> BoundingBoxes: ) -> BoundingBoxes:
output, spatial_size = self._F.resized_crop_bounding_boxes( output, canvas_size = self._F.resized_crop_bounding_boxes(
self.as_subclass(torch.Tensor), self.format, top, left, height, width, size=size self.as_subclass(torch.Tensor), self.format, top, left, height, width, size=size
) )
return BoundingBoxes.wrap_like(self, output, spatial_size=spatial_size) return BoundingBoxes.wrap_like(self, output, canvas_size=canvas_size)
def pad( def pad(
self, self,
...@@ -158,14 +158,14 @@ class BoundingBoxes(Datapoint): ...@@ -158,14 +158,14 @@ class BoundingBoxes(Datapoint):
fill: Optional[Union[int, float, List[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> BoundingBoxes: ) -> BoundingBoxes:
output, spatial_size = self._F.pad_bounding_boxes( output, canvas_size = self._F.pad_bounding_boxes(
self.as_subclass(torch.Tensor), self.as_subclass(torch.Tensor),
format=self.format, format=self.format,
spatial_size=self.spatial_size, canvas_size=self.canvas_size,
padding=padding, padding=padding,
padding_mode=padding_mode, padding_mode=padding_mode,
) )
return BoundingBoxes.wrap_like(self, output, spatial_size=spatial_size) return BoundingBoxes.wrap_like(self, output, canvas_size=canvas_size)
def rotate( def rotate(
self, self,
...@@ -175,15 +175,15 @@ class BoundingBoxes(Datapoint): ...@@ -175,15 +175,15 @@ class BoundingBoxes(Datapoint):
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: _FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> BoundingBoxes: ) -> BoundingBoxes:
output, spatial_size = self._F.rotate_bounding_boxes( output, canvas_size = self._F.rotate_bounding_boxes(
self.as_subclass(torch.Tensor), self.as_subclass(torch.Tensor),
format=self.format, format=self.format,
spatial_size=self.spatial_size, canvas_size=self.canvas_size,
angle=angle, angle=angle,
expand=expand, expand=expand,
center=center, center=center,
) )
return BoundingBoxes.wrap_like(self, output, spatial_size=spatial_size) return BoundingBoxes.wrap_like(self, output, canvas_size=canvas_size)
def affine( def affine(
self, self,
...@@ -198,7 +198,7 @@ class BoundingBoxes(Datapoint): ...@@ -198,7 +198,7 @@ class BoundingBoxes(Datapoint):
output = self._F.affine_bounding_boxes( output = self._F.affine_bounding_boxes(
self.as_subclass(torch.Tensor), self.as_subclass(torch.Tensor),
self.format, self.format,
self.spatial_size, self.canvas_size,
angle, angle,
translate=translate, translate=translate,
scale=scale, scale=scale,
...@@ -218,7 +218,7 @@ class BoundingBoxes(Datapoint): ...@@ -218,7 +218,7 @@ class BoundingBoxes(Datapoint):
output = self._F.perspective_bounding_boxes( output = self._F.perspective_bounding_boxes(
self.as_subclass(torch.Tensor), self.as_subclass(torch.Tensor),
format=self.format, format=self.format,
spatial_size=self.spatial_size, canvas_size=self.canvas_size,
startpoints=startpoints, startpoints=startpoints,
endpoints=endpoints, endpoints=endpoints,
coefficients=coefficients, coefficients=coefficients,
...@@ -232,6 +232,6 @@ class BoundingBoxes(Datapoint): ...@@ -232,6 +232,6 @@ class BoundingBoxes(Datapoint):
fill: _FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> BoundingBoxes: ) -> BoundingBoxes:
output = self._F.elastic_bounding_boxes( output = self._F.elastic_bounding_boxes(
self.as_subclass(torch.Tensor), self.format, self.spatial_size, displacement=displacement self.as_subclass(torch.Tensor), self.format, self.canvas_size, displacement=displacement
) )
return BoundingBoxes.wrap_like(self, output) return BoundingBoxes.wrap_like(self, output)
...@@ -138,7 +138,7 @@ class Datapoint(torch.Tensor): ...@@ -138,7 +138,7 @@ class Datapoint(torch.Tensor):
# *not* happen for `deepcopy(Tensor)`. A side-effect from detaching is that the `Tensor.requires_grad` # *not* happen for `deepcopy(Tensor)`. A side-effect from detaching is that the `Tensor.requires_grad`
# attribute is cleared, so we need to refill it before we return. # attribute is cleared, so we need to refill it before we return.
# Note: We don't explicitly handle deep-copying of the metadata here. The only metadata we currently have is # Note: We don't explicitly handle deep-copying of the metadata here. The only metadata we currently have is
# `BoundingBoxes.format` and `BoundingBoxes.spatial_size`, which are immutable and thus implicitly deep-copied by # `BoundingBoxes.format` and `BoundingBoxes.canvas_size`, which are immutable and thus implicitly deep-copied by
# `BoundingBoxes.clone()`. # `BoundingBoxes.clone()`.
return self.detach().clone().requires_grad_(self.requires_grad) # type: ignore[return-value] return self.detach().clone().requires_grad_(self.requires_grad) # type: ignore[return-value]
......
...@@ -341,13 +341,13 @@ def coco_dectection_wrapper_factory(dataset, target_keys): ...@@ -341,13 +341,13 @@ def coco_dectection_wrapper_factory(dataset, target_keys):
default={"image_id", "boxes", "labels"}, default={"image_id", "boxes", "labels"},
) )
def segmentation_to_mask(segmentation, *, spatial_size): def segmentation_to_mask(segmentation, *, canvas_size):
from pycocotools import mask from pycocotools import mask
segmentation = ( segmentation = (
mask.frPyObjects(segmentation, *spatial_size) mask.frPyObjects(segmentation, *canvas_size)
if isinstance(segmentation, dict) if isinstance(segmentation, dict)
else mask.merge(mask.frPyObjects(segmentation, *spatial_size)) else mask.merge(mask.frPyObjects(segmentation, *canvas_size))
) )
return torch.from_numpy(mask.decode(segmentation)) return torch.from_numpy(mask.decode(segmentation))
...@@ -359,7 +359,7 @@ def coco_dectection_wrapper_factory(dataset, target_keys): ...@@ -359,7 +359,7 @@ def coco_dectection_wrapper_factory(dataset, target_keys):
if not target: if not target:
return image, dict(image_id=image_id) return image, dict(image_id=image_id)
spatial_size = tuple(F.get_spatial_size(image)) canvas_size = tuple(F.get_size(image))
batched_target = list_of_dicts_to_dict_of_lists(target) batched_target = list_of_dicts_to_dict_of_lists(target)
target = {} target = {}
...@@ -372,7 +372,7 @@ def coco_dectection_wrapper_factory(dataset, target_keys): ...@@ -372,7 +372,7 @@ def coco_dectection_wrapper_factory(dataset, target_keys):
datapoints.BoundingBoxes( datapoints.BoundingBoxes(
batched_target["bbox"], batched_target["bbox"],
format=datapoints.BoundingBoxFormat.XYWH, format=datapoints.BoundingBoxFormat.XYWH,
spatial_size=spatial_size, canvas_size=canvas_size,
), ),
new_format=datapoints.BoundingBoxFormat.XYXY, new_format=datapoints.BoundingBoxFormat.XYXY,
) )
...@@ -381,7 +381,7 @@ def coco_dectection_wrapper_factory(dataset, target_keys): ...@@ -381,7 +381,7 @@ def coco_dectection_wrapper_factory(dataset, target_keys):
target["masks"] = datapoints.Mask( target["masks"] = datapoints.Mask(
torch.stack( torch.stack(
[ [
segmentation_to_mask(segmentation, spatial_size=spatial_size) segmentation_to_mask(segmentation, canvas_size=canvas_size)
for segmentation in batched_target["segmentation"] for segmentation in batched_target["segmentation"]
] ]
), ),
...@@ -456,7 +456,7 @@ def voc_detection_wrapper_factory(dataset, target_keys): ...@@ -456,7 +456,7 @@ def voc_detection_wrapper_factory(dataset, target_keys):
for bndbox in batched_instances["bndbox"] for bndbox in batched_instances["bndbox"]
], ],
format=datapoints.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=(image.height, image.width), canvas_size=(image.height, image.width),
) )
if "labels" in target_keys: if "labels" in target_keys:
...@@ -493,7 +493,7 @@ def celeba_wrapper_factory(dataset, target_keys): ...@@ -493,7 +493,7 @@ def celeba_wrapper_factory(dataset, target_keys):
datapoints.BoundingBoxes( datapoints.BoundingBoxes(
item, item,
format=datapoints.BoundingBoxFormat.XYWH, format=datapoints.BoundingBoxFormat.XYWH,
spatial_size=(image.height, image.width), canvas_size=(image.height, image.width),
), ),
new_format=datapoints.BoundingBoxFormat.XYXY, new_format=datapoints.BoundingBoxFormat.XYXY,
), ),
...@@ -543,7 +543,7 @@ def kitti_wrapper_factory(dataset, target_keys): ...@@ -543,7 +543,7 @@ def kitti_wrapper_factory(dataset, target_keys):
target["boxes"] = datapoints.BoundingBoxes( target["boxes"] = datapoints.BoundingBoxes(
batched_target["bbox"], batched_target["bbox"],
format=datapoints.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=(image.height, image.width), canvas_size=(image.height, image.width),
) )
if "labels" in target_keys: if "labels" in target_keys:
...@@ -638,7 +638,7 @@ def widerface_wrapper(dataset, target_keys): ...@@ -638,7 +638,7 @@ def widerface_wrapper(dataset, target_keys):
if "bbox" in target_keys: if "bbox" in target_keys:
target["bbox"] = F.convert_format_bounding_boxes( target["bbox"] = F.convert_format_bounding_boxes(
datapoints.BoundingBoxes( datapoints.BoundingBoxes(
target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width) target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, canvas_size=(image.height, image.width)
), ),
new_format=datapoints.BoundingBoxFormat.XYXY, new_format=datapoints.BoundingBoxFormat.XYXY,
) )
......
from __future__ import annotations from __future__ import annotations
from typing import Any, List, Optional, Tuple, Union from typing import Any, List, Optional, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -56,14 +56,6 @@ class Image(Datapoint): ...@@ -56,14 +56,6 @@ class Image(Datapoint):
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr() return self._make_repr()
@property
def spatial_size(self) -> Tuple[int, int]:
return tuple(self.shape[-2:]) # type: ignore[return-value]
@property
def num_channels(self) -> int:
return self.shape[-3]
def horizontal_flip(self) -> Image: def horizontal_flip(self) -> Image:
output = self._F.horizontal_flip_image_tensor(self.as_subclass(torch.Tensor)) output = self._F.horizontal_flip_image_tensor(self.as_subclass(torch.Tensor))
return Image.wrap_like(self, output) return Image.wrap_like(self, output)
......
from __future__ import annotations from __future__ import annotations
from typing import Any, List, Optional, Tuple, Union from typing import Any, List, Optional, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -51,10 +51,6 @@ class Mask(Datapoint): ...@@ -51,10 +51,6 @@ class Mask(Datapoint):
) -> Mask: ) -> Mask:
return cls._wrap(tensor) return cls._wrap(tensor)
@property
def spatial_size(self) -> Tuple[int, int]:
return tuple(self.shape[-2:]) # type: ignore[return-value]
def horizontal_flip(self) -> Mask: def horizontal_flip(self) -> Mask:
output = self._F.horizontal_flip_mask(self.as_subclass(torch.Tensor)) output = self._F.horizontal_flip_mask(self.as_subclass(torch.Tensor))
return Mask.wrap_like(self, output) return Mask.wrap_like(self, output)
......
from __future__ import annotations from __future__ import annotations
from typing import Any, List, Optional, Tuple, Union from typing import Any, List, Optional, Union
import torch import torch
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
...@@ -46,18 +46,6 @@ class Video(Datapoint): ...@@ -46,18 +46,6 @@ class Video(Datapoint):
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr() return self._make_repr()
@property
def spatial_size(self) -> Tuple[int, int]:
return tuple(self.shape[-2:]) # type: ignore[return-value]
@property
def num_channels(self) -> int:
return self.shape[-3]
@property
def num_frames(self) -> int:
return self.shape[-4]
def horizontal_flip(self) -> Video: def horizontal_flip(self) -> Video:
output = self._F.horizontal_flip_video(self.as_subclass(torch.Tensor)) output = self._F.horizontal_flip_video(self.as_subclass(torch.Tensor))
return Video.wrap_like(self, output) return Video.wrap_like(self, output)
......
...@@ -11,7 +11,7 @@ from torchvision.transforms.v2 import functional as F, InterpolationMode, Transf ...@@ -11,7 +11,7 @@ from torchvision.transforms.v2 import functional as F, InterpolationMode, Transf
from torchvision.transforms.v2._transform import _RandomApplyTransform from torchvision.transforms.v2._transform import _RandomApplyTransform
from torchvision.transforms.v2.functional._geometry import _check_interpolation from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_spatial_size from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_size
class _BaseMixupCutmix(_RandomApplyTransform): class _BaseMixupCutmix(_RandomApplyTransform):
...@@ -64,7 +64,7 @@ class RandomCutmix(_BaseMixupCutmix): ...@@ -64,7 +64,7 @@ class RandomCutmix(_BaseMixupCutmix):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
lam = float(self._dist.sample(())) # type: ignore[arg-type] lam = float(self._dist.sample(())) # type: ignore[arg-type]
H, W = query_spatial_size(flat_inputs) H, W = query_size(flat_inputs)
r_x = torch.randint(W, ()) r_x = torch.randint(W, ())
r_y = torch.randint(H, ()) r_y = torch.randint(H, ())
......
...@@ -7,7 +7,7 @@ from torchvision import datapoints ...@@ -7,7 +7,7 @@ from torchvision import datapoints
from torchvision.prototype.datapoints import Label, OneHotLabel from torchvision.prototype.datapoints import Label, OneHotLabel
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import _setup_fill_arg, _setup_size from torchvision.transforms.v2._utils import _setup_fill_arg, _setup_size
from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_bounding_boxes, query_spatial_size from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_bounding_boxes, query_size
class FixedSizeCrop(Transform): class FixedSizeCrop(Transform):
...@@ -46,7 +46,7 @@ class FixedSizeCrop(Transform): ...@@ -46,7 +46,7 @@ class FixedSizeCrop(Transform):
) )
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs) height, width = query_size(flat_inputs)
new_height = min(height, self.crop_height) new_height = min(height, self.crop_height)
new_width = min(width, self.crop_width) new_width = min(width, self.crop_width)
...@@ -67,7 +67,7 @@ class FixedSizeCrop(Transform): ...@@ -67,7 +67,7 @@ class FixedSizeCrop(Transform):
if needs_crop and bounding_boxes is not None: if needs_crop and bounding_boxes is not None:
format = bounding_boxes.format format = bounding_boxes.format
bounding_boxes, spatial_size = F.crop_bounding_boxes( bounding_boxes, canvas_size = F.crop_bounding_boxes(
bounding_boxes.as_subclass(torch.Tensor), bounding_boxes.as_subclass(torch.Tensor),
format=format, format=format,
top=top, top=top,
...@@ -75,7 +75,7 @@ class FixedSizeCrop(Transform): ...@@ -75,7 +75,7 @@ class FixedSizeCrop(Transform):
height=new_height, height=new_height,
width=new_width, width=new_width,
) )
bounding_boxes = F.clamp_bounding_boxes(bounding_boxes, format=format, spatial_size=spatial_size) bounding_boxes = F.clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size)
height_and_width = F.convert_format_bounding_boxes( height_and_width = F.convert_format_bounding_boxes(
bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYWH bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYWH
)[..., 2:] )[..., 2:]
...@@ -115,9 +115,7 @@ class FixedSizeCrop(Transform): ...@@ -115,9 +115,7 @@ class FixedSizeCrop(Transform):
elif isinstance(inpt, datapoints.BoundingBoxes): elif isinstance(inpt, datapoints.BoundingBoxes):
inpt = datapoints.BoundingBoxes.wrap_like( inpt = datapoints.BoundingBoxes.wrap_like(
inpt, inpt,
F.clamp_bounding_boxes( F.clamp_bounding_boxes(inpt[params["is_valid"]], format=inpt.format, canvas_size=inpt.canvas_size),
inpt[params["is_valid"]], format=inpt.format, spatial_size=inpt.spatial_size
),
) )
if params["needs_pad"]: if params["needs_pad"]:
......
...@@ -12,7 +12,7 @@ from torchvision.transforms.v2 import functional as F ...@@ -12,7 +12,7 @@ from torchvision.transforms.v2 import functional as F
from ._transform import _RandomApplyTransform, Transform from ._transform import _RandomApplyTransform, Transform
from ._utils import _parse_labels_getter from ._utils import _parse_labels_getter
from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size from .utils import has_any, is_simple_tensor, query_chw, query_size
class RandomErasing(_RandomApplyTransform): class RandomErasing(_RandomApplyTransform):
...@@ -284,7 +284,7 @@ class Cutmix(_BaseMixupCutmix): ...@@ -284,7 +284,7 @@ class Cutmix(_BaseMixupCutmix):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
lam = float(self._dist.sample(())) # type: ignore[arg-type] lam = float(self._dist.sample(())) # type: ignore[arg-type]
H, W = query_spatial_size(flat_inputs) H, W = query_size(flat_inputs)
r_x = torch.randint(W, size=(1,)) r_x = torch.randint(W, size=(1,))
r_y = torch.randint(H, size=(1,)) r_y = torch.randint(H, size=(1,))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment