Unverified Commit d5f4cc38 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Datapoint -> TVTensor; datapoint[s] -> tv_tensor[s] (#7894)

parent b9447fdd
...@@ -88,8 +88,8 @@ class CustomGalleryExampleSortKey: ...@@ -88,8 +88,8 @@ class CustomGalleryExampleSortKey:
"plot_transforms_e2e.py", "plot_transforms_e2e.py",
"plot_cutmix_mixup.py", "plot_cutmix_mixup.py",
"plot_custom_transforms.py", "plot_custom_transforms.py",
"plot_datapoints.py", "plot_tv_tensors.py",
"plot_custom_datapoints.py", "plot_custom_tv_tensors.py",
] ]
def __call__(self, filename): def __call__(self, filename):
......
...@@ -32,7 +32,7 @@ architectures, and common image transformations for computer vision. ...@@ -32,7 +32,7 @@ architectures, and common image transformations for computer vision.
:caption: Package Reference :caption: Package Reference
transforms transforms
datapoints tv_tensors
models models
datasets datasets
utils utils
......
...@@ -30,12 +30,12 @@ tasks (image classification, detection, segmentation, video classification). ...@@ -30,12 +30,12 @@ tasks (image classification, detection, segmentation, video classification).
.. code:: python .. code:: python
# Detection (re-using imports and transforms from above) # Detection (re-using imports and transforms from above)
from torchvision import datapoints from torchvision import tv_tensors
img = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8) img = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)
bboxes = torch.randint(0, H // 2, size=(3, 4)) bboxes = torch.randint(0, H // 2, size=(3, 4))
bboxes[:, 2:] += bboxes[:, :2] bboxes[:, 2:] += bboxes[:, :2]
bboxes = datapoints.BoundingBoxes(bboxes, format="XYXY", canvas_size=(H, W)) bboxes = tv_tensors.BoundingBoxes(bboxes, format="XYXY", canvas_size=(H, W))
# The same transforms can be used! # The same transforms can be used!
img, bboxes = transforms(img, bboxes) img, bboxes = transforms(img, bboxes)
...@@ -183,8 +183,8 @@ Transforms are available as classes like ...@@ -183,8 +183,8 @@ Transforms are available as classes like
This is very much like the :mod:`torch.nn` package which defines both classes This is very much like the :mod:`torch.nn` package which defines both classes
and functional equivalents in :mod:`torch.nn.functional`. and functional equivalents in :mod:`torch.nn.functional`.
The functionals support PIL images, pure tensors, or :ref:`datapoints The functionals support PIL images, pure tensors, or :ref:`tv_tensors
<datapoints>`, e.g. both ``resize(image_tensor)`` and ``resize(bboxes)`` are <tv_tensors>`, e.g. both ``resize(image_tensor)`` and ``resize(bboxes)`` are
valid. valid.
.. note:: .. note::
......
.. _datapoints: .. _tv_tensors:
Datapoints TVTensors
========== ==========
.. currentmodule:: torchvision.datapoints .. currentmodule:: torchvision.tv_tensors
Datapoints are tensor subclasses which the :mod:`~torchvision.transforms.v2` v2 transforms use under the hood to TVTensors are tensor subclasses which the :mod:`~torchvision.transforms.v2` v2 transforms use under the hood to
dispatch their inputs to the appropriate lower-level kernels. Most users do not dispatch their inputs to the appropriate lower-level kernels. Most users do not
need to manipulate datapoints directly and can simply rely on dataset wrapping - need to manipulate tv_tensors directly and can simply rely on dataset wrapping -
see e.g. :ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`. see e.g. :ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`.
.. autosummary:: .. autosummary::
...@@ -19,6 +19,6 @@ see e.g. :ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`. ...@@ -19,6 +19,6 @@ see e.g. :ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`.
BoundingBoxFormat BoundingBoxFormat
BoundingBoxes BoundingBoxes
Mask Mask
Datapoint TVTensor
set_return_type set_return_type
wrap wrap
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch import torch
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
...@@ -22,7 +22,7 @@ def plot(imgs, row_title=None, **imshow_kwargs): ...@@ -22,7 +22,7 @@ def plot(imgs, row_title=None, **imshow_kwargs):
if isinstance(target, dict): if isinstance(target, dict):
boxes = target.get("boxes") boxes = target.get("boxes")
masks = target.get("masks") masks = target.get("masks")
elif isinstance(target, datapoints.BoundingBoxes): elif isinstance(target, tv_tensors.BoundingBoxes):
boxes = target boxes = target
else: else:
raise ValueError(f"Unexpected target type: {type(target)}") raise ValueError(f"Unexpected target type: {type(target)}")
......
...@@ -13,7 +13,7 @@ torchvision transforms V2 API. ...@@ -13,7 +13,7 @@ torchvision transforms V2 API.
# %% # %%
import torch import torch
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.transforms import v2 from torchvision.transforms import v2
...@@ -62,7 +62,7 @@ transforms = v2.Compose([ ...@@ -62,7 +62,7 @@ transforms = v2.Compose([
H, W = 256, 256 H, W = 256, 256
img = torch.rand(3, H, W) img = torch.rand(3, H, W)
bboxes = datapoints.BoundingBoxes( bboxes = tv_tensors.BoundingBoxes(
torch.tensor([[0, 10, 10, 20], [50, 50, 70, 70]]), torch.tensor([[0, 10, 10, 20], [50, 50, 70, 70]]),
format="XYXY", format="XYXY",
canvas_size=(H, W) canvas_size=(H, W)
...@@ -74,9 +74,9 @@ out_img, out_bboxes, out_label = transforms(img, bboxes, label) ...@@ -74,9 +74,9 @@ out_img, out_bboxes, out_label = transforms(img, bboxes, label)
print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }") print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }")
# %% # %%
# .. note:: # .. note::
# While working with datapoint classes in your code, make sure to # While working with tv_tensor classes in your code, make sure to
# familiarize yourself with this section: # familiarize yourself with this section:
# :ref:`datapoint_unwrapping_behaviour` # :ref:`tv_tensor_unwrapping_behaviour`
# #
# Supporting arbitrary input structures # Supporting arbitrary input structures
# ===================================== # =====================================
...@@ -111,7 +111,7 @@ print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}") ...@@ -111,7 +111,7 @@ print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
# In brief, the core logic is to unpack the input into a flat list using `pytree # In brief, the core logic is to unpack the input into a flat list using `pytree
# <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and # <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and
# then transform only the entries that can be transformed (the decision is made # then transform only the entries that can be transformed (the decision is made
# based on the **class** of the entries, as all datapoints are # based on the **class** of the entries, as all tv_tensors are
# tensor-subclasses) plus some custom logic that is out of score here - check the # tensor-subclasses) plus some custom logic that is out of score here - check the
# code for details. The (potentially transformed) entries are then repacked and # code for details. The (potentially transformed) entries are then repacked and
# returned, in the same structure as the input. # returned, in the same structure as the input.
......
""" """
===================================== =====================================
How to write your own Datapoint class How to write your own TVTensor class
===================================== =====================================
.. note:: .. note::
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_custom_datapoints.ipynb>`_ Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_custom_tv_tensors.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_custom_datapoints.py>` to download the full example code. or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_custom_tv_tensors.py>` to download the full example code.
This guide is intended for advanced users and downstream library maintainers. We explain how to This guide is intended for advanced users and downstream library maintainers. We explain how to
write your own datapoint class, and how to make it compatible with the built-in write your own tv_tensor class, and how to make it compatible with the built-in
Torchvision v2 transforms. Before continuing, make sure you have read Torchvision v2 transforms. Before continuing, make sure you have read
:ref:`sphx_glr_auto_examples_transforms_plot_datapoints.py`. :ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`.
""" """
# %% # %%
import torch import torch
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.transforms import v2 from torchvision.transforms import v2
# %% # %%
# We will create a very simple class that just inherits from the base # We will create a very simple class that just inherits from the base
# :class:`~torchvision.datapoints.Datapoint` class. It will be enough to cover # :class:`~torchvision.tv_tensors.TVTensor` class. It will be enough to cover
# what you need to know to implement your more elaborate uses-cases. If you need # what you need to know to implement your more elaborate uses-cases. If you need
# to create a class that carries meta-data, take a look at how the # to create a class that carries meta-data, take a look at how the
# :class:`~torchvision.datapoints.BoundingBoxes` class is `implemented # :class:`~torchvision.tv_tensors.BoundingBoxes` class is `implemented
# <https://github.com/pytorch/vision/blob/main/torchvision/datapoints/_bounding_box.py>`_. # <https://github.com/pytorch/vision/blob/main/torchvision/tv_tensors/_bounding_box.py>`_.
class MyDatapoint(datapoints.Datapoint): class MyTVTensor(tv_tensors.TVTensor):
pass pass
my_dp = MyDatapoint([1, 2, 3]) my_dp = MyTVTensor([1, 2, 3])
my_dp my_dp
# %% # %%
# Now that we have defined our custom Datapoint class, we want it to be # Now that we have defined our custom TVTensor class, we want it to be
# compatible with the built-in torchvision transforms, and the functional API. # compatible with the built-in torchvision transforms, and the functional API.
# For that, we need to implement a kernel which performs the core of the # For that, we need to implement a kernel which performs the core of the
# transformation, and then "hook" it to the functional that we want to support # transformation, and then "hook" it to the functional that we want to support
# via :func:`~torchvision.transforms.v2.functional.register_kernel`. # via :func:`~torchvision.transforms.v2.functional.register_kernel`.
# #
# We illustrate this process below: we create a kernel for the "horizontal flip" # We illustrate this process below: we create a kernel for the "horizontal flip"
# operation of our MyDatapoint class, and register it to the functional API. # operation of our MyTVTensor class, and register it to the functional API.
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
@F.register_kernel(functional="hflip", datapoint_cls=MyDatapoint) @F.register_kernel(functional="hflip", tv_tensor_cls=MyTVTensor)
def hflip_my_datapoint(my_dp, *args, **kwargs): def hflip_my_tv_tensor(my_dp, *args, **kwargs):
print("Flipping!") print("Flipping!")
out = my_dp.flip(-1) out = my_dp.flip(-1)
return datapoints.wrap(out, like=my_dp) return tv_tensors.wrap(out, like=my_dp)
# %% # %%
# To understand why :func:`~torchvision.datapoints.wrap` is used, see # To understand why :func:`~torchvision.tv_tensors.wrap` is used, see
# :ref:`datapoint_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now, # :ref:`tv_tensor_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now,
# we will explain it below in :ref:`param_forwarding`. # we will explain it below in :ref:`param_forwarding`.
# #
# .. note:: # .. note::
...@@ -67,9 +67,9 @@ def hflip_my_datapoint(my_dp, *args, **kwargs): ...@@ -67,9 +67,9 @@ def hflip_my_datapoint(my_dp, *args, **kwargs):
# ``@register_kernel(functional=F.hflip, ...)``. # ``@register_kernel(functional=F.hflip, ...)``.
# #
# Now that we have registered our kernel, we can call the functional API on a # Now that we have registered our kernel, we can call the functional API on a
# ``MyDatapoint`` instance: # ``MyTVTensor`` instance:
my_dp = MyDatapoint(torch.rand(3, 256, 256)) my_dp = MyTVTensor(torch.rand(3, 256, 256))
_ = F.hflip(my_dp) _ = F.hflip(my_dp)
# %% # %%
...@@ -102,10 +102,10 @@ _ = t(my_dp) ...@@ -102,10 +102,10 @@ _ = t(my_dp)
# to its :func:`~torchvision.transforms.v2.functional.hflip` functional. If you # to its :func:`~torchvision.transforms.v2.functional.hflip` functional. If you
# already defined and registered your own kernel as # already defined and registered your own kernel as
def hflip_my_datapoint(my_dp): # noqa def hflip_my_tv_tensor(my_dp): # noqa
print("Flipping!") print("Flipping!")
out = my_dp.flip(-1) out = my_dp.flip(-1)
return datapoints.wrap(out, like=my_dp) return tv_tensors.wrap(out, like=my_dp)
# %% # %%
......
...@@ -23,7 +23,7 @@ import pathlib ...@@ -23,7 +23,7 @@ import pathlib
import torch import torch
import torch.utils.data import torch.utils.data
from torchvision import models, datasets, datapoints from torchvision import models, datasets, tv_tensors
from torchvision.transforms import v2 from torchvision.transforms import v2
torch.manual_seed(0) torch.manual_seed(0)
...@@ -72,7 +72,7 @@ print(f"{type(target['boxes']) = }\n{type(target['labels']) = }\n{type(target['m ...@@ -72,7 +72,7 @@ print(f"{type(target['boxes']) = }\n{type(target['labels']) = }\n{type(target['m
# %% # %%
# We used the ``target_keys`` parameter to specify the kind of output we're # We used the ``target_keys`` parameter to specify the kind of output we're
# interested in. Our dataset now returns a target which is dict where the values # interested in. Our dataset now returns a target which is dict where the values
# are :ref:`Datapoints <what_are_datapoints>` (all are :class:`torch.Tensor` # are :ref:`TVTensors <what_are_tv_tensors>` (all are :class:`torch.Tensor`
# subclasses). We're dropped all unncessary keys from the previous output, but # subclasses). We're dropped all unncessary keys from the previous output, but
# if you need any of the original keys e.g. "image_id", you can still ask for # if you need any of the original keys e.g. "image_id", you can still ask for
# it. # it.
...@@ -103,7 +103,7 @@ transforms = v2.Compose( ...@@ -103,7 +103,7 @@ transforms = v2.Compose(
[ [
v2.ToImage(), v2.ToImage(),
v2.RandomPhotometricDistort(p=1), v2.RandomPhotometricDistort(p=1),
v2.RandomZoomOut(fill={datapoints.Image: (123, 117, 104), "others": 0}), v2.RandomZoomOut(fill={tv_tensors.Image: (123, 117, 104), "others": 0}),
v2.RandomIoUCrop(), v2.RandomIoUCrop(),
v2.RandomHorizontalFlip(p=1), v2.RandomHorizontalFlip(p=1),
v2.SanitizeBoundingBoxes(), v2.SanitizeBoundingBoxes(),
......
...@@ -88,9 +88,9 @@ plot([img, out]) ...@@ -88,9 +88,9 @@ plot([img, out])
# #
# Let's briefly look at a detection example with bounding boxes. # Let's briefly look at a detection example with bounding boxes.
from torchvision import datapoints # we'll describe this a bit later, bare with us from torchvision import tv_tensors # we'll describe this a bit later, bare with us
boxes = datapoints.BoundingBoxes( boxes = tv_tensors.BoundingBoxes(
[ [
[15, 10, 370, 510], [15, 10, 370, 510],
[275, 340, 510, 510], [275, 340, 510, 510],
...@@ -111,44 +111,44 @@ plot([(img, boxes), (out_img, out_boxes)]) ...@@ -111,44 +111,44 @@ plot([(img, boxes), (out_img, out_boxes)])
# %% # %%
# #
# The example above focuses on object detection. But if we had masks # The example above focuses on object detection. But if we had masks
# (:class:`torchvision.datapoints.Mask`) for object segmentation or semantic # (:class:`torchvision.tv_tensors.Mask`) for object segmentation or semantic
# segmentation, or videos (:class:`torchvision.datapoints.Video`), we could have # segmentation, or videos (:class:`torchvision.tv_tensors.Video`), we could have
# passed them to the transforms in exactly the same way. # passed them to the transforms in exactly the same way.
# #
# By now you likely have a few questions: what are these datapoints, how do we # By now you likely have a few questions: what are these tv_tensors, how do we
# use them, and what is the expected input/output of those transforms? We'll # use them, and what is the expected input/output of those transforms? We'll
# answer these in the next sections. # answer these in the next sections.
# %% # %%
# #
# .. _what_are_datapoints: # .. _what_are_tv_tensors:
# #
# What are Datapoints? # What are TVTensors?
# -------------------- # --------------------
# #
# Datapoints are :class:`torch.Tensor` subclasses. The available datapoints are # TVTensors are :class:`torch.Tensor` subclasses. The available tv_tensors are
# :class:`~torchvision.datapoints.Image`, # :class:`~torchvision.tv_tensors.Image`,
# :class:`~torchvision.datapoints.BoundingBoxes`, # :class:`~torchvision.tv_tensors.BoundingBoxes`,
# :class:`~torchvision.datapoints.Mask`, and # :class:`~torchvision.tv_tensors.Mask`, and
# :class:`~torchvision.datapoints.Video`. # :class:`~torchvision.tv_tensors.Video`.
# #
# Datapoints look and feel just like regular tensors - they **are** tensors. # TVTensors look and feel just like regular tensors - they **are** tensors.
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` # Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()``
# or any ``torch.*`` operator will also work on a datapoint: # or any ``torch.*`` operator will also work on a tv_tensor:
img_dp = datapoints.Image(torch.randint(0, 256, (3, 256, 256), dtype=torch.uint8)) img_dp = tv_tensors.Image(torch.randint(0, 256, (3, 256, 256), dtype=torch.uint8))
print(f"{isinstance(img_dp, torch.Tensor) = }") print(f"{isinstance(img_dp, torch.Tensor) = }")
print(f"{img_dp.dtype = }, {img_dp.shape = }, {img_dp.sum() = }") print(f"{img_dp.dtype = }, {img_dp.shape = }, {img_dp.sum() = }")
# %% # %%
# These Datapoint classes are at the core of the transforms: in order to # These TVTensor classes are at the core of the transforms: in order to
# transform a given input, the transforms first look at the **class** of the # transform a given input, the transforms first look at the **class** of the
# object, and dispatch to the appropriate implementation accordingly. # object, and dispatch to the appropriate implementation accordingly.
# #
# You don't need to know much more about datapoints at this point, but advanced # You don't need to know much more about tv_tensors at this point, but advanced
# users who want to learn more can refer to # users who want to learn more can refer to
# :ref:`sphx_glr_auto_examples_transforms_plot_datapoints.py`. # :ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`.
# #
# What do I pass as input? # What do I pass as input?
# ------------------------ # ------------------------
...@@ -196,17 +196,17 @@ print(f"{out_target['this_is_ignored']}") ...@@ -196,17 +196,17 @@ print(f"{out_target['this_is_ignored']}")
# Pure :class:`torch.Tensor` objects are, in general, treated as images (or # Pure :class:`torch.Tensor` objects are, in general, treated as images (or
# as videos for video-specific transforms). Indeed, you may have noticed # as videos for video-specific transforms). Indeed, you may have noticed
# that in the code above we haven't used the # that in the code above we haven't used the
# :class:`~torchvision.datapoints.Image` class at all, and yet our images # :class:`~torchvision.tv_tensors.Image` class at all, and yet our images
# got transformed properly. Transforms follow the following logic to # got transformed properly. Transforms follow the following logic to
# determine whether a pure Tensor should be treated as an image (or video), # determine whether a pure Tensor should be treated as an image (or video),
# or just ignored: # or just ignored:
# #
# * If there is an :class:`~torchvision.datapoints.Image`, # * If there is an :class:`~torchvision.tv_tensors.Image`,
# :class:`~torchvision.datapoints.Video`, # :class:`~torchvision.tv_tensors.Video`,
# or :class:`PIL.Image.Image` instance in the input, all other pure # or :class:`PIL.Image.Image` instance in the input, all other pure
# tensors are passed-through. # tensors are passed-through.
# * If there is no :class:`~torchvision.datapoints.Image` or # * If there is no :class:`~torchvision.tv_tensors.Image` or
# :class:`~torchvision.datapoints.Video` instance, only the first pure # :class:`~torchvision.tv_tensors.Video` instance, only the first pure
# :class:`torch.Tensor` will be transformed as image or video, while all # :class:`torch.Tensor` will be transformed as image or video, while all
# others will be passed-through. Here "first" means "first in a depth-wise # others will be passed-through. Here "first" means "first in a depth-wise
# traversal". # traversal".
...@@ -234,9 +234,9 @@ print(f"{out_target['this_is_ignored']}") ...@@ -234,9 +234,9 @@ print(f"{out_target['this_is_ignored']}")
# Torchvision also supports datasets for object detection or segmentation like # Torchvision also supports datasets for object detection or segmentation like
# :class:`torchvision.datasets.CocoDetection`. Those datasets predate # :class:`torchvision.datasets.CocoDetection`. Those datasets predate
# the existence of the :mod:`torchvision.transforms.v2` module and of the # the existence of the :mod:`torchvision.transforms.v2` module and of the
# datapoints, so they don't return datapoints out of the box. # tv_tensors, so they don't return tv_tensors out of the box.
# #
# An easy way to force those datasets to return datapoints and to make them # An easy way to force those datasets to return tv_tensors and to make them
# compatible with v2 transforms is to use the # compatible with v2 transforms is to use the
# :func:`torchvision.datasets.wrap_dataset_for_transforms_v2` function: # :func:`torchvision.datasets.wrap_dataset_for_transforms_v2` function:
# #
...@@ -246,14 +246,14 @@ print(f"{out_target['this_is_ignored']}") ...@@ -246,14 +246,14 @@ print(f"{out_target['this_is_ignored']}")
# #
# dataset = CocoDetection(..., transforms=my_transforms) # dataset = CocoDetection(..., transforms=my_transforms)
# dataset = wrap_dataset_for_transforms_v2(dataset) # dataset = wrap_dataset_for_transforms_v2(dataset)
# # Now the dataset returns datapoints! # # Now the dataset returns tv_tensors!
# #
# Using your own datasets # Using your own datasets
# ^^^^^^^^^^^^^^^^^^^^^^^ # ^^^^^^^^^^^^^^^^^^^^^^^
# #
# If you have a custom dataset, then you'll need to convert your objects into # If you have a custom dataset, then you'll need to convert your objects into
# the appropriate Datapoint classes. Creating Datapoint instances is very easy, # the appropriate TVTensor classes. Creating TVTensor instances is very easy,
# refer to :ref:`datapoint_creation` for more details. # refer to :ref:`tv_tensor_creation` for more details.
# #
# There are two main places where you can implement that conversion logic: # There are two main places where you can implement that conversion logic:
# #
......
""" """
============== =============
Datapoints FAQ TVTensors FAQ
============== =============
.. note:: .. note::
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_datapoints.ipynb>`_ Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_tv_tensors.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_datapoints.py>` to download the full example code. or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_tv_tensors.py>` to download the full example code.
Datapoints are Tensor subclasses introduced together with TVTensors are Tensor subclasses introduced together with
``torchvision.transforms.v2``. This example showcases what these datapoints are ``torchvision.transforms.v2``. This example showcases what these tv_tensors are
and how they behave. and how they behave.
.. warning:: .. warning::
**Intended Audience** Unless you're writing your own transforms or your own datapoints, you **Intended Audience** Unless you're writing your own transforms or your own tv_tensors, you
probably do not need to read this guide. This is a fairly low-level topic probably do not need to read this guide. This is a fairly low-level topic
that most users will not need to worry about: you do not need to understand that most users will not need to worry about: you do not need to understand
the internals of datapoints to efficiently rely on the internals of tv_tensors to efficiently rely on
``torchvision.transforms.v2``. It may however be useful for advanced users ``torchvision.transforms.v2``. It may however be useful for advanced users
trying to implement their own datasets, transforms, or work directly with trying to implement their own datasets, transforms, or work directly with
the datapoints. the tv_tensors.
""" """
# %% # %%
import PIL.Image import PIL.Image
import torch import torch
from torchvision import datapoints from torchvision import tv_tensors
# %% # %%
# What are datapoints? # What are tv_tensors?
# -------------------- # --------------------
# #
# Datapoints are zero-copy tensor subclasses: # TVTensors are zero-copy tensor subclasses:
tensor = torch.rand(3, 256, 256) tensor = torch.rand(3, 256, 256)
image = datapoints.Image(tensor) image = tv_tensors.Image(tensor)
assert isinstance(image, torch.Tensor) assert isinstance(image, torch.Tensor)
assert image.data_ptr() == tensor.data_ptr() assert image.data_ptr() == tensor.data_ptr()
...@@ -46,33 +46,33 @@ assert image.data_ptr() == tensor.data_ptr() ...@@ -46,33 +46,33 @@ assert image.data_ptr() == tensor.data_ptr()
# Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function # Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
# for the input data. # for the input data.
# #
# :mod:`torchvision.datapoints` supports four types of datapoints: # :mod:`torchvision.tv_tensors` supports four types of tv_tensors:
# #
# * :class:`~torchvision.datapoints.Image` # * :class:`~torchvision.tv_tensors.Image`
# * :class:`~torchvision.datapoints.Video` # * :class:`~torchvision.tv_tensors.Video`
# * :class:`~torchvision.datapoints.BoundingBoxes` # * :class:`~torchvision.tv_tensors.BoundingBoxes`
# * :class:`~torchvision.datapoints.Mask` # * :class:`~torchvision.tv_tensors.Mask`
# #
# What can I do with a datapoint? # What can I do with a tv_tensor?
# ------------------------------- # -------------------------------
# #
# Datapoints look and feel just like regular tensors - they **are** tensors. # TVTensors look and feel just like regular tensors - they **are** tensors.
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or # Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or
# any ``torch.*`` operator will also work on datapoints. See # any ``torch.*`` operator will also work on tv_tensors. See
# :ref:`datapoint_unwrapping_behaviour` for a few gotchas. # :ref:`tv_tensor_unwrapping_behaviour` for a few gotchas.
# %% # %%
# .. _datapoint_creation: # .. _tv_tensor_creation:
# #
# How do I construct a datapoint? # How do I construct a tv_tensor?
# ------------------------------- # -------------------------------
# #
# Using the constructor # Using the constructor
# ^^^^^^^^^^^^^^^^^^^^^ # ^^^^^^^^^^^^^^^^^^^^^
# #
# Each datapoint class takes any tensor-like data that can be turned into a :class:`~torch.Tensor` # Each tv_tensor class takes any tensor-like data that can be turned into a :class:`~torch.Tensor`
image = datapoints.Image([[[[0, 1], [1, 0]]]]) image = tv_tensors.Image([[[[0, 1], [1, 0]]]])
print(image) print(image)
...@@ -80,64 +80,64 @@ print(image) ...@@ -80,64 +80,64 @@ print(image)
# Similar to other PyTorch creations ops, the constructor also takes the ``dtype``, ``device``, and ``requires_grad`` # Similar to other PyTorch creations ops, the constructor also takes the ``dtype``, ``device``, and ``requires_grad``
# parameters. # parameters.
float_image = datapoints.Image([[[0, 1], [1, 0]]], dtype=torch.float32, requires_grad=True) float_image = tv_tensors.Image([[[0, 1], [1, 0]]], dtype=torch.float32, requires_grad=True)
print(float_image) print(float_image)
# %% # %%
# In addition, :class:`~torchvision.datapoints.Image` and :class:`~torchvision.datapoints.Mask` can also take a # In addition, :class:`~torchvision.tv_tensors.Image` and :class:`~torchvision.tv_tensors.Mask` can also take a
# :class:`PIL.Image.Image` directly: # :class:`PIL.Image.Image` directly:
image = datapoints.Image(PIL.Image.open("../assets/astronaut.jpg")) image = tv_tensors.Image(PIL.Image.open("../assets/astronaut.jpg"))
print(image.shape, image.dtype) print(image.shape, image.dtype)
# %% # %%
# Some datapoints require additional metadata to be passed in ordered to be constructed. For example, # Some tv_tensors require additional metadata to be passed in ordered to be constructed. For example,
# :class:`~torchvision.datapoints.BoundingBoxes` requires the coordinate format as well as the size of the # :class:`~torchvision.tv_tensors.BoundingBoxes` requires the coordinate format as well as the size of the
# corresponding image (``canvas_size``) alongside the actual values. These # corresponding image (``canvas_size``) alongside the actual values. These
# metadata are required to properly transform the bounding boxes. # metadata are required to properly transform the bounding boxes.
bboxes = datapoints.BoundingBoxes( bboxes = tv_tensors.BoundingBoxes(
[[17, 16, 344, 495], [0, 10, 0, 10]], [[17, 16, 344, 495], [0, 10, 0, 10]],
format=datapoints.BoundingBoxFormat.XYXY, format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=image.shape[-2:] canvas_size=image.shape[-2:]
) )
print(bboxes) print(bboxes)
# %% # %%
# Using ``datapoints.wrap()`` # Using ``tv_tensors.wrap()``
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^
# #
# You can also use the :func:`~torchvision.datapoints.wrap` function to wrap a tensor object # You can also use the :func:`~torchvision.tv_tensors.wrap` function to wrap a tensor object
# into a datapoint. This is useful when you already have an object of the # into a tv_tensor. This is useful when you already have an object of the
# desired type, which typically happens when writing transforms: you just want # desired type, which typically happens when writing transforms: you just want
# to wrap the output like the input. # to wrap the output like the input.
new_bboxes = torch.tensor([0, 20, 30, 40]) new_bboxes = torch.tensor([0, 20, 30, 40])
new_bboxes = datapoints.wrap(new_bboxes, like=bboxes) new_bboxes = tv_tensors.wrap(new_bboxes, like=bboxes)
assert isinstance(new_bboxes, datapoints.BoundingBoxes) assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
assert new_bboxes.canvas_size == bboxes.canvas_size assert new_bboxes.canvas_size == bboxes.canvas_size
# %% # %%
# The metadata of ``new_bboxes`` is the same as ``bboxes``, but you could pass # The metadata of ``new_bboxes`` is the same as ``bboxes``, but you could pass
# it as a parameter to override it. # it as a parameter to override it.
# #
# .. _datapoint_unwrapping_behaviour: # .. _tv_tensor_unwrapping_behaviour:
# #
# I had a Datapoint but now I have a Tensor. Help! # I had a TVTensor but now I have a Tensor. Help!
# ------------------------------------------------ # ------------------------------------------------
# #
# By default, operations on :class:`~torchvision.datapoints.Datapoint` objects # By default, operations on :class:`~torchvision.tv_tensors.TVTensor` objects
# will return a pure Tensor: # will return a pure Tensor:
assert isinstance(bboxes, datapoints.BoundingBoxes) assert isinstance(bboxes, tv_tensors.BoundingBoxes)
# Shift bboxes by 3 pixels in both H and W # Shift bboxes by 3 pixels in both H and W
new_bboxes = bboxes + 3 new_bboxes = bboxes + 3
assert isinstance(new_bboxes, torch.Tensor) assert isinstance(new_bboxes, torch.Tensor)
assert not isinstance(new_bboxes, datapoints.BoundingBoxes) assert not isinstance(new_bboxes, tv_tensors.BoundingBoxes)
# %% # %%
# .. note:: # .. note::
...@@ -145,36 +145,36 @@ assert not isinstance(new_bboxes, datapoints.BoundingBoxes) ...@@ -145,36 +145,36 @@ assert not isinstance(new_bboxes, datapoints.BoundingBoxes)
# This behavior only affects native ``torch`` operations. If you are using # This behavior only affects native ``torch`` operations. If you are using
# the built-in ``torchvision`` transforms or functionals, you will always get # the built-in ``torchvision`` transforms or functionals, you will always get
# as output the same type that you passed as input (pure ``Tensor`` or # as output the same type that you passed as input (pure ``Tensor`` or
# ``Datapoint``). # ``TVTensor``).
# %% # %%
# But I want a Datapoint back! # But I want a TVTensor back!
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^
# #
# You can re-wrap a pure tensor into a datapoint by just calling the datapoint # You can re-wrap a pure tensor into a tv_tensor by just calling the tv_tensor
# constructor, or by using the :func:`~torchvision.datapoints.wrap` function # constructor, or by using the :func:`~torchvision.tv_tensors.wrap` function
# (see more details above in :ref:`datapoint_creation`): # (see more details above in :ref:`tv_tensor_creation`):
new_bboxes = bboxes + 3 new_bboxes = bboxes + 3
new_bboxes = datapoints.wrap(new_bboxes, like=bboxes) new_bboxes = tv_tensors.wrap(new_bboxes, like=bboxes)
assert isinstance(new_bboxes, datapoints.BoundingBoxes) assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
# %% # %%
# Alternatively, you can use the :func:`~torchvision.datapoints.set_return_type` # Alternatively, you can use the :func:`~torchvision.tv_tensors.set_return_type`
# as a global config setting for the whole program, or as a context manager # as a global config setting for the whole program, or as a context manager
# (read its docs to learn more about caveats): # (read its docs to learn more about caveats):
with datapoints.set_return_type("datapoint"): with tv_tensors.set_return_type("tv_tensor"):
new_bboxes = bboxes + 3 new_bboxes = bboxes + 3
assert isinstance(new_bboxes, datapoints.BoundingBoxes) assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
# %% # %%
# Why is this happening? # Why is this happening?
# ^^^^^^^^^^^^^^^^^^^^^^ # ^^^^^^^^^^^^^^^^^^^^^^
# #
# **For performance reasons**. :class:`~torchvision.datapoints.Datapoint` # **For performance reasons**. :class:`~torchvision.tv_tensors.TVTensor`
# classes are Tensor subclasses, so any operation involving a # classes are Tensor subclasses, so any operation involving a
# :class:`~torchvision.datapoints.Datapoint` object will go through the # :class:`~torchvision.tv_tensors.TVTensor` object will go through the
# `__torch_function__ # `__torch_function__
# <https://pytorch.org/docs/stable/notes/extending.html#extending-torch>`_ # <https://pytorch.org/docs/stable/notes/extending.html#extending-torch>`_
# protocol. This induces a small overhead, which we want to avoid when possible. # protocol. This induces a small overhead, which we want to avoid when possible.
...@@ -183,12 +183,12 @@ assert isinstance(new_bboxes, datapoints.BoundingBoxes) ...@@ -183,12 +183,12 @@ assert isinstance(new_bboxes, datapoints.BoundingBoxes)
# ``forward``. # ``forward``.
# #
# **The alternative isn't much better anyway.** For every operation where # **The alternative isn't much better anyway.** For every operation where
# preserving the :class:`~torchvision.datapoints.Datapoint` type makes # preserving the :class:`~torchvision.tv_tensors.TVTensor` type makes
# sense, there are just as many operations where returning a pure Tensor is # sense, there are just as many operations where returning a pure Tensor is
# preferable: for example, is ``img.sum()`` still an :class:`~torchvision.datapoints.Image`? # preferable: for example, is ``img.sum()`` still an :class:`~torchvision.tv_tensors.Image`?
# If we were to preserve :class:`~torchvision.datapoints.Datapoint` types all # If we were to preserve :class:`~torchvision.tv_tensors.TVTensor` types all
# the way, even model's logits or the output of the loss function would end up # the way, even model's logits or the output of the loss function would end up
# being of type :class:`~torchvision.datapoints.Image`, and surely that's not # being of type :class:`~torchvision.tv_tensors.Image`, and surely that's not
# desirable. # desirable.
# #
# .. note:: # .. note::
...@@ -203,22 +203,22 @@ assert isinstance(new_bboxes, datapoints.BoundingBoxes) ...@@ -203,22 +203,22 @@ assert isinstance(new_bboxes, datapoints.BoundingBoxes)
# There are a few exceptions to this "unwrapping" rule: # There are a few exceptions to this "unwrapping" rule:
# :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`, # :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`,
# :meth:`torch.Tensor.detach`, and :meth:`~torch.Tensor.requires_grad_` retain # :meth:`torch.Tensor.detach`, and :meth:`~torch.Tensor.requires_grad_` retain
# the datapoint type. # the tv_tensor type.
# #
# Inplace operations on datapoints like ``obj.add_()`` will preserve the type of # Inplace operations on tv_tensors like ``obj.add_()`` will preserve the type of
# ``obj``. However, the **returned** value of inplace operations will be a pure # ``obj``. However, the **returned** value of inplace operations will be a pure
# tensor: # tensor:
image = datapoints.Image([[[0, 1], [1, 0]]]) image = tv_tensors.Image([[[0, 1], [1, 0]]])
new_image = image.add_(1).mul_(2) new_image = image.add_(1).mul_(2)
# image got transformed in-place and is still an Image datapoint, but new_image # image got transformed in-place and is still an Image tv_tensor, but new_image
# is a Tensor. They share the same underlying data and they're equal, just # is a Tensor. They share the same underlying data and they're equal, just
# different classes. # different classes.
assert isinstance(image, datapoints.Image) assert isinstance(image, tv_tensors.Image)
print(image) print(image)
assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image) assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, tv_tensors.Image)
assert (new_image == image).all() assert (new_image == image).all()
assert new_image.data_ptr() == image.data_ptr() assert new_image.data_ptr() == image.data_ptr()
...@@ -7,10 +7,10 @@ import transforms as reference_transforms ...@@ -7,10 +7,10 @@ import transforms as reference_transforms
def get_modules(use_v2): def get_modules(use_v2):
# We need a protected import to avoid the V2 warning in case just V1 is used # We need a protected import to avoid the V2 warning in case just V1 is used
if use_v2: if use_v2:
import torchvision.datapoints
import torchvision.transforms.v2 import torchvision.transforms.v2
import torchvision.tv_tensors
return torchvision.transforms.v2, torchvision.datapoints return torchvision.transforms.v2, torchvision.tv_tensors
else: else:
return reference_transforms, None return reference_transforms, None
...@@ -28,16 +28,16 @@ class DetectionPresetTrain: ...@@ -28,16 +28,16 @@ class DetectionPresetTrain:
use_v2=False, use_v2=False,
): ):
T, datapoints = get_modules(use_v2) T, tv_tensors = get_modules(use_v2)
transforms = [] transforms = []
backend = backend.lower() backend = backend.lower()
if backend == "datapoint": if backend == "tv_tensor":
transforms.append(T.ToImage()) transforms.append(T.ToImage())
elif backend == "tensor": elif backend == "tensor":
transforms.append(T.PILToTensor()) transforms.append(T.PILToTensor())
elif backend != "pil": elif backend != "pil":
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") raise ValueError(f"backend can be 'tv_tensor', 'tensor' or 'pil', but got {backend}")
if data_augmentation == "hflip": if data_augmentation == "hflip":
transforms += [T.RandomHorizontalFlip(p=hflip_prob)] transforms += [T.RandomHorizontalFlip(p=hflip_prob)]
...@@ -54,7 +54,7 @@ class DetectionPresetTrain: ...@@ -54,7 +54,7 @@ class DetectionPresetTrain:
T.RandomHorizontalFlip(p=hflip_prob), T.RandomHorizontalFlip(p=hflip_prob),
] ]
elif data_augmentation == "ssd": elif data_augmentation == "ssd":
fill = defaultdict(lambda: mean, {datapoints.Mask: 0}) if use_v2 else list(mean) fill = defaultdict(lambda: mean, {tv_tensors.Mask: 0}) if use_v2 else list(mean)
transforms += [ transforms += [
T.RandomPhotometricDistort(), T.RandomPhotometricDistort(),
T.RandomZoomOut(fill=fill), T.RandomZoomOut(fill=fill),
...@@ -77,7 +77,7 @@ class DetectionPresetTrain: ...@@ -77,7 +77,7 @@ class DetectionPresetTrain:
if use_v2: if use_v2:
transforms += [ transforms += [
T.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.XYXY), T.ConvertBoundingBoxFormat(tv_tensors.BoundingBoxFormat.XYXY),
T.SanitizeBoundingBoxes(), T.SanitizeBoundingBoxes(),
T.ToPureTensor(), T.ToPureTensor(),
] ]
...@@ -98,10 +98,10 @@ class DetectionPresetEval: ...@@ -98,10 +98,10 @@ class DetectionPresetEval:
transforms += [T.ToImage() if use_v2 else T.PILToTensor()] transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
elif backend == "tensor": elif backend == "tensor":
transforms += [T.PILToTensor()] transforms += [T.PILToTensor()]
elif backend == "datapoint": elif backend == "tv_tensor":
transforms += [T.ToImage()] transforms += [T.ToImage()]
else: else:
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") raise ValueError(f"backend can be 'tv_tensor', 'tensor' or 'pil', but got {backend}")
transforms += [T.ToDtype(torch.float, scale=True)] transforms += [T.ToDtype(torch.float, scale=True)]
......
...@@ -180,8 +180,8 @@ def get_args_parser(add_help=True): ...@@ -180,8 +180,8 @@ def get_args_parser(add_help=True):
def main(args): def main(args):
if args.backend.lower() == "datapoint" and not args.use_v2: if args.backend.lower() == "tv_tensor" and not args.use_v2:
raise ValueError("Use --use-v2 if you want to use the datapoint backend.") raise ValueError("Use --use-v2 if you want to use the tv_tensor backend.")
if args.dataset not in ("coco", "coco_kp"): if args.dataset not in ("coco", "coco_kp"):
raise ValueError(f"Dataset should be coco or coco_kp, got {args.dataset}") raise ValueError(f"Dataset should be coco or coco_kp, got {args.dataset}")
if "keypoint" in args.model and args.dataset != "coco_kp": if "keypoint" in args.model and args.dataset != "coco_kp":
......
...@@ -4,11 +4,11 @@ import torch ...@@ -4,11 +4,11 @@ import torch
def get_modules(use_v2): def get_modules(use_v2):
# We need a protected import to avoid the V2 warning in case just V1 is used # We need a protected import to avoid the V2 warning in case just V1 is used
if use_v2: if use_v2:
import torchvision.datapoints
import torchvision.transforms.v2 import torchvision.transforms.v2
import torchvision.tv_tensors
import v2_extras import v2_extras
return torchvision.transforms.v2, torchvision.datapoints, v2_extras return torchvision.transforms.v2, torchvision.tv_tensors, v2_extras
else: else:
import transforms import transforms
...@@ -27,16 +27,16 @@ class SegmentationPresetTrain: ...@@ -27,16 +27,16 @@ class SegmentationPresetTrain:
backend="pil", backend="pil",
use_v2=False, use_v2=False,
): ):
T, datapoints, v2_extras = get_modules(use_v2) T, tv_tensors, v2_extras = get_modules(use_v2)
transforms = [] transforms = []
backend = backend.lower() backend = backend.lower()
if backend == "datapoint": if backend == "tv_tensor":
transforms.append(T.ToImage()) transforms.append(T.ToImage())
elif backend == "tensor": elif backend == "tensor":
transforms.append(T.PILToTensor()) transforms.append(T.PILToTensor())
elif backend != "pil": elif backend != "pil":
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") raise ValueError(f"backend can be 'tv_tensor', 'tensor' or 'pil', but got {backend}")
transforms += [T.RandomResize(min_size=int(0.5 * base_size), max_size=int(2.0 * base_size))] transforms += [T.RandomResize(min_size=int(0.5 * base_size), max_size=int(2.0 * base_size))]
...@@ -46,7 +46,7 @@ class SegmentationPresetTrain: ...@@ -46,7 +46,7 @@ class SegmentationPresetTrain:
if use_v2: if use_v2:
# We need a custom pad transform here, since the padding we want to perform here is fundamentally # We need a custom pad transform here, since the padding we want to perform here is fundamentally
# different from the padding in `RandomCrop` if `pad_if_needed=True`. # different from the padding in `RandomCrop` if `pad_if_needed=True`.
transforms += [v2_extras.PadIfSmaller(crop_size, fill={datapoints.Mask: 255, "others": 0})] transforms += [v2_extras.PadIfSmaller(crop_size, fill={tv_tensors.Mask: 255, "others": 0})]
transforms += [T.RandomCrop(crop_size)] transforms += [T.RandomCrop(crop_size)]
...@@ -54,9 +54,9 @@ class SegmentationPresetTrain: ...@@ -54,9 +54,9 @@ class SegmentationPresetTrain:
transforms += [T.PILToTensor()] transforms += [T.PILToTensor()]
if use_v2: if use_v2:
img_type = datapoints.Image if backend == "datapoint" else torch.Tensor img_type = tv_tensors.Image if backend == "tv_tensor" else torch.Tensor
transforms += [ transforms += [
T.ToDtype(dtype={img_type: torch.float32, datapoints.Mask: torch.int64, "others": None}, scale=True) T.ToDtype(dtype={img_type: torch.float32, tv_tensors.Mask: torch.int64, "others": None}, scale=True)
] ]
else: else:
# No need to explicitly convert masks as they're magically int64 already # No need to explicitly convert masks as they're magically int64 already
...@@ -82,10 +82,10 @@ class SegmentationPresetEval: ...@@ -82,10 +82,10 @@ class SegmentationPresetEval:
backend = backend.lower() backend = backend.lower()
if backend == "tensor": if backend == "tensor":
transforms += [T.PILToTensor()] transforms += [T.PILToTensor()]
elif backend == "datapoint": elif backend == "tv_tensor":
transforms += [T.ToImage()] transforms += [T.ToImage()]
elif backend != "pil": elif backend != "pil":
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") raise ValueError(f"backend can be 'tv_tensor', 'tensor' or 'pil', but got {backend}")
if use_v2: if use_v2:
transforms += [T.Resize(size=(base_size, base_size))] transforms += [T.Resize(size=(base_size, base_size))]
......
...@@ -128,7 +128,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi ...@@ -128,7 +128,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi
def main(args): def main(args):
if args.backend.lower() != "pil" and not args.use_v2: if args.backend.lower() != "pil" and not args.use_v2:
# TODO: Support tensor backend in V1? # TODO: Support tensor backend in V1?
raise ValueError("Use --use-v2 if you want to use the datapoint or tensor backend.") raise ValueError("Use --use-v2 if you want to use the tv_tensor or tensor backend.")
if args.use_v2 and args.dataset != "coco": if args.use_v2 and args.dataset != "coco":
raise ValueError("v2 is only support supported for coco dataset for now.") raise ValueError("v2 is only support supported for coco dataset for now.")
......
"""This file only exists to be lazy-imported and avoid V2-related import warnings when just using V1.""" """This file only exists to be lazy-imported and avoid V2-related import warnings when just using V1."""
import torch import torch
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.transforms import v2 from torchvision.transforms import v2
...@@ -80,4 +80,4 @@ class CocoDetectionToVOCSegmentation(v2.Transform): ...@@ -80,4 +80,4 @@ class CocoDetectionToVOCSegmentation(v2.Transform):
if segmentation_mask is None: if segmentation_mask is None:
segmentation_mask = torch.zeros(v2.functional.get_size(image), dtype=torch.uint8) segmentation_mask = torch.zeros(v2.functional.get_size(image), dtype=torch.uint8)
return image, datapoints.Mask(segmentation_mask) return image, tv_tensors.Mask(segmentation_mask)
...@@ -19,7 +19,7 @@ import torch.testing ...@@ -19,7 +19,7 @@ import torch.testing
from PIL import Image from PIL import Image
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision import datapoints, io from torchvision import io, tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import to_image, to_pil_image from torchvision.transforms.v2.functional import to_image, to_pil_image
...@@ -391,7 +391,7 @@ def make_image( ...@@ -391,7 +391,7 @@ def make_image(
if color_space in {"GRAY_ALPHA", "RGBA"}: if color_space in {"GRAY_ALPHA", "RGBA"}:
data[..., -1, :, :] = max_value data[..., -1, :, :] = max_value
return datapoints.Image(data) return tv_tensors.Image(data)
def make_image_tensor(*args, **kwargs): def make_image_tensor(*args, **kwargs):
...@@ -405,7 +405,7 @@ def make_image_pil(*args, **kwargs): ...@@ -405,7 +405,7 @@ def make_image_pil(*args, **kwargs):
def make_bounding_boxes( def make_bounding_boxes(
canvas_size=DEFAULT_SIZE, canvas_size=DEFAULT_SIZE,
*, *,
format=datapoints.BoundingBoxFormat.XYXY, format=tv_tensors.BoundingBoxFormat.XYXY,
dtype=None, dtype=None,
device="cpu", device="cpu",
): ):
...@@ -415,7 +415,7 @@ def make_bounding_boxes( ...@@ -415,7 +415,7 @@ def make_bounding_boxes(
return torch.stack([torch.randint(max_value - v, ()) for v in values.tolist()]) return torch.stack([torch.randint(max_value - v, ()) for v in values.tolist()])
if isinstance(format, str): if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format] format = tv_tensors.BoundingBoxFormat[format]
dtype = dtype or torch.float32 dtype = dtype or torch.float32
...@@ -424,21 +424,21 @@ def make_bounding_boxes( ...@@ -424,21 +424,21 @@ def make_bounding_boxes(
y = sample_position(h, canvas_size[0]) y = sample_position(h, canvas_size[0])
x = sample_position(w, canvas_size[1]) x = sample_position(w, canvas_size[1])
if format is datapoints.BoundingBoxFormat.XYWH: if format is tv_tensors.BoundingBoxFormat.XYWH:
parts = (x, y, w, h) parts = (x, y, w, h)
elif format is datapoints.BoundingBoxFormat.XYXY: elif format is tv_tensors.BoundingBoxFormat.XYXY:
x1, y1 = x, y x1, y1 = x, y
x2 = x1 + w x2 = x1 + w
y2 = y1 + h y2 = y1 + h
parts = (x1, y1, x2, y2) parts = (x1, y1, x2, y2)
elif format is datapoints.BoundingBoxFormat.CXCYWH: elif format is tv_tensors.BoundingBoxFormat.CXCYWH:
cx = x + w / 2 cx = x + w / 2
cy = y + h / 2 cy = y + h / 2
parts = (cx, cy, w, h) parts = (cx, cy, w, h)
else: else:
raise ValueError(f"Format {format} is not supported") raise ValueError(f"Format {format} is not supported")
return datapoints.BoundingBoxes( return tv_tensors.BoundingBoxes(
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, canvas_size=canvas_size torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, canvas_size=canvas_size
) )
...@@ -446,7 +446,7 @@ def make_bounding_boxes( ...@@ -446,7 +446,7 @@ def make_bounding_boxes(
def make_detection_mask(size=DEFAULT_SIZE, *, dtype=None, device="cpu"): def make_detection_mask(size=DEFAULT_SIZE, *, dtype=None, device="cpu"):
"""Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks""" """Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks"""
num_objects = 1 num_objects = 1
return datapoints.Mask( return tv_tensors.Mask(
torch.testing.make_tensor( torch.testing.make_tensor(
(num_objects, *size), (num_objects, *size),
low=0, low=0,
...@@ -459,7 +459,7 @@ def make_detection_mask(size=DEFAULT_SIZE, *, dtype=None, device="cpu"): ...@@ -459,7 +459,7 @@ def make_detection_mask(size=DEFAULT_SIZE, *, dtype=None, device="cpu"):
def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=(), dtype=None, device="cpu"): def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=(), dtype=None, device="cpu"):
"""Make a "segmentation" mask, i.e. (*, H, W), where the category is encoded as pixel value""" """Make a "segmentation" mask, i.e. (*, H, W), where the category is encoded as pixel value"""
return datapoints.Mask( return tv_tensors.Mask(
torch.testing.make_tensor( torch.testing.make_tensor(
(*batch_dims, *size), (*batch_dims, *size),
low=0, low=0,
...@@ -471,7 +471,7 @@ def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=( ...@@ -471,7 +471,7 @@ def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=(
def make_video(size=DEFAULT_SIZE, *, num_frames=3, batch_dims=(), **kwargs): def make_video(size=DEFAULT_SIZE, *, num_frames=3, batch_dims=(), **kwargs):
return datapoints.Video(make_image(size, batch_dims=(*batch_dims, num_frames), **kwargs)) return tv_tensors.Video(make_image(size, batch_dims=(*batch_dims, num_frames), **kwargs))
def make_video_tensor(*args, **kwargs): def make_video_tensor(*args, **kwargs):
......
...@@ -568,7 +568,7 @@ class DatasetTestCase(unittest.TestCase): ...@@ -568,7 +568,7 @@ class DatasetTestCase(unittest.TestCase):
@test_all_configs @test_all_configs
def test_transforms_v2_wrapper(self, config): def test_transforms_v2_wrapper(self, config):
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.datasets import wrap_dataset_for_transforms_v2 from torchvision.datasets import wrap_dataset_for_transforms_v2
try: try:
...@@ -590,7 +590,7 @@ class DatasetTestCase(unittest.TestCase): ...@@ -590,7 +590,7 @@ class DatasetTestCase(unittest.TestCase):
wrapped_sample = wrapped_dataset[0] wrapped_sample = wrapped_dataset[0]
assert tree_any( assert tree_any(
lambda item: isinstance(item, (datapoints.Datapoint, PIL.Image.Image)), wrapped_sample lambda item: isinstance(item, (tv_tensors.TVTensor, PIL.Image.Image)), wrapped_sample
) )
except TypeError as error: except TypeError as error:
msg = f"No wrapper exists for dataset class {type(dataset).__name__}" msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
...@@ -717,7 +717,7 @@ def check_transforms_v2_wrapper_spawn(dataset): ...@@ -717,7 +717,7 @@ def check_transforms_v2_wrapper_spawn(dataset):
pytest.skip("Multiprocessing spawning is only checked on macOS.") pytest.skip("Multiprocessing spawning is only checked on macOS.")
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.datasets import wrap_dataset_for_transforms_v2 from torchvision.datasets import wrap_dataset_for_transforms_v2
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset) wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)
...@@ -726,7 +726,7 @@ def check_transforms_v2_wrapper_spawn(dataset): ...@@ -726,7 +726,7 @@ def check_transforms_v2_wrapper_spawn(dataset):
for wrapped_sample in dataloader: for wrapped_sample in dataloader:
assert tree_any( assert tree_any(
lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample lambda item: isinstance(item, (tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)), wrapped_sample
) )
......
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
import torch import torch
from torch.nn.functional import one_hot from torch.nn.functional import one_hot
from torchvision.prototype import datapoints from torchvision.prototype import tv_tensors
from transforms_v2_legacy_utils import combinations_grid, DEFAULT_EXTRA_DIMS, from_loader, from_loaders, TensorLoader from transforms_v2_legacy_utils import combinations_grid, DEFAULT_EXTRA_DIMS, from_loader, from_loaders, TensorLoader
...@@ -40,7 +40,7 @@ def make_label_loader(*, extra_dims=(), categories=None, dtype=torch.int64): ...@@ -40,7 +40,7 @@ def make_label_loader(*, extra_dims=(), categories=None, dtype=torch.int64):
# The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values, # The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values,
# regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123 # regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123
data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype) data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype)
return datapoints.Label(data, categories=categories) return tv_tensors.Label(data, categories=categories)
return LabelLoader(fn, shape=extra_dims, dtype=dtype, categories=categories) return LabelLoader(fn, shape=extra_dims, dtype=dtype, categories=categories)
...@@ -64,7 +64,7 @@ def make_one_hot_label_loader(*, categories=None, extra_dims=(), dtype=torch.int ...@@ -64,7 +64,7 @@ def make_one_hot_label_loader(*, categories=None, extra_dims=(), dtype=torch.int
# since `one_hot` only supports int64 # since `one_hot` only supports int64
label = make_label_loader(extra_dims=extra_dims, categories=num_categories, dtype=torch.int64).load(device) label = make_label_loader(extra_dims=extra_dims, categories=num_categories, dtype=torch.int64).load(device)
data = one_hot(label, num_classes=num_categories).to(dtype) data = one_hot(label, num_classes=num_categories).to(dtype)
return datapoints.OneHotLabel(data, categories=categories) return tv_tensors.OneHotLabel(data, categories=categories)
return OneHotLabelLoader(fn, shape=(*extra_dims, num_categories), dtype=dtype, categories=categories) return OneHotLabelLoader(fn, shape=(*extra_dims, num_categories), dtype=dtype, categories=categories)
......
...@@ -3387,11 +3387,11 @@ class TestDatasetWrapper: ...@@ -3387,11 +3387,11 @@ class TestDatasetWrapper:
datasets.wrap_dataset_for_transforms_v2(dataset) datasets.wrap_dataset_for_transforms_v2(dataset)
def test_subclass(self, mocker): def test_subclass(self, mocker):
from torchvision import datapoints from torchvision import tv_tensors
sentinel = object() sentinel = object()
mocker.patch.dict( mocker.patch.dict(
datapoints._dataset_wrapper.WRAPPER_FACTORIES, tv_tensors._dataset_wrapper.WRAPPER_FACTORIES,
clear=False, clear=False,
values={datasets.FakeData: lambda dataset, target_keys: lambda idx, sample: sentinel}, values={datasets.FakeData: lambda dataset, target_keys: lambda idx, sample: sentinel},
) )
......
...@@ -19,12 +19,12 @@ from torch.utils.data.graph_settings import get_all_graph_pipes ...@@ -19,12 +19,12 @@ from torch.utils.data.graph_settings import get_all_graph_pipes
from torchdata.dataloader2.graph.utils import traverse_dps from torchdata.dataloader2.graph.utils import traverse_dps
from torchdata.datapipes.iter import ShardingFilter, Shuffler from torchdata.datapipes.iter import ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper from torchdata.datapipes.utils import StreamWrapper
from torchvision import datapoints from torchvision import tv_tensors
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype import datasets from torchvision.prototype import datasets
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import EncodedImage from torchvision.prototype.datasets.utils import EncodedImage
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
from torchvision.prototype.tv_tensors import Label
from torchvision.transforms.v2._utils import is_pure_tensor from torchvision.transforms.v2._utils import is_pure_tensor
...@@ -147,7 +147,7 @@ class TestCommon: ...@@ -147,7 +147,7 @@ class TestCommon:
pure_tensors = {key for key, value in sample.items() if is_pure_tensor(value)} pure_tensors = {key for key, value in sample.items() if is_pure_tensor(value)}
if pure_tensors and not any( if pure_tensors and not any(
isinstance(item, (datapoints.Image, datapoints.Video, EncodedImage)) for item in sample.values() isinstance(item, (tv_tensors.Image, tv_tensors.Video, EncodedImage)) for item in sample.values()
): ):
raise AssertionError( raise AssertionError(
f"The values of key(s) " f"The values of key(s) "
...@@ -276,7 +276,7 @@ class TestUSPS: ...@@ -276,7 +276,7 @@ class TestUSPS:
assert "image" in sample assert "image" in sample
assert "label" in sample assert "label" in sample
assert isinstance(sample["image"], datapoints.Image) assert isinstance(sample["image"], tv_tensors.Image)
assert isinstance(sample["label"], Label) assert isinstance(sample["label"], Label)
assert sample["image"].shape == (1, 16, 16) assert sample["image"].shape == (1, 16, 16)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment