plot_datapoints.py 6.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
"""
==============
Datapoints FAQ
==============

The :mod:`torchvision.datapoints` namespace was introduced together with ``torchvision.transforms.v2``. This example
showcases what these datapoints are and how they behave. This is a fairly low-level topic that most users will not need
to worry about: you do not need to understand the internals of datapoints to efficiently rely on
``torchvision.transforms.v2``. It may however be useful for advanced users trying to implement their own datasets,
transforms, or work directly with the datapoints.
"""

import PIL.Image

import torch
import torchvision

# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
# some APIs may slightly change in the future
torchvision.disable_beta_transforms_warning()

from torchvision import datapoints
23
from torchvision.transforms.v2 import functional as F
24
25


26
# %%
27
28
29
30
31
32
33
34
35
36
37
38
# What are datapoints?
# --------------------
#
# Datapoints are zero-copy tensor subclasses:

tensor = torch.rand(3, 256, 256)
image = datapoints.Image(tensor)

assert isinstance(image, torch.Tensor)
assert image.data_ptr() == tensor.data_ptr()


39
# %%
40
41
42
43
44
45
46
47
48
49
# Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
# for the input data.
#
# What datapoints are supported?
# ------------------------------
#
# So far :mod:`torchvision.datapoints` supports four types of datapoints:
#
# * :class:`~torchvision.datapoints.Image`
# * :class:`~torchvision.datapoints.Video`
50
# * :class:`~torchvision.datapoints.BoundingBoxes`
51
52
53
54
55
56
57
58
59
60
61
# * :class:`~torchvision.datapoints.Mask`
#
# How do I construct a datapoint?
# -------------------------------
#
# Each datapoint class takes any tensor-like data that can be turned into a :class:`~torch.Tensor`

image = datapoints.Image([[[[0, 1], [1, 0]]]])
print(image)


62
# %%
63
64
65
66
67
68
69
# Similar to other PyTorch creations ops, the constructor also takes the ``dtype``, ``device``, and ``requires_grad``
# parameters.

float_image = datapoints.Image([[[0, 1], [1, 0]]], dtype=torch.float32, requires_grad=True)
print(float_image)


70
# %%
71
72
73
74
75
76
# In addition, :class:`~torchvision.datapoints.Image` and :class:`~torchvision.datapoints.Mask` also take a
# :class:`PIL.Image.Image` directly:

image = datapoints.Image(PIL.Image.open("assets/astronaut.jpg"))
print(image.shape, image.dtype)

77
# %%
78
# In general, the datapoints can also store additional metadata that complements the underlying tensor. For example,
79
# :class:`~torchvision.datapoints.BoundingBoxes` stores the coordinate format as well as the spatial size of the
80
81
# corresponding image alongside the actual values:

82
bounding_box = datapoints.BoundingBoxes(
Philip Meier's avatar
Philip Meier committed
83
    [17, 16, 344, 495], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=image.shape[-2:]
84
85
86
87
)
print(bounding_box)


88
# %%
89
90
91
92
93
94
95
96
# Do I have to wrap the output of the datasets myself?
# ----------------------------------------------------
#
# Only if you are using custom datasets. For the built-in ones, you can use
# :func:`torchvision.datasets.wrap_dataset_for_transforms_v2`. Note that the function also supports subclasses of the
# built-in datasets. Meaning, if your custom dataset subclasses from a built-in one and the output type is the same, you
# also don't have to wrap manually.
#
97
98
99
100
101
102
103
104
105
106
107
# If you have a custom dataset, for example the ``PennFudanDataset`` from
# `this tutorial <https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html>`_, you have two options:
#
# 1. Perform the wrapping inside ``__getitem__``:

class PennFudanDataset(torch.utils.data.Dataset):
    ...

    def __getitem__(self, item):
        ...

108
        target["boxes"] = datapoints.BoundingBoxes(
109
110
            boxes,
            format=datapoints.BoundingBoxFormat.XYXY,
Philip Meier's avatar
Philip Meier committed
111
            canvas_size=F.get_size(img),
112
113
114
115
116
117
118
119
120
121
122
        )
        target["labels"] = labels
        target["masks"] = datapoints.Mask(masks)

        ...

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        ...

123
# %%
124
125
126
127
128
# 2. Perform the wrapping inside a custom transformation at the beginning of your pipeline:


class WrapPennFudanDataset:
    def __call__(self, img, target):
129
        target["boxes"] = datapoints.BoundingBoxes(
130
131
            target["boxes"],
            format=datapoints.BoundingBoxFormat.XYXY,
Philip Meier's avatar
Philip Meier committed
132
            canvas_size=F.get_size(img),
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        )
        target["masks"] = datapoints.Mask(target["masks"])
        return img, target


...


def get_transform(train):
    transforms = []
    transforms.append(WrapPennFudanDataset())
    transforms.append(T.PILToTensor())
    ...

147
# %%
148
149
# .. note::
#
150
#    If both :class:`~torchvision.datapoints.BoundingBoxes`'es and :class:`~torchvision.datapoints.Mask`'s are included in
151
152
153
154
155
156
157
158
#    the sample, ``torchvision.transforms.v2`` will transform them both. Meaning, if you don't need both, dropping or
#    at least not wrapping the obsolete parts, can lead to a significant performance boost.
#
#    For example, if you are using the ``PennFudanDataset`` for object detection, not wrapping the masks avoids
#    transforming them over and over again in the pipeline just to ultimately ignoring them. In general, it would be
#    even better to not load the masks at all, but this is not possible in this example, since the bounding boxes are
#    generated from the masks.
#
159
160
161
162
163
164
165
166
# How do the datapoints behave inside a computation?
# --------------------------------------------------
#
# Datapoints look and feel just like regular tensors. Everything that is supported on a plain :class:`torch.Tensor`
# also works on datapoints.
# Since for most operations involving datapoints, it cannot be safely inferred whether the result should retain the
# datapoint type, we choose to return a plain tensor instead of a datapoint (this might change, see note below):

167

168
169
170
171
172
173
assert isinstance(image, datapoints.Image)

new_image = image + 0

assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image)

174
# %%
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# .. note::
#
#    This "unwrapping" behaviour is something we're actively seeking feedback on. If you find this surprising or if you
#    have any suggestions on how to better support your use-cases, please reach out to us via this issue:
#    https://github.com/pytorch/vision/issues/7319
#
# There are two exceptions to this rule:
#
# 1. The operations :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`, and :meth:`~torch.Tensor.requires_grad_`
#    retain the datapoint type.
# 2. Inplace operations on datapoints cannot change the type of the datapoint they are called on. However, if you use
#    the flow style, the returned value will be unwrapped:

image = datapoints.Image([[[0, 1], [1, 0]]])

new_image = image.add_(1).mul_(2)

assert isinstance(image, torch.Tensor)
print(image)

assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image)
assert (new_image == image).all()