Unverified Commit 47cd5ea8 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Make transforms illutration example use v2 instead of v1 (#7886)

parent 7ebc3ad8
......@@ -84,6 +84,7 @@ class CustomGalleryExampleSortKey:
transforms_subsection_order = [
"plot_transforms_getting_started.py",
"plot_transforms_illustrations.py",
"plot_transforms_e2e.py",
"plot_cutmix_mixup.py",
"plot_custom_transforms.py",
......
......@@ -62,7 +62,7 @@ show([dog1, dog2])
# --------------------------
# Most transforms natively support tensors on top of PIL images (to visualize
# the effect of the transforms, you may refer to see
# :ref:`sphx_glr_auto_examples_others_plot_transforms.py`).
# :ref:`sphx_glr_auto_examples_transforms_plot_transforms_illustrations.py`).
# Using tensor images, we can run the transforms on GPUs if cuda is available!
import torch.nn as nn
......
......@@ -5,7 +5,7 @@ from torchvision import datapoints
from torchvision.transforms.v2 import functional as F
def plot(imgs):
def plot(imgs, row_title=None, **imshow_kwargs):
if not isinstance(imgs[0], list):
# Make a 2d grid even if there's just 1 row
imgs = [imgs]
......@@ -40,7 +40,11 @@ def plot(imgs):
img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=.65)
ax = axs[row_idx, col_idx]
ax.imshow(img.permute(1, 2, 0).numpy())
ax.imshow(img.permute(1, 2, 0).numpy(), **imshow_kwargs)
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if row_title is not None:
for row_idx in range(num_rows):
axs[row_idx, 0].set(ylabel=row_title[row_idx])
plt.tight_layout()
......@@ -4,55 +4,33 @@ Illustration of transforms
==========================
.. note::
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_transforms.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_others_plot_transforms.py>` to download the full example code.
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_transforms_illustrations.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_transforms_illustrations.py>` to download the full example code.
This example illustrates the various transforms available in :ref:`the
torchvision.transforms module <transforms>`.
This example illustrates some of the various transforms available in :ref:`the
torchvision.transforms.v2 module <transforms>`.
"""
# %%
# sphinx_gallery_thumbnail_path = "../../gallery/assets/transforms_thumbnail.png"
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as T
from torchvision.transforms import v2
plt.rcParams["savefig.bbox"] = 'tight'
orig_img = Image.open(Path('../assets') / 'astronaut.jpg')
# if you change the seed, make sure that the randomly-applied transforms
# properly show that the image can be both transformed and *not* transformed!
torch.manual_seed(0)
def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
if not isinstance(imgs[0], list):
# Make a 2d grid even if there's just 1 row
imgs = [imgs]
num_rows = len(imgs)
num_cols = len(imgs[0]) + with_orig
fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
for row_idx, row in enumerate(imgs):
row = [orig_img] + row if with_orig else row
for col_idx, img in enumerate(row):
ax = axs[row_idx, col_idx]
ax.imshow(np.asarray(img), **imshow_kwargs)
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if with_orig:
axs[0, 0].set(title='Original image')
axs[0, 0].title.set_size(8)
if row_title is not None:
for row_idx in range(num_rows):
axs[row_idx, 0].set(ylabel=row_title[row_idx])
plt.tight_layout()
# If you're trying to run that on collab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
from helpers import plot
orig_img = Image.open(Path('../assets') / 'astronaut.jpg')
# %%
# Geometric Transforms
......@@ -66,8 +44,8 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
# The :class:`~torchvision.transforms.Pad` transform
# (see also :func:`~torchvision.transforms.functional.pad`)
# pads all image borders with some pixel values.
padded_imgs = [T.Pad(padding=padding)(orig_img) for padding in (3, 10, 30, 50)]
plot(padded_imgs)
padded_imgs = [v2.Pad(padding=padding)(orig_img) for padding in (3, 10, 30, 50)]
plot([orig_img] + padded_imgs)
# %%
# Resize
......@@ -75,8 +53,8 @@ plot(padded_imgs)
# The :class:`~torchvision.transforms.Resize` transform
# (see also :func:`~torchvision.transforms.functional.resize`)
# resizes an image.
resized_imgs = [T.Resize(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
plot(resized_imgs)
resized_imgs = [v2.Resize(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
plot([orig_img] + resized_imgs)
# %%
# CenterCrop
......@@ -84,8 +62,8 @@ plot(resized_imgs)
# The :class:`~torchvision.transforms.CenterCrop` transform
# (see also :func:`~torchvision.transforms.functional.center_crop`)
# crops the given image at the center.
center_crops = [T.CenterCrop(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
plot(center_crops)
center_crops = [v2.CenterCrop(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
plot([orig_img] + center_crops)
# %%
# FiveCrop
......@@ -93,8 +71,8 @@ plot(center_crops)
# The :class:`~torchvision.transforms.FiveCrop` transform
# (see also :func:`~torchvision.transforms.functional.five_crop`)
# crops the given image into four corners and the central crop.
(top_left, top_right, bottom_left, bottom_right, center) = T.FiveCrop(size=(100, 100))(orig_img)
plot([top_left, top_right, bottom_left, bottom_right, center])
(top_left, top_right, bottom_left, bottom_right, center) = v2.FiveCrop(size=(100, 100))(orig_img)
plot([orig_img] + [top_left, top_right, bottom_left, bottom_right, center])
# %%
# RandomPerspective
......@@ -102,9 +80,9 @@ plot([top_left, top_right, bottom_left, bottom_right, center])
# The :class:`~torchvision.transforms.RandomPerspective` transform
# (see also :func:`~torchvision.transforms.functional.perspective`)
# performs random perspective transform on an image.
perspective_transformer = T.RandomPerspective(distortion_scale=0.6, p=1.0)
perspective_transformer = v2.RandomPerspective(distortion_scale=0.6, p=1.0)
perspective_imgs = [perspective_transformer(orig_img) for _ in range(4)]
plot(perspective_imgs)
plot([orig_img] + perspective_imgs)
# %%
# RandomRotation
......@@ -112,9 +90,9 @@ plot(perspective_imgs)
# The :class:`~torchvision.transforms.RandomRotation` transform
# (see also :func:`~torchvision.transforms.functional.rotate`)
# rotates an image with random angle.
rotater = T.RandomRotation(degrees=(0, 180))
rotater = v2.RandomRotation(degrees=(0, 180))
rotated_imgs = [rotater(orig_img) for _ in range(4)]
plot(rotated_imgs)
plot([orig_img] + rotated_imgs)
# %%
# RandomAffine
......@@ -122,9 +100,9 @@ plot(rotated_imgs)
# The :class:`~torchvision.transforms.RandomAffine` transform
# (see also :func:`~torchvision.transforms.functional.affine`)
# performs random affine transform on an image.
affine_transfomer = T.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75))
affine_transfomer = v2.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75))
affine_imgs = [affine_transfomer(orig_img) for _ in range(4)]
plot(affine_imgs)
plot([orig_img] + affine_imgs)
# %%
# ElasticTransform
......@@ -133,9 +111,9 @@ plot(affine_imgs)
# (see also :func:`~torchvision.transforms.functional.elastic_transform`)
# Randomly transforms the morphology of objects in images and produces a
# see-through-water-like effect.
elastic_transformer = T.ElasticTransform(alpha=250.0)
elastic_transformer = v2.ElasticTransform(alpha=250.0)
transformed_imgs = [elastic_transformer(orig_img) for _ in range(2)]
plot(transformed_imgs)
plot([orig_img] + transformed_imgs)
# %%
# RandomCrop
......@@ -143,9 +121,9 @@ plot(transformed_imgs)
# The :class:`~torchvision.transforms.RandomCrop` transform
# (see also :func:`~torchvision.transforms.functional.crop`)
# crops an image at a random location.
cropper = T.RandomCrop(size=(128, 128))
cropper = v2.RandomCrop(size=(128, 128))
crops = [cropper(orig_img) for _ in range(4)]
plot(crops)
plot([orig_img] + crops)
# %%
# RandomResizedCrop
......@@ -154,9 +132,9 @@ plot(crops)
# (see also :func:`~torchvision.transforms.functional.resized_crop`)
# crops an image at a random location, and then resizes the crop to a given
# size.
resize_cropper = T.RandomResizedCrop(size=(32, 32))
resize_cropper = v2.RandomResizedCrop(size=(32, 32))
resized_crops = [resize_cropper(orig_img) for _ in range(4)]
plot(resized_crops)
plot([orig_img] + resized_crops)
# %%
# Photometric Transforms
......@@ -175,17 +153,17 @@ plot(resized_crops)
# The :class:`~torchvision.transforms.Grayscale` transform
# (see also :func:`~torchvision.transforms.functional.to_grayscale`)
# converts an image to grayscale
gray_img = T.Grayscale()(orig_img)
plot([gray_img], cmap='gray')
gray_img = v2.Grayscale()(orig_img)
plot([orig_img, gray_img], cmap='gray')
# %%
# ColorJitter
# ~~~~~~~~~~~
# The :class:`~torchvision.transforms.ColorJitter` transform
# randomly changes the brightness, contrast, saturation, hue, and other properties of an image.
jitter = T.ColorJitter(brightness=.5, hue=.3)
jitted_imgs = [jitter(orig_img) for _ in range(4)]
plot(jitted_imgs)
jitter = v2.ColorJitter(brightness=.5, hue=.3)
jittered_imgs = [jitter(orig_img) for _ in range(4)]
plot([orig_img] + jittered_imgs)
# %%
# GaussianBlur
......@@ -193,9 +171,9 @@ plot(jitted_imgs)
# The :class:`~torchvision.transforms.GaussianBlur` transform
# (see also :func:`~torchvision.transforms.functional.gaussian_blur`)
# performs gaussian blur transform on an image.
blurrer = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))
blurrer = v2.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5.))
blurred_imgs = [blurrer(orig_img) for _ in range(4)]
plot(blurred_imgs)
plot([orig_img] + blurred_imgs)
# %%
# RandomInvert
......@@ -203,9 +181,9 @@ plot(blurred_imgs)
# The :class:`~torchvision.transforms.RandomInvert` transform
# (see also :func:`~torchvision.transforms.functional.invert`)
# randomly inverts the colors of the given image.
inverter = T.RandomInvert()
inverter = v2.RandomInvert()
invertered_imgs = [inverter(orig_img) for _ in range(4)]
plot(invertered_imgs)
plot([orig_img] + invertered_imgs)
# %%
# RandomPosterize
......@@ -214,9 +192,9 @@ plot(invertered_imgs)
# (see also :func:`~torchvision.transforms.functional.posterize`)
# randomly posterizes the image by reducing the number of bits
# of each color channel.
posterizer = T.RandomPosterize(bits=2)
posterizer = v2.RandomPosterize(bits=2)
posterized_imgs = [posterizer(orig_img) for _ in range(4)]
plot(posterized_imgs)
plot([orig_img] + posterized_imgs)
# %%
# RandomSolarize
......@@ -225,9 +203,9 @@ plot(posterized_imgs)
# (see also :func:`~torchvision.transforms.functional.solarize`)
# randomly solarizes the image by inverting all pixel values above
# the threshold.
solarizer = T.RandomSolarize(threshold=192.0)
solarizer = v2.RandomSolarize(threshold=192.0)
solarized_imgs = [solarizer(orig_img) for _ in range(4)]
plot(solarized_imgs)
plot([orig_img] + solarized_imgs)
# %%
# RandomAdjustSharpness
......@@ -235,9 +213,9 @@ plot(solarized_imgs)
# The :class:`~torchvision.transforms.RandomAdjustSharpness` transform
# (see also :func:`~torchvision.transforms.functional.adjust_sharpness`)
# randomly adjusts the sharpness of the given image.
sharpness_adjuster = T.RandomAdjustSharpness(sharpness_factor=2)
sharpness_adjuster = v2.RandomAdjustSharpness(sharpness_factor=2)
sharpened_imgs = [sharpness_adjuster(orig_img) for _ in range(4)]
plot(sharpened_imgs)
plot([orig_img] + sharpened_imgs)
# %%
# RandomAutocontrast
......@@ -245,9 +223,9 @@ plot(sharpened_imgs)
# The :class:`~torchvision.transforms.RandomAutocontrast` transform
# (see also :func:`~torchvision.transforms.functional.autocontrast`)
# randomly applies autocontrast to the given image.
autocontraster = T.RandomAutocontrast()
autocontraster = v2.RandomAutocontrast()
autocontrasted_imgs = [autocontraster(orig_img) for _ in range(4)]
plot(autocontrasted_imgs)
plot([orig_img] + autocontrasted_imgs)
# %%
# RandomEqualize
......@@ -255,9 +233,9 @@ plot(autocontrasted_imgs)
# The :class:`~torchvision.transforms.RandomEqualize` transform
# (see also :func:`~torchvision.transforms.functional.equalize`)
# randomly equalizes the histogram of the given image.
equalizer = T.RandomEqualize()
equalizer = v2.RandomEqualize()
equalized_imgs = [equalizer(orig_img) for _ in range(4)]
plot(equalized_imgs)
plot([orig_img] + equalized_imgs)
# %%
# Augmentation Transforms
......@@ -270,22 +248,22 @@ plot(equalized_imgs)
# The :class:`~torchvision.transforms.AutoAugment` transform
# automatically augments data based on a given auto-augmentation policy.
# See :class:`~torchvision.transforms.AutoAugmentPolicy` for the available policies.
policies = [T.AutoAugmentPolicy.CIFAR10, T.AutoAugmentPolicy.IMAGENET, T.AutoAugmentPolicy.SVHN]
augmenters = [T.AutoAugment(policy) for policy in policies]
policies = [v2.AutoAugmentPolicy.CIFAR10, v2.AutoAugmentPolicy.IMAGENET, v2.AutoAugmentPolicy.SVHN]
augmenters = [v2.AutoAugment(policy) for policy in policies]
imgs = [
[augmenter(orig_img) for _ in range(4)]
for augmenter in augmenters
]
row_title = [str(policy).split('.')[-1] for policy in policies]
plot(imgs, row_title=row_title)
plot([[orig_img] + row for row in imgs], row_title=row_title)
# %%
# RandAugment
# ~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandAugment` is an alternate version of AutoAugment.
augmenter = T.RandAugment()
augmenter = v2.RandAugment()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)
plot([orig_img] + imgs)
# %%
# TrivialAugmentWide
......@@ -293,17 +271,17 @@ plot(imgs)
# The :class:`~torchvision.transforms.TrivialAugmentWide` is an alternate implementation of AutoAugment.
# However, instead of transforming an image multiple times, it transforms an image only once
# using a random transform from a given list with a random strength number.
augmenter = T.TrivialAugmentWide()
augmenter = v2.TrivialAugmentWide()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)
plot([orig_img] + imgs)
# %%
# AugMix
# ~~~~~~
# The :class:`~torchvision.transforms.AugMix` transform interpolates between augmented versions of an image.
augmenter = T.AugMix()
augmenter = v2.AugMix()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)
plot([orig_img] + imgs)
# %%
# Randomly-applied Transforms
......@@ -318,9 +296,9 @@ plot(imgs)
# The :class:`~torchvision.transforms.RandomHorizontalFlip` transform
# (see also :func:`~torchvision.transforms.functional.hflip`)
# performs horizontal flip of an image, with a given probability.
hflipper = T.RandomHorizontalFlip(p=0.5)
hflipper = v2.RandomHorizontalFlip(p=0.5)
transformed_imgs = [hflipper(orig_img) for _ in range(4)]
plot(transformed_imgs)
plot([orig_img] + transformed_imgs)
# %%
# RandomVerticalFlip
......@@ -328,15 +306,15 @@ plot(transformed_imgs)
# The :class:`~torchvision.transforms.RandomVerticalFlip` transform
# (see also :func:`~torchvision.transforms.functional.vflip`)
# performs vertical flip of an image, with a given probability.
vflipper = T.RandomVerticalFlip(p=0.5)
vflipper = v2.RandomVerticalFlip(p=0.5)
transformed_imgs = [vflipper(orig_img) for _ in range(4)]
plot(transformed_imgs)
plot([orig_img] + transformed_imgs)
# %%
# RandomApply
# ~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomApply` transform
# randomly applies a list of transforms, with a given probability.
applier = T.RandomApply(transforms=[T.RandomCrop(size=(64, 64))], p=0.5)
applier = v2.RandomApply(transforms=[v2.RandomCrop(size=(64, 64))], p=0.5)
transformed_imgs = [applier(orig_img) for _ in range(4)]
plot(transformed_imgs)
plot([orig_img] + transformed_imgs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment