Unverified Commit c1592f96 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Remove wrap_like class method and add datapoints.wrap() function (#7832)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 498b9c86
...@@ -19,3 +19,4 @@ see e.g. :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`. ...@@ -19,3 +19,4 @@ see e.g. :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`.
Mask Mask
Datapoint Datapoint
set_return_type set_return_type
wrap
...@@ -53,11 +53,11 @@ from torchvision.transforms.v2 import functional as F ...@@ -53,11 +53,11 @@ from torchvision.transforms.v2 import functional as F
def hflip_my_datapoint(my_dp, *args, **kwargs): def hflip_my_datapoint(my_dp, *args, **kwargs):
print("Flipping!") print("Flipping!")
out = my_dp.flip(-1) out = my_dp.flip(-1)
return MyDatapoint.wrap_like(my_dp, out) return datapoints.wrap(out, like=my_dp)
# %% # %%
# To understand why ``wrap_like`` is used, see # To understand why :func:`~torchvision.datapoints.wrap` is used, see
# :ref:`datapoint_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now, # :ref:`datapoint_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now,
# we will explain it below in :ref:`param_forwarding`. # we will explain it below in :ref:`param_forwarding`.
# #
...@@ -107,7 +107,7 @@ _ = t(my_dp) ...@@ -107,7 +107,7 @@ _ = t(my_dp)
def hflip_my_datapoint(my_dp): # noqa def hflip_my_datapoint(my_dp): # noqa
print("Flipping!") print("Flipping!")
out = my_dp.flip(-1) out = my_dp.flip(-1)
return MyDatapoint.wrap_like(my_dp, out) return datapoints.wrap(out, like=my_dp)
# %% # %%
......
...@@ -107,26 +107,23 @@ bboxes = datapoints.BoundingBoxes( ...@@ -107,26 +107,23 @@ bboxes = datapoints.BoundingBoxes(
print(bboxes) print(bboxes)
# %% # %%
# Using the ``wrap_like()`` class method # Using ``datapoints.wrap()``
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^
# #
# You can also use the ``wrap_like()`` class method to wrap a tensor object # You can also use the :func:`~torchvision.datapoints.wrap` function to wrap a tensor object
# into a datapoint. This is useful when you already have an object of the # into a datapoint. This is useful when you already have an object of the
# desired type, which typically happens when writing transforms: you just want # desired type, which typically happens when writing transforms: you just want
# to wrap the output like the input. This API is inspired by utils like # to wrap the output like the input.
# :func:`torch.zeros_like`:
new_bboxes = torch.tensor([0, 20, 30, 40]) new_bboxes = torch.tensor([0, 20, 30, 40])
new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes) new_bboxes = datapoints.wrap(new_bboxes, like=bboxes)
assert isinstance(new_bboxes, datapoints.BoundingBoxes) assert isinstance(new_bboxes, datapoints.BoundingBoxes)
assert new_bboxes.canvas_size == bboxes.canvas_size assert new_bboxes.canvas_size == bboxes.canvas_size
# %% # %%
# The metadata of ``new_bboxes`` is the same as ``bboxes``, but you could pass # The metadata of ``new_bboxes`` is the same as ``bboxes``, but you could pass
# it as a parameter to override it. Check the # it as a parameter to override it.
# :meth:`~torchvision.datapoints.BoundingBoxes.wrap_like` documentation for
# more details.
# #
# Do I have to wrap the output of the datasets myself? # Do I have to wrap the output of the datasets myself?
# ---------------------------------------------------- # ----------------------------------------------------
...@@ -230,11 +227,11 @@ assert not isinstance(new_bboxes, datapoints.BoundingBoxes) ...@@ -230,11 +227,11 @@ assert not isinstance(new_bboxes, datapoints.BoundingBoxes)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# #
# You can re-wrap a pure tensor into a datapoint by just calling the datapoint # You can re-wrap a pure tensor into a datapoint by just calling the datapoint
# constructor, or by using the ``.wrap_like()`` class method (see more details # constructor, or by using the :func:`~torchvision.datapoints.wrap` function
# above in :ref:`datapoint_creation`): # (see more details above in :ref:`datapoint_creation`):
new_bboxes = bboxes + 3 new_bboxes = bboxes + 3
new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes) new_bboxes = datapoints.wrap(new_bboxes, like=bboxes)
assert isinstance(new_bboxes, datapoints.BoundingBoxes) assert isinstance(new_bboxes, datapoints.BoundingBoxes)
# %% # %%
......
...@@ -213,13 +213,13 @@ def test_inplace_op_no_wrapping(make_input, return_type): ...@@ -213,13 +213,13 @@ def test_inplace_op_no_wrapping(make_input, return_type):
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video]) @pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
def test_wrap_like(make_input): def test_wrap(make_input):
dp = make_input() dp = make_input()
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here # any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output = dp * 2 output = dp * 2
dp_new = type(dp).wrap_like(dp, output) dp_new = datapoints.wrap(output, like=dp)
assert type(dp_new) is type(dp) assert type(dp_new) is type(dp)
assert dp_new.data_ptr() == output.data_ptr() assert dp_new.data_ptr() == output.data_ptr()
......
...@@ -570,7 +570,7 @@ class TestResize: ...@@ -570,7 +570,7 @@ class TestResize:
canvas_size=(new_height, new_width), canvas_size=(new_height, new_width),
affine_matrix=affine_matrix, affine_matrix=affine_matrix,
) )
return datapoints.BoundingBoxes.wrap_like(bounding_boxes, expected_bboxes, canvas_size=(new_height, new_width)) return datapoints.wrap(expected_bboxes, like=bounding_boxes, canvas_size=(new_height, new_width))
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("size", OUTPUT_SIZES) @pytest.mark.parametrize("size", OUTPUT_SIZES)
...@@ -815,7 +815,7 @@ class TestHorizontalFlip: ...@@ -815,7 +815,7 @@ class TestHorizontalFlip:
affine_matrix=affine_matrix, affine_matrix=affine_matrix,
) )
return datapoints.BoundingBoxes.wrap_like(bounding_boxes, expected_bboxes) return datapoints.wrap(expected_bboxes, like=bounding_boxes)
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -1278,7 +1278,7 @@ class TestVerticalFlip: ...@@ -1278,7 +1278,7 @@ class TestVerticalFlip:
affine_matrix=affine_matrix, affine_matrix=affine_matrix,
) )
return datapoints.BoundingBoxes.wrap_like(bounding_boxes, expected_bboxes) return datapoints.wrap(expected_bboxes, like=bounding_boxes)
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)]) @pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
......
...@@ -12,3 +12,25 @@ if _WARN_ABOUT_BETA_TRANSFORMS: ...@@ -12,3 +12,25 @@ if _WARN_ABOUT_BETA_TRANSFORMS:
import warnings import warnings
warnings.warn(_BETA_TRANSFORMS_WARNING) warnings.warn(_BETA_TRANSFORMS_WARNING)
def wrap(wrappee, *, like, **kwargs):
"""Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.datapoint.Datapoint` subclass as ``like``.
If ``like`` is a :class:`~torchvision.datapoint.BoundingBoxes`, the ``format`` and ``canvas_size`` of
``like`` are assigned to ``wrappee``, unless they are passed as ``kwargs``.
Args:
wrappee (Tensor): The tensor to convert.
like (Datapoint): The
kwargs: Can contain "format" and "canvas_size" if ``like`` is a :class:`~torchvision.datapoint.BoundingBoxes`.
Ignored otherwise.
"""
if isinstance(like, BoundingBoxes):
return BoundingBoxes._wrap(
wrappee,
format=kwargs.get("format", like.format),
canvas_size=kwargs.get("canvas_size", like.canvas_size),
)
else:
return wrappee.as_subclass(type(like))
...@@ -75,32 +75,6 @@ class BoundingBoxes(Datapoint): ...@@ -75,32 +75,6 @@ class BoundingBoxes(Datapoint):
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return cls._wrap(tensor, format=format, canvas_size=canvas_size) return cls._wrap(tensor, format=format, canvas_size=canvas_size)
@classmethod
def wrap_like(
cls,
other: BoundingBoxes,
tensor: torch.Tensor,
*,
format: Optional[Union[BoundingBoxFormat, str]] = None,
canvas_size: Optional[Tuple[int, int]] = None,
) -> BoundingBoxes:
"""Wrap a :class:`torch.Tensor` as :class:`BoundingBoxes` from a reference.
Args:
other (BoundingBoxes): Reference bounding box.
tensor (Tensor): Tensor to be wrapped as :class:`BoundingBoxes`
format (BoundingBoxFormat, str, optional): Format of the bounding box. If omitted, it is taken from the
reference.
canvas_size (two-tuple of ints, optional): Height and width of the corresponding image or video. If
omitted, it is taken from the reference.
"""
return cls._wrap(
tensor,
format=format if format is not None else other.format,
canvas_size=canvas_size if canvas_size is not None else other.canvas_size,
)
@classmethod @classmethod
def _wrap_output( def _wrap_output(
cls, cls,
......
...@@ -31,10 +31,6 @@ class Datapoint(torch.Tensor): ...@@ -31,10 +31,6 @@ class Datapoint(torch.Tensor):
requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False
return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad) return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)
@classmethod
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
return tensor.as_subclass(cls)
@classmethod @classmethod
def _wrap_output( def _wrap_output(
cls, cls,
......
...@@ -32,13 +32,6 @@ class _LabelBase(Datapoint): ...@@ -32,13 +32,6 @@ class _LabelBase(Datapoint):
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return cls._wrap(tensor, categories=categories) return cls._wrap(tensor, categories=categories)
@classmethod
def wrap_like(cls: Type[L], other: L, tensor: torch.Tensor, *, categories: Optional[Sequence[str]] = None) -> L:
return cls._wrap(
tensor,
categories=categories if categories is not None else other.categories,
)
@classmethod @classmethod
def from_category( def from_category(
cls: Type[L], cls: Type[L],
......
...@@ -36,11 +36,9 @@ class SimpleCopyPaste(Transform): ...@@ -36,11 +36,9 @@ class SimpleCopyPaste(Transform):
antialias: Optional[bool], antialias: Optional[bool],
) -> Tuple[torch.Tensor, Dict[str, Any]]: ) -> Tuple[torch.Tensor, Dict[str, Any]]:
paste_masks = paste_target["masks"].wrap_like(paste_target["masks"], paste_target["masks"][random_selection]) paste_masks = datapoints.wrap(paste_target["masks"][random_selection], like=paste_target["masks"])
paste_boxes = paste_target["boxes"].wrap_like(paste_target["boxes"], paste_target["boxes"][random_selection]) paste_boxes = datapoints.wrap(paste_target["boxes"][random_selection], like=paste_target["boxes"])
paste_labels = paste_target["labels"].wrap_like( paste_labels = datapoints.wrap(paste_target["labels"][random_selection], like=paste_target["labels"])
paste_target["labels"], paste_target["labels"][random_selection]
)
masks = target["masks"] masks = target["masks"]
...@@ -143,7 +141,7 @@ class SimpleCopyPaste(Transform): ...@@ -143,7 +141,7 @@ class SimpleCopyPaste(Transform):
c0, c1, c2, c3 = 0, 0, 0, 0 c0, c1, c2, c3 = 0, 0, 0, 0
for i, obj in enumerate(flat_sample): for i, obj in enumerate(flat_sample):
if isinstance(obj, datapoints.Image): if isinstance(obj, datapoints.Image):
flat_sample[i] = datapoints.Image.wrap_like(obj, output_images[c0]) flat_sample[i] = datapoints.wrap(output_images[c0], like=obj)
c0 += 1 c0 += 1
elif isinstance(obj, PIL.Image.Image): elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_image_pil(output_images[c0]) flat_sample[i] = F.to_image_pil(output_images[c0])
...@@ -152,13 +150,13 @@ class SimpleCopyPaste(Transform): ...@@ -152,13 +150,13 @@ class SimpleCopyPaste(Transform):
flat_sample[i] = output_images[c0] flat_sample[i] = output_images[c0]
c0 += 1 c0 += 1
elif isinstance(obj, datapoints.BoundingBoxes): elif isinstance(obj, datapoints.BoundingBoxes):
flat_sample[i] = datapoints.BoundingBoxes.wrap_like(obj, output_targets[c1]["boxes"]) flat_sample[i] = datapoints.wrap(output_targets[c1]["boxes"], like=obj)
c1 += 1 c1 += 1
elif isinstance(obj, datapoints.Mask): elif isinstance(obj, datapoints.Mask):
flat_sample[i] = datapoints.Mask.wrap_like(obj, output_targets[c2]["masks"]) flat_sample[i] = datapoints.wrap(output_targets[c2]["masks"], like=obj)
c2 += 1 c2 += 1
elif isinstance(obj, (proto_datapoints.Label, proto_datapoints.OneHotLabel)): elif isinstance(obj, (proto_datapoints.Label, proto_datapoints.OneHotLabel)):
flat_sample[i] = obj.wrap_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type] flat_sample[i] = datapoints.wrap(output_targets[c3]["labels"], like=obj)
c3 += 1 c3 += 1
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
......
...@@ -112,11 +112,11 @@ class FixedSizeCrop(Transform): ...@@ -112,11 +112,11 @@ class FixedSizeCrop(Transform):
if params["is_valid"] is not None: if params["is_valid"] is not None:
if isinstance(inpt, (Label, OneHotLabel, datapoints.Mask)): if isinstance(inpt, (Label, OneHotLabel, datapoints.Mask)):
inpt = inpt.wrap_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type] inpt = datapoints.wrap(inpt[params["is_valid"]], like=inpt)
elif isinstance(inpt, datapoints.BoundingBoxes): elif isinstance(inpt, datapoints.BoundingBoxes):
inpt = datapoints.BoundingBoxes.wrap_like( inpt = datapoints.wrap(
inpt,
F.clamp_bounding_boxes(inpt[params["is_valid"]], format=inpt.format, canvas_size=inpt.canvas_size), F.clamp_bounding_boxes(inpt[params["is_valid"]], format=inpt.format, canvas_size=inpt.canvas_size),
like=inpt,
) )
if params["needs_pad"]: if params["needs_pad"]:
......
...@@ -249,7 +249,7 @@ class MixUp(_BaseMixUpCutMix): ...@@ -249,7 +249,7 @@ class MixUp(_BaseMixUpCutMix):
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
if isinstance(inpt, (datapoints.Image, datapoints.Video)): if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] output = datapoints.wrap(output, like=inpt)
return output return output
else: else:
...@@ -319,7 +319,7 @@ class CutMix(_BaseMixUpCutMix): ...@@ -319,7 +319,7 @@ class CutMix(_BaseMixUpCutMix):
output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2] output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]
if isinstance(inpt, (datapoints.Image, datapoints.Video)): if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] output = datapoints.wrap(output, like=inpt)
return output return output
else: else:
......
...@@ -620,7 +620,7 @@ class AugMix(_AutoAugmentBase): ...@@ -620,7 +620,7 @@ class AugMix(_AutoAugmentBase):
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype) mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
if isinstance(orig_image_or_video, (datapoints.Image, datapoints.Video)): if isinstance(orig_image_or_video, (datapoints.Image, datapoints.Video)):
mix = orig_image_or_video.wrap_like(orig_image_or_video, mix) # type: ignore[arg-type] mix = datapoints.wrap(mix, like=orig_image_or_video)
elif isinstance(orig_image_or_video, PIL.Image.Image): elif isinstance(orig_image_or_video, PIL.Image.Image):
mix = F.to_image_pil(mix) mix = F.to_image_pil(mix)
......
...@@ -338,7 +338,7 @@ class FiveCrop(Transform): ...@@ -338,7 +338,7 @@ class FiveCrop(Transform):
... images_or_videos, labels = sample ... images_or_videos, labels = sample
... batch_size = len(images_or_videos) ... batch_size = len(images_or_videos)
... image_or_video = images_or_videos[0] ... image_or_video = images_or_videos[0]
... images_or_videos = image_or_video.wrap_like(image_or_video, torch.stack(images_or_videos)) ... images_or_videos = datapoints.wrap(torch.stack(images_or_videos), like=image_or_video)
... labels = torch.full((batch_size,), label, device=images_or_videos.device) ... labels = torch.full((batch_size,), label, device=images_or_videos.device)
... return images_or_videos, labels ... return images_or_videos, labels
... ...
......
...@@ -131,7 +131,7 @@ class LinearTransformation(Transform): ...@@ -131,7 +131,7 @@ class LinearTransformation(Transform):
output = output.reshape(shape) output = output.reshape(shape)
if isinstance(inpt, (datapoints.Image, datapoints.Video)): if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] output = datapoints.wrap(output, like=inpt)
return output return output
...@@ -423,4 +423,4 @@ class SanitizeBoundingBoxes(Transform): ...@@ -423,4 +423,4 @@ class SanitizeBoundingBoxes(Transform):
if is_label: if is_label:
return output return output
return type(inpt).wrap_like(inpt, output) return datapoints.wrap(output, like=inpt)
...@@ -87,7 +87,7 @@ def _horizontal_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) -> ...@@ -87,7 +87,7 @@ def _horizontal_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) ->
output = horizontal_flip_bounding_boxes( output = horizontal_flip_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
) )
return datapoints.BoundingBoxes.wrap_like(inpt, output) return datapoints.wrap(output, like=inpt)
@_register_kernel_internal(horizontal_flip, datapoints.Video) @_register_kernel_internal(horizontal_flip, datapoints.Video)
...@@ -143,7 +143,7 @@ def _vertical_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) -> da ...@@ -143,7 +143,7 @@ def _vertical_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) -> da
output = vertical_flip_bounding_boxes( output = vertical_flip_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
) )
return datapoints.BoundingBoxes.wrap_like(inpt, output) return datapoints.wrap(output, like=inpt)
@_register_kernel_internal(vertical_flip, datapoints.Video) @_register_kernel_internal(vertical_flip, datapoints.Video)
...@@ -321,7 +321,7 @@ def _resize_mask_dispatch( ...@@ -321,7 +321,7 @@ def _resize_mask_dispatch(
inpt: datapoints.Mask, size: List[int], max_size: Optional[int] = None, **kwargs: Any inpt: datapoints.Mask, size: List[int], max_size: Optional[int] = None, **kwargs: Any
) -> datapoints.Mask: ) -> datapoints.Mask:
output = resize_mask(inpt.as_subclass(torch.Tensor), size, max_size=max_size) output = resize_mask(inpt.as_subclass(torch.Tensor), size, max_size=max_size)
return datapoints.Mask.wrap_like(inpt, output) return datapoints.wrap(output, like=inpt)
def resize_bounding_boxes( def resize_bounding_boxes(
...@@ -349,7 +349,7 @@ def _resize_bounding_boxes_dispatch( ...@@ -349,7 +349,7 @@ def _resize_bounding_boxes_dispatch(
output, canvas_size = resize_bounding_boxes( output, canvas_size = resize_bounding_boxes(
inpt.as_subclass(torch.Tensor), inpt.canvas_size, size, max_size=max_size inpt.as_subclass(torch.Tensor), inpt.canvas_size, size, max_size=max_size
) )
return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size) return datapoints.wrap(output, like=inpt, canvas_size=canvas_size)
@_register_kernel_internal(resize, datapoints.Video) @_register_kernel_internal(resize, datapoints.Video)
...@@ -857,7 +857,7 @@ def _affine_bounding_boxes_dispatch( ...@@ -857,7 +857,7 @@ def _affine_bounding_boxes_dispatch(
shear=shear, shear=shear,
center=center, center=center,
) )
return datapoints.BoundingBoxes.wrap_like(inpt, output) return datapoints.wrap(output, like=inpt)
def affine_mask( def affine_mask(
...@@ -912,7 +912,7 @@ def _affine_mask_dispatch( ...@@ -912,7 +912,7 @@ def _affine_mask_dispatch(
fill=fill, fill=fill,
center=center, center=center,
) )
return datapoints.Mask.wrap_like(inpt, output) return datapoints.wrap(output, like=inpt)
@_register_kernel_internal(affine, datapoints.Video) @_register_kernel_internal(affine, datapoints.Video)
...@@ -1058,7 +1058,7 @@ def _rotate_bounding_boxes_dispatch( ...@@ -1058,7 +1058,7 @@ def _rotate_bounding_boxes_dispatch(
expand=expand, expand=expand,
center=center, center=center,
) )
return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size) return datapoints.wrap(output, like=inpt, canvas_size=canvas_size)
def rotate_mask( def rotate_mask(
...@@ -1099,7 +1099,7 @@ def _rotate_mask_dispatch( ...@@ -1099,7 +1099,7 @@ def _rotate_mask_dispatch(
**kwargs, **kwargs,
) -> datapoints.Mask: ) -> datapoints.Mask:
output = rotate_mask(inpt.as_subclass(torch.Tensor), angle=angle, expand=expand, fill=fill, center=center) output = rotate_mask(inpt.as_subclass(torch.Tensor), angle=angle, expand=expand, fill=fill, center=center)
return datapoints.Mask.wrap_like(inpt, output) return datapoints.wrap(output, like=inpt)
@_register_kernel_internal(rotate, datapoints.Video) @_register_kernel_internal(rotate, datapoints.Video)
...@@ -1321,7 +1321,7 @@ def _pad_bounding_boxes_dispatch( ...@@ -1321,7 +1321,7 @@ def _pad_bounding_boxes_dispatch(
padding=padding, padding=padding,
padding_mode=padding_mode, padding_mode=padding_mode,
) )
return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size) return datapoints.wrap(output, like=inpt, canvas_size=canvas_size)
@_register_kernel_internal(pad, datapoints.Video) @_register_kernel_internal(pad, datapoints.Video)
...@@ -1396,7 +1396,7 @@ def _crop_bounding_boxes_dispatch( ...@@ -1396,7 +1396,7 @@ def _crop_bounding_boxes_dispatch(
output, canvas_size = crop_bounding_boxes( output, canvas_size = crop_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width
) )
return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size) return datapoints.wrap(output, like=inpt, canvas_size=canvas_size)
@_register_kernel_internal(crop, datapoints.Mask) @_register_kernel_internal(crop, datapoints.Mask)
...@@ -1670,7 +1670,7 @@ def _perspective_bounding_boxes_dispatch( ...@@ -1670,7 +1670,7 @@ def _perspective_bounding_boxes_dispatch(
endpoints=endpoints, endpoints=endpoints,
coefficients=coefficients, coefficients=coefficients,
) )
return datapoints.BoundingBoxes.wrap_like(inpt, output) return datapoints.wrap(output, like=inpt)
def perspective_mask( def perspective_mask(
...@@ -1712,7 +1712,7 @@ def _perspective_mask_dispatch( ...@@ -1712,7 +1712,7 @@ def _perspective_mask_dispatch(
fill=fill, fill=fill,
coefficients=coefficients, coefficients=coefficients,
) )
return datapoints.Mask.wrap_like(inpt, output) return datapoints.wrap(output, like=inpt)
@_register_kernel_internal(perspective, datapoints.Video) @_register_kernel_internal(perspective, datapoints.Video)
...@@ -1887,7 +1887,7 @@ def _elastic_bounding_boxes_dispatch( ...@@ -1887,7 +1887,7 @@ def _elastic_bounding_boxes_dispatch(
output = elastic_bounding_boxes( output = elastic_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, displacement=displacement inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, displacement=displacement
) )
return datapoints.BoundingBoxes.wrap_like(inpt, output) return datapoints.wrap(output, like=inpt)
def elastic_mask( def elastic_mask(
...@@ -1914,7 +1914,7 @@ def _elastic_mask_dispatch( ...@@ -1914,7 +1914,7 @@ def _elastic_mask_dispatch(
inpt: datapoints.Mask, displacement: torch.Tensor, fill: _FillTypeJIT = None, **kwargs inpt: datapoints.Mask, displacement: torch.Tensor, fill: _FillTypeJIT = None, **kwargs
) -> datapoints.Mask: ) -> datapoints.Mask:
output = elastic_mask(inpt.as_subclass(torch.Tensor), displacement=displacement, fill=fill) output = elastic_mask(inpt.as_subclass(torch.Tensor), displacement=displacement, fill=fill)
return datapoints.Mask.wrap_like(inpt, output) return datapoints.wrap(output, like=inpt)
@_register_kernel_internal(elastic, datapoints.Video) @_register_kernel_internal(elastic, datapoints.Video)
...@@ -2022,7 +2022,7 @@ def _center_crop_bounding_boxes_dispatch( ...@@ -2022,7 +2022,7 @@ def _center_crop_bounding_boxes_dispatch(
output, canvas_size = center_crop_bounding_boxes( output, canvas_size = center_crop_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, output_size=output_size inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, output_size=output_size
) )
return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size) return datapoints.wrap(output, like=inpt, canvas_size=canvas_size)
@_register_kernel_internal(center_crop, datapoints.Mask) @_register_kernel_internal(center_crop, datapoints.Mask)
...@@ -2156,7 +2156,7 @@ def _resized_crop_bounding_boxes_dispatch( ...@@ -2156,7 +2156,7 @@ def _resized_crop_bounding_boxes_dispatch(
output, canvas_size = resized_crop_bounding_boxes( output, canvas_size = resized_crop_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width, size=size inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width, size=size
) )
return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size) return datapoints.wrap(output, like=inpt, canvas_size=canvas_size)
def resized_crop_mask( def resized_crop_mask(
...@@ -2178,7 +2178,7 @@ def _resized_crop_mask_dispatch( ...@@ -2178,7 +2178,7 @@ def _resized_crop_mask_dispatch(
output = resized_crop_mask( output = resized_crop_mask(
inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size
) )
return datapoints.Mask.wrap_like(inpt, output) return datapoints.wrap(output, like=inpt)
@_register_kernel_internal(resized_crop, datapoints.Video) @_register_kernel_internal(resized_crop, datapoints.Video)
......
...@@ -223,7 +223,7 @@ def convert_format_bounding_boxes( ...@@ -223,7 +223,7 @@ def convert_format_bounding_boxes(
output = _convert_format_bounding_boxes( output = _convert_format_bounding_boxes(
inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=new_format, inplace=inplace inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=new_format, inplace=inplace
) )
return datapoints.BoundingBoxes.wrap_like(inpt, output, format=new_format) return datapoints.wrap(output, like=inpt, format=new_format)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead." f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
...@@ -265,7 +265,7 @@ def clamp_bounding_boxes( ...@@ -265,7 +265,7 @@ def clamp_bounding_boxes(
if format is not None or canvas_size is not None: if format is not None or canvas_size is not None:
raise ValueError("For bounding box datapoint inputs, `format` and `canvas_size` must not be passed.") raise ValueError("For bounding box datapoint inputs, `format` and `canvas_size` must not be passed.")
output = _clamp_bounding_boxes(inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size) output = _clamp_bounding_boxes(inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size)
return datapoints.BoundingBoxes.wrap_like(inpt, output) return datapoints.wrap(output, like=inpt)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead." f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
......
...@@ -25,11 +25,11 @@ def _kernel_datapoint_wrapper(kernel): ...@@ -25,11 +25,11 @@ def _kernel_datapoint_wrapper(kernel):
# regardless of whether we override __torch_function__ in our base class # regardless of whether we override __torch_function__ in our base class
# or not. # or not.
# Also, even if we didn't call `as_subclass` here, we would still need # Also, even if we didn't call `as_subclass` here, we would still need
# this wrapper to call wrap_like(), because the Datapoint type would be # this wrapper to call wrap(), because the Datapoint type would be
# lost after the first operation due to our own __torch_function__ # lost after the first operation due to our own __torch_function__
# logic. # logic.
output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs) output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs)
return type(inpt).wrap_like(inpt, output) return datapoints.wrap(output, like=inpt)
return wrapper return wrapper
...@@ -137,7 +137,7 @@ def _register_five_ten_crop_kernel_internal(functional, input_type): ...@@ -137,7 +137,7 @@ def _register_five_ten_crop_kernel_internal(functional, input_type):
def wrapper(inpt, *args, **kwargs): def wrapper(inpt, *args, **kwargs):
output = kernel(inpt, *args, **kwargs) output = kernel(inpt, *args, **kwargs)
container_type = type(output) container_type = type(output)
return container_type(type(inpt).wrap_like(inpt, o) for o in output) return container_type(datapoints.wrap(o, like=inpt) for o in output)
return wrapper return wrapper
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment