plot_scripted_tensor_transforms.py 4.61 KB
Newer Older
1
2
3
4
5
"""
=========================
Tensor transforms and JIT
=========================

Nicolas Hug's avatar
Nicolas Hug committed
6
7
8
9
.. note::
    Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_scripted_tensor_transforms.ipynb>`_
    or :ref:`go to the end <sphx_glr_download_auto_examples_others_plot_scripted_tensor_transforms.py>` to download the full example code.

10
11
12
13
14
15
16
This example illustrates various features that are now supported by the
:ref:`image transformations <transforms>` on Tensor images. In particular, we
show how image transforms can be performed on GPU, and how one can also script
them using JIT compilation.

Prior to v0.8.0, transforms in torchvision have traditionally been PIL-centric
and presented multiple limitations due to that. Now, since v0.8.0, transforms
17
implementations are Tensor and PIL compatible, and we can achieve the following
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
new features:

- transform multi-band torch tensor images (with more than 3-4 channels)
- torchscript transforms together with your model for deployment
- support for GPU acceleration
- batched transformation such as for videos
- read and decode data directly as torch tensor with torchscript support (for PNG and JPEG image formats)

.. note::
    These features are only possible with **Tensor** images.
"""

from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

import torch
import torchvision.transforms as T
from torchvision.io import read_image


plt.rcParams["savefig.bbox"] = 'tight'
41
torch.manual_seed(1)
42
43
44
45
46
47
48
49
50
51


def show(imgs):
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = T.ToPILImage()(img.to('cpu'))
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])


52
# %%
53
54
55
# The :func:`~torchvision.io.read_image` function allows to read an image and
# directly load it as a tensor

56
57
dog1 = read_image(str(Path('../assets') / 'dog1.jpg'))
dog2 = read_image(str(Path('../assets') / 'dog2.jpg'))
58
59
show([dog1, dog2])

60
# %%
61
62
63
64
# Transforming images on GPU
# --------------------------
# Most transforms natively support tensors on top of PIL images (to visualize
# the effect of the transforms, you may refer to see
65
# :ref:`sphx_glr_auto_examples_transforms_plot_transforms_illustrations.py`).
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# Using tensor images, we can run the transforms on GPUs if cuda is available!

import torch.nn as nn

transforms = torch.nn.Sequential(
    T.RandomCrop(224),
    T.RandomHorizontalFlip(p=0.3),
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
dog1 = dog1.to(device)
dog2 = dog2.to(device)

transformed_dog1 = transforms(dog1)
transformed_dog2 = transforms(dog2)
show([transformed_dog1, transformed_dog2])

83
# %%
84
85
86
87
88
89
90
91
# Scriptable transforms for easier deployment via torchscript
# -----------------------------------------------------------
# We now show how to combine image transformations and a model forward pass,
# while using ``torch.jit.script`` to obtain a single scripted module.
#
# Let's define a ``Predictor`` module that transforms the input tensor and then
# applies an ImageNet model on it.

92
from torchvision.models import resnet18, ResNet18_Weights
93
94
95
96
97
98


class Predictor(nn.Module):

    def __init__(self):
        super().__init__()
99
100
101
        weights = ResNet18_Weights.DEFAULT
        self.resnet18 = resnet18(weights=weights, progress=False).eval()
        self.transforms = weights.transforms()
102
103
104
105
106
107
108
109

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            x = self.transforms(x)
            y_pred = self.resnet18(x)
            return y_pred.argmax(dim=1)


110
# %%
111
112
113
114
115
116
117
118
119
120
121
# Now, let's define scripted and non-scripted instances of ``Predictor`` and
# apply it on multiple tensor images of the same size

predictor = Predictor().to(device)
scripted_predictor = torch.jit.script(predictor).to(device)

batch = torch.stack([dog1, dog2]).to(device)

res = predictor(batch)
res_scripted = scripted_predictor(batch)

122
# %%
123
124
125
126
127
# We can verify that the prediction of the scripted and non-scripted models are
# the same:

import json

128
with open(Path('../assets') / 'imagenet_class_index.json') as labels_file:
129
130
131
132
133
134
    labels = json.load(labels_file)

for i, (pred, pred_scripted) in enumerate(zip(res, res_scripted)):
    assert pred == pred_scripted
    print(f"Prediction for Dog {i + 1}: {labels[str(pred.item())]}")

135
# %%
136
# Since the model is scripted, it can be easily dumped on disk and re-used
137
138
139
140
141
142
143
144
145

import tempfile

with tempfile.NamedTemporaryFile() as f:
    scripted_predictor.save(f.name)

    dumped_scripted_predictor = torch.jit.load(f.name)
    res_scripted_dumped = dumped_scripted_predictor(batch)
assert (res_scripted_dumped == res_scripted).all()