Unverified Commit d5f4cc38 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Datapoint -> TVTensor; datapoint[s] -> tv_tensor[s] (#7894)

parent b9447fdd
......@@ -88,8 +88,8 @@ class CustomGalleryExampleSortKey:
"plot_transforms_e2e.py",
"plot_cutmix_mixup.py",
"plot_custom_transforms.py",
"plot_datapoints.py",
"plot_custom_datapoints.py",
"plot_tv_tensors.py",
"plot_custom_tv_tensors.py",
]
def __call__(self, filename):
......
......@@ -32,7 +32,7 @@ architectures, and common image transformations for computer vision.
:caption: Package Reference
transforms
datapoints
tv_tensors
models
datasets
utils
......
......@@ -30,12 +30,12 @@ tasks (image classification, detection, segmentation, video classification).
.. code:: python
# Detection (re-using imports and transforms from above)
from torchvision import datapoints
from torchvision import tv_tensors
img = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)
bboxes = torch.randint(0, H // 2, size=(3, 4))
bboxes[:, 2:] += bboxes[:, :2]
bboxes = datapoints.BoundingBoxes(bboxes, format="XYXY", canvas_size=(H, W))
bboxes = tv_tensors.BoundingBoxes(bboxes, format="XYXY", canvas_size=(H, W))
# The same transforms can be used!
img, bboxes = transforms(img, bboxes)
......@@ -183,8 +183,8 @@ Transforms are available as classes like
This is very much like the :mod:`torch.nn` package which defines both classes
and functional equivalents in :mod:`torch.nn.functional`.
The functionals support PIL images, pure tensors, or :ref:`datapoints
<datapoints>`, e.g. both ``resize(image_tensor)`` and ``resize(bboxes)`` are
The functionals support PIL images, pure tensors, or :ref:`tv_tensors
<tv_tensors>`, e.g. both ``resize(image_tensor)`` and ``resize(bboxes)`` are
valid.
.. note::
......
.. _datapoints:
.. _tv_tensors:
Datapoints
TVTensors
==========
.. currentmodule:: torchvision.datapoints
.. currentmodule:: torchvision.tv_tensors
Datapoints are tensor subclasses which the :mod:`~torchvision.transforms.v2` v2 transforms use under the hood to
TVTensors are tensor subclasses which the :mod:`~torchvision.transforms.v2` v2 transforms use under the hood to
dispatch their inputs to the appropriate lower-level kernels. Most users do not
need to manipulate datapoints directly and can simply rely on dataset wrapping -
need to manipulate tv_tensors directly and can simply rely on dataset wrapping -
see e.g. :ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`.
.. autosummary::
......@@ -19,6 +19,6 @@ see e.g. :ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`.
BoundingBoxFormat
BoundingBoxes
Mask
Datapoint
TVTensor
set_return_type
wrap
import matplotlib.pyplot as plt
import torch
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F
......@@ -22,7 +22,7 @@ def plot(imgs, row_title=None, **imshow_kwargs):
if isinstance(target, dict):
boxes = target.get("boxes")
masks = target.get("masks")
elif isinstance(target, datapoints.BoundingBoxes):
elif isinstance(target, tv_tensors.BoundingBoxes):
boxes = target
else:
raise ValueError(f"Unexpected target type: {type(target)}")
......
......@@ -13,7 +13,7 @@ torchvision transforms V2 API.
# %%
import torch
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms import v2
......@@ -62,7 +62,7 @@ transforms = v2.Compose([
H, W = 256, 256
img = torch.rand(3, H, W)
bboxes = datapoints.BoundingBoxes(
bboxes = tv_tensors.BoundingBoxes(
torch.tensor([[0, 10, 10, 20], [50, 50, 70, 70]]),
format="XYXY",
canvas_size=(H, W)
......@@ -74,9 +74,9 @@ out_img, out_bboxes, out_label = transforms(img, bboxes, label)
print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }")
# %%
# .. note::
# While working with datapoint classes in your code, make sure to
# While working with tv_tensor classes in your code, make sure to
# familiarize yourself with this section:
# :ref:`datapoint_unwrapping_behaviour`
# :ref:`tv_tensor_unwrapping_behaviour`
#
# Supporting arbitrary input structures
# =====================================
......@@ -111,7 +111,7 @@ print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
# In brief, the core logic is to unpack the input into a flat list using `pytree
# <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and
# then transform only the entries that can be transformed (the decision is made
# based on the **class** of the entries, as all datapoints are
# based on the **class** of the entries, as all tv_tensors are
# tensor-subclasses) plus some custom logic that is out of score here - check the
# code for details. The (potentially transformed) entries are then repacked and
# returned, in the same structure as the input.
......
"""
=====================================
How to write your own Datapoint class
How to write your own TVTensor class
=====================================
.. note::
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_custom_datapoints.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_custom_datapoints.py>` to download the full example code.
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_custom_tv_tensors.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_custom_tv_tensors.py>` to download the full example code.
This guide is intended for advanced users and downstream library maintainers. We explain how to
write your own datapoint class, and how to make it compatible with the built-in
write your own tv_tensor class, and how to make it compatible with the built-in
Torchvision v2 transforms. Before continuing, make sure you have read
:ref:`sphx_glr_auto_examples_transforms_plot_datapoints.py`.
:ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`.
"""
# %%
import torch
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms import v2
# %%
# We will create a very simple class that just inherits from the base
# :class:`~torchvision.datapoints.Datapoint` class. It will be enough to cover
# :class:`~torchvision.tv_tensors.TVTensor` class. It will be enough to cover
# what you need to know to implement your more elaborate uses-cases. If you need
# to create a class that carries meta-data, take a look at how the
# :class:`~torchvision.datapoints.BoundingBoxes` class is `implemented
# <https://github.com/pytorch/vision/blob/main/torchvision/datapoints/_bounding_box.py>`_.
# :class:`~torchvision.tv_tensors.BoundingBoxes` class is `implemented
# <https://github.com/pytorch/vision/blob/main/torchvision/tv_tensors/_bounding_box.py>`_.
class MyDatapoint(datapoints.Datapoint):
class MyTVTensor(tv_tensors.TVTensor):
pass
my_dp = MyDatapoint([1, 2, 3])
my_dp = MyTVTensor([1, 2, 3])
my_dp
# %%
# Now that we have defined our custom Datapoint class, we want it to be
# Now that we have defined our custom TVTensor class, we want it to be
# compatible with the built-in torchvision transforms, and the functional API.
# For that, we need to implement a kernel which performs the core of the
# transformation, and then "hook" it to the functional that we want to support
# via :func:`~torchvision.transforms.v2.functional.register_kernel`.
#
# We illustrate this process below: we create a kernel for the "horizontal flip"
# operation of our MyDatapoint class, and register it to the functional API.
# operation of our MyTVTensor class, and register it to the functional API.
from torchvision.transforms.v2 import functional as F
@F.register_kernel(functional="hflip", datapoint_cls=MyDatapoint)
def hflip_my_datapoint(my_dp, *args, **kwargs):
@F.register_kernel(functional="hflip", tv_tensor_cls=MyTVTensor)
def hflip_my_tv_tensor(my_dp, *args, **kwargs):
print("Flipping!")
out = my_dp.flip(-1)
return datapoints.wrap(out, like=my_dp)
return tv_tensors.wrap(out, like=my_dp)
# %%
# To understand why :func:`~torchvision.datapoints.wrap` is used, see
# :ref:`datapoint_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now,
# To understand why :func:`~torchvision.tv_tensors.wrap` is used, see
# :ref:`tv_tensor_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now,
# we will explain it below in :ref:`param_forwarding`.
#
# .. note::
......@@ -67,9 +67,9 @@ def hflip_my_datapoint(my_dp, *args, **kwargs):
# ``@register_kernel(functional=F.hflip, ...)``.
#
# Now that we have registered our kernel, we can call the functional API on a
# ``MyDatapoint`` instance:
# ``MyTVTensor`` instance:
my_dp = MyDatapoint(torch.rand(3, 256, 256))
my_dp = MyTVTensor(torch.rand(3, 256, 256))
_ = F.hflip(my_dp)
# %%
......@@ -102,10 +102,10 @@ _ = t(my_dp)
# to its :func:`~torchvision.transforms.v2.functional.hflip` functional. If you
# already defined and registered your own kernel as
def hflip_my_datapoint(my_dp): # noqa
def hflip_my_tv_tensor(my_dp): # noqa
print("Flipping!")
out = my_dp.flip(-1)
return datapoints.wrap(out, like=my_dp)
return tv_tensors.wrap(out, like=my_dp)
# %%
......
......@@ -23,7 +23,7 @@ import pathlib
import torch
import torch.utils.data
from torchvision import models, datasets, datapoints
from torchvision import models, datasets, tv_tensors
from torchvision.transforms import v2
torch.manual_seed(0)
......@@ -72,7 +72,7 @@ print(f"{type(target['boxes']) = }\n{type(target['labels']) = }\n{type(target['m
# %%
# We used the ``target_keys`` parameter to specify the kind of output we're
# interested in. Our dataset now returns a target which is dict where the values
# are :ref:`Datapoints <what_are_datapoints>` (all are :class:`torch.Tensor`
# are :ref:`TVTensors <what_are_tv_tensors>` (all are :class:`torch.Tensor`
# subclasses). We're dropped all unncessary keys from the previous output, but
# if you need any of the original keys e.g. "image_id", you can still ask for
# it.
......@@ -103,7 +103,7 @@ transforms = v2.Compose(
[
v2.ToImage(),
v2.RandomPhotometricDistort(p=1),
v2.RandomZoomOut(fill={datapoints.Image: (123, 117, 104), "others": 0}),
v2.RandomZoomOut(fill={tv_tensors.Image: (123, 117, 104), "others": 0}),
v2.RandomIoUCrop(),
v2.RandomHorizontalFlip(p=1),
v2.SanitizeBoundingBoxes(),
......
......@@ -88,9 +88,9 @@ plot([img, out])
#
# Let's briefly look at a detection example with bounding boxes.
from torchvision import datapoints # we'll describe this a bit later, bare with us
from torchvision import tv_tensors # we'll describe this a bit later, bare with us
boxes = datapoints.BoundingBoxes(
boxes = tv_tensors.BoundingBoxes(
[
[15, 10, 370, 510],
[275, 340, 510, 510],
......@@ -111,44 +111,44 @@ plot([(img, boxes), (out_img, out_boxes)])
# %%
#
# The example above focuses on object detection. But if we had masks
# (:class:`torchvision.datapoints.Mask`) for object segmentation or semantic
# segmentation, or videos (:class:`torchvision.datapoints.Video`), we could have
# (:class:`torchvision.tv_tensors.Mask`) for object segmentation or semantic
# segmentation, or videos (:class:`torchvision.tv_tensors.Video`), we could have
# passed them to the transforms in exactly the same way.
#
# By now you likely have a few questions: what are these datapoints, how do we
# By now you likely have a few questions: what are these tv_tensors, how do we
# use them, and what is the expected input/output of those transforms? We'll
# answer these in the next sections.
# %%
#
# .. _what_are_datapoints:
# .. _what_are_tv_tensors:
#
# What are Datapoints?
# What are TVTensors?
# --------------------
#
# Datapoints are :class:`torch.Tensor` subclasses. The available datapoints are
# :class:`~torchvision.datapoints.Image`,
# :class:`~torchvision.datapoints.BoundingBoxes`,
# :class:`~torchvision.datapoints.Mask`, and
# :class:`~torchvision.datapoints.Video`.
# TVTensors are :class:`torch.Tensor` subclasses. The available tv_tensors are
# :class:`~torchvision.tv_tensors.Image`,
# :class:`~torchvision.tv_tensors.BoundingBoxes`,
# :class:`~torchvision.tv_tensors.Mask`, and
# :class:`~torchvision.tv_tensors.Video`.
#
# Datapoints look and feel just like regular tensors - they **are** tensors.
# TVTensors look and feel just like regular tensors - they **are** tensors.
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()``
# or any ``torch.*`` operator will also work on a datapoint:
# or any ``torch.*`` operator will also work on a tv_tensor:
img_dp = datapoints.Image(torch.randint(0, 256, (3, 256, 256), dtype=torch.uint8))
img_dp = tv_tensors.Image(torch.randint(0, 256, (3, 256, 256), dtype=torch.uint8))
print(f"{isinstance(img_dp, torch.Tensor) = }")
print(f"{img_dp.dtype = }, {img_dp.shape = }, {img_dp.sum() = }")
# %%
# These Datapoint classes are at the core of the transforms: in order to
# These TVTensor classes are at the core of the transforms: in order to
# transform a given input, the transforms first look at the **class** of the
# object, and dispatch to the appropriate implementation accordingly.
#
# You don't need to know much more about datapoints at this point, but advanced
# You don't need to know much more about tv_tensors at this point, but advanced
# users who want to learn more can refer to
# :ref:`sphx_glr_auto_examples_transforms_plot_datapoints.py`.
# :ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`.
#
# What do I pass as input?
# ------------------------
......@@ -196,17 +196,17 @@ print(f"{out_target['this_is_ignored']}")
# Pure :class:`torch.Tensor` objects are, in general, treated as images (or
# as videos for video-specific transforms). Indeed, you may have noticed
# that in the code above we haven't used the
# :class:`~torchvision.datapoints.Image` class at all, and yet our images
# :class:`~torchvision.tv_tensors.Image` class at all, and yet our images
# got transformed properly. Transforms follow the following logic to
# determine whether a pure Tensor should be treated as an image (or video),
# or just ignored:
#
# * If there is an :class:`~torchvision.datapoints.Image`,
# :class:`~torchvision.datapoints.Video`,
# * If there is an :class:`~torchvision.tv_tensors.Image`,
# :class:`~torchvision.tv_tensors.Video`,
# or :class:`PIL.Image.Image` instance in the input, all other pure
# tensors are passed-through.
# * If there is no :class:`~torchvision.datapoints.Image` or
# :class:`~torchvision.datapoints.Video` instance, only the first pure
# * If there is no :class:`~torchvision.tv_tensors.Image` or
# :class:`~torchvision.tv_tensors.Video` instance, only the first pure
# :class:`torch.Tensor` will be transformed as image or video, while all
# others will be passed-through. Here "first" means "first in a depth-wise
# traversal".
......@@ -234,9 +234,9 @@ print(f"{out_target['this_is_ignored']}")
# Torchvision also supports datasets for object detection or segmentation like
# :class:`torchvision.datasets.CocoDetection`. Those datasets predate
# the existence of the :mod:`torchvision.transforms.v2` module and of the
# datapoints, so they don't return datapoints out of the box.
# tv_tensors, so they don't return tv_tensors out of the box.
#
# An easy way to force those datasets to return datapoints and to make them
# An easy way to force those datasets to return tv_tensors and to make them
# compatible with v2 transforms is to use the
# :func:`torchvision.datasets.wrap_dataset_for_transforms_v2` function:
#
......@@ -246,14 +246,14 @@ print(f"{out_target['this_is_ignored']}")
#
# dataset = CocoDetection(..., transforms=my_transforms)
# dataset = wrap_dataset_for_transforms_v2(dataset)
# # Now the dataset returns datapoints!
# # Now the dataset returns tv_tensors!
#
# Using your own datasets
# ^^^^^^^^^^^^^^^^^^^^^^^
#
# If you have a custom dataset, then you'll need to convert your objects into
# the appropriate Datapoint classes. Creating Datapoint instances is very easy,
# refer to :ref:`datapoint_creation` for more details.
# the appropriate TVTensor classes. Creating TVTensor instances is very easy,
# refer to :ref:`tv_tensor_creation` for more details.
#
# There are two main places where you can implement that conversion logic:
#
......
"""
==============
Datapoints FAQ
==============
=============
TVTensors FAQ
=============
.. note::
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_datapoints.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_datapoints.py>` to download the full example code.
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_tv_tensors.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_tv_tensors.py>` to download the full example code.
Datapoints are Tensor subclasses introduced together with
``torchvision.transforms.v2``. This example showcases what these datapoints are
TVTensors are Tensor subclasses introduced together with
``torchvision.transforms.v2``. This example showcases what these tv_tensors are
and how they behave.
.. warning::
**Intended Audience** Unless you're writing your own transforms or your own datapoints, you
**Intended Audience** Unless you're writing your own transforms or your own tv_tensors, you
probably do not need to read this guide. This is a fairly low-level topic
that most users will not need to worry about: you do not need to understand
the internals of datapoints to efficiently rely on
the internals of tv_tensors to efficiently rely on
``torchvision.transforms.v2``. It may however be useful for advanced users
trying to implement their own datasets, transforms, or work directly with
the datapoints.
the tv_tensors.
"""
# %%
import PIL.Image
import torch
from torchvision import datapoints
from torchvision import tv_tensors
# %%
# What are datapoints?
# What are tv_tensors?
# --------------------
#
# Datapoints are zero-copy tensor subclasses:
# TVTensors are zero-copy tensor subclasses:
tensor = torch.rand(3, 256, 256)
image = datapoints.Image(tensor)
image = tv_tensors.Image(tensor)
assert isinstance(image, torch.Tensor)
assert image.data_ptr() == tensor.data_ptr()
......@@ -46,33 +46,33 @@ assert image.data_ptr() == tensor.data_ptr()
# Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
# for the input data.
#
# :mod:`torchvision.datapoints` supports four types of datapoints:
# :mod:`torchvision.tv_tensors` supports four types of tv_tensors:
#
# * :class:`~torchvision.datapoints.Image`
# * :class:`~torchvision.datapoints.Video`
# * :class:`~torchvision.datapoints.BoundingBoxes`
# * :class:`~torchvision.datapoints.Mask`
# * :class:`~torchvision.tv_tensors.Image`
# * :class:`~torchvision.tv_tensors.Video`
# * :class:`~torchvision.tv_tensors.BoundingBoxes`
# * :class:`~torchvision.tv_tensors.Mask`
#
# What can I do with a datapoint?
# What can I do with a tv_tensor?
# -------------------------------
#
# Datapoints look and feel just like regular tensors - they **are** tensors.
# TVTensors look and feel just like regular tensors - they **are** tensors.
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or
# any ``torch.*`` operator will also work on datapoints. See
# :ref:`datapoint_unwrapping_behaviour` for a few gotchas.
# any ``torch.*`` operator will also work on tv_tensors. See
# :ref:`tv_tensor_unwrapping_behaviour` for a few gotchas.
# %%
# .. _datapoint_creation:
# .. _tv_tensor_creation:
#
# How do I construct a datapoint?
# How do I construct a tv_tensor?
# -------------------------------
#
# Using the constructor
# ^^^^^^^^^^^^^^^^^^^^^
#
# Each datapoint class takes any tensor-like data that can be turned into a :class:`~torch.Tensor`
# Each tv_tensor class takes any tensor-like data that can be turned into a :class:`~torch.Tensor`
image = datapoints.Image([[[[0, 1], [1, 0]]]])
image = tv_tensors.Image([[[[0, 1], [1, 0]]]])
print(image)
......@@ -80,64 +80,64 @@ print(image)
# Similar to other PyTorch creations ops, the constructor also takes the ``dtype``, ``device``, and ``requires_grad``
# parameters.
float_image = datapoints.Image([[[0, 1], [1, 0]]], dtype=torch.float32, requires_grad=True)
float_image = tv_tensors.Image([[[0, 1], [1, 0]]], dtype=torch.float32, requires_grad=True)
print(float_image)
# %%
# In addition, :class:`~torchvision.datapoints.Image` and :class:`~torchvision.datapoints.Mask` can also take a
# In addition, :class:`~torchvision.tv_tensors.Image` and :class:`~torchvision.tv_tensors.Mask` can also take a
# :class:`PIL.Image.Image` directly:
image = datapoints.Image(PIL.Image.open("../assets/astronaut.jpg"))
image = tv_tensors.Image(PIL.Image.open("../assets/astronaut.jpg"))
print(image.shape, image.dtype)
# %%
# Some datapoints require additional metadata to be passed in ordered to be constructed. For example,
# :class:`~torchvision.datapoints.BoundingBoxes` requires the coordinate format as well as the size of the
# Some tv_tensors require additional metadata to be passed in ordered to be constructed. For example,
# :class:`~torchvision.tv_tensors.BoundingBoxes` requires the coordinate format as well as the size of the
# corresponding image (``canvas_size``) alongside the actual values. These
# metadata are required to properly transform the bounding boxes.
bboxes = datapoints.BoundingBoxes(
bboxes = tv_tensors.BoundingBoxes(
[[17, 16, 344, 495], [0, 10, 0, 10]],
format=datapoints.BoundingBoxFormat.XYXY,
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=image.shape[-2:]
)
print(bboxes)
# %%
# Using ``datapoints.wrap()``
# Using ``tv_tensors.wrap()``
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# You can also use the :func:`~torchvision.datapoints.wrap` function to wrap a tensor object
# into a datapoint. This is useful when you already have an object of the
# You can also use the :func:`~torchvision.tv_tensors.wrap` function to wrap a tensor object
# into a tv_tensor. This is useful when you already have an object of the
# desired type, which typically happens when writing transforms: you just want
# to wrap the output like the input.
new_bboxes = torch.tensor([0, 20, 30, 40])
new_bboxes = datapoints.wrap(new_bboxes, like=bboxes)
assert isinstance(new_bboxes, datapoints.BoundingBoxes)
new_bboxes = tv_tensors.wrap(new_bboxes, like=bboxes)
assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
assert new_bboxes.canvas_size == bboxes.canvas_size
# %%
# The metadata of ``new_bboxes`` is the same as ``bboxes``, but you could pass
# it as a parameter to override it.
#
# .. _datapoint_unwrapping_behaviour:
# .. _tv_tensor_unwrapping_behaviour:
#
# I had a Datapoint but now I have a Tensor. Help!
# I had a TVTensor but now I have a Tensor. Help!
# ------------------------------------------------
#
# By default, operations on :class:`~torchvision.datapoints.Datapoint` objects
# By default, operations on :class:`~torchvision.tv_tensors.TVTensor` objects
# will return a pure Tensor:
assert isinstance(bboxes, datapoints.BoundingBoxes)
assert isinstance(bboxes, tv_tensors.BoundingBoxes)
# Shift bboxes by 3 pixels in both H and W
new_bboxes = bboxes + 3
assert isinstance(new_bboxes, torch.Tensor)
assert not isinstance(new_bboxes, datapoints.BoundingBoxes)
assert not isinstance(new_bboxes, tv_tensors.BoundingBoxes)
# %%
# .. note::
......@@ -145,36 +145,36 @@ assert not isinstance(new_bboxes, datapoints.BoundingBoxes)
# This behavior only affects native ``torch`` operations. If you are using
# the built-in ``torchvision`` transforms or functionals, you will always get
# as output the same type that you passed as input (pure ``Tensor`` or
# ``Datapoint``).
# ``TVTensor``).
# %%
# But I want a Datapoint back!
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# But I want a TVTensor back!
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# You can re-wrap a pure tensor into a datapoint by just calling the datapoint
# constructor, or by using the :func:`~torchvision.datapoints.wrap` function
# (see more details above in :ref:`datapoint_creation`):
# You can re-wrap a pure tensor into a tv_tensor by just calling the tv_tensor
# constructor, or by using the :func:`~torchvision.tv_tensors.wrap` function
# (see more details above in :ref:`tv_tensor_creation`):
new_bboxes = bboxes + 3
new_bboxes = datapoints.wrap(new_bboxes, like=bboxes)
assert isinstance(new_bboxes, datapoints.BoundingBoxes)
new_bboxes = tv_tensors.wrap(new_bboxes, like=bboxes)
assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
# %%
# Alternatively, you can use the :func:`~torchvision.datapoints.set_return_type`
# Alternatively, you can use the :func:`~torchvision.tv_tensors.set_return_type`
# as a global config setting for the whole program, or as a context manager
# (read its docs to learn more about caveats):
with datapoints.set_return_type("datapoint"):
with tv_tensors.set_return_type("tv_tensor"):
new_bboxes = bboxes + 3
assert isinstance(new_bboxes, datapoints.BoundingBoxes)
assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
# %%
# Why is this happening?
# ^^^^^^^^^^^^^^^^^^^^^^
#
# **For performance reasons**. :class:`~torchvision.datapoints.Datapoint`
# **For performance reasons**. :class:`~torchvision.tv_tensors.TVTensor`
# classes are Tensor subclasses, so any operation involving a
# :class:`~torchvision.datapoints.Datapoint` object will go through the
# :class:`~torchvision.tv_tensors.TVTensor` object will go through the
# `__torch_function__
# <https://pytorch.org/docs/stable/notes/extending.html#extending-torch>`_
# protocol. This induces a small overhead, which we want to avoid when possible.
......@@ -183,12 +183,12 @@ assert isinstance(new_bboxes, datapoints.BoundingBoxes)
# ``forward``.
#
# **The alternative isn't much better anyway.** For every operation where
# preserving the :class:`~torchvision.datapoints.Datapoint` type makes
# preserving the :class:`~torchvision.tv_tensors.TVTensor` type makes
# sense, there are just as many operations where returning a pure Tensor is
# preferable: for example, is ``img.sum()`` still an :class:`~torchvision.datapoints.Image`?
# If we were to preserve :class:`~torchvision.datapoints.Datapoint` types all
# preferable: for example, is ``img.sum()`` still an :class:`~torchvision.tv_tensors.Image`?
# If we were to preserve :class:`~torchvision.tv_tensors.TVTensor` types all
# the way, even model's logits or the output of the loss function would end up
# being of type :class:`~torchvision.datapoints.Image`, and surely that's not
# being of type :class:`~torchvision.tv_tensors.Image`, and surely that's not
# desirable.
#
# .. note::
......@@ -203,22 +203,22 @@ assert isinstance(new_bboxes, datapoints.BoundingBoxes)
# There are a few exceptions to this "unwrapping" rule:
# :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`,
# :meth:`torch.Tensor.detach`, and :meth:`~torch.Tensor.requires_grad_` retain
# the datapoint type.
# the tv_tensor type.
#
# Inplace operations on datapoints like ``obj.add_()`` will preserve the type of
# Inplace operations on tv_tensors like ``obj.add_()`` will preserve the type of
# ``obj``. However, the **returned** value of inplace operations will be a pure
# tensor:
image = datapoints.Image([[[0, 1], [1, 0]]])
image = tv_tensors.Image([[[0, 1], [1, 0]]])
new_image = image.add_(1).mul_(2)
# image got transformed in-place and is still an Image datapoint, but new_image
# image got transformed in-place and is still an Image tv_tensor, but new_image
# is a Tensor. They share the same underlying data and they're equal, just
# different classes.
assert isinstance(image, datapoints.Image)
assert isinstance(image, tv_tensors.Image)
print(image)
assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image)
assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, tv_tensors.Image)
assert (new_image == image).all()
assert new_image.data_ptr() == image.data_ptr()
......@@ -7,10 +7,10 @@ import transforms as reference_transforms
def get_modules(use_v2):
# We need a protected import to avoid the V2 warning in case just V1 is used
if use_v2:
import torchvision.datapoints
import torchvision.transforms.v2
import torchvision.tv_tensors
return torchvision.transforms.v2, torchvision.datapoints
return torchvision.transforms.v2, torchvision.tv_tensors
else:
return reference_transforms, None
......@@ -28,16 +28,16 @@ class DetectionPresetTrain:
use_v2=False,
):
T, datapoints = get_modules(use_v2)
T, tv_tensors = get_modules(use_v2)
transforms = []
backend = backend.lower()
if backend == "datapoint":
if backend == "tv_tensor":
transforms.append(T.ToImage())
elif backend == "tensor":
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
raise ValueError(f"backend can be 'tv_tensor', 'tensor' or 'pil', but got {backend}")
if data_augmentation == "hflip":
transforms += [T.RandomHorizontalFlip(p=hflip_prob)]
......@@ -54,7 +54,7 @@ class DetectionPresetTrain:
T.RandomHorizontalFlip(p=hflip_prob),
]
elif data_augmentation == "ssd":
fill = defaultdict(lambda: mean, {datapoints.Mask: 0}) if use_v2 else list(mean)
fill = defaultdict(lambda: mean, {tv_tensors.Mask: 0}) if use_v2 else list(mean)
transforms += [
T.RandomPhotometricDistort(),
T.RandomZoomOut(fill=fill),
......@@ -77,7 +77,7 @@ class DetectionPresetTrain:
if use_v2:
transforms += [
T.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.XYXY),
T.ConvertBoundingBoxFormat(tv_tensors.BoundingBoxFormat.XYXY),
T.SanitizeBoundingBoxes(),
T.ToPureTensor(),
]
......@@ -98,10 +98,10 @@ class DetectionPresetEval:
transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
elif backend == "tensor":
transforms += [T.PILToTensor()]
elif backend == "datapoint":
elif backend == "tv_tensor":
transforms += [T.ToImage()]
else:
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
raise ValueError(f"backend can be 'tv_tensor', 'tensor' or 'pil', but got {backend}")
transforms += [T.ToDtype(torch.float, scale=True)]
......
......@@ -180,8 +180,8 @@ def get_args_parser(add_help=True):
def main(args):
if args.backend.lower() == "datapoint" and not args.use_v2:
raise ValueError("Use --use-v2 if you want to use the datapoint backend.")
if args.backend.lower() == "tv_tensor" and not args.use_v2:
raise ValueError("Use --use-v2 if you want to use the tv_tensor backend.")
if args.dataset not in ("coco", "coco_kp"):
raise ValueError(f"Dataset should be coco or coco_kp, got {args.dataset}")
if "keypoint" in args.model and args.dataset != "coco_kp":
......
......@@ -4,11 +4,11 @@ import torch
def get_modules(use_v2):
# We need a protected import to avoid the V2 warning in case just V1 is used
if use_v2:
import torchvision.datapoints
import torchvision.transforms.v2
import torchvision.tv_tensors
import v2_extras
return torchvision.transforms.v2, torchvision.datapoints, v2_extras
return torchvision.transforms.v2, torchvision.tv_tensors, v2_extras
else:
import transforms
......@@ -27,16 +27,16 @@ class SegmentationPresetTrain:
backend="pil",
use_v2=False,
):
T, datapoints, v2_extras = get_modules(use_v2)
T, tv_tensors, v2_extras = get_modules(use_v2)
transforms = []
backend = backend.lower()
if backend == "datapoint":
if backend == "tv_tensor":
transforms.append(T.ToImage())
elif backend == "tensor":
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
raise ValueError(f"backend can be 'tv_tensor', 'tensor' or 'pil', but got {backend}")
transforms += [T.RandomResize(min_size=int(0.5 * base_size), max_size=int(2.0 * base_size))]
......@@ -46,7 +46,7 @@ class SegmentationPresetTrain:
if use_v2:
# We need a custom pad transform here, since the padding we want to perform here is fundamentally
# different from the padding in `RandomCrop` if `pad_if_needed=True`.
transforms += [v2_extras.PadIfSmaller(crop_size, fill={datapoints.Mask: 255, "others": 0})]
transforms += [v2_extras.PadIfSmaller(crop_size, fill={tv_tensors.Mask: 255, "others": 0})]
transforms += [T.RandomCrop(crop_size)]
......@@ -54,9 +54,9 @@ class SegmentationPresetTrain:
transforms += [T.PILToTensor()]
if use_v2:
img_type = datapoints.Image if backend == "datapoint" else torch.Tensor
img_type = tv_tensors.Image if backend == "tv_tensor" else torch.Tensor
transforms += [
T.ToDtype(dtype={img_type: torch.float32, datapoints.Mask: torch.int64, "others": None}, scale=True)
T.ToDtype(dtype={img_type: torch.float32, tv_tensors.Mask: torch.int64, "others": None}, scale=True)
]
else:
# No need to explicitly convert masks as they're magically int64 already
......@@ -82,10 +82,10 @@ class SegmentationPresetEval:
backend = backend.lower()
if backend == "tensor":
transforms += [T.PILToTensor()]
elif backend == "datapoint":
elif backend == "tv_tensor":
transforms += [T.ToImage()]
elif backend != "pil":
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
raise ValueError(f"backend can be 'tv_tensor', 'tensor' or 'pil', but got {backend}")
if use_v2:
transforms += [T.Resize(size=(base_size, base_size))]
......
......@@ -128,7 +128,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi
def main(args):
if args.backend.lower() != "pil" and not args.use_v2:
# TODO: Support tensor backend in V1?
raise ValueError("Use --use-v2 if you want to use the datapoint or tensor backend.")
raise ValueError("Use --use-v2 if you want to use the tv_tensor or tensor backend.")
if args.use_v2 and args.dataset != "coco":
raise ValueError("v2 is only support supported for coco dataset for now.")
......
"""This file only exists to be lazy-imported and avoid V2-related import warnings when just using V1."""
import torch
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms import v2
......@@ -80,4 +80,4 @@ class CocoDetectionToVOCSegmentation(v2.Transform):
if segmentation_mask is None:
segmentation_mask = torch.zeros(v2.functional.get_size(image), dtype=torch.uint8)
return image, datapoints.Mask(segmentation_mask)
return image, tv_tensors.Mask(segmentation_mask)
......@@ -19,7 +19,7 @@ import torch.testing
from PIL import Image
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision import datapoints, io
from torchvision import io, tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import to_image, to_pil_image
......@@ -391,7 +391,7 @@ def make_image(
if color_space in {"GRAY_ALPHA", "RGBA"}:
data[..., -1, :, :] = max_value
return datapoints.Image(data)
return tv_tensors.Image(data)
def make_image_tensor(*args, **kwargs):
......@@ -405,7 +405,7 @@ def make_image_pil(*args, **kwargs):
def make_bounding_boxes(
canvas_size=DEFAULT_SIZE,
*,
format=datapoints.BoundingBoxFormat.XYXY,
format=tv_tensors.BoundingBoxFormat.XYXY,
dtype=None,
device="cpu",
):
......@@ -415,7 +415,7 @@ def make_bounding_boxes(
return torch.stack([torch.randint(max_value - v, ()) for v in values.tolist()])
if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format]
format = tv_tensors.BoundingBoxFormat[format]
dtype = dtype or torch.float32
......@@ -424,21 +424,21 @@ def make_bounding_boxes(
y = sample_position(h, canvas_size[0])
x = sample_position(w, canvas_size[1])
if format is datapoints.BoundingBoxFormat.XYWH:
if format is tv_tensors.BoundingBoxFormat.XYWH:
parts = (x, y, w, h)
elif format is datapoints.BoundingBoxFormat.XYXY:
elif format is tv_tensors.BoundingBoxFormat.XYXY:
x1, y1 = x, y
x2 = x1 + w
y2 = y1 + h
parts = (x1, y1, x2, y2)
elif format is datapoints.BoundingBoxFormat.CXCYWH:
elif format is tv_tensors.BoundingBoxFormat.CXCYWH:
cx = x + w / 2
cy = y + h / 2
parts = (cx, cy, w, h)
else:
raise ValueError(f"Format {format} is not supported")
return datapoints.BoundingBoxes(
return tv_tensors.BoundingBoxes(
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, canvas_size=canvas_size
)
......@@ -446,7 +446,7 @@ def make_bounding_boxes(
def make_detection_mask(size=DEFAULT_SIZE, *, dtype=None, device="cpu"):
"""Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks"""
num_objects = 1
return datapoints.Mask(
return tv_tensors.Mask(
torch.testing.make_tensor(
(num_objects, *size),
low=0,
......@@ -459,7 +459,7 @@ def make_detection_mask(size=DEFAULT_SIZE, *, dtype=None, device="cpu"):
def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=(), dtype=None, device="cpu"):
"""Make a "segmentation" mask, i.e. (*, H, W), where the category is encoded as pixel value"""
return datapoints.Mask(
return tv_tensors.Mask(
torch.testing.make_tensor(
(*batch_dims, *size),
low=0,
......@@ -471,7 +471,7 @@ def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=(
def make_video(size=DEFAULT_SIZE, *, num_frames=3, batch_dims=(), **kwargs):
return datapoints.Video(make_image(size, batch_dims=(*batch_dims, num_frames), **kwargs))
return tv_tensors.Video(make_image(size, batch_dims=(*batch_dims, num_frames), **kwargs))
def make_video_tensor(*args, **kwargs):
......
......@@ -568,7 +568,7 @@ class DatasetTestCase(unittest.TestCase):
@test_all_configs
def test_transforms_v2_wrapper(self, config):
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.datasets import wrap_dataset_for_transforms_v2
try:
......@@ -590,7 +590,7 @@ class DatasetTestCase(unittest.TestCase):
wrapped_sample = wrapped_dataset[0]
assert tree_any(
lambda item: isinstance(item, (datapoints.Datapoint, PIL.Image.Image)), wrapped_sample
lambda item: isinstance(item, (tv_tensors.TVTensor, PIL.Image.Image)), wrapped_sample
)
except TypeError as error:
msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
......@@ -717,7 +717,7 @@ def check_transforms_v2_wrapper_spawn(dataset):
pytest.skip("Multiprocessing spawning is only checked on macOS.")
from torch.utils.data import DataLoader
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.datasets import wrap_dataset_for_transforms_v2
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)
......@@ -726,7 +726,7 @@ def check_transforms_v2_wrapper_spawn(dataset):
for wrapped_sample in dataloader:
assert tree_any(
lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample
lambda item: isinstance(item, (tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)), wrapped_sample
)
......
......@@ -6,7 +6,7 @@ import pytest
import torch
from torch.nn.functional import one_hot
from torchvision.prototype import datapoints
from torchvision.prototype import tv_tensors
from transforms_v2_legacy_utils import combinations_grid, DEFAULT_EXTRA_DIMS, from_loader, from_loaders, TensorLoader
......@@ -40,7 +40,7 @@ def make_label_loader(*, extra_dims=(), categories=None, dtype=torch.int64):
# The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values,
# regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123
data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype)
return datapoints.Label(data, categories=categories)
return tv_tensors.Label(data, categories=categories)
return LabelLoader(fn, shape=extra_dims, dtype=dtype, categories=categories)
......@@ -64,7 +64,7 @@ def make_one_hot_label_loader(*, categories=None, extra_dims=(), dtype=torch.int
# since `one_hot` only supports int64
label = make_label_loader(extra_dims=extra_dims, categories=num_categories, dtype=torch.int64).load(device)
data = one_hot(label, num_classes=num_categories).to(dtype)
return datapoints.OneHotLabel(data, categories=categories)
return tv_tensors.OneHotLabel(data, categories=categories)
return OneHotLabelLoader(fn, shape=(*extra_dims, num_categories), dtype=dtype, categories=categories)
......
......@@ -3387,11 +3387,11 @@ class TestDatasetWrapper:
datasets.wrap_dataset_for_transforms_v2(dataset)
def test_subclass(self, mocker):
from torchvision import datapoints
from torchvision import tv_tensors
sentinel = object()
mocker.patch.dict(
datapoints._dataset_wrapper.WRAPPER_FACTORIES,
tv_tensors._dataset_wrapper.WRAPPER_FACTORIES,
clear=False,
values={datasets.FakeData: lambda dataset, target_keys: lambda idx, sample: sentinel},
)
......
......@@ -19,12 +19,12 @@ from torch.utils.data.graph_settings import get_all_graph_pipes
from torchdata.dataloader2.graph.utils import traverse_dps
from torchdata.datapipes.iter import ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision._utils import sequence_to_str
from torchvision.prototype import datasets
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import EncodedImage
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
from torchvision.prototype.tv_tensors import Label
from torchvision.transforms.v2._utils import is_pure_tensor
......@@ -147,7 +147,7 @@ class TestCommon:
pure_tensors = {key for key, value in sample.items() if is_pure_tensor(value)}
if pure_tensors and not any(
isinstance(item, (datapoints.Image, datapoints.Video, EncodedImage)) for item in sample.values()
isinstance(item, (tv_tensors.Image, tv_tensors.Video, EncodedImage)) for item in sample.values()
):
raise AssertionError(
f"The values of key(s) "
......@@ -276,7 +276,7 @@ class TestUSPS:
assert "image" in sample
assert "label" in sample
assert isinstance(sample["image"], datapoints.Image)
assert isinstance(sample["image"], tv_tensors.Image)
assert isinstance(sample["label"], Label)
assert sample["image"].shape == (1, 16, 16)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment