plot_datapoints.py 7.79 KB
Newer Older
1
2
3
4
5
"""
==============
Datapoints FAQ
==============

Nicolas Hug's avatar
Nicolas Hug committed
6
7
.. note::
    Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_datapoints.ipynb>`_
Nicolas Hug's avatar
Nicolas Hug committed
8
    or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_datapoints.py>` to download the full example code.
Nicolas Hug's avatar
Nicolas Hug committed
9

10

11
12
13
14
15
16
17
18
19
20
21
22
23
Datapoints are Tensor subclasses introduced together with
``torchvision.transforms.v2``. This example showcases what these datapoints are
and how they behave.

.. warning::

    **Intended Audience** Unless you're writing your own transforms or your own datapoints, you
    probably do not need to read this guide. This is a fairly low-level topic
    that most users will not need to worry about: you do not need to understand
    the internals of datapoints to efficiently rely on
    ``torchvision.transforms.v2``. It may however be useful for advanced users
    trying to implement their own datasets, transforms, or work directly with
    the datapoints.
24
25
"""

26
# %%
27
28
29
30
31
32
import PIL.Image

import torch
from torchvision import datapoints


33
# %%
34
35
36
37
38
39
40
41
42
43
44
# What are datapoints?
# --------------------
#
# Datapoints are zero-copy tensor subclasses:

tensor = torch.rand(3, 256, 256)
image = datapoints.Image(tensor)

assert isinstance(image, torch.Tensor)
assert image.data_ptr() == tensor.data_ptr()

45
# %%
46
47
48
# Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
# for the input data.
#
49
50
51
52
53
54
55
# :mod:`torchvision.datapoints` supports four types of datapoints:
#
# * :class:`~torchvision.datapoints.Image`
# * :class:`~torchvision.datapoints.Video`
# * :class:`~torchvision.datapoints.BoundingBoxes`
# * :class:`~torchvision.datapoints.Mask`
#
56
57
58
59
60
# What can I do with a datapoint?
# -------------------------------
#
# Datapoints look and feel just like regular tensors - they **are** tensors.
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or
61
# any ``torch.*`` operator will also work on datapoints. See
62
63
64
65
66
# :ref:`datapoint_unwrapping_behaviour` for a few gotchas.

# %%
# .. _datapoint_creation:
#
67
68
69
# How do I construct a datapoint?
# -------------------------------
#
70
71
72
# Using the constructor
# ^^^^^^^^^^^^^^^^^^^^^
#
73
74
75
76
77
78
# Each datapoint class takes any tensor-like data that can be turned into a :class:`~torch.Tensor`

image = datapoints.Image([[[[0, 1], [1, 0]]]])
print(image)


79
# %%
80
81
82
83
84
85
86
# Similar to other PyTorch creations ops, the constructor also takes the ``dtype``, ``device``, and ``requires_grad``
# parameters.

float_image = datapoints.Image([[[0, 1], [1, 0]]], dtype=torch.float32, requires_grad=True)
print(float_image)


87
# %%
88
# In addition, :class:`~torchvision.datapoints.Image` and :class:`~torchvision.datapoints.Mask` can also take a
89
90
# :class:`PIL.Image.Image` directly:

91
image = datapoints.Image(PIL.Image.open("../assets/astronaut.jpg"))
92
93
print(image.shape, image.dtype)

94
# %%
95
96
97
98
99
100
101
102
103
# Some datapoints require additional metadata to be passed in ordered to be constructed. For example,
# :class:`~torchvision.datapoints.BoundingBoxes` requires the coordinate format as well as the size of the
# corresponding image (``canvas_size``) alongside the actual values. These
# metadata are required to properly transform the bounding boxes.

bboxes = datapoints.BoundingBoxes(
    [[17, 16, 344, 495], [0, 10, 0, 10]],
    format=datapoints.BoundingBoxFormat.XYXY,
    canvas_size=image.shape[-2:]
104
)
105
106
107
print(bboxes)

# %%
108
109
# Using ``datapoints.wrap()``
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
110
#
111
# You can also use the :func:`~torchvision.datapoints.wrap` function to wrap a tensor object
112
113
# into a datapoint. This is useful when you already have an object of the
# desired type, which typically happens when writing transforms: you just want
114
# to wrap the output like the input.
115
116

new_bboxes = torch.tensor([0, 20, 30, 40])
117
new_bboxes = datapoints.wrap(new_bboxes, like=bboxes)
118
119
assert isinstance(new_bboxes, datapoints.BoundingBoxes)
assert new_bboxes.canvas_size == bboxes.canvas_size
120

121
# %%
122
# The metadata of ``new_bboxes`` is the same as ``bboxes``, but you could pass
123
# it as a parameter to override it.
124
125
#
# .. _datapoint_unwrapping_behaviour:
126
#
127
128
129
# I had a Datapoint but now I have a Tensor. Help!
# ------------------------------------------------
#
130
131
# By default, operations on :class:`~torchvision.datapoints.Datapoint` objects
# will return a pure Tensor:
132

133

134
assert isinstance(bboxes, datapoints.BoundingBoxes)
135

136
137
# Shift bboxes by 3 pixels in both H and W
new_bboxes = bboxes + 3
138

139
140
141
142
143
144
145
146
147
148
assert isinstance(new_bboxes, torch.Tensor)
assert not isinstance(new_bboxes, datapoints.BoundingBoxes)

# %%
# .. note::
#
#    This behavior only affects native ``torch`` operations. If you are using
#    the built-in ``torchvision`` transforms or functionals, you will always get
#    as output the same type that you passed as input (pure ``Tensor`` or
#    ``Datapoint``).
149
150

# %%
151
152
153
154
# But I want a Datapoint back!
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# You can re-wrap a pure tensor into a datapoint by just calling the datapoint
155
156
# constructor, or by using the :func:`~torchvision.datapoints.wrap` function
# (see more details above in :ref:`datapoint_creation`):
157
158

new_bboxes = bboxes + 3
159
new_bboxes = datapoints.wrap(new_bboxes, like=bboxes)
160
assert isinstance(new_bboxes, datapoints.BoundingBoxes)
161

162
# %%
163
# Alternatively, you can use the :func:`~torchvision.datapoints.set_return_type`
Nicolas Hug's avatar
Nicolas Hug committed
164
165
# as a global config setting for the whole program, or as a context manager
# (read its docs to learn more about caveats):
166
167
168
169
170
171
172
173

with datapoints.set_return_type("datapoint"):
    new_bboxes = bboxes + 3
assert isinstance(new_bboxes, datapoints.BoundingBoxes)

# %%
# Why is this happening?
# ^^^^^^^^^^^^^^^^^^^^^^
174
#
175
176
177
178
179
180
181
182
183
# **For performance reasons**. :class:`~torchvision.datapoints.Datapoint`
# classes are Tensor subclasses, so any operation involving a
# :class:`~torchvision.datapoints.Datapoint` object will go through the
# `__torch_function__
# <https://pytorch.org/docs/stable/notes/extending.html#extending-torch>`_
# protocol. This induces a small overhead, which we want to avoid when possible.
# This doesn't matter for built-in ``torchvision`` transforms because we can
# avoid the overhead there, but it could be a problem in your model's
# ``forward``.
184
#
185
186
187
188
189
190
191
192
# **The alternative isn't much better anyway.** For every operation where
# preserving the :class:`~torchvision.datapoints.Datapoint` type makes
# sense, there are just as many operations where returning a pure Tensor is
# preferable: for example, is ``img.sum()`` still an :class:`~torchvision.datapoints.Image`?
# If we were to preserve :class:`~torchvision.datapoints.Datapoint` types all
# the way, even model's logits or the output of the loss function would end up
# being of type :class:`~torchvision.datapoints.Image`, and surely that's not
# desirable.
193
#
194
195
# .. note::
#
196
#    This behaviour is something we're actively seeking feedback on. If you find this surprising or if you
197
198
199
#    have any suggestions on how to better support your use-cases, please reach out to us via this issue:
#    https://github.com/pytorch/vision/issues/7319
#
200
201
202
# Exceptions
# ^^^^^^^^^^
#
203
# There are a few exceptions to this "unwrapping" rule:
Nicolas Hug's avatar
Nicolas Hug committed
204
205
206
# :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`,
# :meth:`torch.Tensor.detach`, and :meth:`~torch.Tensor.requires_grad_` retain
# the datapoint type.
207
#
Nicolas Hug's avatar
Nicolas Hug committed
208
209
210
# Inplace operations on datapoints like ``obj.add_()`` will preserve the type of
# ``obj``. However, the **returned** value of inplace operations will be a pure
# tensor:
211
212
213
214
215

image = datapoints.Image([[[0, 1], [1, 0]]])

new_image = image.add_(1).mul_(2)

216
217
218
219
# image got transformed in-place and is still an Image datapoint, but new_image
# is a Tensor. They share the same underlying data and they're equal, just
# different classes.
assert isinstance(image, datapoints.Image)
220
221
222
223
print(image)

assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image)
assert (new_image == image).all()
224
assert new_image.data_ptr() == image.data_ptr()