Unverified Commit 88591717 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Allow users to choose whether to return Datapoint subclasses or pure Tensor (#7825)

parent 3065ad59
......@@ -18,3 +18,4 @@ see e.g. :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`.
BoundingBoxes
Mask
Datapoint
set_return_type
......@@ -6,6 +6,20 @@ from common_utils import assert_equal
from PIL import Image
from torchvision import datapoints
from common_utils import (
make_bounding_box,
make_detection_mask,
make_image,
make_image_tensor,
make_segmentation_mask,
make_video,
)
@pytest.fixture(autouse=True)
def preserve_default_wrapping_behaviour():
yield
datapoints.set_return_type("Tensor")
@pytest.mark.parametrize("data", [torch.rand(3, 32, 32), Image.new("RGB", (32, 32), color=123)])
......@@ -80,72 +94,88 @@ def test_to_wrapping():
assert image_to.dtype is torch.float64
def test_to_datapoint_reference():
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_to_datapoint_reference(return_type):
tensor = torch.rand((3, 16, 16), dtype=torch.float64)
image = datapoints.Image(tensor)
tensor_to = tensor.to(image)
with datapoints.set_return_type(return_type):
tensor_to = tensor.to(image)
assert type(tensor_to) is torch.Tensor
assert type(tensor_to) is (datapoints.Image if return_type == "datapoint" else torch.Tensor)
assert tensor_to.dtype is torch.float64
def test_clone_wrapping():
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_clone_wrapping(return_type):
image = datapoints.Image(torch.rand(3, 16, 16))
image_clone = image.clone()
with datapoints.set_return_type(return_type):
image_clone = image.clone()
assert type(image_clone) is datapoints.Image
assert image_clone.data_ptr() != image.data_ptr()
def test_requires_grad__wrapping():
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_requires_grad__wrapping(return_type):
image = datapoints.Image(torch.rand(3, 16, 16))
assert not image.requires_grad
image_requires_grad = image.requires_grad_(True)
with datapoints.set_return_type(return_type):
image_requires_grad = image.requires_grad_(True)
assert type(image_requires_grad) is datapoints.Image
assert image.requires_grad
assert image_requires_grad.requires_grad
def test_detach_wrapping():
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_detach_wrapping(return_type):
image = datapoints.Image(torch.rand(3, 16, 16), requires_grad=True)
image_detached = image.detach()
with datapoints.set_return_type(return_type):
image_detached = image.detach()
assert type(image_detached) is datapoints.Image
def test_no_wrapping_exceptions_with_metadata():
# Sanity checks for the ops in _NO_WRAPPING_EXCEPTIONS and datapoints with metadata
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_force_subclass_with_metadata(return_type):
# Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and datapoints with metadata
format, canvas_size = "XYXY", (32, 32)
bbox = datapoints.BoundingBoxes([[0, 0, 5, 5], [2, 2, 7, 7]], format=format, canvas_size=canvas_size)
datapoints.set_return_type(return_type)
bbox = bbox.clone()
assert bbox.format, bbox.canvas_size == (format, canvas_size)
if return_type == "datapoint":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
bbox = bbox.to(torch.float64)
assert bbox.format, bbox.canvas_size == (format, canvas_size)
if return_type == "datapoint":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
bbox = bbox.detach()
assert bbox.format, bbox.canvas_size == (format, canvas_size)
if return_type == "datapoint":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
assert not bbox.requires_grad
bbox.requires_grad_(True)
assert bbox.format, bbox.canvas_size == (format, canvas_size)
assert bbox.requires_grad
if return_type == "datapoint":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
assert bbox.requires_grad
def test_other_op_no_wrapping():
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_other_op_no_wrapping(return_type):
image = datapoints.Image(torch.rand(3, 16, 16))
# any operation besides the ones listed in `Datapoint._NO_WRAPPING_EXCEPTIONS` will do here
output = image * 2
with datapoints.set_return_type(return_type):
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output = image * 2
assert type(output) is torch.Tensor
assert type(output) is (datapoints.Image if return_type == "datapoint" else torch.Tensor)
@pytest.mark.parametrize(
......@@ -164,19 +194,21 @@ def test_no_tensor_output_op_no_wrapping(op):
assert type(output) is not datapoints.Image
def test_inplace_op_no_wrapping():
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_inplace_op_no_wrapping(return_type):
image = datapoints.Image(torch.rand(3, 16, 16))
output = image.add_(0)
with datapoints.set_return_type(return_type):
output = image.add_(0)
assert type(output) is torch.Tensor
assert type(output) is (datapoints.Image if return_type == "datapoint" else torch.Tensor)
assert type(image) is datapoints.Image
def test_wrap_like():
image = datapoints.Image(torch.rand(3, 16, 16))
# any operation besides the ones listed in `Datapoint._NO_WRAPPING_EXCEPTIONS` will do here
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output = image * 2
image_new = datapoints.Image.wrap_like(image, output)
......@@ -209,3 +241,91 @@ def test_deepcopy(datapoint, requires_grad):
assert type(datapoint_deepcopied) is type(datapoint)
assert datapoint_deepcopied.requires_grad is requires_grad
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_operations(return_type):
datapoints.set_return_type(return_type)
img = datapoints.Image(torch.rand(3, 10, 10))
t = torch.rand(3, 10, 10)
mask = datapoints.Mask(torch.rand(1, 10, 10))
for out in (
[
img + t,
t + img,
img * t,
t * img,
img + 3,
3 + img,
img * 3,
3 * img,
img + img,
img.sum(),
img.reshape(-1),
img.float(),
torch.stack([img, img]),
]
+ list(torch.chunk(img, 2))
+ list(torch.unbind(img))
):
assert type(out) is (datapoints.Image if return_type == "datapoint" else torch.Tensor)
for out in (
[
mask + t,
t + mask,
mask * t,
t * mask,
mask + 3,
3 + mask,
mask * 3,
3 * mask,
mask + mask,
mask.sum(),
mask.reshape(-1),
mask.float(),
torch.stack([mask, mask]),
]
+ list(torch.chunk(mask, 2))
+ list(torch.unbind(mask))
):
assert type(out) is (datapoints.Mask if return_type == "datapoint" else torch.Tensor)
with pytest.raises(TypeError, match="unsupported operand type"):
img + mask
with pytest.raises(TypeError, match="unsupported operand type"):
img * mask
bboxes = datapoints.BoundingBoxes(
[[17, 16, 344, 495], [0, 10, 0, 10]], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(1000, 1000)
)
t = torch.rand(2, 4)
for out in (
[
bboxes + t,
t + bboxes,
bboxes * t,
t * bboxes,
bboxes + 3,
3 + bboxes,
bboxes * 3,
3 * bboxes,
bboxes + bboxes,
bboxes.sum(),
bboxes.reshape(-1),
bboxes.float(),
torch.stack([bboxes, bboxes]),
]
+ list(torch.chunk(bboxes, 2))
+ list(torch.unbind(bboxes))
):
if return_type == "Tensor":
assert type(out) is torch.Tensor
else:
assert isinstance(out, datapoints.BoundingBoxes)
assert hasattr(out, "format")
assert hasattr(out, "canvas_size")
import torch
from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS
from ._bounding_box import BoundingBoxes, BoundingBoxFormat
from ._datapoint import Datapoint
from ._image import Image
from ._mask import Mask
from ._torch_function_helpers import set_return_type
from ._video import Video
if _WARN_ABOUT_BETA_TRANSFORMS:
......
from __future__ import annotations
from enum import Enum
from typing import Any, Optional, Tuple, Union
from typing import Any, Mapping, Optional, Sequence, Tuple, Union
import torch
from torch.utils._pytree import tree_flatten
from ._datapoint import Datapoint
......@@ -48,11 +49,12 @@ class BoundingBoxes(Datapoint):
canvas_size: Tuple[int, int]
@classmethod
def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override]
if tensor.ndim == 1:
tensor = tensor.unsqueeze(0)
elif tensor.ndim != 2:
raise ValueError(f"Expected a 1D or 2D tensor, got {tensor.ndim}D")
def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int], check_dims: bool = True) -> BoundingBoxes: # type: ignore[override]
if check_dims:
if tensor.ndim == 1:
tensor = tensor.unsqueeze(0)
elif tensor.ndim != 2:
raise ValueError(f"Expected a 1D or 2D tensor, got {tensor.ndim}D")
if isinstance(format, str):
format = BoundingBoxFormat[format.upper()]
bounding_boxes = tensor.as_subclass(cls)
......@@ -99,5 +101,29 @@ class BoundingBoxes(Datapoint):
canvas_size=canvas_size if canvas_size is not None else other.canvas_size,
)
@classmethod
def _wrap_output(
cls,
output: torch.Tensor,
args: Sequence[Any] = (),
kwargs: Optional[Mapping[str, Any]] = None,
) -> BoundingBoxes:
# If there are BoundingBoxes instances in the output, their metadata got lost when we called
# super().__torch_function__. We need to restore the metadata somehow, so we choose to take
# the metadata from the first bbox in the parameters.
# This should be what we want in most cases. When it's not, it's probably a mis-use anyway, e.g.
# something like some_xyxy_bbox + some_xywh_bbox; we don't guard against those cases.
flat_params, _ = tree_flatten(args + (tuple(kwargs.values()) if kwargs else ())) # type: ignore[operator]
first_bbox_from_args = next(x for x in flat_params if isinstance(x, BoundingBoxes))
format, canvas_size = first_bbox_from_args.format, first_bbox_from_args.canvas_size
if isinstance(output, torch.Tensor) and not isinstance(output, BoundingBoxes):
output = BoundingBoxes._wrap(output, format=format, canvas_size=canvas_size, check_dims=False)
elif isinstance(output, (tuple, list)):
output = type(output)(
BoundingBoxes._wrap(part, format=format, canvas_size=canvas_size, check_dims=False) for part in output
)
return output
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(format=self.format, canvas_size=self.canvas_size)
......@@ -6,6 +6,8 @@ import torch
from torch._C import DisableTorchFunctionSubclass
from torch.types import _device, _dtype, _size
from torchvision.datapoints._torch_function_helpers import _FORCE_TORCHFUNCTION_SUBCLASS, _must_return_subclass
D = TypeVar("D", bound="Datapoint")
......@@ -33,9 +35,21 @@ class Datapoint(torch.Tensor):
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
return tensor.as_subclass(cls)
# The ops in this set are those that should *preserve* the Datapoint type,
# i.e. they are exceptions to the "no wrapping" rule.
_NO_WRAPPING_EXCEPTIONS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_}
@classmethod
def _wrap_output(
cls,
output: torch.Tensor,
args: Sequence[Any] = (),
kwargs: Optional[Mapping[str, Any]] = None,
) -> torch.Tensor:
# Same as torch._tensor._convert
if isinstance(output, torch.Tensor) and not isinstance(output, cls):
output = output.as_subclass(cls)
if isinstance(output, (tuple, list)):
# Also handles things like namedtuples
output = type(output)(cls._wrap_output(part, args, kwargs) for part in output)
return output
@classmethod
def __torch_function__(
......@@ -60,7 +74,7 @@ class Datapoint(torch.Tensor):
2. For most operations, there is no way of knowing if the input type is still valid for the output.
For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are
listed in :attr:`Datapoint._NO_WRAPPING_EXCEPTIONS`
listed in _FORCE_TORCHFUNCTION_SUBCLASS
"""
# Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
# need to reimplement the functionality.
......@@ -68,19 +82,22 @@ class Datapoint(torch.Tensor):
if not all(issubclass(cls, t) for t in types):
return NotImplemented
# Like in the base Tensor.__torch_function__ implementation, it's easier to always use
# DisableTorchFunctionSubclass and then manually re-wrap the output if necessary
with DisableTorchFunctionSubclass():
output = func(*args, **kwargs or dict())
if func in cls._NO_WRAPPING_EXCEPTIONS and isinstance(args[0], cls):
must_return_subclass = _must_return_subclass()
if must_return_subclass or (func in _FORCE_TORCHFUNCTION_SUBCLASS and isinstance(args[0], cls)):
# We also require the primary operand, i.e. `args[0]`, to be
# an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
# invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
# `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with
# `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
# be wrapped into a `datapoints.Image`.
return cls.wrap_like(args[0], output)
return cls._wrap_output(output, args, kwargs)
if isinstance(output, cls):
if not must_return_subclass and isinstance(output, cls):
# DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`,
# so for those, the output is still a Datapoint. Thus, we need to manually unwrap.
return output.as_subclass(torch.Tensor)
......
import torch
_TORCHFUNCTION_SUBCLASS = False
class _ReturnTypeCM:
def __init__(self, to_restore):
self.to_restore = to_restore
def __enter__(self):
return self
def __exit__(self, *args):
global _TORCHFUNCTION_SUBCLASS
_TORCHFUNCTION_SUBCLASS = self.to_restore
def set_return_type(return_type: str):
"""Set the return type of torch operations on datapoints.
Can be used as a global flag for the entire program:
.. code:: python
set_return_type("datapoints")
img = datapoints.Image(torch.rand(3, 5, 5))
img + 2 # This is an Image
or as a context manager to restrict the scope:
.. code:: python
img = datapoints.Image(torch.rand(3, 5, 5))
with set_return_type("datapoints"):
img + 2 # This is an Image
img + 2 # This is a pure Tensor
Args:
return_type (str): Can be "datapoint" or "tensor". Default is "tensor".
"""
global _TORCHFUNCTION_SUBCLASS
to_restore = _TORCHFUNCTION_SUBCLASS
_TORCHFUNCTION_SUBCLASS = {"tensor": False, "datapoint": True}[return_type.lower()]
return _ReturnTypeCM(to_restore)
def _must_return_subclass():
return _TORCHFUNCTION_SUBCLASS
# For those ops we always want to preserve the original subclass instead of returning a pure Tensor
_FORCE_TORCHFUNCTION_SUBCLASS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_}
......@@ -401,7 +401,7 @@ class SanitizeBoundingBoxes(Transform):
valid &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w)
valid &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h)
params = dict(valid=valid, labels=labels)
params = dict(valid=valid.as_subclass(torch.Tensor), labels=labels)
flat_outputs = [
# Even-though it may look like we're transforming all inputs, we don't:
# _transform() will only care about BoundingBoxeses and the labels
......
......@@ -19,6 +19,8 @@ _KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {}
def _kernel_datapoint_wrapper(kernel):
@functools.wraps(kernel)
def wrapper(inpt, *args, **kwargs):
# We always pass datapoints as pure tensors to the kernels to avoid going through the
# Tensor.__torch_function__ logic, which is costly.
output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs)
return type(inpt).wrap_like(inpt, output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment