Unverified Commit c1592f96 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Remove wrap_like class method and add datapoints.wrap() function (#7832)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 498b9c86
......@@ -19,3 +19,4 @@ see e.g. :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`.
Mask
Datapoint
set_return_type
wrap
......@@ -53,11 +53,11 @@ from torchvision.transforms.v2 import functional as F
def hflip_my_datapoint(my_dp, *args, **kwargs):
print("Flipping!")
out = my_dp.flip(-1)
return MyDatapoint.wrap_like(my_dp, out)
return datapoints.wrap(out, like=my_dp)
# %%
# To understand why ``wrap_like`` is used, see
# To understand why :func:`~torchvision.datapoints.wrap` is used, see
# :ref:`datapoint_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now,
# we will explain it below in :ref:`param_forwarding`.
#
......@@ -107,7 +107,7 @@ _ = t(my_dp)
def hflip_my_datapoint(my_dp): # noqa
print("Flipping!")
out = my_dp.flip(-1)
return MyDatapoint.wrap_like(my_dp, out)
return datapoints.wrap(out, like=my_dp)
# %%
......
......@@ -107,26 +107,23 @@ bboxes = datapoints.BoundingBoxes(
print(bboxes)
# %%
# Using the ``wrap_like()`` class method
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Using ``datapoints.wrap()``
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# You can also use the ``wrap_like()`` class method to wrap a tensor object
# You can also use the :func:`~torchvision.datapoints.wrap` function to wrap a tensor object
# into a datapoint. This is useful when you already have an object of the
# desired type, which typically happens when writing transforms: you just want
# to wrap the output like the input. This API is inspired by utils like
# :func:`torch.zeros_like`:
# to wrap the output like the input.
new_bboxes = torch.tensor([0, 20, 30, 40])
new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes)
new_bboxes = datapoints.wrap(new_bboxes, like=bboxes)
assert isinstance(new_bboxes, datapoints.BoundingBoxes)
assert new_bboxes.canvas_size == bboxes.canvas_size
# %%
# The metadata of ``new_bboxes`` is the same as ``bboxes``, but you could pass
# it as a parameter to override it. Check the
# :meth:`~torchvision.datapoints.BoundingBoxes.wrap_like` documentation for
# more details.
# it as a parameter to override it.
#
# Do I have to wrap the output of the datasets myself?
# ----------------------------------------------------
......@@ -230,11 +227,11 @@ assert not isinstance(new_bboxes, datapoints.BoundingBoxes)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# You can re-wrap a pure tensor into a datapoint by just calling the datapoint
# constructor, or by using the ``.wrap_like()`` class method (see more details
# above in :ref:`datapoint_creation`):
# constructor, or by using the :func:`~torchvision.datapoints.wrap` function
# (see more details above in :ref:`datapoint_creation`):
new_bboxes = bboxes + 3
new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes)
new_bboxes = datapoints.wrap(new_bboxes, like=bboxes)
assert isinstance(new_bboxes, datapoints.BoundingBoxes)
# %%
......
......@@ -213,13 +213,13 @@ def test_inplace_op_no_wrapping(make_input, return_type):
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
def test_wrap_like(make_input):
def test_wrap(make_input):
dp = make_input()
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output = dp * 2
dp_new = type(dp).wrap_like(dp, output)
dp_new = datapoints.wrap(output, like=dp)
assert type(dp_new) is type(dp)
assert dp_new.data_ptr() == output.data_ptr()
......
......@@ -570,7 +570,7 @@ class TestResize:
canvas_size=(new_height, new_width),
affine_matrix=affine_matrix,
)
return datapoints.BoundingBoxes.wrap_like(bounding_boxes, expected_bboxes, canvas_size=(new_height, new_width))
return datapoints.wrap(expected_bboxes, like=bounding_boxes, canvas_size=(new_height, new_width))
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("size", OUTPUT_SIZES)
......@@ -815,7 +815,7 @@ class TestHorizontalFlip:
affine_matrix=affine_matrix,
)
return datapoints.BoundingBoxes.wrap_like(bounding_boxes, expected_bboxes)
return datapoints.wrap(expected_bboxes, like=bounding_boxes)
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize(
......@@ -1278,7 +1278,7 @@ class TestVerticalFlip:
affine_matrix=affine_matrix,
)
return datapoints.BoundingBoxes.wrap_like(bounding_boxes, expected_bboxes)
return datapoints.wrap(expected_bboxes, like=bounding_boxes)
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
......
......@@ -12,3 +12,25 @@ if _WARN_ABOUT_BETA_TRANSFORMS:
import warnings
warnings.warn(_BETA_TRANSFORMS_WARNING)
def wrap(wrappee, *, like, **kwargs):
"""Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.datapoint.Datapoint` subclass as ``like``.
If ``like`` is a :class:`~torchvision.datapoint.BoundingBoxes`, the ``format`` and ``canvas_size`` of
``like`` are assigned to ``wrappee``, unless they are passed as ``kwargs``.
Args:
wrappee (Tensor): The tensor to convert.
like (Datapoint): The
kwargs: Can contain "format" and "canvas_size" if ``like`` is a :class:`~torchvision.datapoint.BoundingBoxes`.
Ignored otherwise.
"""
if isinstance(like, BoundingBoxes):
return BoundingBoxes._wrap(
wrappee,
format=kwargs.get("format", like.format),
canvas_size=kwargs.get("canvas_size", like.canvas_size),
)
else:
return wrappee.as_subclass(type(like))
......@@ -75,32 +75,6 @@ class BoundingBoxes(Datapoint):
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return cls._wrap(tensor, format=format, canvas_size=canvas_size)
@classmethod
def wrap_like(
cls,
other: BoundingBoxes,
tensor: torch.Tensor,
*,
format: Optional[Union[BoundingBoxFormat, str]] = None,
canvas_size: Optional[Tuple[int, int]] = None,
) -> BoundingBoxes:
"""Wrap a :class:`torch.Tensor` as :class:`BoundingBoxes` from a reference.
Args:
other (BoundingBoxes): Reference bounding box.
tensor (Tensor): Tensor to be wrapped as :class:`BoundingBoxes`
format (BoundingBoxFormat, str, optional): Format of the bounding box. If omitted, it is taken from the
reference.
canvas_size (two-tuple of ints, optional): Height and width of the corresponding image or video. If
omitted, it is taken from the reference.
"""
return cls._wrap(
tensor,
format=format if format is not None else other.format,
canvas_size=canvas_size if canvas_size is not None else other.canvas_size,
)
@classmethod
def _wrap_output(
cls,
......
......@@ -31,10 +31,6 @@ class Datapoint(torch.Tensor):
requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False
return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)
@classmethod
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
return tensor.as_subclass(cls)
@classmethod
def _wrap_output(
cls,
......
......@@ -32,13 +32,6 @@ class _LabelBase(Datapoint):
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return cls._wrap(tensor, categories=categories)
@classmethod
def wrap_like(cls: Type[L], other: L, tensor: torch.Tensor, *, categories: Optional[Sequence[str]] = None) -> L:
return cls._wrap(
tensor,
categories=categories if categories is not None else other.categories,
)
@classmethod
def from_category(
cls: Type[L],
......
......@@ -36,11 +36,9 @@ class SimpleCopyPaste(Transform):
antialias: Optional[bool],
) -> Tuple[torch.Tensor, Dict[str, Any]]:
paste_masks = paste_target["masks"].wrap_like(paste_target["masks"], paste_target["masks"][random_selection])
paste_boxes = paste_target["boxes"].wrap_like(paste_target["boxes"], paste_target["boxes"][random_selection])
paste_labels = paste_target["labels"].wrap_like(
paste_target["labels"], paste_target["labels"][random_selection]
)
paste_masks = datapoints.wrap(paste_target["masks"][random_selection], like=paste_target["masks"])
paste_boxes = datapoints.wrap(paste_target["boxes"][random_selection], like=paste_target["boxes"])
paste_labels = datapoints.wrap(paste_target["labels"][random_selection], like=paste_target["labels"])
masks = target["masks"]
......@@ -143,7 +141,7 @@ class SimpleCopyPaste(Transform):
c0, c1, c2, c3 = 0, 0, 0, 0
for i, obj in enumerate(flat_sample):
if isinstance(obj, datapoints.Image):
flat_sample[i] = datapoints.Image.wrap_like(obj, output_images[c0])
flat_sample[i] = datapoints.wrap(output_images[c0], like=obj)
c0 += 1
elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_image_pil(output_images[c0])
......@@ -152,13 +150,13 @@ class SimpleCopyPaste(Transform):
flat_sample[i] = output_images[c0]
c0 += 1
elif isinstance(obj, datapoints.BoundingBoxes):
flat_sample[i] = datapoints.BoundingBoxes.wrap_like(obj, output_targets[c1]["boxes"])
flat_sample[i] = datapoints.wrap(output_targets[c1]["boxes"], like=obj)
c1 += 1
elif isinstance(obj, datapoints.Mask):
flat_sample[i] = datapoints.Mask.wrap_like(obj, output_targets[c2]["masks"])
flat_sample[i] = datapoints.wrap(output_targets[c2]["masks"], like=obj)
c2 += 1
elif isinstance(obj, (proto_datapoints.Label, proto_datapoints.OneHotLabel)):
flat_sample[i] = obj.wrap_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type]
flat_sample[i] = datapoints.wrap(output_targets[c3]["labels"], like=obj)
c3 += 1
def forward(self, *inputs: Any) -> Any:
......
......@@ -112,11 +112,11 @@ class FixedSizeCrop(Transform):
if params["is_valid"] is not None:
if isinstance(inpt, (Label, OneHotLabel, datapoints.Mask)):
inpt = inpt.wrap_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type]
inpt = datapoints.wrap(inpt[params["is_valid"]], like=inpt)
elif isinstance(inpt, datapoints.BoundingBoxes):
inpt = datapoints.BoundingBoxes.wrap_like(
inpt,
inpt = datapoints.wrap(
F.clamp_bounding_boxes(inpt[params["is_valid"]], format=inpt.format, canvas_size=inpt.canvas_size),
like=inpt,
)
if params["needs_pad"]:
......
......@@ -249,7 +249,7 @@ class MixUp(_BaseMixUpCutMix):
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
output = datapoints.wrap(output, like=inpt)
return output
else:
......@@ -319,7 +319,7 @@ class CutMix(_BaseMixUpCutMix):
output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
output = datapoints.wrap(output, like=inpt)
return output
else:
......
......@@ -620,7 +620,7 @@ class AugMix(_AutoAugmentBase):
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
if isinstance(orig_image_or_video, (datapoints.Image, datapoints.Video)):
mix = orig_image_or_video.wrap_like(orig_image_or_video, mix) # type: ignore[arg-type]
mix = datapoints.wrap(mix, like=orig_image_or_video)
elif isinstance(orig_image_or_video, PIL.Image.Image):
mix = F.to_image_pil(mix)
......
......@@ -338,7 +338,7 @@ class FiveCrop(Transform):
... images_or_videos, labels = sample
... batch_size = len(images_or_videos)
... image_or_video = images_or_videos[0]
... images_or_videos = image_or_video.wrap_like(image_or_video, torch.stack(images_or_videos))
... images_or_videos = datapoints.wrap(torch.stack(images_or_videos), like=image_or_video)
... labels = torch.full((batch_size,), label, device=images_or_videos.device)
... return images_or_videos, labels
...
......
......@@ -131,7 +131,7 @@ class LinearTransformation(Transform):
output = output.reshape(shape)
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
output = datapoints.wrap(output, like=inpt)
return output
......@@ -423,4 +423,4 @@ class SanitizeBoundingBoxes(Transform):
if is_label:
return output
return type(inpt).wrap_like(inpt, output)
return datapoints.wrap(output, like=inpt)
......@@ -87,7 +87,7 @@ def _horizontal_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) ->
output = horizontal_flip_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
)
return datapoints.BoundingBoxes.wrap_like(inpt, output)
return datapoints.wrap(output, like=inpt)
@_register_kernel_internal(horizontal_flip, datapoints.Video)
......@@ -143,7 +143,7 @@ def _vertical_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) -> da
output = vertical_flip_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
)
return datapoints.BoundingBoxes.wrap_like(inpt, output)
return datapoints.wrap(output, like=inpt)
@_register_kernel_internal(vertical_flip, datapoints.Video)
......@@ -321,7 +321,7 @@ def _resize_mask_dispatch(
inpt: datapoints.Mask, size: List[int], max_size: Optional[int] = None, **kwargs: Any
) -> datapoints.Mask:
output = resize_mask(inpt.as_subclass(torch.Tensor), size, max_size=max_size)
return datapoints.Mask.wrap_like(inpt, output)
return datapoints.wrap(output, like=inpt)
def resize_bounding_boxes(
......@@ -349,7 +349,7 @@ def _resize_bounding_boxes_dispatch(
output, canvas_size = resize_bounding_boxes(
inpt.as_subclass(torch.Tensor), inpt.canvas_size, size, max_size=max_size
)
return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size)
return datapoints.wrap(output, like=inpt, canvas_size=canvas_size)
@_register_kernel_internal(resize, datapoints.Video)
......@@ -857,7 +857,7 @@ def _affine_bounding_boxes_dispatch(
shear=shear,
center=center,
)
return datapoints.BoundingBoxes.wrap_like(inpt, output)
return datapoints.wrap(output, like=inpt)
def affine_mask(
......@@ -912,7 +912,7 @@ def _affine_mask_dispatch(
fill=fill,
center=center,
)
return datapoints.Mask.wrap_like(inpt, output)
return datapoints.wrap(output, like=inpt)
@_register_kernel_internal(affine, datapoints.Video)
......@@ -1058,7 +1058,7 @@ def _rotate_bounding_boxes_dispatch(
expand=expand,
center=center,
)
return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size)
return datapoints.wrap(output, like=inpt, canvas_size=canvas_size)
def rotate_mask(
......@@ -1099,7 +1099,7 @@ def _rotate_mask_dispatch(
**kwargs,
) -> datapoints.Mask:
output = rotate_mask(inpt.as_subclass(torch.Tensor), angle=angle, expand=expand, fill=fill, center=center)
return datapoints.Mask.wrap_like(inpt, output)
return datapoints.wrap(output, like=inpt)
@_register_kernel_internal(rotate, datapoints.Video)
......@@ -1321,7 +1321,7 @@ def _pad_bounding_boxes_dispatch(
padding=padding,
padding_mode=padding_mode,
)
return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size)
return datapoints.wrap(output, like=inpt, canvas_size=canvas_size)
@_register_kernel_internal(pad, datapoints.Video)
......@@ -1396,7 +1396,7 @@ def _crop_bounding_boxes_dispatch(
output, canvas_size = crop_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width
)
return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size)
return datapoints.wrap(output, like=inpt, canvas_size=canvas_size)
@_register_kernel_internal(crop, datapoints.Mask)
......@@ -1670,7 +1670,7 @@ def _perspective_bounding_boxes_dispatch(
endpoints=endpoints,
coefficients=coefficients,
)
return datapoints.BoundingBoxes.wrap_like(inpt, output)
return datapoints.wrap(output, like=inpt)
def perspective_mask(
......@@ -1712,7 +1712,7 @@ def _perspective_mask_dispatch(
fill=fill,
coefficients=coefficients,
)
return datapoints.Mask.wrap_like(inpt, output)
return datapoints.wrap(output, like=inpt)
@_register_kernel_internal(perspective, datapoints.Video)
......@@ -1887,7 +1887,7 @@ def _elastic_bounding_boxes_dispatch(
output = elastic_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, displacement=displacement
)
return datapoints.BoundingBoxes.wrap_like(inpt, output)
return datapoints.wrap(output, like=inpt)
def elastic_mask(
......@@ -1914,7 +1914,7 @@ def _elastic_mask_dispatch(
inpt: datapoints.Mask, displacement: torch.Tensor, fill: _FillTypeJIT = None, **kwargs
) -> datapoints.Mask:
output = elastic_mask(inpt.as_subclass(torch.Tensor), displacement=displacement, fill=fill)
return datapoints.Mask.wrap_like(inpt, output)
return datapoints.wrap(output, like=inpt)
@_register_kernel_internal(elastic, datapoints.Video)
......@@ -2022,7 +2022,7 @@ def _center_crop_bounding_boxes_dispatch(
output, canvas_size = center_crop_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, output_size=output_size
)
return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size)
return datapoints.wrap(output, like=inpt, canvas_size=canvas_size)
@_register_kernel_internal(center_crop, datapoints.Mask)
......@@ -2156,7 +2156,7 @@ def _resized_crop_bounding_boxes_dispatch(
output, canvas_size = resized_crop_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width, size=size
)
return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size)
return datapoints.wrap(output, like=inpt, canvas_size=canvas_size)
def resized_crop_mask(
......@@ -2178,7 +2178,7 @@ def _resized_crop_mask_dispatch(
output = resized_crop_mask(
inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size
)
return datapoints.Mask.wrap_like(inpt, output)
return datapoints.wrap(output, like=inpt)
@_register_kernel_internal(resized_crop, datapoints.Video)
......
......@@ -223,7 +223,7 @@ def convert_format_bounding_boxes(
output = _convert_format_bounding_boxes(
inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=new_format, inplace=inplace
)
return datapoints.BoundingBoxes.wrap_like(inpt, output, format=new_format)
return datapoints.wrap(output, like=inpt, format=new_format)
else:
raise TypeError(
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
......@@ -265,7 +265,7 @@ def clamp_bounding_boxes(
if format is not None or canvas_size is not None:
raise ValueError("For bounding box datapoint inputs, `format` and `canvas_size` must not be passed.")
output = _clamp_bounding_boxes(inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size)
return datapoints.BoundingBoxes.wrap_like(inpt, output)
return datapoints.wrap(output, like=inpt)
else:
raise TypeError(
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
......
......@@ -25,11 +25,11 @@ def _kernel_datapoint_wrapper(kernel):
# regardless of whether we override __torch_function__ in our base class
# or not.
# Also, even if we didn't call `as_subclass` here, we would still need
# this wrapper to call wrap_like(), because the Datapoint type would be
# this wrapper to call wrap(), because the Datapoint type would be
# lost after the first operation due to our own __torch_function__
# logic.
output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs)
return type(inpt).wrap_like(inpt, output)
return datapoints.wrap(output, like=inpt)
return wrapper
......@@ -137,7 +137,7 @@ def _register_five_ten_crop_kernel_internal(functional, input_type):
def wrapper(inpt, *args, **kwargs):
output = kernel(inpt, *args, **kwargs)
container_type = type(output)
return container_type(type(inpt).wrap_like(inpt, o) for o in output)
return container_type(datapoints.wrap(o, like=inpt) for o in output)
return wrapper
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment