Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
f4f685dd
Unverified
Commit
f4f685dd
authored
Aug 15, 2023
by
Nicolas Hug
Committed by
GitHub
Aug 15, 2023
Browse files
More datapoints docs and comments (#7830)
Co-authored-by:
vfdev
<
vfdev.5@gmail.com
>
parent
6c44ceb5
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
87 additions
and
50 deletions
+87
-50
gallery/plot_custom_datapoints.py
gallery/plot_custom_datapoints.py
+1
-5
gallery/plot_datapoints.py
gallery/plot_datapoints.py
+57
-25
test/test_datapoints.py
test/test_datapoints.py
+1
-0
torchvision/datapoints/_datapoint.py
torchvision/datapoints/_datapoint.py
+11
-17
torchvision/datapoints/_torch_function_helpers.py
torchvision/datapoints/_torch_function_helpers.py
+8
-1
torchvision/transforms/v2/functional/_utils.py
torchvision/transforms/v2/functional/_utils.py
+9
-2
No files found.
gallery/plot_custom_datapoints.py
View file @
f4f685dd
...
...
@@ -3,7 +3,7 @@
How to write your own Datapoint class
=====================================
This guide is intended for downstream library maintainers. We explain how to
This guide is intended for
advanced users and
downstream library maintainers. We explain how to
write your own datapoint class, and how to make it compatible with the built-in
Torchvision v2 transforms. Before continuing, make sure you have read
:ref:`sphx_glr_auto_examples_plot_datapoints.py`.
...
...
@@ -68,10 +68,6 @@ def hflip_my_datapoint(my_dp, *args, **kwargs):
# could also have used the functional *itself*, i.e.
# ``@register_kernel(functional=F.hflip, ...)``.
#
# The functionals that you can be hooked into are the ones in
# ``torchvision.transforms.v2.functional`` and they are documented in
# :ref:`functional_transforms`.
#
# Now that we have registered our kernel, we can call the functional API on a
# ``MyDatapoint`` instance:
...
...
gallery/plot_datapoints.py
View file @
f4f685dd
...
...
@@ -48,26 +48,22 @@ assert image.data_ptr() == tensor.data_ptr()
# Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
# for the input data.
#
# :mod:`torchvision.datapoints` supports four types of datapoints:
#
# * :class:`~torchvision.datapoints.Image`
# * :class:`~torchvision.datapoints.Video`
# * :class:`~torchvision.datapoints.BoundingBoxes`
# * :class:`~torchvision.datapoints.Mask`
#
# What can I do with a datapoint?
# -------------------------------
#
# Datapoints look and feel just like regular tensors - they **are** tensors.
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or
# any ``torch.*`` operator will also work
s
on datapoints. See
# any ``torch.*`` operator will also work on datapoints. See
# :ref:`datapoint_unwrapping_behaviour` for a few gotchas.
# %%
#
# What datapoints are supported?
# ------------------------------
#
# So far :mod:`torchvision.datapoints` supports four types of datapoints:
#
# * :class:`~torchvision.datapoints.Image`
# * :class:`~torchvision.datapoints.Video`
# * :class:`~torchvision.datapoints.BoundingBoxes`
# * :class:`~torchvision.datapoints.Mask`
#
# .. _datapoint_creation:
#
# How do I construct a datapoint?
...
...
@@ -209,9 +205,8 @@ def get_transform(train):
# I had a Datapoint but now I have a Tensor. Help!
# ------------------------------------------------
#
# For a lot of operations involving datapoints, we cannot safely infer whether
# the result should retain the datapoint type, so we choose to return a plain
# tensor instead of a datapoint (this might change, see note below):
# By default, operations on :class:`~torchvision.datapoints.Datapoint` objects
# will return a pure Tensor:
assert
isinstance
(
bboxes
,
datapoints
.
BoundingBoxes
)
...
...
@@ -219,32 +214,69 @@ assert isinstance(bboxes, datapoints.BoundingBoxes)
# Shift bboxes by 3 pixels in both H and W
new_bboxes
=
bboxes
+
3
assert
isinstance
(
new_bboxes
,
torch
.
Tensor
)
and
not
isinstance
(
new_bboxes
,
datapoints
.
BoundingBoxes
)
assert
isinstance
(
new_bboxes
,
torch
.
Tensor
)
assert
not
isinstance
(
new_bboxes
,
datapoints
.
BoundingBoxes
)
# %%
# .. note::
#
# This behavior only affects native ``torch`` operations. If you are using
# the built-in ``torchvision`` transforms or functionals, you will always get
# as output the same type that you passed as input (pure ``Tensor`` or
# ``Datapoint``).
# %%
# If you're writing your own custom transforms or code involving datapoints, you
# can re-wrap the output into a datapoint by just calling their constructor, or
# by using the ``.wrap_like()`` class method:
# But I want a Datapoint back!
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# You can re-wrap a pure tensor into a datapoint by just calling the datapoint
# constructor, or by using the ``.wrap_like()`` class method (see more details
# above in :ref:`datapoint_creation`):
new_bboxes
=
bboxes
+
3
new_bboxes
=
datapoints
.
BoundingBoxes
.
wrap_like
(
bboxes
,
new_bboxes
)
assert
isinstance
(
new_bboxes
,
datapoints
.
BoundingBoxes
)
# %%
# See more details above in :ref:`datapoint_creation`.
# Alternatively, you can use the :func:`~torchvision.datapoints.set_return_type`
# as a global config setting for the whole program, or as a context manager:
with
datapoints
.
set_return_type
(
"datapoint"
):
new_bboxes
=
bboxes
+
3
assert
isinstance
(
new_bboxes
,
datapoints
.
BoundingBoxes
)
# %%
# Why is this happening?
# ^^^^^^^^^^^^^^^^^^^^^^
#
# .. note::
# **For performance reasons**. :class:`~torchvision.datapoints.Datapoint`
# classes are Tensor subclasses, so any operation involving a
# :class:`~torchvision.datapoints.Datapoint` object will go through the
# `__torch_function__
# <https://pytorch.org/docs/stable/notes/extending.html#extending-torch>`_
# protocol. This induces a small overhead, which we want to avoid when possible.
# This doesn't matter for built-in ``torchvision`` transforms because we can
# avoid the overhead there, but it could be a problem in your model's
# ``forward``.
#
# You never need to re-wrap manually if you're using the built-in transforms
# or their functional equivalents: this is automatically taken care of for
# you.
# **The alternative isn't much better anyway.** For every operation where
# preserving the :class:`~torchvision.datapoints.Datapoint` type makes
# sense, there are just as many operations where returning a pure Tensor is
# preferable: for example, is ``img.sum()`` still an :class:`~torchvision.datapoints.Image`?
# If we were to preserve :class:`~torchvision.datapoints.Datapoint` types all
# the way, even model's logits or the output of the loss function would end up
# being of type :class:`~torchvision.datapoints.Image`, and surely that's not
# desirable.
#
# .. note::
#
# This
"unwrapping"
behaviour is something we're actively seeking feedback on. If you find this surprising or if you
# This behaviour is something we're actively seeking feedback on. If you find this surprising or if you
# have any suggestions on how to better support your use-cases, please reach out to us via this issue:
# https://github.com/pytorch/vision/issues/7319
#
# Exceptions
# ^^^^^^^^^^
#
# There are a few exceptions to this "unwrapping" rule:
#
# 1. Operations like :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`,
...
...
test/test_datapoints.py
View file @
f4f685dd
...
...
@@ -101,6 +101,7 @@ def test_to_datapoint_reference(make_input, return_type):
assert
type
(
tensor_to
)
is
(
type
(
dp
)
if
return_type
==
"datapoint"
else
torch
.
Tensor
)
assert
tensor_to
.
dtype
is
dp
.
dtype
assert
type
(
tensor
)
is
torch
.
Tensor
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_box
,
make_segmentation_mask
,
make_video
])
...
...
torchvision/datapoints/_datapoint.py
View file @
f4f685dd
...
...
@@ -66,19 +66,12 @@ class Datapoint(torch.Tensor):
``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
``args`` and ``kwargs`` of the original call.
The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Datapoint`
use case, this has two downsides:
Why do we override this? Because the base implementation in torch.Tensor would preserve the Datapoint type
of the output. In our case, we want to return pure tensors instead (with a few exceptions). Refer to the
"Datapoints FAQ" gallery example for a rationale of this behaviour (TL;DR: perf + no silver bullet).
1. Since some :class:`Datapoint`'s require metadata to be constructed, the default wrapping, i.e.
``return cls(func(*args, **kwargs))``, will fail for them.
2. For most operations, there is no way of knowing if the input type is still valid for the output.
For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are
listed in _FORCE_TORCHFUNCTION_SUBCLASS
Our implementation below is very similar to the base implementation in ``torch.Tensor`` - go check it out.
"""
# Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
# need to reimplement the functionality.
if
not
all
(
issubclass
(
cls
,
t
)
for
t
in
types
):
return
NotImplemented
...
...
@@ -89,12 +82,13 @@ class Datapoint(torch.Tensor):
must_return_subclass
=
_must_return_subclass
()
if
must_return_subclass
or
(
func
in
_FORCE_TORCHFUNCTION_SUBCLASS
and
isinstance
(
args
[
0
],
cls
)):
# We also require the primary operand, i.e. `args[0]`, to be
# an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
# invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
# `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with
# `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
# be wrapped into a `datapoints.Image`.
# If you're wondering why we need the `isinstance(args[0], cls)` check, remove it and see what fails
# in test_to_datapoint_reference().
# The __torch_function__ protocol will invoke the __torch_function__ method on *all* types involved in
# the computation by walking the MRO upwards. For example,
# `out = a_pure_tensor.to(an_image)` will invoke `Image.__torch_function__` with
# `args = (a_pure_tensor, an_image)` first. Without this guard, `out` would
# be wrapped into an `Image`.
return
cls
.
_wrap_output
(
output
,
args
,
kwargs
)
if
not
must_return_subclass
and
isinstance
(
output
,
cls
):
...
...
torchvision/datapoints/_torch_function_helpers.py
View file @
f4f685dd
...
...
@@ -18,12 +18,18 @@ class _ReturnTypeCM:
def
set_return_type
(
return_type
:
str
):
"""Set the return type of torch operations on datapoints.
This only affects the behaviour of torch operations. It has no effect on
``torchvision`` transforms or functionals, which will always return as
output the same type that was passed as input.
Can be used as a global flag for the entire program:
.. code:: python
set_return_type("datapoints")
img = datapoints.Image(torch.rand(3, 5, 5))
img + 2 # This is a pure Tensor (default behaviour)
set_return_type("datapoints")
img + 2 # This is an Image
or as a context manager to restrict the scope:
...
...
@@ -31,6 +37,7 @@ def set_return_type(return_type: str):
.. code:: python
img = datapoints.Image(torch.rand(3, 5, 5))
img + 2 # This is a pure Tensor
with set_return_type("datapoints"):
img + 2 # This is an Image
img + 2 # This is a pure Tensor
...
...
torchvision/transforms/v2/functional/_utils.py
View file @
f4f685dd
...
...
@@ -19,8 +19,15 @@ _KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {}
def
_kernel_datapoint_wrapper
(
kernel
):
@
functools
.
wraps
(
kernel
)
def
wrapper
(
inpt
,
*
args
,
**
kwargs
):
# We always pass datapoints as pure tensors to the kernels to avoid going through the
# Tensor.__torch_function__ logic, which is costly.
# If you're wondering whether we could / should get rid of this wrapper,
# the answer is no: we want to pass pure Tensors to avoid the overhead
# of the __torch_function__ machinery. Note that this is always valid,
# regardless of whether we override __torch_function__ in our base class
# or not.
# Also, even if we didn't call `as_subclass` here, we would still need
# this wrapper to call wrap_like(), because the Datapoint type would be
# lost after the first operation due to our own __torch_function__
# logic.
output
=
kernel
(
inpt
.
as_subclass
(
torch
.
Tensor
),
*
args
,
**
kwargs
)
return
type
(
inpt
).
wrap_like
(
inpt
,
output
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment