Unverified Commit 332bff93 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Renaming: `BoundingBox` -> `BoundingBoxes` (#7778)

parent d4e5aa21
...@@ -15,5 +15,5 @@ see e.g. :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`. ...@@ -15,5 +15,5 @@ see e.g. :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`.
Image Image
Video Video
BoundingBoxFormat BoundingBoxFormat
BoundingBox BoundingBoxes
Mask Mask
...@@ -206,8 +206,8 @@ Miscellaneous ...@@ -206,8 +206,8 @@ Miscellaneous
v2.RandomErasing v2.RandomErasing
Lambda Lambda
v2.Lambda v2.Lambda
v2.SanitizeBoundingBox v2.SanitizeBoundingBoxes
v2.ClampBoundingBox v2.ClampBoundingBoxes
v2.UniformTemporalSubsample v2.UniformTemporalSubsample
.. _conversion_transforms: .. _conversion_transforms:
......
...@@ -47,7 +47,7 @@ assert image.data_ptr() == tensor.data_ptr() ...@@ -47,7 +47,7 @@ assert image.data_ptr() == tensor.data_ptr()
# #
# * :class:`~torchvision.datapoints.Image` # * :class:`~torchvision.datapoints.Image`
# * :class:`~torchvision.datapoints.Video` # * :class:`~torchvision.datapoints.Video`
# * :class:`~torchvision.datapoints.BoundingBox` # * :class:`~torchvision.datapoints.BoundingBoxes`
# * :class:`~torchvision.datapoints.Mask` # * :class:`~torchvision.datapoints.Mask`
# #
# How do I construct a datapoint? # How do I construct a datapoint?
...@@ -76,10 +76,10 @@ print(image.shape, image.dtype) ...@@ -76,10 +76,10 @@ print(image.shape, image.dtype)
######################################################################################################################## ########################################################################################################################
# In general, the datapoints can also store additional metadata that complements the underlying tensor. For example, # In general, the datapoints can also store additional metadata that complements the underlying tensor. For example,
# :class:`~torchvision.datapoints.BoundingBox` stores the coordinate format as well as the spatial size of the # :class:`~torchvision.datapoints.BoundingBoxes` stores the coordinate format as well as the spatial size of the
# corresponding image alongside the actual values: # corresponding image alongside the actual values:
bounding_box = datapoints.BoundingBox( bounding_box = datapoints.BoundingBoxes(
[17, 16, 344, 495], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=image.shape[-2:] [17, 16, 344, 495], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=image.shape[-2:]
) )
print(bounding_box) print(bounding_box)
...@@ -105,7 +105,7 @@ class PennFudanDataset(torch.utils.data.Dataset): ...@@ -105,7 +105,7 @@ class PennFudanDataset(torch.utils.data.Dataset):
def __getitem__(self, item): def __getitem__(self, item):
... ...
target["boxes"] = datapoints.BoundingBox( target["boxes"] = datapoints.BoundingBoxes(
boxes, boxes,
format=datapoints.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=F.get_spatial_size(img), spatial_size=F.get_spatial_size(img),
...@@ -126,7 +126,7 @@ class PennFudanDataset(torch.utils.data.Dataset): ...@@ -126,7 +126,7 @@ class PennFudanDataset(torch.utils.data.Dataset):
class WrapPennFudanDataset: class WrapPennFudanDataset:
def __call__(self, img, target): def __call__(self, img, target):
target["boxes"] = datapoints.BoundingBox( target["boxes"] = datapoints.BoundingBoxes(
target["boxes"], target["boxes"],
format=datapoints.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=F.get_spatial_size(img), spatial_size=F.get_spatial_size(img),
...@@ -147,7 +147,7 @@ def get_transform(train): ...@@ -147,7 +147,7 @@ def get_transform(train):
######################################################################################################################## ########################################################################################################################
# .. note:: # .. note::
# #
# If both :class:`~torchvision.datapoints.BoundingBox`'es and :class:`~torchvision.datapoints.Mask`'s are included in # If both :class:`~torchvision.datapoints.BoundingBoxes`'es and :class:`~torchvision.datapoints.Mask`'s are included in
# the sample, ``torchvision.transforms.v2`` will transform them both. Meaning, if you don't need both, dropping or # the sample, ``torchvision.transforms.v2`` will transform them both. Meaning, if you don't need both, dropping or
# at least not wrapping the obsolete parts, can lead to a significant performance boost. # at least not wrapping the obsolete parts, can lead to a significant performance boost.
# #
......
...@@ -29,7 +29,7 @@ def load_data(): ...@@ -29,7 +29,7 @@ def load_data():
masks = datapoints.Mask(merged_masks == labels.view(-1, 1, 1)) masks = datapoints.Mask(merged_masks == labels.view(-1, 1, 1))
bounding_boxes = datapoints.BoundingBox( bounding_boxes = datapoints.BoundingBoxes(
masks_to_boxes(masks), format=datapoints.BoundingBoxFormat.XYXY, spatial_size=image.shape[-2:] masks_to_boxes(masks), format=datapoints.BoundingBoxFormat.XYXY, spatial_size=image.shape[-2:]
) )
......
...@@ -106,13 +106,13 @@ transform = transforms.Compose( ...@@ -106,13 +106,13 @@ transform = transforms.Compose(
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.ToImageTensor(), transforms.ToImageTensor(),
transforms.ConvertImageDtype(torch.float32), transforms.ConvertImageDtype(torch.float32),
transforms.SanitizeBoundingBox(), transforms.SanitizeBoundingBoxes(),
] ]
) )
######################################################################################################################## ########################################################################################################################
# .. note:: # .. note::
# Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBox` transform is a no-op in this example, but it # Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` transform is a no-op in this example, but it
# should be placed at least once at the end of a detection pipeline to remove degenerate bounding boxes as well as # should be placed at least once at the end of a detection pipeline to remove degenerate bounding boxes as well as
# the corresponding labels and optionally masks. It is particularly critical to add it if # the corresponding labels and optionally masks. It is particularly critical to add it if
# :class:`~torchvision.transforms.v2.RandomIoUCrop` was used. # :class:`~torchvision.transforms.v2.RandomIoUCrop` was used.
......
...@@ -78,7 +78,7 @@ class DetectionPresetTrain: ...@@ -78,7 +78,7 @@ class DetectionPresetTrain:
if use_v2: if use_v2:
transforms += [ transforms += [
T.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.XYXY), T.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.XYXY),
T.SanitizeBoundingBox(), T.SanitizeBoundingBoxes(),
] ]
self.transforms = T.Compose(transforms) self.transforms = T.Compose(transforms)
......
...@@ -620,7 +620,7 @@ def make_image_loaders_for_interpolation( ...@@ -620,7 +620,7 @@ def make_image_loaders_for_interpolation(
@dataclasses.dataclass @dataclasses.dataclass
class BoundingBoxLoader(TensorLoader): class BoundingBoxesLoader(TensorLoader):
format: datapoints.BoundingBoxFormat format: datapoints.BoundingBoxFormat
spatial_size: Tuple[int, int] spatial_size: Tuple[int, int]
...@@ -639,7 +639,7 @@ def make_bounding_box( ...@@ -639,7 +639,7 @@ def make_bounding_box(
- (box[3] - box[1], box[2] - box[0]) for XYXY - (box[3] - box[1], box[2] - box[0]) for XYXY
- (H, W) for XYWH and CXCYWH - (H, W) for XYWH and CXCYWH
spatial_size: Size of the reference object, e.g. an image. Corresponds to the .spatial_size attribute on spatial_size: Size of the reference object, e.g. an image. Corresponds to the .spatial_size attribute on
returned datapoints.BoundingBox returned datapoints.BoundingBoxes
To generate a valid joint sample, you need to set spatial_size here to the same value as size on the other maker To generate a valid joint sample, you need to set spatial_size here to the same value as size on the other maker
functions, e.g. functions, e.g.
...@@ -647,8 +647,8 @@ def make_bounding_box( ...@@ -647,8 +647,8 @@ def make_bounding_box(
.. code:: .. code::
image = make_image=(size=size) image = make_image=(size=size)
bounding_box = make_bounding_box(spatial_size=size) bounding_boxes = make_bounding_box(spatial_size=size)
assert F.get_spatial_size(bounding_box) == F.get_spatial_size(image) assert F.get_spatial_size(bounding_boxes) == F.get_spatial_size(image)
For convenience, if both size and spatial_size are omitted, spatial_size defaults to the same value as size for all For convenience, if both size and spatial_size are omitted, spatial_size defaults to the same value as size for all
other maker functions, e.g. other maker functions, e.g.
...@@ -656,8 +656,8 @@ def make_bounding_box( ...@@ -656,8 +656,8 @@ def make_bounding_box(
.. code:: .. code::
image = make_image=() image = make_image=()
bounding_box = make_bounding_box() bounding_boxes = make_bounding_box()
assert F.get_spatial_size(bounding_box) == F.get_spatial_size(image) assert F.get_spatial_size(bounding_boxes) == F.get_spatial_size(image)
""" """
def sample_position(values, max_value): def sample_position(values, max_value):
...@@ -679,7 +679,7 @@ def make_bounding_box( ...@@ -679,7 +679,7 @@ def make_bounding_box(
dtype = dtype or torch.float32 dtype = dtype or torch.float32
if any(dim == 0 for dim in batch_dims): if any(dim == 0 for dim in batch_dims):
return datapoints.BoundingBox( return datapoints.BoundingBoxes(
torch.empty(*batch_dims, 4, dtype=dtype, device=device), format=format, spatial_size=spatial_size torch.empty(*batch_dims, 4, dtype=dtype, device=device), format=format, spatial_size=spatial_size
) )
...@@ -705,7 +705,7 @@ def make_bounding_box( ...@@ -705,7 +705,7 @@ def make_bounding_box(
else: else:
raise ValueError(f"Format {format} is not supported") raise ValueError(f"Format {format} is not supported")
return datapoints.BoundingBox( return datapoints.BoundingBoxes(
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, spatial_size=spatial_size torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, spatial_size=spatial_size
) )
...@@ -725,7 +725,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORT ...@@ -725,7 +725,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORT
format=format, spatial_size=spatial_size, batch_dims=batch_dims, dtype=dtype, device=device format=format, spatial_size=spatial_size, batch_dims=batch_dims, dtype=dtype, device=device
) )
return BoundingBoxLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=spatial_size) return BoundingBoxesLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=spatial_size)
def make_bounding_box_loaders( def make_bounding_box_loaders(
......
...@@ -27,7 +27,7 @@ def test_mask_instance(data): ...@@ -27,7 +27,7 @@ def test_mask_instance(data):
"format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH] "format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH]
) )
def test_bbox_instance(data, format): def test_bbox_instance(data, format):
bboxes = datapoints.BoundingBox(data, format=format, spatial_size=(32, 32)) bboxes = datapoints.BoundingBoxes(data, format=format, spatial_size=(32, 32))
assert isinstance(bboxes, torch.Tensor) assert isinstance(bboxes, torch.Tensor)
assert bboxes.ndim == 2 and bboxes.shape[1] == 4 assert bboxes.ndim == 2 and bboxes.shape[1] == 4
if isinstance(format, str): if isinstance(format, str):
...@@ -164,7 +164,7 @@ def test_wrap_like(): ...@@ -164,7 +164,7 @@ def test_wrap_like():
[ [
datapoints.Image(torch.rand(3, 16, 16)), datapoints.Image(torch.rand(3, 16, 16)),
datapoints.Video(torch.rand(2, 3, 16, 16)), datapoints.Video(torch.rand(2, 3, 16, 16)),
datapoints.BoundingBox([0.0, 1.0, 2.0, 3.0], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(10, 10)), datapoints.BoundingBoxes([0.0, 1.0, 2.0, 3.0], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(10, 10)),
datapoints.Mask(torch.randint(0, 256, (16, 16), dtype=torch.uint8)), datapoints.Mask(torch.randint(0, 256, (16, 16), dtype=torch.uint8)),
], ],
) )
......
...@@ -20,7 +20,7 @@ from common_utils import ( ...@@ -20,7 +20,7 @@ from common_utils import (
from prototype_common_utils import make_label, make_one_hot_labels from prototype_common_utils import make_label, make_one_hot_labels
from torchvision.datapoints import BoundingBox, BoundingBoxFormat, Image, Mask, Video from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
from torchvision.prototype import datapoints, transforms from torchvision.prototype import datapoints, transforms
from torchvision.transforms.v2._utils import _convert_fill_arg from torchvision.transforms.v2._utils import _convert_fill_arg
from torchvision.transforms.v2.functional import InterpolationMode, pil_to_tensor, to_image_pil from torchvision.transforms.v2.functional import InterpolationMode, pil_to_tensor, to_image_pil
...@@ -101,10 +101,10 @@ class TestSimpleCopyPaste: ...@@ -101,10 +101,10 @@ class TestSimpleCopyPaste:
self.create_fake_image(mocker, Image), self.create_fake_image(mocker, Image),
# labels, bboxes, masks # labels, bboxes, masks
mocker.MagicMock(spec=datapoints.Label), mocker.MagicMock(spec=datapoints.Label),
mocker.MagicMock(spec=BoundingBox), mocker.MagicMock(spec=BoundingBoxes),
mocker.MagicMock(spec=Mask), mocker.MagicMock(spec=Mask),
# labels, bboxes, masks # labels, bboxes, masks
mocker.MagicMock(spec=BoundingBox), mocker.MagicMock(spec=BoundingBoxes),
mocker.MagicMock(spec=Mask), mocker.MagicMock(spec=Mask),
] ]
...@@ -122,11 +122,11 @@ class TestSimpleCopyPaste: ...@@ -122,11 +122,11 @@ class TestSimpleCopyPaste:
self.create_fake_image(mocker, image_type), self.create_fake_image(mocker, image_type),
# labels, bboxes, masks # labels, bboxes, masks
mocker.MagicMock(spec=label_type), mocker.MagicMock(spec=label_type),
mocker.MagicMock(spec=BoundingBox), mocker.MagicMock(spec=BoundingBoxes),
mocker.MagicMock(spec=Mask), mocker.MagicMock(spec=Mask),
# labels, bboxes, masks # labels, bboxes, masks
mocker.MagicMock(spec=label_type), mocker.MagicMock(spec=label_type),
mocker.MagicMock(spec=BoundingBox), mocker.MagicMock(spec=BoundingBoxes),
mocker.MagicMock(spec=Mask), mocker.MagicMock(spec=Mask),
] ]
...@@ -142,7 +142,7 @@ class TestSimpleCopyPaste: ...@@ -142,7 +142,7 @@ class TestSimpleCopyPaste:
for target in targets: for target in targets:
for key, type_ in [ for key, type_ in [
("boxes", BoundingBox), ("boxes", BoundingBoxes),
("masks", Mask), ("masks", Mask),
("labels", label_type), ("labels", label_type),
]: ]:
...@@ -163,7 +163,7 @@ class TestSimpleCopyPaste: ...@@ -163,7 +163,7 @@ class TestSimpleCopyPaste:
if label_type == datapoints.OneHotLabel: if label_type == datapoints.OneHotLabel:
labels = torch.nn.functional.one_hot(labels, num_classes=5) labels = torch.nn.functional.one_hot(labels, num_classes=5)
target = { target = {
"boxes": BoundingBox( "boxes": BoundingBoxes(
torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", spatial_size=(32, 32) torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", spatial_size=(32, 32)
), ),
"masks": Mask(masks), "masks": Mask(masks),
...@@ -178,7 +178,7 @@ class TestSimpleCopyPaste: ...@@ -178,7 +178,7 @@ class TestSimpleCopyPaste:
if label_type == datapoints.OneHotLabel: if label_type == datapoints.OneHotLabel:
paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5) paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5)
paste_target = { paste_target = {
"boxes": BoundingBox( "boxes": BoundingBoxes(
torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", spatial_size=(32, 32) torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", spatial_size=(32, 32)
), ),
"masks": Mask(paste_masks), "masks": Mask(paste_masks),
...@@ -332,7 +332,7 @@ class TestFixedSizeCrop: ...@@ -332,7 +332,7 @@ class TestFixedSizeCrop:
assert_equal(output["masks"], masks[is_valid]) assert_equal(output["masks"], masks[is_valid])
assert_equal(output["labels"], labels[is_valid]) assert_equal(output["labels"], labels[is_valid])
def test__transform_bounding_box_clamping(self, mocker): def test__transform_bounding_boxes_clamping(self, mocker):
batch_size = 3 batch_size = 3
spatial_size = (10, 10) spatial_size = (10, 10)
...@@ -349,15 +349,15 @@ class TestFixedSizeCrop: ...@@ -349,15 +349,15 @@ class TestFixedSizeCrop:
), ),
) )
bounding_box = make_bounding_box( bounding_boxes = make_bounding_box(
format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, batch_dims=(batch_size,) format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, batch_dims=(batch_size,)
) )
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_box") mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes")
transform = transforms.FixedSizeCrop((-1, -1)) transform = transforms.FixedSizeCrop((-1, -1))
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)
transform(bounding_box) transform(bounding_boxes)
mock.assert_called_once() mock.assert_called_once()
...@@ -390,7 +390,7 @@ class TestPermuteDimensions: ...@@ -390,7 +390,7 @@ class TestPermuteDimensions:
def test_call(self, dims, inverse_dims): def test_call(self, dims, inverse_dims):
sample = dict( sample = dict(
image=make_image(), image=make_image(),
bounding_box=make_bounding_box(format=BoundingBoxFormat.XYXY), bounding_boxes=make_bounding_box(format=BoundingBoxFormat.XYXY),
video=make_video(), video=make_video(),
str="str", str="str",
int=0, int=0,
...@@ -434,7 +434,7 @@ class TestTransposeDimensions: ...@@ -434,7 +434,7 @@ class TestTransposeDimensions:
def test_call(self, dims): def test_call(self, dims):
sample = dict( sample = dict(
image=make_image(), image=make_image(),
bounding_box=make_bounding_box(format=BoundingBoxFormat.XYXY), bounding_boxes=make_bounding_box(format=BoundingBoxFormat.XYXY),
video=make_video(), video=make_video(),
str="str", str="str",
int=0, int=0,
......
...@@ -46,8 +46,8 @@ def make_pil_images(*args, **kwargs): ...@@ -46,8 +46,8 @@ def make_pil_images(*args, **kwargs):
def make_vanilla_tensor_bounding_boxes(*args, **kwargs): def make_vanilla_tensor_bounding_boxes(*args, **kwargs):
for bounding_box in make_bounding_boxes(*args, **kwargs): for bounding_boxes in make_bounding_boxes(*args, **kwargs):
yield bounding_box.data yield bounding_boxes.data
def parametrize(transforms_with_inputs): def parametrize(transforms_with_inputs):
...@@ -69,7 +69,7 @@ def auto_augment_adapter(transform, input, device): ...@@ -69,7 +69,7 @@ def auto_augment_adapter(transform, input, device):
adapted_input = {} adapted_input = {}
image_or_video_found = False image_or_video_found = False
for key, value in input.items(): for key, value in input.items():
if isinstance(value, (datapoints.BoundingBox, datapoints.Mask)): if isinstance(value, (datapoints.BoundingBoxes, datapoints.Mask)):
# AA transforms don't support bounding boxes or masks # AA transforms don't support bounding boxes or masks
continue continue
elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor, PIL.Image.Image)): elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor, PIL.Image.Image)):
...@@ -143,7 +143,7 @@ class TestSmoke: ...@@ -143,7 +143,7 @@ class TestSmoke:
(transforms.RandomZoomOut(p=1.0), None), (transforms.RandomZoomOut(p=1.0), None),
(transforms.Resize([16, 16], antialias=True), None), (transforms.Resize([16, 16], antialias=True), None),
(transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2), antialias=True), None), (transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2), antialias=True), None),
(transforms.ClampBoundingBox(), None), (transforms.ClampBoundingBoxes(), None),
(transforms.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.CXCYWH), None), (transforms.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.CXCYWH), None),
(transforms.ConvertImageDtype(), None), (transforms.ConvertImageDtype(), None),
(transforms.GaussianBlur(kernel_size=3), None), (transforms.GaussianBlur(kernel_size=3), None),
...@@ -180,16 +180,16 @@ class TestSmoke: ...@@ -180,16 +180,16 @@ class TestSmoke:
image_datapoint=make_image(size=spatial_size), image_datapoint=make_image(size=spatial_size),
video_datapoint=make_video(size=spatial_size), video_datapoint=make_video(size=spatial_size),
image_pil=next(make_pil_images(sizes=[spatial_size], color_spaces=["RGB"])), image_pil=next(make_pil_images(sizes=[spatial_size], color_spaces=["RGB"])),
bounding_box_xyxy=make_bounding_box( bounding_boxes_xyxy=make_bounding_box(
format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, batch_dims=(3,) format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, batch_dims=(3,)
), ),
bounding_box_xywh=make_bounding_box( bounding_boxes_xywh=make_bounding_box(
format=datapoints.BoundingBoxFormat.XYWH, spatial_size=spatial_size, batch_dims=(4,) format=datapoints.BoundingBoxFormat.XYWH, spatial_size=spatial_size, batch_dims=(4,)
), ),
bounding_box_cxcywh=make_bounding_box( bounding_boxes_cxcywh=make_bounding_box(
format=datapoints.BoundingBoxFormat.CXCYWH, spatial_size=spatial_size, batch_dims=(5,) format=datapoints.BoundingBoxFormat.CXCYWH, spatial_size=spatial_size, batch_dims=(5,)
), ),
bounding_box_degenerate_xyxy=datapoints.BoundingBox( bounding_boxes_degenerate_xyxy=datapoints.BoundingBoxes(
[ [
[0, 0, 0, 0], # no height or width [0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height [0, 0, 0, 1], # no height
...@@ -201,7 +201,7 @@ class TestSmoke: ...@@ -201,7 +201,7 @@ class TestSmoke:
format=datapoints.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size, spatial_size=spatial_size,
), ),
bounding_box_degenerate_xywh=datapoints.BoundingBox( bounding_boxes_degenerate_xywh=datapoints.BoundingBoxes(
[ [
[0, 0, 0, 0], # no height or width [0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height [0, 0, 0, 1], # no height
...@@ -213,7 +213,7 @@ class TestSmoke: ...@@ -213,7 +213,7 @@ class TestSmoke:
format=datapoints.BoundingBoxFormat.XYWH, format=datapoints.BoundingBoxFormat.XYWH,
spatial_size=spatial_size, spatial_size=spatial_size,
), ),
bounding_box_degenerate_cxcywh=datapoints.BoundingBox( bounding_boxes_degenerate_cxcywh=datapoints.BoundingBoxes(
[ [
[0, 0, 0, 0], # no height or width [0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height [0, 0, 0, 1], # no height
...@@ -261,7 +261,7 @@ class TestSmoke: ...@@ -261,7 +261,7 @@ class TestSmoke:
else: else:
assert output_item is input_item assert output_item is input_item
if isinstance(input_item, datapoints.BoundingBox) and not isinstance( if isinstance(input_item, datapoints.BoundingBoxes) and not isinstance(
transform, transforms.ConvertBoundingBoxFormat transform, transforms.ConvertBoundingBoxFormat
): ):
assert output_item.format == input_item.format assert output_item.format == input_item.format
...@@ -271,10 +271,10 @@ class TestSmoke: ...@@ -271,10 +271,10 @@ class TestSmoke:
# TODO: we should test that against all degenerate boxes above # TODO: we should test that against all degenerate boxes above
for format in list(datapoints.BoundingBoxFormat): for format in list(datapoints.BoundingBoxFormat):
sample = dict( sample = dict(
boxes=datapoints.BoundingBox([[0, 0, 0, 0]], format=format, spatial_size=(224, 244)), boxes=datapoints.BoundingBoxes([[0, 0, 0, 0]], format=format, spatial_size=(224, 244)),
labels=torch.tensor([3]), labels=torch.tensor([3]),
) )
assert transforms.SanitizeBoundingBox()(sample)["boxes"].shape == (0, 4) assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4)
@parametrize( @parametrize(
[ [
...@@ -942,7 +942,7 @@ class TestRandomErasing: ...@@ -942,7 +942,7 @@ class TestRandomErasing:
class TestTransform: class TestTransform:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inpt_type", "inpt_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int], [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int],
) )
def test_check_transformed_types(self, inpt_type, mocker): def test_check_transformed_types(self, inpt_type, mocker):
# This test ensures that we correctly handle which types to transform and which to bypass # This test ensures that we correctly handle which types to transform and which to bypass
...@@ -960,7 +960,7 @@ class TestTransform: ...@@ -960,7 +960,7 @@ class TestTransform:
class TestToImageTensor: class TestToImageTensor:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inpt_type", "inpt_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int], [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int],
) )
def test__transform(self, inpt_type, mocker): def test__transform(self, inpt_type, mocker):
fn = mocker.patch( fn = mocker.patch(
...@@ -971,7 +971,7 @@ class TestToImageTensor: ...@@ -971,7 +971,7 @@ class TestToImageTensor:
inpt = mocker.MagicMock(spec=inpt_type) inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToImageTensor() transform = transforms.ToImageTensor()
transform(inpt) transform(inpt)
if inpt_type in (datapoints.BoundingBox, datapoints.Image, str, int): if inpt_type in (datapoints.BoundingBoxes, datapoints.Image, str, int):
assert fn.call_count == 0 assert fn.call_count == 0
else: else:
fn.assert_called_once_with(inpt) fn.assert_called_once_with(inpt)
...@@ -980,7 +980,7 @@ class TestToImageTensor: ...@@ -980,7 +980,7 @@ class TestToImageTensor:
class TestToImagePIL: class TestToImagePIL:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inpt_type", "inpt_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int], [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int],
) )
def test__transform(self, inpt_type, mocker): def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.transforms.v2.functional.to_image_pil") fn = mocker.patch("torchvision.transforms.v2.functional.to_image_pil")
...@@ -988,7 +988,7 @@ class TestToImagePIL: ...@@ -988,7 +988,7 @@ class TestToImagePIL:
inpt = mocker.MagicMock(spec=inpt_type) inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToImagePIL() transform = transforms.ToImagePIL()
transform(inpt) transform(inpt)
if inpt_type in (datapoints.BoundingBox, PIL.Image.Image, str, int): if inpt_type in (datapoints.BoundingBoxes, PIL.Image.Image, str, int):
assert fn.call_count == 0 assert fn.call_count == 0
else: else:
fn.assert_called_once_with(inpt, mode=transform.mode) fn.assert_called_once_with(inpt, mode=transform.mode)
...@@ -997,7 +997,7 @@ class TestToImagePIL: ...@@ -997,7 +997,7 @@ class TestToImagePIL:
class TestToPILImage: class TestToPILImage:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inpt_type", "inpt_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int], [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int],
) )
def test__transform(self, inpt_type, mocker): def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.transforms.v2.functional.to_image_pil") fn = mocker.patch("torchvision.transforms.v2.functional.to_image_pil")
...@@ -1005,7 +1005,7 @@ class TestToPILImage: ...@@ -1005,7 +1005,7 @@ class TestToPILImage:
inpt = mocker.MagicMock(spec=inpt_type) inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToPILImage() transform = transforms.ToPILImage()
transform(inpt) transform(inpt)
if inpt_type in (PIL.Image.Image, datapoints.BoundingBox, str, int): if inpt_type in (PIL.Image.Image, datapoints.BoundingBoxes, str, int):
assert fn.call_count == 0 assert fn.call_count == 0
else: else:
fn.assert_called_once_with(inpt, mode=transform.mode) fn.assert_called_once_with(inpt, mode=transform.mode)
...@@ -1014,7 +1014,7 @@ class TestToPILImage: ...@@ -1014,7 +1014,7 @@ class TestToPILImage:
class TestToTensor: class TestToTensor:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inpt_type", "inpt_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int], [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int],
) )
def test__transform(self, inpt_type, mocker): def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.transforms.functional.to_tensor") fn = mocker.patch("torchvision.transforms.functional.to_tensor")
...@@ -1023,7 +1023,7 @@ class TestToTensor: ...@@ -1023,7 +1023,7 @@ class TestToTensor:
with pytest.warns(UserWarning, match="deprecated and will be removed"): with pytest.warns(UserWarning, match="deprecated and will be removed"):
transform = transforms.ToTensor() transform = transforms.ToTensor()
transform(inpt) transform(inpt)
if inpt_type in (datapoints.Image, torch.Tensor, datapoints.BoundingBox, str, int): if inpt_type in (datapoints.Image, torch.Tensor, datapoints.BoundingBoxes, str, int):
assert fn.call_count == 0 assert fn.call_count == 0
else: else:
fn.assert_called_once_with(inpt) fn.assert_called_once_with(inpt)
...@@ -1065,7 +1065,7 @@ class TestRandomIoUCrop: ...@@ -1065,7 +1065,7 @@ class TestRandomIoUCrop:
image = mocker.MagicMock(spec=datapoints.Image) image = mocker.MagicMock(spec=datapoints.Image)
image.num_channels = 3 image.num_channels = 3
image.spatial_size = (24, 32) image.spatial_size = (24, 32)
bboxes = datapoints.BoundingBox( bboxes = datapoints.BoundingBoxes(
torch.tensor([[1, 1, 10, 10], [20, 20, 23, 23], [1, 20, 10, 23], [20, 1, 23, 10]]), torch.tensor([[1, 1, 10, 10], [20, 20, 23, 23], [1, 20, 10, 23], [20, 1, 23, 10]]),
format="XYXY", format="XYXY",
spatial_size=image.spatial_size, spatial_size=image.spatial_size,
...@@ -1103,7 +1103,7 @@ class TestRandomIoUCrop: ...@@ -1103,7 +1103,7 @@ class TestRandomIoUCrop:
def test__transform_empty_params(self, mocker): def test__transform_empty_params(self, mocker):
transform = transforms.RandomIoUCrop(sampler_options=[2.0]) transform = transforms.RandomIoUCrop(sampler_options=[2.0])
image = datapoints.Image(torch.rand(1, 3, 4, 4)) image = datapoints.Image(torch.rand(1, 3, 4, 4))
bboxes = datapoints.BoundingBox(torch.tensor([[1, 1, 2, 2]]), format="XYXY", spatial_size=(4, 4)) bboxes = datapoints.BoundingBoxes(torch.tensor([[1, 1, 2, 2]]), format="XYXY", spatial_size=(4, 4))
label = torch.tensor([1]) label = torch.tensor([1])
sample = [image, bboxes, label] sample = [image, bboxes, label]
# Let's mock transform._get_params to control the output: # Let's mock transform._get_params to control the output:
...@@ -1147,7 +1147,7 @@ class TestRandomIoUCrop: ...@@ -1147,7 +1147,7 @@ class TestRandomIoUCrop:
# check number of bboxes vs number of labels: # check number of bboxes vs number of labels:
output_bboxes = output[1] output_bboxes = output[1]
assert isinstance(output_bboxes, datapoints.BoundingBox) assert isinstance(output_bboxes, datapoints.BoundingBoxes)
assert (output_bboxes[~is_within_crop_area] == 0).all() assert (output_bboxes[~is_within_crop_area] == 0).all()
output_masks = output[2] output_masks = output[2]
...@@ -1505,7 +1505,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): ...@@ -1505,7 +1505,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
transforms.ConvertImageDtype(torch.float), transforms.ConvertImageDtype(torch.float),
] ]
if sanitize: if sanitize:
t += [transforms.SanitizeBoundingBox()] t += [transforms.SanitizeBoundingBoxes()]
t = transforms.Compose(t) t = transforms.Compose(t)
num_boxes = 5 num_boxes = 5
...@@ -1523,7 +1523,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): ...@@ -1523,7 +1523,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
boxes = torch.randint(0, min(H, W) // 2, size=(num_boxes, 4)) boxes = torch.randint(0, min(H, W) // 2, size=(num_boxes, 4))
boxes[:, 2:] += boxes[:, :2] boxes[:, 2:] += boxes[:, :2]
boxes = boxes.clamp(min=0, max=min(H, W)) boxes = boxes.clamp(min=0, max=min(H, W))
boxes = datapoints.BoundingBox(boxes, format="XYXY", spatial_size=(H, W)) boxes = datapoints.BoundingBoxes(boxes, format="XYXY", spatial_size=(H, W))
masks = datapoints.Mask(torch.randint(0, 2, size=(num_boxes, H, W), dtype=torch.uint8)) masks = datapoints.Mask(torch.randint(0, 2, size=(num_boxes, H, W), dtype=torch.uint8))
...@@ -1546,7 +1546,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): ...@@ -1546,7 +1546,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
# ssd and ssdlite contain RandomIoUCrop which may "remove" some bbox. It # ssd and ssdlite contain RandomIoUCrop which may "remove" some bbox. It
# doesn't remove them strictly speaking, it just marks some boxes as # doesn't remove them strictly speaking, it just marks some boxes as
# degenerate and those boxes will be later removed by # degenerate and those boxes will be later removed by
# SanitizeBoundingBox(), which we add to the pipelines if the sanitize # SanitizeBoundingBoxes(), which we add to the pipelines if the sanitize
# param is True. # param is True.
# Note that the values below are probably specific to the random seed # Note that the values below are probably specific to the random seed
# set above (which is fine). # set above (which is fine).
...@@ -1594,7 +1594,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type): ...@@ -1594,7 +1594,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
boxes = torch.tensor(boxes) boxes = torch.tensor(boxes)
labels = torch.arange(boxes.shape[0]) labels = torch.arange(boxes.shape[0])
boxes = datapoints.BoundingBox( boxes = datapoints.BoundingBoxes(
boxes, boxes,
format=datapoints.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=(H, W), spatial_size=(H, W),
...@@ -1616,7 +1616,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type): ...@@ -1616,7 +1616,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
img = sample.pop("image") img = sample.pop("image")
sample = (img, sample) sample = (img, sample)
out = transforms.SanitizeBoundingBox(min_size=min_size, labels_getter=labels_getter)(sample) out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample)
if sample_type is tuple: if sample_type is tuple:
out_image = out[0] out_image = out[0]
...@@ -1634,7 +1634,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type): ...@@ -1634,7 +1634,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
assert out_image is input_img assert out_image is input_img
assert out_whatever is whatever assert out_whatever is whatever
assert isinstance(out_boxes, datapoints.BoundingBox) assert isinstance(out_boxes, datapoints.BoundingBoxes)
assert isinstance(out_masks, datapoints.Mask) assert isinstance(out_masks, datapoints.Mask)
if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None): if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None):
...@@ -1648,31 +1648,31 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type): ...@@ -1648,31 +1648,31 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
def test_sanitize_bounding_boxes_errors(): def test_sanitize_bounding_boxes_errors():
good_bbox = datapoints.BoundingBox( good_bbox = datapoints.BoundingBoxes(
[[0, 0, 10, 10]], [[0, 0, 10, 10]],
format=datapoints.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=(20, 20), spatial_size=(20, 20),
) )
with pytest.raises(ValueError, match="min_size must be >= 1"): with pytest.raises(ValueError, match="min_size must be >= 1"):
transforms.SanitizeBoundingBox(min_size=0) transforms.SanitizeBoundingBoxes(min_size=0)
with pytest.raises(ValueError, match="labels_getter should either be 'default'"): with pytest.raises(ValueError, match="labels_getter should either be 'default'"):
transforms.SanitizeBoundingBox(labels_getter=12) transforms.SanitizeBoundingBoxes(labels_getter=12)
with pytest.raises(ValueError, match="Could not infer where the labels are"): with pytest.raises(ValueError, match="Could not infer where the labels are"):
bad_labels_key = {"bbox": good_bbox, "BAD_KEY": torch.arange(good_bbox.shape[0])} bad_labels_key = {"bbox": good_bbox, "BAD_KEY": torch.arange(good_bbox.shape[0])}
transforms.SanitizeBoundingBox()(bad_labels_key) transforms.SanitizeBoundingBoxes()(bad_labels_key)
with pytest.raises(ValueError, match="must be a tensor"): with pytest.raises(ValueError, match="must be a tensor"):
not_a_tensor = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0]).tolist()} not_a_tensor = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0]).tolist()}
transforms.SanitizeBoundingBox()(not_a_tensor) transforms.SanitizeBoundingBoxes()(not_a_tensor)
with pytest.raises(ValueError, match="Number of boxes"): with pytest.raises(ValueError, match="Number of boxes"):
different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)} different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)}
transforms.SanitizeBoundingBox()(different_sizes) transforms.SanitizeBoundingBoxes()(different_sizes)
with pytest.raises(ValueError, match="boxes must be of shape"): with pytest.raises(ValueError, match="boxes must be of shape"):
bad_bbox = datapoints.BoundingBox( # batch with 2 elements bad_bbox = datapoints.BoundingBoxes( # batch with 2 elements
[ [
[[0, 0, 10, 10]], [[0, 0, 10, 10]],
[[0, 0, 10, 10]], [[0, 0, 10, 10]],
...@@ -1681,7 +1681,7 @@ def test_sanitize_bounding_boxes_errors(): ...@@ -1681,7 +1681,7 @@ def test_sanitize_bounding_boxes_errors():
spatial_size=(20, 20), spatial_size=(20, 20),
) )
different_sizes = {"bbox": bad_bbox, "labels": torch.arange(bad_bbox.shape[0])} different_sizes = {"bbox": bad_bbox, "labels": torch.arange(bad_bbox.shape[0])}
transforms.SanitizeBoundingBox()(different_sizes) transforms.SanitizeBoundingBoxes()(different_sizes)
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -1127,7 +1127,7 @@ class TestRefDetTransforms: ...@@ -1127,7 +1127,7 @@ class TestRefDetTransforms:
v2_transforms.Compose( v2_transforms.Compose(
[ [
v2_transforms.RandomIoUCrop(), v2_transforms.RandomIoUCrop(),
v2_transforms.SanitizeBoundingBox(labels_getter=lambda sample: sample[1]["labels"]), v2_transforms.SanitizeBoundingBoxes(labels_getter=lambda sample: sample[1]["labels"]),
] ]
), ),
{"with_mask": False}, {"with_mask": False},
......
...@@ -26,7 +26,7 @@ from torchvision import datapoints ...@@ -26,7 +26,7 @@ from torchvision import datapoints
from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
from torchvision.transforms.v2.functional._meta import clamp_bounding_box, convert_format_bounding_box from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes
from torchvision.transforms.v2.utils import is_simple_tensor from torchvision.transforms.v2.utils import is_simple_tensor
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS from transforms_v2_kernel_infos import KERNEL_INFOS
...@@ -176,7 +176,7 @@ class TestKernels: ...@@ -176,7 +176,7 @@ class TestKernels:
# Everything to the left is considered a batch dimension. # Everything to the left is considered a batch dimension.
data_dims = { data_dims = {
datapoints.Image: 3, datapoints.Image: 3,
datapoints.BoundingBox: 1, datapoints.BoundingBoxes: 1,
# `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks # `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
# it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one # it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
# type all kernels should also work without differentiating between the two. Thus, we go with 2 here as # type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
...@@ -515,15 +515,15 @@ class TestDispatchers: ...@@ -515,15 +515,15 @@ class TestDispatchers:
[ [
info info
for info in DISPATCHER_INFOS for info in DISPATCHER_INFOS
if datapoints.BoundingBox in info.kernels and info.dispatcher is not F.convert_format_bounding_box if datapoints.BoundingBoxes in info.kernels and info.dispatcher is not F.convert_format_bounding_boxes
], ],
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.BoundingBox), args_kwargs_fn=lambda info: info.sample_inputs(datapoints.BoundingBoxes),
) )
def test_bounding_box_format_consistency(self, info, args_kwargs): def test_bounding_boxes_format_consistency(self, info, args_kwargs):
(bounding_box, *other_args), kwargs = args_kwargs.load() (bounding_boxes, *other_args), kwargs = args_kwargs.load()
format = bounding_box.format format = bounding_boxes.format
output = info.dispatcher(bounding_box, *other_args, **kwargs) output = info.dispatcher(bounding_boxes, *other_args, **kwargs)
assert output.format == format assert output.format == format
...@@ -562,7 +562,7 @@ def test_normalize_image_tensor_stats(device, num_channels): ...@@ -562,7 +562,7 @@ def test_normalize_image_tensor_stats(device, num_channels):
assert_samples_from_standard_normal(F.normalize_image_tensor(image, mean, std)) assert_samples_from_standard_normal(F.normalize_image_tensor(image, mean, std))
class TestClampBoundingBox: class TestClampBoundingBoxes:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"metadata", "metadata",
[ [
...@@ -575,7 +575,7 @@ class TestClampBoundingBox: ...@@ -575,7 +575,7 @@ class TestClampBoundingBox:
simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor) simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)
with pytest.raises(ValueError, match=re.escape("`format` and `spatial_size` has to be passed")): with pytest.raises(ValueError, match=re.escape("`format` and `spatial_size` has to be passed")):
F.clamp_bounding_box(simple_tensor, **metadata) F.clamp_bounding_boxes(simple_tensor, **metadata)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"metadata", "metadata",
...@@ -589,10 +589,10 @@ class TestClampBoundingBox: ...@@ -589,10 +589,10 @@ class TestClampBoundingBox:
datapoint = next(make_bounding_boxes()) datapoint = next(make_bounding_boxes())
with pytest.raises(ValueError, match=re.escape("`format` and `spatial_size` must not be passed")): with pytest.raises(ValueError, match=re.escape("`format` and `spatial_size` must not be passed")):
F.clamp_bounding_box(datapoint, **metadata) F.clamp_bounding_boxes(datapoint, **metadata)
class TestConvertFormatBoundingBox: class TestConvertFormatBoundingBoxes:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("inpt", "old_format"), ("inpt", "old_format"),
[ [
...@@ -602,19 +602,19 @@ class TestConvertFormatBoundingBox: ...@@ -602,19 +602,19 @@ class TestConvertFormatBoundingBox:
) )
def test_missing_new_format(self, inpt, old_format): def test_missing_new_format(self, inpt, old_format):
with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")): with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")):
F.convert_format_bounding_box(inpt, old_format) F.convert_format_bounding_boxes(inpt, old_format)
def test_simple_tensor_insufficient_metadata(self): def test_simple_tensor_insufficient_metadata(self):
simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor) simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)
with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")): with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
F.convert_format_bounding_box(simple_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH) F.convert_format_bounding_boxes(simple_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)
def test_datapoint_explicit_metadata(self): def test_datapoint_explicit_metadata(self):
datapoint = next(make_bounding_boxes()) datapoint = next(make_bounding_boxes())
with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")): with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")):
F.convert_format_bounding_box( F.convert_format_bounding_boxes(
datapoint, old_format=datapoint.format, new_format=datapoints.BoundingBoxFormat.CXCYWH datapoint, old_format=datapoint.format, new_format=datapoints.BoundingBoxFormat.CXCYWH
) )
...@@ -658,7 +658,7 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_): ...@@ -658,7 +658,7 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
[-8, 12, 70, 40, [(-2.0, 23.0, 13.0, 43.0), (38.0, 13.0, 58.0, 30.0), (33.0, 54.0, 44.0, 70.0)]], [-8, 12, 70, 40, [(-2.0, 23.0, 13.0, 43.0), (38.0, 13.0, 58.0, 30.0), (33.0, 54.0, 44.0, 70.0)]],
], ],
) )
def test_correctness_crop_bounding_box(device, format, top, left, height, width, expected_bboxes): def test_correctness_crop_bounding_boxes(device, format, top, left, height, width, expected_bboxes):
# Expected bboxes computed using Albumentations: # Expected bboxes computed using Albumentations:
# import numpy as np # import numpy as np
...@@ -681,13 +681,13 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, ...@@ -681,13 +681,13 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
] ]
in_boxes = torch.tensor(in_boxes, device=device) in_boxes = torch.tensor(in_boxes, device=device)
if format != datapoints.BoundingBoxFormat.XYXY: if format != datapoints.BoundingBoxFormat.XYXY:
in_boxes = convert_format_bounding_box(in_boxes, datapoints.BoundingBoxFormat.XYXY, format) in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
expected_bboxes = clamp_bounding_box( expected_bboxes = clamp_bounding_boxes(
datapoints.BoundingBox(expected_bboxes, format="XYXY", spatial_size=spatial_size) datapoints.BoundingBoxes(expected_bboxes, format="XYXY", spatial_size=spatial_size)
).tolist() ).tolist()
output_boxes, output_spatial_size = F.crop_bounding_box( output_boxes, output_spatial_size = F.crop_bounding_boxes(
in_boxes, in_boxes,
format, format,
top, top,
...@@ -697,7 +697,7 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, ...@@ -697,7 +697,7 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
) )
if format != datapoints.BoundingBoxFormat.XYXY: if format != datapoints.BoundingBoxFormat.XYXY:
output_boxes = convert_format_bounding_box(output_boxes, format, datapoints.BoundingBoxFormat.XYXY) output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
torch.testing.assert_close(output_spatial_size, spatial_size) torch.testing.assert_close(output_spatial_size, spatial_size)
...@@ -727,7 +727,7 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device): ...@@ -727,7 +727,7 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
[-5, 5, 35, 45, (32, 34)], [-5, 5, 35, 45, (32, 34)],
], ],
) )
def test_correctness_resized_crop_bounding_box(device, format, top, left, height, width, size): def test_correctness_resized_crop_bounding_boxes(device, format, top, left, height, width, size):
def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_): def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_):
# bbox should be xyxy # bbox should be xyxy
bbox[0] = (bbox[0] - left_) * size_[1] / width_ bbox[0] = (bbox[0] - left_) * size_[1] / width_
...@@ -747,16 +747,16 @@ def test_correctness_resized_crop_bounding_box(device, format, top, left, height ...@@ -747,16 +747,16 @@ def test_correctness_resized_crop_bounding_box(device, format, top, left, height
expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size)) expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size))
expected_bboxes = torch.tensor(expected_bboxes, device=device) expected_bboxes = torch.tensor(expected_bboxes, device=device)
in_boxes = datapoints.BoundingBox( in_boxes = datapoints.BoundingBoxes(
in_boxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, device=device in_boxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, device=device
) )
if format != datapoints.BoundingBoxFormat.XYXY: if format != datapoints.BoundingBoxFormat.XYXY:
in_boxes = convert_format_bounding_box(in_boxes, datapoints.BoundingBoxFormat.XYXY, format) in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
output_boxes, output_spatial_size = F.resized_crop_bounding_box(in_boxes, format, top, left, height, width, size) output_boxes, output_spatial_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size)
if format != datapoints.BoundingBoxFormat.XYXY: if format != datapoints.BoundingBoxFormat.XYXY:
output_boxes = convert_format_bounding_box(output_boxes, format, datapoints.BoundingBoxFormat.XYXY) output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes, expected_bboxes) torch.testing.assert_close(output_boxes, expected_bboxes)
torch.testing.assert_close(output_spatial_size, size) torch.testing.assert_close(output_spatial_size, size)
...@@ -776,7 +776,7 @@ def _parse_padding(padding): ...@@ -776,7 +776,7 @@ def _parse_padding(padding):
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("padding", [[1], [1, 1], [1, 1, 2, 2]]) @pytest.mark.parametrize("padding", [[1], [1, 1], [1, 1, 2, 2]])
def test_correctness_pad_bounding_box(device, padding): def test_correctness_pad_bounding_boxes(device, padding):
def _compute_expected_bbox(bbox, padding_): def _compute_expected_bbox(bbox, padding_):
pad_left, pad_up, _, _ = _parse_padding(padding_) pad_left, pad_up, _, _ = _parse_padding(padding_)
...@@ -785,13 +785,13 @@ def test_correctness_pad_bounding_box(device, padding): ...@@ -785,13 +785,13 @@ def test_correctness_pad_bounding_box(device, padding):
bbox = ( bbox = (
bbox.clone() bbox.clone()
if format == datapoints.BoundingBoxFormat.XYXY if format == datapoints.BoundingBoxFormat.XYXY
else convert_format_bounding_box(bbox, new_format=datapoints.BoundingBoxFormat.XYXY) else convert_format_bounding_boxes(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
) )
bbox[0::2] += pad_left bbox[0::2] += pad_left
bbox[1::2] += pad_up bbox[1::2] += pad_up
bbox = convert_format_bounding_box(bbox, new_format=format) bbox = convert_format_bounding_boxes(bbox, new_format=format)
if bbox.dtype != dtype: if bbox.dtype != dtype:
# Temporary cast to original dtype # Temporary cast to original dtype
# e.g. float32 -> int # e.g. float32 -> int
...@@ -808,7 +808,7 @@ def test_correctness_pad_bounding_box(device, padding): ...@@ -808,7 +808,7 @@ def test_correctness_pad_bounding_box(device, padding):
bboxes_format = bboxes.format bboxes_format = bboxes.format
bboxes_spatial_size = bboxes.spatial_size bboxes_spatial_size = bboxes.spatial_size
output_boxes, output_spatial_size = F.pad_bounding_box( output_boxes, output_spatial_size = F.pad_bounding_boxes(
bboxes, format=bboxes_format, spatial_size=bboxes_spatial_size, padding=padding bboxes, format=bboxes_format, spatial_size=bboxes_spatial_size, padding=padding
) )
...@@ -819,7 +819,7 @@ def test_correctness_pad_bounding_box(device, padding): ...@@ -819,7 +819,7 @@ def test_correctness_pad_bounding_box(device, padding):
expected_bboxes = [] expected_bboxes = []
for bbox in bboxes: for bbox in bboxes:
bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size)
expected_bboxes.append(_compute_expected_bbox(bbox, padding)) expected_bboxes.append(_compute_expected_bbox(bbox, padding))
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
...@@ -849,7 +849,7 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device): ...@@ -849,7 +849,7 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device):
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]], [[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
], ],
) )
def test_correctness_perspective_bounding_box(device, startpoints, endpoints): def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
def _compute_expected_bbox(bbox, pcoeffs_): def _compute_expected_bbox(bbox, pcoeffs_):
m1 = np.array( m1 = np.array(
[ [
...@@ -864,7 +864,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -864,7 +864,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
] ]
) )
bbox_xyxy = convert_format_bounding_box(bbox, new_format=datapoints.BoundingBoxFormat.XYXY) bbox_xyxy = convert_format_bounding_boxes(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
points = np.array( points = np.array(
[ [
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0], [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
...@@ -884,14 +884,14 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -884,14 +884,14 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
np.max(transformed_points[:, 1]), np.max(transformed_points[:, 1]),
] ]
) )
out_bbox = datapoints.BoundingBox( out_bbox = datapoints.BoundingBoxes(
out_bbox, out_bbox,
format=datapoints.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=bbox.spatial_size, spatial_size=bbox.spatial_size,
dtype=bbox.dtype, dtype=bbox.dtype,
device=bbox.device, device=bbox.device,
) )
return clamp_bounding_box(convert_format_bounding_box(out_bbox, new_format=bbox.format)) return clamp_bounding_boxes(convert_format_bounding_boxes(out_bbox, new_format=bbox.format))
spatial_size = (32, 38) spatial_size = (32, 38)
...@@ -901,7 +901,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -901,7 +901,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
for bboxes in make_bounding_boxes(spatial_size=spatial_size, extra_dims=((4,),)): for bboxes in make_bounding_boxes(spatial_size=spatial_size, extra_dims=((4,),)):
bboxes = bboxes.to(device) bboxes = bboxes.to(device)
output_bboxes = F.perspective_bounding_box( output_bboxes = F.perspective_bounding_boxes(
bboxes.as_subclass(torch.Tensor), bboxes.as_subclass(torch.Tensor),
format=bboxes.format, format=bboxes.format,
spatial_size=bboxes.spatial_size, spatial_size=bboxes.spatial_size,
...@@ -915,7 +915,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -915,7 +915,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
expected_bboxes = [] expected_bboxes = []
for bbox in bboxes: for bbox in bboxes:
bbox = datapoints.BoundingBox(bbox, format=bboxes.format, spatial_size=bboxes.spatial_size) bbox = datapoints.BoundingBoxes(bbox, format=bboxes.format, spatial_size=bboxes.spatial_size)
expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs)) expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs))
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes) expected_bboxes = torch.stack(expected_bboxes)
...@@ -929,12 +929,12 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -929,12 +929,12 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
"output_size", "output_size",
[(18, 18), [18, 15], (16, 19), [12], [46, 48]], [(18, 18), [18, 15], (16, 19), [12], [46, 48]],
) )
def test_correctness_center_crop_bounding_box(device, output_size): def test_correctness_center_crop_bounding_boxes(device, output_size):
def _compute_expected_bbox(bbox, output_size_): def _compute_expected_bbox(bbox, output_size_):
format_ = bbox.format format_ = bbox.format
spatial_size_ = bbox.spatial_size spatial_size_ = bbox.spatial_size
dtype = bbox.dtype dtype = bbox.dtype
bbox = convert_format_bounding_box(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH) bbox = convert_format_bounding_boxes(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH)
if len(output_size_) == 1: if len(output_size_) == 1:
output_size_.append(output_size_[-1]) output_size_.append(output_size_[-1])
...@@ -948,8 +948,8 @@ def test_correctness_center_crop_bounding_box(device, output_size): ...@@ -948,8 +948,8 @@ def test_correctness_center_crop_bounding_box(device, output_size):
bbox[3].item(), bbox[3].item(),
] ]
out_bbox = torch.tensor(out_bbox) out_bbox = torch.tensor(out_bbox)
out_bbox = convert_format_bounding_box(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_) out_bbox = convert_format_bounding_boxes(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_)
out_bbox = clamp_bounding_box(out_bbox, format=format_, spatial_size=output_size) out_bbox = clamp_bounding_boxes(out_bbox, format=format_, spatial_size=output_size)
return out_bbox.to(dtype=dtype, device=bbox.device) return out_bbox.to(dtype=dtype, device=bbox.device)
for bboxes in make_bounding_boxes(extra_dims=((4,),)): for bboxes in make_bounding_boxes(extra_dims=((4,),)):
...@@ -957,7 +957,7 @@ def test_correctness_center_crop_bounding_box(device, output_size): ...@@ -957,7 +957,7 @@ def test_correctness_center_crop_bounding_box(device, output_size):
bboxes_format = bboxes.format bboxes_format = bboxes.format
bboxes_spatial_size = bboxes.spatial_size bboxes_spatial_size = bboxes.spatial_size
output_boxes, output_spatial_size = F.center_crop_bounding_box( output_boxes, output_spatial_size = F.center_crop_bounding_boxes(
bboxes, bboxes_format, bboxes_spatial_size, output_size bboxes, bboxes_format, bboxes_spatial_size, output_size
) )
...@@ -966,7 +966,7 @@ def test_correctness_center_crop_bounding_box(device, output_size): ...@@ -966,7 +966,7 @@ def test_correctness_center_crop_bounding_box(device, output_size):
expected_bboxes = [] expected_bboxes = []
for bbox in bboxes: for bbox in bboxes:
bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size)
expected_bboxes.append(_compute_expected_bbox(bbox, output_size)) expected_bboxes.append(_compute_expected_bbox(bbox, output_size))
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
......
This diff is collapsed.
...@@ -20,20 +20,20 @@ MASK = make_detection_mask(size=IMAGE.spatial_size) ...@@ -20,20 +20,20 @@ MASK = make_detection_mask(size=IMAGE.spatial_size)
("sample", "types", "expected"), ("sample", "types", "expected"),
[ [
((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image,), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image,), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBox,), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBoxes,), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.Mask,), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBox), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBoxes), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBox, datapoints.Mask), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBoxes, datapoints.Mask), True),
((MASK,), (datapoints.Image, datapoints.BoundingBox), False), ((MASK,), (datapoints.Image, datapoints.BoundingBoxes), False),
((BOUNDING_BOX,), (datapoints.Image, datapoints.Mask), False), ((BOUNDING_BOX,), (datapoints.Image, datapoints.Mask), False),
((IMAGE,), (datapoints.BoundingBox, datapoints.Mask), False), ((IMAGE,), (datapoints.BoundingBoxes, datapoints.Mask), False),
( (
(IMAGE, BOUNDING_BOX, MASK), (IMAGE, BOUNDING_BOX, MASK),
(datapoints.Image, datapoints.BoundingBox, datapoints.Mask), (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask),
True, True,
), ),
((), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), False), ((), (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True), ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
...@@ -58,30 +58,30 @@ def test_has_any(sample, types, expected): ...@@ -58,30 +58,30 @@ def test_has_any(sample, types, expected):
("sample", "types", "expected"), ("sample", "types", "expected"),
[ [
((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image,), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image,), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBox,), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBoxes,), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.Mask,), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBox), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBoxes), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBox, datapoints.Mask), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBoxes, datapoints.Mask), True),
( (
(IMAGE, BOUNDING_BOX, MASK), (IMAGE, BOUNDING_BOX, MASK),
(datapoints.Image, datapoints.BoundingBox, datapoints.Mask), (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask),
True, True,
), ),
((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBox), False), ((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBoxes), False),
((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), False), ((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), False),
((IMAGE, MASK), (datapoints.BoundingBox, datapoints.Mask), False), ((IMAGE, MASK), (datapoints.BoundingBoxes, datapoints.Mask), False),
( (
(IMAGE, BOUNDING_BOX, MASK), (IMAGE, BOUNDING_BOX, MASK),
(datapoints.Image, datapoints.BoundingBox, datapoints.Mask), (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask),
True, True,
), ),
((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), False), ((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), False),
((IMAGE, MASK), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), False), ((IMAGE, MASK), (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), False),
((IMAGE, BOUNDING_BOX), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), False), ((IMAGE, BOUNDING_BOX), (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), False),
( (
(IMAGE, BOUNDING_BOX, MASK), (IMAGE, BOUNDING_BOX, MASK),
(lambda obj: isinstance(obj, (datapoints.Image, datapoints.BoundingBox, datapoints.Mask)),), (lambda obj: isinstance(obj, (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask)),),
True, True,
), ),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
......
...@@ -143,7 +143,7 @@ DISPATCHER_INFOS = [ ...@@ -143,7 +143,7 @@ DISPATCHER_INFOS = [
kernels={ kernels={
datapoints.Image: F.crop_image_tensor, datapoints.Image: F.crop_image_tensor,
datapoints.Video: F.crop_video, datapoints.Video: F.crop_video,
datapoints.BoundingBox: F.crop_bounding_box, datapoints.BoundingBoxes: F.crop_bounding_boxes,
datapoints.Mask: F.crop_mask, datapoints.Mask: F.crop_mask,
}, },
pil_kernel_info=PILKernelInfo(F.crop_image_pil, kernel_name="crop_image_pil"), pil_kernel_info=PILKernelInfo(F.crop_image_pil, kernel_name="crop_image_pil"),
...@@ -153,7 +153,7 @@ DISPATCHER_INFOS = [ ...@@ -153,7 +153,7 @@ DISPATCHER_INFOS = [
kernels={ kernels={
datapoints.Image: F.resized_crop_image_tensor, datapoints.Image: F.resized_crop_image_tensor,
datapoints.Video: F.resized_crop_video, datapoints.Video: F.resized_crop_video,
datapoints.BoundingBox: F.resized_crop_bounding_box, datapoints.BoundingBoxes: F.resized_crop_bounding_boxes,
datapoints.Mask: F.resized_crop_mask, datapoints.Mask: F.resized_crop_mask,
}, },
pil_kernel_info=PILKernelInfo(F.resized_crop_image_pil), pil_kernel_info=PILKernelInfo(F.resized_crop_image_pil),
...@@ -163,7 +163,7 @@ DISPATCHER_INFOS = [ ...@@ -163,7 +163,7 @@ DISPATCHER_INFOS = [
kernels={ kernels={
datapoints.Image: F.pad_image_tensor, datapoints.Image: F.pad_image_tensor,
datapoints.Video: F.pad_video, datapoints.Video: F.pad_video,
datapoints.BoundingBox: F.pad_bounding_box, datapoints.BoundingBoxes: F.pad_bounding_boxes,
datapoints.Mask: F.pad_mask, datapoints.Mask: F.pad_mask,
}, },
pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"), pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"),
...@@ -185,7 +185,7 @@ DISPATCHER_INFOS = [ ...@@ -185,7 +185,7 @@ DISPATCHER_INFOS = [
kernels={ kernels={
datapoints.Image: F.perspective_image_tensor, datapoints.Image: F.perspective_image_tensor,
datapoints.Video: F.perspective_video, datapoints.Video: F.perspective_video,
datapoints.BoundingBox: F.perspective_bounding_box, datapoints.BoundingBoxes: F.perspective_bounding_boxes,
datapoints.Mask: F.perspective_mask, datapoints.Mask: F.perspective_mask,
}, },
pil_kernel_info=PILKernelInfo(F.perspective_image_pil), pil_kernel_info=PILKernelInfo(F.perspective_image_pil),
...@@ -199,7 +199,7 @@ DISPATCHER_INFOS = [ ...@@ -199,7 +199,7 @@ DISPATCHER_INFOS = [
kernels={ kernels={
datapoints.Image: F.elastic_image_tensor, datapoints.Image: F.elastic_image_tensor,
datapoints.Video: F.elastic_video, datapoints.Video: F.elastic_video,
datapoints.BoundingBox: F.elastic_bounding_box, datapoints.BoundingBoxes: F.elastic_bounding_boxes,
datapoints.Mask: F.elastic_mask, datapoints.Mask: F.elastic_mask,
}, },
pil_kernel_info=PILKernelInfo(F.elastic_image_pil), pil_kernel_info=PILKernelInfo(F.elastic_image_pil),
...@@ -210,7 +210,7 @@ DISPATCHER_INFOS = [ ...@@ -210,7 +210,7 @@ DISPATCHER_INFOS = [
kernels={ kernels={
datapoints.Image: F.center_crop_image_tensor, datapoints.Image: F.center_crop_image_tensor,
datapoints.Video: F.center_crop_video, datapoints.Video: F.center_crop_video,
datapoints.BoundingBox: F.center_crop_bounding_box, datapoints.BoundingBoxes: F.center_crop_bounding_boxes,
datapoints.Mask: F.center_crop_mask, datapoints.Mask: F.center_crop_mask,
}, },
pil_kernel_info=PILKernelInfo(F.center_crop_image_pil), pil_kernel_info=PILKernelInfo(F.center_crop_image_pil),
...@@ -374,15 +374,15 @@ DISPATCHER_INFOS = [ ...@@ -374,15 +374,15 @@ DISPATCHER_INFOS = [
], ],
), ),
DispatcherInfo( DispatcherInfo(
F.clamp_bounding_box, F.clamp_bounding_boxes,
kernels={datapoints.BoundingBox: F.clamp_bounding_box}, kernels={datapoints.BoundingBoxes: F.clamp_bounding_boxes},
test_marks=[ test_marks=[
skip_dispatch_datapoint, skip_dispatch_datapoint,
], ],
), ),
DispatcherInfo( DispatcherInfo(
F.convert_format_bounding_box, F.convert_format_bounding_boxes,
kernels={datapoints.BoundingBox: F.convert_format_bounding_box}, kernels={datapoints.BoundingBoxes: F.convert_format_bounding_boxes},
test_marks=[ test_marks=[
skip_dispatch_datapoint, skip_dispatch_datapoint,
], ],
......
...@@ -184,13 +184,13 @@ def float32_vs_uint8_fill_adapter(other_args, kwargs): ...@@ -184,13 +184,13 @@ def float32_vs_uint8_fill_adapter(other_args, kwargs):
return other_args, dict(kwargs, fill=fill) return other_args, dict(kwargs, fill=fill)
def reference_affine_bounding_box_helper(bounding_box, *, format, spatial_size, affine_matrix): def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, spatial_size, affine_matrix):
def transform(bbox, affine_matrix_, format_, spatial_size_): def transform(bbox, affine_matrix_, format_, spatial_size_):
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1 # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
in_dtype = bbox.dtype in_dtype = bbox.dtype
if not torch.is_floating_point(bbox): if not torch.is_floating_point(bbox):
bbox = bbox.float() bbox = bbox.float()
bbox_xyxy = F.convert_format_bounding_box( bbox_xyxy = F.convert_format_bounding_boxes(
bbox.as_subclass(torch.Tensor), bbox.as_subclass(torch.Tensor),
old_format=format_, old_format=format_,
new_format=datapoints.BoundingBoxFormat.XYXY, new_format=datapoints.BoundingBoxFormat.XYXY,
...@@ -214,18 +214,18 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, spatial_size, ...@@ -214,18 +214,18 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, spatial_size,
], ],
dtype=bbox_xyxy.dtype, dtype=bbox_xyxy.dtype,
) )
out_bbox = F.convert_format_bounding_box( out_bbox = F.convert_format_bounding_boxes(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True
) )
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64 # It is important to clamp before casting, especially for CXCYWH format, dtype=int64
out_bbox = F.clamp_bounding_box(out_bbox, format=format_, spatial_size=spatial_size_) out_bbox = F.clamp_bounding_boxes(out_bbox, format=format_, spatial_size=spatial_size_)
out_bbox = out_bbox.to(dtype=in_dtype) out_bbox = out_bbox.to(dtype=in_dtype)
return out_bbox return out_bbox
if bounding_box.ndim < 2: if bounding_boxes.ndim < 2:
bounding_box = [bounding_box] bounding_boxes = [bounding_boxes]
expected_bboxes = [transform(bbox, affine_matrix, format, spatial_size) for bbox in bounding_box] expected_bboxes = [transform(bbox, affine_matrix, format, spatial_size) for bbox in bounding_boxes]
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes) expected_bboxes = torch.stack(expected_bboxes)
else: else:
...@@ -234,30 +234,30 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, spatial_size, ...@@ -234,30 +234,30 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, spatial_size,
return expected_bboxes return expected_bboxes
def sample_inputs_convert_format_bounding_box(): def sample_inputs_convert_format_bounding_boxes():
formats = list(datapoints.BoundingBoxFormat) formats = list(datapoints.BoundingBoxFormat)
for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats): for bounding_boxes_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
yield ArgsKwargs(bounding_box_loader, old_format=bounding_box_loader.format, new_format=new_format) yield ArgsKwargs(bounding_boxes_loader, old_format=bounding_boxes_loader.format, new_format=new_format)
def reference_convert_format_bounding_box(bounding_box, old_format, new_format): def reference_convert_format_bounding_boxes(bounding_boxes, old_format, new_format):
return torchvision.ops.box_convert( return torchvision.ops.box_convert(
bounding_box, in_fmt=old_format.name.lower(), out_fmt=new_format.name.lower() bounding_boxes, in_fmt=old_format.name.lower(), out_fmt=new_format.name.lower()
).to(bounding_box.dtype) ).to(bounding_boxes.dtype)
def reference_inputs_convert_format_bounding_box(): def reference_inputs_convert_format_bounding_boxes():
for args_kwargs in sample_inputs_convert_format_bounding_box(): for args_kwargs in sample_inputs_convert_format_bounding_boxes():
if len(args_kwargs.args[0].shape) == 2: if len(args_kwargs.args[0].shape) == 2:
yield args_kwargs yield args_kwargs
KERNEL_INFOS.append( KERNEL_INFOS.append(
KernelInfo( KernelInfo(
F.convert_format_bounding_box, F.convert_format_bounding_boxes,
sample_inputs_fn=sample_inputs_convert_format_bounding_box, sample_inputs_fn=sample_inputs_convert_format_bounding_boxes,
reference_fn=reference_convert_format_bounding_box, reference_fn=reference_convert_format_bounding_boxes,
reference_inputs_fn=reference_inputs_convert_format_bounding_box, reference_inputs_fn=reference_inputs_convert_format_bounding_boxes,
logs_usage=True, logs_usage=True,
closeness_kwargs={ closeness_kwargs={
(("TestKernels", "test_against_reference"), torch.int64, "cpu"): dict(atol=1, rtol=0), (("TestKernels", "test_against_reference"), torch.int64, "cpu"): dict(atol=1, rtol=0),
...@@ -290,11 +290,11 @@ def reference_inputs_crop_image_tensor(): ...@@ -290,11 +290,11 @@ def reference_inputs_crop_image_tensor():
yield ArgsKwargs(image_loader, **params) yield ArgsKwargs(image_loader, **params)
def sample_inputs_crop_bounding_box(): def sample_inputs_crop_bounding_boxes():
for bounding_box_loader, params in itertools.product( for bounding_boxes_loader, params in itertools.product(
make_bounding_box_loaders(), [_CROP_PARAMS[0], _CROP_PARAMS[-1]] make_bounding_box_loaders(), [_CROP_PARAMS[0], _CROP_PARAMS[-1]]
): ):
yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **params) yield ArgsKwargs(bounding_boxes_loader, format=bounding_boxes_loader.format, **params)
def sample_inputs_crop_mask(): def sample_inputs_crop_mask():
...@@ -312,27 +312,27 @@ def sample_inputs_crop_video(): ...@@ -312,27 +312,27 @@ def sample_inputs_crop_video():
yield ArgsKwargs(video_loader, top=4, left=3, height=7, width=8) yield ArgsKwargs(video_loader, top=4, left=3, height=7, width=8)
def reference_crop_bounding_box(bounding_box, *, format, top, left, height, width): def reference_crop_bounding_boxes(bounding_boxes, *, format, top, left, height, width):
affine_matrix = np.array( affine_matrix = np.array(
[ [
[1, 0, -left], [1, 0, -left],
[0, 1, -top], [0, 1, -top],
], ],
dtype="float64" if bounding_box.dtype == torch.float64 else "float32", dtype="float64" if bounding_boxes.dtype == torch.float64 else "float32",
) )
spatial_size = (height, width) spatial_size = (height, width)
expected_bboxes = reference_affine_bounding_box_helper( expected_bboxes = reference_affine_bounding_boxes_helper(
bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix bounding_boxes, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix
) )
return expected_bboxes, spatial_size return expected_bboxes, spatial_size
def reference_inputs_crop_bounding_box(): def reference_inputs_crop_bounding_boxes():
for bounding_box_loader, params in itertools.product( for bounding_boxes_loader, params in itertools.product(
make_bounding_box_loaders(extra_dims=((), (4,))), [_CROP_PARAMS[0], _CROP_PARAMS[-1]] make_bounding_box_loaders(extra_dims=((), (4,))), [_CROP_PARAMS[0], _CROP_PARAMS[-1]]
): ):
yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **params) yield ArgsKwargs(bounding_boxes_loader, format=bounding_boxes_loader.format, **params)
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
...@@ -346,10 +346,10 @@ KERNEL_INFOS.extend( ...@@ -346,10 +346,10 @@ KERNEL_INFOS.extend(
float32_vs_uint8=True, float32_vs_uint8=True,
), ),
KernelInfo( KernelInfo(
F.crop_bounding_box, F.crop_bounding_boxes,
sample_inputs_fn=sample_inputs_crop_bounding_box, sample_inputs_fn=sample_inputs_crop_bounding_boxes,
reference_fn=reference_crop_bounding_box, reference_fn=reference_crop_bounding_boxes,
reference_inputs_fn=reference_inputs_crop_bounding_box, reference_inputs_fn=reference_inputs_crop_bounding_boxes,
), ),
KernelInfo( KernelInfo(
F.crop_mask, F.crop_mask,
...@@ -406,9 +406,9 @@ def reference_inputs_resized_crop_image_tensor(): ...@@ -406,9 +406,9 @@ def reference_inputs_resized_crop_image_tensor():
) )
def sample_inputs_resized_crop_bounding_box(): def sample_inputs_resized_crop_bounding_boxes():
for bounding_box_loader in make_bounding_box_loaders(): for bounding_boxes_loader in make_bounding_box_loaders():
yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **_RESIZED_CROP_PARAMS[0]) yield ArgsKwargs(bounding_boxes_loader, format=bounding_boxes_loader.format, **_RESIZED_CROP_PARAMS[0])
def sample_inputs_resized_crop_mask(): def sample_inputs_resized_crop_mask():
...@@ -436,8 +436,8 @@ KERNEL_INFOS.extend( ...@@ -436,8 +436,8 @@ KERNEL_INFOS.extend(
}, },
), ),
KernelInfo( KernelInfo(
F.resized_crop_bounding_box, F.resized_crop_bounding_boxes,
sample_inputs_fn=sample_inputs_resized_crop_bounding_box, sample_inputs_fn=sample_inputs_resized_crop_bounding_boxes,
), ),
KernelInfo( KernelInfo(
F.resized_crop_mask, F.resized_crop_mask,
...@@ -500,14 +500,14 @@ def reference_inputs_pad_image_tensor(): ...@@ -500,14 +500,14 @@ def reference_inputs_pad_image_tensor():
yield ArgsKwargs(image_loader, fill=fill, **params) yield ArgsKwargs(image_loader, fill=fill, **params)
def sample_inputs_pad_bounding_box(): def sample_inputs_pad_bounding_boxes():
for bounding_box_loader, padding in itertools.product( for bounding_boxes_loader, padding in itertools.product(
make_bounding_box_loaders(), [1, (1,), (1, 2), (1, 2, 3, 4), [1], [1, 2], [1, 2, 3, 4]] make_bounding_box_loaders(), [1, (1,), (1, 2), (1, 2, 3, 4), [1], [1, 2], [1, 2, 3, 4]]
): ):
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader, bounding_boxes_loader,
format=bounding_box_loader.format, format=bounding_boxes_loader.format,
spatial_size=bounding_box_loader.spatial_size, spatial_size=bounding_boxes_loader.spatial_size,
padding=padding, padding=padding,
padding_mode="constant", padding_mode="constant",
) )
...@@ -530,7 +530,7 @@ def sample_inputs_pad_video(): ...@@ -530,7 +530,7 @@ def sample_inputs_pad_video():
yield ArgsKwargs(video_loader, padding=[1]) yield ArgsKwargs(video_loader, padding=[1])
def reference_pad_bounding_box(bounding_box, *, format, spatial_size, padding, padding_mode): def reference_pad_bounding_boxes(bounding_boxes, *, format, spatial_size, padding, padding_mode):
left, right, top, bottom = _parse_pad_padding(padding) left, right, top, bottom = _parse_pad_padding(padding)
...@@ -539,26 +539,26 @@ def reference_pad_bounding_box(bounding_box, *, format, spatial_size, padding, p ...@@ -539,26 +539,26 @@ def reference_pad_bounding_box(bounding_box, *, format, spatial_size, padding, p
[1, 0, left], [1, 0, left],
[0, 1, top], [0, 1, top],
], ],
dtype="float64" if bounding_box.dtype == torch.float64 else "float32", dtype="float64" if bounding_boxes.dtype == torch.float64 else "float32",
) )
height = spatial_size[0] + top + bottom height = spatial_size[0] + top + bottom
width = spatial_size[1] + left + right width = spatial_size[1] + left + right
expected_bboxes = reference_affine_bounding_box_helper( expected_bboxes = reference_affine_bounding_boxes_helper(
bounding_box, format=format, spatial_size=(height, width), affine_matrix=affine_matrix bounding_boxes, format=format, spatial_size=(height, width), affine_matrix=affine_matrix
) )
return expected_bboxes, (height, width) return expected_bboxes, (height, width)
def reference_inputs_pad_bounding_box(): def reference_inputs_pad_bounding_boxes():
for bounding_box_loader, padding in itertools.product( for bounding_boxes_loader, padding in itertools.product(
make_bounding_box_loaders(extra_dims=((), (4,))), [1, (1,), (1, 2), (1, 2, 3, 4), [1], [1, 2], [1, 2, 3, 4]] make_bounding_box_loaders(extra_dims=((), (4,))), [1, (1,), (1, 2), (1, 2, 3, 4), [1], [1, 2], [1, 2, 3, 4]]
): ):
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader, bounding_boxes_loader,
format=bounding_box_loader.format, format=bounding_boxes_loader.format,
spatial_size=bounding_box_loader.spatial_size, spatial_size=bounding_boxes_loader.spatial_size,
padding=padding, padding=padding,
padding_mode="constant", padding_mode="constant",
) )
...@@ -591,10 +591,10 @@ KERNEL_INFOS.extend( ...@@ -591,10 +591,10 @@ KERNEL_INFOS.extend(
], ],
), ),
KernelInfo( KernelInfo(
F.pad_bounding_box, F.pad_bounding_boxes,
sample_inputs_fn=sample_inputs_pad_bounding_box, sample_inputs_fn=sample_inputs_pad_bounding_boxes,
reference_fn=reference_pad_bounding_box, reference_fn=reference_pad_bounding_boxes,
reference_inputs_fn=reference_inputs_pad_bounding_box, reference_inputs_fn=reference_inputs_pad_bounding_boxes,
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("padding"), xfail_jit_python_scalar_arg("padding"),
], ],
...@@ -655,12 +655,12 @@ def reference_inputs_perspective_image_tensor(): ...@@ -655,12 +655,12 @@ def reference_inputs_perspective_image_tensor():
) )
def sample_inputs_perspective_bounding_box(): def sample_inputs_perspective_bounding_boxes():
for bounding_box_loader in make_bounding_box_loaders(): for bounding_boxes_loader in make_bounding_box_loaders():
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader, bounding_boxes_loader,
format=bounding_box_loader.format, format=bounding_boxes_loader.format,
spatial_size=bounding_box_loader.spatial_size, spatial_size=bounding_boxes_loader.spatial_size,
startpoints=None, startpoints=None,
endpoints=None, endpoints=None,
coefficients=_PERSPECTIVE_COEFFS[0], coefficients=_PERSPECTIVE_COEFFS[0],
...@@ -712,8 +712,8 @@ KERNEL_INFOS.extend( ...@@ -712,8 +712,8 @@ KERNEL_INFOS.extend(
test_marks=[xfail_jit_python_scalar_arg("fill")], test_marks=[xfail_jit_python_scalar_arg("fill")],
), ),
KernelInfo( KernelInfo(
F.perspective_bounding_box, F.perspective_bounding_boxes,
sample_inputs_fn=sample_inputs_perspective_bounding_box, sample_inputs_fn=sample_inputs_perspective_bounding_boxes,
closeness_kwargs={ closeness_kwargs={
**scripted_vs_eager_float64_tolerances("cpu", atol=1e-6, rtol=1e-6), **scripted_vs_eager_float64_tolerances("cpu", atol=1e-6, rtol=1e-6),
**scripted_vs_eager_float64_tolerances("cuda", atol=1e-6, rtol=1e-6), **scripted_vs_eager_float64_tolerances("cuda", atol=1e-6, rtol=1e-6),
...@@ -767,13 +767,13 @@ def reference_inputs_elastic_image_tensor(): ...@@ -767,13 +767,13 @@ def reference_inputs_elastic_image_tensor():
yield ArgsKwargs(image_loader, interpolation=interpolation, displacement=displacement, fill=fill) yield ArgsKwargs(image_loader, interpolation=interpolation, displacement=displacement, fill=fill)
def sample_inputs_elastic_bounding_box(): def sample_inputs_elastic_bounding_boxes():
for bounding_box_loader in make_bounding_box_loaders(): for bounding_boxes_loader in make_bounding_box_loaders():
displacement = _get_elastic_displacement(bounding_box_loader.spatial_size) displacement = _get_elastic_displacement(bounding_boxes_loader.spatial_size)
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader, bounding_boxes_loader,
format=bounding_box_loader.format, format=bounding_boxes_loader.format,
spatial_size=bounding_box_loader.spatial_size, spatial_size=bounding_boxes_loader.spatial_size,
displacement=displacement, displacement=displacement,
) )
...@@ -804,8 +804,8 @@ KERNEL_INFOS.extend( ...@@ -804,8 +804,8 @@ KERNEL_INFOS.extend(
test_marks=[xfail_jit_python_scalar_arg("fill")], test_marks=[xfail_jit_python_scalar_arg("fill")],
), ),
KernelInfo( KernelInfo(
F.elastic_bounding_box, F.elastic_bounding_boxes,
sample_inputs_fn=sample_inputs_elastic_bounding_box, sample_inputs_fn=sample_inputs_elastic_bounding_boxes,
), ),
KernelInfo( KernelInfo(
F.elastic_mask, F.elastic_mask,
...@@ -845,12 +845,12 @@ def reference_inputs_center_crop_image_tensor(): ...@@ -845,12 +845,12 @@ def reference_inputs_center_crop_image_tensor():
yield ArgsKwargs(image_loader, output_size=output_size) yield ArgsKwargs(image_loader, output_size=output_size)
def sample_inputs_center_crop_bounding_box(): def sample_inputs_center_crop_bounding_boxes():
for bounding_box_loader, output_size in itertools.product(make_bounding_box_loaders(), _CENTER_CROP_OUTPUT_SIZES): for bounding_boxes_loader, output_size in itertools.product(make_bounding_box_loaders(), _CENTER_CROP_OUTPUT_SIZES):
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader, bounding_boxes_loader,
format=bounding_box_loader.format, format=bounding_boxes_loader.format,
spatial_size=bounding_box_loader.spatial_size, spatial_size=bounding_boxes_loader.spatial_size,
output_size=output_size, output_size=output_size,
) )
...@@ -887,8 +887,8 @@ KERNEL_INFOS.extend( ...@@ -887,8 +887,8 @@ KERNEL_INFOS.extend(
], ],
), ),
KernelInfo( KernelInfo(
F.center_crop_bounding_box, F.center_crop_bounding_boxes,
sample_inputs_fn=sample_inputs_center_crop_bounding_box, sample_inputs_fn=sample_inputs_center_crop_bounding_boxes,
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("output_size"), xfail_jit_python_scalar_arg("output_size"),
], ],
...@@ -1482,19 +1482,19 @@ KERNEL_INFOS.extend( ...@@ -1482,19 +1482,19 @@ KERNEL_INFOS.extend(
) )
def sample_inputs_clamp_bounding_box(): def sample_inputs_clamp_bounding_boxes():
for bounding_box_loader in make_bounding_box_loaders(): for bounding_boxes_loader in make_bounding_box_loaders():
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader, bounding_boxes_loader,
format=bounding_box_loader.format, format=bounding_boxes_loader.format,
spatial_size=bounding_box_loader.spatial_size, spatial_size=bounding_boxes_loader.spatial_size,
) )
KERNEL_INFOS.append( KERNEL_INFOS.append(
KernelInfo( KernelInfo(
F.clamp_bounding_box, F.clamp_bounding_boxes,
sample_inputs_fn=sample_inputs_clamp_bounding_box, sample_inputs_fn=sample_inputs_clamp_bounding_boxes,
logs_usage=True, logs_usage=True,
) )
) )
......
from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS
from ._bounding_box import BoundingBox, BoundingBoxFormat from ._bounding_box import BoundingBoxes, BoundingBoxFormat
from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT
from ._image import _ImageType, _ImageTypeJIT, _TensorImageType, _TensorImageTypeJIT, Image from ._image import _ImageType, _ImageTypeJIT, _TensorImageType, _TensorImageTypeJIT, Image
from ._mask import Mask from ._mask import Mask
......
...@@ -24,7 +24,7 @@ class BoundingBoxFormat(Enum): ...@@ -24,7 +24,7 @@ class BoundingBoxFormat(Enum):
CXCYWH = "CXCYWH" CXCYWH = "CXCYWH"
class BoundingBox(Datapoint): class BoundingBoxes(Datapoint):
"""[BETA] :class:`torch.Tensor` subclass for bounding boxes. """[BETA] :class:`torch.Tensor` subclass for bounding boxes.
Args: Args:
...@@ -43,11 +43,11 @@ class BoundingBox(Datapoint): ...@@ -43,11 +43,11 @@ class BoundingBox(Datapoint):
spatial_size: Tuple[int, int] spatial_size: Tuple[int, int]
@classmethod @classmethod
def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, spatial_size: Tuple[int, int]) -> BoundingBox: def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, spatial_size: Tuple[int, int]) -> BoundingBoxes:
bounding_box = tensor.as_subclass(cls) bounding_boxes = tensor.as_subclass(cls)
bounding_box.format = format bounding_boxes.format = format
bounding_box.spatial_size = spatial_size bounding_boxes.spatial_size = spatial_size
return bounding_box return bounding_boxes
def __new__( def __new__(
cls, cls,
...@@ -58,7 +58,7 @@ class BoundingBox(Datapoint): ...@@ -58,7 +58,7 @@ class BoundingBox(Datapoint):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None, requires_grad: Optional[bool] = None,
) -> BoundingBox: ) -> BoundingBoxes:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
if isinstance(format, str): if isinstance(format, str):
...@@ -69,17 +69,17 @@ class BoundingBox(Datapoint): ...@@ -69,17 +69,17 @@ class BoundingBox(Datapoint):
@classmethod @classmethod
def wrap_like( def wrap_like(
cls, cls,
other: BoundingBox, other: BoundingBoxes,
tensor: torch.Tensor, tensor: torch.Tensor,
*, *,
format: Optional[BoundingBoxFormat] = None, format: Optional[BoundingBoxFormat] = None,
spatial_size: Optional[Tuple[int, int]] = None, spatial_size: Optional[Tuple[int, int]] = None,
) -> BoundingBox: ) -> BoundingBoxes:
"""Wrap a :class:`torch.Tensor` as :class:`BoundingBox` from a reference. """Wrap a :class:`torch.Tensor` as :class:`BoundingBoxes` from a reference.
Args: Args:
other (BoundingBox): Reference bounding box. other (BoundingBoxes): Reference bounding box.
tensor (Tensor): Tensor to be wrapped as :class:`BoundingBox` tensor (Tensor): Tensor to be wrapped as :class:`BoundingBoxes`
format (BoundingBoxFormat, str, optional): Format of the bounding box. If omitted, it is taken from the format (BoundingBoxFormat, str, optional): Format of the bounding box. If omitted, it is taken from the
reference. reference.
spatial_size (two-tuple of ints, optional): Height and width of the corresponding image or video. If spatial_size (two-tuple of ints, optional): Height and width of the corresponding image or video. If
...@@ -98,17 +98,17 @@ class BoundingBox(Datapoint): ...@@ -98,17 +98,17 @@ class BoundingBox(Datapoint):
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(format=self.format, spatial_size=self.spatial_size) return self._make_repr(format=self.format, spatial_size=self.spatial_size)
def horizontal_flip(self) -> BoundingBox: def horizontal_flip(self) -> BoundingBoxes:
output = self._F.horizontal_flip_bounding_box( output = self._F.horizontal_flip_bounding_boxes(
self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size
) )
return BoundingBox.wrap_like(self, output) return BoundingBoxes.wrap_like(self, output)
def vertical_flip(self) -> BoundingBox: def vertical_flip(self) -> BoundingBoxes:
output = self._F.vertical_flip_bounding_box( output = self._F.vertical_flip_bounding_boxes(
self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size
) )
return BoundingBox.wrap_like(self, output) return BoundingBoxes.wrap_like(self, output)
def resize( # type: ignore[override] def resize( # type: ignore[override]
self, self,
...@@ -116,26 +116,26 @@ class BoundingBox(Datapoint): ...@@ -116,26 +116,26 @@ class BoundingBox(Datapoint):
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> BoundingBox: ) -> BoundingBoxes:
output, spatial_size = self._F.resize_bounding_box( output, spatial_size = self._F.resize_bounding_boxes(
self.as_subclass(torch.Tensor), self.as_subclass(torch.Tensor),
spatial_size=self.spatial_size, spatial_size=self.spatial_size,
size=size, size=size,
max_size=max_size, max_size=max_size,
) )
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) return BoundingBoxes.wrap_like(self, output, spatial_size=spatial_size)
def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox: def crop(self, top: int, left: int, height: int, width: int) -> BoundingBoxes:
output, spatial_size = self._F.crop_bounding_box( output, spatial_size = self._F.crop_bounding_boxes(
self.as_subclass(torch.Tensor), self.format, top=top, left=left, height=height, width=width self.as_subclass(torch.Tensor), self.format, top=top, left=left, height=height, width=width
) )
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) return BoundingBoxes.wrap_like(self, output, spatial_size=spatial_size)
def center_crop(self, output_size: List[int]) -> BoundingBox: def center_crop(self, output_size: List[int]) -> BoundingBoxes:
output, spatial_size = self._F.center_crop_bounding_box( output, spatial_size = self._F.center_crop_bounding_boxes(
self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size, output_size=output_size self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size, output_size=output_size
) )
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) return BoundingBoxes.wrap_like(self, output, spatial_size=spatial_size)
def resized_crop( def resized_crop(
self, self,
...@@ -146,26 +146,26 @@ class BoundingBox(Datapoint): ...@@ -146,26 +146,26 @@ class BoundingBox(Datapoint):
size: List[int], size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> BoundingBox: ) -> BoundingBoxes:
output, spatial_size = self._F.resized_crop_bounding_box( output, spatial_size = self._F.resized_crop_bounding_boxes(
self.as_subclass(torch.Tensor), self.format, top, left, height, width, size=size self.as_subclass(torch.Tensor), self.format, top, left, height, width, size=size
) )
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) return BoundingBoxes.wrap_like(self, output, spatial_size=spatial_size)
def pad( def pad(
self, self,
padding: Union[int, Sequence[int]], padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, List[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> BoundingBox: ) -> BoundingBoxes:
output, spatial_size = self._F.pad_bounding_box( output, spatial_size = self._F.pad_bounding_boxes(
self.as_subclass(torch.Tensor), self.as_subclass(torch.Tensor),
format=self.format, format=self.format,
spatial_size=self.spatial_size, spatial_size=self.spatial_size,
padding=padding, padding=padding,
padding_mode=padding_mode, padding_mode=padding_mode,
) )
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) return BoundingBoxes.wrap_like(self, output, spatial_size=spatial_size)
def rotate( def rotate(
self, self,
...@@ -174,8 +174,8 @@ class BoundingBox(Datapoint): ...@@ -174,8 +174,8 @@ class BoundingBox(Datapoint):
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: _FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> BoundingBox: ) -> BoundingBoxes:
output, spatial_size = self._F.rotate_bounding_box( output, spatial_size = self._F.rotate_bounding_boxes(
self.as_subclass(torch.Tensor), self.as_subclass(torch.Tensor),
format=self.format, format=self.format,
spatial_size=self.spatial_size, spatial_size=self.spatial_size,
...@@ -183,7 +183,7 @@ class BoundingBox(Datapoint): ...@@ -183,7 +183,7 @@ class BoundingBox(Datapoint):
expand=expand, expand=expand,
center=center, center=center,
) )
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) return BoundingBoxes.wrap_like(self, output, spatial_size=spatial_size)
def affine( def affine(
self, self,
...@@ -194,8 +194,8 @@ class BoundingBox(Datapoint): ...@@ -194,8 +194,8 @@ class BoundingBox(Datapoint):
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: _FillTypeJIT = None, fill: _FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> BoundingBox: ) -> BoundingBoxes:
output = self._F.affine_bounding_box( output = self._F.affine_bounding_boxes(
self.as_subclass(torch.Tensor), self.as_subclass(torch.Tensor),
self.format, self.format,
self.spatial_size, self.spatial_size,
...@@ -205,7 +205,7 @@ class BoundingBox(Datapoint): ...@@ -205,7 +205,7 @@ class BoundingBox(Datapoint):
shear=shear, shear=shear,
center=center, center=center,
) )
return BoundingBox.wrap_like(self, output) return BoundingBoxes.wrap_like(self, output)
def perspective( def perspective(
self, self,
...@@ -214,8 +214,8 @@ class BoundingBox(Datapoint): ...@@ -214,8 +214,8 @@ class BoundingBox(Datapoint):
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: _FillTypeJIT = None, fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> BoundingBox: ) -> BoundingBoxes:
output = self._F.perspective_bounding_box( output = self._F.perspective_bounding_boxes(
self.as_subclass(torch.Tensor), self.as_subclass(torch.Tensor),
format=self.format, format=self.format,
spatial_size=self.spatial_size, spatial_size=self.spatial_size,
...@@ -223,15 +223,15 @@ class BoundingBox(Datapoint): ...@@ -223,15 +223,15 @@ class BoundingBox(Datapoint):
endpoints=endpoints, endpoints=endpoints,
coefficients=coefficients, coefficients=coefficients,
) )
return BoundingBox.wrap_like(self, output) return BoundingBoxes.wrap_like(self, output)
def elastic( def elastic(
self, self,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: _FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> BoundingBox: ) -> BoundingBoxes:
output = self._F.elastic_bounding_box( output = self._F.elastic_bounding_boxes(
self.as_subclass(torch.Tensor), self.format, self.spatial_size, displacement=displacement self.as_subclass(torch.Tensor), self.format, self.spatial_size, displacement=displacement
) )
return BoundingBox.wrap_like(self, output) return BoundingBoxes.wrap_like(self, output)
...@@ -138,8 +138,8 @@ class Datapoint(torch.Tensor): ...@@ -138,8 +138,8 @@ class Datapoint(torch.Tensor):
# *not* happen for `deepcopy(Tensor)`. A side-effect from detaching is that the `Tensor.requires_grad` # *not* happen for `deepcopy(Tensor)`. A side-effect from detaching is that the `Tensor.requires_grad`
# attribute is cleared, so we need to refill it before we return. # attribute is cleared, so we need to refill it before we return.
# Note: We don't explicitly handle deep-copying of the metadata here. The only metadata we currently have is # Note: We don't explicitly handle deep-copying of the metadata here. The only metadata we currently have is
# `BoundingBox.format` and `BoundingBox.spatial_size`, which are immutable and thus implicitly deep-copied by # `BoundingBoxes.format` and `BoundingBoxes.spatial_size`, which are immutable and thus implicitly deep-copied by
# `BoundingBox.clone()`. # `BoundingBoxes.clone()`.
return self.detach().clone().requires_grad_(self.requires_grad) # type: ignore[return-value] return self.detach().clone().requires_grad_(self.requires_grad) # type: ignore[return-value]
def horizontal_flip(self) -> Datapoint: def horizontal_flip(self) -> Datapoint:
......
...@@ -44,7 +44,7 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None): ...@@ -44,7 +44,7 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
the target and wrap the data in the corresponding ``torchvision.datapoints``. The original keys are the target and wrap the data in the corresponding ``torchvision.datapoints``. The original keys are
preserved. If ``target_keys`` is ommitted, returns only the values for the ``"boxes"`` and ``"labels"``. preserved. If ``target_keys`` is ommitted, returns only the values for the ``"boxes"`` and ``"labels"``.
* :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY`` * :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY``
coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint. coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBoxes` datapoint.
* :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dicts, the wrapper returns a * :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dicts, the wrapper returns a
dict of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data dict of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data
in the corresponding ``torchvision.datapoints``. The original keys are preserved. If ``target_keys`` is in the corresponding ``torchvision.datapoints``. The original keys are preserved. If ``target_keys`` is
...@@ -56,7 +56,7 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None): ...@@ -56,7 +56,7 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
a dictionary with the key-value-pairs ``"masks"`` (as :class:`~torchvision.datapoints.Mask` datapoint) and a dictionary with the key-value-pairs ``"masks"`` (as :class:`~torchvision.datapoints.Mask` datapoint) and
``"labels"``. ``"labels"``.
* :class:`~torchvision.datasets.WIDERFace`: The value for key ``"bbox"`` in the target is converted to ``XYXY`` * :class:`~torchvision.datasets.WIDERFace`: The value for key ``"bbox"`` in the target is converted to ``XYXY``
coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint. coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBoxes` datapoint.
Image classification datasets Image classification datasets
...@@ -360,8 +360,8 @@ def coco_dectection_wrapper_factory(dataset, target_keys): ...@@ -360,8 +360,8 @@ def coco_dectection_wrapper_factory(dataset, target_keys):
target["image_id"] = image_id target["image_id"] = image_id
if "boxes" in target_keys: if "boxes" in target_keys:
target["boxes"] = F.convert_format_bounding_box( target["boxes"] = F.convert_format_bounding_boxes(
datapoints.BoundingBox( datapoints.BoundingBoxes(
batched_target["bbox"], batched_target["bbox"],
format=datapoints.BoundingBoxFormat.XYWH, format=datapoints.BoundingBoxFormat.XYWH,
spatial_size=spatial_size, spatial_size=spatial_size,
...@@ -442,7 +442,7 @@ def voc_detection_wrapper_factory(dataset, target_keys): ...@@ -442,7 +442,7 @@ def voc_detection_wrapper_factory(dataset, target_keys):
target = {} target = {}
if "boxes" in target_keys: if "boxes" in target_keys:
target["boxes"] = datapoints.BoundingBox( target["boxes"] = datapoints.BoundingBoxes(
[ [
[int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] [int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")]
for bndbox in batched_instances["bndbox"] for bndbox in batched_instances["bndbox"]
...@@ -481,8 +481,8 @@ def celeba_wrapper_factory(dataset, target_keys): ...@@ -481,8 +481,8 @@ def celeba_wrapper_factory(dataset, target_keys):
target, target,
target_types=dataset.target_type, target_types=dataset.target_type,
type_wrappers={ type_wrappers={
"bbox": lambda item: F.convert_format_bounding_box( "bbox": lambda item: F.convert_format_bounding_boxes(
datapoints.BoundingBox( datapoints.BoundingBoxes(
item, item,
format=datapoints.BoundingBoxFormat.XYWH, format=datapoints.BoundingBoxFormat.XYWH,
spatial_size=(image.height, image.width), spatial_size=(image.height, image.width),
...@@ -532,7 +532,7 @@ def kitti_wrapper_factory(dataset, target_keys): ...@@ -532,7 +532,7 @@ def kitti_wrapper_factory(dataset, target_keys):
target = {} target = {}
if "boxes" in target_keys: if "boxes" in target_keys:
target["boxes"] = datapoints.BoundingBox( target["boxes"] = datapoints.BoundingBoxes(
batched_target["bbox"], batched_target["bbox"],
format=datapoints.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=(image.height, image.width), spatial_size=(image.height, image.width),
...@@ -628,8 +628,8 @@ def widerface_wrapper(dataset, target_keys): ...@@ -628,8 +628,8 @@ def widerface_wrapper(dataset, target_keys):
target = {key: target[key] for key in target_keys} target = {key: target[key] for key in target_keys}
if "bbox" in target_keys: if "bbox" in target_keys:
target["bbox"] = F.convert_format_bounding_box( target["bbox"] = F.convert_format_bounding_boxes(
datapoints.BoundingBox( datapoints.BoundingBoxes(
target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width) target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width)
), ),
new_format=datapoints.BoundingBoxFormat.XYXY, new_format=datapoints.BoundingBoxFormat.XYXY,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment