plot_scripted_tensor_transforms.py 4.18 KB
Newer Older
1
"""
Nicolas Hug's avatar
Nicolas Hug committed
2
3
4
===================
Torchscript support
===================
5

Nicolas Hug's avatar
Nicolas Hug committed
6
7
8
9
.. note::
    Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_scripted_tensor_transforms.ipynb>`_
    or :ref:`go to the end <sphx_glr_download_auto_examples_others_plot_scripted_tensor_transforms.py>` to download the full example code.

Nicolas Hug's avatar
Nicolas Hug committed
10
11
12
This example illustrates `torchscript
<https://pytorch.org/docs/stable/jit.html>`_ support of the torchvision
:ref:`transforms <transforms>` on Tensor images.
13
14
"""

Nicolas Hug's avatar
Nicolas Hug committed
15
# %%
16
17
18
19
20
from pathlib import Path

import matplotlib.pyplot as plt

import torch
Nicolas Hug's avatar
Nicolas Hug committed
21
import torch.nn as nn
22

Nicolas Hug's avatar
Nicolas Hug committed
23
24
import torchvision.transforms as v1
from torchvision.io import read_image
25
26

plt.rcParams["savefig.bbox"] = 'tight'
27
torch.manual_seed(1)
28

Nicolas Hug's avatar
Nicolas Hug committed
29
30
31
32
33
34
# If you're trying to run that on collab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
import sys
sys.path += ["../transforms"]
from helpers import plot
ASSETS_PATH = Path('../assets')
35
36


37
# %%
Nicolas Hug's avatar
Nicolas Hug committed
38
39
40
# Most transforms support torchscript. For composing transforms, we use
# :class:`torch.nn.Sequential` instead of
# :class:`~torchvision.transforms.v2.Compose`:
41

Nicolas Hug's avatar
Nicolas Hug committed
42
43
dog1 = read_image(str(ASSETS_PATH / 'dog1.jpg'))
dog2 = read_image(str(ASSETS_PATH / 'dog2.jpg'))
44
45

transforms = torch.nn.Sequential(
Nicolas Hug's avatar
Nicolas Hug committed
46
47
    v1.RandomCrop(224),
    v1.RandomHorizontalFlip(p=0.3),
48
49
)

Nicolas Hug's avatar
Nicolas Hug committed
50
51
52
scripted_transforms = torch.jit.script(transforms)

plot([dog1, scripted_transforms(dog1), dog2, scripted_transforms(dog2)])
53
54


55
# %%
Nicolas Hug's avatar
Nicolas Hug committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# .. warning::
#
#     Above we have used transforms from the ``torchvision.transforms``
#     namespace, i.e. the "v1" transforms. The v2 transforms from the
#     ``torchvision.transforms.v2`` namespace are the :ref:`recommended
#     <v1_or_v2>` way to use transforms in your code.
#
#     The v2 transforms also support torchscript, but if you call
#     ``torch.jit.script()`` on a v2 **class** transform, you'll actually end up
#     with its (scripted) v1 equivalent.  This may lead to slightly different
#     results between the scripted and eager executions due to implementation
#     differences between v1 and v2.
#
#     If you really need torchscript support for the v2 transforms, **we
#     recommend scripting the functionals** from the
#     ``torchvision.transforms.v2.functional`` namespace to avoid surprises.
#
# Below we now show how to combine image transformations and a model forward
# pass, while using ``torch.jit.script`` to obtain a single scripted module.
75
76
77
78
#
# Let's define a ``Predictor`` module that transforms the input tensor and then
# applies an ImageNet model on it.

79
from torchvision.models import resnet18, ResNet18_Weights
80
81
82
83
84
85


class Predictor(nn.Module):

    def __init__(self):
        super().__init__()
86
87
        weights = ResNet18_Weights.DEFAULT
        self.resnet18 = resnet18(weights=weights, progress=False).eval()
Nicolas Hug's avatar
Nicolas Hug committed
88
        self.transforms = weights.transforms(antialias=True)
89
90
91
92
93
94
95
96

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            x = self.transforms(x)
            y_pred = self.resnet18(x)
            return y_pred.argmax(dim=1)


97
# %%
98
99
100
# Now, let's define scripted and non-scripted instances of ``Predictor`` and
# apply it on multiple tensor images of the same size

Nicolas Hug's avatar
Nicolas Hug committed
101
102
device = "cuda" if torch.cuda.is_available() else "cpu"

103
104
105
106
107
108
109
110
predictor = Predictor().to(device)
scripted_predictor = torch.jit.script(predictor).to(device)

batch = torch.stack([dog1, dog2]).to(device)

res = predictor(batch)
res_scripted = scripted_predictor(batch)

111
# %%
112
113
114
115
116
# We can verify that the prediction of the scripted and non-scripted models are
# the same:

import json

117
with open(Path('../assets') / 'imagenet_class_index.json') as labels_file:
118
119
120
121
122
123
    labels = json.load(labels_file)

for i, (pred, pred_scripted) in enumerate(zip(res, res_scripted)):
    assert pred == pred_scripted
    print(f"Prediction for Dog {i + 1}: {labels[str(pred.item())]}")

124
# %%
125
# Since the model is scripted, it can be easily dumped on disk and re-used
126
127
128
129
130
131
132
133
134

import tempfile

with tempfile.NamedTemporaryFile() as f:
    scripted_predictor.save(f.name)

    dumped_scripted_predictor = torch.jit.load(f.name)
    res_scripted_dumped = dumped_scripted_predictor(batch)
assert (res_scripted_dumped == res_scripted).all()
Nicolas Hug's avatar
Nicolas Hug committed
135
136

# %%