Unverified Commit 84db2ac4 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add tuto for custom transforms and custom datapoints in gallery example (#7795)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent bf03f4ed
......@@ -320,7 +320,7 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
used within the autoclass directive.
"""
if obj.__name__.endswith(("_Weights", "_QuantizedWeights")):
if getattr(obj, ".__name__", "").endswith(("_Weights", "_QuantizedWeights")):
if len(obj) == 0:
lines[:] = ["There are no available pre-trained weights."]
......
......@@ -17,3 +17,4 @@ see e.g. :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`.
BoundingBoxFormat
BoundingBoxes
Mask
Datapoint
......@@ -375,3 +375,14 @@ you can use a functional transform to build transform classes with custom behavi
to_pil_image
to_tensor
vflip
Developer tools
---------------
.. currentmodule:: torchvision.transforms.v2.functional
.. autosummary::
:toctree: generated/
:template: function.rst
register_kernel
"""
=====================================
How to write your own Datapoint class
=====================================
This guide is intended for downstream library maintainers. We explain how to
write your own datapoint class, and how to make it compatible with the built-in
Torchvision v2 transforms. Before continuing, make sure you have read
:ref:`sphx_glr_auto_examples_plot_datapoints.py`.
"""
# %%
import torch
import torchvision
# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
# some APIs may slightly change in the future
torchvision.disable_beta_transforms_warning()
from torchvision import datapoints
from torchvision.transforms import v2
# %%
# We will create a very simple class that just inherits from the base
# :class:`~torchvision.datapoints.Datapoint` class. It will be enough to cover
# what you need to know to implement your more elaborate uses-cases. If you need
# to create a class that carries meta-data, take a look at how the
# :class:`~torchvision.datapoints.BoundingBoxes` class is `implemented
# <https://github.com/pytorch/vision/blob/main/torchvision/datapoints/_bounding_box.py>`_.
class MyDatapoint(datapoints.Datapoint):
pass
my_dp = MyDatapoint([1, 2, 3])
my_dp
# %%
# Now that we have defined our custom Datapoint class, we want it to be
# compatible with the built-in torchvision transforms, and the functional API.
# For that, we need to implement a kernel which performs the core of the
# transformation, and then "hook" it to the functional that we want to support
# via :func:`~torchvision.transforms.v2.functional.register_kernel`.
#
# We illustrate this process below: we create a kernel for the "horizontal flip"
# operation of our MyDatapoint class, and register it to the functional API.
from torchvision.transforms.v2 import functional as F
@F.register_kernel(dispatcher="hflip", datapoint_cls=MyDatapoint)
def hflip_my_datapoint(my_dp, *args, **kwargs):
print("Flipping!")
out = my_dp.flip(-1)
return MyDatapoint.wrap_like(my_dp, out)
# %%
# To understand why ``wrap_like`` is used, see
# :ref:`datapoint_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now,
# we will explain it below in :ref:`param_forwarding`.
#
# .. note::
#
# In our call to ``register_kernel`` above we used a string
# ``dispatcher="hflip"`` to refer to the functional we want to hook into. We
# could also have used the functional *itself*, i.e.
# ``@register_kernel(dispatcher=F.hflip, ...)``.
#
# The functionals that you can be hooked into are the ones in
# ``torchvision.transforms.v2.functional`` and they are documented in
# :ref:`functional_transforms`.
#
# Now that we have registered our kernel, we can call the functional API on a
# ``MyDatapoint`` instance:
my_dp = MyDatapoint(torch.rand(3, 256, 256))
_ = F.hflip(my_dp)
# %%
# And we can also use the
# :class:`~torchvision.transforms.v2.RandomHorizontalFlip` transform, since it relies on :func:`~torchvision.transforms.v2.functional.hflip` internally:
t = v2.RandomHorizontalFlip(p=1)
_ = t(my_dp)
# %%
# .. note::
#
# We cannot register a kernel for a transform class, we can only register a
# kernel for a **functional**. The reason we can't register a transform
# class is because one transform may internally rely on more than one
# functional, so in general we can't register a single kernel for a given
# class.
#
# .. _param_forwarding:
#
# Parameter forwarding, and ensuring future compatibility of your kernels
# -----------------------------------------------------------------------
#
# The functional API that you're hooking into is public and therefore
# **backward** compatible: we guarantee that the parameters of these functionals
# won't be removed or renamed without a proper deprecation cycle. However, we
# don't guarantee **forward** compatibility, and we may add new parameters in
# the future.
#
# Imagine that in a future version, Torchvision adds a new ``inplace`` parameter
# to its :func:`~torchvision.transforms.v2.functional.hflip` functional. If you
# already defined and registered your own kernel as
def hflip_my_datapoint(my_dp): # noqa
print("Flipping!")
out = my_dp.flip(-1)
return MyDatapoint.wrap_like(my_dp, out)
# %%
# then calling ``F.hflip(my_dp)`` will **fail**, because ``hflip`` will try to
# pass the new ``inplace`` parameter to your kernel, but your kernel doesn't
# accept it.
#
# For this reason, we recommend to always define your kernels with
# ``*args, **kwargs`` in their signature, as done above. This way, your kernel
# will be able to accept any new parameter that we may add in the future.
# (Technically, adding `**kwargs` only should be enough).
"""
===================================
How to write your own v2 transforms
===================================
This guide explains how to write transforms that are compatible with the
torchvision transforms V2 API.
"""
# %%
import torch
import torchvision
# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
# some APIs may slightly change in the future
torchvision.disable_beta_transforms_warning()
from torchvision import datapoints
from torchvision.transforms import v2
# %%
# Just create a ``nn.Module`` and override the ``forward`` method
# ===============================================================
#
# In most cases, this is all you're going to need, as long as you already know
# the structure of the input that your transform will expect. For example if
# you're just doing image classification, your transform will typically accept a
# single image as input, or a ``(img, label)`` input. So you can just hard-code
# your ``forward`` method to accept just that, e.g.
#
# .. code:: python
#
# class MyCustomTransform(torch.nn.Module):
# def forward(self, img, label):
# # Do some transformations
# return new_img, new_label
#
# .. note::
#
# This means that if you have a custom transform that is already compatible
# with the V1 transforms (those in ``torchvision.transforms``), it will
# still work with the V2 transforms without any change!
#
# We will illustrate this more completely below with a typical detection case,
# where our samples are just images, bounding boxes and labels:
class MyCustomTransform(torch.nn.Module):
def forward(self, img, bboxes, label): # we assume inputs are always structured like this
print(
f"I'm transforming an image of shape {img.shape} "
f"with bboxes = {bboxes}\n{label = }"
)
# Do some transformations. Here, we're just passing though the input
return img, bboxes, label
transforms = v2.Compose([
MyCustomTransform(),
v2.RandomResizedCrop((224, 224), antialias=True),
v2.RandomHorizontalFlip(p=1),
v2.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
])
H, W = 256, 256
img = torch.rand(3, H, W)
bboxes = datapoints.BoundingBoxes(
torch.tensor([[0, 10, 10, 20], [50, 50, 70, 70]]),
format="XYXY",
canvas_size=(H, W)
)
label = 3
out_img, out_bboxes, out_label = transforms(img, bboxes, label)
# %%
print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }")
# %%
# .. note::
# While working with datapoint classes in your code, make sure to
# familiarize yourself with this section:
# :ref:`datapoint_unwrapping_behaviour`
#
# Supporting arbitrary input structures
# =====================================
#
# In the section above, we have assumed that you already know the structure of
# your inputs and that you're OK with hard-coding this expected structure in
# your code. If you want your custom transforms to be as flexible as possible,
# this can be a bit limitting.
#
# A key feature of the builtin Torchvision V2 transforms is that they can accept
# arbitrary input structure and return the same structure as output (with
# transformed entries). For example, transforms can accept a single image, or a
# tuple of ``(img, label)``, or an arbitrary nested dictionary as input:
structured_input = {
"img": img,
"annotations": (bboxes, label),
"something_that_will_be_ignored": (1, "hello")
}
structured_output = v2.RandomHorizontalFlip(p=1)(structured_input)
assert isinstance(structured_output, dict)
assert structured_output["something_that_will_be_ignored"] == (1, "hello")
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
# %%
# If you want to reproduce this behavior in your own transform, we invite you to
# look at our `code
# <https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/_transform.py>`_
# and adapt it to your needs.
#
# In brief, the core logic is to unpack the input into a flat list using `pytree
# <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and
# then transform only the entries that can be transformed (the decision is made
# based on the **class** of the entries, as all datapoints are
# tensor-subclasses) plus some custom logic that is out of score here - check the
# code for details. The (potentially transformed) entries are then repacked and
# returned, in the same structure as the input.
#
# We do not provide public dev-facing tools to achieve that at this time, but if
# this is something that would be valuable to you, please let us know by opening
# an issue on our `GitHub repo <https://github.com/pytorch/vision/issues>`_.
......@@ -3,13 +3,22 @@
Datapoints FAQ
==============
The :mod:`torchvision.datapoints` namespace was introduced together with ``torchvision.transforms.v2``. This example
showcases what these datapoints are and how they behave. This is a fairly low-level topic that most users will not need
to worry about: you do not need to understand the internals of datapoints to efficiently rely on
``torchvision.transforms.v2``. It may however be useful for advanced users trying to implement their own datasets,
transforms, or work directly with the datapoints.
Datapoints are Tensor subclasses introduced together with
``torchvision.transforms.v2``. This example showcases what these datapoints are
and how they behave.
.. warning::
**Intended Audience** Unless you're writing your own transforms or your own datapoints, you
probably do not need to read this guide. This is a fairly low-level topic
that most users will not need to worry about: you do not need to understand
the internals of datapoints to efficiently rely on
``torchvision.transforms.v2``. It may however be useful for advanced users
trying to implement their own datasets, transforms, or work directly with
the datapoints.
"""
# %%
import PIL.Image
import torch
......@@ -35,11 +44,20 @@ image = datapoints.Image(tensor)
assert isinstance(image, torch.Tensor)
assert image.data_ptr() == tensor.data_ptr()
# %%
# Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
# for the input data.
#
# What can I do with a datapoint?
# -------------------------------
#
# Datapoints look and feel just like regular tensors - they **are** tensors.
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or
# any ``torch.*`` operator will also works on datapoints. See
# :ref:`datapoint_unwrapping_behaviour` for a few gotchas.
# %%
#
# What datapoints are supported?
# ------------------------------
#
......@@ -50,9 +68,14 @@ assert image.data_ptr() == tensor.data_ptr()
# * :class:`~torchvision.datapoints.BoundingBoxes`
# * :class:`~torchvision.datapoints.Mask`
#
# .. _datapoint_creation:
#
# How do I construct a datapoint?
# -------------------------------
#
# Using the constructor
# ^^^^^^^^^^^^^^^^^^^^^
#
# Each datapoint class takes any tensor-like data that can be turned into a :class:`~torch.Tensor`
image = datapoints.Image([[[[0, 1], [1, 0]]]])
......@@ -68,27 +91,52 @@ print(float_image)
# %%
# In addition, :class:`~torchvision.datapoints.Image` and :class:`~torchvision.datapoints.Mask` also take a
# In addition, :class:`~torchvision.datapoints.Image` and :class:`~torchvision.datapoints.Mask` can also take a
# :class:`PIL.Image.Image` directly:
image = datapoints.Image(PIL.Image.open("assets/astronaut.jpg"))
print(image.shape, image.dtype)
# %%
# In general, the datapoints can also store additional metadata that complements the underlying tensor. For example,
# :class:`~torchvision.datapoints.BoundingBoxes` stores the coordinate format as well as the spatial size of the
# corresponding image alongside the actual values:
# Some datapoints require additional metadata to be passed in ordered to be constructed. For example,
# :class:`~torchvision.datapoints.BoundingBoxes` requires the coordinate format as well as the size of the
# corresponding image (``canvas_size``) alongside the actual values. These
# metadata are required to properly transform the bounding boxes.
bounding_box = datapoints.BoundingBoxes(
[17, 16, 344, 495], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=image.shape[-2:]
bboxes = datapoints.BoundingBoxes(
[[17, 16, 344, 495], [0, 10, 0, 10]],
format=datapoints.BoundingBoxFormat.XYXY,
canvas_size=image.shape[-2:]
)
print(bounding_box)
print(bboxes)
# %%
# Using the ``wrap_like()`` class method
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# You can also use the ``wrap_like()`` class method to wrap a tensor object
# into a datapoint. This is useful when you already have an object of the
# desired type, which typically happens when writing transforms: you just want
# to wrap the output like the input. This API is inspired by utils like
# :func:`torch.zeros_like`:
new_bboxes = torch.tensor([0, 20, 30, 40])
new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes)
assert isinstance(new_bboxes, datapoints.BoundingBoxes)
assert new_bboxes.canvas_size == bboxes.canvas_size
# %%
# The metadata of ``new_bboxes`` is the same as ``bboxes``, but you could pass
# it as a parameter to override it. Check the
# :meth:`~torchvision.datapoints.BoundingBoxes.wrap_like` documentation for
# more details.
#
# Do I have to wrap the output of the datasets myself?
# ----------------------------------------------------
#
# TODO: Move this in another guide - this is user-facing, not dev-facing.
#
# Only if you are using custom datasets. For the built-in ones, you can use
# :func:`torchvision.datasets.wrap_dataset_for_transforms_v2`. Note that the function also supports subclasses of the
# built-in datasets. Meaning, if your custom dataset subclasses from a built-in one and the output type is the same, you
......@@ -105,8 +153,8 @@ class PennFudanDataset(torch.utils.data.Dataset):
def __getitem__(self, item):
...
target["boxes"] = datapoints.BoundingBoxes(
boxes,
target["bboxes"] = datapoints.BoundingBoxes(
bboxes,
format=datapoints.BoundingBoxFormat.XYXY,
canvas_size=F.get_size(img),
)
......@@ -147,7 +195,7 @@ def get_transform(train):
# %%
# .. note::
#
# If both :class:`~torchvision.datapoints.BoundingBoxes`'es and :class:`~torchvision.datapoints.Mask`'s are included in
# If both :class:`~torchvision.datapoints.BoundingBoxes` and :class:`~torchvision.datapoints.Mask`'s are included in
# the sample, ``torchvision.transforms.v2`` will transform them both. Meaning, if you don't need both, dropping or
# at least not wrapping the obsolete parts, can lead to a significant performance boost.
#
......@@ -156,41 +204,66 @@ def get_transform(train):
# even better to not load the masks at all, but this is not possible in this example, since the bounding boxes are
# generated from the masks.
#
# How do the datapoints behave inside a computation?
# --------------------------------------------------
# .. _datapoint_unwrapping_behaviour:
#
# Datapoints look and feel just like regular tensors. Everything that is supported on a plain :class:`torch.Tensor`
# also works on datapoints.
# Since for most operations involving datapoints, it cannot be safely inferred whether the result should retain the
# datapoint type, we choose to return a plain tensor instead of a datapoint (this might change, see note below):
# I had a Datapoint but now I have a Tensor. Help!
# ------------------------------------------------
#
# For a lot of operations involving datapoints, we cannot safely infer whether
# the result should retain the datapoint type, so we choose to return a plain
# tensor instead of a datapoint (this might change, see note below):
assert isinstance(image, datapoints.Image)
assert isinstance(bboxes, datapoints.BoundingBoxes)
new_image = image + 0
# Shift bboxes by 3 pixels in both H and W
new_bboxes = bboxes + 3
assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image)
assert isinstance(new_bboxes, torch.Tensor) and not isinstance(new_bboxes, datapoints.BoundingBoxes)
# %%
# If you're writing your own custom transforms or code involving datapoints, you
# can re-wrap the output into a datapoint by just calling their constructor, or
# by using the ``.wrap_like()`` class method:
new_bboxes = bboxes + 3
new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes)
assert isinstance(new_bboxes, datapoints.BoundingBoxes)
# %%
# See more details above in :ref:`datapoint_creation`.
#
# .. note::
#
# You never need to re-wrap manually if you're using the built-in transforms
# or their functional equivalents: this is automatically taken care of for
# you.
#
# .. note::
#
# This "unwrapping" behaviour is something we're actively seeking feedback on. If you find this surprising or if you
# have any suggestions on how to better support your use-cases, please reach out to us via this issue:
# https://github.com/pytorch/vision/issues/7319
#
# There are two exceptions to this rule:
# There are a few exceptions to this "unwrapping" rule:
#
# 1. The operations :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`, and :meth:`~torch.Tensor.requires_grad_`
# retain the datapoint type.
# 2. Inplace operations on datapoints cannot change the type of the datapoint they are called on. However, if you use
# the flow style, the returned value will be unwrapped:
# 1. Operations like :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`,
# :meth:`torch.Tensor.detach` and :meth:`~torch.Tensor.requires_grad_` retain
# the datapoint type.
# 2. Inplace operations on datapoints like ``.add_()`` preserve they type. However,
# the **returned** value of inplace operations will be unwrapped into a pure
# tensor:
image = datapoints.Image([[[0, 1], [1, 0]]])
new_image = image.add_(1).mul_(2)
assert isinstance(image, torch.Tensor)
# image got transformed in-place and is still an Image datapoint, but new_image
# is a Tensor. They share the same underlying data and they're equal, just
# different classes.
assert isinstance(image, datapoints.Image)
print(image)
assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image)
assert (new_image == image).all()
assert new_image.data_ptr() == image.data_ptr()
......@@ -42,7 +42,7 @@ class BoundingBoxes(Datapoint):
canvas_size: Tuple[int, int]
@classmethod
def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, canvas_size: Tuple[int, int]) -> BoundingBoxes:
def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override]
bounding_boxes = tensor.as_subclass(cls)
bounding_boxes.format = format
bounding_boxes.canvas_size = canvas_size
......
......@@ -14,6 +14,13 @@ _FillTypeJIT = Optional[List[float]]
class Datapoint(torch.Tensor):
"""[Beta] Base class for all datapoints.
You probably don't want to use this class unless you're defining your own
custom Datapoints. See
:ref:`sphx_glr_auto_examples_plot_custom_datapoints.py` for details.
"""
@staticmethod
def _to_tensor(
data: Any,
......@@ -25,9 +32,13 @@ class Datapoint(torch.Tensor):
requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False
return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)
@classmethod
def _wrap(cls: Type[D], tensor: torch.Tensor) -> D:
return tensor.as_subclass(cls)
@classmethod
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
raise NotImplementedError
return cls._wrap(tensor)
_NO_WRAPPING_EXCEPTIONS = {
torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),
......
......@@ -22,11 +22,6 @@ class Image(Datapoint):
``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
"""
@classmethod
def _wrap(cls, tensor: torch.Tensor) -> Image:
image = tensor.as_subclass(cls)
return image
def __new__(
cls,
data: Any,
......@@ -48,10 +43,6 @@ class Image(Datapoint):
return cls._wrap(tensor)
@classmethod
def wrap_like(cls, other: Image, tensor: torch.Tensor) -> Image:
return cls._wrap(tensor)
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr()
......
......@@ -22,10 +22,6 @@ class Mask(Datapoint):
``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
"""
@classmethod
def _wrap(cls, tensor: torch.Tensor) -> Mask:
return tensor.as_subclass(cls)
def __new__(
cls,
data: Any,
......@@ -41,11 +37,3 @@ class Mask(Datapoint):
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return cls._wrap(tensor)
@classmethod
def wrap_like(
cls,
other: Mask,
tensor: torch.Tensor,
) -> Mask:
return cls._wrap(tensor)
......@@ -20,11 +20,6 @@ class Video(Datapoint):
``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
"""
@classmethod
def _wrap(cls, tensor: torch.Tensor) -> Video:
video = tensor.as_subclass(cls)
return video
def __new__(
cls,
data: Any,
......@@ -38,10 +33,6 @@ class Video(Datapoint):
raise ValueError
return cls._wrap(tensor)
@classmethod
def wrap_like(cls, other: Video, tensor: torch.Tensor) -> Video:
return cls._wrap(tensor)
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr()
......
......@@ -15,7 +15,7 @@ class _LabelBase(Datapoint):
categories: Optional[Sequence[str]]
@classmethod
def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L:
def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L: # type: ignore[override]
label_base = tensor.as_subclass(cls)
label_base.categories = categories
return label_base
......
......@@ -47,6 +47,11 @@ def _name_to_dispatcher(name):
def register_kernel(dispatcher, datapoint_cls):
"""Decorate a kernel to register it for a dispatcher and a (custom) datapoint type.
See :ref:`sphx_glr_auto_examples_plot_custom_datapoints.py` for usage
details.
"""
if isinstance(dispatcher, str):
dispatcher = _name_to_dispatcher(name=dispatcher)
return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment