Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
f4f685dd
Unverified
Commit
f4f685dd
authored
Aug 15, 2023
by
Nicolas Hug
Committed by
GitHub
Aug 15, 2023
Browse files
More datapoints docs and comments (#7830)
Co-authored-by:
vfdev
<
vfdev.5@gmail.com
>
parent
6c44ceb5
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
87 additions
and
50 deletions
+87
-50
gallery/plot_custom_datapoints.py
gallery/plot_custom_datapoints.py
+1
-5
gallery/plot_datapoints.py
gallery/plot_datapoints.py
+57
-25
test/test_datapoints.py
test/test_datapoints.py
+1
-0
torchvision/datapoints/_datapoint.py
torchvision/datapoints/_datapoint.py
+11
-17
torchvision/datapoints/_torch_function_helpers.py
torchvision/datapoints/_torch_function_helpers.py
+8
-1
torchvision/transforms/v2/functional/_utils.py
torchvision/transforms/v2/functional/_utils.py
+9
-2
No files found.
gallery/plot_custom_datapoints.py
View file @
f4f685dd
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
How to write your own Datapoint class
How to write your own Datapoint class
=====================================
=====================================
This guide is intended for downstream library maintainers. We explain how to
This guide is intended for
advanced users and
downstream library maintainers. We explain how to
write your own datapoint class, and how to make it compatible with the built-in
write your own datapoint class, and how to make it compatible with the built-in
Torchvision v2 transforms. Before continuing, make sure you have read
Torchvision v2 transforms. Before continuing, make sure you have read
:ref:`sphx_glr_auto_examples_plot_datapoints.py`.
:ref:`sphx_glr_auto_examples_plot_datapoints.py`.
...
@@ -68,10 +68,6 @@ def hflip_my_datapoint(my_dp, *args, **kwargs):
...
@@ -68,10 +68,6 @@ def hflip_my_datapoint(my_dp, *args, **kwargs):
# could also have used the functional *itself*, i.e.
# could also have used the functional *itself*, i.e.
# ``@register_kernel(functional=F.hflip, ...)``.
# ``@register_kernel(functional=F.hflip, ...)``.
#
#
# The functionals that you can be hooked into are the ones in
# ``torchvision.transforms.v2.functional`` and they are documented in
# :ref:`functional_transforms`.
#
# Now that we have registered our kernel, we can call the functional API on a
# Now that we have registered our kernel, we can call the functional API on a
# ``MyDatapoint`` instance:
# ``MyDatapoint`` instance:
...
...
gallery/plot_datapoints.py
View file @
f4f685dd
...
@@ -48,26 +48,22 @@ assert image.data_ptr() == tensor.data_ptr()
...
@@ -48,26 +48,22 @@ assert image.data_ptr() == tensor.data_ptr()
# Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
# Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
# for the input data.
# for the input data.
#
#
# :mod:`torchvision.datapoints` supports four types of datapoints:
#
# * :class:`~torchvision.datapoints.Image`
# * :class:`~torchvision.datapoints.Video`
# * :class:`~torchvision.datapoints.BoundingBoxes`
# * :class:`~torchvision.datapoints.Mask`
#
# What can I do with a datapoint?
# What can I do with a datapoint?
# -------------------------------
# -------------------------------
#
#
# Datapoints look and feel just like regular tensors - they **are** tensors.
# Datapoints look and feel just like regular tensors - they **are** tensors.
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or
# any ``torch.*`` operator will also work
s
on datapoints. See
# any ``torch.*`` operator will also work on datapoints. See
# :ref:`datapoint_unwrapping_behaviour` for a few gotchas.
# :ref:`datapoint_unwrapping_behaviour` for a few gotchas.
# %%
# %%
#
# What datapoints are supported?
# ------------------------------
#
# So far :mod:`torchvision.datapoints` supports four types of datapoints:
#
# * :class:`~torchvision.datapoints.Image`
# * :class:`~torchvision.datapoints.Video`
# * :class:`~torchvision.datapoints.BoundingBoxes`
# * :class:`~torchvision.datapoints.Mask`
#
# .. _datapoint_creation:
# .. _datapoint_creation:
#
#
# How do I construct a datapoint?
# How do I construct a datapoint?
...
@@ -209,9 +205,8 @@ def get_transform(train):
...
@@ -209,9 +205,8 @@ def get_transform(train):
# I had a Datapoint but now I have a Tensor. Help!
# I had a Datapoint but now I have a Tensor. Help!
# ------------------------------------------------
# ------------------------------------------------
#
#
# For a lot of operations involving datapoints, we cannot safely infer whether
# By default, operations on :class:`~torchvision.datapoints.Datapoint` objects
# the result should retain the datapoint type, so we choose to return a plain
# will return a pure Tensor:
# tensor instead of a datapoint (this might change, see note below):
assert
isinstance
(
bboxes
,
datapoints
.
BoundingBoxes
)
assert
isinstance
(
bboxes
,
datapoints
.
BoundingBoxes
)
...
@@ -219,32 +214,69 @@ assert isinstance(bboxes, datapoints.BoundingBoxes)
...
@@ -219,32 +214,69 @@ assert isinstance(bboxes, datapoints.BoundingBoxes)
# Shift bboxes by 3 pixels in both H and W
# Shift bboxes by 3 pixels in both H and W
new_bboxes
=
bboxes
+
3
new_bboxes
=
bboxes
+
3
assert
isinstance
(
new_bboxes
,
torch
.
Tensor
)
and
not
isinstance
(
new_bboxes
,
datapoints
.
BoundingBoxes
)
assert
isinstance
(
new_bboxes
,
torch
.
Tensor
)
assert
not
isinstance
(
new_bboxes
,
datapoints
.
BoundingBoxes
)
# %%
# .. note::
#
# This behavior only affects native ``torch`` operations. If you are using
# the built-in ``torchvision`` transforms or functionals, you will always get
# as output the same type that you passed as input (pure ``Tensor`` or
# ``Datapoint``).
# %%
# %%
# If you're writing your own custom transforms or code involving datapoints, you
# But I want a Datapoint back!
# can re-wrap the output into a datapoint by just calling their constructor, or
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# by using the ``.wrap_like()`` class method:
#
# You can re-wrap a pure tensor into a datapoint by just calling the datapoint
# constructor, or by using the ``.wrap_like()`` class method (see more details
# above in :ref:`datapoint_creation`):
new_bboxes
=
bboxes
+
3
new_bboxes
=
bboxes
+
3
new_bboxes
=
datapoints
.
BoundingBoxes
.
wrap_like
(
bboxes
,
new_bboxes
)
new_bboxes
=
datapoints
.
BoundingBoxes
.
wrap_like
(
bboxes
,
new_bboxes
)
assert
isinstance
(
new_bboxes
,
datapoints
.
BoundingBoxes
)
assert
isinstance
(
new_bboxes
,
datapoints
.
BoundingBoxes
)
# %%
# %%
# See more details above in :ref:`datapoint_creation`.
# Alternatively, you can use the :func:`~torchvision.datapoints.set_return_type`
# as a global config setting for the whole program, or as a context manager:
with
datapoints
.
set_return_type
(
"datapoint"
):
new_bboxes
=
bboxes
+
3
assert
isinstance
(
new_bboxes
,
datapoints
.
BoundingBoxes
)
# %%
# Why is this happening?
# ^^^^^^^^^^^^^^^^^^^^^^
#
#
# .. note::
# **For performance reasons**. :class:`~torchvision.datapoints.Datapoint`
# classes are Tensor subclasses, so any operation involving a
# :class:`~torchvision.datapoints.Datapoint` object will go through the
# `__torch_function__
# <https://pytorch.org/docs/stable/notes/extending.html#extending-torch>`_
# protocol. This induces a small overhead, which we want to avoid when possible.
# This doesn't matter for built-in ``torchvision`` transforms because we can
# avoid the overhead there, but it could be a problem in your model's
# ``forward``.
#
#
# You never need to re-wrap manually if you're using the built-in transforms
# **The alternative isn't much better anyway.** For every operation where
# or their functional equivalents: this is automatically taken care of for
# preserving the :class:`~torchvision.datapoints.Datapoint` type makes
# you.
# sense, there are just as many operations where returning a pure Tensor is
# preferable: for example, is ``img.sum()`` still an :class:`~torchvision.datapoints.Image`?
# If we were to preserve :class:`~torchvision.datapoints.Datapoint` types all
# the way, even model's logits or the output of the loss function would end up
# being of type :class:`~torchvision.datapoints.Image`, and surely that's not
# desirable.
#
#
# .. note::
# .. note::
#
#
# This
"unwrapping"
behaviour is something we're actively seeking feedback on. If you find this surprising or if you
# This behaviour is something we're actively seeking feedback on. If you find this surprising or if you
# have any suggestions on how to better support your use-cases, please reach out to us via this issue:
# have any suggestions on how to better support your use-cases, please reach out to us via this issue:
# https://github.com/pytorch/vision/issues/7319
# https://github.com/pytorch/vision/issues/7319
#
#
# Exceptions
# ^^^^^^^^^^
#
# There are a few exceptions to this "unwrapping" rule:
# There are a few exceptions to this "unwrapping" rule:
#
#
# 1. Operations like :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`,
# 1. Operations like :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`,
...
...
test/test_datapoints.py
View file @
f4f685dd
...
@@ -101,6 +101,7 @@ def test_to_datapoint_reference(make_input, return_type):
...
@@ -101,6 +101,7 @@ def test_to_datapoint_reference(make_input, return_type):
assert
type
(
tensor_to
)
is
(
type
(
dp
)
if
return_type
==
"datapoint"
else
torch
.
Tensor
)
assert
type
(
tensor_to
)
is
(
type
(
dp
)
if
return_type
==
"datapoint"
else
torch
.
Tensor
)
assert
tensor_to
.
dtype
is
dp
.
dtype
assert
tensor_to
.
dtype
is
dp
.
dtype
assert
type
(
tensor
)
is
torch
.
Tensor
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_box
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_box
,
make_segmentation_mask
,
make_video
])
...
...
torchvision/datapoints/_datapoint.py
View file @
f4f685dd
...
@@ -66,19 +66,12 @@ class Datapoint(torch.Tensor):
...
@@ -66,19 +66,12 @@ class Datapoint(torch.Tensor):
``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
``args`` and ``kwargs`` of the original call.
``args`` and ``kwargs`` of the original call.
The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Datapoint`
Why do we override this? Because the base implementation in torch.Tensor would preserve the Datapoint type
use case, this has two downsides:
of the output. In our case, we want to return pure tensors instead (with a few exceptions). Refer to the
"Datapoints FAQ" gallery example for a rationale of this behaviour (TL;DR: perf + no silver bullet).
1. Since some :class:`Datapoint`'s require metadata to be constructed, the default wrapping, i.e.
Our implementation below is very similar to the base implementation in ``torch.Tensor`` - go check it out.
``return cls(func(*args, **kwargs))``, will fail for them.
2. For most operations, there is no way of knowing if the input type is still valid for the output.
For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are
listed in _FORCE_TORCHFUNCTION_SUBCLASS
"""
"""
# Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
# need to reimplement the functionality.
if
not
all
(
issubclass
(
cls
,
t
)
for
t
in
types
):
if
not
all
(
issubclass
(
cls
,
t
)
for
t
in
types
):
return
NotImplemented
return
NotImplemented
...
@@ -89,12 +82,13 @@ class Datapoint(torch.Tensor):
...
@@ -89,12 +82,13 @@ class Datapoint(torch.Tensor):
must_return_subclass
=
_must_return_subclass
()
must_return_subclass
=
_must_return_subclass
()
if
must_return_subclass
or
(
func
in
_FORCE_TORCHFUNCTION_SUBCLASS
and
isinstance
(
args
[
0
],
cls
)):
if
must_return_subclass
or
(
func
in
_FORCE_TORCHFUNCTION_SUBCLASS
and
isinstance
(
args
[
0
],
cls
)):
# We also require the primary operand, i.e. `args[0]`, to be
# If you're wondering why we need the `isinstance(args[0], cls)` check, remove it and see what fails
# an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
# in test_to_datapoint_reference().
# invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
# The __torch_function__ protocol will invoke the __torch_function__ method on *all* types involved in
# `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with
# the computation by walking the MRO upwards. For example,
# `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
# `out = a_pure_tensor.to(an_image)` will invoke `Image.__torch_function__` with
# be wrapped into a `datapoints.Image`.
# `args = (a_pure_tensor, an_image)` first. Without this guard, `out` would
# be wrapped into an `Image`.
return
cls
.
_wrap_output
(
output
,
args
,
kwargs
)
return
cls
.
_wrap_output
(
output
,
args
,
kwargs
)
if
not
must_return_subclass
and
isinstance
(
output
,
cls
):
if
not
must_return_subclass
and
isinstance
(
output
,
cls
):
...
...
torchvision/datapoints/_torch_function_helpers.py
View file @
f4f685dd
...
@@ -18,12 +18,18 @@ class _ReturnTypeCM:
...
@@ -18,12 +18,18 @@ class _ReturnTypeCM:
def
set_return_type
(
return_type
:
str
):
def
set_return_type
(
return_type
:
str
):
"""Set the return type of torch operations on datapoints.
"""Set the return type of torch operations on datapoints.
This only affects the behaviour of torch operations. It has no effect on
``torchvision`` transforms or functionals, which will always return as
output the same type that was passed as input.
Can be used as a global flag for the entire program:
Can be used as a global flag for the entire program:
.. code:: python
.. code:: python
set_return_type("datapoints")
img = datapoints.Image(torch.rand(3, 5, 5))
img = datapoints.Image(torch.rand(3, 5, 5))
img + 2 # This is a pure Tensor (default behaviour)
set_return_type("datapoints")
img + 2 # This is an Image
img + 2 # This is an Image
or as a context manager to restrict the scope:
or as a context manager to restrict the scope:
...
@@ -31,6 +37,7 @@ def set_return_type(return_type: str):
...
@@ -31,6 +37,7 @@ def set_return_type(return_type: str):
.. code:: python
.. code:: python
img = datapoints.Image(torch.rand(3, 5, 5))
img = datapoints.Image(torch.rand(3, 5, 5))
img + 2 # This is a pure Tensor
with set_return_type("datapoints"):
with set_return_type("datapoints"):
img + 2 # This is an Image
img + 2 # This is an Image
img + 2 # This is a pure Tensor
img + 2 # This is a pure Tensor
...
...
torchvision/transforms/v2/functional/_utils.py
View file @
f4f685dd
...
@@ -19,8 +19,15 @@ _KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {}
...
@@ -19,8 +19,15 @@ _KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {}
def
_kernel_datapoint_wrapper
(
kernel
):
def
_kernel_datapoint_wrapper
(
kernel
):
@
functools
.
wraps
(
kernel
)
@
functools
.
wraps
(
kernel
)
def
wrapper
(
inpt
,
*
args
,
**
kwargs
):
def
wrapper
(
inpt
,
*
args
,
**
kwargs
):
# We always pass datapoints as pure tensors to the kernels to avoid going through the
# If you're wondering whether we could / should get rid of this wrapper,
# Tensor.__torch_function__ logic, which is costly.
# the answer is no: we want to pass pure Tensors to avoid the overhead
# of the __torch_function__ machinery. Note that this is always valid,
# regardless of whether we override __torch_function__ in our base class
# or not.
# Also, even if we didn't call `as_subclass` here, we would still need
# this wrapper to call wrap_like(), because the Datapoint type would be
# lost after the first operation due to our own __torch_function__
# logic.
output
=
kernel
(
inpt
.
as_subclass
(
torch
.
Tensor
),
*
args
,
**
kwargs
)
output
=
kernel
(
inpt
.
as_subclass
(
torch
.
Tensor
),
*
args
,
**
kwargs
)
return
type
(
inpt
).
wrap_like
(
inpt
,
output
)
return
type
(
inpt
).
wrap_like
(
inpt
,
output
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment