plot_tv_tensors.py 7.75 KB
Newer Older
1
"""
2
3
4
=============
TVTensors FAQ
=============
5

Nicolas Hug's avatar
Nicolas Hug committed
6
.. note::
7
8
    Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_tv_tensors.ipynb>`_
    or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_tv_tensors.py>` to download the full example code.
Nicolas Hug's avatar
Nicolas Hug committed
9

10

11
TVTensors are Tensor subclasses introduced together with
12
``torchvision.transforms.v2``. This example showcases what these TVTensors are
13
14
15
16
and how they behave.

.. warning::

17
    **Intended Audience** Unless you're writing your own transforms or your own TVTensors, you
18
19
    probably do not need to read this guide. This is a fairly low-level topic
    that most users will not need to worry about: you do not need to understand
20
    the internals of TVTensors to efficiently rely on
21
22
    ``torchvision.transforms.v2``. It may however be useful for advanced users
    trying to implement their own datasets, transforms, or work directly with
23
    the TVTensors.
24
25
"""

26
# %%
27
28
29
import PIL.Image

import torch
30
from torchvision import tv_tensors
31
32


33
# %%
34
35
# What are TVTensors?
# -------------------
36
#
37
# TVTensors are zero-copy tensor subclasses:
38
39

tensor = torch.rand(3, 256, 256)
40
image = tv_tensors.Image(tensor)
41
42
43
44

assert isinstance(image, torch.Tensor)
assert image.data_ptr() == tensor.data_ptr()

45
# %%
46
47
48
# Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
# for the input data.
#
49
# :mod:`torchvision.tv_tensors` supports four types of TVTensors:
50
#
51
52
53
54
# * :class:`~torchvision.tv_tensors.Image`
# * :class:`~torchvision.tv_tensors.Video`
# * :class:`~torchvision.tv_tensors.BoundingBoxes`
# * :class:`~torchvision.tv_tensors.Mask`
55
#
56
57
# What can I do with a TVTensor?
# ------------------------------
58
#
59
# TVTensors look and feel just like regular tensors - they **are** tensors.
60
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or
61
# any ``torch.*`` operator will also work on TVTensors. See
62
# :ref:`tv_tensor_unwrapping_behaviour` for a few gotchas.
63
64

# %%
65
# .. _tv_tensor_creation:
66
#
67
68
# How do I construct a TVTensor?
# ------------------------------
69
#
70
71
72
# Using the constructor
# ^^^^^^^^^^^^^^^^^^^^^
#
73
# Each TVTensor class takes any tensor-like data that can be turned into a :class:`~torch.Tensor`
74

75
image = tv_tensors.Image([[[[0, 1], [1, 0]]]])
76
77
78
print(image)


79
# %%
80
81
82
# Similar to other PyTorch creations ops, the constructor also takes the ``dtype``, ``device``, and ``requires_grad``
# parameters.

83
float_image = tv_tensors.Image([[[0, 1], [1, 0]]], dtype=torch.float32, requires_grad=True)
84
85
86
print(float_image)


87
# %%
88
# In addition, :class:`~torchvision.tv_tensors.Image` and :class:`~torchvision.tv_tensors.Mask` can also take a
89
90
# :class:`PIL.Image.Image` directly:

91
image = tv_tensors.Image(PIL.Image.open("../assets/astronaut.jpg"))
92
93
print(image.shape, image.dtype)

94
# %%
95
# Some TVTensors require additional metadata to be passed in ordered to be constructed. For example,
96
# :class:`~torchvision.tv_tensors.BoundingBoxes` requires the coordinate format as well as the size of the
97
98
99
# corresponding image (``canvas_size``) alongside the actual values. These
# metadata are required to properly transform the bounding boxes.

100
bboxes = tv_tensors.BoundingBoxes(
101
    [[17, 16, 344, 495], [0, 10, 0, 10]],
102
    format=tv_tensors.BoundingBoxFormat.XYXY,
103
    canvas_size=image.shape[-2:]
104
)
105
106
107
print(bboxes)

# %%
108
# Using ``tv_tensors.wrap()``
109
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
110
#
111
# You can also use the :func:`~torchvision.tv_tensors.wrap` function to wrap a tensor object
112
# into a TVTensor. This is useful when you already have an object of the
113
# desired type, which typically happens when writing transforms: you just want
114
# to wrap the output like the input.
115
116

new_bboxes = torch.tensor([0, 20, 30, 40])
117
118
new_bboxes = tv_tensors.wrap(new_bboxes, like=bboxes)
assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
119
assert new_bboxes.canvas_size == bboxes.canvas_size
120

121
# %%
122
# The metadata of ``new_bboxes`` is the same as ``bboxes``, but you could pass
123
# it as a parameter to override it.
124
#
125
# .. _tv_tensor_unwrapping_behaviour:
126
#
127
# I had a TVTensor but now I have a Tensor. Help!
128
# -----------------------------------------------
129
#
130
# By default, operations on :class:`~torchvision.tv_tensors.TVTensor` objects
131
# will return a pure Tensor:
132

133

134
assert isinstance(bboxes, tv_tensors.BoundingBoxes)
135

136
137
# Shift bboxes by 3 pixels in both H and W
new_bboxes = bboxes + 3
138

139
assert isinstance(new_bboxes, torch.Tensor)
140
assert not isinstance(new_bboxes, tv_tensors.BoundingBoxes)
141
142
143
144
145
146
147

# %%
# .. note::
#
#    This behavior only affects native ``torch`` operations. If you are using
#    the built-in ``torchvision`` transforms or functionals, you will always get
#    as output the same type that you passed as input (pure ``Tensor`` or
148
#    ``TVTensor``).
149
150

# %%
151
152
# But I want a TVTensor back!
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
153
#
154
# You can re-wrap a pure tensor into a TVTensor by just calling the TVTensor
155
156
# constructor, or by using the :func:`~torchvision.tv_tensors.wrap` function
# (see more details above in :ref:`tv_tensor_creation`):
157
158

new_bboxes = bboxes + 3
159
160
new_bboxes = tv_tensors.wrap(new_bboxes, like=bboxes)
assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
161

162
# %%
163
# Alternatively, you can use the :func:`~torchvision.tv_tensors.set_return_type`
Nicolas Hug's avatar
Nicolas Hug committed
164
165
# as a global config setting for the whole program, or as a context manager
# (read its docs to learn more about caveats):
166

167
with tv_tensors.set_return_type("TVTensor"):
168
    new_bboxes = bboxes + 3
169
assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
170
171
172
173

# %%
# Why is this happening?
# ^^^^^^^^^^^^^^^^^^^^^^
174
#
175
# **For performance reasons**. :class:`~torchvision.tv_tensors.TVTensor`
176
# classes are Tensor subclasses, so any operation involving a
177
# :class:`~torchvision.tv_tensors.TVTensor` object will go through the
178
179
180
181
182
183
# `__torch_function__
# <https://pytorch.org/docs/stable/notes/extending.html#extending-torch>`_
# protocol. This induces a small overhead, which we want to avoid when possible.
# This doesn't matter for built-in ``torchvision`` transforms because we can
# avoid the overhead there, but it could be a problem in your model's
# ``forward``.
184
#
185
# **The alternative isn't much better anyway.** For every operation where
186
# preserving the :class:`~torchvision.tv_tensors.TVTensor` type makes
187
# sense, there are just as many operations where returning a pure Tensor is
188
189
# preferable: for example, is ``img.sum()`` still an :class:`~torchvision.tv_tensors.Image`?
# If we were to preserve :class:`~torchvision.tv_tensors.TVTensor` types all
190
# the way, even model's logits or the output of the loss function would end up
191
# being of type :class:`~torchvision.tv_tensors.Image`, and surely that's not
192
# desirable.
193
#
194
195
# .. note::
#
196
#    This behaviour is something we're actively seeking feedback on. If you find this surprising or if you
197
198
199
#    have any suggestions on how to better support your use-cases, please reach out to us via this issue:
#    https://github.com/pytorch/vision/issues/7319
#
200
201
202
# Exceptions
# ^^^^^^^^^^
#
203
# There are a few exceptions to this "unwrapping" rule:
Nicolas Hug's avatar
Nicolas Hug committed
204
205
# :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`,
# :meth:`torch.Tensor.detach`, and :meth:`~torch.Tensor.requires_grad_` retain
206
# the TVTensor type.
207
#
208
# Inplace operations on TVTensors like ``obj.add_()`` will preserve the type of
Nicolas Hug's avatar
Nicolas Hug committed
209
210
# ``obj``. However, the **returned** value of inplace operations will be a pure
# tensor:
211

212
image = tv_tensors.Image([[[0, 1], [1, 0]]])
213
214
215

new_image = image.add_(1).mul_(2)

216
# image got transformed in-place and is still a TVTensor Image, but new_image
217
218
# is a Tensor. They share the same underlying data and they're equal, just
# different classes.
219
assert isinstance(image, tv_tensors.Image)
220
221
print(image)

222
assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, tv_tensors.Image)
223
assert (new_image == image).all()
224
assert new_image.data_ptr() == image.data_ptr()