Unverified Commit 5ac27fe3 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Rework transforms example in gallery (#3744)

parent c8f7d772
......@@ -11,21 +11,40 @@ from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as T
plt.rcParams["savefig.bbox"] = 'tight'
orig_img = Image.open(Path('assets') / 'astronaut.jpg')
def plot(img, title: str = "", with_orig: bool = True, **kwargs):
def _plot(img, title, **kwargs):
plt.figure().suptitle(title, fontsize=25)
plt.imshow(np.asarray(img), **kwargs)
plt.axis('off')
# if you change the seed, make sure that the randomly-applied transforms
# properly show that the image can be both transformed and *not* transformed!
torch.manual_seed(0)
def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
if not isinstance(imgs[0], list):
# Make a 2d grid even if there's just 1 row
imgs = [imgs]
num_rows = len(imgs)
num_cols = len(imgs[0]) + with_orig
fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
for row_idx, row in enumerate(imgs):
row = [orig_img] + row if with_orig else row
for col_idx, img in enumerate(row):
ax = axs[row_idx, col_idx]
ax.imshow(np.asarray(img), **imshow_kwargs)
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if with_orig:
_plot(orig_img, "Original Image")
_plot(img, title, **kwargs)
axs[0, 0].set(title='Original image')
axs[0, 0].title.set_size(8)
if row_title is not None:
for row_idx in range(num_rows):
axs[row_idx, 0].set(ylabel=row_title[row_idx])
plt.tight_layout()
####################################
......@@ -34,8 +53,8 @@ def plot(img, title: str = "", with_orig: bool = True, **kwargs):
# The :class:`~torchvision.transforms.Pad` transform
# (see also :func:`~torchvision.transforms.functional.pad`)
# fills image borders with some pixel values.
padded_img = T.Pad(padding=30)(orig_img)
plot(padded_img, "Padded Image")
padded_imgs = [T.Pad(padding=padding)(orig_img) for padding in (3, 10, 30, 50)]
plot(padded_imgs)
####################################
# Resize
......@@ -43,8 +62,8 @@ plot(padded_img, "Padded Image")
# The :class:`~torchvision.transforms.Resize` transform
# (see also :func:`~torchvision.transforms.functional.resize`)
# resizes an image.
resized_img = T.Resize(size=30)(orig_img)
plot(resized_img, "Resized Image")
resized_imgs = [T.Resize(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
plot(resized_imgs)
####################################
# CenterCrop
......@@ -52,9 +71,8 @@ plot(resized_img, "Resized Image")
# The :class:`~torchvision.transforms.CenterCrop` transform
# (see also :func:`~torchvision.transforms.functional.center_crop`)
# crops the given image at the center.
center_cropped_img = T.CenterCrop(size=(100, 100))(orig_img)
plot(center_cropped_img, "Center Cropped Image")
center_crops = [T.CenterCrop(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
plot(center_crops)
####################################
# FiveCrop
......@@ -62,20 +80,8 @@ plot(center_cropped_img, "Center Cropped Image")
# The :class:`~torchvision.transforms.FiveCrop` transform
# (see also :func:`~torchvision.transforms.functional.five_crop`)
# crops the given image into four corners and the central crop.
(img1, img2, img3, img4, img5) = T.FiveCrop(size=(100, 100))(orig_img)
plot(img1, "Top Left Corner Image")
plot(img2, "Top Right Corner Image", with_orig=False)
plot(img3, "Bottom Left Corner Image", with_orig=False)
plot(img4, "Bottom Right Corner Image", with_orig=False)
plot(img5, "Center Image", with_orig=False)
####################################
# ColorJitter
# -----------
# The :class:`~torchvision.transforms.ColorJitter` transform
# randomly changes the brightness, saturation, and other properties of an image.
jitted_img = T.ColorJitter(brightness=.5, hue=.3)(orig_img)
plot(jitted_img, "Jitted Image")
(top_left, top_right, bottom_left, bottom_right, center) = T.FiveCrop(size=(100, 100))(orig_img)
plot([top_left, top_right, bottom_left, bottom_right, center])
####################################
# Grayscale
......@@ -84,120 +90,130 @@ plot(jitted_img, "Jitted Image")
# (see also :func:`~torchvision.transforms.functional.to_grayscale`)
# converts an image to grayscale
gray_img = T.Grayscale()(orig_img)
plot(gray_img, "Grayscale Image", cmap='gray')
plot([gray_img], cmap='gray')
####################################
# RandomPerspective
# Random transforms
# -----------------
# The following transforms are random, which means that the same transfomer
# instance will produce different result each time it transforms a given image.
#
# ColorJitter
# ~~~~~~~~~~~
# The :class:`~torchvision.transforms.ColorJitter` transform
# randomly changes the brightness, saturation, and other properties of an image.
jitter = T.ColorJitter(brightness=.5, hue=.3)
jitted_imgs = [jitter(orig_img) for _ in range(4)]
plot(jitted_imgs)
####################################
# GaussianBlur
# ~~~~~~~~~~~~
# The :class:`~torchvision.transforms.GaussianBlur` transform
# (see also :func:`~torchvision.transforms.functional.gaussian_blur`)
# performs gaussian blur transform on an image.
blurrer = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))
blurred_imgs = [blurrer(orig_img) for _ in range(4)]
plot(blurred_imgs)
####################################
# RandomPerspective
# ~~~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomPerspective` transform
# (see also :func:`~torchvision.transforms.functional.perspective`)
# performs random perspective transform on an image.
perspectived_img = T.RandomPerspective(distortion_scale=0.6, p=1.0)(orig_img)
plot(perspectived_img, "Perspective transformed Image")
perspective_transformer = T.RandomPerspective(distortion_scale=0.6, p=1.0)
perspective_imgs = [perspective_transformer(orig_img) for _ in range(4)]
plot(perspective_imgs)
####################################
# RandomRotation
# --------------
# ~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomRotation` transform
# (see also :func:`~torchvision.transforms.functional.rotate`)
# rotates an image with random angle.
rotated_img = T.RandomRotation(degrees=(30, 70))(orig_img)
plot(rotated_img, "Rotated Image")
rotater = T.RandomRotation(degrees=(0, 180))
rotated_imgs = [rotater(orig_img) for _ in range(4)]
plot(rotated_imgs)
####################################
# RandomAffine
# ------------
# ~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomAffine` transform
# (see also :func:`~torchvision.transforms.functional.affine`)
# performs random affine transform on an image.
affined_img = T.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75))(orig_img)
plot(affined_img, "Affine transformed Image")
affine_transfomer = T.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75))
affine_imgs = [affine_transfomer(orig_img) for _ in range(4)]
plot(affine_imgs)
####################################
# RandomCrop
# ----------
# ~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomCrop` transform
# (see also :func:`~torchvision.transforms.functional.crop`)
# crops an image at a random location.
crop_img = T.RandomCrop(size=(128, 128))(orig_img)
plot(crop_img, "Random cropped Image")
cropper = T.RandomCrop(size=(128, 128))
crops = [cropper(orig_img) for _ in range(4)]
plot(crops)
####################################
# RandomResizedCrop
# -----------------
# ~~~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomResizedCrop` transform
# (see also :func:`~torchvision.transforms.functional.resized_crop`)
# crops an image at a random location, and then resizes the crop to a given
# size.
resized_crop_img = T.RandomResizedCrop(size=(32, 32))(orig_img)
plot(resized_crop_img, "Random resized cropped Image")
resize_cropper = T.RandomResizedCrop(size=(32, 32))
resized_crops = [resize_cropper(orig_img) for _ in range(4)]
plot(resized_crops)
####################################
# AutoAugment
# ~~~~~~~~~~~
# The :class:`~torchvision.transforms.AutoAugment` transform
# automatically augments data based on a given auto-augmentation policy.
# See :class:`~torchvision.transforms.AutoAugmentPolicy` for the available policies.
policies = [T.AutoAugmentPolicy.CIFAR10, T.AutoAugmentPolicy.IMAGENET, T.AutoAugmentPolicy.SVHN]
augmenters = [T.AutoAugment(policy) for policy in policies]
imgs = [
[augmenter(orig_img) for _ in range(4)]
for augmenter in augmenters
]
row_title = [str(policy).split('.')[-1] for policy in policies]
plot(imgs, row_title=row_title)
####################################
# Randomly-applied transforms
# ---------------------------
#
# Some transforms are randomly-applied given a probability ``p``. That is, the
# transformed image may actually be the same as the original one, even when
# called with the same transformer instance!
#
# RandomHorizontalFlip
# --------------------
# ~~~~~~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomHorizontalFlip` transform
# (see also :func:`~torchvision.transforms.functional.hflip`)
# performs horizontal flip of an image, with a given probability.
#
# .. note::
# Since the transform is applied randomly, the two images below may actually be
# the same.
random_hflip_img = T.RandomHorizontalFlip(p=0.5)(orig_img)
plot(random_hflip_img, "Random horizontal flipped Image")
hflipper = T.RandomHorizontalFlip(p=0.5)
transformed_imgs = [hflipper(orig_img) for _ in range(4)]
plot(transformed_imgs)
####################################
# RandomVerticalFlip
# ------------------
# ~~~~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomVerticalFlip` transform
# (see also :func:`~torchvision.transforms.functional.vflip`)
# performs vertical flip of an image, with a given probability.
#
# .. note::
# Since the transform is applied randomly, the two images below may actually be
# the same.
random_vflip_img = T.RandomVerticalFlip(p=0.5)(orig_img)
plot(random_vflip_img, "Random vertical flipped Image")
vflipper = T.RandomVerticalFlip(p=0.5)
transformed_imgs = [vflipper(orig_img) for _ in range(4)]
plot(transformed_imgs)
####################################
# RandomApply
# -----------
# ~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomApply` transform
# randomly applies a list of transforms, with a given probability
#
# .. note::
# Since the transform is applied randomly, the two images below may actually be
# the same.
random_apply_img = T.RandomApply(transforms=[T.RandomCrop(size=(64, 64))], p=0.5)(orig_img)
plot(random_apply_img, "Random Apply transformed Image")
####################################
# GaussianBlur
# ------------
# The :class:`~torchvision.transforms.GaussianBlur` transform
# (see also :func:`~torchvision.transforms.functional.gaussian_blur`)
# performs gaussianblur transform on an image.
gaus_blur_img = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.4, 3.0))(orig_img)
plot(gaus_blur_img, "Gaussian Blurred Image")
####################################
# AutoAugment
# -----------
# The :class:`~torchvision.transforms.AutoAugment` transform
# automatically augments data based on a given auto-augmentation policy.
# See :class:`~torchvision.transforms.AutoAugmentPolicy` for the available policies.
policies = [T.AutoAugmentPolicy.CIFAR10, T.AutoAugmentPolicy.IMAGENET, T.AutoAugmentPolicy.SVHN]
num_cols = 5
fig, axs = plt.subplots(nrows=len(policies), ncols=num_cols)
fig.suptitle("Auto-augmented images with different policies")
for pol_idx, policy in enumerate(policies):
auto_augmenter = T.AutoAugment(policy)
for col in range(num_cols):
augmented_img = auto_augmenter(orig_img)
ax = axs[pol_idx, col]
ax.imshow(np.asarray(augmented_img))
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if col == 0:
ax.set(ylabel=str(policy).split('.')[-1])
# randomly applies a list of transforms, with a given probability.
applier = T.RandomApply(transforms=[T.RandomCrop(size=(64, 64))], p=0.5)
transformed_imgs = [applier(orig_img) for _ in range(4)]
plot(transformed_imgs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment