plot_transforms_getting_started.py 10 KB
Newer Older
1
2
3
4
5
"""
==================================
Getting started with transforms v2
==================================

Nicolas Hug's avatar
Nicolas Hug committed
6
.. note::
Nicolas Hug's avatar
Nicolas Hug committed
7
8
    Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_transforms_getting_started.ipynb>`_
    or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_transforms_getting_started.py>` to download the full example code.
Nicolas Hug's avatar
Nicolas Hug committed
9

10
11
12
13
This example illustrates all of what you need to know to get started with the
new :mod:`torchvision.transforms.v2` API. We'll cover simple tasks like
image classification, and more advanced ones like object detection /
segmentation.
14
15
"""

16
17
18
# %%
# First, a bit of setup
from pathlib import Path
19
import torch
20
21
import matplotlib.pyplot as plt
plt.rcParams["savefig.bbox"] = 'tight'
22

23
24
from torchvision.transforms import v2
from torchvision.io import read_image
25

26
torch.manual_seed(1)
27

28
29
30
31
32
# If you're trying to run that on collab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
from helpers import plot
img = read_image(str(Path('../assets') / 'astronaut.jpg'))
print(f"{type(img) = }, {img.dtype = }, {img.shape = }")
33

34
35
36
37
38
39
40
# %%
# The basics
# ----------
#
# The Torchvision transforms behave like a regular :class:`torch.nn.Module` (in
# fact, most of them are): instantiate a transform, pass an input, get a
# transformed output:
41

42
43
transform = v2.RandomCrop(size=(224, 224))
out = transform(img)
44

45
plot([img, out])
46

47
48
49
50
51
52
# %%
# I just want to do image classification
# --------------------------------------
#
# If you just care about image classification, things are very simple. A basic
# classification pipeline may look like this:
53

54
55
56
57
58
59
60
transforms = v2.Compose([
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
out = transforms(img)
61

62
plot([img, out])
63

64
# %%
65
66
67
68
69
70
71
72
# Such transformation pipeline is typically passed as the ``transform`` argument
# to the :ref:`Datasets <datasets>`, e.g. ``ImageNet(...,
# transform=transforms)``.
#
# That's pretty much all there is. From there, read through our :ref:`main docs
# <transforms>` to learn more about recommended practices and conventions, or
# explore more :ref:`examples <transforms_gallery>` e.g. how to use augmentation
# transforms like :ref:`CutMix and MixUp
Nicolas Hug's avatar
Nicolas Hug committed
73
# <sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py>`.
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#
# .. note::
#
#     If you're already relying on the ``torchvision.transforms`` v1 API,
#     we recommend to :ref:`switch to the new v2 transforms<v1_or_v2>`. It's
#     very easy: the v2 transforms are fully compatible with the v1 API, so you
#     only need to change the import!
#
# Detection, Segmentation, Videos
# -------------------------------
#
# The new Torchvision transforms in the ``torchvision.transforms.v2`` namespace
# support tasks beyond image classification: they can also transform bounding
# boxes, segmentation / detection masks, or videos.
#
# Let's briefly look at a detection example with bounding boxes.
90

91
from torchvision import tv_tensors  # we'll describe this a bit later, bare with us
92

93
boxes = tv_tensors.BoundingBoxes(
94
    [
95
96
97
98
99
        [15, 10, 370, 510],
        [275, 340, 510, 510],
        [130, 345, 210, 425]
    ],
    format="XYXY", canvas_size=img.shape[-2:])
100

101
transforms = v2.Compose([
102
103
104
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomPhotometricDistort(p=1),
    v2.RandomHorizontalFlip(p=1),
105
])
106
107
out_img, out_boxes = transforms(img, boxes)
print(type(boxes), type(out_boxes))
108

109
plot([(img, boxes), (out_img, out_boxes)])
110

111
112
113
# %%
#
# The example above focuses on object detection. But if we had masks
114
115
# (:class:`torchvision.tv_tensors.Mask`) for object segmentation or semantic
# segmentation, or videos (:class:`torchvision.tv_tensors.Video`), we could have
116
117
# passed them to the transforms in exactly the same way.
#
118
# By now you likely have a few questions: what are these TVTensors, how do we
119
120
# use them, and what is the expected input/output of those transforms? We'll
# answer these in the next sections.
121

122
# %%
123
#
124
# .. _what_are_tv_tensors:
125
#
126
# What are TVTensors?
127
128
# --------------------
#
129
# TVTensors are :class:`torch.Tensor` subclasses. The available TVTensors are
130
131
132
133
# :class:`~torchvision.tv_tensors.Image`,
# :class:`~torchvision.tv_tensors.BoundingBoxes`,
# :class:`~torchvision.tv_tensors.Mask`, and
# :class:`~torchvision.tv_tensors.Video`.
134
#
135
# TVTensors look and feel just like regular tensors - they **are** tensors.
136
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()``
137
# or any ``torch.*`` operator will also work on a TVTensor:
138

139
img_dp = tv_tensors.Image(torch.randint(0, 256, (3, 256, 256), dtype=torch.uint8))
140

141
142
print(f"{isinstance(img_dp, torch.Tensor) = }")
print(f"{img_dp.dtype = }, {img_dp.shape = }, {img_dp.sum() = }")
143

144
# %%
145
# These TVTensor classes are at the core of the transforms: in order to
146
147
# transform a given input, the transforms first look at the **class** of the
# object, and dispatch to the appropriate implementation accordingly.
148
#
149
# You don't need to know much more about TVTensors at this point, but advanced
150
# users who want to learn more can refer to
151
# :ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`.
152
153
154
155
156
157
#
# What do I pass as input?
# ------------------------
#
# Above, we've seen two examples: one where we passed a single image as input
# i.e. ``out = transforms(img)``, and one where we passed both an image and
158
# bounding boxes, i.e. ``out_img, out_boxes = transforms(img, boxes)``.
159
160
161
162
163
164
#
# In fact, transforms support **arbitrary input structures**. The input can be a
# single image, a tuple, an arbitrarily nested dictionary... pretty much
# anything. The same structure will be returned as output. Below, we use the
# same detection transforms, but pass a tuple (image, target_dict) as input and
# we're getting the same structure as output:
165

166
target = {
167
168
    "boxes": boxes,
    "labels": torch.arange(boxes.shape[0]),
169
170
    "this_is_ignored": ("arbitrary", {"structure": "!"})
}
171

172
173
# Re-using the transforms and definitions from above.
out_img, out_target = transforms(img, target)
174

Nicolas Hug's avatar
Nicolas Hug committed
175
# sphinx_gallery_thumbnail_number = 4
176
plot([(img, target["boxes"]), (out_img, out_target["boxes"])])
177
print(f"{out_target['this_is_ignored']}")
178

179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
# %%
# We passed a tuple so we get a tuple back, and the second element is the
# tranformed target dict. Transforms don't really care about the structure of
# the input; as mentioned above, they only care about the **type** of the
# objects and transforms them accordingly.
#
# *Foreign* objects like strings or ints are simply passed-through. This can be
# useful e.g. if you want to associate a path with every single sample when
# debugging!
#
# .. _passthrough_heuristic:
#
# .. note::
#
#     **Disclaimer** This note is slightly advanced and can be safely skipped on
#     a first read.
#
#     Pure :class:`torch.Tensor` objects are, in general, treated as images (or
#     as videos for video-specific transforms). Indeed, you may have noticed
#     that in the code above we haven't used the
199
#     :class:`~torchvision.tv_tensors.Image` class at all, and yet our images
200
201
202
203
#     got transformed properly. Transforms follow the following logic to
#     determine whether a pure Tensor should be treated as an image (or video),
#     or just ignored:
#
204
205
#     * If there is an :class:`~torchvision.tv_tensors.Image`,
#       :class:`~torchvision.tv_tensors.Video`,
206
207
#       or :class:`PIL.Image.Image` instance in the input, all other pure
#       tensors are passed-through.
208
209
#     * If there is no :class:`~torchvision.tv_tensors.Image` or
#       :class:`~torchvision.tv_tensors.Video` instance, only the first pure
210
211
212
213
214
215
216
217
218
219
#       :class:`torch.Tensor` will be transformed as image or video, while all
#       others will be passed-through. Here "first" means "first in a depth-wise
#       traversal".
#
#     This is what happened in the detection example above: the first pure
#     tensor was the image so it got transformed properly, and all other pure
#     tensor instances like the ``labels`` were passed-through (although labels
#     can still be transformed by some transforms like
#     :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes`!).
#
Nicolas Hug's avatar
Nicolas Hug committed
220
221
# .. _transforms_datasets_intercompatibility:
#
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
# Transforms and Datasets intercompatibility
# ------------------------------------------
#
# Roughly speaking, the output of the datasets must correspond to the input of
# the transforms. How to do that depends on whether you're using the torchvision
# :ref:`built-in datatsets <datasets>`, or your own custom datasets.
#
# Using built-in datasets
# ^^^^^^^^^^^^^^^^^^^^^^^
#
# If you're just doing image classification, you don't need to do anything. Just
# use ``transform`` argument of the dataset e.g. ``ImageNet(...,
# transform=transforms)`` and you're good to go.
#
# Torchvision also supports datasets for object detection or segmentation like
# :class:`torchvision.datasets.CocoDetection`. Those datasets predate
# the existence of the :mod:`torchvision.transforms.v2` module and of the
239
# TVTensors, so they don't return TVTensors out of the box.
240
#
241
# An easy way to force those datasets to return TVTensors and to make them
242
243
244
245
246
247
248
# compatible with v2 transforms is to use the
# :func:`torchvision.datasets.wrap_dataset_for_transforms_v2` function:
#
# .. code-block:: python
#
#    from torchvision.datasets import CocoDetection, wrap_dataset_for_transforms_v2
#
Nicolas Hug's avatar
Nicolas Hug committed
249
#    dataset = CocoDetection(..., transforms=my_transforms)
250
#    dataset = wrap_dataset_for_transforms_v2(dataset)
251
#    # Now the dataset returns TVTensors!
252
253
254
255
256
#
# Using your own datasets
# ^^^^^^^^^^^^^^^^^^^^^^^
#
# If you have a custom dataset, then you'll need to convert your objects into
257
258
# the appropriate TVTensor classes. Creating TVTensor instances is very easy,
# refer to :ref:`tv_tensor_creation` for more details.
259
260
261
262
263
264
265
266
#
# There are two main places where you can implement that conversion logic:
#
# - At the end of the datasets's ``__getitem__`` method, before returning the
#   sample (or by sub-classing the dataset).
# - As the very first step of your transforms pipeline
#
# Either way, the logic will depend on your specific dataset.