Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
25ec3f26
Unverified
Commit
25ec3f26
authored
Aug 30, 2023
by
Nicolas Hug
Committed by
GitHub
Aug 30, 2023
Browse files
tv_tensor -> TVTensor where it matters (#7904)
Co-authored-by:
Philip Meier
<
github.pmeier@posteo.de
>
parent
d5f4cc38
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
84 additions
and
65 deletions
+84
-65
docs/source/transforms.rst
docs/source/transforms.rst
+1
-1
docs/source/tv_tensors.rst
docs/source/tv_tensors.rst
+5
-4
gallery/transforms/plot_custom_transforms.py
gallery/transforms/plot_custom_transforms.py
+2
-2
gallery/transforms/plot_custom_tv_tensors.py
gallery/transforms/plot_custom_tv_tensors.py
+3
-3
gallery/transforms/plot_transforms_getting_started.py
gallery/transforms/plot_transforms_getting_started.py
+7
-7
gallery/transforms/plot_tv_tensors.py
gallery/transforms/plot_tv_tensors.py
+21
-21
test/test_tv_tensors.py
test/test_tv_tensors.py
+33
-20
torchvision/tv_tensors/_torch_function_helpers.py
torchvision/tv_tensors/_torch_function_helpers.py
+11
-6
torchvision/tv_tensors/_tv_tensor.py
torchvision/tv_tensors/_tv_tensor.py
+1
-1
No files found.
docs/source/transforms.rst
View file @
25ec3f26
...
@@ -183,7 +183,7 @@ Transforms are available as classes like
...
@@ -183,7 +183,7 @@ Transforms are available as classes like
This is very much like the :mod:`torch.nn` package which defines both classes
This is very much like the :mod:`torch.nn` package which defines both classes
and functional equivalents in :mod:`torch.nn.functional`.
and functional equivalents in :mod:`torch.nn.functional`.
The functionals support PIL images, pure tensors, or :ref:`
tv_t
ensors
The functionals support PIL images, pure tensors, or :ref:`
TVT
ensors
<tv_tensors>`, e.g. both ``resize(image_tensor)`` and ``resize(bboxes)`` are
<tv_tensors>`, e.g. both ``resize(image_tensor)`` and ``resize(bboxes)`` are
valid.
valid.
...
...
docs/source/tv_tensors.rst
View file @
25ec3f26
...
@@ -5,10 +5,11 @@ TVTensors
...
@@ -5,10 +5,11 @@ TVTensors
.. currentmodule:: torchvision.tv_tensors
.. currentmodule:: torchvision.tv_tensors
TVTensors are tensor subclasses which the :mod:`~torchvision.transforms.v2` v2 transforms use under the hood to
TVTensors are :class:`torch.Tensor` subclasses which the v2 :ref:`transforms
dispatch their inputs to the appropriate lower-level kernels. Most users do not
<transforms>` use under the hood to dispatch their inputs to the appropriate
need to manipulate tv_tensors directly and can simply rely on dataset wrapping -
lower-level kernels. Most users do not need to manipulate TVTensors directly and
see e.g. :ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`.
can simply rely on dataset wrapping - see e.g.
:ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`.
.. autosummary::
.. autosummary::
:toctree: generated/
:toctree: generated/
...
...
gallery/transforms/plot_custom_transforms.py
View file @
25ec3f26
...
@@ -74,7 +74,7 @@ out_img, out_bboxes, out_label = transforms(img, bboxes, label)
...
@@ -74,7 +74,7 @@ out_img, out_bboxes, out_label = transforms(img, bboxes, label)
print
(
f
"Output image shape:
{
out_img
.
shape
}
\n
out_bboxes =
{
out_bboxes
}
\n
{
out_label
=
}
"
)
print
(
f
"Output image shape:
{
out_img
.
shape
}
\n
out_bboxes =
{
out_bboxes
}
\n
{
out_label
=
}
"
)
# %%
# %%
# .. note::
# .. note::
# While working with
tv_t
ensor classes in your code, make sure to
# While working with
TVT
ensor classes in your code, make sure to
# familiarize yourself with this section:
# familiarize yourself with this section:
# :ref:`tv_tensor_unwrapping_behaviour`
# :ref:`tv_tensor_unwrapping_behaviour`
#
#
...
@@ -111,7 +111,7 @@ print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
...
@@ -111,7 +111,7 @@ print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
# In brief, the core logic is to unpack the input into a flat list using `pytree
# In brief, the core logic is to unpack the input into a flat list using `pytree
# <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and
# <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and
# then transform only the entries that can be transformed (the decision is made
# then transform only the entries that can be transformed (the decision is made
# based on the **class** of the entries, as all
tv_t
ensors are
# based on the **class** of the entries, as all
TVT
ensors are
# tensor-subclasses) plus some custom logic that is out of score here - check the
# tensor-subclasses) plus some custom logic that is out of score here - check the
# code for details. The (potentially transformed) entries are then repacked and
# code for details. The (potentially transformed) entries are then repacked and
# returned, in the same structure as the input.
# returned, in the same structure as the input.
...
...
gallery/transforms/plot_custom_tv_tensors.py
View file @
25ec3f26
"""
"""
====================================
=
====================================
How to write your own TVTensor class
How to write your own TVTensor class
====================================
=
====================================
.. note::
.. note::
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_custom_tv_tensors.ipynb>`_
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_custom_tv_tensors.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_custom_tv_tensors.py>` to download the full example code.
or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_custom_tv_tensors.py>` to download the full example code.
This guide is intended for advanced users and downstream library maintainers. We explain how to
This guide is intended for advanced users and downstream library maintainers. We explain how to
write your own
tv_t
ensor class, and how to make it compatible with the built-in
write your own
TVT
ensor class, and how to make it compatible with the built-in
Torchvision v2 transforms. Before continuing, make sure you have read
Torchvision v2 transforms. Before continuing, make sure you have read
:ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`.
:ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`.
"""
"""
...
...
gallery/transforms/plot_transforms_getting_started.py
View file @
25ec3f26
...
@@ -115,7 +115,7 @@ plot([(img, boxes), (out_img, out_boxes)])
...
@@ -115,7 +115,7 @@ plot([(img, boxes), (out_img, out_boxes)])
# segmentation, or videos (:class:`torchvision.tv_tensors.Video`), we could have
# segmentation, or videos (:class:`torchvision.tv_tensors.Video`), we could have
# passed them to the transforms in exactly the same way.
# passed them to the transforms in exactly the same way.
#
#
# By now you likely have a few questions: what are these
tv_t
ensors, how do we
# By now you likely have a few questions: what are these
TVT
ensors, how do we
# use them, and what is the expected input/output of those transforms? We'll
# use them, and what is the expected input/output of those transforms? We'll
# answer these in the next sections.
# answer these in the next sections.
...
@@ -126,7 +126,7 @@ plot([(img, boxes), (out_img, out_boxes)])
...
@@ -126,7 +126,7 @@ plot([(img, boxes), (out_img, out_boxes)])
# What are TVTensors?
# What are TVTensors?
# --------------------
# --------------------
#
#
# TVTensors are :class:`torch.Tensor` subclasses. The available
tv_t
ensors are
# TVTensors are :class:`torch.Tensor` subclasses. The available
TVT
ensors are
# :class:`~torchvision.tv_tensors.Image`,
# :class:`~torchvision.tv_tensors.Image`,
# :class:`~torchvision.tv_tensors.BoundingBoxes`,
# :class:`~torchvision.tv_tensors.BoundingBoxes`,
# :class:`~torchvision.tv_tensors.Mask`, and
# :class:`~torchvision.tv_tensors.Mask`, and
...
@@ -134,7 +134,7 @@ plot([(img, boxes), (out_img, out_boxes)])
...
@@ -134,7 +134,7 @@ plot([(img, boxes), (out_img, out_boxes)])
#
#
# TVTensors look and feel just like regular tensors - they **are** tensors.
# TVTensors look and feel just like regular tensors - they **are** tensors.
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()``
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()``
# or any ``torch.*`` operator will also work on a
tv_t
ensor:
# or any ``torch.*`` operator will also work on a
TVT
ensor:
img_dp
=
tv_tensors
.
Image
(
torch
.
randint
(
0
,
256
,
(
3
,
256
,
256
),
dtype
=
torch
.
uint8
))
img_dp
=
tv_tensors
.
Image
(
torch
.
randint
(
0
,
256
,
(
3
,
256
,
256
),
dtype
=
torch
.
uint8
))
...
@@ -146,7 +146,7 @@ print(f"{img_dp.dtype = }, {img_dp.shape = }, {img_dp.sum() = }")
...
@@ -146,7 +146,7 @@ print(f"{img_dp.dtype = }, {img_dp.shape = }, {img_dp.sum() = }")
# transform a given input, the transforms first look at the **class** of the
# transform a given input, the transforms first look at the **class** of the
# object, and dispatch to the appropriate implementation accordingly.
# object, and dispatch to the appropriate implementation accordingly.
#
#
# You don't need to know much more about
tv_t
ensors at this point, but advanced
# You don't need to know much more about
TVT
ensors at this point, but advanced
# users who want to learn more can refer to
# users who want to learn more can refer to
# :ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`.
# :ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`.
#
#
...
@@ -234,9 +234,9 @@ print(f"{out_target['this_is_ignored']}")
...
@@ -234,9 +234,9 @@ print(f"{out_target['this_is_ignored']}")
# Torchvision also supports datasets for object detection or segmentation like
# Torchvision also supports datasets for object detection or segmentation like
# :class:`torchvision.datasets.CocoDetection`. Those datasets predate
# :class:`torchvision.datasets.CocoDetection`. Those datasets predate
# the existence of the :mod:`torchvision.transforms.v2` module and of the
# the existence of the :mod:`torchvision.transforms.v2` module and of the
#
tv_t
ensors, so they don't return
tv_t
ensors out of the box.
#
TVT
ensors, so they don't return
TVT
ensors out of the box.
#
#
# An easy way to force those datasets to return
tv_t
ensors and to make them
# An easy way to force those datasets to return
TVT
ensors and to make them
# compatible with v2 transforms is to use the
# compatible with v2 transforms is to use the
# :func:`torchvision.datasets.wrap_dataset_for_transforms_v2` function:
# :func:`torchvision.datasets.wrap_dataset_for_transforms_v2` function:
#
#
...
@@ -246,7 +246,7 @@ print(f"{out_target['this_is_ignored']}")
...
@@ -246,7 +246,7 @@ print(f"{out_target['this_is_ignored']}")
#
#
# dataset = CocoDetection(..., transforms=my_transforms)
# dataset = CocoDetection(..., transforms=my_transforms)
# dataset = wrap_dataset_for_transforms_v2(dataset)
# dataset = wrap_dataset_for_transforms_v2(dataset)
# # Now the dataset returns
tv_t
ensors!
# # Now the dataset returns
TVT
ensors!
#
#
# Using your own datasets
# Using your own datasets
# ^^^^^^^^^^^^^^^^^^^^^^^
# ^^^^^^^^^^^^^^^^^^^^^^^
...
...
gallery/transforms/plot_tv_tensors.py
View file @
25ec3f26
...
@@ -9,18 +9,18 @@ TVTensors FAQ
...
@@ -9,18 +9,18 @@ TVTensors FAQ
TVTensors are Tensor subclasses introduced together with
TVTensors are Tensor subclasses introduced together with
``torchvision.transforms.v2``. This example showcases what these
tv_t
ensors are
``torchvision.transforms.v2``. This example showcases what these
TVT
ensors are
and how they behave.
and how they behave.
.. warning::
.. warning::
**Intended Audience** Unless you're writing your own transforms or your own
tv_t
ensors, you
**Intended Audience** Unless you're writing your own transforms or your own
TVT
ensors, you
probably do not need to read this guide. This is a fairly low-level topic
probably do not need to read this guide. This is a fairly low-level topic
that most users will not need to worry about: you do not need to understand
that most users will not need to worry about: you do not need to understand
the internals of
tv_t
ensors to efficiently rely on
the internals of
TVT
ensors to efficiently rely on
``torchvision.transforms.v2``. It may however be useful for advanced users
``torchvision.transforms.v2``. It may however be useful for advanced users
trying to implement their own datasets, transforms, or work directly with
trying to implement their own datasets, transforms, or work directly with
the
tv_t
ensors.
the
TVT
ensors.
"""
"""
# %%
# %%
...
@@ -31,8 +31,8 @@ from torchvision import tv_tensors
...
@@ -31,8 +31,8 @@ from torchvision import tv_tensors
# %%
# %%
# What are
tv_t
ensors?
# What are
TVT
ensors?
# -------------------
-
# -------------------
#
#
# TVTensors are zero-copy tensor subclasses:
# TVTensors are zero-copy tensor subclasses:
...
@@ -46,31 +46,31 @@ assert image.data_ptr() == tensor.data_ptr()
...
@@ -46,31 +46,31 @@ assert image.data_ptr() == tensor.data_ptr()
# Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
# Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
# for the input data.
# for the input data.
#
#
# :mod:`torchvision.tv_tensors` supports four types of
tv_t
ensors:
# :mod:`torchvision.tv_tensors` supports four types of
TVT
ensors:
#
#
# * :class:`~torchvision.tv_tensors.Image`
# * :class:`~torchvision.tv_tensors.Image`
# * :class:`~torchvision.tv_tensors.Video`
# * :class:`~torchvision.tv_tensors.Video`
# * :class:`~torchvision.tv_tensors.BoundingBoxes`
# * :class:`~torchvision.tv_tensors.BoundingBoxes`
# * :class:`~torchvision.tv_tensors.Mask`
# * :class:`~torchvision.tv_tensors.Mask`
#
#
# What can I do with a
tv_t
ensor?
# What can I do with a
TVT
ensor?
# ------------------------------
-
# ------------------------------
#
#
# TVTensors look and feel just like regular tensors - they **are** tensors.
# TVTensors look and feel just like regular tensors - they **are** tensors.
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or
# any ``torch.*`` operator will also work on
tv_t
ensors. See
# any ``torch.*`` operator will also work on
TVT
ensors. See
# :ref:`tv_tensor_unwrapping_behaviour` for a few gotchas.
# :ref:`tv_tensor_unwrapping_behaviour` for a few gotchas.
# %%
# %%
# .. _tv_tensor_creation:
# .. _tv_tensor_creation:
#
#
# How do I construct a
tv_t
ensor?
# How do I construct a
TVT
ensor?
# ------------------------------
-
# ------------------------------
#
#
# Using the constructor
# Using the constructor
# ^^^^^^^^^^^^^^^^^^^^^
# ^^^^^^^^^^^^^^^^^^^^^
#
#
# Each
tv_t
ensor class takes any tensor-like data that can be turned into a :class:`~torch.Tensor`
# Each
TVT
ensor class takes any tensor-like data that can be turned into a :class:`~torch.Tensor`
image
=
tv_tensors
.
Image
([[[[
0
,
1
],
[
1
,
0
]]]])
image
=
tv_tensors
.
Image
([[[[
0
,
1
],
[
1
,
0
]]]])
print
(
image
)
print
(
image
)
...
@@ -92,7 +92,7 @@ image = tv_tensors.Image(PIL.Image.open("../assets/astronaut.jpg"))
...
@@ -92,7 +92,7 @@ image = tv_tensors.Image(PIL.Image.open("../assets/astronaut.jpg"))
print
(
image
.
shape
,
image
.
dtype
)
print
(
image
.
shape
,
image
.
dtype
)
# %%
# %%
# Some
tv_t
ensors require additional metadata to be passed in ordered to be constructed. For example,
# Some
TVT
ensors require additional metadata to be passed in ordered to be constructed. For example,
# :class:`~torchvision.tv_tensors.BoundingBoxes` requires the coordinate format as well as the size of the
# :class:`~torchvision.tv_tensors.BoundingBoxes` requires the coordinate format as well as the size of the
# corresponding image (``canvas_size``) alongside the actual values. These
# corresponding image (``canvas_size``) alongside the actual values. These
# metadata are required to properly transform the bounding boxes.
# metadata are required to properly transform the bounding boxes.
...
@@ -109,7 +109,7 @@ print(bboxes)
...
@@ -109,7 +109,7 @@ print(bboxes)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
#
# You can also use the :func:`~torchvision.tv_tensors.wrap` function to wrap a tensor object
# You can also use the :func:`~torchvision.tv_tensors.wrap` function to wrap a tensor object
# into a
tv_t
ensor. This is useful when you already have an object of the
# into a
TVT
ensor. This is useful when you already have an object of the
# desired type, which typically happens when writing transforms: you just want
# desired type, which typically happens when writing transforms: you just want
# to wrap the output like the input.
# to wrap the output like the input.
...
@@ -125,7 +125,7 @@ assert new_bboxes.canvas_size == bboxes.canvas_size
...
@@ -125,7 +125,7 @@ assert new_bboxes.canvas_size == bboxes.canvas_size
# .. _tv_tensor_unwrapping_behaviour:
# .. _tv_tensor_unwrapping_behaviour:
#
#
# I had a TVTensor but now I have a Tensor. Help!
# I had a TVTensor but now I have a Tensor. Help!
# -----------------------------------------------
-
# -----------------------------------------------
#
#
# By default, operations on :class:`~torchvision.tv_tensors.TVTensor` objects
# By default, operations on :class:`~torchvision.tv_tensors.TVTensor` objects
# will return a pure Tensor:
# will return a pure Tensor:
...
@@ -151,7 +151,7 @@ assert not isinstance(new_bboxes, tv_tensors.BoundingBoxes)
...
@@ -151,7 +151,7 @@ assert not isinstance(new_bboxes, tv_tensors.BoundingBoxes)
# But I want a TVTensor back!
# But I want a TVTensor back!
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
#
# You can re-wrap a pure tensor into a
tv_t
ensor by just calling the
tv_t
ensor
# You can re-wrap a pure tensor into a
TVT
ensor by just calling the
TVT
ensor
# constructor, or by using the :func:`~torchvision.tv_tensors.wrap` function
# constructor, or by using the :func:`~torchvision.tv_tensors.wrap` function
# (see more details above in :ref:`tv_tensor_creation`):
# (see more details above in :ref:`tv_tensor_creation`):
...
@@ -164,7 +164,7 @@ assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
...
@@ -164,7 +164,7 @@ assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
# as a global config setting for the whole program, or as a context manager
# as a global config setting for the whole program, or as a context manager
# (read its docs to learn more about caveats):
# (read its docs to learn more about caveats):
with
tv_tensors
.
set_return_type
(
"
tv_t
ensor"
):
with
tv_tensors
.
set_return_type
(
"
TVT
ensor"
):
new_bboxes
=
bboxes
+
3
new_bboxes
=
bboxes
+
3
assert
isinstance
(
new_bboxes
,
tv_tensors
.
BoundingBoxes
)
assert
isinstance
(
new_bboxes
,
tv_tensors
.
BoundingBoxes
)
...
@@ -203,9 +203,9 @@ assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
...
@@ -203,9 +203,9 @@ assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
# There are a few exceptions to this "unwrapping" rule:
# There are a few exceptions to this "unwrapping" rule:
# :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`,
# :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`,
# :meth:`torch.Tensor.detach`, and :meth:`~torch.Tensor.requires_grad_` retain
# :meth:`torch.Tensor.detach`, and :meth:`~torch.Tensor.requires_grad_` retain
# the
tv_t
ensor type.
# the
TVT
ensor type.
#
#
# Inplace operations on
tv_t
ensors like ``obj.add_()`` will preserve the type of
# Inplace operations on
TVT
ensors like ``obj.add_()`` will preserve the type of
# ``obj``. However, the **returned** value of inplace operations will be a pure
# ``obj``. However, the **returned** value of inplace operations will be a pure
# tensor:
# tensor:
...
@@ -213,7 +213,7 @@ image = tv_tensors.Image([[[0, 1], [1, 0]]])
...
@@ -213,7 +213,7 @@ image = tv_tensors.Image([[[0, 1], [1, 0]]])
new_image
=
image
.
add_
(
1
).
mul_
(
2
)
new_image
=
image
.
add_
(
1
).
mul_
(
2
)
# image got transformed in-place and is still a
n Image tv_tensor
, but new_image
# image got transformed in-place and is still a
TVTensor Image
, but new_image
# is a Tensor. They share the same underlying data and they're equal, just
# is a Tensor. They share the same underlying data and they're equal, just
# different classes.
# different classes.
assert
isinstance
(
image
,
tv_tensors
.
Image
)
assert
isinstance
(
image
,
tv_tensors
.
Image
)
...
...
test/test_tv_tensors.py
View file @
25ec3f26
...
@@ -91,7 +91,7 @@ def test_to_wrapping(make_input):
...
@@ -91,7 +91,7 @@ def test_to_wrapping(make_input):
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_t
ensor"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
TVT
ensor"
])
def
test_to_tv_tensor_reference
(
make_input
,
return_type
):
def
test_to_tv_tensor_reference
(
make_input
,
return_type
):
tensor
=
torch
.
rand
((
3
,
16
,
16
),
dtype
=
torch
.
float64
)
tensor
=
torch
.
rand
((
3
,
16
,
16
),
dtype
=
torch
.
float64
)
dp
=
make_input
()
dp
=
make_input
()
...
@@ -99,13 +99,13 @@ def test_to_tv_tensor_reference(make_input, return_type):
...
@@ -99,13 +99,13 @@ def test_to_tv_tensor_reference(make_input, return_type):
with
tv_tensors
.
set_return_type
(
return_type
):
with
tv_tensors
.
set_return_type
(
return_type
):
tensor_to
=
tensor
.
to
(
dp
)
tensor_to
=
tensor
.
to
(
dp
)
assert
type
(
tensor_to
)
is
(
type
(
dp
)
if
return_type
==
"
tv_t
ensor"
else
torch
.
Tensor
)
assert
type
(
tensor_to
)
is
(
type
(
dp
)
if
return_type
==
"
TVT
ensor"
else
torch
.
Tensor
)
assert
tensor_to
.
dtype
is
dp
.
dtype
assert
tensor_to
.
dtype
is
dp
.
dtype
assert
type
(
tensor
)
is
torch
.
Tensor
assert
type
(
tensor
)
is
torch
.
Tensor
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_t
ensor"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
TVT
ensor"
])
def
test_clone_wrapping
(
make_input
,
return_type
):
def
test_clone_wrapping
(
make_input
,
return_type
):
dp
=
make_input
()
dp
=
make_input
()
...
@@ -117,7 +117,7 @@ def test_clone_wrapping(make_input, return_type):
...
@@ -117,7 +117,7 @@ def test_clone_wrapping(make_input, return_type):
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_t
ensor"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
TVT
ensor"
])
def
test_requires_grad__wrapping
(
make_input
,
return_type
):
def
test_requires_grad__wrapping
(
make_input
,
return_type
):
dp
=
make_input
(
dtype
=
torch
.
float
)
dp
=
make_input
(
dtype
=
torch
.
float
)
...
@@ -132,7 +132,7 @@ def test_requires_grad__wrapping(make_input, return_type):
...
@@ -132,7 +132,7 @@ def test_requires_grad__wrapping(make_input, return_type):
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_t
ensor"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
TVT
ensor"
])
def
test_detach_wrapping
(
make_input
,
return_type
):
def
test_detach_wrapping
(
make_input
,
return_type
):
dp
=
make_input
(
dtype
=
torch
.
float
).
requires_grad_
(
True
)
dp
=
make_input
(
dtype
=
torch
.
float
).
requires_grad_
(
True
)
...
@@ -142,7 +142,7 @@ def test_detach_wrapping(make_input, return_type):
...
@@ -142,7 +142,7 @@ def test_detach_wrapping(make_input, return_type):
assert
type
(
dp_detached
)
is
type
(
dp
)
assert
type
(
dp_detached
)
is
type
(
dp
)
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_t
ensor"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
TVT
ensor"
])
def
test_force_subclass_with_metadata
(
return_type
):
def
test_force_subclass_with_metadata
(
return_type
):
# Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and tv_tensors with metadata
# Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and tv_tensors with metadata
# Largely the same as above, we additionally check that the metadata is preserved
# Largely the same as above, we additionally check that the metadata is preserved
...
@@ -151,27 +151,27 @@ def test_force_subclass_with_metadata(return_type):
...
@@ -151,27 +151,27 @@ def test_force_subclass_with_metadata(return_type):
tv_tensors
.
set_return_type
(
return_type
)
tv_tensors
.
set_return_type
(
return_type
)
bbox
=
bbox
.
clone
()
bbox
=
bbox
.
clone
()
if
return_type
==
"
tv_t
ensor"
:
if
return_type
==
"
TVT
ensor"
:
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
bbox
=
bbox
.
to
(
torch
.
float64
)
bbox
=
bbox
.
to
(
torch
.
float64
)
if
return_type
==
"
tv_t
ensor"
:
if
return_type
==
"
TVT
ensor"
:
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
bbox
=
bbox
.
detach
()
bbox
=
bbox
.
detach
()
if
return_type
==
"
tv_t
ensor"
:
if
return_type
==
"
TVT
ensor"
:
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
assert
not
bbox
.
requires_grad
assert
not
bbox
.
requires_grad
bbox
.
requires_grad_
(
True
)
bbox
.
requires_grad_
(
True
)
if
return_type
==
"
tv_t
ensor"
:
if
return_type
==
"
TVT
ensor"
:
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
assert
bbox
.
requires_grad
assert
bbox
.
requires_grad
tv_tensors
.
set_return_type
(
"tensor"
)
tv_tensors
.
set_return_type
(
"tensor"
)
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_t
ensor"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
TVT
ensor"
])
def
test_other_op_no_wrapping
(
make_input
,
return_type
):
def
test_other_op_no_wrapping
(
make_input
,
return_type
):
dp
=
make_input
()
dp
=
make_input
()
...
@@ -179,7 +179,7 @@ def test_other_op_no_wrapping(make_input, return_type):
...
@@ -179,7 +179,7 @@ def test_other_op_no_wrapping(make_input, return_type):
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output
=
dp
*
2
output
=
dp
*
2
assert
type
(
output
)
is
(
type
(
dp
)
if
return_type
==
"
tv_t
ensor"
else
torch
.
Tensor
)
assert
type
(
output
)
is
(
type
(
dp
)
if
return_type
==
"
TVT
ensor"
else
torch
.
Tensor
)
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
...
@@ -200,7 +200,7 @@ def test_no_tensor_output_op_no_wrapping(make_input, op):
...
@@ -200,7 +200,7 @@ def test_no_tensor_output_op_no_wrapping(make_input, op):
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_t
ensor"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
TVT
ensor"
])
def
test_inplace_op_no_wrapping
(
make_input
,
return_type
):
def
test_inplace_op_no_wrapping
(
make_input
,
return_type
):
dp
=
make_input
()
dp
=
make_input
()
original_type
=
type
(
dp
)
original_type
=
type
(
dp
)
...
@@ -208,7 +208,7 @@ def test_inplace_op_no_wrapping(make_input, return_type):
...
@@ -208,7 +208,7 @@ def test_inplace_op_no_wrapping(make_input, return_type):
with
tv_tensors
.
set_return_type
(
return_type
):
with
tv_tensors
.
set_return_type
(
return_type
):
output
=
dp
.
add_
(
0
)
output
=
dp
.
add_
(
0
)
assert
type
(
output
)
is
(
type
(
dp
)
if
return_type
==
"
tv_t
ensor"
else
torch
.
Tensor
)
assert
type
(
output
)
is
(
type
(
dp
)
if
return_type
==
"
TVT
ensor"
else
torch
.
Tensor
)
assert
type
(
dp
)
is
original_type
assert
type
(
dp
)
is
original_type
...
@@ -243,7 +243,7 @@ def test_deepcopy(make_input, requires_grad):
...
@@ -243,7 +243,7 @@ def test_deepcopy(make_input, requires_grad):
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_t
ensor"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
TVT
ensor"
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"op"
,
"op"
,
(
(
...
@@ -267,8 +267,8 @@ def test_usual_operations(make_input, return_type, op):
...
@@ -267,8 +267,8 @@ def test_usual_operations(make_input, return_type, op):
dp
=
make_input
()
dp
=
make_input
()
with
tv_tensors
.
set_return_type
(
return_type
):
with
tv_tensors
.
set_return_type
(
return_type
):
out
=
op
(
dp
)
out
=
op
(
dp
)
assert
type
(
out
)
is
(
type
(
dp
)
if
return_type
==
"
tv_t
ensor"
else
torch
.
Tensor
)
assert
type
(
out
)
is
(
type
(
dp
)
if
return_type
==
"
TVT
ensor"
else
torch
.
Tensor
)
if
isinstance
(
dp
,
tv_tensors
.
BoundingBoxes
)
and
return_type
==
"
tv_t
ensor"
:
if
isinstance
(
dp
,
tv_tensors
.
BoundingBoxes
)
and
return_type
==
"
TVT
ensor"
:
assert
hasattr
(
out
,
"format"
)
assert
hasattr
(
out
,
"format"
)
assert
hasattr
(
out
,
"canvas_size"
)
assert
hasattr
(
out
,
"canvas_size"
)
...
@@ -286,16 +286,16 @@ def test_set_return_type():
...
@@ -286,16 +286,16 @@ def test_set_return_type():
assert
type
(
img
+
3
)
is
torch
.
Tensor
assert
type
(
img
+
3
)
is
torch
.
Tensor
with
tv_tensors
.
set_return_type
(
"
tv_t
ensor"
):
with
tv_tensors
.
set_return_type
(
"
TVT
ensor"
):
assert
type
(
img
+
3
)
is
tv_tensors
.
Image
assert
type
(
img
+
3
)
is
tv_tensors
.
Image
assert
type
(
img
+
3
)
is
torch
.
Tensor
assert
type
(
img
+
3
)
is
torch
.
Tensor
tv_tensors
.
set_return_type
(
"
tv_t
ensor"
)
tv_tensors
.
set_return_type
(
"
TVT
ensor"
)
assert
type
(
img
+
3
)
is
tv_tensors
.
Image
assert
type
(
img
+
3
)
is
tv_tensors
.
Image
with
tv_tensors
.
set_return_type
(
"tensor"
):
with
tv_tensors
.
set_return_type
(
"tensor"
):
assert
type
(
img
+
3
)
is
torch
.
Tensor
assert
type
(
img
+
3
)
is
torch
.
Tensor
with
tv_tensors
.
set_return_type
(
"
tv_t
ensor"
):
with
tv_tensors
.
set_return_type
(
"
TVT
ensor"
):
assert
type
(
img
+
3
)
is
tv_tensors
.
Image
assert
type
(
img
+
3
)
is
tv_tensors
.
Image
tv_tensors
.
set_return_type
(
"tensor"
)
tv_tensors
.
set_return_type
(
"tensor"
)
assert
type
(
img
+
3
)
is
torch
.
Tensor
assert
type
(
img
+
3
)
is
torch
.
Tensor
...
@@ -305,3 +305,16 @@ def test_set_return_type():
...
@@ -305,3 +305,16 @@ def test_set_return_type():
assert
type
(
img
+
3
)
is
tv_tensors
.
Image
assert
type
(
img
+
3
)
is
tv_tensors
.
Image
tv_tensors
.
set_return_type
(
"tensor"
)
tv_tensors
.
set_return_type
(
"tensor"
)
def
test_return_type_input
():
img
=
make_image
()
# Case-insensitive
with
tv_tensors
.
set_return_type
(
"tvtensor"
):
assert
type
(
img
+
3
)
is
tv_tensors
.
Image
with
pytest
.
raises
(
ValueError
,
match
=
"return_type must be"
):
tv_tensors
.
set_return_type
(
"typo"
)
tv_tensors
.
set_return_type
(
"tensor"
)
torchvision/tv_tensors/_torch_function_helpers.py
View file @
25ec3f26
...
@@ -16,7 +16,7 @@ class _ReturnTypeCM:
...
@@ -16,7 +16,7 @@ class _ReturnTypeCM:
def
set_return_type
(
return_type
:
str
):
def
set_return_type
(
return_type
:
str
):
"""[BETA] Set the return type of torch operations on tv_tensors.
"""[BETA] Set the return type of torch operations on
:class:`~torchvision.
tv_tensors.
TVTensor`.
This only affects the behaviour of torch operations. It has no effect on
This only affects the behaviour of torch operations. It has no effect on
``torchvision`` transforms or functionals, which will always return as
``torchvision`` transforms or functionals, which will always return as
...
@@ -26,7 +26,7 @@ def set_return_type(return_type: str):
...
@@ -26,7 +26,7 @@ def set_return_type(return_type: str):
We recommend using :class:`~torchvision.transforms.v2.ToPureTensor` at
We recommend using :class:`~torchvision.transforms.v2.ToPureTensor` at
the end of your transform pipelines if you use
the end of your transform pipelines if you use
``set_return_type("
dataptoint
")``. This will avoid the
``set_return_type("
TVTensor
")``. This will avoid the
``__torch_function__`` overhead in the models ``forward()``.
``__torch_function__`` overhead in the models ``forward()``.
Can be used as a global flag for the entire program:
Can be used as a global flag for the entire program:
...
@@ -36,7 +36,7 @@ def set_return_type(return_type: str):
...
@@ -36,7 +36,7 @@ def set_return_type(return_type: str):
img = tv_tensors.Image(torch.rand(3, 5, 5))
img = tv_tensors.Image(torch.rand(3, 5, 5))
img + 2 # This is a pure Tensor (default behaviour)
img + 2 # This is a pure Tensor (default behaviour)
set_return_type("
tv_t
ensor
s
")
set_return_type("
TVT
ensor")
img + 2 # This is an Image
img + 2 # This is an Image
or as a context manager to restrict the scope:
or as a context manager to restrict the scope:
...
@@ -45,16 +45,21 @@ def set_return_type(return_type: str):
...
@@ -45,16 +45,21 @@ def set_return_type(return_type: str):
img = tv_tensors.Image(torch.rand(3, 5, 5))
img = tv_tensors.Image(torch.rand(3, 5, 5))
img + 2 # This is a pure Tensor
img + 2 # This is a pure Tensor
with set_return_type("
tv_t
ensor
s
"):
with set_return_type("
TVT
ensor"):
img + 2 # This is an Image
img + 2 # This is an Image
img + 2 # This is a pure Tensor
img + 2 # This is a pure Tensor
Args:
Args:
return_type (str): Can be "tv_tensor" or "tensor". Default is "tensor".
return_type (str): Can be "TVTensor" or "Tensor" (case-insensitive).
Default is "Tensor" (i.e. pure :class:`torch.Tensor`).
"""
"""
global
_TORCHFUNCTION_SUBCLASS
global
_TORCHFUNCTION_SUBCLASS
to_restore
=
_TORCHFUNCTION_SUBCLASS
to_restore
=
_TORCHFUNCTION_SUBCLASS
_TORCHFUNCTION_SUBCLASS
=
{
"tensor"
:
False
,
"tv_tensor"
:
True
}[
return_type
.
lower
()]
try
:
_TORCHFUNCTION_SUBCLASS
=
{
"tensor"
:
False
,
"tvtensor"
:
True
}[
return_type
.
lower
()]
except
KeyError
:
raise
ValueError
(
f
"return_type must be 'TVTensor' or 'Tensor', got
{
return_type
}
"
)
from
None
return
_ReturnTypeCM
(
to_restore
)
return
_ReturnTypeCM
(
to_restore
)
...
...
torchvision/tv_tensors/_tv_tensor.py
View file @
25ec3f26
...
@@ -13,7 +13,7 @@ D = TypeVar("D", bound="TVTensor")
...
@@ -13,7 +13,7 @@ D = TypeVar("D", bound="TVTensor")
class
TVTensor
(
torch
.
Tensor
):
class
TVTensor
(
torch
.
Tensor
):
"""[Beta] Base class for all
tv_t
ensors.
"""[Beta] Base class for all
TVT
ensors.
You probably don't want to use this class unless you're defining your own
You probably don't want to use this class unless you're defining your own
custom TVTensors. See
custom TVTensors. See
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment