Unverified Commit 25ec3f26 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

tv_tensor -> TVTensor where it matters (#7904)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent d5f4cc38
...@@ -183,7 +183,7 @@ Transforms are available as classes like ...@@ -183,7 +183,7 @@ Transforms are available as classes like
This is very much like the :mod:`torch.nn` package which defines both classes This is very much like the :mod:`torch.nn` package which defines both classes
and functional equivalents in :mod:`torch.nn.functional`. and functional equivalents in :mod:`torch.nn.functional`.
The functionals support PIL images, pure tensors, or :ref:`tv_tensors The functionals support PIL images, pure tensors, or :ref:`TVTensors
<tv_tensors>`, e.g. both ``resize(image_tensor)`` and ``resize(bboxes)`` are <tv_tensors>`, e.g. both ``resize(image_tensor)`` and ``resize(bboxes)`` are
valid. valid.
......
...@@ -5,10 +5,11 @@ TVTensors ...@@ -5,10 +5,11 @@ TVTensors
.. currentmodule:: torchvision.tv_tensors .. currentmodule:: torchvision.tv_tensors
TVTensors are tensor subclasses which the :mod:`~torchvision.transforms.v2` v2 transforms use under the hood to TVTensors are :class:`torch.Tensor` subclasses which the v2 :ref:`transforms
dispatch their inputs to the appropriate lower-level kernels. Most users do not <transforms>` use under the hood to dispatch their inputs to the appropriate
need to manipulate tv_tensors directly and can simply rely on dataset wrapping - lower-level kernels. Most users do not need to manipulate TVTensors directly and
see e.g. :ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`. can simply rely on dataset wrapping - see e.g.
:ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`.
.. autosummary:: .. autosummary::
:toctree: generated/ :toctree: generated/
......
...@@ -74,7 +74,7 @@ out_img, out_bboxes, out_label = transforms(img, bboxes, label) ...@@ -74,7 +74,7 @@ out_img, out_bboxes, out_label = transforms(img, bboxes, label)
print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }") print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }")
# %% # %%
# .. note:: # .. note::
# While working with tv_tensor classes in your code, make sure to # While working with TVTensor classes in your code, make sure to
# familiarize yourself with this section: # familiarize yourself with this section:
# :ref:`tv_tensor_unwrapping_behaviour` # :ref:`tv_tensor_unwrapping_behaviour`
# #
...@@ -111,7 +111,7 @@ print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}") ...@@ -111,7 +111,7 @@ print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
# In brief, the core logic is to unpack the input into a flat list using `pytree # In brief, the core logic is to unpack the input into a flat list using `pytree
# <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and # <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and
# then transform only the entries that can be transformed (the decision is made # then transform only the entries that can be transformed (the decision is made
# based on the **class** of the entries, as all tv_tensors are # based on the **class** of the entries, as all TVTensors are
# tensor-subclasses) plus some custom logic that is out of score here - check the # tensor-subclasses) plus some custom logic that is out of score here - check the
# code for details. The (potentially transformed) entries are then repacked and # code for details. The (potentially transformed) entries are then repacked and
# returned, in the same structure as the input. # returned, in the same structure as the input.
......
""" """
===================================== ====================================
How to write your own TVTensor class How to write your own TVTensor class
===================================== ====================================
.. note:: .. note::
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_custom_tv_tensors.ipynb>`_ Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_custom_tv_tensors.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_custom_tv_tensors.py>` to download the full example code. or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_custom_tv_tensors.py>` to download the full example code.
This guide is intended for advanced users and downstream library maintainers. We explain how to This guide is intended for advanced users and downstream library maintainers. We explain how to
write your own tv_tensor class, and how to make it compatible with the built-in write your own TVTensor class, and how to make it compatible with the built-in
Torchvision v2 transforms. Before continuing, make sure you have read Torchvision v2 transforms. Before continuing, make sure you have read
:ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`. :ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`.
""" """
......
...@@ -115,7 +115,7 @@ plot([(img, boxes), (out_img, out_boxes)]) ...@@ -115,7 +115,7 @@ plot([(img, boxes), (out_img, out_boxes)])
# segmentation, or videos (:class:`torchvision.tv_tensors.Video`), we could have # segmentation, or videos (:class:`torchvision.tv_tensors.Video`), we could have
# passed them to the transforms in exactly the same way. # passed them to the transforms in exactly the same way.
# #
# By now you likely have a few questions: what are these tv_tensors, how do we # By now you likely have a few questions: what are these TVTensors, how do we
# use them, and what is the expected input/output of those transforms? We'll # use them, and what is the expected input/output of those transforms? We'll
# answer these in the next sections. # answer these in the next sections.
...@@ -126,7 +126,7 @@ plot([(img, boxes), (out_img, out_boxes)]) ...@@ -126,7 +126,7 @@ plot([(img, boxes), (out_img, out_boxes)])
# What are TVTensors? # What are TVTensors?
# -------------------- # --------------------
# #
# TVTensors are :class:`torch.Tensor` subclasses. The available tv_tensors are # TVTensors are :class:`torch.Tensor` subclasses. The available TVTensors are
# :class:`~torchvision.tv_tensors.Image`, # :class:`~torchvision.tv_tensors.Image`,
# :class:`~torchvision.tv_tensors.BoundingBoxes`, # :class:`~torchvision.tv_tensors.BoundingBoxes`,
# :class:`~torchvision.tv_tensors.Mask`, and # :class:`~torchvision.tv_tensors.Mask`, and
...@@ -134,7 +134,7 @@ plot([(img, boxes), (out_img, out_boxes)]) ...@@ -134,7 +134,7 @@ plot([(img, boxes), (out_img, out_boxes)])
# #
# TVTensors look and feel just like regular tensors - they **are** tensors. # TVTensors look and feel just like regular tensors - they **are** tensors.
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` # Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()``
# or any ``torch.*`` operator will also work on a tv_tensor: # or any ``torch.*`` operator will also work on a TVTensor:
img_dp = tv_tensors.Image(torch.randint(0, 256, (3, 256, 256), dtype=torch.uint8)) img_dp = tv_tensors.Image(torch.randint(0, 256, (3, 256, 256), dtype=torch.uint8))
...@@ -146,7 +146,7 @@ print(f"{img_dp.dtype = }, {img_dp.shape = }, {img_dp.sum() = }") ...@@ -146,7 +146,7 @@ print(f"{img_dp.dtype = }, {img_dp.shape = }, {img_dp.sum() = }")
# transform a given input, the transforms first look at the **class** of the # transform a given input, the transforms first look at the **class** of the
# object, and dispatch to the appropriate implementation accordingly. # object, and dispatch to the appropriate implementation accordingly.
# #
# You don't need to know much more about tv_tensors at this point, but advanced # You don't need to know much more about TVTensors at this point, but advanced
# users who want to learn more can refer to # users who want to learn more can refer to
# :ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`. # :ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`.
# #
...@@ -234,9 +234,9 @@ print(f"{out_target['this_is_ignored']}") ...@@ -234,9 +234,9 @@ print(f"{out_target['this_is_ignored']}")
# Torchvision also supports datasets for object detection or segmentation like # Torchvision also supports datasets for object detection or segmentation like
# :class:`torchvision.datasets.CocoDetection`. Those datasets predate # :class:`torchvision.datasets.CocoDetection`. Those datasets predate
# the existence of the :mod:`torchvision.transforms.v2` module and of the # the existence of the :mod:`torchvision.transforms.v2` module and of the
# tv_tensors, so they don't return tv_tensors out of the box. # TVTensors, so they don't return TVTensors out of the box.
# #
# An easy way to force those datasets to return tv_tensors and to make them # An easy way to force those datasets to return TVTensors and to make them
# compatible with v2 transforms is to use the # compatible with v2 transforms is to use the
# :func:`torchvision.datasets.wrap_dataset_for_transforms_v2` function: # :func:`torchvision.datasets.wrap_dataset_for_transforms_v2` function:
# #
...@@ -246,7 +246,7 @@ print(f"{out_target['this_is_ignored']}") ...@@ -246,7 +246,7 @@ print(f"{out_target['this_is_ignored']}")
# #
# dataset = CocoDetection(..., transforms=my_transforms) # dataset = CocoDetection(..., transforms=my_transforms)
# dataset = wrap_dataset_for_transforms_v2(dataset) # dataset = wrap_dataset_for_transforms_v2(dataset)
# # Now the dataset returns tv_tensors! # # Now the dataset returns TVTensors!
# #
# Using your own datasets # Using your own datasets
# ^^^^^^^^^^^^^^^^^^^^^^^ # ^^^^^^^^^^^^^^^^^^^^^^^
......
...@@ -9,18 +9,18 @@ TVTensors FAQ ...@@ -9,18 +9,18 @@ TVTensors FAQ
TVTensors are Tensor subclasses introduced together with TVTensors are Tensor subclasses introduced together with
``torchvision.transforms.v2``. This example showcases what these tv_tensors are ``torchvision.transforms.v2``. This example showcases what these TVTensors are
and how they behave. and how they behave.
.. warning:: .. warning::
**Intended Audience** Unless you're writing your own transforms or your own tv_tensors, you **Intended Audience** Unless you're writing your own transforms or your own TVTensors, you
probably do not need to read this guide. This is a fairly low-level topic probably do not need to read this guide. This is a fairly low-level topic
that most users will not need to worry about: you do not need to understand that most users will not need to worry about: you do not need to understand
the internals of tv_tensors to efficiently rely on the internals of TVTensors to efficiently rely on
``torchvision.transforms.v2``. It may however be useful for advanced users ``torchvision.transforms.v2``. It may however be useful for advanced users
trying to implement their own datasets, transforms, or work directly with trying to implement their own datasets, transforms, or work directly with
the tv_tensors. the TVTensors.
""" """
# %% # %%
...@@ -31,8 +31,8 @@ from torchvision import tv_tensors ...@@ -31,8 +31,8 @@ from torchvision import tv_tensors
# %% # %%
# What are tv_tensors? # What are TVTensors?
# -------------------- # -------------------
# #
# TVTensors are zero-copy tensor subclasses: # TVTensors are zero-copy tensor subclasses:
...@@ -46,31 +46,31 @@ assert image.data_ptr() == tensor.data_ptr() ...@@ -46,31 +46,31 @@ assert image.data_ptr() == tensor.data_ptr()
# Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function # Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
# for the input data. # for the input data.
# #
# :mod:`torchvision.tv_tensors` supports four types of tv_tensors: # :mod:`torchvision.tv_tensors` supports four types of TVTensors:
# #
# * :class:`~torchvision.tv_tensors.Image` # * :class:`~torchvision.tv_tensors.Image`
# * :class:`~torchvision.tv_tensors.Video` # * :class:`~torchvision.tv_tensors.Video`
# * :class:`~torchvision.tv_tensors.BoundingBoxes` # * :class:`~torchvision.tv_tensors.BoundingBoxes`
# * :class:`~torchvision.tv_tensors.Mask` # * :class:`~torchvision.tv_tensors.Mask`
# #
# What can I do with a tv_tensor? # What can I do with a TVTensor?
# ------------------------------- # ------------------------------
# #
# TVTensors look and feel just like regular tensors - they **are** tensors. # TVTensors look and feel just like regular tensors - they **are** tensors.
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or # Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or
# any ``torch.*`` operator will also work on tv_tensors. See # any ``torch.*`` operator will also work on TVTensors. See
# :ref:`tv_tensor_unwrapping_behaviour` for a few gotchas. # :ref:`tv_tensor_unwrapping_behaviour` for a few gotchas.
# %% # %%
# .. _tv_tensor_creation: # .. _tv_tensor_creation:
# #
# How do I construct a tv_tensor? # How do I construct a TVTensor?
# ------------------------------- # ------------------------------
# #
# Using the constructor # Using the constructor
# ^^^^^^^^^^^^^^^^^^^^^ # ^^^^^^^^^^^^^^^^^^^^^
# #
# Each tv_tensor class takes any tensor-like data that can be turned into a :class:`~torch.Tensor` # Each TVTensor class takes any tensor-like data that can be turned into a :class:`~torch.Tensor`
image = tv_tensors.Image([[[[0, 1], [1, 0]]]]) image = tv_tensors.Image([[[[0, 1], [1, 0]]]])
print(image) print(image)
...@@ -92,7 +92,7 @@ image = tv_tensors.Image(PIL.Image.open("../assets/astronaut.jpg")) ...@@ -92,7 +92,7 @@ image = tv_tensors.Image(PIL.Image.open("../assets/astronaut.jpg"))
print(image.shape, image.dtype) print(image.shape, image.dtype)
# %% # %%
# Some tv_tensors require additional metadata to be passed in ordered to be constructed. For example, # Some TVTensors require additional metadata to be passed in ordered to be constructed. For example,
# :class:`~torchvision.tv_tensors.BoundingBoxes` requires the coordinate format as well as the size of the # :class:`~torchvision.tv_tensors.BoundingBoxes` requires the coordinate format as well as the size of the
# corresponding image (``canvas_size``) alongside the actual values. These # corresponding image (``canvas_size``) alongside the actual values. These
# metadata are required to properly transform the bounding boxes. # metadata are required to properly transform the bounding boxes.
...@@ -109,7 +109,7 @@ print(bboxes) ...@@ -109,7 +109,7 @@ print(bboxes)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^
# #
# You can also use the :func:`~torchvision.tv_tensors.wrap` function to wrap a tensor object # You can also use the :func:`~torchvision.tv_tensors.wrap` function to wrap a tensor object
# into a tv_tensor. This is useful when you already have an object of the # into a TVTensor. This is useful when you already have an object of the
# desired type, which typically happens when writing transforms: you just want # desired type, which typically happens when writing transforms: you just want
# to wrap the output like the input. # to wrap the output like the input.
...@@ -125,7 +125,7 @@ assert new_bboxes.canvas_size == bboxes.canvas_size ...@@ -125,7 +125,7 @@ assert new_bboxes.canvas_size == bboxes.canvas_size
# .. _tv_tensor_unwrapping_behaviour: # .. _tv_tensor_unwrapping_behaviour:
# #
# I had a TVTensor but now I have a Tensor. Help! # I had a TVTensor but now I have a Tensor. Help!
# ------------------------------------------------ # -----------------------------------------------
# #
# By default, operations on :class:`~torchvision.tv_tensors.TVTensor` objects # By default, operations on :class:`~torchvision.tv_tensors.TVTensor` objects
# will return a pure Tensor: # will return a pure Tensor:
...@@ -151,7 +151,7 @@ assert not isinstance(new_bboxes, tv_tensors.BoundingBoxes) ...@@ -151,7 +151,7 @@ assert not isinstance(new_bboxes, tv_tensors.BoundingBoxes)
# But I want a TVTensor back! # But I want a TVTensor back!
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^
# #
# You can re-wrap a pure tensor into a tv_tensor by just calling the tv_tensor # You can re-wrap a pure tensor into a TVTensor by just calling the TVTensor
# constructor, or by using the :func:`~torchvision.tv_tensors.wrap` function # constructor, or by using the :func:`~torchvision.tv_tensors.wrap` function
# (see more details above in :ref:`tv_tensor_creation`): # (see more details above in :ref:`tv_tensor_creation`):
...@@ -164,7 +164,7 @@ assert isinstance(new_bboxes, tv_tensors.BoundingBoxes) ...@@ -164,7 +164,7 @@ assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
# as a global config setting for the whole program, or as a context manager # as a global config setting for the whole program, or as a context manager
# (read its docs to learn more about caveats): # (read its docs to learn more about caveats):
with tv_tensors.set_return_type("tv_tensor"): with tv_tensors.set_return_type("TVTensor"):
new_bboxes = bboxes + 3 new_bboxes = bboxes + 3
assert isinstance(new_bboxes, tv_tensors.BoundingBoxes) assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
...@@ -203,9 +203,9 @@ assert isinstance(new_bboxes, tv_tensors.BoundingBoxes) ...@@ -203,9 +203,9 @@ assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
# There are a few exceptions to this "unwrapping" rule: # There are a few exceptions to this "unwrapping" rule:
# :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`, # :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`,
# :meth:`torch.Tensor.detach`, and :meth:`~torch.Tensor.requires_grad_` retain # :meth:`torch.Tensor.detach`, and :meth:`~torch.Tensor.requires_grad_` retain
# the tv_tensor type. # the TVTensor type.
# #
# Inplace operations on tv_tensors like ``obj.add_()`` will preserve the type of # Inplace operations on TVTensors like ``obj.add_()`` will preserve the type of
# ``obj``. However, the **returned** value of inplace operations will be a pure # ``obj``. However, the **returned** value of inplace operations will be a pure
# tensor: # tensor:
...@@ -213,7 +213,7 @@ image = tv_tensors.Image([[[0, 1], [1, 0]]]) ...@@ -213,7 +213,7 @@ image = tv_tensors.Image([[[0, 1], [1, 0]]])
new_image = image.add_(1).mul_(2) new_image = image.add_(1).mul_(2)
# image got transformed in-place and is still an Image tv_tensor, but new_image # image got transformed in-place and is still a TVTensor Image, but new_image
# is a Tensor. They share the same underlying data and they're equal, just # is a Tensor. They share the same underlying data and they're equal, just
# different classes. # different classes.
assert isinstance(image, tv_tensors.Image) assert isinstance(image, tv_tensors.Image)
......
...@@ -91,7 +91,7 @@ def test_to_wrapping(make_input): ...@@ -91,7 +91,7 @@ def test_to_wrapping(make_input):
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"]) @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_to_tv_tensor_reference(make_input, return_type): def test_to_tv_tensor_reference(make_input, return_type):
tensor = torch.rand((3, 16, 16), dtype=torch.float64) tensor = torch.rand((3, 16, 16), dtype=torch.float64)
dp = make_input() dp = make_input()
...@@ -99,13 +99,13 @@ def test_to_tv_tensor_reference(make_input, return_type): ...@@ -99,13 +99,13 @@ def test_to_tv_tensor_reference(make_input, return_type):
with tv_tensors.set_return_type(return_type): with tv_tensors.set_return_type(return_type):
tensor_to = tensor.to(dp) tensor_to = tensor.to(dp)
assert type(tensor_to) is (type(dp) if return_type == "tv_tensor" else torch.Tensor) assert type(tensor_to) is (type(dp) if return_type == "TVTensor" else torch.Tensor)
assert tensor_to.dtype is dp.dtype assert tensor_to.dtype is dp.dtype
assert type(tensor) is torch.Tensor assert type(tensor) is torch.Tensor
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"]) @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_clone_wrapping(make_input, return_type): def test_clone_wrapping(make_input, return_type):
dp = make_input() dp = make_input()
...@@ -117,7 +117,7 @@ def test_clone_wrapping(make_input, return_type): ...@@ -117,7 +117,7 @@ def test_clone_wrapping(make_input, return_type):
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"]) @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_requires_grad__wrapping(make_input, return_type): def test_requires_grad__wrapping(make_input, return_type):
dp = make_input(dtype=torch.float) dp = make_input(dtype=torch.float)
...@@ -132,7 +132,7 @@ def test_requires_grad__wrapping(make_input, return_type): ...@@ -132,7 +132,7 @@ def test_requires_grad__wrapping(make_input, return_type):
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"]) @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_detach_wrapping(make_input, return_type): def test_detach_wrapping(make_input, return_type):
dp = make_input(dtype=torch.float).requires_grad_(True) dp = make_input(dtype=torch.float).requires_grad_(True)
...@@ -142,7 +142,7 @@ def test_detach_wrapping(make_input, return_type): ...@@ -142,7 +142,7 @@ def test_detach_wrapping(make_input, return_type):
assert type(dp_detached) is type(dp) assert type(dp_detached) is type(dp)
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"]) @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_force_subclass_with_metadata(return_type): def test_force_subclass_with_metadata(return_type):
# Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and tv_tensors with metadata # Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and tv_tensors with metadata
# Largely the same as above, we additionally check that the metadata is preserved # Largely the same as above, we additionally check that the metadata is preserved
...@@ -151,27 +151,27 @@ def test_force_subclass_with_metadata(return_type): ...@@ -151,27 +151,27 @@ def test_force_subclass_with_metadata(return_type):
tv_tensors.set_return_type(return_type) tv_tensors.set_return_type(return_type)
bbox = bbox.clone() bbox = bbox.clone()
if return_type == "tv_tensor": if return_type == "TVTensor":
assert bbox.format, bbox.canvas_size == (format, canvas_size) assert bbox.format, bbox.canvas_size == (format, canvas_size)
bbox = bbox.to(torch.float64) bbox = bbox.to(torch.float64)
if return_type == "tv_tensor": if return_type == "TVTensor":
assert bbox.format, bbox.canvas_size == (format, canvas_size) assert bbox.format, bbox.canvas_size == (format, canvas_size)
bbox = bbox.detach() bbox = bbox.detach()
if return_type == "tv_tensor": if return_type == "TVTensor":
assert bbox.format, bbox.canvas_size == (format, canvas_size) assert bbox.format, bbox.canvas_size == (format, canvas_size)
assert not bbox.requires_grad assert not bbox.requires_grad
bbox.requires_grad_(True) bbox.requires_grad_(True)
if return_type == "tv_tensor": if return_type == "TVTensor":
assert bbox.format, bbox.canvas_size == (format, canvas_size) assert bbox.format, bbox.canvas_size == (format, canvas_size)
assert bbox.requires_grad assert bbox.requires_grad
tv_tensors.set_return_type("tensor") tv_tensors.set_return_type("tensor")
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"]) @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_other_op_no_wrapping(make_input, return_type): def test_other_op_no_wrapping(make_input, return_type):
dp = make_input() dp = make_input()
...@@ -179,7 +179,7 @@ def test_other_op_no_wrapping(make_input, return_type): ...@@ -179,7 +179,7 @@ def test_other_op_no_wrapping(make_input, return_type):
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here # any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output = dp * 2 output = dp * 2
assert type(output) is (type(dp) if return_type == "tv_tensor" else torch.Tensor) assert type(output) is (type(dp) if return_type == "TVTensor" else torch.Tensor)
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
...@@ -200,7 +200,7 @@ def test_no_tensor_output_op_no_wrapping(make_input, op): ...@@ -200,7 +200,7 @@ def test_no_tensor_output_op_no_wrapping(make_input, op):
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"]) @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_inplace_op_no_wrapping(make_input, return_type): def test_inplace_op_no_wrapping(make_input, return_type):
dp = make_input() dp = make_input()
original_type = type(dp) original_type = type(dp)
...@@ -208,7 +208,7 @@ def test_inplace_op_no_wrapping(make_input, return_type): ...@@ -208,7 +208,7 @@ def test_inplace_op_no_wrapping(make_input, return_type):
with tv_tensors.set_return_type(return_type): with tv_tensors.set_return_type(return_type):
output = dp.add_(0) output = dp.add_(0)
assert type(output) is (type(dp) if return_type == "tv_tensor" else torch.Tensor) assert type(output) is (type(dp) if return_type == "TVTensor" else torch.Tensor)
assert type(dp) is original_type assert type(dp) is original_type
...@@ -243,7 +243,7 @@ def test_deepcopy(make_input, requires_grad): ...@@ -243,7 +243,7 @@ def test_deepcopy(make_input, requires_grad):
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"]) @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"op", "op",
( (
...@@ -267,8 +267,8 @@ def test_usual_operations(make_input, return_type, op): ...@@ -267,8 +267,8 @@ def test_usual_operations(make_input, return_type, op):
dp = make_input() dp = make_input()
with tv_tensors.set_return_type(return_type): with tv_tensors.set_return_type(return_type):
out = op(dp) out = op(dp)
assert type(out) is (type(dp) if return_type == "tv_tensor" else torch.Tensor) assert type(out) is (type(dp) if return_type == "TVTensor" else torch.Tensor)
if isinstance(dp, tv_tensors.BoundingBoxes) and return_type == "tv_tensor": if isinstance(dp, tv_tensors.BoundingBoxes) and return_type == "TVTensor":
assert hasattr(out, "format") assert hasattr(out, "format")
assert hasattr(out, "canvas_size") assert hasattr(out, "canvas_size")
...@@ -286,16 +286,16 @@ def test_set_return_type(): ...@@ -286,16 +286,16 @@ def test_set_return_type():
assert type(img + 3) is torch.Tensor assert type(img + 3) is torch.Tensor
with tv_tensors.set_return_type("tv_tensor"): with tv_tensors.set_return_type("TVTensor"):
assert type(img + 3) is tv_tensors.Image assert type(img + 3) is tv_tensors.Image
assert type(img + 3) is torch.Tensor assert type(img + 3) is torch.Tensor
tv_tensors.set_return_type("tv_tensor") tv_tensors.set_return_type("TVTensor")
assert type(img + 3) is tv_tensors.Image assert type(img + 3) is tv_tensors.Image
with tv_tensors.set_return_type("tensor"): with tv_tensors.set_return_type("tensor"):
assert type(img + 3) is torch.Tensor assert type(img + 3) is torch.Tensor
with tv_tensors.set_return_type("tv_tensor"): with tv_tensors.set_return_type("TVTensor"):
assert type(img + 3) is tv_tensors.Image assert type(img + 3) is tv_tensors.Image
tv_tensors.set_return_type("tensor") tv_tensors.set_return_type("tensor")
assert type(img + 3) is torch.Tensor assert type(img + 3) is torch.Tensor
...@@ -305,3 +305,16 @@ def test_set_return_type(): ...@@ -305,3 +305,16 @@ def test_set_return_type():
assert type(img + 3) is tv_tensors.Image assert type(img + 3) is tv_tensors.Image
tv_tensors.set_return_type("tensor") tv_tensors.set_return_type("tensor")
def test_return_type_input():
img = make_image()
# Case-insensitive
with tv_tensors.set_return_type("tvtensor"):
assert type(img + 3) is tv_tensors.Image
with pytest.raises(ValueError, match="return_type must be"):
tv_tensors.set_return_type("typo")
tv_tensors.set_return_type("tensor")
...@@ -16,7 +16,7 @@ class _ReturnTypeCM: ...@@ -16,7 +16,7 @@ class _ReturnTypeCM:
def set_return_type(return_type: str): def set_return_type(return_type: str):
"""[BETA] Set the return type of torch operations on tv_tensors. """[BETA] Set the return type of torch operations on :class:`~torchvision.tv_tensors.TVTensor`.
This only affects the behaviour of torch operations. It has no effect on This only affects the behaviour of torch operations. It has no effect on
``torchvision`` transforms or functionals, which will always return as ``torchvision`` transforms or functionals, which will always return as
...@@ -26,7 +26,7 @@ def set_return_type(return_type: str): ...@@ -26,7 +26,7 @@ def set_return_type(return_type: str):
We recommend using :class:`~torchvision.transforms.v2.ToPureTensor` at We recommend using :class:`~torchvision.transforms.v2.ToPureTensor` at
the end of your transform pipelines if you use the end of your transform pipelines if you use
``set_return_type("dataptoint")``. This will avoid the ``set_return_type("TVTensor")``. This will avoid the
``__torch_function__`` overhead in the models ``forward()``. ``__torch_function__`` overhead in the models ``forward()``.
Can be used as a global flag for the entire program: Can be used as a global flag for the entire program:
...@@ -36,7 +36,7 @@ def set_return_type(return_type: str): ...@@ -36,7 +36,7 @@ def set_return_type(return_type: str):
img = tv_tensors.Image(torch.rand(3, 5, 5)) img = tv_tensors.Image(torch.rand(3, 5, 5))
img + 2 # This is a pure Tensor (default behaviour) img + 2 # This is a pure Tensor (default behaviour)
set_return_type("tv_tensors") set_return_type("TVTensor")
img + 2 # This is an Image img + 2 # This is an Image
or as a context manager to restrict the scope: or as a context manager to restrict the scope:
...@@ -45,16 +45,21 @@ def set_return_type(return_type: str): ...@@ -45,16 +45,21 @@ def set_return_type(return_type: str):
img = tv_tensors.Image(torch.rand(3, 5, 5)) img = tv_tensors.Image(torch.rand(3, 5, 5))
img + 2 # This is a pure Tensor img + 2 # This is a pure Tensor
with set_return_type("tv_tensors"): with set_return_type("TVTensor"):
img + 2 # This is an Image img + 2 # This is an Image
img + 2 # This is a pure Tensor img + 2 # This is a pure Tensor
Args: Args:
return_type (str): Can be "tv_tensor" or "tensor". Default is "tensor". return_type (str): Can be "TVTensor" or "Tensor" (case-insensitive).
Default is "Tensor" (i.e. pure :class:`torch.Tensor`).
""" """
global _TORCHFUNCTION_SUBCLASS global _TORCHFUNCTION_SUBCLASS
to_restore = _TORCHFUNCTION_SUBCLASS to_restore = _TORCHFUNCTION_SUBCLASS
_TORCHFUNCTION_SUBCLASS = {"tensor": False, "tv_tensor": True}[return_type.lower()]
try:
_TORCHFUNCTION_SUBCLASS = {"tensor": False, "tvtensor": True}[return_type.lower()]
except KeyError:
raise ValueError(f"return_type must be 'TVTensor' or 'Tensor', got {return_type}") from None
return _ReturnTypeCM(to_restore) return _ReturnTypeCM(to_restore)
......
...@@ -13,7 +13,7 @@ D = TypeVar("D", bound="TVTensor") ...@@ -13,7 +13,7 @@ D = TypeVar("D", bound="TVTensor")
class TVTensor(torch.Tensor): class TVTensor(torch.Tensor):
"""[Beta] Base class for all tv_tensors. """[Beta] Base class for all TVTensors.
You probably don't want to use this class unless you're defining your own You probably don't want to use this class unless you're defining your own
custom TVTensors. See custom TVTensors. See
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment