plot_custom_datapoints.py 4.6 KB
Newer Older
1
2
3
4
5
"""
=====================================
How to write your own Datapoint class
=====================================

Nicolas Hug's avatar
Nicolas Hug committed
6
7
.. note::
    Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_custom_datapoints.ipynb>`_
Nicolas Hug's avatar
Nicolas Hug committed
8
    or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_custom_datapoints.py>` to download the full example code.
Nicolas Hug's avatar
Nicolas Hug committed
9

10
This guide is intended for advanced users and downstream library maintainers. We explain how to
11
12
write your own datapoint class, and how to make it compatible with the built-in
Torchvision v2 transforms. Before continuing, make sure you have read
Nicolas Hug's avatar
Nicolas Hug committed
13
:ref:`sphx_glr_auto_examples_transforms_plot_datapoints.py`.
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""

# %%
import torch
from torchvision import datapoints
from torchvision.transforms import v2

# %%
# We will create a very simple class that just inherits from the base
# :class:`~torchvision.datapoints.Datapoint` class. It will be enough to cover
# what you need to know to implement your more elaborate uses-cases. If you need
# to create a class that carries meta-data, take a look at how the
# :class:`~torchvision.datapoints.BoundingBoxes` class is `implemented
# <https://github.com/pytorch/vision/blob/main/torchvision/datapoints/_bounding_box.py>`_.


class MyDatapoint(datapoints.Datapoint):
    pass


my_dp = MyDatapoint([1, 2, 3])
my_dp

# %%
# Now that we have defined our custom Datapoint class, we want it to be
# compatible with the built-in torchvision transforms, and the functional API.
# For that, we need to implement a kernel which performs the core of the
# transformation, and then "hook" it to the functional that we want to support
# via :func:`~torchvision.transforms.v2.functional.register_kernel`.
#
# We illustrate this process below: we create a kernel for the "horizontal flip"
# operation of our MyDatapoint class, and register it to the functional API.

from torchvision.transforms.v2 import functional as F


Nicolas Hug's avatar
Nicolas Hug committed
50
@F.register_kernel(functional="hflip", datapoint_cls=MyDatapoint)
51
52
53
def hflip_my_datapoint(my_dp, *args, **kwargs):
    print("Flipping!")
    out = my_dp.flip(-1)
54
    return datapoints.wrap(out, like=my_dp)
55
56
57


# %%
58
# To understand why :func:`~torchvision.datapoints.wrap` is used, see
59
60
61
62
63
64
# :ref:`datapoint_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now,
# we will explain it below in :ref:`param_forwarding`.
#
# .. note::
#
#     In our call to ``register_kernel`` above we used a string
Nicolas Hug's avatar
Nicolas Hug committed
65
#     ``functional="hflip"`` to refer to the functional we want to hook into. We
66
#     could also have used the  functional *itself*, i.e.
Nicolas Hug's avatar
Nicolas Hug committed
67
#     ``@register_kernel(functional=F.hflip, ...)``.
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#
# Now that we have registered our kernel, we can call the functional API on a
# ``MyDatapoint`` instance:

my_dp = MyDatapoint(torch.rand(3, 256, 256))
_ = F.hflip(my_dp)

# %%
# And we can also use the
# :class:`~torchvision.transforms.v2.RandomHorizontalFlip` transform, since it relies on :func:`~torchvision.transforms.v2.functional.hflip` internally:
t = v2.RandomHorizontalFlip(p=1)
_ = t(my_dp)

# %%
# .. note::
#
#     We cannot register a kernel for a transform class, we can only register a
#     kernel for a **functional**. The reason we can't register a transform
#     class is because one transform may internally rely on more than one
#     functional, so in general we can't register a single kernel for a given
#     class.
#
# .. _param_forwarding:
#
# Parameter forwarding, and ensuring future compatibility of your kernels
# -----------------------------------------------------------------------
#
# The functional API that you're hooking into is public and therefore
# **backward** compatible: we guarantee that the parameters of these functionals
# won't be removed or renamed without a proper deprecation cycle. However, we
# don't guarantee **forward** compatibility, and we may add new parameters in
# the future.
#
# Imagine that in a future version, Torchvision adds a new ``inplace`` parameter
# to its :func:`~torchvision.transforms.v2.functional.hflip` functional. If you
# already defined and registered your own kernel as

def hflip_my_datapoint(my_dp):  # noqa
    print("Flipping!")
    out = my_dp.flip(-1)
108
    return datapoints.wrap(out, like=my_dp)
109
110
111
112
113
114
115
116
117
118
119


# %%
# then calling ``F.hflip(my_dp)`` will **fail**, because ``hflip`` will try to
# pass the new ``inplace`` parameter to your kernel, but your kernel doesn't
# accept it.
#
# For this reason, we recommend to always define your kernels with
# ``*args, **kwargs`` in their signature, as done above. This way, your kernel
# will be able to accept any new parameter that we may add in the future.
# (Technically, adding `**kwargs` only should be enough).