Unverified Commit d5f4cc38 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Datapoint -> TVTensor; datapoint[s] -> tv_tensor[s] (#7894)

parent b9447fdd
...@@ -5,10 +5,10 @@ from typing import Any, Optional, Union ...@@ -5,10 +5,10 @@ from typing import Any, Optional, Union
import PIL.Image import PIL.Image
import torch import torch
from ._datapoint import Datapoint from ._tv_tensor import TVTensor
class Image(Datapoint): class Image(TVTensor):
"""[BETA] :class:`torch.Tensor` subclass for images. """[BETA] :class:`torch.Tensor` subclass for images.
.. note:: .. note::
......
...@@ -5,10 +5,10 @@ from typing import Any, Optional, Union ...@@ -5,10 +5,10 @@ from typing import Any, Optional, Union
import PIL.Image import PIL.Image
import torch import torch
from ._datapoint import Datapoint from ._tv_tensor import TVTensor
class Mask(Datapoint): class Mask(TVTensor):
"""[BETA] :class:`torch.Tensor` subclass for segmentation and detection masks. """[BETA] :class:`torch.Tensor` subclass for segmentation and detection masks.
Args: Args:
......
...@@ -16,7 +16,7 @@ class _ReturnTypeCM: ...@@ -16,7 +16,7 @@ class _ReturnTypeCM:
def set_return_type(return_type: str): def set_return_type(return_type: str):
"""[BETA] Set the return type of torch operations on datapoints. """[BETA] Set the return type of torch operations on tv_tensors.
This only affects the behaviour of torch operations. It has no effect on This only affects the behaviour of torch operations. It has no effect on
``torchvision`` transforms or functionals, which will always return as ``torchvision`` transforms or functionals, which will always return as
...@@ -33,28 +33,28 @@ def set_return_type(return_type: str): ...@@ -33,28 +33,28 @@ def set_return_type(return_type: str):
.. code:: python .. code:: python
img = datapoints.Image(torch.rand(3, 5, 5)) img = tv_tensors.Image(torch.rand(3, 5, 5))
img + 2 # This is a pure Tensor (default behaviour) img + 2 # This is a pure Tensor (default behaviour)
set_return_type("datapoints") set_return_type("tv_tensors")
img + 2 # This is an Image img + 2 # This is an Image
or as a context manager to restrict the scope: or as a context manager to restrict the scope:
.. code:: python .. code:: python
img = datapoints.Image(torch.rand(3, 5, 5)) img = tv_tensors.Image(torch.rand(3, 5, 5))
img + 2 # This is a pure Tensor img + 2 # This is a pure Tensor
with set_return_type("datapoints"): with set_return_type("tv_tensors"):
img + 2 # This is an Image img + 2 # This is an Image
img + 2 # This is a pure Tensor img + 2 # This is a pure Tensor
Args: Args:
return_type (str): Can be "datapoint" or "tensor". Default is "tensor". return_type (str): Can be "tv_tensor" or "tensor". Default is "tensor".
""" """
global _TORCHFUNCTION_SUBCLASS global _TORCHFUNCTION_SUBCLASS
to_restore = _TORCHFUNCTION_SUBCLASS to_restore = _TORCHFUNCTION_SUBCLASS
_TORCHFUNCTION_SUBCLASS = {"tensor": False, "datapoint": True}[return_type.lower()] _TORCHFUNCTION_SUBCLASS = {"tensor": False, "tv_tensor": True}[return_type.lower()]
return _ReturnTypeCM(to_restore) return _ReturnTypeCM(to_restore)
......
...@@ -6,18 +6,18 @@ import torch ...@@ -6,18 +6,18 @@ import torch
from torch._C import DisableTorchFunctionSubclass from torch._C import DisableTorchFunctionSubclass
from torch.types import _device, _dtype, _size from torch.types import _device, _dtype, _size
from torchvision.datapoints._torch_function_helpers import _FORCE_TORCHFUNCTION_SUBCLASS, _must_return_subclass from torchvision.tv_tensors._torch_function_helpers import _FORCE_TORCHFUNCTION_SUBCLASS, _must_return_subclass
D = TypeVar("D", bound="Datapoint") D = TypeVar("D", bound="TVTensor")
class Datapoint(torch.Tensor): class TVTensor(torch.Tensor):
"""[Beta] Base class for all datapoints. """[Beta] Base class for all tv_tensors.
You probably don't want to use this class unless you're defining your own You probably don't want to use this class unless you're defining your own
custom Datapoints. See custom TVTensors. See
:ref:`sphx_glr_auto_examples_transforms_plot_custom_datapoints.py` for details. :ref:`sphx_glr_auto_examples_transforms_plot_custom_tv_tensors.py` for details.
""" """
@staticmethod @staticmethod
...@@ -62,9 +62,9 @@ class Datapoint(torch.Tensor): ...@@ -62,9 +62,9 @@ class Datapoint(torch.Tensor):
``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the ``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
``args`` and ``kwargs`` of the original call. ``args`` and ``kwargs`` of the original call.
Why do we override this? Because the base implementation in torch.Tensor would preserve the Datapoint type Why do we override this? Because the base implementation in torch.Tensor would preserve the TVTensor type
of the output. In our case, we want to return pure tensors instead (with a few exceptions). Refer to the of the output. In our case, we want to return pure tensors instead (with a few exceptions). Refer to the
"Datapoints FAQ" gallery example for a rationale of this behaviour (TL;DR: perf + no silver bullet). "TVTensors FAQ" gallery example for a rationale of this behaviour (TL;DR: perf + no silver bullet).
Our implementation below is very similar to the base implementation in ``torch.Tensor`` - go check it out. Our implementation below is very similar to the base implementation in ``torch.Tensor`` - go check it out.
""" """
...@@ -79,7 +79,7 @@ class Datapoint(torch.Tensor): ...@@ -79,7 +79,7 @@ class Datapoint(torch.Tensor):
must_return_subclass = _must_return_subclass() must_return_subclass = _must_return_subclass()
if must_return_subclass or (func in _FORCE_TORCHFUNCTION_SUBCLASS and isinstance(args[0], cls)): if must_return_subclass or (func in _FORCE_TORCHFUNCTION_SUBCLASS and isinstance(args[0], cls)):
# If you're wondering why we need the `isinstance(args[0], cls)` check, remove it and see what fails # If you're wondering why we need the `isinstance(args[0], cls)` check, remove it and see what fails
# in test_to_datapoint_reference(). # in test_to_tv_tensor_reference().
# The __torch_function__ protocol will invoke the __torch_function__ method on *all* types involved in # The __torch_function__ protocol will invoke the __torch_function__ method on *all* types involved in
# the computation by walking the MRO upwards. For example, # the computation by walking the MRO upwards. For example,
# `out = a_pure_tensor.to(an_image)` will invoke `Image.__torch_function__` with # `out = a_pure_tensor.to(an_image)` will invoke `Image.__torch_function__` with
...@@ -89,7 +89,7 @@ class Datapoint(torch.Tensor): ...@@ -89,7 +89,7 @@ class Datapoint(torch.Tensor):
if not must_return_subclass and isinstance(output, cls): if not must_return_subclass and isinstance(output, cls):
# DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`, # DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`,
# so for those, the output is still a Datapoint. Thus, we need to manually unwrap. # so for those, the output is still a TVTensor. Thus, we need to manually unwrap.
return output.as_subclass(torch.Tensor) return output.as_subclass(torch.Tensor)
return output return output
......
...@@ -4,10 +4,10 @@ from typing import Any, Optional, Union ...@@ -4,10 +4,10 @@ from typing import Any, Optional, Union
import torch import torch
from ._datapoint import Datapoint from ._tv_tensor import TVTensor
class Video(Datapoint): class Video(TVTensor):
"""[BETA] :class:`torch.Tensor` subclass for videos. """[BETA] :class:`torch.Tensor` subclass for videos.
Args: Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment