Unverified Commit f4f685dd authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

More datapoints docs and comments (#7830)


Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
parent 6c44ceb5
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
How to write your own Datapoint class How to write your own Datapoint class
===================================== =====================================
This guide is intended for downstream library maintainers. We explain how to This guide is intended for advanced users and downstream library maintainers. We explain how to
write your own datapoint class, and how to make it compatible with the built-in write your own datapoint class, and how to make it compatible with the built-in
Torchvision v2 transforms. Before continuing, make sure you have read Torchvision v2 transforms. Before continuing, make sure you have read
:ref:`sphx_glr_auto_examples_plot_datapoints.py`. :ref:`sphx_glr_auto_examples_plot_datapoints.py`.
...@@ -68,10 +68,6 @@ def hflip_my_datapoint(my_dp, *args, **kwargs): ...@@ -68,10 +68,6 @@ def hflip_my_datapoint(my_dp, *args, **kwargs):
# could also have used the functional *itself*, i.e. # could also have used the functional *itself*, i.e.
# ``@register_kernel(functional=F.hflip, ...)``. # ``@register_kernel(functional=F.hflip, ...)``.
# #
# The functionals that you can be hooked into are the ones in
# ``torchvision.transforms.v2.functional`` and they are documented in
# :ref:`functional_transforms`.
#
# Now that we have registered our kernel, we can call the functional API on a # Now that we have registered our kernel, we can call the functional API on a
# ``MyDatapoint`` instance: # ``MyDatapoint`` instance:
......
...@@ -48,26 +48,22 @@ assert image.data_ptr() == tensor.data_ptr() ...@@ -48,26 +48,22 @@ assert image.data_ptr() == tensor.data_ptr()
# Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function # Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
# for the input data. # for the input data.
# #
# :mod:`torchvision.datapoints` supports four types of datapoints:
#
# * :class:`~torchvision.datapoints.Image`
# * :class:`~torchvision.datapoints.Video`
# * :class:`~torchvision.datapoints.BoundingBoxes`
# * :class:`~torchvision.datapoints.Mask`
#
# What can I do with a datapoint? # What can I do with a datapoint?
# ------------------------------- # -------------------------------
# #
# Datapoints look and feel just like regular tensors - they **are** tensors. # Datapoints look and feel just like regular tensors - they **are** tensors.
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or # Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or
# any ``torch.*`` operator will also works on datapoints. See # any ``torch.*`` operator will also work on datapoints. See
# :ref:`datapoint_unwrapping_behaviour` for a few gotchas. # :ref:`datapoint_unwrapping_behaviour` for a few gotchas.
# %% # %%
#
# What datapoints are supported?
# ------------------------------
#
# So far :mod:`torchvision.datapoints` supports four types of datapoints:
#
# * :class:`~torchvision.datapoints.Image`
# * :class:`~torchvision.datapoints.Video`
# * :class:`~torchvision.datapoints.BoundingBoxes`
# * :class:`~torchvision.datapoints.Mask`
#
# .. _datapoint_creation: # .. _datapoint_creation:
# #
# How do I construct a datapoint? # How do I construct a datapoint?
...@@ -209,9 +205,8 @@ def get_transform(train): ...@@ -209,9 +205,8 @@ def get_transform(train):
# I had a Datapoint but now I have a Tensor. Help! # I had a Datapoint but now I have a Tensor. Help!
# ------------------------------------------------ # ------------------------------------------------
# #
# For a lot of operations involving datapoints, we cannot safely infer whether # By default, operations on :class:`~torchvision.datapoints.Datapoint` objects
# the result should retain the datapoint type, so we choose to return a plain # will return a pure Tensor:
# tensor instead of a datapoint (this might change, see note below):
assert isinstance(bboxes, datapoints.BoundingBoxes) assert isinstance(bboxes, datapoints.BoundingBoxes)
...@@ -219,32 +214,69 @@ assert isinstance(bboxes, datapoints.BoundingBoxes) ...@@ -219,32 +214,69 @@ assert isinstance(bboxes, datapoints.BoundingBoxes)
# Shift bboxes by 3 pixels in both H and W # Shift bboxes by 3 pixels in both H and W
new_bboxes = bboxes + 3 new_bboxes = bboxes + 3
assert isinstance(new_bboxes, torch.Tensor) and not isinstance(new_bboxes, datapoints.BoundingBoxes) assert isinstance(new_bboxes, torch.Tensor)
assert not isinstance(new_bboxes, datapoints.BoundingBoxes)
# %%
# .. note::
#
# This behavior only affects native ``torch`` operations. If you are using
# the built-in ``torchvision`` transforms or functionals, you will always get
# as output the same type that you passed as input (pure ``Tensor`` or
# ``Datapoint``).
# %% # %%
# If you're writing your own custom transforms or code involving datapoints, you # But I want a Datapoint back!
# can re-wrap the output into a datapoint by just calling their constructor, or # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# by using the ``.wrap_like()`` class method: #
# You can re-wrap a pure tensor into a datapoint by just calling the datapoint
# constructor, or by using the ``.wrap_like()`` class method (see more details
# above in :ref:`datapoint_creation`):
new_bboxes = bboxes + 3 new_bboxes = bboxes + 3
new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes) new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes)
assert isinstance(new_bboxes, datapoints.BoundingBoxes) assert isinstance(new_bboxes, datapoints.BoundingBoxes)
# %% # %%
# See more details above in :ref:`datapoint_creation`. # Alternatively, you can use the :func:`~torchvision.datapoints.set_return_type`
# as a global config setting for the whole program, or as a context manager:
with datapoints.set_return_type("datapoint"):
new_bboxes = bboxes + 3
assert isinstance(new_bboxes, datapoints.BoundingBoxes)
# %%
# Why is this happening?
# ^^^^^^^^^^^^^^^^^^^^^^
# #
# .. note:: # **For performance reasons**. :class:`~torchvision.datapoints.Datapoint`
# classes are Tensor subclasses, so any operation involving a
# :class:`~torchvision.datapoints.Datapoint` object will go through the
# `__torch_function__
# <https://pytorch.org/docs/stable/notes/extending.html#extending-torch>`_
# protocol. This induces a small overhead, which we want to avoid when possible.
# This doesn't matter for built-in ``torchvision`` transforms because we can
# avoid the overhead there, but it could be a problem in your model's
# ``forward``.
# #
# You never need to re-wrap manually if you're using the built-in transforms # **The alternative isn't much better anyway.** For every operation where
# or their functional equivalents: this is automatically taken care of for # preserving the :class:`~torchvision.datapoints.Datapoint` type makes
# you. # sense, there are just as many operations where returning a pure Tensor is
# preferable: for example, is ``img.sum()`` still an :class:`~torchvision.datapoints.Image`?
# If we were to preserve :class:`~torchvision.datapoints.Datapoint` types all
# the way, even model's logits or the output of the loss function would end up
# being of type :class:`~torchvision.datapoints.Image`, and surely that's not
# desirable.
# #
# .. note:: # .. note::
# #
# This "unwrapping" behaviour is something we're actively seeking feedback on. If you find this surprising or if you # This behaviour is something we're actively seeking feedback on. If you find this surprising or if you
# have any suggestions on how to better support your use-cases, please reach out to us via this issue: # have any suggestions on how to better support your use-cases, please reach out to us via this issue:
# https://github.com/pytorch/vision/issues/7319 # https://github.com/pytorch/vision/issues/7319
# #
# Exceptions
# ^^^^^^^^^^
#
# There are a few exceptions to this "unwrapping" rule: # There are a few exceptions to this "unwrapping" rule:
# #
# 1. Operations like :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`, # 1. Operations like :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`,
......
...@@ -101,6 +101,7 @@ def test_to_datapoint_reference(make_input, return_type): ...@@ -101,6 +101,7 @@ def test_to_datapoint_reference(make_input, return_type):
assert type(tensor_to) is (type(dp) if return_type == "datapoint" else torch.Tensor) assert type(tensor_to) is (type(dp) if return_type == "datapoint" else torch.Tensor)
assert tensor_to.dtype is dp.dtype assert tensor_to.dtype is dp.dtype
assert type(tensor) is torch.Tensor
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video]) @pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
......
...@@ -66,19 +66,12 @@ class Datapoint(torch.Tensor): ...@@ -66,19 +66,12 @@ class Datapoint(torch.Tensor):
``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the ``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
``args`` and ``kwargs`` of the original call. ``args`` and ``kwargs`` of the original call.
The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Datapoint` Why do we override this? Because the base implementation in torch.Tensor would preserve the Datapoint type
use case, this has two downsides: of the output. In our case, we want to return pure tensors instead (with a few exceptions). Refer to the
"Datapoints FAQ" gallery example for a rationale of this behaviour (TL;DR: perf + no silver bullet).
1. Since some :class:`Datapoint`'s require metadata to be constructed, the default wrapping, i.e. Our implementation below is very similar to the base implementation in ``torch.Tensor`` - go check it out.
``return cls(func(*args, **kwargs))``, will fail for them.
2. For most operations, there is no way of knowing if the input type is still valid for the output.
For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are
listed in _FORCE_TORCHFUNCTION_SUBCLASS
""" """
# Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
# need to reimplement the functionality.
if not all(issubclass(cls, t) for t in types): if not all(issubclass(cls, t) for t in types):
return NotImplemented return NotImplemented
...@@ -89,12 +82,13 @@ class Datapoint(torch.Tensor): ...@@ -89,12 +82,13 @@ class Datapoint(torch.Tensor):
must_return_subclass = _must_return_subclass() must_return_subclass = _must_return_subclass()
if must_return_subclass or (func in _FORCE_TORCHFUNCTION_SUBCLASS and isinstance(args[0], cls)): if must_return_subclass or (func in _FORCE_TORCHFUNCTION_SUBCLASS and isinstance(args[0], cls)):
# We also require the primary operand, i.e. `args[0]`, to be # If you're wondering why we need the `isinstance(args[0], cls)` check, remove it and see what fails
# an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will # in test_to_datapoint_reference().
# invoke this method on *all* types involved in the computation by walking the MRO upwards. For example, # The __torch_function__ protocol will invoke the __torch_function__ method on *all* types involved in
# `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with # the computation by walking the MRO upwards. For example,
# `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would # `out = a_pure_tensor.to(an_image)` will invoke `Image.__torch_function__` with
# be wrapped into a `datapoints.Image`. # `args = (a_pure_tensor, an_image)` first. Without this guard, `out` would
# be wrapped into an `Image`.
return cls._wrap_output(output, args, kwargs) return cls._wrap_output(output, args, kwargs)
if not must_return_subclass and isinstance(output, cls): if not must_return_subclass and isinstance(output, cls):
......
...@@ -18,12 +18,18 @@ class _ReturnTypeCM: ...@@ -18,12 +18,18 @@ class _ReturnTypeCM:
def set_return_type(return_type: str): def set_return_type(return_type: str):
"""Set the return type of torch operations on datapoints. """Set the return type of torch operations on datapoints.
This only affects the behaviour of torch operations. It has no effect on
``torchvision`` transforms or functionals, which will always return as
output the same type that was passed as input.
Can be used as a global flag for the entire program: Can be used as a global flag for the entire program:
.. code:: python .. code:: python
set_return_type("datapoints")
img = datapoints.Image(torch.rand(3, 5, 5)) img = datapoints.Image(torch.rand(3, 5, 5))
img + 2 # This is a pure Tensor (default behaviour)
set_return_type("datapoints")
img + 2 # This is an Image img + 2 # This is an Image
or as a context manager to restrict the scope: or as a context manager to restrict the scope:
...@@ -31,6 +37,7 @@ def set_return_type(return_type: str): ...@@ -31,6 +37,7 @@ def set_return_type(return_type: str):
.. code:: python .. code:: python
img = datapoints.Image(torch.rand(3, 5, 5)) img = datapoints.Image(torch.rand(3, 5, 5))
img + 2 # This is a pure Tensor
with set_return_type("datapoints"): with set_return_type("datapoints"):
img + 2 # This is an Image img + 2 # This is an Image
img + 2 # This is a pure Tensor img + 2 # This is a pure Tensor
......
...@@ -19,8 +19,15 @@ _KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {} ...@@ -19,8 +19,15 @@ _KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {}
def _kernel_datapoint_wrapper(kernel): def _kernel_datapoint_wrapper(kernel):
@functools.wraps(kernel) @functools.wraps(kernel)
def wrapper(inpt, *args, **kwargs): def wrapper(inpt, *args, **kwargs):
# We always pass datapoints as pure tensors to the kernels to avoid going through the # If you're wondering whether we could / should get rid of this wrapper,
# Tensor.__torch_function__ logic, which is costly. # the answer is no: we want to pass pure Tensors to avoid the overhead
# of the __torch_function__ machinery. Note that this is always valid,
# regardless of whether we override __torch_function__ in our base class
# or not.
# Also, even if we didn't call `as_subclass` here, we would still need
# this wrapper to call wrap_like(), because the Datapoint type would be
# lost after the first operation due to our own __torch_function__
# logic.
output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs) output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs)
return type(inpt).wrap_like(inpt, output) return type(inpt).wrap_like(inpt, output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment