You need to sign in or sign up before continuing.
Unverified Commit 655ebdbc authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Redo torchscript example (#7889)

parent 6472a5cb
...@@ -214,7 +214,8 @@ Torchscript support ...@@ -214,7 +214,8 @@ Torchscript support
------------------- -------------------
Most transform classes and functionals support torchscript. For composing Most transform classes and functionals support torchscript. For composing
transforms, use :class:`torch.nn.Sequential` instead of ``Compose``: transforms, use :class:`torch.nn.Sequential` instead of
:class:`~torchvision.transforms.v2.Compose`:
.. code:: python .. code:: python
...@@ -232,7 +233,7 @@ transforms, use :class:`torch.nn.Sequential` instead of ``Compose``: ...@@ -232,7 +233,7 @@ transforms, use :class:`torch.nn.Sequential` instead of ``Compose``:
scripted and eager executions due to implementation differences between v1 scripted and eager executions due to implementation differences between v1
and v2. and v2.
If you really need torchscript support for the v2 tranforms, we recommend If you really need torchscript support for the v2 transforms, we recommend
scripting the **functionals** from the scripting the **functionals** from the
``torchvision.transforms.v2.functional`` namespace to avoid surprises. ``torchvision.transforms.v2.functional`` namespace to avoid surprises.
...@@ -242,7 +243,10 @@ are always treated as images. If you need torchscript support for other types ...@@ -242,7 +243,10 @@ are always treated as images. If you need torchscript support for other types
like bounding boxes or masks, you can rely on the :ref:`low-level kernels like bounding boxes or masks, you can rely on the :ref:`low-level kernels
<functional_transforms>`. <functional_transforms>`.
For any custom transformations to be used with ``torch.jit.script``, they should be derived from ``torch.nn.Module``. For any custom transformations to be used with ``torch.jit.script``, they should
be derived from ``torch.nn.Module``.
See also: :ref:`sphx_glr_auto_examples_others_plot_scripted_tensor_transforms.py`.
V2 API reference - Recommended V2 API reference - Recommended
------------------------------ ------------------------------
......
""" """
========================= ===================
Tensor transforms and JIT Torchscript support
========================= ===================
.. note:: .. note::
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_scripted_tensor_transforms.ipynb>`_ Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_scripted_tensor_transforms.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_others_plot_scripted_tensor_transforms.py>` to download the full example code. or :ref:`go to the end <sphx_glr_download_auto_examples_others_plot_scripted_tensor_transforms.py>` to download the full example code.
This example illustrates various features that are now supported by the This example illustrates `torchscript
:ref:`image transformations <transforms>` on Tensor images. In particular, we <https://pytorch.org/docs/stable/jit.html>`_ support of the torchvision
show how image transforms can be performed on GPU, and how one can also script :ref:`transforms <transforms>` on Tensor images.
them using JIT compilation.
Prior to v0.8.0, transforms in torchvision have traditionally been PIL-centric
and presented multiple limitations due to that. Now, since v0.8.0, transforms
implementations are Tensor and PIL compatible, and we can achieve the following
new features:
- transform multi-band torch tensor images (with more than 3-4 channels)
- torchscript transforms together with your model for deployment
- support for GPU acceleration
- batched transformation such as for videos
- read and decode data directly as torch tensor with torchscript support (for PNG and JPEG image formats)
.. note::
These features are only possible with **Tensor** images.
""" """
# %%
from pathlib import Path from pathlib import Path
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np
import torch import torch
import torchvision.transforms as T import torch.nn as nn
from torchvision.io import read_image
import torchvision.transforms as v1
from torchvision.io import read_image
plt.rcParams["savefig.bbox"] = 'tight' plt.rcParams["savefig.bbox"] = 'tight'
torch.manual_seed(1) torch.manual_seed(1)
# If you're trying to run that on collab, you can download the assets and the
def show(imgs): # helpers from https://github.com/pytorch/vision/tree/main/gallery/
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False) import sys
for i, img in enumerate(imgs): sys.path += ["../transforms"]
img = T.ToPILImage()(img.to('cpu')) from helpers import plot
axs[0, i].imshow(np.asarray(img)) ASSETS_PATH = Path('../assets')
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
# %% # %%
# The :func:`~torchvision.io.read_image` function allows to read an image and # Most transforms support torchscript. For composing transforms, we use
# directly load it as a tensor # :class:`torch.nn.Sequential` instead of
# :class:`~torchvision.transforms.v2.Compose`:
dog1 = read_image(str(Path('../assets') / 'dog1.jpg')) dog1 = read_image(str(ASSETS_PATH / 'dog1.jpg'))
dog2 = read_image(str(Path('../assets') / 'dog2.jpg')) dog2 = read_image(str(ASSETS_PATH / 'dog2.jpg'))
show([dog1, dog2])
# %%
# Transforming images on GPU
# --------------------------
# Most transforms natively support tensors on top of PIL images (to visualize
# the effect of the transforms, you may refer to see
# :ref:`sphx_glr_auto_examples_transforms_plot_transforms_illustrations.py`).
# Using tensor images, we can run the transforms on GPUs if cuda is available!
import torch.nn as nn
transforms = torch.nn.Sequential( transforms = torch.nn.Sequential(
T.RandomCrop(224), v1.RandomCrop(224),
T.RandomHorizontalFlip(p=0.3), v1.RandomHorizontalFlip(p=0.3),
) )
device = 'cuda' if torch.cuda.is_available() else 'cpu' scripted_transforms = torch.jit.script(transforms)
dog1 = dog1.to(device)
dog2 = dog2.to(device) plot([dog1, scripted_transforms(dog1), dog2, scripted_transforms(dog2)])
transformed_dog1 = transforms(dog1)
transformed_dog2 = transforms(dog2)
show([transformed_dog1, transformed_dog2])
# %% # %%
# Scriptable transforms for easier deployment via torchscript # .. warning::
# ----------------------------------------------------------- #
# We now show how to combine image transformations and a model forward pass, # Above we have used transforms from the ``torchvision.transforms``
# while using ``torch.jit.script`` to obtain a single scripted module. # namespace, i.e. the "v1" transforms. The v2 transforms from the
# ``torchvision.transforms.v2`` namespace are the :ref:`recommended
# <v1_or_v2>` way to use transforms in your code.
#
# The v2 transforms also support torchscript, but if you call
# ``torch.jit.script()`` on a v2 **class** transform, you'll actually end up
# with its (scripted) v1 equivalent. This may lead to slightly different
# results between the scripted and eager executions due to implementation
# differences between v1 and v2.
#
# If you really need torchscript support for the v2 transforms, **we
# recommend scripting the functionals** from the
# ``torchvision.transforms.v2.functional`` namespace to avoid surprises.
#
# Below we now show how to combine image transformations and a model forward
# pass, while using ``torch.jit.script`` to obtain a single scripted module.
# #
# Let's define a ``Predictor`` module that transforms the input tensor and then # Let's define a ``Predictor`` module that transforms the input tensor and then
# applies an ImageNet model on it. # applies an ImageNet model on it.
...@@ -98,7 +85,7 @@ class Predictor(nn.Module): ...@@ -98,7 +85,7 @@ class Predictor(nn.Module):
super().__init__() super().__init__()
weights = ResNet18_Weights.DEFAULT weights = ResNet18_Weights.DEFAULT
self.resnet18 = resnet18(weights=weights, progress=False).eval() self.resnet18 = resnet18(weights=weights, progress=False).eval()
self.transforms = weights.transforms() self.transforms = weights.transforms(antialias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad(): with torch.no_grad():
...@@ -111,6 +98,8 @@ class Predictor(nn.Module): ...@@ -111,6 +98,8 @@ class Predictor(nn.Module):
# Now, let's define scripted and non-scripted instances of ``Predictor`` and # Now, let's define scripted and non-scripted instances of ``Predictor`` and
# apply it on multiple tensor images of the same size # apply it on multiple tensor images of the same size
device = "cuda" if torch.cuda.is_available() else "cpu"
predictor = Predictor().to(device) predictor = Predictor().to(device)
scripted_predictor = torch.jit.script(predictor).to(device) scripted_predictor = torch.jit.script(predictor).to(device)
...@@ -143,3 +132,5 @@ with tempfile.NamedTemporaryFile() as f: ...@@ -143,3 +132,5 @@ with tempfile.NamedTemporaryFile() as f:
dumped_scripted_predictor = torch.jit.load(f.name) dumped_scripted_predictor = torch.jit.load(f.name)
res_scripted_dumped = dumped_scripted_predictor(batch) res_scripted_dumped = dumped_scripted_predictor(batch)
assert (res_scripted_dumped == res_scripted).all() assert (res_scripted_dumped == res_scripted).all()
# %%
...@@ -172,6 +172,7 @@ target = { ...@@ -172,6 +172,7 @@ target = {
# Re-using the transforms and definitions from above. # Re-using the transforms and definitions from above.
out_img, out_target = transforms(img, target) out_img, out_target = transforms(img, target)
# sphinx_gallery_thumbnail_number = 4
plot([(img, target["boxes"]), (out_img, out_target["boxes"])]) plot([(img, target["boxes"]), (out_img, out_target["boxes"])])
print(f"{out_target['this_is_ignored']}") print(f"{out_target['this_is_ignored']}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment