plot_transforms.py 12 KB
Newer Older
1
2
3
4
5
"""
==========================
Illustration of transforms
==========================

Nicolas Hug's avatar
Nicolas Hug committed
6
7
8
9
.. note::
    Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_transforms.ipynb>`_
    or :ref:`go to the end <sphx_glr_download_auto_examples_others_plot_transforms.py>` to download the full example code.

10
11
This example illustrates the various transforms available in :ref:`the
torchvision.transforms module <transforms>`.
12
13
"""

14
15
# sphinx_gallery_thumbnail_path = "../../gallery/assets/transforms_thumbnail.png"

16
17
18
19
20
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np

21
import torch
22
23
24
import torchvision.transforms as T


25
plt.rcParams["savefig.bbox"] = 'tight'
26
orig_img = Image.open(Path('../assets') / 'astronaut.jpg')
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# if you change the seed, make sure that the randomly-applied transforms
# properly show that the image can be both transformed and *not* transformed!
torch.manual_seed(0)


def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        row = [orig_img] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
46
47

    if with_orig:
48
49
50
51
52
53
54
        axs[0, 0].set(title='Original image')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()
55

Philip Meier's avatar
Philip Meier committed
56

57
# %%
58
59
60
61
62
63
# Geometric Transforms
# --------------------
# Geometric image transformation refers to the process of altering the geometric properties of an image,
# such as its shape, size, orientation, or position.
# It involves applying mathematical operations to the image pixels or coordinates to achieve the desired transformation.
#
64
# Pad
65
# ~~~
66
67
# The :class:`~torchvision.transforms.Pad` transform
# (see also :func:`~torchvision.transforms.functional.pad`)
68
# pads all image borders with some pixel values.
69
70
padded_imgs = [T.Pad(padding=padding)(orig_img) for padding in (3, 10, 30, 50)]
plot(padded_imgs)
71

72
# %%
73
# Resize
74
# ~~~~~~
75
76
77
# The :class:`~torchvision.transforms.Resize` transform
# (see also :func:`~torchvision.transforms.functional.resize`)
# resizes an image.
78
79
resized_imgs = [T.Resize(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
plot(resized_imgs)
80

81
# %%
82
# CenterCrop
83
# ~~~~~~~~~~
84
85
86
# The :class:`~torchvision.transforms.CenterCrop` transform
# (see also :func:`~torchvision.transforms.functional.center_crop`)
# crops the given image at the center.
87
88
center_crops = [T.CenterCrop(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
plot(center_crops)
89

90
# %%
91
# FiveCrop
92
# ~~~~~~~~
93
94
95
# The :class:`~torchvision.transforms.FiveCrop` transform
# (see also :func:`~torchvision.transforms.functional.five_crop`)
# crops the given image into four corners and the central crop.
96
97
(top_left, top_right, bottom_left, bottom_right, center) = T.FiveCrop(size=(100, 100))(orig_img)
plot([top_left, top_right, bottom_left, bottom_right, center])
98

99
# %%
100
101
# RandomPerspective
# ~~~~~~~~~~~~~~~~~
102
103
104
# The :class:`~torchvision.transforms.RandomPerspective` transform
# (see also :func:`~torchvision.transforms.functional.perspective`)
# performs random perspective transform on an image.
105
106
107
perspective_transformer = T.RandomPerspective(distortion_scale=0.6, p=1.0)
perspective_imgs = [perspective_transformer(orig_img) for _ in range(4)]
plot(perspective_imgs)
108

109
# %%
110
# RandomRotation
111
# ~~~~~~~~~~~~~~
112
113
114
# The :class:`~torchvision.transforms.RandomRotation` transform
# (see also :func:`~torchvision.transforms.functional.rotate`)
# rotates an image with random angle.
115
116
117
rotater = T.RandomRotation(degrees=(0, 180))
rotated_imgs = [rotater(orig_img) for _ in range(4)]
plot(rotated_imgs)
118

119
# %%
120
# RandomAffine
121
# ~~~~~~~~~~~~
122
123
124
# The :class:`~torchvision.transforms.RandomAffine` transform
# (see also :func:`~torchvision.transforms.functional.affine`)
# performs random affine transform on an image.
125
126
127
affine_transfomer = T.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75))
affine_imgs = [affine_transfomer(orig_img) for _ in range(4)]
plot(affine_imgs)
128

129
# %%
130
131
132
133
134
135
136
137
138
139
# ElasticTransform
# ~~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.ElasticTransform` transform
# (see also :func:`~torchvision.transforms.functional.elastic_transform`)
# Randomly transforms the morphology of objects in images and produces a
# see-through-water-like effect.
elastic_transformer = T.ElasticTransform(alpha=250.0)
transformed_imgs = [elastic_transformer(orig_img) for _ in range(2)]
plot(transformed_imgs)

140
# %%
141
# RandomCrop
142
# ~~~~~~~~~~
143
144
145
# The :class:`~torchvision.transforms.RandomCrop` transform
# (see also :func:`~torchvision.transforms.functional.crop`)
# crops an image at a random location.
146
147
148
cropper = T.RandomCrop(size=(128, 128))
crops = [cropper(orig_img) for _ in range(4)]
plot(crops)
149

150
# %%
151
# RandomResizedCrop
152
# ~~~~~~~~~~~~~~~~~
153
154
155
156
# The :class:`~torchvision.transforms.RandomResizedCrop` transform
# (see also :func:`~torchvision.transforms.functional.resized_crop`)
# crops an image at a random location, and then resizes the crop to a given
# size.
157
158
159
resize_cropper = T.RandomResizedCrop(size=(32, 32))
resized_crops = [resize_cropper(orig_img) for _ in range(4)]
plot(resized_crops)
160

161
# %%
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
# Photometric Transforms
# ----------------------
# Photometric image transformation refers to the process of modifying the photometric properties of an image,
# such as its brightness, contrast, color, or tone.
# These transformations are applied to change the visual appearance of an image
# while preserving its geometric structure.
#
# Except :class:`~torchvision.transforms.Grayscale`, the following transforms are random,
# which means that the same transform
# instance will produce different result each time it transforms a given image.
#
# Grayscale
# ~~~~~~~~~
# The :class:`~torchvision.transforms.Grayscale` transform
# (see also :func:`~torchvision.transforms.functional.to_grayscale`)
# converts an image to grayscale
gray_img = T.Grayscale()(orig_img)
plot([gray_img], cmap='gray')

181
# %%
182
183
184
185
186
187
188
189
# ColorJitter
# ~~~~~~~~~~~
# The :class:`~torchvision.transforms.ColorJitter` transform
# randomly changes the brightness, contrast, saturation, hue, and other properties of an image.
jitter = T.ColorJitter(brightness=.5, hue=.3)
jitted_imgs = [jitter(orig_img) for _ in range(4)]
plot(jitted_imgs)

190
# %%
191
192
193
194
195
196
197
198
199
# GaussianBlur
# ~~~~~~~~~~~~
# The :class:`~torchvision.transforms.GaussianBlur` transform
# (see also :func:`~torchvision.transforms.functional.gaussian_blur`)
# performs gaussian blur transform on an image.
blurrer = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))
blurred_imgs = [blurrer(orig_img) for _ in range(4)]
plot(blurred_imgs)

200
# %%
201
202
203
204
205
206
207
208
209
# RandomInvert
# ~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomInvert` transform
# (see also :func:`~torchvision.transforms.functional.invert`)
# randomly inverts the colors of the given image.
inverter = T.RandomInvert()
invertered_imgs = [inverter(orig_img) for _ in range(4)]
plot(invertered_imgs)

210
# %%
211
212
213
214
215
216
217
218
219
220
# RandomPosterize
# ~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomPosterize` transform
# (see also :func:`~torchvision.transforms.functional.posterize`)
# randomly posterizes the image by reducing the number of bits
# of each color channel.
posterizer = T.RandomPosterize(bits=2)
posterized_imgs = [posterizer(orig_img) for _ in range(4)]
plot(posterized_imgs)

221
# %%
222
223
224
225
226
227
228
229
230
231
# RandomSolarize
# ~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomSolarize` transform
# (see also :func:`~torchvision.transforms.functional.solarize`)
# randomly solarizes the image by inverting all pixel values above
# the threshold.
solarizer = T.RandomSolarize(threshold=192.0)
solarized_imgs = [solarizer(orig_img) for _ in range(4)]
plot(solarized_imgs)

232
# %%
233
234
235
236
237
238
239
240
241
# RandomAdjustSharpness
# ~~~~~~~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomAdjustSharpness` transform
# (see also :func:`~torchvision.transforms.functional.adjust_sharpness`)
# randomly adjusts the sharpness of the given image.
sharpness_adjuster = T.RandomAdjustSharpness(sharpness_factor=2)
sharpened_imgs = [sharpness_adjuster(orig_img) for _ in range(4)]
plot(sharpened_imgs)

242
# %%
243
244
245
246
247
248
249
250
251
# RandomAutocontrast
# ~~~~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomAutocontrast` transform
# (see also :func:`~torchvision.transforms.functional.autocontrast`)
# randomly applies autocontrast to the given image.
autocontraster = T.RandomAutocontrast()
autocontrasted_imgs = [autocontraster(orig_img) for _ in range(4)]
plot(autocontrasted_imgs)

252
# %%
253
254
255
256
257
258
259
260
261
# RandomEqualize
# ~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomEqualize` transform
# (see also :func:`~torchvision.transforms.functional.equalize`)
# randomly equalizes the histogram of the given image.
equalizer = T.RandomEqualize()
equalized_imgs = [equalizer(orig_img) for _ in range(4)]
plot(equalized_imgs)

262
# %%
263
264
265
266
267
# Augmentation Transforms
# -----------------------
# The following transforms are combinations of multiple transforms,
# either geometric or photometric, or both.
#
268
269
270
271
272
273
274
275
276
277
278
279
280
281
# AutoAugment
# ~~~~~~~~~~~
# The :class:`~torchvision.transforms.AutoAugment` transform
# automatically augments data based on a given auto-augmentation policy.
# See :class:`~torchvision.transforms.AutoAugmentPolicy` for the available policies.
policies = [T.AutoAugmentPolicy.CIFAR10, T.AutoAugmentPolicy.IMAGENET, T.AutoAugmentPolicy.SVHN]
augmenters = [T.AutoAugment(policy) for policy in policies]
imgs = [
    [augmenter(orig_img) for _ in range(4)]
    for augmenter in augmenters
]
row_title = [str(policy).split('.')[-1] for policy in policies]
plot(imgs, row_title=row_title)

282
# %%
283
284
# RandAugment
# ~~~~~~~~~~~
285
# The :class:`~torchvision.transforms.RandAugment` is an alternate version of AutoAugment.
286
287
288
289
augmenter = T.RandAugment()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)

290
# %%
291
292
# TrivialAugmentWide
# ~~~~~~~~~~~~~~~~~~
293
294
295
# The :class:`~torchvision.transforms.TrivialAugmentWide` is an alternate implementation of AutoAugment.
# However, instead of transforming an image multiple times, it transforms an image only once
# using a random transform from a given list with a random strength number.
296
297
298
299
augmenter = T.TrivialAugmentWide()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)

300
# %%
301
302
# AugMix
# ~~~~~~
303
# The :class:`~torchvision.transforms.AugMix` transform interpolates between augmented versions of an image.
304
305
306
307
augmenter = T.AugMix()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)

308
# %%
309
# Randomly-applied Transforms
310
311
# ---------------------------
#
312
313
314
# The following transforms are randomly-applied given a probability ``p``.  That is, given ``p = 0.5``,
# there is a 50% chance to return the original image, and a 50% chance to return the transformed image,
# even when called with the same transform instance!
315
#
316
# RandomHorizontalFlip
317
# ~~~~~~~~~~~~~~~~~~~~
318
319
320
# The :class:`~torchvision.transforms.RandomHorizontalFlip` transform
# (see also :func:`~torchvision.transforms.functional.hflip`)
# performs horizontal flip of an image, with a given probability.
321
322
323
hflipper = T.RandomHorizontalFlip(p=0.5)
transformed_imgs = [hflipper(orig_img) for _ in range(4)]
plot(transformed_imgs)
324

325
# %%
326
# RandomVerticalFlip
327
# ~~~~~~~~~~~~~~~~~~
328
329
330
# The :class:`~torchvision.transforms.RandomVerticalFlip` transform
# (see also :func:`~torchvision.transforms.functional.vflip`)
# performs vertical flip of an image, with a given probability.
331
332
333
vflipper = T.RandomVerticalFlip(p=0.5)
transformed_imgs = [vflipper(orig_img) for _ in range(4)]
plot(transformed_imgs)
334

335
# %%
336
# RandomApply
337
# ~~~~~~~~~~~
338
# The :class:`~torchvision.transforms.RandomApply` transform
339
340
341
342
# randomly applies a list of transforms, with a given probability.
applier = T.RandomApply(transforms=[T.RandomCrop(size=(64, 64))], p=0.5)
transformed_imgs = [applier(orig_img) for _ in range(4)]
plot(transformed_imgs)