Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
5ac27fe3
Unverified
Commit
5ac27fe3
authored
Apr 28, 2021
by
Nicolas Hug
Committed by
GitHub
Apr 28, 2021
Browse files
Rework transforms example in gallery (#3744)
parent
c8f7d772
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
116 additions
and
100 deletions
+116
-100
gallery/plot_transforms.py
gallery/plot_transforms.py
+116
-100
No files found.
gallery/plot_transforms.py
View file @
5ac27fe3
...
@@ -11,21 +11,40 @@ from pathlib import Path
...
@@ -11,21 +11,40 @@ from pathlib import Path
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
numpy
as
np
import
torch
import
torchvision.transforms
as
T
import
torchvision.transforms
as
T
plt
.
rcParams
[
"savefig.bbox"
]
=
'tight'
orig_img
=
Image
.
open
(
Path
(
'assets'
)
/
'astronaut.jpg'
)
orig_img
=
Image
.
open
(
Path
(
'assets'
)
/
'astronaut.jpg'
)
# if you change the seed, make sure that the randomly-applied transforms
# properly show that the image can be both transformed and *not* transformed!
def
plot
(
img
,
title
:
str
=
""
,
with_orig
:
bool
=
True
,
**
kwargs
):
torch
.
manual_seed
(
0
)
def
_plot
(
img
,
title
,
**
kwargs
):
plt
.
figure
().
suptitle
(
title
,
fontsize
=
25
)
plt
.
imshow
(
np
.
asarray
(
img
),
**
kwargs
)
def
plot
(
imgs
,
with_orig
=
True
,
row_title
=
None
,
**
imshow_kwargs
):
plt
.
axis
(
'off'
)
if
not
isinstance
(
imgs
[
0
],
list
):
# Make a 2d grid even if there's just 1 row
imgs
=
[
imgs
]
num_rows
=
len
(
imgs
)
num_cols
=
len
(
imgs
[
0
])
+
with_orig
fig
,
axs
=
plt
.
subplots
(
nrows
=
num_rows
,
ncols
=
num_cols
,
squeeze
=
False
)
for
row_idx
,
row
in
enumerate
(
imgs
):
row
=
[
orig_img
]
+
row
if
with_orig
else
row
for
col_idx
,
img
in
enumerate
(
row
):
ax
=
axs
[
row_idx
,
col_idx
]
ax
.
imshow
(
np
.
asarray
(
img
),
**
imshow_kwargs
)
ax
.
set
(
xticklabels
=
[],
yticklabels
=
[],
xticks
=
[],
yticks
=
[])
if
with_orig
:
if
with_orig
:
_plot
(
orig_img
,
"Original Image"
)
axs
[
0
,
0
].
set
(
title
=
'Original image'
)
_plot
(
img
,
title
,
**
kwargs
)
axs
[
0
,
0
].
title
.
set_size
(
8
)
if
row_title
is
not
None
:
for
row_idx
in
range
(
num_rows
):
axs
[
row_idx
,
0
].
set
(
ylabel
=
row_title
[
row_idx
])
plt
.
tight_layout
()
####################################
####################################
...
@@ -34,8 +53,8 @@ def plot(img, title: str = "", with_orig: bool = True, **kwargs):
...
@@ -34,8 +53,8 @@ def plot(img, title: str = "", with_orig: bool = True, **kwargs):
# The :class:`~torchvision.transforms.Pad` transform
# The :class:`~torchvision.transforms.Pad` transform
# (see also :func:`~torchvision.transforms.functional.pad`)
# (see also :func:`~torchvision.transforms.functional.pad`)
# fills image borders with some pixel values.
# fills image borders with some pixel values.
padded_img
=
T
.
Pad
(
padding
=
30
)(
orig_img
)
padded_img
s
=
[
T
.
Pad
(
padding
=
padding
)(
orig_img
)
for
padding
in
(
3
,
10
,
30
,
50
)]
plot
(
padded_img
,
"Padded Image"
)
plot
(
padded_img
s
)
####################################
####################################
# Resize
# Resize
...
@@ -43,8 +62,8 @@ plot(padded_img, "Padded Image")
...
@@ -43,8 +62,8 @@ plot(padded_img, "Padded Image")
# The :class:`~torchvision.transforms.Resize` transform
# The :class:`~torchvision.transforms.Resize` transform
# (see also :func:`~torchvision.transforms.functional.resize`)
# (see also :func:`~torchvision.transforms.functional.resize`)
# resizes an image.
# resizes an image.
resized_img
=
T
.
Resize
(
size
=
30
)(
orig_img
)
resized_img
s
=
[
T
.
Resize
(
size
=
size
)(
orig_img
)
for
size
in
(
30
,
50
,
100
,
orig_img
.
size
)]
plot
(
resized_img
,
"Resized Image"
)
plot
(
resized_img
s
)
####################################
####################################
# CenterCrop
# CenterCrop
...
@@ -52,9 +71,8 @@ plot(resized_img, "Resized Image")
...
@@ -52,9 +71,8 @@ plot(resized_img, "Resized Image")
# The :class:`~torchvision.transforms.CenterCrop` transform
# The :class:`~torchvision.transforms.CenterCrop` transform
# (see also :func:`~torchvision.transforms.functional.center_crop`)
# (see also :func:`~torchvision.transforms.functional.center_crop`)
# crops the given image at the center.
# crops the given image at the center.
center_cropped_img
=
T
.
CenterCrop
(
size
=
(
100
,
100
))(
orig_img
)
center_crops
=
[
T
.
CenterCrop
(
size
=
size
)(
orig_img
)
for
size
in
(
30
,
50
,
100
,
orig_img
.
size
)]
plot
(
center_cropped_img
,
"Center Cropped Image"
)
plot
(
center_crops
)
####################################
####################################
# FiveCrop
# FiveCrop
...
@@ -62,20 +80,8 @@ plot(center_cropped_img, "Center Cropped Image")
...
@@ -62,20 +80,8 @@ plot(center_cropped_img, "Center Cropped Image")
# The :class:`~torchvision.transforms.FiveCrop` transform
# The :class:`~torchvision.transforms.FiveCrop` transform
# (see also :func:`~torchvision.transforms.functional.five_crop`)
# (see also :func:`~torchvision.transforms.functional.five_crop`)
# crops the given image into four corners and the central crop.
# crops the given image into four corners and the central crop.
(
img1
,
img2
,
img3
,
img4
,
img5
)
=
T
.
FiveCrop
(
size
=
(
100
,
100
))(
orig_img
)
(
top_left
,
top_right
,
bottom_left
,
bottom_right
,
center
)
=
T
.
FiveCrop
(
size
=
(
100
,
100
))(
orig_img
)
plot
(
img1
,
"Top Left Corner Image"
)
plot
([
top_left
,
top_right
,
bottom_left
,
bottom_right
,
center
])
plot
(
img2
,
"Top Right Corner Image"
,
with_orig
=
False
)
plot
(
img3
,
"Bottom Left Corner Image"
,
with_orig
=
False
)
plot
(
img4
,
"Bottom Right Corner Image"
,
with_orig
=
False
)
plot
(
img5
,
"Center Image"
,
with_orig
=
False
)
####################################
# ColorJitter
# -----------
# The :class:`~torchvision.transforms.ColorJitter` transform
# randomly changes the brightness, saturation, and other properties of an image.
jitted_img
=
T
.
ColorJitter
(
brightness
=
.
5
,
hue
=
.
3
)(
orig_img
)
plot
(
jitted_img
,
"Jitted Image"
)
####################################
####################################
# Grayscale
# Grayscale
...
@@ -84,120 +90,130 @@ plot(jitted_img, "Jitted Image")
...
@@ -84,120 +90,130 @@ plot(jitted_img, "Jitted Image")
# (see also :func:`~torchvision.transforms.functional.to_grayscale`)
# (see also :func:`~torchvision.transforms.functional.to_grayscale`)
# converts an image to grayscale
# converts an image to grayscale
gray_img
=
T
.
Grayscale
()(
orig_img
)
gray_img
=
T
.
Grayscale
()(
orig_img
)
plot
(
gray_img
,
"Grayscale Image"
,
cmap
=
'gray'
)
plot
(
[
gray_img
]
,
cmap
=
'gray'
)
####################################
####################################
# Random
Perspective
# Random
transforms
# -----------------
# -----------------
# The following transforms are random, which means that the same transfomer
# instance will produce different result each time it transforms a given image.
#
# ColorJitter
# ~~~~~~~~~~~
# The :class:`~torchvision.transforms.ColorJitter` transform
# randomly changes the brightness, saturation, and other properties of an image.
jitter
=
T
.
ColorJitter
(
brightness
=
.
5
,
hue
=
.
3
)
jitted_imgs
=
[
jitter
(
orig_img
)
for
_
in
range
(
4
)]
plot
(
jitted_imgs
)
####################################
# GaussianBlur
# ~~~~~~~~~~~~
# The :class:`~torchvision.transforms.GaussianBlur` transform
# (see also :func:`~torchvision.transforms.functional.gaussian_blur`)
# performs gaussian blur transform on an image.
blurrer
=
T
.
GaussianBlur
(
kernel_size
=
(
5
,
9
),
sigma
=
(
0.1
,
5
))
blurred_imgs
=
[
blurrer
(
orig_img
)
for
_
in
range
(
4
)]
plot
(
blurred_imgs
)
####################################
# RandomPerspective
# ~~~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomPerspective` transform
# The :class:`~torchvision.transforms.RandomPerspective` transform
# (see also :func:`~torchvision.transforms.functional.perspective`)
# (see also :func:`~torchvision.transforms.functional.perspective`)
# performs random perspective transform on an image.
# performs random perspective transform on an image.
perspectived_img
=
T
.
RandomPerspective
(
distortion_scale
=
0.6
,
p
=
1.0
)(
orig_img
)
perspective_transformer
=
T
.
RandomPerspective
(
distortion_scale
=
0.6
,
p
=
1.0
)
plot
(
perspectived_img
,
"Perspective transformed Image"
)
perspective_imgs
=
[
perspective_transformer
(
orig_img
)
for
_
in
range
(
4
)]
plot
(
perspective_imgs
)
####################################
####################################
# RandomRotation
# RandomRotation
#
--------------
#
~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomRotation` transform
# The :class:`~torchvision.transforms.RandomRotation` transform
# (see also :func:`~torchvision.transforms.functional.rotate`)
# (see also :func:`~torchvision.transforms.functional.rotate`)
# rotates an image with random angle.
# rotates an image with random angle.
rotated_img
=
T
.
RandomRotation
(
degrees
=
(
30
,
70
))(
orig_img
)
rotater
=
T
.
RandomRotation
(
degrees
=
(
0
,
180
))
plot
(
rotated_img
,
"Rotated Image"
)
rotated_imgs
=
[
rotater
(
orig_img
)
for
_
in
range
(
4
)]
plot
(
rotated_imgs
)
####################################
####################################
# RandomAffine
# RandomAffine
#
------------
#
~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomAffine` transform
# The :class:`~torchvision.transforms.RandomAffine` transform
# (see also :func:`~torchvision.transforms.functional.affine`)
# (see also :func:`~torchvision.transforms.functional.affine`)
# performs random affine transform on an image.
# performs random affine transform on an image.
affined_img
=
T
.
RandomAffine
(
degrees
=
(
30
,
70
),
translate
=
(
0.1
,
0.3
),
scale
=
(
0.5
,
0.75
))(
orig_img
)
affine_transfomer
=
T
.
RandomAffine
(
degrees
=
(
30
,
70
),
translate
=
(
0.1
,
0.3
),
scale
=
(
0.5
,
0.75
))
plot
(
affined_img
,
"Affine transformed Image"
)
affine_imgs
=
[
affine_transfomer
(
orig_img
)
for
_
in
range
(
4
)]
plot
(
affine_imgs
)
####################################
####################################
# RandomCrop
# RandomCrop
#
----------
#
~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomCrop` transform
# The :class:`~torchvision.transforms.RandomCrop` transform
# (see also :func:`~torchvision.transforms.functional.crop`)
# (see also :func:`~torchvision.transforms.functional.crop`)
# crops an image at a random location.
# crops an image at a random location.
crop_img
=
T
.
RandomCrop
(
size
=
(
128
,
128
))(
orig_img
)
cropper
=
T
.
RandomCrop
(
size
=
(
128
,
128
))
plot
(
crop_img
,
"Random cropped Image"
)
crops
=
[
cropper
(
orig_img
)
for
_
in
range
(
4
)]
plot
(
crops
)
####################################
####################################
# RandomResizedCrop
# RandomResizedCrop
#
-----------------
#
~~~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomResizedCrop` transform
# The :class:`~torchvision.transforms.RandomResizedCrop` transform
# (see also :func:`~torchvision.transforms.functional.resized_crop`)
# (see also :func:`~torchvision.transforms.functional.resized_crop`)
# crops an image at a random location, and then resizes the crop to a given
# crops an image at a random location, and then resizes the crop to a given
# size.
# size.
resized_crop_img
=
T
.
RandomResizedCrop
(
size
=
(
32
,
32
))(
orig_img
)
resize_cropper
=
T
.
RandomResizedCrop
(
size
=
(
32
,
32
))
plot
(
resized_crop_img
,
"Random resized cropped Image"
)
resized_crops
=
[
resize_cropper
(
orig_img
)
for
_
in
range
(
4
)]
plot
(
resized_crops
)
####################################
####################################
# AutoAugment
# ~~~~~~~~~~~
# The :class:`~torchvision.transforms.AutoAugment` transform
# automatically augments data based on a given auto-augmentation policy.
# See :class:`~torchvision.transforms.AutoAugmentPolicy` for the available policies.
policies
=
[
T
.
AutoAugmentPolicy
.
CIFAR10
,
T
.
AutoAugmentPolicy
.
IMAGENET
,
T
.
AutoAugmentPolicy
.
SVHN
]
augmenters
=
[
T
.
AutoAugment
(
policy
)
for
policy
in
policies
]
imgs
=
[
[
augmenter
(
orig_img
)
for
_
in
range
(
4
)]
for
augmenter
in
augmenters
]
row_title
=
[
str
(
policy
).
split
(
'.'
)[
-
1
]
for
policy
in
policies
]
plot
(
imgs
,
row_title
=
row_title
)
####################################
# Randomly-applied transforms
# ---------------------------
#
# Some transforms are randomly-applied given a probability ``p``. That is, the
# transformed image may actually be the same as the original one, even when
# called with the same transformer instance!
#
# RandomHorizontalFlip
# RandomHorizontalFlip
#
--------------------
#
~~~~~~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomHorizontalFlip` transform
# The :class:`~torchvision.transforms.RandomHorizontalFlip` transform
# (see also :func:`~torchvision.transforms.functional.hflip`)
# (see also :func:`~torchvision.transforms.functional.hflip`)
# performs horizontal flip of an image, with a given probability.
# performs horizontal flip of an image, with a given probability.
#
hflipper
=
T
.
RandomHorizontalFlip
(
p
=
0.5
)
# .. note::
transformed_imgs
=
[
hflipper
(
orig_img
)
for
_
in
range
(
4
)]
# Since the transform is applied randomly, the two images below may actually be
plot
(
transformed_imgs
)
# the same.
random_hflip_img
=
T
.
RandomHorizontalFlip
(
p
=
0.5
)(
orig_img
)
plot
(
random_hflip_img
,
"Random horizontal flipped Image"
)
####################################
####################################
# RandomVerticalFlip
# RandomVerticalFlip
#
------------------
#
~~~~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomVerticalFlip` transform
# The :class:`~torchvision.transforms.RandomVerticalFlip` transform
# (see also :func:`~torchvision.transforms.functional.vflip`)
# (see also :func:`~torchvision.transforms.functional.vflip`)
# performs vertical flip of an image, with a given probability.
# performs vertical flip of an image, with a given probability.
#
vflipper
=
T
.
RandomVerticalFlip
(
p
=
0.5
)
# .. note::
transformed_imgs
=
[
vflipper
(
orig_img
)
for
_
in
range
(
4
)]
# Since the transform is applied randomly, the two images below may actually be
plot
(
transformed_imgs
)
# the same.
random_vflip_img
=
T
.
RandomVerticalFlip
(
p
=
0.5
)(
orig_img
)
plot
(
random_vflip_img
,
"Random vertical flipped Image"
)
####################################
####################################
# RandomApply
# RandomApply
#
-----------
#
~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomApply` transform
# The :class:`~torchvision.transforms.RandomApply` transform
# randomly applies a list of transforms, with a given probability
# randomly applies a list of transforms, with a given probability.
#
applier
=
T
.
RandomApply
(
transforms
=
[
T
.
RandomCrop
(
size
=
(
64
,
64
))],
p
=
0.5
)
# .. note::
transformed_imgs
=
[
applier
(
orig_img
)
for
_
in
range
(
4
)]
# Since the transform is applied randomly, the two images below may actually be
plot
(
transformed_imgs
)
# the same.
random_apply_img
=
T
.
RandomApply
(
transforms
=
[
T
.
RandomCrop
(
size
=
(
64
,
64
))],
p
=
0.5
)(
orig_img
)
plot
(
random_apply_img
,
"Random Apply transformed Image"
)
####################################
# GaussianBlur
# ------------
# The :class:`~torchvision.transforms.GaussianBlur` transform
# (see also :func:`~torchvision.transforms.functional.gaussian_blur`)
# performs gaussianblur transform on an image.
gaus_blur_img
=
T
.
GaussianBlur
(
kernel_size
=
(
5
,
9
),
sigma
=
(
0.4
,
3.0
))(
orig_img
)
plot
(
gaus_blur_img
,
"Gaussian Blurred Image"
)
####################################
# AutoAugment
# -----------
# The :class:`~torchvision.transforms.AutoAugment` transform
# automatically augments data based on a given auto-augmentation policy.
# See :class:`~torchvision.transforms.AutoAugmentPolicy` for the available policies.
policies
=
[
T
.
AutoAugmentPolicy
.
CIFAR10
,
T
.
AutoAugmentPolicy
.
IMAGENET
,
T
.
AutoAugmentPolicy
.
SVHN
]
num_cols
=
5
fig
,
axs
=
plt
.
subplots
(
nrows
=
len
(
policies
),
ncols
=
num_cols
)
fig
.
suptitle
(
"Auto-augmented images with different policies"
)
for
pol_idx
,
policy
in
enumerate
(
policies
):
auto_augmenter
=
T
.
AutoAugment
(
policy
)
for
col
in
range
(
num_cols
):
augmented_img
=
auto_augmenter
(
orig_img
)
ax
=
axs
[
pol_idx
,
col
]
ax
.
imshow
(
np
.
asarray
(
augmented_img
))
ax
.
set
(
xticklabels
=
[],
yticklabels
=
[],
xticks
=
[],
yticks
=
[])
if
col
==
0
:
ax
.
set
(
ylabel
=
str
(
policy
).
split
(
'.'
)[
-
1
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment