Unverified Commit 47cd5ea8 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Make transforms illutration example use v2 instead of v1 (#7886)

parent 7ebc3ad8
...@@ -84,6 +84,7 @@ class CustomGalleryExampleSortKey: ...@@ -84,6 +84,7 @@ class CustomGalleryExampleSortKey:
transforms_subsection_order = [ transforms_subsection_order = [
"plot_transforms_getting_started.py", "plot_transforms_getting_started.py",
"plot_transforms_illustrations.py",
"plot_transforms_e2e.py", "plot_transforms_e2e.py",
"plot_cutmix_mixup.py", "plot_cutmix_mixup.py",
"plot_custom_transforms.py", "plot_custom_transforms.py",
......
...@@ -62,7 +62,7 @@ show([dog1, dog2]) ...@@ -62,7 +62,7 @@ show([dog1, dog2])
# -------------------------- # --------------------------
# Most transforms natively support tensors on top of PIL images (to visualize # Most transforms natively support tensors on top of PIL images (to visualize
# the effect of the transforms, you may refer to see # the effect of the transforms, you may refer to see
# :ref:`sphx_glr_auto_examples_others_plot_transforms.py`). # :ref:`sphx_glr_auto_examples_transforms_plot_transforms_illustrations.py`).
# Using tensor images, we can run the transforms on GPUs if cuda is available! # Using tensor images, we can run the transforms on GPUs if cuda is available!
import torch.nn as nn import torch.nn as nn
......
...@@ -5,7 +5,7 @@ from torchvision import datapoints ...@@ -5,7 +5,7 @@ from torchvision import datapoints
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
def plot(imgs): def plot(imgs, row_title=None, **imshow_kwargs):
if not isinstance(imgs[0], list): if not isinstance(imgs[0], list):
# Make a 2d grid even if there's just 1 row # Make a 2d grid even if there's just 1 row
imgs = [imgs] imgs = [imgs]
...@@ -40,7 +40,11 @@ def plot(imgs): ...@@ -40,7 +40,11 @@ def plot(imgs):
img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=.65) img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=.65)
ax = axs[row_idx, col_idx] ax = axs[row_idx, col_idx]
ax.imshow(img.permute(1, 2, 0).numpy()) ax.imshow(img.permute(1, 2, 0).numpy(), **imshow_kwargs)
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if row_title is not None:
for row_idx in range(num_rows):
axs[row_idx, 0].set(ylabel=row_title[row_idx])
plt.tight_layout() plt.tight_layout()
...@@ -4,55 +4,33 @@ Illustration of transforms ...@@ -4,55 +4,33 @@ Illustration of transforms
========================== ==========================
.. note:: .. note::
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_transforms.ipynb>`_ Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_transforms_illustrations.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_others_plot_transforms.py>` to download the full example code. or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_transforms_illustrations.py>` to download the full example code.
This example illustrates the various transforms available in :ref:`the This example illustrates some of the various transforms available in :ref:`the
torchvision.transforms module <transforms>`. torchvision.transforms.v2 module <transforms>`.
""" """
# %%
# sphinx_gallery_thumbnail_path = "../../gallery/assets/transforms_thumbnail.png" # sphinx_gallery_thumbnail_path = "../../gallery/assets/transforms_thumbnail.png"
from PIL import Image from PIL import Image
from pathlib import Path from pathlib import Path
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np
import torch import torch
import torchvision.transforms as T from torchvision.transforms import v2
plt.rcParams["savefig.bbox"] = 'tight' plt.rcParams["savefig.bbox"] = 'tight'
orig_img = Image.open(Path('../assets') / 'astronaut.jpg')
# if you change the seed, make sure that the randomly-applied transforms # if you change the seed, make sure that the randomly-applied transforms
# properly show that the image can be both transformed and *not* transformed! # properly show that the image can be both transformed and *not* transformed!
torch.manual_seed(0) torch.manual_seed(0)
# If you're trying to run that on collab, you can download the assets and the
def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): # helpers from https://github.com/pytorch/vision/tree/main/gallery/
if not isinstance(imgs[0], list): from helpers import plot
# Make a 2d grid even if there's just 1 row orig_img = Image.open(Path('../assets') / 'astronaut.jpg')
imgs = [imgs]
num_rows = len(imgs)
num_cols = len(imgs[0]) + with_orig
fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
for row_idx, row in enumerate(imgs):
row = [orig_img] + row if with_orig else row
for col_idx, img in enumerate(row):
ax = axs[row_idx, col_idx]
ax.imshow(np.asarray(img), **imshow_kwargs)
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if with_orig:
axs[0, 0].set(title='Original image')
axs[0, 0].title.set_size(8)
if row_title is not None:
for row_idx in range(num_rows):
axs[row_idx, 0].set(ylabel=row_title[row_idx])
plt.tight_layout()
# %% # %%
# Geometric Transforms # Geometric Transforms
...@@ -66,8 +44,8 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): ...@@ -66,8 +44,8 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
# The :class:`~torchvision.transforms.Pad` transform # The :class:`~torchvision.transforms.Pad` transform
# (see also :func:`~torchvision.transforms.functional.pad`) # (see also :func:`~torchvision.transforms.functional.pad`)
# pads all image borders with some pixel values. # pads all image borders with some pixel values.
padded_imgs = [T.Pad(padding=padding)(orig_img) for padding in (3, 10, 30, 50)] padded_imgs = [v2.Pad(padding=padding)(orig_img) for padding in (3, 10, 30, 50)]
plot(padded_imgs) plot([orig_img] + padded_imgs)
# %% # %%
# Resize # Resize
...@@ -75,8 +53,8 @@ plot(padded_imgs) ...@@ -75,8 +53,8 @@ plot(padded_imgs)
# The :class:`~torchvision.transforms.Resize` transform # The :class:`~torchvision.transforms.Resize` transform
# (see also :func:`~torchvision.transforms.functional.resize`) # (see also :func:`~torchvision.transforms.functional.resize`)
# resizes an image. # resizes an image.
resized_imgs = [T.Resize(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)] resized_imgs = [v2.Resize(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
plot(resized_imgs) plot([orig_img] + resized_imgs)
# %% # %%
# CenterCrop # CenterCrop
...@@ -84,8 +62,8 @@ plot(resized_imgs) ...@@ -84,8 +62,8 @@ plot(resized_imgs)
# The :class:`~torchvision.transforms.CenterCrop` transform # The :class:`~torchvision.transforms.CenterCrop` transform
# (see also :func:`~torchvision.transforms.functional.center_crop`) # (see also :func:`~torchvision.transforms.functional.center_crop`)
# crops the given image at the center. # crops the given image at the center.
center_crops = [T.CenterCrop(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)] center_crops = [v2.CenterCrop(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
plot(center_crops) plot([orig_img] + center_crops)
# %% # %%
# FiveCrop # FiveCrop
...@@ -93,8 +71,8 @@ plot(center_crops) ...@@ -93,8 +71,8 @@ plot(center_crops)
# The :class:`~torchvision.transforms.FiveCrop` transform # The :class:`~torchvision.transforms.FiveCrop` transform
# (see also :func:`~torchvision.transforms.functional.five_crop`) # (see also :func:`~torchvision.transforms.functional.five_crop`)
# crops the given image into four corners and the central crop. # crops the given image into four corners and the central crop.
(top_left, top_right, bottom_left, bottom_right, center) = T.FiveCrop(size=(100, 100))(orig_img) (top_left, top_right, bottom_left, bottom_right, center) = v2.FiveCrop(size=(100, 100))(orig_img)
plot([top_left, top_right, bottom_left, bottom_right, center]) plot([orig_img] + [top_left, top_right, bottom_left, bottom_right, center])
# %% # %%
# RandomPerspective # RandomPerspective
...@@ -102,9 +80,9 @@ plot([top_left, top_right, bottom_left, bottom_right, center]) ...@@ -102,9 +80,9 @@ plot([top_left, top_right, bottom_left, bottom_right, center])
# The :class:`~torchvision.transforms.RandomPerspective` transform # The :class:`~torchvision.transforms.RandomPerspective` transform
# (see also :func:`~torchvision.transforms.functional.perspective`) # (see also :func:`~torchvision.transforms.functional.perspective`)
# performs random perspective transform on an image. # performs random perspective transform on an image.
perspective_transformer = T.RandomPerspective(distortion_scale=0.6, p=1.0) perspective_transformer = v2.RandomPerspective(distortion_scale=0.6, p=1.0)
perspective_imgs = [perspective_transformer(orig_img) for _ in range(4)] perspective_imgs = [perspective_transformer(orig_img) for _ in range(4)]
plot(perspective_imgs) plot([orig_img] + perspective_imgs)
# %% # %%
# RandomRotation # RandomRotation
...@@ -112,9 +90,9 @@ plot(perspective_imgs) ...@@ -112,9 +90,9 @@ plot(perspective_imgs)
# The :class:`~torchvision.transforms.RandomRotation` transform # The :class:`~torchvision.transforms.RandomRotation` transform
# (see also :func:`~torchvision.transforms.functional.rotate`) # (see also :func:`~torchvision.transforms.functional.rotate`)
# rotates an image with random angle. # rotates an image with random angle.
rotater = T.RandomRotation(degrees=(0, 180)) rotater = v2.RandomRotation(degrees=(0, 180))
rotated_imgs = [rotater(orig_img) for _ in range(4)] rotated_imgs = [rotater(orig_img) for _ in range(4)]
plot(rotated_imgs) plot([orig_img] + rotated_imgs)
# %% # %%
# RandomAffine # RandomAffine
...@@ -122,9 +100,9 @@ plot(rotated_imgs) ...@@ -122,9 +100,9 @@ plot(rotated_imgs)
# The :class:`~torchvision.transforms.RandomAffine` transform # The :class:`~torchvision.transforms.RandomAffine` transform
# (see also :func:`~torchvision.transforms.functional.affine`) # (see also :func:`~torchvision.transforms.functional.affine`)
# performs random affine transform on an image. # performs random affine transform on an image.
affine_transfomer = T.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75)) affine_transfomer = v2.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75))
affine_imgs = [affine_transfomer(orig_img) for _ in range(4)] affine_imgs = [affine_transfomer(orig_img) for _ in range(4)]
plot(affine_imgs) plot([orig_img] + affine_imgs)
# %% # %%
# ElasticTransform # ElasticTransform
...@@ -133,9 +111,9 @@ plot(affine_imgs) ...@@ -133,9 +111,9 @@ plot(affine_imgs)
# (see also :func:`~torchvision.transforms.functional.elastic_transform`) # (see also :func:`~torchvision.transforms.functional.elastic_transform`)
# Randomly transforms the morphology of objects in images and produces a # Randomly transforms the morphology of objects in images and produces a
# see-through-water-like effect. # see-through-water-like effect.
elastic_transformer = T.ElasticTransform(alpha=250.0) elastic_transformer = v2.ElasticTransform(alpha=250.0)
transformed_imgs = [elastic_transformer(orig_img) for _ in range(2)] transformed_imgs = [elastic_transformer(orig_img) for _ in range(2)]
plot(transformed_imgs) plot([orig_img] + transformed_imgs)
# %% # %%
# RandomCrop # RandomCrop
...@@ -143,9 +121,9 @@ plot(transformed_imgs) ...@@ -143,9 +121,9 @@ plot(transformed_imgs)
# The :class:`~torchvision.transforms.RandomCrop` transform # The :class:`~torchvision.transforms.RandomCrop` transform
# (see also :func:`~torchvision.transforms.functional.crop`) # (see also :func:`~torchvision.transforms.functional.crop`)
# crops an image at a random location. # crops an image at a random location.
cropper = T.RandomCrop(size=(128, 128)) cropper = v2.RandomCrop(size=(128, 128))
crops = [cropper(orig_img) for _ in range(4)] crops = [cropper(orig_img) for _ in range(4)]
plot(crops) plot([orig_img] + crops)
# %% # %%
# RandomResizedCrop # RandomResizedCrop
...@@ -154,9 +132,9 @@ plot(crops) ...@@ -154,9 +132,9 @@ plot(crops)
# (see also :func:`~torchvision.transforms.functional.resized_crop`) # (see also :func:`~torchvision.transforms.functional.resized_crop`)
# crops an image at a random location, and then resizes the crop to a given # crops an image at a random location, and then resizes the crop to a given
# size. # size.
resize_cropper = T.RandomResizedCrop(size=(32, 32)) resize_cropper = v2.RandomResizedCrop(size=(32, 32))
resized_crops = [resize_cropper(orig_img) for _ in range(4)] resized_crops = [resize_cropper(orig_img) for _ in range(4)]
plot(resized_crops) plot([orig_img] + resized_crops)
# %% # %%
# Photometric Transforms # Photometric Transforms
...@@ -175,17 +153,17 @@ plot(resized_crops) ...@@ -175,17 +153,17 @@ plot(resized_crops)
# The :class:`~torchvision.transforms.Grayscale` transform # The :class:`~torchvision.transforms.Grayscale` transform
# (see also :func:`~torchvision.transforms.functional.to_grayscale`) # (see also :func:`~torchvision.transforms.functional.to_grayscale`)
# converts an image to grayscale # converts an image to grayscale
gray_img = T.Grayscale()(orig_img) gray_img = v2.Grayscale()(orig_img)
plot([gray_img], cmap='gray') plot([orig_img, gray_img], cmap='gray')
# %% # %%
# ColorJitter # ColorJitter
# ~~~~~~~~~~~ # ~~~~~~~~~~~
# The :class:`~torchvision.transforms.ColorJitter` transform # The :class:`~torchvision.transforms.ColorJitter` transform
# randomly changes the brightness, contrast, saturation, hue, and other properties of an image. # randomly changes the brightness, contrast, saturation, hue, and other properties of an image.
jitter = T.ColorJitter(brightness=.5, hue=.3) jitter = v2.ColorJitter(brightness=.5, hue=.3)
jitted_imgs = [jitter(orig_img) for _ in range(4)] jittered_imgs = [jitter(orig_img) for _ in range(4)]
plot(jitted_imgs) plot([orig_img] + jittered_imgs)
# %% # %%
# GaussianBlur # GaussianBlur
...@@ -193,9 +171,9 @@ plot(jitted_imgs) ...@@ -193,9 +171,9 @@ plot(jitted_imgs)
# The :class:`~torchvision.transforms.GaussianBlur` transform # The :class:`~torchvision.transforms.GaussianBlur` transform
# (see also :func:`~torchvision.transforms.functional.gaussian_blur`) # (see also :func:`~torchvision.transforms.functional.gaussian_blur`)
# performs gaussian blur transform on an image. # performs gaussian blur transform on an image.
blurrer = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)) blurrer = v2.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5.))
blurred_imgs = [blurrer(orig_img) for _ in range(4)] blurred_imgs = [blurrer(orig_img) for _ in range(4)]
plot(blurred_imgs) plot([orig_img] + blurred_imgs)
# %% # %%
# RandomInvert # RandomInvert
...@@ -203,9 +181,9 @@ plot(blurred_imgs) ...@@ -203,9 +181,9 @@ plot(blurred_imgs)
# The :class:`~torchvision.transforms.RandomInvert` transform # The :class:`~torchvision.transforms.RandomInvert` transform
# (see also :func:`~torchvision.transforms.functional.invert`) # (see also :func:`~torchvision.transforms.functional.invert`)
# randomly inverts the colors of the given image. # randomly inverts the colors of the given image.
inverter = T.RandomInvert() inverter = v2.RandomInvert()
invertered_imgs = [inverter(orig_img) for _ in range(4)] invertered_imgs = [inverter(orig_img) for _ in range(4)]
plot(invertered_imgs) plot([orig_img] + invertered_imgs)
# %% # %%
# RandomPosterize # RandomPosterize
...@@ -214,9 +192,9 @@ plot(invertered_imgs) ...@@ -214,9 +192,9 @@ plot(invertered_imgs)
# (see also :func:`~torchvision.transforms.functional.posterize`) # (see also :func:`~torchvision.transforms.functional.posterize`)
# randomly posterizes the image by reducing the number of bits # randomly posterizes the image by reducing the number of bits
# of each color channel. # of each color channel.
posterizer = T.RandomPosterize(bits=2) posterizer = v2.RandomPosterize(bits=2)
posterized_imgs = [posterizer(orig_img) for _ in range(4)] posterized_imgs = [posterizer(orig_img) for _ in range(4)]
plot(posterized_imgs) plot([orig_img] + posterized_imgs)
# %% # %%
# RandomSolarize # RandomSolarize
...@@ -225,9 +203,9 @@ plot(posterized_imgs) ...@@ -225,9 +203,9 @@ plot(posterized_imgs)
# (see also :func:`~torchvision.transforms.functional.solarize`) # (see also :func:`~torchvision.transforms.functional.solarize`)
# randomly solarizes the image by inverting all pixel values above # randomly solarizes the image by inverting all pixel values above
# the threshold. # the threshold.
solarizer = T.RandomSolarize(threshold=192.0) solarizer = v2.RandomSolarize(threshold=192.0)
solarized_imgs = [solarizer(orig_img) for _ in range(4)] solarized_imgs = [solarizer(orig_img) for _ in range(4)]
plot(solarized_imgs) plot([orig_img] + solarized_imgs)
# %% # %%
# RandomAdjustSharpness # RandomAdjustSharpness
...@@ -235,9 +213,9 @@ plot(solarized_imgs) ...@@ -235,9 +213,9 @@ plot(solarized_imgs)
# The :class:`~torchvision.transforms.RandomAdjustSharpness` transform # The :class:`~torchvision.transforms.RandomAdjustSharpness` transform
# (see also :func:`~torchvision.transforms.functional.adjust_sharpness`) # (see also :func:`~torchvision.transforms.functional.adjust_sharpness`)
# randomly adjusts the sharpness of the given image. # randomly adjusts the sharpness of the given image.
sharpness_adjuster = T.RandomAdjustSharpness(sharpness_factor=2) sharpness_adjuster = v2.RandomAdjustSharpness(sharpness_factor=2)
sharpened_imgs = [sharpness_adjuster(orig_img) for _ in range(4)] sharpened_imgs = [sharpness_adjuster(orig_img) for _ in range(4)]
plot(sharpened_imgs) plot([orig_img] + sharpened_imgs)
# %% # %%
# RandomAutocontrast # RandomAutocontrast
...@@ -245,9 +223,9 @@ plot(sharpened_imgs) ...@@ -245,9 +223,9 @@ plot(sharpened_imgs)
# The :class:`~torchvision.transforms.RandomAutocontrast` transform # The :class:`~torchvision.transforms.RandomAutocontrast` transform
# (see also :func:`~torchvision.transforms.functional.autocontrast`) # (see also :func:`~torchvision.transforms.functional.autocontrast`)
# randomly applies autocontrast to the given image. # randomly applies autocontrast to the given image.
autocontraster = T.RandomAutocontrast() autocontraster = v2.RandomAutocontrast()
autocontrasted_imgs = [autocontraster(orig_img) for _ in range(4)] autocontrasted_imgs = [autocontraster(orig_img) for _ in range(4)]
plot(autocontrasted_imgs) plot([orig_img] + autocontrasted_imgs)
# %% # %%
# RandomEqualize # RandomEqualize
...@@ -255,9 +233,9 @@ plot(autocontrasted_imgs) ...@@ -255,9 +233,9 @@ plot(autocontrasted_imgs)
# The :class:`~torchvision.transforms.RandomEqualize` transform # The :class:`~torchvision.transforms.RandomEqualize` transform
# (see also :func:`~torchvision.transforms.functional.equalize`) # (see also :func:`~torchvision.transforms.functional.equalize`)
# randomly equalizes the histogram of the given image. # randomly equalizes the histogram of the given image.
equalizer = T.RandomEqualize() equalizer = v2.RandomEqualize()
equalized_imgs = [equalizer(orig_img) for _ in range(4)] equalized_imgs = [equalizer(orig_img) for _ in range(4)]
plot(equalized_imgs) plot([orig_img] + equalized_imgs)
# %% # %%
# Augmentation Transforms # Augmentation Transforms
...@@ -270,22 +248,22 @@ plot(equalized_imgs) ...@@ -270,22 +248,22 @@ plot(equalized_imgs)
# The :class:`~torchvision.transforms.AutoAugment` transform # The :class:`~torchvision.transforms.AutoAugment` transform
# automatically augments data based on a given auto-augmentation policy. # automatically augments data based on a given auto-augmentation policy.
# See :class:`~torchvision.transforms.AutoAugmentPolicy` for the available policies. # See :class:`~torchvision.transforms.AutoAugmentPolicy` for the available policies.
policies = [T.AutoAugmentPolicy.CIFAR10, T.AutoAugmentPolicy.IMAGENET, T.AutoAugmentPolicy.SVHN] policies = [v2.AutoAugmentPolicy.CIFAR10, v2.AutoAugmentPolicy.IMAGENET, v2.AutoAugmentPolicy.SVHN]
augmenters = [T.AutoAugment(policy) for policy in policies] augmenters = [v2.AutoAugment(policy) for policy in policies]
imgs = [ imgs = [
[augmenter(orig_img) for _ in range(4)] [augmenter(orig_img) for _ in range(4)]
for augmenter in augmenters for augmenter in augmenters
] ]
row_title = [str(policy).split('.')[-1] for policy in policies] row_title = [str(policy).split('.')[-1] for policy in policies]
plot(imgs, row_title=row_title) plot([[orig_img] + row for row in imgs], row_title=row_title)
# %% # %%
# RandAugment # RandAugment
# ~~~~~~~~~~~ # ~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandAugment` is an alternate version of AutoAugment. # The :class:`~torchvision.transforms.RandAugment` is an alternate version of AutoAugment.
augmenter = T.RandAugment() augmenter = v2.RandAugment()
imgs = [augmenter(orig_img) for _ in range(4)] imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs) plot([orig_img] + imgs)
# %% # %%
# TrivialAugmentWide # TrivialAugmentWide
...@@ -293,17 +271,17 @@ plot(imgs) ...@@ -293,17 +271,17 @@ plot(imgs)
# The :class:`~torchvision.transforms.TrivialAugmentWide` is an alternate implementation of AutoAugment. # The :class:`~torchvision.transforms.TrivialAugmentWide` is an alternate implementation of AutoAugment.
# However, instead of transforming an image multiple times, it transforms an image only once # However, instead of transforming an image multiple times, it transforms an image only once
# using a random transform from a given list with a random strength number. # using a random transform from a given list with a random strength number.
augmenter = T.TrivialAugmentWide() augmenter = v2.TrivialAugmentWide()
imgs = [augmenter(orig_img) for _ in range(4)] imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs) plot([orig_img] + imgs)
# %% # %%
# AugMix # AugMix
# ~~~~~~ # ~~~~~~
# The :class:`~torchvision.transforms.AugMix` transform interpolates between augmented versions of an image. # The :class:`~torchvision.transforms.AugMix` transform interpolates between augmented versions of an image.
augmenter = T.AugMix() augmenter = v2.AugMix()
imgs = [augmenter(orig_img) for _ in range(4)] imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs) plot([orig_img] + imgs)
# %% # %%
# Randomly-applied Transforms # Randomly-applied Transforms
...@@ -318,9 +296,9 @@ plot(imgs) ...@@ -318,9 +296,9 @@ plot(imgs)
# The :class:`~torchvision.transforms.RandomHorizontalFlip` transform # The :class:`~torchvision.transforms.RandomHorizontalFlip` transform
# (see also :func:`~torchvision.transforms.functional.hflip`) # (see also :func:`~torchvision.transforms.functional.hflip`)
# performs horizontal flip of an image, with a given probability. # performs horizontal flip of an image, with a given probability.
hflipper = T.RandomHorizontalFlip(p=0.5) hflipper = v2.RandomHorizontalFlip(p=0.5)
transformed_imgs = [hflipper(orig_img) for _ in range(4)] transformed_imgs = [hflipper(orig_img) for _ in range(4)]
plot(transformed_imgs) plot([orig_img] + transformed_imgs)
# %% # %%
# RandomVerticalFlip # RandomVerticalFlip
...@@ -328,15 +306,15 @@ plot(transformed_imgs) ...@@ -328,15 +306,15 @@ plot(transformed_imgs)
# The :class:`~torchvision.transforms.RandomVerticalFlip` transform # The :class:`~torchvision.transforms.RandomVerticalFlip` transform
# (see also :func:`~torchvision.transforms.functional.vflip`) # (see also :func:`~torchvision.transforms.functional.vflip`)
# performs vertical flip of an image, with a given probability. # performs vertical flip of an image, with a given probability.
vflipper = T.RandomVerticalFlip(p=0.5) vflipper = v2.RandomVerticalFlip(p=0.5)
transformed_imgs = [vflipper(orig_img) for _ in range(4)] transformed_imgs = [vflipper(orig_img) for _ in range(4)]
plot(transformed_imgs) plot([orig_img] + transformed_imgs)
# %% # %%
# RandomApply # RandomApply
# ~~~~~~~~~~~ # ~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandomApply` transform # The :class:`~torchvision.transforms.RandomApply` transform
# randomly applies a list of transforms, with a given probability. # randomly applies a list of transforms, with a given probability.
applier = T.RandomApply(transforms=[T.RandomCrop(size=(64, 64))], p=0.5) applier = v2.RandomApply(transforms=[v2.RandomCrop(size=(64, 64))], p=0.5)
transformed_imgs = [applier(orig_img) for _ in range(4)] transformed_imgs = [applier(orig_img) for _ in range(4)]
plot(transformed_imgs) plot([orig_img] + transformed_imgs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment