Unverified Commit 88591717 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Allow users to choose whether to return Datapoint subclasses or pure Tensor (#7825)

parent 3065ad59
...@@ -18,3 +18,4 @@ see e.g. :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`. ...@@ -18,3 +18,4 @@ see e.g. :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`.
BoundingBoxes BoundingBoxes
Mask Mask
Datapoint Datapoint
set_return_type
...@@ -6,6 +6,20 @@ from common_utils import assert_equal ...@@ -6,6 +6,20 @@ from common_utils import assert_equal
from PIL import Image from PIL import Image
from torchvision import datapoints from torchvision import datapoints
from common_utils import (
make_bounding_box,
make_detection_mask,
make_image,
make_image_tensor,
make_segmentation_mask,
make_video,
)
@pytest.fixture(autouse=True)
def preserve_default_wrapping_behaviour():
yield
datapoints.set_return_type("Tensor")
@pytest.mark.parametrize("data", [torch.rand(3, 32, 32), Image.new("RGB", (32, 32), color=123)]) @pytest.mark.parametrize("data", [torch.rand(3, 32, 32), Image.new("RGB", (32, 32), color=123)])
...@@ -80,72 +94,88 @@ def test_to_wrapping(): ...@@ -80,72 +94,88 @@ def test_to_wrapping():
assert image_to.dtype is torch.float64 assert image_to.dtype is torch.float64
def test_to_datapoint_reference(): @pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_to_datapoint_reference(return_type):
tensor = torch.rand((3, 16, 16), dtype=torch.float64) tensor = torch.rand((3, 16, 16), dtype=torch.float64)
image = datapoints.Image(tensor) image = datapoints.Image(tensor)
tensor_to = tensor.to(image) with datapoints.set_return_type(return_type):
tensor_to = tensor.to(image)
assert type(tensor_to) is torch.Tensor assert type(tensor_to) is (datapoints.Image if return_type == "datapoint" else torch.Tensor)
assert tensor_to.dtype is torch.float64 assert tensor_to.dtype is torch.float64
def test_clone_wrapping(): @pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_clone_wrapping(return_type):
image = datapoints.Image(torch.rand(3, 16, 16)) image = datapoints.Image(torch.rand(3, 16, 16))
image_clone = image.clone() with datapoints.set_return_type(return_type):
image_clone = image.clone()
assert type(image_clone) is datapoints.Image assert type(image_clone) is datapoints.Image
assert image_clone.data_ptr() != image.data_ptr() assert image_clone.data_ptr() != image.data_ptr()
def test_requires_grad__wrapping(): @pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_requires_grad__wrapping(return_type):
image = datapoints.Image(torch.rand(3, 16, 16)) image = datapoints.Image(torch.rand(3, 16, 16))
assert not image.requires_grad assert not image.requires_grad
image_requires_grad = image.requires_grad_(True) with datapoints.set_return_type(return_type):
image_requires_grad = image.requires_grad_(True)
assert type(image_requires_grad) is datapoints.Image assert type(image_requires_grad) is datapoints.Image
assert image.requires_grad assert image.requires_grad
assert image_requires_grad.requires_grad assert image_requires_grad.requires_grad
def test_detach_wrapping(): @pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_detach_wrapping(return_type):
image = datapoints.Image(torch.rand(3, 16, 16), requires_grad=True) image = datapoints.Image(torch.rand(3, 16, 16), requires_grad=True)
image_detached = image.detach() with datapoints.set_return_type(return_type):
image_detached = image.detach()
assert type(image_detached) is datapoints.Image assert type(image_detached) is datapoints.Image
def test_no_wrapping_exceptions_with_metadata(): @pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
# Sanity checks for the ops in _NO_WRAPPING_EXCEPTIONS and datapoints with metadata def test_force_subclass_with_metadata(return_type):
# Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and datapoints with metadata
format, canvas_size = "XYXY", (32, 32) format, canvas_size = "XYXY", (32, 32)
bbox = datapoints.BoundingBoxes([[0, 0, 5, 5], [2, 2, 7, 7]], format=format, canvas_size=canvas_size) bbox = datapoints.BoundingBoxes([[0, 0, 5, 5], [2, 2, 7, 7]], format=format, canvas_size=canvas_size)
datapoints.set_return_type(return_type)
bbox = bbox.clone() bbox = bbox.clone()
assert bbox.format, bbox.canvas_size == (format, canvas_size) if return_type == "datapoint":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
bbox = bbox.to(torch.float64) bbox = bbox.to(torch.float64)
assert bbox.format, bbox.canvas_size == (format, canvas_size) if return_type == "datapoint":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
bbox = bbox.detach() bbox = bbox.detach()
assert bbox.format, bbox.canvas_size == (format, canvas_size) if return_type == "datapoint":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
assert not bbox.requires_grad assert not bbox.requires_grad
bbox.requires_grad_(True) bbox.requires_grad_(True)
assert bbox.format, bbox.canvas_size == (format, canvas_size) if return_type == "datapoint":
assert bbox.requires_grad assert bbox.format, bbox.canvas_size == (format, canvas_size)
assert bbox.requires_grad
def test_other_op_no_wrapping(): @pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_other_op_no_wrapping(return_type):
image = datapoints.Image(torch.rand(3, 16, 16)) image = datapoints.Image(torch.rand(3, 16, 16))
# any operation besides the ones listed in `Datapoint._NO_WRAPPING_EXCEPTIONS` will do here with datapoints.set_return_type(return_type):
output = image * 2 # any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output = image * 2
assert type(output) is torch.Tensor assert type(output) is (datapoints.Image if return_type == "datapoint" else torch.Tensor)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -164,19 +194,21 @@ def test_no_tensor_output_op_no_wrapping(op): ...@@ -164,19 +194,21 @@ def test_no_tensor_output_op_no_wrapping(op):
assert type(output) is not datapoints.Image assert type(output) is not datapoints.Image
def test_inplace_op_no_wrapping(): @pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_inplace_op_no_wrapping(return_type):
image = datapoints.Image(torch.rand(3, 16, 16)) image = datapoints.Image(torch.rand(3, 16, 16))
output = image.add_(0) with datapoints.set_return_type(return_type):
output = image.add_(0)
assert type(output) is torch.Tensor assert type(output) is (datapoints.Image if return_type == "datapoint" else torch.Tensor)
assert type(image) is datapoints.Image assert type(image) is datapoints.Image
def test_wrap_like(): def test_wrap_like():
image = datapoints.Image(torch.rand(3, 16, 16)) image = datapoints.Image(torch.rand(3, 16, 16))
# any operation besides the ones listed in `Datapoint._NO_WRAPPING_EXCEPTIONS` will do here # any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output = image * 2 output = image * 2
image_new = datapoints.Image.wrap_like(image, output) image_new = datapoints.Image.wrap_like(image, output)
...@@ -209,3 +241,91 @@ def test_deepcopy(datapoint, requires_grad): ...@@ -209,3 +241,91 @@ def test_deepcopy(datapoint, requires_grad):
assert type(datapoint_deepcopied) is type(datapoint) assert type(datapoint_deepcopied) is type(datapoint)
assert datapoint_deepcopied.requires_grad is requires_grad assert datapoint_deepcopied.requires_grad is requires_grad
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_operations(return_type):
datapoints.set_return_type(return_type)
img = datapoints.Image(torch.rand(3, 10, 10))
t = torch.rand(3, 10, 10)
mask = datapoints.Mask(torch.rand(1, 10, 10))
for out in (
[
img + t,
t + img,
img * t,
t * img,
img + 3,
3 + img,
img * 3,
3 * img,
img + img,
img.sum(),
img.reshape(-1),
img.float(),
torch.stack([img, img]),
]
+ list(torch.chunk(img, 2))
+ list(torch.unbind(img))
):
assert type(out) is (datapoints.Image if return_type == "datapoint" else torch.Tensor)
for out in (
[
mask + t,
t + mask,
mask * t,
t * mask,
mask + 3,
3 + mask,
mask * 3,
3 * mask,
mask + mask,
mask.sum(),
mask.reshape(-1),
mask.float(),
torch.stack([mask, mask]),
]
+ list(torch.chunk(mask, 2))
+ list(torch.unbind(mask))
):
assert type(out) is (datapoints.Mask if return_type == "datapoint" else torch.Tensor)
with pytest.raises(TypeError, match="unsupported operand type"):
img + mask
with pytest.raises(TypeError, match="unsupported operand type"):
img * mask
bboxes = datapoints.BoundingBoxes(
[[17, 16, 344, 495], [0, 10, 0, 10]], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(1000, 1000)
)
t = torch.rand(2, 4)
for out in (
[
bboxes + t,
t + bboxes,
bboxes * t,
t * bboxes,
bboxes + 3,
3 + bboxes,
bboxes * 3,
3 * bboxes,
bboxes + bboxes,
bboxes.sum(),
bboxes.reshape(-1),
bboxes.float(),
torch.stack([bboxes, bboxes]),
]
+ list(torch.chunk(bboxes, 2))
+ list(torch.unbind(bboxes))
):
if return_type == "Tensor":
assert type(out) is torch.Tensor
else:
assert isinstance(out, datapoints.BoundingBoxes)
assert hasattr(out, "format")
assert hasattr(out, "canvas_size")
import torch
from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS
from ._bounding_box import BoundingBoxes, BoundingBoxFormat from ._bounding_box import BoundingBoxes, BoundingBoxFormat
from ._datapoint import Datapoint from ._datapoint import Datapoint
from ._image import Image from ._image import Image
from ._mask import Mask from ._mask import Mask
from ._torch_function_helpers import set_return_type
from ._video import Video from ._video import Video
if _WARN_ABOUT_BETA_TRANSFORMS: if _WARN_ABOUT_BETA_TRANSFORMS:
......
from __future__ import annotations from __future__ import annotations
from enum import Enum from enum import Enum
from typing import Any, Optional, Tuple, Union from typing import Any, Mapping, Optional, Sequence, Tuple, Union
import torch import torch
from torch.utils._pytree import tree_flatten
from ._datapoint import Datapoint from ._datapoint import Datapoint
...@@ -48,11 +49,12 @@ class BoundingBoxes(Datapoint): ...@@ -48,11 +49,12 @@ class BoundingBoxes(Datapoint):
canvas_size: Tuple[int, int] canvas_size: Tuple[int, int]
@classmethod @classmethod
def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override] def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int], check_dims: bool = True) -> BoundingBoxes: # type: ignore[override]
if tensor.ndim == 1: if check_dims:
tensor = tensor.unsqueeze(0) if tensor.ndim == 1:
elif tensor.ndim != 2: tensor = tensor.unsqueeze(0)
raise ValueError(f"Expected a 1D or 2D tensor, got {tensor.ndim}D") elif tensor.ndim != 2:
raise ValueError(f"Expected a 1D or 2D tensor, got {tensor.ndim}D")
if isinstance(format, str): if isinstance(format, str):
format = BoundingBoxFormat[format.upper()] format = BoundingBoxFormat[format.upper()]
bounding_boxes = tensor.as_subclass(cls) bounding_boxes = tensor.as_subclass(cls)
...@@ -99,5 +101,29 @@ class BoundingBoxes(Datapoint): ...@@ -99,5 +101,29 @@ class BoundingBoxes(Datapoint):
canvas_size=canvas_size if canvas_size is not None else other.canvas_size, canvas_size=canvas_size if canvas_size is not None else other.canvas_size,
) )
@classmethod
def _wrap_output(
cls,
output: torch.Tensor,
args: Sequence[Any] = (),
kwargs: Optional[Mapping[str, Any]] = None,
) -> BoundingBoxes:
# If there are BoundingBoxes instances in the output, their metadata got lost when we called
# super().__torch_function__. We need to restore the metadata somehow, so we choose to take
# the metadata from the first bbox in the parameters.
# This should be what we want in most cases. When it's not, it's probably a mis-use anyway, e.g.
# something like some_xyxy_bbox + some_xywh_bbox; we don't guard against those cases.
flat_params, _ = tree_flatten(args + (tuple(kwargs.values()) if kwargs else ())) # type: ignore[operator]
first_bbox_from_args = next(x for x in flat_params if isinstance(x, BoundingBoxes))
format, canvas_size = first_bbox_from_args.format, first_bbox_from_args.canvas_size
if isinstance(output, torch.Tensor) and not isinstance(output, BoundingBoxes):
output = BoundingBoxes._wrap(output, format=format, canvas_size=canvas_size, check_dims=False)
elif isinstance(output, (tuple, list)):
output = type(output)(
BoundingBoxes._wrap(part, format=format, canvas_size=canvas_size, check_dims=False) for part in output
)
return output
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(format=self.format, canvas_size=self.canvas_size) return self._make_repr(format=self.format, canvas_size=self.canvas_size)
...@@ -6,6 +6,8 @@ import torch ...@@ -6,6 +6,8 @@ import torch
from torch._C import DisableTorchFunctionSubclass from torch._C import DisableTorchFunctionSubclass
from torch.types import _device, _dtype, _size from torch.types import _device, _dtype, _size
from torchvision.datapoints._torch_function_helpers import _FORCE_TORCHFUNCTION_SUBCLASS, _must_return_subclass
D = TypeVar("D", bound="Datapoint") D = TypeVar("D", bound="Datapoint")
...@@ -33,9 +35,21 @@ class Datapoint(torch.Tensor): ...@@ -33,9 +35,21 @@ class Datapoint(torch.Tensor):
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
return tensor.as_subclass(cls) return tensor.as_subclass(cls)
# The ops in this set are those that should *preserve* the Datapoint type, @classmethod
# i.e. they are exceptions to the "no wrapping" rule. def _wrap_output(
_NO_WRAPPING_EXCEPTIONS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_} cls,
output: torch.Tensor,
args: Sequence[Any] = (),
kwargs: Optional[Mapping[str, Any]] = None,
) -> torch.Tensor:
# Same as torch._tensor._convert
if isinstance(output, torch.Tensor) and not isinstance(output, cls):
output = output.as_subclass(cls)
if isinstance(output, (tuple, list)):
# Also handles things like namedtuples
output = type(output)(cls._wrap_output(part, args, kwargs) for part in output)
return output
@classmethod @classmethod
def __torch_function__( def __torch_function__(
...@@ -60,7 +74,7 @@ class Datapoint(torch.Tensor): ...@@ -60,7 +74,7 @@ class Datapoint(torch.Tensor):
2. For most operations, there is no way of knowing if the input type is still valid for the output. 2. For most operations, there is no way of knowing if the input type is still valid for the output.
For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are
listed in :attr:`Datapoint._NO_WRAPPING_EXCEPTIONS` listed in _FORCE_TORCHFUNCTION_SUBCLASS
""" """
# Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we # Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
# need to reimplement the functionality. # need to reimplement the functionality.
...@@ -68,19 +82,22 @@ class Datapoint(torch.Tensor): ...@@ -68,19 +82,22 @@ class Datapoint(torch.Tensor):
if not all(issubclass(cls, t) for t in types): if not all(issubclass(cls, t) for t in types):
return NotImplemented return NotImplemented
# Like in the base Tensor.__torch_function__ implementation, it's easier to always use
# DisableTorchFunctionSubclass and then manually re-wrap the output if necessary
with DisableTorchFunctionSubclass(): with DisableTorchFunctionSubclass():
output = func(*args, **kwargs or dict()) output = func(*args, **kwargs or dict())
if func in cls._NO_WRAPPING_EXCEPTIONS and isinstance(args[0], cls): must_return_subclass = _must_return_subclass()
if must_return_subclass or (func in _FORCE_TORCHFUNCTION_SUBCLASS and isinstance(args[0], cls)):
# We also require the primary operand, i.e. `args[0]`, to be # We also require the primary operand, i.e. `args[0]`, to be
# an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will # an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
# invoke this method on *all* types involved in the computation by walking the MRO upwards. For example, # invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
# `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with # `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with
# `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would # `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
# be wrapped into a `datapoints.Image`. # be wrapped into a `datapoints.Image`.
return cls.wrap_like(args[0], output) return cls._wrap_output(output, args, kwargs)
if isinstance(output, cls): if not must_return_subclass and isinstance(output, cls):
# DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`, # DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`,
# so for those, the output is still a Datapoint. Thus, we need to manually unwrap. # so for those, the output is still a Datapoint. Thus, we need to manually unwrap.
return output.as_subclass(torch.Tensor) return output.as_subclass(torch.Tensor)
......
import torch
_TORCHFUNCTION_SUBCLASS = False
class _ReturnTypeCM:
def __init__(self, to_restore):
self.to_restore = to_restore
def __enter__(self):
return self
def __exit__(self, *args):
global _TORCHFUNCTION_SUBCLASS
_TORCHFUNCTION_SUBCLASS = self.to_restore
def set_return_type(return_type: str):
"""Set the return type of torch operations on datapoints.
Can be used as a global flag for the entire program:
.. code:: python
set_return_type("datapoints")
img = datapoints.Image(torch.rand(3, 5, 5))
img + 2 # This is an Image
or as a context manager to restrict the scope:
.. code:: python
img = datapoints.Image(torch.rand(3, 5, 5))
with set_return_type("datapoints"):
img + 2 # This is an Image
img + 2 # This is a pure Tensor
Args:
return_type (str): Can be "datapoint" or "tensor". Default is "tensor".
"""
global _TORCHFUNCTION_SUBCLASS
to_restore = _TORCHFUNCTION_SUBCLASS
_TORCHFUNCTION_SUBCLASS = {"tensor": False, "datapoint": True}[return_type.lower()]
return _ReturnTypeCM(to_restore)
def _must_return_subclass():
return _TORCHFUNCTION_SUBCLASS
# For those ops we always want to preserve the original subclass instead of returning a pure Tensor
_FORCE_TORCHFUNCTION_SUBCLASS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_}
...@@ -401,7 +401,7 @@ class SanitizeBoundingBoxes(Transform): ...@@ -401,7 +401,7 @@ class SanitizeBoundingBoxes(Transform):
valid &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w) valid &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w)
valid &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h) valid &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h)
params = dict(valid=valid, labels=labels) params = dict(valid=valid.as_subclass(torch.Tensor), labels=labels)
flat_outputs = [ flat_outputs = [
# Even-though it may look like we're transforming all inputs, we don't: # Even-though it may look like we're transforming all inputs, we don't:
# _transform() will only care about BoundingBoxeses and the labels # _transform() will only care about BoundingBoxeses and the labels
......
...@@ -19,6 +19,8 @@ _KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {} ...@@ -19,6 +19,8 @@ _KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {}
def _kernel_datapoint_wrapper(kernel): def _kernel_datapoint_wrapper(kernel):
@functools.wraps(kernel) @functools.wraps(kernel)
def wrapper(inpt, *args, **kwargs): def wrapper(inpt, *args, **kwargs):
# We always pass datapoints as pure tensors to the kernels to avoid going through the
# Tensor.__torch_function__ logic, which is costly.
output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs) output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs)
return type(inpt).wrap_like(inpt, output) return type(inpt).wrap_like(inpt, output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment