plot_datapoints.py 10.4 KB
Newer Older
1
2
3
4
5
"""
==============
Datapoints FAQ
==============

6
7
8
9
10
11
12
13
14
15
16
17
18
Datapoints are Tensor subclasses introduced together with
``torchvision.transforms.v2``. This example showcases what these datapoints are
and how they behave.

.. warning::

    **Intended Audience** Unless you're writing your own transforms or your own datapoints, you
    probably do not need to read this guide. This is a fairly low-level topic
    that most users will not need to worry about: you do not need to understand
    the internals of datapoints to efficiently rely on
    ``torchvision.transforms.v2``. It may however be useful for advanced users
    trying to implement their own datasets, transforms, or work directly with
    the datapoints.
19
20
"""

21
# %%
22
23
24
25
26
27
28
29
30
31
import PIL.Image

import torch
import torchvision

# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
# some APIs may slightly change in the future
torchvision.disable_beta_transforms_warning()

from torchvision import datapoints
32
from torchvision.transforms.v2 import functional as F
33
34


35
# %%
36
37
38
39
40
41
42
43
44
45
46
# What are datapoints?
# --------------------
#
# Datapoints are zero-copy tensor subclasses:

tensor = torch.rand(3, 256, 256)
image = datapoints.Image(tensor)

assert isinstance(image, torch.Tensor)
assert image.data_ptr() == tensor.data_ptr()

47
# %%
48
49
50
# Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
# for the input data.
#
51
52
53
54
55
56
57
# :mod:`torchvision.datapoints` supports four types of datapoints:
#
# * :class:`~torchvision.datapoints.Image`
# * :class:`~torchvision.datapoints.Video`
# * :class:`~torchvision.datapoints.BoundingBoxes`
# * :class:`~torchvision.datapoints.Mask`
#
58
59
60
61
62
# What can I do with a datapoint?
# -------------------------------
#
# Datapoints look and feel just like regular tensors - they **are** tensors.
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or
63
# any ``torch.*`` operator will also work on datapoints. See
64
65
66
67
68
# :ref:`datapoint_unwrapping_behaviour` for a few gotchas.

# %%
# .. _datapoint_creation:
#
69
70
71
# How do I construct a datapoint?
# -------------------------------
#
72
73
74
# Using the constructor
# ^^^^^^^^^^^^^^^^^^^^^
#
75
76
77
78
79
80
# Each datapoint class takes any tensor-like data that can be turned into a :class:`~torch.Tensor`

image = datapoints.Image([[[[0, 1], [1, 0]]]])
print(image)


81
# %%
82
83
84
85
86
87
88
# Similar to other PyTorch creations ops, the constructor also takes the ``dtype``, ``device``, and ``requires_grad``
# parameters.

float_image = datapoints.Image([[[0, 1], [1, 0]]], dtype=torch.float32, requires_grad=True)
print(float_image)


89
# %%
90
# In addition, :class:`~torchvision.datapoints.Image` and :class:`~torchvision.datapoints.Mask` can also take a
91
92
93
94
95
# :class:`PIL.Image.Image` directly:

image = datapoints.Image(PIL.Image.open("assets/astronaut.jpg"))
print(image.shape, image.dtype)

96
# %%
97
98
99
100
101
102
103
104
105
# Some datapoints require additional metadata to be passed in ordered to be constructed. For example,
# :class:`~torchvision.datapoints.BoundingBoxes` requires the coordinate format as well as the size of the
# corresponding image (``canvas_size``) alongside the actual values. These
# metadata are required to properly transform the bounding boxes.

bboxes = datapoints.BoundingBoxes(
    [[17, 16, 344, 495], [0, 10, 0, 10]],
    format=datapoints.BoundingBoxFormat.XYXY,
    canvas_size=image.shape[-2:]
106
)
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
print(bboxes)

# %%
# Using the ``wrap_like()`` class method
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# You can also use the ``wrap_like()`` class method to wrap a tensor object
# into a datapoint. This is useful when you already have an object of the
# desired type, which typically happens when writing transforms: you just want
# to wrap the output like the input. This API is inspired by utils like
# :func:`torch.zeros_like`:

new_bboxes = torch.tensor([0, 20, 30, 40])
new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes)
assert isinstance(new_bboxes, datapoints.BoundingBoxes)
assert new_bboxes.canvas_size == bboxes.canvas_size
123
124


125
# %%
126
127
128
129
130
# The metadata of ``new_bboxes`` is the same as ``bboxes``, but you could pass
# it as a parameter to override it. Check the
# :meth:`~torchvision.datapoints.BoundingBoxes.wrap_like` documentation for
# more details.
#
131
132
133
# Do I have to wrap the output of the datasets myself?
# ----------------------------------------------------
#
134
135
# TODO: Move this in another guide - this is user-facing, not dev-facing.
#
136
137
138
139
140
# Only if you are using custom datasets. For the built-in ones, you can use
# :func:`torchvision.datasets.wrap_dataset_for_transforms_v2`. Note that the function also supports subclasses of the
# built-in datasets. Meaning, if your custom dataset subclasses from a built-in one and the output type is the same, you
# also don't have to wrap manually.
#
141
142
143
144
145
146
147
148
149
150
151
# If you have a custom dataset, for example the ``PennFudanDataset`` from
# `this tutorial <https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html>`_, you have two options:
#
# 1. Perform the wrapping inside ``__getitem__``:

class PennFudanDataset(torch.utils.data.Dataset):
    ...

    def __getitem__(self, item):
        ...

152
153
        target["bboxes"] = datapoints.BoundingBoxes(
            bboxes,
154
            format=datapoints.BoundingBoxFormat.XYXY,
Philip Meier's avatar
Philip Meier committed
155
            canvas_size=F.get_size(img),
156
157
158
159
160
161
162
163
164
165
166
        )
        target["labels"] = labels
        target["masks"] = datapoints.Mask(masks)

        ...

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        ...

167
# %%
168
169
170
171
172
# 2. Perform the wrapping inside a custom transformation at the beginning of your pipeline:


class WrapPennFudanDataset:
    def __call__(self, img, target):
173
        target["boxes"] = datapoints.BoundingBoxes(
174
175
            target["boxes"],
            format=datapoints.BoundingBoxFormat.XYXY,
Philip Meier's avatar
Philip Meier committed
176
            canvas_size=F.get_size(img),
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        )
        target["masks"] = datapoints.Mask(target["masks"])
        return img, target


...


def get_transform(train):
    transforms = []
    transforms.append(WrapPennFudanDataset())
    transforms.append(T.PILToTensor())
    ...

191
# %%
192
193
# .. note::
#
194
#    If both :class:`~torchvision.datapoints.BoundingBoxes` and :class:`~torchvision.datapoints.Mask`'s are included in
195
196
197
198
199
200
201
202
#    the sample, ``torchvision.transforms.v2`` will transform them both. Meaning, if you don't need both, dropping or
#    at least not wrapping the obsolete parts, can lead to a significant performance boost.
#
#    For example, if you are using the ``PennFudanDataset`` for object detection, not wrapping the masks avoids
#    transforming them over and over again in the pipeline just to ultimately ignoring them. In general, it would be
#    even better to not load the masks at all, but this is not possible in this example, since the bounding boxes are
#    generated from the masks.
#
203
# .. _datapoint_unwrapping_behaviour:
204
#
205
206
207
# I had a Datapoint but now I have a Tensor. Help!
# ------------------------------------------------
#
208
209
# By default, operations on :class:`~torchvision.datapoints.Datapoint` objects
# will return a pure Tensor:
210

211

212
assert isinstance(bboxes, datapoints.BoundingBoxes)
213

214
215
# Shift bboxes by 3 pixels in both H and W
new_bboxes = bboxes + 3
216

217
218
219
220
221
222
223
224
225
226
assert isinstance(new_bboxes, torch.Tensor)
assert not isinstance(new_bboxes, datapoints.BoundingBoxes)

# %%
# .. note::
#
#    This behavior only affects native ``torch`` operations. If you are using
#    the built-in ``torchvision`` transforms or functionals, you will always get
#    as output the same type that you passed as input (pure ``Tensor`` or
#    ``Datapoint``).
227
228

# %%
229
230
231
232
233
234
# But I want a Datapoint back!
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# You can re-wrap a pure tensor into a datapoint by just calling the datapoint
# constructor, or by using the ``.wrap_like()`` class method (see more details
# above in :ref:`datapoint_creation`):
235
236
237
238

new_bboxes = bboxes + 3
new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes)
assert isinstance(new_bboxes, datapoints.BoundingBoxes)
239

240
# %%
241
242
243
244
245
246
247
248
249
250
# Alternatively, you can use the :func:`~torchvision.datapoints.set_return_type`
# as a global config setting for the whole program, or as a context manager:

with datapoints.set_return_type("datapoint"):
    new_bboxes = bboxes + 3
assert isinstance(new_bboxes, datapoints.BoundingBoxes)

# %%
# Why is this happening?
# ^^^^^^^^^^^^^^^^^^^^^^
251
#
252
253
254
255
256
257
258
259
260
# **For performance reasons**. :class:`~torchvision.datapoints.Datapoint`
# classes are Tensor subclasses, so any operation involving a
# :class:`~torchvision.datapoints.Datapoint` object will go through the
# `__torch_function__
# <https://pytorch.org/docs/stable/notes/extending.html#extending-torch>`_
# protocol. This induces a small overhead, which we want to avoid when possible.
# This doesn't matter for built-in ``torchvision`` transforms because we can
# avoid the overhead there, but it could be a problem in your model's
# ``forward``.
261
#
262
263
264
265
266
267
268
269
# **The alternative isn't much better anyway.** For every operation where
# preserving the :class:`~torchvision.datapoints.Datapoint` type makes
# sense, there are just as many operations where returning a pure Tensor is
# preferable: for example, is ``img.sum()`` still an :class:`~torchvision.datapoints.Image`?
# If we were to preserve :class:`~torchvision.datapoints.Datapoint` types all
# the way, even model's logits or the output of the loss function would end up
# being of type :class:`~torchvision.datapoints.Image`, and surely that's not
# desirable.
270
#
271
272
# .. note::
#
273
#    This behaviour is something we're actively seeking feedback on. If you find this surprising or if you
274
275
276
#    have any suggestions on how to better support your use-cases, please reach out to us via this issue:
#    https://github.com/pytorch/vision/issues/7319
#
277
278
279
# Exceptions
# ^^^^^^^^^^
#
280
# There are a few exceptions to this "unwrapping" rule:
281
#
282
283
284
285
286
287
# 1. Operations like :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`,
#    :meth:`torch.Tensor.detach` and :meth:`~torch.Tensor.requires_grad_` retain
#    the datapoint type.
# 2. Inplace operations on datapoints like ``.add_()`` preserve they type. However,
#    the **returned** value of inplace operations will be unwrapped into a pure
#    tensor:
288
289
290
291
292

image = datapoints.Image([[[0, 1], [1, 0]]])

new_image = image.add_(1).mul_(2)

293
294
295
296
# image got transformed in-place and is still an Image datapoint, but new_image
# is a Tensor. They share the same underlying data and they're equal, just
# different classes.
assert isinstance(image, datapoints.Image)
297
298
299
300
print(image)

assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image)
assert (new_image == image).all()
301
assert new_image.data_ptr() == image.data_ptr()