Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
25ec3f26
Unverified
Commit
25ec3f26
authored
Aug 30, 2023
by
Nicolas Hug
Committed by
GitHub
Aug 30, 2023
Browse files
tv_tensor -> TVTensor where it matters (#7904)
Co-authored-by:
Philip Meier
<
github.pmeier@posteo.de
>
parent
d5f4cc38
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
84 additions
and
65 deletions
+84
-65
docs/source/transforms.rst
docs/source/transforms.rst
+1
-1
docs/source/tv_tensors.rst
docs/source/tv_tensors.rst
+5
-4
gallery/transforms/plot_custom_transforms.py
gallery/transforms/plot_custom_transforms.py
+2
-2
gallery/transforms/plot_custom_tv_tensors.py
gallery/transforms/plot_custom_tv_tensors.py
+3
-3
gallery/transforms/plot_transforms_getting_started.py
gallery/transforms/plot_transforms_getting_started.py
+7
-7
gallery/transforms/plot_tv_tensors.py
gallery/transforms/plot_tv_tensors.py
+21
-21
test/test_tv_tensors.py
test/test_tv_tensors.py
+33
-20
torchvision/tv_tensors/_torch_function_helpers.py
torchvision/tv_tensors/_torch_function_helpers.py
+11
-6
torchvision/tv_tensors/_tv_tensor.py
torchvision/tv_tensors/_tv_tensor.py
+1
-1
No files found.
docs/source/transforms.rst
View file @
25ec3f26
...
...
@@ -183,7 +183,7 @@ Transforms are available as classes like
This is very much like the :mod:`torch.nn` package which defines both classes
and functional equivalents in :mod:`torch.nn.functional`.
The functionals support PIL images, pure tensors, or :ref:`
tv_t
ensors
The functionals support PIL images, pure tensors, or :ref:`
TVT
ensors
<tv_tensors>`, e.g. both ``resize(image_tensor)`` and ``resize(bboxes)`` are
valid.
...
...
docs/source/tv_tensors.rst
View file @
25ec3f26
...
...
@@ -5,10 +5,11 @@ TVTensors
.. currentmodule:: torchvision.tv_tensors
TVTensors are tensor subclasses which the :mod:`~torchvision.transforms.v2` v2 transforms use under the hood to
dispatch their inputs to the appropriate lower-level kernels. Most users do not
need to manipulate tv_tensors directly and can simply rely on dataset wrapping -
see e.g. :ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`.
TVTensors are :class:`torch.Tensor` subclasses which the v2 :ref:`transforms
<transforms>` use under the hood to dispatch their inputs to the appropriate
lower-level kernels. Most users do not need to manipulate TVTensors directly and
can simply rely on dataset wrapping - see e.g.
:ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`.
.. autosummary::
:toctree: generated/
...
...
gallery/transforms/plot_custom_transforms.py
View file @
25ec3f26
...
...
@@ -74,7 +74,7 @@ out_img, out_bboxes, out_label = transforms(img, bboxes, label)
print
(
f
"Output image shape:
{
out_img
.
shape
}
\n
out_bboxes =
{
out_bboxes
}
\n
{
out_label
=
}
"
)
# %%
# .. note::
# While working with
tv_t
ensor classes in your code, make sure to
# While working with
TVT
ensor classes in your code, make sure to
# familiarize yourself with this section:
# :ref:`tv_tensor_unwrapping_behaviour`
#
...
...
@@ -111,7 +111,7 @@ print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
# In brief, the core logic is to unpack the input into a flat list using `pytree
# <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and
# then transform only the entries that can be transformed (the decision is made
# based on the **class** of the entries, as all
tv_t
ensors are
# based on the **class** of the entries, as all
TVT
ensors are
# tensor-subclasses) plus some custom logic that is out of score here - check the
# code for details. The (potentially transformed) entries are then repacked and
# returned, in the same structure as the input.
...
...
gallery/transforms/plot_custom_tv_tensors.py
View file @
25ec3f26
"""
====================================
=
====================================
How to write your own TVTensor class
====================================
=
====================================
.. note::
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_custom_tv_tensors.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_custom_tv_tensors.py>` to download the full example code.
This guide is intended for advanced users and downstream library maintainers. We explain how to
write your own
tv_t
ensor class, and how to make it compatible with the built-in
write your own
TVT
ensor class, and how to make it compatible with the built-in
Torchvision v2 transforms. Before continuing, make sure you have read
:ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`.
"""
...
...
gallery/transforms/plot_transforms_getting_started.py
View file @
25ec3f26
...
...
@@ -115,7 +115,7 @@ plot([(img, boxes), (out_img, out_boxes)])
# segmentation, or videos (:class:`torchvision.tv_tensors.Video`), we could have
# passed them to the transforms in exactly the same way.
#
# By now you likely have a few questions: what are these
tv_t
ensors, how do we
# By now you likely have a few questions: what are these
TVT
ensors, how do we
# use them, and what is the expected input/output of those transforms? We'll
# answer these in the next sections.
...
...
@@ -126,7 +126,7 @@ plot([(img, boxes), (out_img, out_boxes)])
# What are TVTensors?
# --------------------
#
# TVTensors are :class:`torch.Tensor` subclasses. The available
tv_t
ensors are
# TVTensors are :class:`torch.Tensor` subclasses. The available
TVT
ensors are
# :class:`~torchvision.tv_tensors.Image`,
# :class:`~torchvision.tv_tensors.BoundingBoxes`,
# :class:`~torchvision.tv_tensors.Mask`, and
...
...
@@ -134,7 +134,7 @@ plot([(img, boxes), (out_img, out_boxes)])
#
# TVTensors look and feel just like regular tensors - they **are** tensors.
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()``
# or any ``torch.*`` operator will also work on a
tv_t
ensor:
# or any ``torch.*`` operator will also work on a
TVT
ensor:
img_dp
=
tv_tensors
.
Image
(
torch
.
randint
(
0
,
256
,
(
3
,
256
,
256
),
dtype
=
torch
.
uint8
))
...
...
@@ -146,7 +146,7 @@ print(f"{img_dp.dtype = }, {img_dp.shape = }, {img_dp.sum() = }")
# transform a given input, the transforms first look at the **class** of the
# object, and dispatch to the appropriate implementation accordingly.
#
# You don't need to know much more about
tv_t
ensors at this point, but advanced
# You don't need to know much more about
TVT
ensors at this point, but advanced
# users who want to learn more can refer to
# :ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`.
#
...
...
@@ -234,9 +234,9 @@ print(f"{out_target['this_is_ignored']}")
# Torchvision also supports datasets for object detection or segmentation like
# :class:`torchvision.datasets.CocoDetection`. Those datasets predate
# the existence of the :mod:`torchvision.transforms.v2` module and of the
#
tv_t
ensors, so they don't return
tv_t
ensors out of the box.
#
TVT
ensors, so they don't return
TVT
ensors out of the box.
#
# An easy way to force those datasets to return
tv_t
ensors and to make them
# An easy way to force those datasets to return
TVT
ensors and to make them
# compatible with v2 transforms is to use the
# :func:`torchvision.datasets.wrap_dataset_for_transforms_v2` function:
#
...
...
@@ -246,7 +246,7 @@ print(f"{out_target['this_is_ignored']}")
#
# dataset = CocoDetection(..., transforms=my_transforms)
# dataset = wrap_dataset_for_transforms_v2(dataset)
# # Now the dataset returns
tv_t
ensors!
# # Now the dataset returns
TVT
ensors!
#
# Using your own datasets
# ^^^^^^^^^^^^^^^^^^^^^^^
...
...
gallery/transforms/plot_tv_tensors.py
View file @
25ec3f26
...
...
@@ -9,18 +9,18 @@ TVTensors FAQ
TVTensors are Tensor subclasses introduced together with
``torchvision.transforms.v2``. This example showcases what these
tv_t
ensors are
``torchvision.transforms.v2``. This example showcases what these
TVT
ensors are
and how they behave.
.. warning::
**Intended Audience** Unless you're writing your own transforms or your own
tv_t
ensors, you
**Intended Audience** Unless you're writing your own transforms or your own
TVT
ensors, you
probably do not need to read this guide. This is a fairly low-level topic
that most users will not need to worry about: you do not need to understand
the internals of
tv_t
ensors to efficiently rely on
the internals of
TVT
ensors to efficiently rely on
``torchvision.transforms.v2``. It may however be useful for advanced users
trying to implement their own datasets, transforms, or work directly with
the
tv_t
ensors.
the
TVT
ensors.
"""
# %%
...
...
@@ -31,8 +31,8 @@ from torchvision import tv_tensors
# %%
# What are
tv_t
ensors?
# -------------------
-
# What are
TVT
ensors?
# -------------------
#
# TVTensors are zero-copy tensor subclasses:
...
...
@@ -46,31 +46,31 @@ assert image.data_ptr() == tensor.data_ptr()
# Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
# for the input data.
#
# :mod:`torchvision.tv_tensors` supports four types of
tv_t
ensors:
# :mod:`torchvision.tv_tensors` supports four types of
TVT
ensors:
#
# * :class:`~torchvision.tv_tensors.Image`
# * :class:`~torchvision.tv_tensors.Video`
# * :class:`~torchvision.tv_tensors.BoundingBoxes`
# * :class:`~torchvision.tv_tensors.Mask`
#
# What can I do with a
tv_t
ensor?
# ------------------------------
-
# What can I do with a
TVT
ensor?
# ------------------------------
#
# TVTensors look and feel just like regular tensors - they **are** tensors.
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or
# any ``torch.*`` operator will also work on
tv_t
ensors. See
# any ``torch.*`` operator will also work on
TVT
ensors. See
# :ref:`tv_tensor_unwrapping_behaviour` for a few gotchas.
# %%
# .. _tv_tensor_creation:
#
# How do I construct a
tv_t
ensor?
# ------------------------------
-
# How do I construct a
TVT
ensor?
# ------------------------------
#
# Using the constructor
# ^^^^^^^^^^^^^^^^^^^^^
#
# Each
tv_t
ensor class takes any tensor-like data that can be turned into a :class:`~torch.Tensor`
# Each
TVT
ensor class takes any tensor-like data that can be turned into a :class:`~torch.Tensor`
image
=
tv_tensors
.
Image
([[[[
0
,
1
],
[
1
,
0
]]]])
print
(
image
)
...
...
@@ -92,7 +92,7 @@ image = tv_tensors.Image(PIL.Image.open("../assets/astronaut.jpg"))
print
(
image
.
shape
,
image
.
dtype
)
# %%
# Some
tv_t
ensors require additional metadata to be passed in ordered to be constructed. For example,
# Some
TVT
ensors require additional metadata to be passed in ordered to be constructed. For example,
# :class:`~torchvision.tv_tensors.BoundingBoxes` requires the coordinate format as well as the size of the
# corresponding image (``canvas_size``) alongside the actual values. These
# metadata are required to properly transform the bounding boxes.
...
...
@@ -109,7 +109,7 @@ print(bboxes)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# You can also use the :func:`~torchvision.tv_tensors.wrap` function to wrap a tensor object
# into a
tv_t
ensor. This is useful when you already have an object of the
# into a
TVT
ensor. This is useful when you already have an object of the
# desired type, which typically happens when writing transforms: you just want
# to wrap the output like the input.
...
...
@@ -125,7 +125,7 @@ assert new_bboxes.canvas_size == bboxes.canvas_size
# .. _tv_tensor_unwrapping_behaviour:
#
# I had a TVTensor but now I have a Tensor. Help!
# -----------------------------------------------
-
# -----------------------------------------------
#
# By default, operations on :class:`~torchvision.tv_tensors.TVTensor` objects
# will return a pure Tensor:
...
...
@@ -151,7 +151,7 @@ assert not isinstance(new_bboxes, tv_tensors.BoundingBoxes)
# But I want a TVTensor back!
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# You can re-wrap a pure tensor into a
tv_t
ensor by just calling the
tv_t
ensor
# You can re-wrap a pure tensor into a
TVT
ensor by just calling the
TVT
ensor
# constructor, or by using the :func:`~torchvision.tv_tensors.wrap` function
# (see more details above in :ref:`tv_tensor_creation`):
...
...
@@ -164,7 +164,7 @@ assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
# as a global config setting for the whole program, or as a context manager
# (read its docs to learn more about caveats):
with
tv_tensors
.
set_return_type
(
"
tv_t
ensor"
):
with
tv_tensors
.
set_return_type
(
"
TVT
ensor"
):
new_bboxes
=
bboxes
+
3
assert
isinstance
(
new_bboxes
,
tv_tensors
.
BoundingBoxes
)
...
...
@@ -203,9 +203,9 @@ assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
# There are a few exceptions to this "unwrapping" rule:
# :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`,
# :meth:`torch.Tensor.detach`, and :meth:`~torch.Tensor.requires_grad_` retain
# the
tv_t
ensor type.
# the
TVT
ensor type.
#
# Inplace operations on
tv_t
ensors like ``obj.add_()`` will preserve the type of
# Inplace operations on
TVT
ensors like ``obj.add_()`` will preserve the type of
# ``obj``. However, the **returned** value of inplace operations will be a pure
# tensor:
...
...
@@ -213,7 +213,7 @@ image = tv_tensors.Image([[[0, 1], [1, 0]]])
new_image
=
image
.
add_
(
1
).
mul_
(
2
)
# image got transformed in-place and is still a
n Image tv_tensor
, but new_image
# image got transformed in-place and is still a
TVTensor Image
, but new_image
# is a Tensor. They share the same underlying data and they're equal, just
# different classes.
assert
isinstance
(
image
,
tv_tensors
.
Image
)
...
...
test/test_tv_tensors.py
View file @
25ec3f26
...
...
@@ -91,7 +91,7 @@ def test_to_wrapping(make_input):
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_t
ensor"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
TVT
ensor"
])
def
test_to_tv_tensor_reference
(
make_input
,
return_type
):
tensor
=
torch
.
rand
((
3
,
16
,
16
),
dtype
=
torch
.
float64
)
dp
=
make_input
()
...
...
@@ -99,13 +99,13 @@ def test_to_tv_tensor_reference(make_input, return_type):
with
tv_tensors
.
set_return_type
(
return_type
):
tensor_to
=
tensor
.
to
(
dp
)
assert
type
(
tensor_to
)
is
(
type
(
dp
)
if
return_type
==
"
tv_t
ensor"
else
torch
.
Tensor
)
assert
type
(
tensor_to
)
is
(
type
(
dp
)
if
return_type
==
"
TVT
ensor"
else
torch
.
Tensor
)
assert
tensor_to
.
dtype
is
dp
.
dtype
assert
type
(
tensor
)
is
torch
.
Tensor
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_t
ensor"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
TVT
ensor"
])
def
test_clone_wrapping
(
make_input
,
return_type
):
dp
=
make_input
()
...
...
@@ -117,7 +117,7 @@ def test_clone_wrapping(make_input, return_type):
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_t
ensor"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
TVT
ensor"
])
def
test_requires_grad__wrapping
(
make_input
,
return_type
):
dp
=
make_input
(
dtype
=
torch
.
float
)
...
...
@@ -132,7 +132,7 @@ def test_requires_grad__wrapping(make_input, return_type):
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_t
ensor"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
TVT
ensor"
])
def
test_detach_wrapping
(
make_input
,
return_type
):
dp
=
make_input
(
dtype
=
torch
.
float
).
requires_grad_
(
True
)
...
...
@@ -142,7 +142,7 @@ def test_detach_wrapping(make_input, return_type):
assert
type
(
dp_detached
)
is
type
(
dp
)
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_t
ensor"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
TVT
ensor"
])
def
test_force_subclass_with_metadata
(
return_type
):
# Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and tv_tensors with metadata
# Largely the same as above, we additionally check that the metadata is preserved
...
...
@@ -151,27 +151,27 @@ def test_force_subclass_with_metadata(return_type):
tv_tensors
.
set_return_type
(
return_type
)
bbox
=
bbox
.
clone
()
if
return_type
==
"
tv_t
ensor"
:
if
return_type
==
"
TVT
ensor"
:
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
bbox
=
bbox
.
to
(
torch
.
float64
)
if
return_type
==
"
tv_t
ensor"
:
if
return_type
==
"
TVT
ensor"
:
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
bbox
=
bbox
.
detach
()
if
return_type
==
"
tv_t
ensor"
:
if
return_type
==
"
TVT
ensor"
:
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
assert
not
bbox
.
requires_grad
bbox
.
requires_grad_
(
True
)
if
return_type
==
"
tv_t
ensor"
:
if
return_type
==
"
TVT
ensor"
:
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
assert
bbox
.
requires_grad
tv_tensors
.
set_return_type
(
"tensor"
)
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_t
ensor"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
TVT
ensor"
])
def
test_other_op_no_wrapping
(
make_input
,
return_type
):
dp
=
make_input
()
...
...
@@ -179,7 +179,7 @@ def test_other_op_no_wrapping(make_input, return_type):
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output
=
dp
*
2
assert
type
(
output
)
is
(
type
(
dp
)
if
return_type
==
"
tv_t
ensor"
else
torch
.
Tensor
)
assert
type
(
output
)
is
(
type
(
dp
)
if
return_type
==
"
TVT
ensor"
else
torch
.
Tensor
)
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
...
...
@@ -200,7 +200,7 @@ def test_no_tensor_output_op_no_wrapping(make_input, op):
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_t
ensor"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
TVT
ensor"
])
def
test_inplace_op_no_wrapping
(
make_input
,
return_type
):
dp
=
make_input
()
original_type
=
type
(
dp
)
...
...
@@ -208,7 +208,7 @@ def test_inplace_op_no_wrapping(make_input, return_type):
with
tv_tensors
.
set_return_type
(
return_type
):
output
=
dp
.
add_
(
0
)
assert
type
(
output
)
is
(
type
(
dp
)
if
return_type
==
"
tv_t
ensor"
else
torch
.
Tensor
)
assert
type
(
output
)
is
(
type
(
dp
)
if
return_type
==
"
TVT
ensor"
else
torch
.
Tensor
)
assert
type
(
dp
)
is
original_type
...
...
@@ -243,7 +243,7 @@ def test_deepcopy(make_input, requires_grad):
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
tv_t
ensor"
])
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"
TVT
ensor"
])
@
pytest
.
mark
.
parametrize
(
"op"
,
(
...
...
@@ -267,8 +267,8 @@ def test_usual_operations(make_input, return_type, op):
dp
=
make_input
()
with
tv_tensors
.
set_return_type
(
return_type
):
out
=
op
(
dp
)
assert
type
(
out
)
is
(
type
(
dp
)
if
return_type
==
"
tv_t
ensor"
else
torch
.
Tensor
)
if
isinstance
(
dp
,
tv_tensors
.
BoundingBoxes
)
and
return_type
==
"
tv_t
ensor"
:
assert
type
(
out
)
is
(
type
(
dp
)
if
return_type
==
"
TVT
ensor"
else
torch
.
Tensor
)
if
isinstance
(
dp
,
tv_tensors
.
BoundingBoxes
)
and
return_type
==
"
TVT
ensor"
:
assert
hasattr
(
out
,
"format"
)
assert
hasattr
(
out
,
"canvas_size"
)
...
...
@@ -286,16 +286,16 @@ def test_set_return_type():
assert
type
(
img
+
3
)
is
torch
.
Tensor
with
tv_tensors
.
set_return_type
(
"
tv_t
ensor"
):
with
tv_tensors
.
set_return_type
(
"
TVT
ensor"
):
assert
type
(
img
+
3
)
is
tv_tensors
.
Image
assert
type
(
img
+
3
)
is
torch
.
Tensor
tv_tensors
.
set_return_type
(
"
tv_t
ensor"
)
tv_tensors
.
set_return_type
(
"
TVT
ensor"
)
assert
type
(
img
+
3
)
is
tv_tensors
.
Image
with
tv_tensors
.
set_return_type
(
"tensor"
):
assert
type
(
img
+
3
)
is
torch
.
Tensor
with
tv_tensors
.
set_return_type
(
"
tv_t
ensor"
):
with
tv_tensors
.
set_return_type
(
"
TVT
ensor"
):
assert
type
(
img
+
3
)
is
tv_tensors
.
Image
tv_tensors
.
set_return_type
(
"tensor"
)
assert
type
(
img
+
3
)
is
torch
.
Tensor
...
...
@@ -305,3 +305,16 @@ def test_set_return_type():
assert
type
(
img
+
3
)
is
tv_tensors
.
Image
tv_tensors
.
set_return_type
(
"tensor"
)
def
test_return_type_input
():
img
=
make_image
()
# Case-insensitive
with
tv_tensors
.
set_return_type
(
"tvtensor"
):
assert
type
(
img
+
3
)
is
tv_tensors
.
Image
with
pytest
.
raises
(
ValueError
,
match
=
"return_type must be"
):
tv_tensors
.
set_return_type
(
"typo"
)
tv_tensors
.
set_return_type
(
"tensor"
)
torchvision/tv_tensors/_torch_function_helpers.py
View file @
25ec3f26
...
...
@@ -16,7 +16,7 @@ class _ReturnTypeCM:
def
set_return_type
(
return_type
:
str
):
"""[BETA] Set the return type of torch operations on tv_tensors.
"""[BETA] Set the return type of torch operations on
:class:`~torchvision.
tv_tensors.
TVTensor`.
This only affects the behaviour of torch operations. It has no effect on
``torchvision`` transforms or functionals, which will always return as
...
...
@@ -26,7 +26,7 @@ def set_return_type(return_type: str):
We recommend using :class:`~torchvision.transforms.v2.ToPureTensor` at
the end of your transform pipelines if you use
``set_return_type("
dataptoint
")``. This will avoid the
``set_return_type("
TVTensor
")``. This will avoid the
``__torch_function__`` overhead in the models ``forward()``.
Can be used as a global flag for the entire program:
...
...
@@ -36,7 +36,7 @@ def set_return_type(return_type: str):
img = tv_tensors.Image(torch.rand(3, 5, 5))
img + 2 # This is a pure Tensor (default behaviour)
set_return_type("
tv_t
ensor
s
")
set_return_type("
TVT
ensor")
img + 2 # This is an Image
or as a context manager to restrict the scope:
...
...
@@ -45,16 +45,21 @@ def set_return_type(return_type: str):
img = tv_tensors.Image(torch.rand(3, 5, 5))
img + 2 # This is a pure Tensor
with set_return_type("
tv_t
ensor
s
"):
with set_return_type("
TVT
ensor"):
img + 2 # This is an Image
img + 2 # This is a pure Tensor
Args:
return_type (str): Can be "tv_tensor" or "tensor". Default is "tensor".
return_type (str): Can be "TVTensor" or "Tensor" (case-insensitive).
Default is "Tensor" (i.e. pure :class:`torch.Tensor`).
"""
global
_TORCHFUNCTION_SUBCLASS
to_restore
=
_TORCHFUNCTION_SUBCLASS
_TORCHFUNCTION_SUBCLASS
=
{
"tensor"
:
False
,
"tv_tensor"
:
True
}[
return_type
.
lower
()]
try
:
_TORCHFUNCTION_SUBCLASS
=
{
"tensor"
:
False
,
"tvtensor"
:
True
}[
return_type
.
lower
()]
except
KeyError
:
raise
ValueError
(
f
"return_type must be 'TVTensor' or 'Tensor', got
{
return_type
}
"
)
from
None
return
_ReturnTypeCM
(
to_restore
)
...
...
torchvision/tv_tensors/_tv_tensor.py
View file @
25ec3f26
...
...
@@ -13,7 +13,7 @@ D = TypeVar("D", bound="TVTensor")
class
TVTensor
(
torch
.
Tensor
):
"""[Beta] Base class for all
tv_t
ensors.
"""[Beta] Base class for all
TVT
ensors.
You probably don't want to use this class unless you're defining your own
custom TVTensors. See
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment