Unverified Commit 655ebdbc authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Redo torchscript example (#7889)

parent 6472a5cb
......@@ -214,7 +214,8 @@ Torchscript support
-------------------
Most transform classes and functionals support torchscript. For composing
transforms, use :class:`torch.nn.Sequential` instead of ``Compose``:
transforms, use :class:`torch.nn.Sequential` instead of
:class:`~torchvision.transforms.v2.Compose`:
.. code:: python
......@@ -232,7 +233,7 @@ transforms, use :class:`torch.nn.Sequential` instead of ``Compose``:
scripted and eager executions due to implementation differences between v1
and v2.
If you really need torchscript support for the v2 tranforms, we recommend
If you really need torchscript support for the v2 transforms, we recommend
scripting the **functionals** from the
``torchvision.transforms.v2.functional`` namespace to avoid surprises.
......@@ -242,7 +243,10 @@ are always treated as images. If you need torchscript support for other types
like bounding boxes or masks, you can rely on the :ref:`low-level kernels
<functional_transforms>`.
For any custom transformations to be used with ``torch.jit.script``, they should be derived from ``torch.nn.Module``.
For any custom transformations to be used with ``torch.jit.script``, they should
be derived from ``torch.nn.Module``.
See also: :ref:`sphx_glr_auto_examples_others_plot_scripted_tensor_transforms.py`.
V2 API reference - Recommended
------------------------------
......
"""
=========================
Tensor transforms and JIT
=========================
===================
Torchscript support
===================
.. note::
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_scripted_tensor_transforms.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_others_plot_scripted_tensor_transforms.py>` to download the full example code.
This example illustrates various features that are now supported by the
:ref:`image transformations <transforms>` on Tensor images. In particular, we
show how image transforms can be performed on GPU, and how one can also script
them using JIT compilation.
Prior to v0.8.0, transforms in torchvision have traditionally been PIL-centric
and presented multiple limitations due to that. Now, since v0.8.0, transforms
implementations are Tensor and PIL compatible, and we can achieve the following
new features:
- transform multi-band torch tensor images (with more than 3-4 channels)
- torchscript transforms together with your model for deployment
- support for GPU acceleration
- batched transformation such as for videos
- read and decode data directly as torch tensor with torchscript support (for PNG and JPEG image formats)
.. note::
These features are only possible with **Tensor** images.
This example illustrates `torchscript
<https://pytorch.org/docs/stable/jit.html>`_ support of the torchvision
:ref:`transforms <transforms>` on Tensor images.
"""
# %%
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as T
from torchvision.io import read_image
import torch.nn as nn
import torchvision.transforms as v1
from torchvision.io import read_image
plt.rcParams["savefig.bbox"] = 'tight'
torch.manual_seed(1)
def show(imgs):
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
for i, img in enumerate(imgs):
img = T.ToPILImage()(img.to('cpu'))
axs[0, i].imshow(np.asarray(img))
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
# If you're trying to run that on collab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
import sys
sys.path += ["../transforms"]
from helpers import plot
ASSETS_PATH = Path('../assets')
# %%
# The :func:`~torchvision.io.read_image` function allows to read an image and
# directly load it as a tensor
# Most transforms support torchscript. For composing transforms, we use
# :class:`torch.nn.Sequential` instead of
# :class:`~torchvision.transforms.v2.Compose`:
dog1 = read_image(str(Path('../assets') / 'dog1.jpg'))
dog2 = read_image(str(Path('../assets') / 'dog2.jpg'))
show([dog1, dog2])
# %%
# Transforming images on GPU
# --------------------------
# Most transforms natively support tensors on top of PIL images (to visualize
# the effect of the transforms, you may refer to see
# :ref:`sphx_glr_auto_examples_transforms_plot_transforms_illustrations.py`).
# Using tensor images, we can run the transforms on GPUs if cuda is available!
import torch.nn as nn
dog1 = read_image(str(ASSETS_PATH / 'dog1.jpg'))
dog2 = read_image(str(ASSETS_PATH / 'dog2.jpg'))
transforms = torch.nn.Sequential(
T.RandomCrop(224),
T.RandomHorizontalFlip(p=0.3),
v1.RandomCrop(224),
v1.RandomHorizontalFlip(p=0.3),
)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dog1 = dog1.to(device)
dog2 = dog2.to(device)
scripted_transforms = torch.jit.script(transforms)
plot([dog1, scripted_transforms(dog1), dog2, scripted_transforms(dog2)])
transformed_dog1 = transforms(dog1)
transformed_dog2 = transforms(dog2)
show([transformed_dog1, transformed_dog2])
# %%
# Scriptable transforms for easier deployment via torchscript
# -----------------------------------------------------------
# We now show how to combine image transformations and a model forward pass,
# while using ``torch.jit.script`` to obtain a single scripted module.
# .. warning::
#
# Above we have used transforms from the ``torchvision.transforms``
# namespace, i.e. the "v1" transforms. The v2 transforms from the
# ``torchvision.transforms.v2`` namespace are the :ref:`recommended
# <v1_or_v2>` way to use transforms in your code.
#
# The v2 transforms also support torchscript, but if you call
# ``torch.jit.script()`` on a v2 **class** transform, you'll actually end up
# with its (scripted) v1 equivalent. This may lead to slightly different
# results between the scripted and eager executions due to implementation
# differences between v1 and v2.
#
# If you really need torchscript support for the v2 transforms, **we
# recommend scripting the functionals** from the
# ``torchvision.transforms.v2.functional`` namespace to avoid surprises.
#
# Below we now show how to combine image transformations and a model forward
# pass, while using ``torch.jit.script`` to obtain a single scripted module.
#
# Let's define a ``Predictor`` module that transforms the input tensor and then
# applies an ImageNet model on it.
......@@ -98,7 +85,7 @@ class Predictor(nn.Module):
super().__init__()
weights = ResNet18_Weights.DEFAULT
self.resnet18 = resnet18(weights=weights, progress=False).eval()
self.transforms = weights.transforms()
self.transforms = weights.transforms(antialias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
......@@ -111,6 +98,8 @@ class Predictor(nn.Module):
# Now, let's define scripted and non-scripted instances of ``Predictor`` and
# apply it on multiple tensor images of the same size
device = "cuda" if torch.cuda.is_available() else "cpu"
predictor = Predictor().to(device)
scripted_predictor = torch.jit.script(predictor).to(device)
......@@ -143,3 +132,5 @@ with tempfile.NamedTemporaryFile() as f:
dumped_scripted_predictor = torch.jit.load(f.name)
res_scripted_dumped = dumped_scripted_predictor(batch)
assert (res_scripted_dumped == res_scripted).all()
# %%
......@@ -172,6 +172,7 @@ target = {
# Re-using the transforms and definitions from above.
out_img, out_target = transforms(img, target)
# sphinx_gallery_thumbnail_number = 4
plot([(img, target["boxes"]), (out_img, out_target["boxes"])])
print(f"{out_target['this_is_ignored']}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment