Unverified Commit 25ec3f26 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

tv_tensor -> TVTensor where it matters (#7904)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent d5f4cc38
......@@ -183,7 +183,7 @@ Transforms are available as classes like
This is very much like the :mod:`torch.nn` package which defines both classes
and functional equivalents in :mod:`torch.nn.functional`.
The functionals support PIL images, pure tensors, or :ref:`tv_tensors
The functionals support PIL images, pure tensors, or :ref:`TVTensors
<tv_tensors>`, e.g. both ``resize(image_tensor)`` and ``resize(bboxes)`` are
valid.
......
......@@ -5,10 +5,11 @@ TVTensors
.. currentmodule:: torchvision.tv_tensors
TVTensors are tensor subclasses which the :mod:`~torchvision.transforms.v2` v2 transforms use under the hood to
dispatch their inputs to the appropriate lower-level kernels. Most users do not
need to manipulate tv_tensors directly and can simply rely on dataset wrapping -
see e.g. :ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`.
TVTensors are :class:`torch.Tensor` subclasses which the v2 :ref:`transforms
<transforms>` use under the hood to dispatch their inputs to the appropriate
lower-level kernels. Most users do not need to manipulate TVTensors directly and
can simply rely on dataset wrapping - see e.g.
:ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`.
.. autosummary::
:toctree: generated/
......
......@@ -74,7 +74,7 @@ out_img, out_bboxes, out_label = transforms(img, bboxes, label)
print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }")
# %%
# .. note::
# While working with tv_tensor classes in your code, make sure to
# While working with TVTensor classes in your code, make sure to
# familiarize yourself with this section:
# :ref:`tv_tensor_unwrapping_behaviour`
#
......@@ -111,7 +111,7 @@ print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
# In brief, the core logic is to unpack the input into a flat list using `pytree
# <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and
# then transform only the entries that can be transformed (the decision is made
# based on the **class** of the entries, as all tv_tensors are
# based on the **class** of the entries, as all TVTensors are
# tensor-subclasses) plus some custom logic that is out of score here - check the
# code for details. The (potentially transformed) entries are then repacked and
# returned, in the same structure as the input.
......
"""
=====================================
====================================
How to write your own TVTensor class
=====================================
====================================
.. note::
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_custom_tv_tensors.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_custom_tv_tensors.py>` to download the full example code.
This guide is intended for advanced users and downstream library maintainers. We explain how to
write your own tv_tensor class, and how to make it compatible with the built-in
write your own TVTensor class, and how to make it compatible with the built-in
Torchvision v2 transforms. Before continuing, make sure you have read
:ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`.
"""
......
......@@ -115,7 +115,7 @@ plot([(img, boxes), (out_img, out_boxes)])
# segmentation, or videos (:class:`torchvision.tv_tensors.Video`), we could have
# passed them to the transforms in exactly the same way.
#
# By now you likely have a few questions: what are these tv_tensors, how do we
# By now you likely have a few questions: what are these TVTensors, how do we
# use them, and what is the expected input/output of those transforms? We'll
# answer these in the next sections.
......@@ -126,7 +126,7 @@ plot([(img, boxes), (out_img, out_boxes)])
# What are TVTensors?
# --------------------
#
# TVTensors are :class:`torch.Tensor` subclasses. The available tv_tensors are
# TVTensors are :class:`torch.Tensor` subclasses. The available TVTensors are
# :class:`~torchvision.tv_tensors.Image`,
# :class:`~torchvision.tv_tensors.BoundingBoxes`,
# :class:`~torchvision.tv_tensors.Mask`, and
......@@ -134,7 +134,7 @@ plot([(img, boxes), (out_img, out_boxes)])
#
# TVTensors look and feel just like regular tensors - they **are** tensors.
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()``
# or any ``torch.*`` operator will also work on a tv_tensor:
# or any ``torch.*`` operator will also work on a TVTensor:
img_dp = tv_tensors.Image(torch.randint(0, 256, (3, 256, 256), dtype=torch.uint8))
......@@ -146,7 +146,7 @@ print(f"{img_dp.dtype = }, {img_dp.shape = }, {img_dp.sum() = }")
# transform a given input, the transforms first look at the **class** of the
# object, and dispatch to the appropriate implementation accordingly.
#
# You don't need to know much more about tv_tensors at this point, but advanced
# You don't need to know much more about TVTensors at this point, but advanced
# users who want to learn more can refer to
# :ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`.
#
......@@ -234,9 +234,9 @@ print(f"{out_target['this_is_ignored']}")
# Torchvision also supports datasets for object detection or segmentation like
# :class:`torchvision.datasets.CocoDetection`. Those datasets predate
# the existence of the :mod:`torchvision.transforms.v2` module and of the
# tv_tensors, so they don't return tv_tensors out of the box.
# TVTensors, so they don't return TVTensors out of the box.
#
# An easy way to force those datasets to return tv_tensors and to make them
# An easy way to force those datasets to return TVTensors and to make them
# compatible with v2 transforms is to use the
# :func:`torchvision.datasets.wrap_dataset_for_transforms_v2` function:
#
......@@ -246,7 +246,7 @@ print(f"{out_target['this_is_ignored']}")
#
# dataset = CocoDetection(..., transforms=my_transforms)
# dataset = wrap_dataset_for_transforms_v2(dataset)
# # Now the dataset returns tv_tensors!
# # Now the dataset returns TVTensors!
#
# Using your own datasets
# ^^^^^^^^^^^^^^^^^^^^^^^
......
......@@ -9,18 +9,18 @@ TVTensors FAQ
TVTensors are Tensor subclasses introduced together with
``torchvision.transforms.v2``. This example showcases what these tv_tensors are
``torchvision.transforms.v2``. This example showcases what these TVTensors are
and how they behave.
.. warning::
**Intended Audience** Unless you're writing your own transforms or your own tv_tensors, you
**Intended Audience** Unless you're writing your own transforms or your own TVTensors, you
probably do not need to read this guide. This is a fairly low-level topic
that most users will not need to worry about: you do not need to understand
the internals of tv_tensors to efficiently rely on
the internals of TVTensors to efficiently rely on
``torchvision.transforms.v2``. It may however be useful for advanced users
trying to implement their own datasets, transforms, or work directly with
the tv_tensors.
the TVTensors.
"""
# %%
......@@ -31,8 +31,8 @@ from torchvision import tv_tensors
# %%
# What are tv_tensors?
# --------------------
# What are TVTensors?
# -------------------
#
# TVTensors are zero-copy tensor subclasses:
......@@ -46,31 +46,31 @@ assert image.data_ptr() == tensor.data_ptr()
# Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
# for the input data.
#
# :mod:`torchvision.tv_tensors` supports four types of tv_tensors:
# :mod:`torchvision.tv_tensors` supports four types of TVTensors:
#
# * :class:`~torchvision.tv_tensors.Image`
# * :class:`~torchvision.tv_tensors.Video`
# * :class:`~torchvision.tv_tensors.BoundingBoxes`
# * :class:`~torchvision.tv_tensors.Mask`
#
# What can I do with a tv_tensor?
# -------------------------------
# What can I do with a TVTensor?
# ------------------------------
#
# TVTensors look and feel just like regular tensors - they **are** tensors.
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or
# any ``torch.*`` operator will also work on tv_tensors. See
# any ``torch.*`` operator will also work on TVTensors. See
# :ref:`tv_tensor_unwrapping_behaviour` for a few gotchas.
# %%
# .. _tv_tensor_creation:
#
# How do I construct a tv_tensor?
# -------------------------------
# How do I construct a TVTensor?
# ------------------------------
#
# Using the constructor
# ^^^^^^^^^^^^^^^^^^^^^
#
# Each tv_tensor class takes any tensor-like data that can be turned into a :class:`~torch.Tensor`
# Each TVTensor class takes any tensor-like data that can be turned into a :class:`~torch.Tensor`
image = tv_tensors.Image([[[[0, 1], [1, 0]]]])
print(image)
......@@ -92,7 +92,7 @@ image = tv_tensors.Image(PIL.Image.open("../assets/astronaut.jpg"))
print(image.shape, image.dtype)
# %%
# Some tv_tensors require additional metadata to be passed in ordered to be constructed. For example,
# Some TVTensors require additional metadata to be passed in ordered to be constructed. For example,
# :class:`~torchvision.tv_tensors.BoundingBoxes` requires the coordinate format as well as the size of the
# corresponding image (``canvas_size``) alongside the actual values. These
# metadata are required to properly transform the bounding boxes.
......@@ -109,7 +109,7 @@ print(bboxes)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# You can also use the :func:`~torchvision.tv_tensors.wrap` function to wrap a tensor object
# into a tv_tensor. This is useful when you already have an object of the
# into a TVTensor. This is useful when you already have an object of the
# desired type, which typically happens when writing transforms: you just want
# to wrap the output like the input.
......@@ -125,7 +125,7 @@ assert new_bboxes.canvas_size == bboxes.canvas_size
# .. _tv_tensor_unwrapping_behaviour:
#
# I had a TVTensor but now I have a Tensor. Help!
# ------------------------------------------------
# -----------------------------------------------
#
# By default, operations on :class:`~torchvision.tv_tensors.TVTensor` objects
# will return a pure Tensor:
......@@ -151,7 +151,7 @@ assert not isinstance(new_bboxes, tv_tensors.BoundingBoxes)
# But I want a TVTensor back!
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# You can re-wrap a pure tensor into a tv_tensor by just calling the tv_tensor
# You can re-wrap a pure tensor into a TVTensor by just calling the TVTensor
# constructor, or by using the :func:`~torchvision.tv_tensors.wrap` function
# (see more details above in :ref:`tv_tensor_creation`):
......@@ -164,7 +164,7 @@ assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
# as a global config setting for the whole program, or as a context manager
# (read its docs to learn more about caveats):
with tv_tensors.set_return_type("tv_tensor"):
with tv_tensors.set_return_type("TVTensor"):
new_bboxes = bboxes + 3
assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
......@@ -203,9 +203,9 @@ assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
# There are a few exceptions to this "unwrapping" rule:
# :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`,
# :meth:`torch.Tensor.detach`, and :meth:`~torch.Tensor.requires_grad_` retain
# the tv_tensor type.
# the TVTensor type.
#
# Inplace operations on tv_tensors like ``obj.add_()`` will preserve the type of
# Inplace operations on TVTensors like ``obj.add_()`` will preserve the type of
# ``obj``. However, the **returned** value of inplace operations will be a pure
# tensor:
......@@ -213,7 +213,7 @@ image = tv_tensors.Image([[[0, 1], [1, 0]]])
new_image = image.add_(1).mul_(2)
# image got transformed in-place and is still an Image tv_tensor, but new_image
# image got transformed in-place and is still a TVTensor Image, but new_image
# is a Tensor. They share the same underlying data and they're equal, just
# different classes.
assert isinstance(image, tv_tensors.Image)
......
......@@ -91,7 +91,7 @@ def test_to_wrapping(make_input):
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_to_tv_tensor_reference(make_input, return_type):
tensor = torch.rand((3, 16, 16), dtype=torch.float64)
dp = make_input()
......@@ -99,13 +99,13 @@ def test_to_tv_tensor_reference(make_input, return_type):
with tv_tensors.set_return_type(return_type):
tensor_to = tensor.to(dp)
assert type(tensor_to) is (type(dp) if return_type == "tv_tensor" else torch.Tensor)
assert type(tensor_to) is (type(dp) if return_type == "TVTensor" else torch.Tensor)
assert tensor_to.dtype is dp.dtype
assert type(tensor) is torch.Tensor
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_clone_wrapping(make_input, return_type):
dp = make_input()
......@@ -117,7 +117,7 @@ def test_clone_wrapping(make_input, return_type):
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_requires_grad__wrapping(make_input, return_type):
dp = make_input(dtype=torch.float)
......@@ -132,7 +132,7 @@ def test_requires_grad__wrapping(make_input, return_type):
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_detach_wrapping(make_input, return_type):
dp = make_input(dtype=torch.float).requires_grad_(True)
......@@ -142,7 +142,7 @@ def test_detach_wrapping(make_input, return_type):
assert type(dp_detached) is type(dp)
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_force_subclass_with_metadata(return_type):
# Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and tv_tensors with metadata
# Largely the same as above, we additionally check that the metadata is preserved
......@@ -151,27 +151,27 @@ def test_force_subclass_with_metadata(return_type):
tv_tensors.set_return_type(return_type)
bbox = bbox.clone()
if return_type == "tv_tensor":
if return_type == "TVTensor":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
bbox = bbox.to(torch.float64)
if return_type == "tv_tensor":
if return_type == "TVTensor":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
bbox = bbox.detach()
if return_type == "tv_tensor":
if return_type == "TVTensor":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
assert not bbox.requires_grad
bbox.requires_grad_(True)
if return_type == "tv_tensor":
if return_type == "TVTensor":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
assert bbox.requires_grad
tv_tensors.set_return_type("tensor")
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_other_op_no_wrapping(make_input, return_type):
dp = make_input()
......@@ -179,7 +179,7 @@ def test_other_op_no_wrapping(make_input, return_type):
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output = dp * 2
assert type(output) is (type(dp) if return_type == "tv_tensor" else torch.Tensor)
assert type(output) is (type(dp) if return_type == "TVTensor" else torch.Tensor)
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
......@@ -200,7 +200,7 @@ def test_no_tensor_output_op_no_wrapping(make_input, op):
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_inplace_op_no_wrapping(make_input, return_type):
dp = make_input()
original_type = type(dp)
......@@ -208,7 +208,7 @@ def test_inplace_op_no_wrapping(make_input, return_type):
with tv_tensors.set_return_type(return_type):
output = dp.add_(0)
assert type(output) is (type(dp) if return_type == "tv_tensor" else torch.Tensor)
assert type(output) is (type(dp) if return_type == "TVTensor" else torch.Tensor)
assert type(dp) is original_type
......@@ -243,7 +243,7 @@ def test_deepcopy(make_input, requires_grad):
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
@pytest.mark.parametrize(
"op",
(
......@@ -267,8 +267,8 @@ def test_usual_operations(make_input, return_type, op):
dp = make_input()
with tv_tensors.set_return_type(return_type):
out = op(dp)
assert type(out) is (type(dp) if return_type == "tv_tensor" else torch.Tensor)
if isinstance(dp, tv_tensors.BoundingBoxes) and return_type == "tv_tensor":
assert type(out) is (type(dp) if return_type == "TVTensor" else torch.Tensor)
if isinstance(dp, tv_tensors.BoundingBoxes) and return_type == "TVTensor":
assert hasattr(out, "format")
assert hasattr(out, "canvas_size")
......@@ -286,16 +286,16 @@ def test_set_return_type():
assert type(img + 3) is torch.Tensor
with tv_tensors.set_return_type("tv_tensor"):
with tv_tensors.set_return_type("TVTensor"):
assert type(img + 3) is tv_tensors.Image
assert type(img + 3) is torch.Tensor
tv_tensors.set_return_type("tv_tensor")
tv_tensors.set_return_type("TVTensor")
assert type(img + 3) is tv_tensors.Image
with tv_tensors.set_return_type("tensor"):
assert type(img + 3) is torch.Tensor
with tv_tensors.set_return_type("tv_tensor"):
with tv_tensors.set_return_type("TVTensor"):
assert type(img + 3) is tv_tensors.Image
tv_tensors.set_return_type("tensor")
assert type(img + 3) is torch.Tensor
......@@ -305,3 +305,16 @@ def test_set_return_type():
assert type(img + 3) is tv_tensors.Image
tv_tensors.set_return_type("tensor")
def test_return_type_input():
img = make_image()
# Case-insensitive
with tv_tensors.set_return_type("tvtensor"):
assert type(img + 3) is tv_tensors.Image
with pytest.raises(ValueError, match="return_type must be"):
tv_tensors.set_return_type("typo")
tv_tensors.set_return_type("tensor")
......@@ -16,7 +16,7 @@ class _ReturnTypeCM:
def set_return_type(return_type: str):
"""[BETA] Set the return type of torch operations on tv_tensors.
"""[BETA] Set the return type of torch operations on :class:`~torchvision.tv_tensors.TVTensor`.
This only affects the behaviour of torch operations. It has no effect on
``torchvision`` transforms or functionals, which will always return as
......@@ -26,7 +26,7 @@ def set_return_type(return_type: str):
We recommend using :class:`~torchvision.transforms.v2.ToPureTensor` at
the end of your transform pipelines if you use
``set_return_type("dataptoint")``. This will avoid the
``set_return_type("TVTensor")``. This will avoid the
``__torch_function__`` overhead in the models ``forward()``.
Can be used as a global flag for the entire program:
......@@ -36,7 +36,7 @@ def set_return_type(return_type: str):
img = tv_tensors.Image(torch.rand(3, 5, 5))
img + 2 # This is a pure Tensor (default behaviour)
set_return_type("tv_tensors")
set_return_type("TVTensor")
img + 2 # This is an Image
or as a context manager to restrict the scope:
......@@ -45,16 +45,21 @@ def set_return_type(return_type: str):
img = tv_tensors.Image(torch.rand(3, 5, 5))
img + 2 # This is a pure Tensor
with set_return_type("tv_tensors"):
with set_return_type("TVTensor"):
img + 2 # This is an Image
img + 2 # This is a pure Tensor
Args:
return_type (str): Can be "tv_tensor" or "tensor". Default is "tensor".
return_type (str): Can be "TVTensor" or "Tensor" (case-insensitive).
Default is "Tensor" (i.e. pure :class:`torch.Tensor`).
"""
global _TORCHFUNCTION_SUBCLASS
to_restore = _TORCHFUNCTION_SUBCLASS
_TORCHFUNCTION_SUBCLASS = {"tensor": False, "tv_tensor": True}[return_type.lower()]
try:
_TORCHFUNCTION_SUBCLASS = {"tensor": False, "tvtensor": True}[return_type.lower()]
except KeyError:
raise ValueError(f"return_type must be 'TVTensor' or 'Tensor', got {return_type}") from None
return _ReturnTypeCM(to_restore)
......
......@@ -13,7 +13,7 @@ D = TypeVar("D", bound="TVTensor")
class TVTensor(torch.Tensor):
"""[Beta] Base class for all tv_tensors.
"""[Beta] Base class for all TVTensors.
You probably don't want to use this class unless you're defining your own
custom TVTensors. See
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment