plot_transforms_e2e.py 6.34 KB
Newer Older
1
"""
2
3
4
===============================================================
Transforms v2: End-to-end object detection/segmentation example
===============================================================
5

Nicolas Hug's avatar
Nicolas Hug committed
6
.. note::
Nicolas Hug's avatar
Nicolas Hug committed
7
8
    Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_transforms_e2e.ipynb>`_
    or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_transforms_e2e.py>` to download the full example code.
Nicolas Hug's avatar
Nicolas Hug committed
9

10
11
12
13
14
15
16
17
Object detection and segmentation tasks are natively supported:
``torchvision.transforms.v2`` enables jointly transforming images, videos,
bounding boxes, and masks.

This example showcases an end-to-end instance segmentation training case using
Torchvision utils from ``torchvision.datasets``, ``torchvision.models`` and
``torchvision.transforms.v2``. Everything covered here can be applied similarly
to object detection or semantic segmentation tasks.
18
19
"""

20
# %%
21
22
23
24
25
import pathlib

import torch
import torch.utils.data

26
from torchvision import models, datasets, tv_tensors
27
from torchvision.transforms import v2
28

29
torch.manual_seed(0)
30

31
32
33
34
35
36
37
38
# This loads fake data for illustration purposes of this example. In practice, you'll have
# to replace this with the proper data.
# If you're trying to run that on collab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
ROOT = pathlib.Path("../assets") / "coco"
IMAGES_PATH = str(ROOT / "images")
ANNOTATIONS_PATH = str(ROOT / "instances.json")
from helpers import plot
39
40


41
# %%
42
43
44
# Dataset preparation
# -------------------
#
45
# We start off by loading the :class:`~torchvision.datasets.CocoDetection` dataset to have a look at what it currently
46
# returns.
47

48
dataset = datasets.CocoDetection(IMAGES_PATH, ANNOTATIONS_PATH)
49
50

sample = dataset[0]
51
52
img, target = sample
print(f"{type(img) = }\n{type(target) = }\n{type(target[0]) = }\n{target[0].keys() = }")
53
54


55
# %%
56
57
58
59
60
# Torchvision datasets preserve the data structure and types as it was intended
# by the datasets authors. So by default, the output structure may not always be
# compatible with the models or the transforms.
#
# To overcome that, we can use the
61
# :func:`~torchvision.datasets.wrap_dataset_for_transforms_v2` function. For
62
63
# :class:`~torchvision.datasets.CocoDetection`, this changes the target
# structure to a single dictionary of lists:
64

65
dataset = datasets.wrap_dataset_for_transforms_v2(dataset, target_keys=("boxes", "labels", "masks"))
66
67

sample = dataset[0]
68
69
70
img, target = sample
print(f"{type(img) = }\n{type(target) = }\n{target.keys() = }")
print(f"{type(target['boxes']) = }\n{type(target['labels']) = }\n{type(target['masks']) = }")
71

72
# %%
73
74
# We used the ``target_keys`` parameter to specify the kind of output we're
# interested in. Our dataset now returns a target which is dict where the values
75
# are :ref:`TVTensors <what_are_tv_tensors>` (all are :class:`torch.Tensor`
76
77
78
79
80
81
82
83
84
85
# subclasses). We're dropped all unncessary keys from the previous output, but
# if you need any of the original keys e.g. "image_id", you can still ask for
# it.
#
# .. note::
#
#     If you just want to do detection, you don't need and shouldn't pass
#     "masks" in ``target_keys``: if masks are present in the sample, they will
#     be transformed, slowing down your transformations unnecessarily.
#
86
87
# As baseline, let's have a look at a sample without transformations:

88
plot([dataset[0], dataset[1]])
89
90


91
# %%
92
93
94
95
96
97
98
99
100
# Transforms
# ----------
#
# Let's now define our pre-processing transforms. All the transforms know how
# to handle images, bouding boxes and masks when relevant.
#
# Transforms are typically passed as the ``transforms`` parameter of the
# dataset so that they can leverage multi-processing from the
# :class:`torch.utils.data.DataLoader`.
101

102
transforms = v2.Compose(
103
    [
104
105
        v2.ToImage(),
        v2.RandomPhotometricDistort(p=1),
106
        v2.RandomZoomOut(fill={tv_tensors.Image: (123, 117, 104), "others": 0}),
107
108
109
110
        v2.RandomIoUCrop(),
        v2.RandomHorizontalFlip(p=1),
        v2.SanitizeBoundingBoxes(),
        v2.ToDtype(torch.float32, scale=True),
111
112
113
    ]
)

114
115
116
dataset = datasets.CocoDetection(IMAGES_PATH, ANNOTATIONS_PATH, transforms=transforms)
dataset = datasets.wrap_dataset_for_transforms_v2(dataset, target_keys=["boxes", "labels", "masks"])

117
# %%
118
119
120
121
122
123
124
125
126
127
128
129
# A few things are worth noting here:
#
# - We're converting the PIL image into a
#   :class:`~torchvision.transforms.v2.Image` object. This isn't strictly
#   necessary, but relying on Tensors (here: a Tensor subclass) will
#   :ref:`generally be faster <transforms_perf>`.
# - We are calling :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` to
#   make sure we remove degenerate bounding boxes, as well as their
#   corresponding labels and masks.
#   :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` should be placed
#   at least once at the end of a detection pipeline; it is particularly
#   critical if :class:`~torchvision.transforms.v2.RandomIoUCrop` was used.
130
131
132
#
# Let's look how the sample looks like with our augmentation pipeline in place:

133
# sphinx_gallery_thumbnail_number = 2
134
plot([dataset[0], dataset[1]])
135
136


137
# %%
138
139
140
141
142
143
144
145
146
# We can see that the color of the images were distorted, zoomed in or out, and flipped.
# The bounding boxes and the masks were transformed accordingly. And without any further ado, we can start training.
#
# Data loading and training loop
# ------------------------------
#
# Below we're using Mask-RCNN which is an instance segmentation model, but
# everything we've covered in this tutorial also applies to object detection and
# semantic segmentation tasks.
147
148
149
150

data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=2,
151
152
153
154
155
    # We need a custom collation function here, since the object detection
    # models expect a sequence of images and target dictionaries. The default
    # collation function tries to torch.stack() the individual elements,
    # which fails in general for object detection, because the number of bouding
    # boxes varies between the images of a same batch.
156
157
158
    collate_fn=lambda batch: tuple(zip(*batch)),
)

159
model = models.get_model("maskrcnn_resnet50_fpn_v2", weights=None, weights_backbone=None).train()
160

161
162
for imgs, targets in data_loader:
    loss_dict = model(imgs, targets)
163
    # Put your training logic here
164
165
166
167
168

    print(f"{[img.shape for img in imgs] = }")
    print(f"{[type(target) for target in targets] = }")
    for name, loss_val in loss_dict.items():
        print(f"{name:<20}{loss_val:.3f}")