plot_datapoints.py 7.73 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
"""
==============
Datapoints FAQ
==============

The :mod:`torchvision.datapoints` namespace was introduced together with ``torchvision.transforms.v2``. This example
showcases what these datapoints are and how they behave. This is a fairly low-level topic that most users will not need
to worry about: you do not need to understand the internals of datapoints to efficiently rely on
``torchvision.transforms.v2``. It may however be useful for advanced users trying to implement their own datasets,
transforms, or work directly with the datapoints.
"""

import PIL.Image

import torch
import torchvision

# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
# some APIs may slightly change in the future
torchvision.disable_beta_transforms_warning()

from torchvision import datapoints
23
from torchvision.transforms.v2 import functional as F
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49


########################################################################################################################
# What are datapoints?
# --------------------
#
# Datapoints are zero-copy tensor subclasses:

tensor = torch.rand(3, 256, 256)
image = datapoints.Image(tensor)

assert isinstance(image, torch.Tensor)
assert image.data_ptr() == tensor.data_ptr()


########################################################################################################################
# Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
# for the input data.
#
# What datapoints are supported?
# ------------------------------
#
# So far :mod:`torchvision.datapoints` supports four types of datapoints:
#
# * :class:`~torchvision.datapoints.Image`
# * :class:`~torchvision.datapoints.Video`
50
# * :class:`~torchvision.datapoints.BoundingBoxes`
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# * :class:`~torchvision.datapoints.Mask`
#
# How do I construct a datapoint?
# -------------------------------
#
# Each datapoint class takes any tensor-like data that can be turned into a :class:`~torch.Tensor`

image = datapoints.Image([[[[0, 1], [1, 0]]]])
print(image)


########################################################################################################################
# Similar to other PyTorch creations ops, the constructor also takes the ``dtype``, ``device``, and ``requires_grad``
# parameters.

float_image = datapoints.Image([[[0, 1], [1, 0]]], dtype=torch.float32, requires_grad=True)
print(float_image)


########################################################################################################################
# In addition, :class:`~torchvision.datapoints.Image` and :class:`~torchvision.datapoints.Mask` also take a
# :class:`PIL.Image.Image` directly:

image = datapoints.Image(PIL.Image.open("assets/astronaut.jpg"))
print(image.shape, image.dtype)

########################################################################################################################
# In general, the datapoints can also store additional metadata that complements the underlying tensor. For example,
79
# :class:`~torchvision.datapoints.BoundingBoxes` stores the coordinate format as well as the spatial size of the
80
81
# corresponding image alongside the actual values:

82
bounding_box = datapoints.BoundingBoxes(
Philip Meier's avatar
Philip Meier committed
83
    [17, 16, 344, 495], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=image.shape[-2:]
84
85
86
87
88
89
90
91
92
93
94
95
96
)
print(bounding_box)


########################################################################################################################
# Do I have to wrap the output of the datasets myself?
# ----------------------------------------------------
#
# Only if you are using custom datasets. For the built-in ones, you can use
# :func:`torchvision.datasets.wrap_dataset_for_transforms_v2`. Note that the function also supports subclasses of the
# built-in datasets. Meaning, if your custom dataset subclasses from a built-in one and the output type is the same, you
# also don't have to wrap manually.
#
97
98
99
100
101
102
103
104
105
106
107
# If you have a custom dataset, for example the ``PennFudanDataset`` from
# `this tutorial <https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html>`_, you have two options:
#
# 1. Perform the wrapping inside ``__getitem__``:

class PennFudanDataset(torch.utils.data.Dataset):
    ...

    def __getitem__(self, item):
        ...

108
        target["boxes"] = datapoints.BoundingBoxes(
109
110
            boxes,
            format=datapoints.BoundingBoxFormat.XYXY,
Philip Meier's avatar
Philip Meier committed
111
            canvas_size=F.get_size(img),
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        )
        target["labels"] = labels
        target["masks"] = datapoints.Mask(masks)

        ...

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        ...

########################################################################################################################
# 2. Perform the wrapping inside a custom transformation at the beginning of your pipeline:


class WrapPennFudanDataset:
    def __call__(self, img, target):
129
        target["boxes"] = datapoints.BoundingBoxes(
130
131
            target["boxes"],
            format=datapoints.BoundingBoxFormat.XYXY,
Philip Meier's avatar
Philip Meier committed
132
            canvas_size=F.get_size(img),
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        )
        target["masks"] = datapoints.Mask(target["masks"])
        return img, target


...


def get_transform(train):
    transforms = []
    transforms.append(WrapPennFudanDataset())
    transforms.append(T.PILToTensor())
    ...

########################################################################################################################
# .. note::
#
150
#    If both :class:`~torchvision.datapoints.BoundingBoxes`'es and :class:`~torchvision.datapoints.Mask`'s are included in
151
152
153
154
155
156
157
158
#    the sample, ``torchvision.transforms.v2`` will transform them both. Meaning, if you don't need both, dropping or
#    at least not wrapping the obsolete parts, can lead to a significant performance boost.
#
#    For example, if you are using the ``PennFudanDataset`` for object detection, not wrapping the masks avoids
#    transforming them over and over again in the pipeline just to ultimately ignoring them. In general, it would be
#    even better to not load the masks at all, but this is not possible in this example, since the bounding boxes are
#    generated from the masks.
#
159
160
161
162
163
164
165
166
# How do the datapoints behave inside a computation?
# --------------------------------------------------
#
# Datapoints look and feel just like regular tensors. Everything that is supported on a plain :class:`torch.Tensor`
# also works on datapoints.
# Since for most operations involving datapoints, it cannot be safely inferred whether the result should retain the
# datapoint type, we choose to return a plain tensor instead of a datapoint (this might change, see note below):

167

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
assert isinstance(image, datapoints.Image)

new_image = image + 0

assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image)

########################################################################################################################
# .. note::
#
#    This "unwrapping" behaviour is something we're actively seeking feedback on. If you find this surprising or if you
#    have any suggestions on how to better support your use-cases, please reach out to us via this issue:
#    https://github.com/pytorch/vision/issues/7319
#
# There are two exceptions to this rule:
#
# 1. The operations :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`, and :meth:`~torch.Tensor.requires_grad_`
#    retain the datapoint type.
# 2. Inplace operations on datapoints cannot change the type of the datapoint they are called on. However, if you use
#    the flow style, the returned value will be unwrapped:

image = datapoints.Image([[[0, 1], [1, 0]]])

new_image = image.add_(1).mul_(2)

assert isinstance(image, torch.Tensor)
print(image)

assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image)
assert (new_image == image).all()