plot_custom_tv_tensors.py 4.59 KB
Newer Older
1
2
"""
=====================================
3
How to write your own TVTensor class
4
5
=====================================

Nicolas Hug's avatar
Nicolas Hug committed
6
.. note::
7
8
    Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_custom_tv_tensors.ipynb>`_
    or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_custom_tv_tensors.py>` to download the full example code.
Nicolas Hug's avatar
Nicolas Hug committed
9

10
This guide is intended for advanced users and downstream library maintainers. We explain how to
11
write your own tv_tensor class, and how to make it compatible with the built-in
12
Torchvision v2 transforms. Before continuing, make sure you have read
13
:ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`.
14
15
16
17
"""

# %%
import torch
18
from torchvision import tv_tensors
19
20
21
22
from torchvision.transforms import v2

# %%
# We will create a very simple class that just inherits from the base
23
# :class:`~torchvision.tv_tensors.TVTensor` class. It will be enough to cover
24
25
# what you need to know to implement your more elaborate uses-cases. If you need
# to create a class that carries meta-data, take a look at how the
26
27
# :class:`~torchvision.tv_tensors.BoundingBoxes` class is `implemented
# <https://github.com/pytorch/vision/blob/main/torchvision/tv_tensors/_bounding_box.py>`_.
28
29


30
class MyTVTensor(tv_tensors.TVTensor):
31
32
33
    pass


34
my_dp = MyTVTensor([1, 2, 3])
35
36
37
my_dp

# %%
38
# Now that we have defined our custom TVTensor class, we want it to be
39
40
41
42
43
44
# compatible with the built-in torchvision transforms, and the functional API.
# For that, we need to implement a kernel which performs the core of the
# transformation, and then "hook" it to the functional that we want to support
# via :func:`~torchvision.transforms.v2.functional.register_kernel`.
#
# We illustrate this process below: we create a kernel for the "horizontal flip"
45
# operation of our MyTVTensor class, and register it to the functional API.
46
47
48
49

from torchvision.transforms.v2 import functional as F


50
51
@F.register_kernel(functional="hflip", tv_tensor_cls=MyTVTensor)
def hflip_my_tv_tensor(my_dp, *args, **kwargs):
52
53
    print("Flipping!")
    out = my_dp.flip(-1)
54
    return tv_tensors.wrap(out, like=my_dp)
55
56
57


# %%
58
59
# To understand why :func:`~torchvision.tv_tensors.wrap` is used, see
# :ref:`tv_tensor_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now,
60
61
62
63
64
# we will explain it below in :ref:`param_forwarding`.
#
# .. note::
#
#     In our call to ``register_kernel`` above we used a string
Nicolas Hug's avatar
Nicolas Hug committed
65
#     ``functional="hflip"`` to refer to the functional we want to hook into. We
66
#     could also have used the  functional *itself*, i.e.
Nicolas Hug's avatar
Nicolas Hug committed
67
#     ``@register_kernel(functional=F.hflip, ...)``.
68
69
#
# Now that we have registered our kernel, we can call the functional API on a
70
# ``MyTVTensor`` instance:
71

72
my_dp = MyTVTensor(torch.rand(3, 256, 256))
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
_ = F.hflip(my_dp)

# %%
# And we can also use the
# :class:`~torchvision.transforms.v2.RandomHorizontalFlip` transform, since it relies on :func:`~torchvision.transforms.v2.functional.hflip` internally:
t = v2.RandomHorizontalFlip(p=1)
_ = t(my_dp)

# %%
# .. note::
#
#     We cannot register a kernel for a transform class, we can only register a
#     kernel for a **functional**. The reason we can't register a transform
#     class is because one transform may internally rely on more than one
#     functional, so in general we can't register a single kernel for a given
#     class.
#
# .. _param_forwarding:
#
# Parameter forwarding, and ensuring future compatibility of your kernels
# -----------------------------------------------------------------------
#
# The functional API that you're hooking into is public and therefore
# **backward** compatible: we guarantee that the parameters of these functionals
# won't be removed or renamed without a proper deprecation cycle. However, we
# don't guarantee **forward** compatibility, and we may add new parameters in
# the future.
#
# Imagine that in a future version, Torchvision adds a new ``inplace`` parameter
# to its :func:`~torchvision.transforms.v2.functional.hflip` functional. If you
# already defined and registered your own kernel as

105
def hflip_my_tv_tensor(my_dp):  # noqa
106
107
    print("Flipping!")
    out = my_dp.flip(-1)
108
    return tv_tensors.wrap(out, like=my_dp)
109
110
111
112
113
114
115
116
117
118
119


# %%
# then calling ``F.hflip(my_dp)`` will **fail**, because ``hflip`` will try to
# pass the new ``inplace`` parameter to your kernel, but your kernel doesn't
# accept it.
#
# For this reason, we recommend to always define your kernels with
# ``*args, **kwargs`` in their signature, as done above. This way, your kernel
# will be able to accept any new parameter that we may add in the future.
# (Technically, adding `**kwargs` only should be enough).