plot_transforms.py 11.7 KB
Newer Older
1
2
3
4
5
"""
==========================
Illustration of transforms
==========================

6
7
This example illustrates the various transforms available in :ref:`the
torchvision.transforms module <transforms>`.
8
9
"""

10
11
# sphinx_gallery_thumbnail_path = "../../gallery/assets/transforms_thumbnail.png"

12
13
14
15
16
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np

17
import torch
18
19
20
import torchvision.transforms as T


21
plt.rcParams["savefig.bbox"] = 'tight'
22
orig_img = Image.open(Path('assets') / 'astronaut.jpg')
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# if you change the seed, make sure that the randomly-applied transforms
# properly show that the image can be both transformed and *not* transformed!
torch.manual_seed(0)


def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        row = [orig_img] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
42
43

    if with_orig:
44
45
46
47
48
49
50
        axs[0, 0].set(title='Original image')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()
51

Philip Meier's avatar
Philip Meier committed
52

53
# %%
54
55
56
57
58
59
# Geometric Transforms
# --------------------
# Geometric image transformation refers to the process of altering the geometric properties of an image,
# such as its shape, size, orientation, or position.
# It involves applying mathematical operations to the image pixels or coordinates to achieve the desired transformation.
#
60
# Pad
61
# ~~~
62
63
# The :class:`~torchvision.transforms.Pad` transform
# (see also :func:`~torchvision.transforms.functional.pad`)
64
# pads all image borders with some pixel values.
65
66
padded_imgs = [T.Pad(padding=padding)(orig_img) for padding in (3, 10, 30, 50)]
plot(padded_imgs)
67

68
# %%
69
# Resize
70
# ~~~~~~
71
72
73
# The :class:`~torchvision.transforms.Resize` transform
# (see also :func:`~torchvision.transforms.functional.resize`)
# resizes an image.
74
75
resized_imgs = [T.Resize(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
plot(resized_imgs)
76

77
# %%
78
# CenterCrop
79
# ~~~~~~~~~~
80
81
82
# The :class:`~torchvision.transforms.CenterCrop` transform
# (see also :func:`~torchvision.transforms.functional.center_crop`)
# crops the given image at the center.
83
84
center_crops = [T.CenterCrop(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
plot(center_crops)
85

86
# %%
87
# FiveCrop
88
# ~~~~~~~~
89
90
91
# The :class:`~torchvision.transforms.FiveCrop` transform
# (see also :func:`~torchvision.transforms.functional.five_crop`)
# crops the given image into four corners and the central crop.
92
93
(top_left, top_right, bottom_left, bottom_right, center) = T.FiveCrop(size=(100, 100))(orig_img)
plot([top_left, top_right, bottom_left, bottom_right, center])
94

95
# %%
96
97
# RandomPerspective
# ~~~~~~~~~~~~~~~~~
98
99
100
# The :class:`~torchvision.transforms.RandomPerspective` transform
# (see also :func:`~torchvision.transforms.functional.perspective`)
# performs random perspective transform on an image.
101
102
103
perspective_transformer = T.RandomPerspective(distortion_scale=0.6, p=1.0)
perspective_imgs = [perspective_transformer(orig_img) for _ in range(4)]
plot(perspective_imgs)
104

105
# %%
106
# RandomRotation
107
# ~~~~~~~~~~~~~~
108
109
110
# The :class:`~torchvision.transforms.RandomRotation` transform
# (see also :func:`~torchvision.transforms.functional.rotate`)
# rotates an image with random angle.
111
112
113
rotater = T.RandomRotation(degrees=(0, 180))
rotated_imgs = [rotater(orig_img) for _ in range(4)]
plot(rotated_imgs)
114

115
# %%
116
# RandomAffine
117
# ~~~~~~~~~~~~
118
119
120
# The :class:`~torchvision.transforms.RandomAffine` transform
# (see also :func:`~torchvision.transforms.functional.affine`)
# performs random affine transform on an image.
121
122
123
affine_transfomer = T.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75))
affine_imgs = [affine_transfomer(orig_img) for _ in range(4)]
plot(affine_imgs)
124

125
# %%
126
127
128
129
130
131
132
133
134
135
# ElasticTransform
# ~~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.ElasticTransform` transform
# (see also :func:`~torchvision.transforms.functional.elastic_transform`)
# Randomly transforms the morphology of objects in images and produces a
# see-through-water-like effect.
elastic_transformer = T.ElasticTransform(alpha=250.0)
transformed_imgs = [elastic_transformer(orig_img) for _ in range(2)]
plot(transformed_imgs)

136
# %%
137
# RandomCrop
138
# ~~~~~~~~~~
139
140
141
# The :class:`~torchvision.transforms.RandomCrop` transform
# (see also :func:`~torchvision.transforms.functional.crop`)
# crops an image at a random location.
142
143
144
cropper = T.RandomCrop(size=(128, 128))
crops = [cropper(orig_img) for _ in range(4)]
plot(crops)
145

146
# %%
147
# RandomResizedCrop
148
# ~~~~~~~~~~~~~~~~~
149
150
151
152
# The :class:`~torchvision.transforms.RandomResizedCrop` transform
# (see also :func:`~torchvision.transforms.functional.resized_crop`)
# crops an image at a random location, and then resizes the crop to a given
# size.
153
154
155
resize_cropper = T.RandomResizedCrop(size=(32, 32))
resized_crops = [resize_cropper(orig_img) for _ in range(4)]
plot(resized_crops)
156

157
# %%
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# Photometric Transforms
# ----------------------
# Photometric image transformation refers to the process of modifying the photometric properties of an image,
# such as its brightness, contrast, color, or tone.
# These transformations are applied to change the visual appearance of an image
# while preserving its geometric structure.
#
# Except :class:`~torchvision.transforms.Grayscale`, the following transforms are random,
# which means that the same transform
# instance will produce different result each time it transforms a given image.
#
# Grayscale
# ~~~~~~~~~
# The :class:`~torchvision.transforms.Grayscale` transform
# (see also :func:`~torchvision.transforms.functional.to_grayscale`)
# converts an image to grayscale
gray_img = T.Grayscale()(orig_img)
plot([gray_img], cmap='gray')

177
# %%
178
179
180
181
182
183
184
185
# ColorJitter
# ~~~~~~~~~~~
# The :class:`~torchvision.transforms.ColorJitter` transform
# randomly changes the brightness, contrast, saturation, hue, and other properties of an image.
jitter = T.ColorJitter(brightness=.5, hue=.3)
jitted_imgs = [jitter(orig_img) for _ in range(4)]
plot(jitted_imgs)

186
# %%
187
188
189
190
191
192
193
194
195
# GaussianBlur
# ~~~~~~~~~~~~
# The :class:`~torchvision.transforms.GaussianBlur` transform
# (see also :func:`~torchvision.transforms.functional.gaussian_blur`)
# performs gaussian blur transform on an image.
blurrer = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))
blurred_imgs = [blurrer(orig_img) for _ in range(4)]
plot(blurred_imgs)

196
# %%
197
198
199
200
201
202
203
204
205
# RandomInvert
# ~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomInvert` transform
# (see also :func:`~torchvision.transforms.functional.invert`)
# randomly inverts the colors of the given image.
inverter = T.RandomInvert()
invertered_imgs = [inverter(orig_img) for _ in range(4)]
plot(invertered_imgs)

206
# %%
207
208
209
210
211
212
213
214
215
216
# RandomPosterize
# ~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomPosterize` transform
# (see also :func:`~torchvision.transforms.functional.posterize`)
# randomly posterizes the image by reducing the number of bits
# of each color channel.
posterizer = T.RandomPosterize(bits=2)
posterized_imgs = [posterizer(orig_img) for _ in range(4)]
plot(posterized_imgs)

217
# %%
218
219
220
221
222
223
224
225
226
227
# RandomSolarize
# ~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomSolarize` transform
# (see also :func:`~torchvision.transforms.functional.solarize`)
# randomly solarizes the image by inverting all pixel values above
# the threshold.
solarizer = T.RandomSolarize(threshold=192.0)
solarized_imgs = [solarizer(orig_img) for _ in range(4)]
plot(solarized_imgs)

228
# %%
229
230
231
232
233
234
235
236
237
# RandomAdjustSharpness
# ~~~~~~~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomAdjustSharpness` transform
# (see also :func:`~torchvision.transforms.functional.adjust_sharpness`)
# randomly adjusts the sharpness of the given image.
sharpness_adjuster = T.RandomAdjustSharpness(sharpness_factor=2)
sharpened_imgs = [sharpness_adjuster(orig_img) for _ in range(4)]
plot(sharpened_imgs)

238
# %%
239
240
241
242
243
244
245
246
247
# RandomAutocontrast
# ~~~~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomAutocontrast` transform
# (see also :func:`~torchvision.transforms.functional.autocontrast`)
# randomly applies autocontrast to the given image.
autocontraster = T.RandomAutocontrast()
autocontrasted_imgs = [autocontraster(orig_img) for _ in range(4)]
plot(autocontrasted_imgs)

248
# %%
249
250
251
252
253
254
255
256
257
# RandomEqualize
# ~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomEqualize` transform
# (see also :func:`~torchvision.transforms.functional.equalize`)
# randomly equalizes the histogram of the given image.
equalizer = T.RandomEqualize()
equalized_imgs = [equalizer(orig_img) for _ in range(4)]
plot(equalized_imgs)

258
# %%
259
260
261
262
263
# Augmentation Transforms
# -----------------------
# The following transforms are combinations of multiple transforms,
# either geometric or photometric, or both.
#
264
265
266
267
268
269
270
271
272
273
274
275
276
277
# AutoAugment
# ~~~~~~~~~~~
# The :class:`~torchvision.transforms.AutoAugment` transform
# automatically augments data based on a given auto-augmentation policy.
# See :class:`~torchvision.transforms.AutoAugmentPolicy` for the available policies.
policies = [T.AutoAugmentPolicy.CIFAR10, T.AutoAugmentPolicy.IMAGENET, T.AutoAugmentPolicy.SVHN]
augmenters = [T.AutoAugment(policy) for policy in policies]
imgs = [
    [augmenter(orig_img) for _ in range(4)]
    for augmenter in augmenters
]
row_title = [str(policy).split('.')[-1] for policy in policies]
plot(imgs, row_title=row_title)

278
# %%
279
280
# RandAugment
# ~~~~~~~~~~~
281
# The :class:`~torchvision.transforms.RandAugment` is an alternate version of AutoAugment.
282
283
284
285
augmenter = T.RandAugment()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)

286
# %%
287
288
# TrivialAugmentWide
# ~~~~~~~~~~~~~~~~~~
289
290
291
# The :class:`~torchvision.transforms.TrivialAugmentWide` is an alternate implementation of AutoAugment.
# However, instead of transforming an image multiple times, it transforms an image only once
# using a random transform from a given list with a random strength number.
292
293
294
295
augmenter = T.TrivialAugmentWide()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)

296
# %%
297
298
# AugMix
# ~~~~~~
299
# The :class:`~torchvision.transforms.AugMix` transform interpolates between augmented versions of an image.
300
301
302
303
augmenter = T.AugMix()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)

304
# %%
305
# Randomly-applied Transforms
306
307
# ---------------------------
#
308
309
310
# The following transforms are randomly-applied given a probability ``p``.  That is, given ``p = 0.5``,
# there is a 50% chance to return the original image, and a 50% chance to return the transformed image,
# even when called with the same transform instance!
311
#
312
# RandomHorizontalFlip
313
# ~~~~~~~~~~~~~~~~~~~~
314
315
316
# The :class:`~torchvision.transforms.RandomHorizontalFlip` transform
# (see also :func:`~torchvision.transforms.functional.hflip`)
# performs horizontal flip of an image, with a given probability.
317
318
319
hflipper = T.RandomHorizontalFlip(p=0.5)
transformed_imgs = [hflipper(orig_img) for _ in range(4)]
plot(transformed_imgs)
320

321
# %%
322
# RandomVerticalFlip
323
# ~~~~~~~~~~~~~~~~~~
324
325
326
# The :class:`~torchvision.transforms.RandomVerticalFlip` transform
# (see also :func:`~torchvision.transforms.functional.vflip`)
# performs vertical flip of an image, with a given probability.
327
328
329
vflipper = T.RandomVerticalFlip(p=0.5)
transformed_imgs = [vflipper(orig_img) for _ in range(4)]
plot(transformed_imgs)
330

331
# %%
332
# RandomApply
333
# ~~~~~~~~~~~
334
# The :class:`~torchvision.transforms.RandomApply` transform
335
336
337
338
# randomly applies a list of transforms, with a given probability.
applier = T.RandomApply(transforms=[T.RandomCrop(size=(64, 64))], p=0.5)
transformed_imgs = [applier(orig_img) for _ in range(4)]
plot(transformed_imgs)