plot_custom_datapoints.py 4.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""
=====================================
How to write your own Datapoint class
=====================================

This guide is intended for downstream library maintainers. We explain how to
write your own datapoint class, and how to make it compatible with the built-in
Torchvision v2 transforms. Before continuing, make sure you have read
:ref:`sphx_glr_auto_examples_plot_datapoints.py`.
"""

# %%
import torch
import torchvision

# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
# some APIs may slightly change in the future
torchvision.disable_beta_transforms_warning()

from torchvision import datapoints
from torchvision.transforms import v2

# %%
# We will create a very simple class that just inherits from the base
# :class:`~torchvision.datapoints.Datapoint` class. It will be enough to cover
# what you need to know to implement your more elaborate uses-cases. If you need
# to create a class that carries meta-data, take a look at how the
# :class:`~torchvision.datapoints.BoundingBoxes` class is `implemented
# <https://github.com/pytorch/vision/blob/main/torchvision/datapoints/_bounding_box.py>`_.


class MyDatapoint(datapoints.Datapoint):
    pass


my_dp = MyDatapoint([1, 2, 3])
my_dp

# %%
# Now that we have defined our custom Datapoint class, we want it to be
# compatible with the built-in torchvision transforms, and the functional API.
# For that, we need to implement a kernel which performs the core of the
# transformation, and then "hook" it to the functional that we want to support
# via :func:`~torchvision.transforms.v2.functional.register_kernel`.
#
# We illustrate this process below: we create a kernel for the "horizontal flip"
# operation of our MyDatapoint class, and register it to the functional API.

from torchvision.transforms.v2 import functional as F


Nicolas Hug's avatar
Nicolas Hug committed
52
@F.register_kernel(functional="hflip", datapoint_cls=MyDatapoint)
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def hflip_my_datapoint(my_dp, *args, **kwargs):
    print("Flipping!")
    out = my_dp.flip(-1)
    return MyDatapoint.wrap_like(my_dp, out)


# %%
# To understand why ``wrap_like`` is used, see
# :ref:`datapoint_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now,
# we will explain it below in :ref:`param_forwarding`.
#
# .. note::
#
#     In our call to ``register_kernel`` above we used a string
Nicolas Hug's avatar
Nicolas Hug committed
67
#     ``functional="hflip"`` to refer to the functional we want to hook into. We
68
#     could also have used the  functional *itself*, i.e.
Nicolas Hug's avatar
Nicolas Hug committed
69
#     ``@register_kernel(functional=F.hflip, ...)``.
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#
#     The functionals that you can be hooked into are the ones in
#     ``torchvision.transforms.v2.functional`` and they are documented in
#     :ref:`functional_transforms`.
#
# Now that we have registered our kernel, we can call the functional API on a
# ``MyDatapoint`` instance:

my_dp = MyDatapoint(torch.rand(3, 256, 256))
_ = F.hflip(my_dp)

# %%
# And we can also use the
# :class:`~torchvision.transforms.v2.RandomHorizontalFlip` transform, since it relies on :func:`~torchvision.transforms.v2.functional.hflip` internally:
t = v2.RandomHorizontalFlip(p=1)
_ = t(my_dp)

# %%
# .. note::
#
#     We cannot register a kernel for a transform class, we can only register a
#     kernel for a **functional**. The reason we can't register a transform
#     class is because one transform may internally rely on more than one
#     functional, so in general we can't register a single kernel for a given
#     class.
#
# .. _param_forwarding:
#
# Parameter forwarding, and ensuring future compatibility of your kernels
# -----------------------------------------------------------------------
#
# The functional API that you're hooking into is public and therefore
# **backward** compatible: we guarantee that the parameters of these functionals
# won't be removed or renamed without a proper deprecation cycle. However, we
# don't guarantee **forward** compatibility, and we may add new parameters in
# the future.
#
# Imagine that in a future version, Torchvision adds a new ``inplace`` parameter
# to its :func:`~torchvision.transforms.v2.functional.hflip` functional. If you
# already defined and registered your own kernel as

def hflip_my_datapoint(my_dp):  # noqa
    print("Flipping!")
    out = my_dp.flip(-1)
    return MyDatapoint.wrap_like(my_dp, out)


# %%
# then calling ``F.hflip(my_dp)`` will **fail**, because ``hflip`` will try to
# pass the new ``inplace`` parameter to your kernel, but your kernel doesn't
# accept it.
#
# For this reason, we recommend to always define your kernels with
# ``*args, **kwargs`` in their signature, as done above. This way, your kernel
# will be able to accept any new parameter that we may add in the future.
# (Technically, adding `**kwargs` only should be enough).