"src/targets/vscode:/vscode.git/clone" did not exist on "67f77ac1572fe1ef5216f92c8318d23416e641ae"
Unverified Commit 84db2ac4 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add tuto for custom transforms and custom datapoints in gallery example (#7795)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent bf03f4ed
...@@ -320,7 +320,7 @@ def inject_weight_metadata(app, what, name, obj, options, lines): ...@@ -320,7 +320,7 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
used within the autoclass directive. used within the autoclass directive.
""" """
if obj.__name__.endswith(("_Weights", "_QuantizedWeights")): if getattr(obj, ".__name__", "").endswith(("_Weights", "_QuantizedWeights")):
if len(obj) == 0: if len(obj) == 0:
lines[:] = ["There are no available pre-trained weights."] lines[:] = ["There are no available pre-trained weights."]
......
...@@ -17,3 +17,4 @@ see e.g. :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`. ...@@ -17,3 +17,4 @@ see e.g. :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`.
BoundingBoxFormat BoundingBoxFormat
BoundingBoxes BoundingBoxes
Mask Mask
Datapoint
...@@ -375,3 +375,14 @@ you can use a functional transform to build transform classes with custom behavi ...@@ -375,3 +375,14 @@ you can use a functional transform to build transform classes with custom behavi
to_pil_image to_pil_image
to_tensor to_tensor
vflip vflip
Developer tools
---------------
.. currentmodule:: torchvision.transforms.v2.functional
.. autosummary::
:toctree: generated/
:template: function.rst
register_kernel
"""
=====================================
How to write your own Datapoint class
=====================================
This guide is intended for downstream library maintainers. We explain how to
write your own datapoint class, and how to make it compatible with the built-in
Torchvision v2 transforms. Before continuing, make sure you have read
:ref:`sphx_glr_auto_examples_plot_datapoints.py`.
"""
# %%
import torch
import torchvision
# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
# some APIs may slightly change in the future
torchvision.disable_beta_transforms_warning()
from torchvision import datapoints
from torchvision.transforms import v2
# %%
# We will create a very simple class that just inherits from the base
# :class:`~torchvision.datapoints.Datapoint` class. It will be enough to cover
# what you need to know to implement your more elaborate uses-cases. If you need
# to create a class that carries meta-data, take a look at how the
# :class:`~torchvision.datapoints.BoundingBoxes` class is `implemented
# <https://github.com/pytorch/vision/blob/main/torchvision/datapoints/_bounding_box.py>`_.
class MyDatapoint(datapoints.Datapoint):
pass
my_dp = MyDatapoint([1, 2, 3])
my_dp
# %%
# Now that we have defined our custom Datapoint class, we want it to be
# compatible with the built-in torchvision transforms, and the functional API.
# For that, we need to implement a kernel which performs the core of the
# transformation, and then "hook" it to the functional that we want to support
# via :func:`~torchvision.transforms.v2.functional.register_kernel`.
#
# We illustrate this process below: we create a kernel for the "horizontal flip"
# operation of our MyDatapoint class, and register it to the functional API.
from torchvision.transforms.v2 import functional as F
@F.register_kernel(dispatcher="hflip", datapoint_cls=MyDatapoint)
def hflip_my_datapoint(my_dp, *args, **kwargs):
print("Flipping!")
out = my_dp.flip(-1)
return MyDatapoint.wrap_like(my_dp, out)
# %%
# To understand why ``wrap_like`` is used, see
# :ref:`datapoint_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now,
# we will explain it below in :ref:`param_forwarding`.
#
# .. note::
#
# In our call to ``register_kernel`` above we used a string
# ``dispatcher="hflip"`` to refer to the functional we want to hook into. We
# could also have used the functional *itself*, i.e.
# ``@register_kernel(dispatcher=F.hflip, ...)``.
#
# The functionals that you can be hooked into are the ones in
# ``torchvision.transforms.v2.functional`` and they are documented in
# :ref:`functional_transforms`.
#
# Now that we have registered our kernel, we can call the functional API on a
# ``MyDatapoint`` instance:
my_dp = MyDatapoint(torch.rand(3, 256, 256))
_ = F.hflip(my_dp)
# %%
# And we can also use the
# :class:`~torchvision.transforms.v2.RandomHorizontalFlip` transform, since it relies on :func:`~torchvision.transforms.v2.functional.hflip` internally:
t = v2.RandomHorizontalFlip(p=1)
_ = t(my_dp)
# %%
# .. note::
#
# We cannot register a kernel for a transform class, we can only register a
# kernel for a **functional**. The reason we can't register a transform
# class is because one transform may internally rely on more than one
# functional, so in general we can't register a single kernel for a given
# class.
#
# .. _param_forwarding:
#
# Parameter forwarding, and ensuring future compatibility of your kernels
# -----------------------------------------------------------------------
#
# The functional API that you're hooking into is public and therefore
# **backward** compatible: we guarantee that the parameters of these functionals
# won't be removed or renamed without a proper deprecation cycle. However, we
# don't guarantee **forward** compatibility, and we may add new parameters in
# the future.
#
# Imagine that in a future version, Torchvision adds a new ``inplace`` parameter
# to its :func:`~torchvision.transforms.v2.functional.hflip` functional. If you
# already defined and registered your own kernel as
def hflip_my_datapoint(my_dp): # noqa
print("Flipping!")
out = my_dp.flip(-1)
return MyDatapoint.wrap_like(my_dp, out)
# %%
# then calling ``F.hflip(my_dp)`` will **fail**, because ``hflip`` will try to
# pass the new ``inplace`` parameter to your kernel, but your kernel doesn't
# accept it.
#
# For this reason, we recommend to always define your kernels with
# ``*args, **kwargs`` in their signature, as done above. This way, your kernel
# will be able to accept any new parameter that we may add in the future.
# (Technically, adding `**kwargs` only should be enough).
"""
===================================
How to write your own v2 transforms
===================================
This guide explains how to write transforms that are compatible with the
torchvision transforms V2 API.
"""
# %%
import torch
import torchvision
# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
# some APIs may slightly change in the future
torchvision.disable_beta_transforms_warning()
from torchvision import datapoints
from torchvision.transforms import v2
# %%
# Just create a ``nn.Module`` and override the ``forward`` method
# ===============================================================
#
# In most cases, this is all you're going to need, as long as you already know
# the structure of the input that your transform will expect. For example if
# you're just doing image classification, your transform will typically accept a
# single image as input, or a ``(img, label)`` input. So you can just hard-code
# your ``forward`` method to accept just that, e.g.
#
# .. code:: python
#
# class MyCustomTransform(torch.nn.Module):
# def forward(self, img, label):
# # Do some transformations
# return new_img, new_label
#
# .. note::
#
# This means that if you have a custom transform that is already compatible
# with the V1 transforms (those in ``torchvision.transforms``), it will
# still work with the V2 transforms without any change!
#
# We will illustrate this more completely below with a typical detection case,
# where our samples are just images, bounding boxes and labels:
class MyCustomTransform(torch.nn.Module):
def forward(self, img, bboxes, label): # we assume inputs are always structured like this
print(
f"I'm transforming an image of shape {img.shape} "
f"with bboxes = {bboxes}\n{label = }"
)
# Do some transformations. Here, we're just passing though the input
return img, bboxes, label
transforms = v2.Compose([
MyCustomTransform(),
v2.RandomResizedCrop((224, 224), antialias=True),
v2.RandomHorizontalFlip(p=1),
v2.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
])
H, W = 256, 256
img = torch.rand(3, H, W)
bboxes = datapoints.BoundingBoxes(
torch.tensor([[0, 10, 10, 20], [50, 50, 70, 70]]),
format="XYXY",
canvas_size=(H, W)
)
label = 3
out_img, out_bboxes, out_label = transforms(img, bboxes, label)
# %%
print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }")
# %%
# .. note::
# While working with datapoint classes in your code, make sure to
# familiarize yourself with this section:
# :ref:`datapoint_unwrapping_behaviour`
#
# Supporting arbitrary input structures
# =====================================
#
# In the section above, we have assumed that you already know the structure of
# your inputs and that you're OK with hard-coding this expected structure in
# your code. If you want your custom transforms to be as flexible as possible,
# this can be a bit limitting.
#
# A key feature of the builtin Torchvision V2 transforms is that they can accept
# arbitrary input structure and return the same structure as output (with
# transformed entries). For example, transforms can accept a single image, or a
# tuple of ``(img, label)``, or an arbitrary nested dictionary as input:
structured_input = {
"img": img,
"annotations": (bboxes, label),
"something_that_will_be_ignored": (1, "hello")
}
structured_output = v2.RandomHorizontalFlip(p=1)(structured_input)
assert isinstance(structured_output, dict)
assert structured_output["something_that_will_be_ignored"] == (1, "hello")
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
# %%
# If you want to reproduce this behavior in your own transform, we invite you to
# look at our `code
# <https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/_transform.py>`_
# and adapt it to your needs.
#
# In brief, the core logic is to unpack the input into a flat list using `pytree
# <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and
# then transform only the entries that can be transformed (the decision is made
# based on the **class** of the entries, as all datapoints are
# tensor-subclasses) plus some custom logic that is out of score here - check the
# code for details. The (potentially transformed) entries are then repacked and
# returned, in the same structure as the input.
#
# We do not provide public dev-facing tools to achieve that at this time, but if
# this is something that would be valuable to you, please let us know by opening
# an issue on our `GitHub repo <https://github.com/pytorch/vision/issues>`_.
...@@ -3,13 +3,22 @@ ...@@ -3,13 +3,22 @@
Datapoints FAQ Datapoints FAQ
============== ==============
The :mod:`torchvision.datapoints` namespace was introduced together with ``torchvision.transforms.v2``. This example Datapoints are Tensor subclasses introduced together with
showcases what these datapoints are and how they behave. This is a fairly low-level topic that most users will not need ``torchvision.transforms.v2``. This example showcases what these datapoints are
to worry about: you do not need to understand the internals of datapoints to efficiently rely on and how they behave.
``torchvision.transforms.v2``. It may however be useful for advanced users trying to implement their own datasets,
transforms, or work directly with the datapoints. .. warning::
**Intended Audience** Unless you're writing your own transforms or your own datapoints, you
probably do not need to read this guide. This is a fairly low-level topic
that most users will not need to worry about: you do not need to understand
the internals of datapoints to efficiently rely on
``torchvision.transforms.v2``. It may however be useful for advanced users
trying to implement their own datasets, transforms, or work directly with
the datapoints.
""" """
# %%
import PIL.Image import PIL.Image
import torch import torch
...@@ -35,11 +44,20 @@ image = datapoints.Image(tensor) ...@@ -35,11 +44,20 @@ image = datapoints.Image(tensor)
assert isinstance(image, torch.Tensor) assert isinstance(image, torch.Tensor)
assert image.data_ptr() == tensor.data_ptr() assert image.data_ptr() == tensor.data_ptr()
# %% # %%
# Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function # Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
# for the input data. # for the input data.
# #
# What can I do with a datapoint?
# -------------------------------
#
# Datapoints look and feel just like regular tensors - they **are** tensors.
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or
# any ``torch.*`` operator will also works on datapoints. See
# :ref:`datapoint_unwrapping_behaviour` for a few gotchas.
# %%
#
# What datapoints are supported? # What datapoints are supported?
# ------------------------------ # ------------------------------
# #
...@@ -50,9 +68,14 @@ assert image.data_ptr() == tensor.data_ptr() ...@@ -50,9 +68,14 @@ assert image.data_ptr() == tensor.data_ptr()
# * :class:`~torchvision.datapoints.BoundingBoxes` # * :class:`~torchvision.datapoints.BoundingBoxes`
# * :class:`~torchvision.datapoints.Mask` # * :class:`~torchvision.datapoints.Mask`
# #
# .. _datapoint_creation:
#
# How do I construct a datapoint? # How do I construct a datapoint?
# ------------------------------- # -------------------------------
# #
# Using the constructor
# ^^^^^^^^^^^^^^^^^^^^^
#
# Each datapoint class takes any tensor-like data that can be turned into a :class:`~torch.Tensor` # Each datapoint class takes any tensor-like data that can be turned into a :class:`~torch.Tensor`
image = datapoints.Image([[[[0, 1], [1, 0]]]]) image = datapoints.Image([[[[0, 1], [1, 0]]]])
...@@ -68,27 +91,52 @@ print(float_image) ...@@ -68,27 +91,52 @@ print(float_image)
# %% # %%
# In addition, :class:`~torchvision.datapoints.Image` and :class:`~torchvision.datapoints.Mask` also take a # In addition, :class:`~torchvision.datapoints.Image` and :class:`~torchvision.datapoints.Mask` can also take a
# :class:`PIL.Image.Image` directly: # :class:`PIL.Image.Image` directly:
image = datapoints.Image(PIL.Image.open("assets/astronaut.jpg")) image = datapoints.Image(PIL.Image.open("assets/astronaut.jpg"))
print(image.shape, image.dtype) print(image.shape, image.dtype)
# %% # %%
# In general, the datapoints can also store additional metadata that complements the underlying tensor. For example, # Some datapoints require additional metadata to be passed in ordered to be constructed. For example,
# :class:`~torchvision.datapoints.BoundingBoxes` stores the coordinate format as well as the spatial size of the # :class:`~torchvision.datapoints.BoundingBoxes` requires the coordinate format as well as the size of the
# corresponding image alongside the actual values: # corresponding image (``canvas_size``) alongside the actual values. These
# metadata are required to properly transform the bounding boxes.
bounding_box = datapoints.BoundingBoxes( bboxes = datapoints.BoundingBoxes(
[17, 16, 344, 495], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=image.shape[-2:] [[17, 16, 344, 495], [0, 10, 0, 10]],
format=datapoints.BoundingBoxFormat.XYXY,
canvas_size=image.shape[-2:]
) )
print(bounding_box) print(bboxes)
# %%
# Using the ``wrap_like()`` class method
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# You can also use the ``wrap_like()`` class method to wrap a tensor object
# into a datapoint. This is useful when you already have an object of the
# desired type, which typically happens when writing transforms: you just want
# to wrap the output like the input. This API is inspired by utils like
# :func:`torch.zeros_like`:
new_bboxes = torch.tensor([0, 20, 30, 40])
new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes)
assert isinstance(new_bboxes, datapoints.BoundingBoxes)
assert new_bboxes.canvas_size == bboxes.canvas_size
# %% # %%
# The metadata of ``new_bboxes`` is the same as ``bboxes``, but you could pass
# it as a parameter to override it. Check the
# :meth:`~torchvision.datapoints.BoundingBoxes.wrap_like` documentation for
# more details.
#
# Do I have to wrap the output of the datasets myself? # Do I have to wrap the output of the datasets myself?
# ---------------------------------------------------- # ----------------------------------------------------
# #
# TODO: Move this in another guide - this is user-facing, not dev-facing.
#
# Only if you are using custom datasets. For the built-in ones, you can use # Only if you are using custom datasets. For the built-in ones, you can use
# :func:`torchvision.datasets.wrap_dataset_for_transforms_v2`. Note that the function also supports subclasses of the # :func:`torchvision.datasets.wrap_dataset_for_transforms_v2`. Note that the function also supports subclasses of the
# built-in datasets. Meaning, if your custom dataset subclasses from a built-in one and the output type is the same, you # built-in datasets. Meaning, if your custom dataset subclasses from a built-in one and the output type is the same, you
...@@ -105,8 +153,8 @@ class PennFudanDataset(torch.utils.data.Dataset): ...@@ -105,8 +153,8 @@ class PennFudanDataset(torch.utils.data.Dataset):
def __getitem__(self, item): def __getitem__(self, item):
... ...
target["boxes"] = datapoints.BoundingBoxes( target["bboxes"] = datapoints.BoundingBoxes(
boxes, bboxes,
format=datapoints.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
canvas_size=F.get_size(img), canvas_size=F.get_size(img),
) )
...@@ -147,7 +195,7 @@ def get_transform(train): ...@@ -147,7 +195,7 @@ def get_transform(train):
# %% # %%
# .. note:: # .. note::
# #
# If both :class:`~torchvision.datapoints.BoundingBoxes`'es and :class:`~torchvision.datapoints.Mask`'s are included in # If both :class:`~torchvision.datapoints.BoundingBoxes` and :class:`~torchvision.datapoints.Mask`'s are included in
# the sample, ``torchvision.transforms.v2`` will transform them both. Meaning, if you don't need both, dropping or # the sample, ``torchvision.transforms.v2`` will transform them both. Meaning, if you don't need both, dropping or
# at least not wrapping the obsolete parts, can lead to a significant performance boost. # at least not wrapping the obsolete parts, can lead to a significant performance boost.
# #
...@@ -156,41 +204,66 @@ def get_transform(train): ...@@ -156,41 +204,66 @@ def get_transform(train):
# even better to not load the masks at all, but this is not possible in this example, since the bounding boxes are # even better to not load the masks at all, but this is not possible in this example, since the bounding boxes are
# generated from the masks. # generated from the masks.
# #
# How do the datapoints behave inside a computation? # .. _datapoint_unwrapping_behaviour:
# --------------------------------------------------
# #
# Datapoints look and feel just like regular tensors. Everything that is supported on a plain :class:`torch.Tensor` # I had a Datapoint but now I have a Tensor. Help!
# also works on datapoints. # ------------------------------------------------
# Since for most operations involving datapoints, it cannot be safely inferred whether the result should retain the #
# datapoint type, we choose to return a plain tensor instead of a datapoint (this might change, see note below): # For a lot of operations involving datapoints, we cannot safely infer whether
# the result should retain the datapoint type, so we choose to return a plain
# tensor instead of a datapoint (this might change, see note below):
assert isinstance(image, datapoints.Image) assert isinstance(bboxes, datapoints.BoundingBoxes)
new_image = image + 0 # Shift bboxes by 3 pixels in both H and W
new_bboxes = bboxes + 3
assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image) assert isinstance(new_bboxes, torch.Tensor) and not isinstance(new_bboxes, datapoints.BoundingBoxes)
# %%
# If you're writing your own custom transforms or code involving datapoints, you
# can re-wrap the output into a datapoint by just calling their constructor, or
# by using the ``.wrap_like()`` class method:
new_bboxes = bboxes + 3
new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes)
assert isinstance(new_bboxes, datapoints.BoundingBoxes)
# %% # %%
# See more details above in :ref:`datapoint_creation`.
#
# .. note::
#
# You never need to re-wrap manually if you're using the built-in transforms
# or their functional equivalents: this is automatically taken care of for
# you.
#
# .. note:: # .. note::
# #
# This "unwrapping" behaviour is something we're actively seeking feedback on. If you find this surprising or if you # This "unwrapping" behaviour is something we're actively seeking feedback on. If you find this surprising or if you
# have any suggestions on how to better support your use-cases, please reach out to us via this issue: # have any suggestions on how to better support your use-cases, please reach out to us via this issue:
# https://github.com/pytorch/vision/issues/7319 # https://github.com/pytorch/vision/issues/7319
# #
# There are two exceptions to this rule: # There are a few exceptions to this "unwrapping" rule:
# #
# 1. The operations :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`, and :meth:`~torch.Tensor.requires_grad_` # 1. Operations like :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`,
# retain the datapoint type. # :meth:`torch.Tensor.detach` and :meth:`~torch.Tensor.requires_grad_` retain
# 2. Inplace operations on datapoints cannot change the type of the datapoint they are called on. However, if you use # the datapoint type.
# the flow style, the returned value will be unwrapped: # 2. Inplace operations on datapoints like ``.add_()`` preserve they type. However,
# the **returned** value of inplace operations will be unwrapped into a pure
# tensor:
image = datapoints.Image([[[0, 1], [1, 0]]]) image = datapoints.Image([[[0, 1], [1, 0]]])
new_image = image.add_(1).mul_(2) new_image = image.add_(1).mul_(2)
assert isinstance(image, torch.Tensor) # image got transformed in-place and is still an Image datapoint, but new_image
# is a Tensor. They share the same underlying data and they're equal, just
# different classes.
assert isinstance(image, datapoints.Image)
print(image) print(image)
assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image) assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image)
assert (new_image == image).all() assert (new_image == image).all()
assert new_image.data_ptr() == image.data_ptr()
...@@ -42,7 +42,7 @@ class BoundingBoxes(Datapoint): ...@@ -42,7 +42,7 @@ class BoundingBoxes(Datapoint):
canvas_size: Tuple[int, int] canvas_size: Tuple[int, int]
@classmethod @classmethod
def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, canvas_size: Tuple[int, int]) -> BoundingBoxes: def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override]
bounding_boxes = tensor.as_subclass(cls) bounding_boxes = tensor.as_subclass(cls)
bounding_boxes.format = format bounding_boxes.format = format
bounding_boxes.canvas_size = canvas_size bounding_boxes.canvas_size = canvas_size
......
...@@ -14,6 +14,13 @@ _FillTypeJIT = Optional[List[float]] ...@@ -14,6 +14,13 @@ _FillTypeJIT = Optional[List[float]]
class Datapoint(torch.Tensor): class Datapoint(torch.Tensor):
"""[Beta] Base class for all datapoints.
You probably don't want to use this class unless you're defining your own
custom Datapoints. See
:ref:`sphx_glr_auto_examples_plot_custom_datapoints.py` for details.
"""
@staticmethod @staticmethod
def _to_tensor( def _to_tensor(
data: Any, data: Any,
...@@ -25,9 +32,13 @@ class Datapoint(torch.Tensor): ...@@ -25,9 +32,13 @@ class Datapoint(torch.Tensor):
requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False
return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad) return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)
@classmethod
def _wrap(cls: Type[D], tensor: torch.Tensor) -> D:
return tensor.as_subclass(cls)
@classmethod @classmethod
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
raise NotImplementedError return cls._wrap(tensor)
_NO_WRAPPING_EXCEPTIONS = { _NO_WRAPPING_EXCEPTIONS = {
torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output), torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),
......
...@@ -22,11 +22,6 @@ class Image(Datapoint): ...@@ -22,11 +22,6 @@ class Image(Datapoint):
``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``. ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
""" """
@classmethod
def _wrap(cls, tensor: torch.Tensor) -> Image:
image = tensor.as_subclass(cls)
return image
def __new__( def __new__(
cls, cls,
data: Any, data: Any,
...@@ -48,10 +43,6 @@ class Image(Datapoint): ...@@ -48,10 +43,6 @@ class Image(Datapoint):
return cls._wrap(tensor) return cls._wrap(tensor)
@classmethod
def wrap_like(cls, other: Image, tensor: torch.Tensor) -> Image:
return cls._wrap(tensor)
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr() return self._make_repr()
......
...@@ -22,10 +22,6 @@ class Mask(Datapoint): ...@@ -22,10 +22,6 @@ class Mask(Datapoint):
``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``. ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
""" """
@classmethod
def _wrap(cls, tensor: torch.Tensor) -> Mask:
return tensor.as_subclass(cls)
def __new__( def __new__(
cls, cls,
data: Any, data: Any,
...@@ -41,11 +37,3 @@ class Mask(Datapoint): ...@@ -41,11 +37,3 @@ class Mask(Datapoint):
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return cls._wrap(tensor) return cls._wrap(tensor)
@classmethod
def wrap_like(
cls,
other: Mask,
tensor: torch.Tensor,
) -> Mask:
return cls._wrap(tensor)
...@@ -20,11 +20,6 @@ class Video(Datapoint): ...@@ -20,11 +20,6 @@ class Video(Datapoint):
``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``. ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
""" """
@classmethod
def _wrap(cls, tensor: torch.Tensor) -> Video:
video = tensor.as_subclass(cls)
return video
def __new__( def __new__(
cls, cls,
data: Any, data: Any,
...@@ -38,10 +33,6 @@ class Video(Datapoint): ...@@ -38,10 +33,6 @@ class Video(Datapoint):
raise ValueError raise ValueError
return cls._wrap(tensor) return cls._wrap(tensor)
@classmethod
def wrap_like(cls, other: Video, tensor: torch.Tensor) -> Video:
return cls._wrap(tensor)
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr() return self._make_repr()
......
...@@ -15,7 +15,7 @@ class _LabelBase(Datapoint): ...@@ -15,7 +15,7 @@ class _LabelBase(Datapoint):
categories: Optional[Sequence[str]] categories: Optional[Sequence[str]]
@classmethod @classmethod
def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L: def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L: # type: ignore[override]
label_base = tensor.as_subclass(cls) label_base = tensor.as_subclass(cls)
label_base.categories = categories label_base.categories = categories
return label_base return label_base
......
...@@ -47,6 +47,11 @@ def _name_to_dispatcher(name): ...@@ -47,6 +47,11 @@ def _name_to_dispatcher(name):
def register_kernel(dispatcher, datapoint_cls): def register_kernel(dispatcher, datapoint_cls):
"""Decorate a kernel to register it for a dispatcher and a (custom) datapoint type.
See :ref:`sphx_glr_auto_examples_plot_custom_datapoints.py` for usage
details.
"""
if isinstance(dispatcher, str): if isinstance(dispatcher, str):
dispatcher = _name_to_dispatcher(name=dispatcher) dispatcher = _name_to_dispatcher(name=dispatcher)
return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False) return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment