"examples/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "17a70815259222570feb071034acd7bae2adc019"
Unverified Commit 5ac27fe3 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Rework transforms example in gallery (#3744)

parent c8f7d772
...@@ -11,21 +11,40 @@ from pathlib import Path ...@@ -11,21 +11,40 @@ from pathlib import Path
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch
import torchvision.transforms as T import torchvision.transforms as T
plt.rcParams["savefig.bbox"] = 'tight'
orig_img = Image.open(Path('assets') / 'astronaut.jpg') orig_img = Image.open(Path('assets') / 'astronaut.jpg')
# if you change the seed, make sure that the randomly-applied transforms
# properly show that the image can be both transformed and *not* transformed!
def plot(img, title: str = "", with_orig: bool = True, **kwargs): torch.manual_seed(0)
def _plot(img, title, **kwargs):
plt.figure().suptitle(title, fontsize=25)
plt.imshow(np.asarray(img), **kwargs) def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
plt.axis('off') if not isinstance(imgs[0], list):
# Make a 2d grid even if there's just 1 row
imgs = [imgs]
num_rows = len(imgs)
num_cols = len(imgs[0]) + with_orig
fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
for row_idx, row in enumerate(imgs):
row = [orig_img] + row if with_orig else row
for col_idx, img in enumerate(row):
ax = axs[row_idx, col_idx]
ax.imshow(np.asarray(img), **imshow_kwargs)
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if with_orig: if with_orig:
_plot(orig_img, "Original Image") axs[0, 0].set(title='Original image')
_plot(img, title, **kwargs) axs[0, 0].title.set_size(8)
if row_title is not None:
for row_idx in range(num_rows):
axs[row_idx, 0].set(ylabel=row_title[row_idx])
plt.tight_layout()
#################################### ####################################
...@@ -34,8 +53,8 @@ def plot(img, title: str = "", with_orig: bool = True, **kwargs): ...@@ -34,8 +53,8 @@ def plot(img, title: str = "", with_orig: bool = True, **kwargs):
# The :class:`~torchvision.transforms.Pad` transform # The :class:`~torchvision.transforms.Pad` transform
# (see also :func:`~torchvision.transforms.functional.pad`) # (see also :func:`~torchvision.transforms.functional.pad`)
# fills image borders with some pixel values. # fills image borders with some pixel values.
padded_img = T.Pad(padding=30)(orig_img) padded_imgs = [T.Pad(padding=padding)(orig_img) for padding in (3, 10, 30, 50)]
plot(padded_img, "Padded Image") plot(padded_imgs)
#################################### ####################################
# Resize # Resize
...@@ -43,8 +62,8 @@ plot(padded_img, "Padded Image") ...@@ -43,8 +62,8 @@ plot(padded_img, "Padded Image")
# The :class:`~torchvision.transforms.Resize` transform # The :class:`~torchvision.transforms.Resize` transform
# (see also :func:`~torchvision.transforms.functional.resize`) # (see also :func:`~torchvision.transforms.functional.resize`)
# resizes an image. # resizes an image.
resized_img = T.Resize(size=30)(orig_img) resized_imgs = [T.Resize(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
plot(resized_img, "Resized Image") plot(resized_imgs)
#################################### ####################################
# CenterCrop # CenterCrop
...@@ -52,9 +71,8 @@ plot(resized_img, "Resized Image") ...@@ -52,9 +71,8 @@ plot(resized_img, "Resized Image")
# The :class:`~torchvision.transforms.CenterCrop` transform # The :class:`~torchvision.transforms.CenterCrop` transform
# (see also :func:`~torchvision.transforms.functional.center_crop`) # (see also :func:`~torchvision.transforms.functional.center_crop`)
# crops the given image at the center. # crops the given image at the center.
center_cropped_img = T.CenterCrop(size=(100, 100))(orig_img) center_crops = [T.CenterCrop(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
plot(center_cropped_img, "Center Cropped Image") plot(center_crops)
#################################### ####################################
# FiveCrop # FiveCrop
...@@ -62,20 +80,8 @@ plot(center_cropped_img, "Center Cropped Image") ...@@ -62,20 +80,8 @@ plot(center_cropped_img, "Center Cropped Image")
# The :class:`~torchvision.transforms.FiveCrop` transform # The :class:`~torchvision.transforms.FiveCrop` transform
# (see also :func:`~torchvision.transforms.functional.five_crop`) # (see also :func:`~torchvision.transforms.functional.five_crop`)
# crops the given image into four corners and the central crop. # crops the given image into four corners and the central crop.
(img1, img2, img3, img4, img5) = T.FiveCrop(size=(100, 100))(orig_img) (top_left, top_right, bottom_left, bottom_right, center) = T.FiveCrop(size=(100, 100))(orig_img)
plot(img1, "Top Left Corner Image") plot([top_left, top_right, bottom_left, bottom_right, center])
plot(img2, "Top Right Corner Image", with_orig=False)
plot(img3, "Bottom Left Corner Image", with_orig=False)
plot(img4, "Bottom Right Corner Image", with_orig=False)
plot(img5, "Center Image", with_orig=False)
####################################
# ColorJitter
# -----------
# The :class:`~torchvision.transforms.ColorJitter` transform
# randomly changes the brightness, saturation, and other properties of an image.
jitted_img = T.ColorJitter(brightness=.5, hue=.3)(orig_img)
plot(jitted_img, "Jitted Image")
#################################### ####################################
# Grayscale # Grayscale
...@@ -84,120 +90,130 @@ plot(jitted_img, "Jitted Image") ...@@ -84,120 +90,130 @@ plot(jitted_img, "Jitted Image")
# (see also :func:`~torchvision.transforms.functional.to_grayscale`) # (see also :func:`~torchvision.transforms.functional.to_grayscale`)
# converts an image to grayscale # converts an image to grayscale
gray_img = T.Grayscale()(orig_img) gray_img = T.Grayscale()(orig_img)
plot(gray_img, "Grayscale Image", cmap='gray') plot([gray_img], cmap='gray')
#################################### ####################################
# RandomPerspective # Random transforms
# ----------------- # -----------------
# The following transforms are random, which means that the same transfomer
# instance will produce different result each time it transforms a given image.
#
# ColorJitter
# ~~~~~~~~~~~
# The :class:`~torchvision.transforms.ColorJitter` transform
# randomly changes the brightness, saturation, and other properties of an image.
jitter = T.ColorJitter(brightness=.5, hue=.3)
jitted_imgs = [jitter(orig_img) for _ in range(4)]
plot(jitted_imgs)
####################################
# GaussianBlur
# ~~~~~~~~~~~~
# The :class:`~torchvision.transforms.GaussianBlur` transform
# (see also :func:`~torchvision.transforms.functional.gaussian_blur`)
# performs gaussian blur transform on an image.
blurrer = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))
blurred_imgs = [blurrer(orig_img) for _ in range(4)]
plot(blurred_imgs)
####################################
# RandomPerspective
# ~~~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomPerspective` transform # The :class:`~torchvision.transforms.RandomPerspective` transform
# (see also :func:`~torchvision.transforms.functional.perspective`) # (see also :func:`~torchvision.transforms.functional.perspective`)
# performs random perspective transform on an image. # performs random perspective transform on an image.
perspectived_img = T.RandomPerspective(distortion_scale=0.6, p=1.0)(orig_img) perspective_transformer = T.RandomPerspective(distortion_scale=0.6, p=1.0)
plot(perspectived_img, "Perspective transformed Image") perspective_imgs = [perspective_transformer(orig_img) for _ in range(4)]
plot(perspective_imgs)
#################################### ####################################
# RandomRotation # RandomRotation
# -------------- # ~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomRotation` transform # The :class:`~torchvision.transforms.RandomRotation` transform
# (see also :func:`~torchvision.transforms.functional.rotate`) # (see also :func:`~torchvision.transforms.functional.rotate`)
# rotates an image with random angle. # rotates an image with random angle.
rotated_img = T.RandomRotation(degrees=(30, 70))(orig_img) rotater = T.RandomRotation(degrees=(0, 180))
plot(rotated_img, "Rotated Image") rotated_imgs = [rotater(orig_img) for _ in range(4)]
plot(rotated_imgs)
#################################### ####################################
# RandomAffine # RandomAffine
# ------------ # ~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomAffine` transform # The :class:`~torchvision.transforms.RandomAffine` transform
# (see also :func:`~torchvision.transforms.functional.affine`) # (see also :func:`~torchvision.transforms.functional.affine`)
# performs random affine transform on an image. # performs random affine transform on an image.
affined_img = T.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75))(orig_img) affine_transfomer = T.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75))
plot(affined_img, "Affine transformed Image") affine_imgs = [affine_transfomer(orig_img) for _ in range(4)]
plot(affine_imgs)
#################################### ####################################
# RandomCrop # RandomCrop
# ---------- # ~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomCrop` transform # The :class:`~torchvision.transforms.RandomCrop` transform
# (see also :func:`~torchvision.transforms.functional.crop`) # (see also :func:`~torchvision.transforms.functional.crop`)
# crops an image at a random location. # crops an image at a random location.
crop_img = T.RandomCrop(size=(128, 128))(orig_img) cropper = T.RandomCrop(size=(128, 128))
plot(crop_img, "Random cropped Image") crops = [cropper(orig_img) for _ in range(4)]
plot(crops)
#################################### ####################################
# RandomResizedCrop # RandomResizedCrop
# ----------------- # ~~~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomResizedCrop` transform # The :class:`~torchvision.transforms.RandomResizedCrop` transform
# (see also :func:`~torchvision.transforms.functional.resized_crop`) # (see also :func:`~torchvision.transforms.functional.resized_crop`)
# crops an image at a random location, and then resizes the crop to a given # crops an image at a random location, and then resizes the crop to a given
# size. # size.
resized_crop_img = T.RandomResizedCrop(size=(32, 32))(orig_img) resize_cropper = T.RandomResizedCrop(size=(32, 32))
plot(resized_crop_img, "Random resized cropped Image") resized_crops = [resize_cropper(orig_img) for _ in range(4)]
plot(resized_crops)
#################################### ####################################
# AutoAugment
# ~~~~~~~~~~~
# The :class:`~torchvision.transforms.AutoAugment` transform
# automatically augments data based on a given auto-augmentation policy.
# See :class:`~torchvision.transforms.AutoAugmentPolicy` for the available policies.
policies = [T.AutoAugmentPolicy.CIFAR10, T.AutoAugmentPolicy.IMAGENET, T.AutoAugmentPolicy.SVHN]
augmenters = [T.AutoAugment(policy) for policy in policies]
imgs = [
[augmenter(orig_img) for _ in range(4)]
for augmenter in augmenters
]
row_title = [str(policy).split('.')[-1] for policy in policies]
plot(imgs, row_title=row_title)
####################################
# Randomly-applied transforms
# ---------------------------
#
# Some transforms are randomly-applied given a probability ``p``. That is, the
# transformed image may actually be the same as the original one, even when
# called with the same transformer instance!
#
# RandomHorizontalFlip # RandomHorizontalFlip
# -------------------- # ~~~~~~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomHorizontalFlip` transform # The :class:`~torchvision.transforms.RandomHorizontalFlip` transform
# (see also :func:`~torchvision.transforms.functional.hflip`) # (see also :func:`~torchvision.transforms.functional.hflip`)
# performs horizontal flip of an image, with a given probability. # performs horizontal flip of an image, with a given probability.
# hflipper = T.RandomHorizontalFlip(p=0.5)
# .. note:: transformed_imgs = [hflipper(orig_img) for _ in range(4)]
# Since the transform is applied randomly, the two images below may actually be plot(transformed_imgs)
# the same.
random_hflip_img = T.RandomHorizontalFlip(p=0.5)(orig_img)
plot(random_hflip_img, "Random horizontal flipped Image")
#################################### ####################################
# RandomVerticalFlip # RandomVerticalFlip
# ------------------ # ~~~~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomVerticalFlip` transform # The :class:`~torchvision.transforms.RandomVerticalFlip` transform
# (see also :func:`~torchvision.transforms.functional.vflip`) # (see also :func:`~torchvision.transforms.functional.vflip`)
# performs vertical flip of an image, with a given probability. # performs vertical flip of an image, with a given probability.
# vflipper = T.RandomVerticalFlip(p=0.5)
# .. note:: transformed_imgs = [vflipper(orig_img) for _ in range(4)]
# Since the transform is applied randomly, the two images below may actually be plot(transformed_imgs)
# the same.
random_vflip_img = T.RandomVerticalFlip(p=0.5)(orig_img)
plot(random_vflip_img, "Random vertical flipped Image")
#################################### ####################################
# RandomApply # RandomApply
# ----------- # ~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomApply` transform # The :class:`~torchvision.transforms.RandomApply` transform
# randomly applies a list of transforms, with a given probability # randomly applies a list of transforms, with a given probability.
# applier = T.RandomApply(transforms=[T.RandomCrop(size=(64, 64))], p=0.5)
# .. note:: transformed_imgs = [applier(orig_img) for _ in range(4)]
# Since the transform is applied randomly, the two images below may actually be plot(transformed_imgs)
# the same.
random_apply_img = T.RandomApply(transforms=[T.RandomCrop(size=(64, 64))], p=0.5)(orig_img)
plot(random_apply_img, "Random Apply transformed Image")
####################################
# GaussianBlur
# ------------
# The :class:`~torchvision.transforms.GaussianBlur` transform
# (see also :func:`~torchvision.transforms.functional.gaussian_blur`)
# performs gaussianblur transform on an image.
gaus_blur_img = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.4, 3.0))(orig_img)
plot(gaus_blur_img, "Gaussian Blurred Image")
####################################
# AutoAugment
# -----------
# The :class:`~torchvision.transforms.AutoAugment` transform
# automatically augments data based on a given auto-augmentation policy.
# See :class:`~torchvision.transforms.AutoAugmentPolicy` for the available policies.
policies = [T.AutoAugmentPolicy.CIFAR10, T.AutoAugmentPolicy.IMAGENET, T.AutoAugmentPolicy.SVHN]
num_cols = 5
fig, axs = plt.subplots(nrows=len(policies), ncols=num_cols)
fig.suptitle("Auto-augmented images with different policies")
for pol_idx, policy in enumerate(policies):
auto_augmenter = T.AutoAugment(policy)
for col in range(num_cols):
augmented_img = auto_augmenter(orig_img)
ax = axs[pol_idx, col]
ax.imshow(np.asarray(augmented_img))
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if col == 0:
ax.set(ylabel=str(policy).split('.')[-1])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment