Unverified Commit 37081ee6 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Revamp transforms doc (#7859)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 2c44ebae
.. _datapoints:
Datapoints
==========
......
.. _datasets:
Datasets
========
......
......@@ -5,242 +5,450 @@ Transforming and augmenting images
.. currentmodule:: torchvision.transforms
Torchvision supports common computer vision transformations in the
``torchvision.transforms`` and ``torchvision.transforms.v2`` modules. Transforms
can be used to transform or augment data for training or inference of different
tasks (image classification, detection, segmentation, video classification).
.. note::
In 0.15, we released a new set of transforms available in the
``torchvision.transforms.v2`` namespace, which add support for transforming
not just images but also bounding boxes, masks, or videos. These transforms
are fully backward compatible with the current ones, and you'll see them
documented below with a `v2.` prefix. To get started with those new
transforms, you can check out
:ref:`sphx_glr_auto_examples_v2_transforms_plot_transforms_v2_e2e.py`.
Note that these transforms are still BETA, and while we don't expect major
breaking changes in the future, some APIs may still change according to user
feedback. Please submit any feedback you may have `here
<https://github.com/pytorch/vision/issues/6753>`_, and you can also check
out `this issue <https://github.com/pytorch/vision/issues/7319>`_ to learn
more about the APIs that we suspect might involve future changes.
Transforms are common image transformations available in the
``torchvision.transforms`` module. They can be chained together using
:class:`Compose`.
Most transform classes have a function equivalent: :ref:`functional
transforms <functional_transforms>` give fine-grained control over the
transformations.
This is useful if you have to build a more complex transformation pipeline
(e.g. in the case of segmentation tasks).
.. code:: python
# Image Classification
import torch
from torchvision.transforms import v2
H, W = 32, 32
img = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)
transforms = v2.Compose([
v2.RandomResizedCrop(size=(224, 224), antialias=True),
v2.RandomHorizontalFlip(p=0.5),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
img = transforms(img)
.. code:: python
# Detection (re-using imports and transforms from above)
from torchvision import datapoints
img = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)
bboxes = torch.randint(0, H // 2, size=(3, 4))
bboxes[:, 2:] += bboxes[:, :2]
bboxes = datapoints.BoundingBoxes(bboxes, format="XYXY", canvas_size=(H, W))
# The same transforms can be used!
img, bboxes = transforms(img, bboxes)
# And you can pass arbitrary input structures
output_dict = transforms({"image": img, "bboxes": bboxes})
Transforms are typically passed as the ``transform`` or ``transforms`` argument
to the :ref:`Datasets <datasets>`.
.. TODO: add link to getting started guide here.
Supported input types and conventions
-------------------------------------
Most transformations accept both `PIL <https://pillow.readthedocs.io>`_ images
and tensor images, although some transformations are PIL-only and some are
tensor-only. The :ref:`conversion_transforms` may be used to convert to and from
PIL images, or for converting dtypes and ranges.
and tensor images. The result of both backends (PIL or Tensors) should be very
close. In general, we recommend relying on the tensor backend :ref:`for
performance <transforms_perf>`. The :ref:`conversion transforms
<conversion_transforms>` may be used to convert to and from PIL images, or for
converting dtypes and ranges.
Tensor image are expected to be of shape ``(C, H, W)``, where ``C`` is the
number of channels, and ``H`` and ``W`` refer to height and width. Most
transforms support batched tensor input. A batch of Tensor images is a tensor of
shape ``(N, C, H, W)``, where ``N`` is a number of images in the batch. The
:ref:`v2 <v1_or_v2>` transforms generally accept an arbitrary number of leading
dimensions ``(..., C, H, W)`` and can handle batched images or batched videos.
The transformations that accept tensor images also accept batches of tensor
images. A Tensor Image is a tensor with ``(C, H, W)`` shape, where ``C`` is a
number of channels, ``H`` and ``W`` are image height and width. A batch of
Tensor Images is a tensor of ``(B, C, H, W)`` shape, where ``B`` is a number
of images in the batch.
.. _range_and_dtype:
Dtype and expected value range
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The expected range of the values of a tensor image is implicitly defined by
the tensor dtype. Tensor images with a float dtype are expected to have
values in ``[0, 1)``. Tensor images with an integer dtype are expected to
values in ``[0, 1]``. Tensor images with an integer dtype are expected to
have values in ``[0, MAX_DTYPE]`` where ``MAX_DTYPE`` is the largest value
that can be represented in that dtype.
that can be represented in that dtype. Typically, images of dtype
``torch.uint8`` are expected to have values in ``[0, 255]``.
Randomized transformations will apply the same transformation to all the
images of a given batch, but they will produce different transformations
across calls. For reproducible transformations across calls, you may use
:ref:`functional transforms <functional_transforms>`.
Use :class:`~torchvision.transforms.v2.ToDtype` to convert both the dtype and
range of the inputs.
The following examples illustrate the use of the available transforms:
.. _v1_or_v2:
* :ref:`sphx_glr_auto_examples_others_plot_transforms.py`
V1 or V2? Which one should I use?
---------------------------------
.. figure:: ../source/auto_examples/others/images/sphx_glr_plot_transforms_001.png
:align: center
:scale: 65%
**TL;DR** We recommending using the ``torchvision.transforms.v2`` transforms
instead of those in ``torchvision.transforms``. They're faster and they can do
more things. Just change the import and you should be good to go.
* :ref:`sphx_glr_auto_examples_others_plot_scripted_tensor_transforms.py`
In Torchvision 0.15 (March 2023), we released a new set of transforms available
in the ``torchvision.transforms.v2`` namespace. These transforms have a lot of
advantages compared to the v1 ones (in ``torchvision.transforms``):
.. figure:: ../source/auto_examples/others/images/sphx_glr_plot_scripted_tensor_transforms_001.png
:align: center
:scale: 30%
- They can transform images **but also** bounding boxes, masks, or videos. This
provides support for tasks beyond image classification: detection, segmentation,
video classification, etc.
- They support more transforms like :class:`~torchvision.transforms.v2.CutMix`
and :class:`~torchvision.transforms.v2.MixUp`.
- They're :ref:`faster <transforms_perf>`.
- They support arbitrary input structures (dicts, lists, tuples, etc.).
- Future improvements and features will be added to the v2 transforms only.
.. warning::
.. TODO: Add link to e2e example for first bullet point.
These transforms are **fully backward compatible** with the v1 ones, so if
you're already using tranforms from ``torchvision.transforms``, all you need to
do to is to update the import to ``torchvision.transforms.v2``. In terms of
output, there might be negligible differences due to implementation differences.
To learn more about the v2 transforms, check out
:ref:`sphx_glr_auto_examples_v2_transforms_plot_transforms_v2.py`.
.. TODO: make sure link is still good!!
.. note::
The v2 transforms are still BETA, but at this point we do not expect
disruptive changes to be made to their public APIs. We're planning to make
them fully stable in version 0.17. Please submit any feedback you may have
`here <https://github.com/pytorch/vision/issues/6753>`_.
.. _transforms_perf:
Performance considerations
--------------------------
We recommend the following guidelines to get the best performance out of the
transforms:
- Rely on the v2 transforms from ``torchvision.transforms.v2``
- Use tensors instead of PIL images
- Use ``torch.uint8`` dtype, especially for resizing
- Resize with bilinear or bicubic mode
This is what a typical transform pipeline could look like:
.. code:: python
from torchvision.transforms import v2
transforms = v2.Compose([
v2.ToImage(), # Convert to tensor, only needed if you had a PIL image
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
# ...
v2.RandomResizedCrop(size=(224, 224), antialias=True), # Or Resize(antialias=True)
# ...
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
The above should give you the best performance in a typical training environment
that relies on the :class:`torch.utils.data.DataLoader` with ``num_workers >
0``.
Transforms tend to be sensitive to the input strides / memory layout. Some
transforms will be faster with channels-first images while others prefer
channels-last. You may want to experiment a bit if you're chasing the very
best performance. Using :func:`torch.compile` on individual transforms may
also help factoring out the memory layout variable (e.g. on
:class:`~torchvision.transforms.v2.Normalize`). Note that we're talking about
**memory layout**, not tensor shape.
Note that resize transforms like :class:`~torchvision.transforms.v2.Resize`
and :class:`~torchvision.transforms.v2.RandomResizedCrop` typically prefer
channels-last input and tend **not** to benefit from :func:`torch.compile` at
this time.
.. _functional_transforms:
Since v0.8.0 all random transformations are using torch default random generator to sample random parameters.
It is a backward compatibility breaking change and user should set the random state as following:
Transform classes, functionals, and kernels
-------------------------------------------
.. code:: python
Transforms are available as classes like
:class:`~torchvision.transforms.v2.Resize`, but also as functionals like
:func:`~torchvision.transforms.v2.functional.resize` in the
``torchvision.transforms.v2.functional`` namespace.
This is very much like the :mod:`torch.nn` package which defines both classes
and functional equivalents in :mod:`torch.nn.functional`.
# Previous versions
# import random
# random.seed(12)
The functionals support PIL images, pure tensors, or :ref:`datapoints
<datapoints>`, e.g. both ``resize(image_tensor)`` and ``resize(bboxes)`` are
valid.
# Now
import torch
torch.manual_seed(17)
.. note::
Random transforms like :class:`~torchvision.transforms.v2.RandomCrop` will
randomly sample some parameter each time they're called. Their functional
counterpart (:func:`~torchvision.transforms.v2.functional.crop`) does not do
any kind of random sampling and thus have a slighlty different
parametrization. The ``get_params()`` class method of the transforms class
can be used to perform parameter sampling when using the functional APIs.
Please, keep in mind that the same seed for torch random generator and Python random generator will not
produce the same results.
The ``torchvision.transforms.v2.functional`` namespace also contains what we
call the "kernels". These are the low-level functions that implement the
core functionalities for specific types, e.g. ``resize_bounding_boxes`` or
```resized_crop_mask``. They are public, although not documented. Check the
`code
<https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/functional/__init__.py>`_
to see which ones are available (note that those starting with a leading
underscore are **not** public!). Kernels are only really useful if you want
:ref:`torchscript support <transforms_torchscript>` for types like bounding
boxes or masks.
Transforms scriptability
------------------------
.. _transforms_torchscript:
.. TODO: Add note about v2 scriptability (in next PR)
Torchscript support
-------------------
In order to script the transformations, please use ``torch.nn.Sequential`` instead of :class:`Compose`.
Most transform classes and functionals support torchscript. For composing
transforms, use :class:`torch.nn.Sequential` instead of ``Compose``:
.. code:: python
transforms = torch.nn.Sequential(
transforms.CenterCrop(10),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
CenterCrop(10),
Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)
scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor`` and does not require
`lambda` functions or ``PIL.Image``.
.. warning::
v2 transforms support torchscript, but if you call ``torch.jit.script()`` on
a v2 **class** transform, you'll actually end up with its (scripted) v1
equivalent. This may lead to slightly different results between the
scripted and eager executions due to implementation differences between v1
and v2.
If you really need torchscript support for the v2 tranforms, we recommend
scripting the **functionals** from the
``torchvision.transforms.v2.functional`` namespace to avoid surprises.
Also note that the functionals only support torchscript for pure tensors, which
are always treated as images. If you need torchscript support for other types
like bounding boxes or masks, you can rely on the :ref:`low-level kernels
<functional_transforms>`.
For any custom transformations to be used with ``torch.jit.script``, they should be derived from ``torch.nn.Module``.
V2 API reference - Recommended
------------------------------
Geometry
--------
^^^^^^^^
Resizing
""""""""
.. autosummary::
:toctree: generated/
:template: class.rst
Resize
v2.Resize
v2.ScaleJitter
v2.RandomShortestSize
v2.RandomResize
RandomCrop
Functionals
.. autosummary::
:toctree: generated/
:template: function.rst
v2.functional.resize
Cropping
""""""""
.. autosummary::
:toctree: generated/
:template: class.rst
v2.RandomCrop
RandomResizedCrop
v2.RandomResizedCrop
v2.RandomIoUCrop
CenterCrop
v2.CenterCrop
FiveCrop
v2.FiveCrop
TenCrop
v2.TenCrop
Pad
Functionals
.. autosummary::
:toctree: generated/
:template: function.rst
v2.functional.crop
v2.functional.resized_crop
v2.functional.ten_crop
v2.functional.center_crop
v2.functional.five_crop
Others
""""""
.. autosummary::
:toctree: generated/
:template: class.rst
v2.RandomHorizontalFlip
v2.RandomVerticalFlip
v2.Pad
v2.RandomZoomOut
RandomRotation
v2.RandomRotation
RandomAffine
v2.RandomAffine
RandomPerspective
v2.RandomPerspective
ElasticTransform
v2.ElasticTransform
RandomHorizontalFlip
v2.RandomHorizontalFlip
RandomVerticalFlip
v2.RandomVerticalFlip
Functionals
.. autosummary::
:toctree: generated/
:template: function.rst
v2.functional.horizontal_flip
v2.functional.vertical_flip
v2.functional.pad
v2.functional.rotate
v2.functional.affine
v2.functional.perspective
v2.functional.elastic
Color
-----
^^^^^
.. autosummary::
:toctree: generated/
:template: class.rst
ColorJitter
v2.ColorJitter
v2.RandomChannelPermutation
v2.RandomPhotometricDistort
Grayscale
v2.Grayscale
RandomGrayscale
v2.RandomGrayscale
GaussianBlur
v2.GaussianBlur
RandomInvert
v2.RandomInvert
RandomPosterize
v2.RandomPosterize
RandomSolarize
v2.RandomSolarize
RandomAdjustSharpness
v2.RandomAdjustSharpness
RandomAutocontrast
v2.RandomAutocontrast
RandomEqualize
v2.RandomEqualize
Functionals
.. autosummary::
:toctree: generated/
:template: function.rst
v2.functional.permute_channels
v2.functional.rgb_to_grayscale
v2.functional.to_grayscale
v2.functional.gaussian_blur
v2.functional.invert
v2.functional.posterize
v2.functional.solarize
v2.functional.adjust_sharpness
v2.functional.autocontrast
v2.functional.adjust_contrast
v2.functional.equalize
v2.functional.adjust_brightness
v2.functional.adjust_saturation
v2.functional.adjust_hue
v2.functional.adjust_gamma
Composition
-----------
^^^^^^^^^^^
.. autosummary::
:toctree: generated/
:template: class.rst
Compose
v2.Compose
RandomApply
v2.RandomApply
RandomChoice
v2.RandomChoice
RandomOrder
v2.RandomOrder
Miscellaneous
-------------
^^^^^^^^^^^^^
.. autosummary::
:toctree: generated/
:template: class.rst
LinearTransformation
v2.LinearTransformation
Normalize
v2.Normalize
RandomErasing
v2.RandomErasing
Lambda
v2.Lambda
v2.SanitizeBoundingBoxes
v2.ClampBoundingBoxes
v2.UniformTemporalSubsample
Functionals
.. autosummary::
:toctree: generated/
:template: function.rst
v2.functional.normalize
v2.functional.erase
v2.functional.clamp_bounding_boxes
v2.functional.uniform_temporal_subsample
.. _conversion_transforms:
Conversion
----------
^^^^^^^^^^
.. note::
Beware, some of these conversion transforms below will scale the values
while performing the conversion, while some may not do any scaling. By
scaling, we mean e.g. that a ``uint8`` -> ``float32`` would map the [0,
255] range into [0, 1] (and vice-versa).
255] range into [0, 1] (and vice-versa). See :ref:`range_and_dtype`.
.. autosummary::
:toctree: generated/
:template: class.rst
ToPILImage
v2.ToPILImage
ToTensor
v2.ToTensor
PILToTensor
v2.PILToTensor
v2.ToImage
ConvertImageDtype
v2.ConvertImageDtype
v2.ToPureTensor
v2.PILToTensor
v2.ToPILImage
v2.ToDtype
v2.ConvertBoundingBoxFormat
v2.ToPureTensor
functionals
.. autosummary::
:toctree: generated/
:template: functional.rst
v2.functional.to_image
v2.functional.pil_to_tensor
v2.functional.to_pil_image
v2.functional.to_dtype
v2.functional.convert_bounding_box_format
Deprecated
.. autosummary::
:toctree: generated/
:template: class.rst
v2.ToTensor
v2.functional.to_tensor
v2.ConvertImageDtype
v2.functional.convert_image_dtype
Auto-Augmentation
-----------------
^^^^^^^^^^^^^^^^^
`AutoAugment <https://arxiv.org/pdf/1805.09501.pdf>`_ is a common Data Augmentation technique that can improve the accuracy of Image Classification models.
Though the data augmentation policies are directly linked to their trained dataset, empirical studies show that
......@@ -252,18 +460,14 @@ The new transform can be used standalone or mixed-and-matched with existing tran
:toctree: generated/
:template: class.rst
AutoAugmentPolicy
AutoAugment
v2.AutoAugment
RandAugment
v2.RandAugment
TrivialAugmentWide
v2.TrivialAugmentWide
AugMix
v2.AugMix
CutMix - MixUp
--------------
^^^^^^^^^^^^^^
CutMix and MixUp are special transforms that
are meant to be used on batches rather than on individual images, because they
......@@ -278,64 +482,126 @@ are combining pairs of images together. These can be used after the dataloader
v2.CutMix
v2.MixUp
.. _functional_transforms:
Developer tools
^^^^^^^^^^^^^^^
Functional Transforms
---------------------
.. autosummary::
:toctree: generated/
:template: function.rst
.. currentmodule:: torchvision.transforms.functional
v2.functional.register_kernel
.. note::
You'll find below the documentation for the existing
``torchvision.transforms.functional`` namespace. The
``torchvision.transforms.v2.functional`` namespace exists as well and can be
used! The same functionals are present, so you simply need to change your
import to rely on the ``v2`` namespace.
V1 API Reference
----------------
Functional transforms give you fine-grained control of the transformation pipeline.
As opposed to the transformations above, functional transforms don't contain a random number
generator for their parameters.
That means you have to specify/generate all parameters, but the functional transform will give you
reproducible results across calls.
Geometry
^^^^^^^^
Example:
you can apply a functional transform with the same parameters to multiple images like this:
.. autosummary::
:toctree: generated/
:template: class.rst
.. code:: python
Resize
RandomCrop
RandomResizedCrop
CenterCrop
FiveCrop
TenCrop
Pad
RandomRotation
RandomAffine
RandomPerspective
ElasticTransform
RandomHorizontalFlip
RandomVerticalFlip
import torchvision.transforms.functional as TF
import random
def my_segmentation_transforms(image, segmentation):
if random.random() > 0.5:
angle = random.randint(-30, 30)
image = TF.rotate(image, angle)
segmentation = TF.rotate(segmentation, angle)
# more transforms ...
return image, segmentation
Color
^^^^^
.. autosummary::
:toctree: generated/
:template: class.rst
Example:
you can use a functional transform to build transform classes with custom behavior:
ColorJitter
Grayscale
RandomGrayscale
GaussianBlur
RandomInvert
RandomPosterize
RandomSolarize
RandomAdjustSharpness
RandomAutocontrast
RandomEqualize
.. code:: python
Composition
^^^^^^^^^^^
import torchvision.transforms.functional as TF
import random
.. autosummary::
:toctree: generated/
:template: class.rst
class MyRotationTransform:
"""Rotate by one of the given angles."""
Compose
RandomApply
RandomChoice
RandomOrder
def __init__(self, angles):
self.angles = angles
Miscellaneous
^^^^^^^^^^^^^
def __call__(self, x):
angle = random.choice(self.angles)
return TF.rotate(x, angle)
.. autosummary::
:toctree: generated/
:template: class.rst
rotation_transform = MyRotationTransform(angles=[-30, -15, 0, 15, 30])
LinearTransformation
Normalize
RandomErasing
Lambda
Conversion
^^^^^^^^^^
.. note::
Beware, some of these conversion transforms below will scale the values
while performing the conversion, while some may not do any scaling. By
scaling, we mean e.g. that a ``uint8`` -> ``float32`` would map the [0,
255] range into [0, 1] (and vice-versa). See :ref:`range_and_dtype`.
.. autosummary::
:toctree: generated/
:template: class.rst
ToPILImage
ToTensor
PILToTensor
ConvertImageDtype
Auto-Augmentation
^^^^^^^^^^^^^^^^^
`AutoAugment <https://arxiv.org/pdf/1805.09501.pdf>`_ is a common Data Augmentation technique that can improve the accuracy of Image Classification models.
Though the data augmentation policies are directly linked to their trained dataset, empirical studies show that
ImageNet policies provide significant improvements when applied to other datasets.
In TorchVision we implemented 3 policies learned on the following datasets: ImageNet, CIFAR10 and SVHN.
The new transform can be used standalone or mixed-and-matched with existing transforms:
.. autosummary::
:toctree: generated/
:template: class.rst
AutoAugmentPolicy
AutoAugment
RandAugment
TrivialAugmentWide
AugMix
Functional Transforms
^^^^^^^^^^^^^^^^^^^^^
.. currentmodule:: torchvision.transforms.functional
.. autosummary::
:toctree: generated/
......@@ -376,14 +642,3 @@ you can use a functional transform to build transform classes with custom behavi
to_pil_image
to_tensor
vflip
Developer tools
---------------
.. currentmodule:: torchvision.transforms.v2.functional
.. autosummary::
:toctree: generated/
:template: function.rst
register_kernel
.. _transforms_gallery:
V2 transforms
-------------
......@@ -235,7 +235,8 @@ assert isinstance(new_bboxes, datapoints.BoundingBoxes)
# %%
# Alternatively, you can use the :func:`~torchvision.datapoints.set_return_type`
# as a global config setting for the whole program, or as a context manager:
# as a global config setting for the whole program, or as a context manager
# (read its docs to learn more about caveats):
with datapoints.set_return_type("datapoint"):
new_bboxes = bboxes + 3
......@@ -274,13 +275,13 @@ assert isinstance(new_bboxes, datapoints.BoundingBoxes)
# ^^^^^^^^^^
#
# There are a few exceptions to this "unwrapping" rule:
# :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`,
# :meth:`torch.Tensor.detach`, and :meth:`~torch.Tensor.requires_grad_` retain
# the datapoint type.
#
# 1. Operations like :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`,
# :meth:`torch.Tensor.detach` and :meth:`~torch.Tensor.requires_grad_` retain
# the datapoint type.
# 2. Inplace operations on datapoints like ``.add_()`` preserve they type. However,
# the **returned** value of inplace operations will be unwrapped into a pure
# tensor:
# Inplace operations on datapoints like ``obj.add_()`` will preserve the type of
# ``obj``. However, the **returned** value of inplace operations will be a pure
# tensor:
image = datapoints.Image([[[0, 1], [1, 0]]])
......
......@@ -14,7 +14,7 @@ from torchvision import datapoints
from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_bounding_box_format
from torchvision.transforms.v2.utils import is_pure_tensor
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS
......@@ -390,7 +390,7 @@ class TestDispatchers:
assert isinstance(output, type(datapoint))
if isinstance(datapoint, datapoints.BoundingBoxes) and info.dispatcher is not F.convert_format_bounding_boxes:
if isinstance(datapoint, datapoints.BoundingBoxes) and info.dispatcher is not F.convert_bounding_box_format:
assert output.format == datapoint.format
@pytest.mark.parametrize(
......@@ -445,7 +445,7 @@ class TestDispatchers:
[
info
for info in DISPATCHER_INFOS
if datapoints.BoundingBoxes in info.kernels and info.dispatcher is not F.convert_format_bounding_boxes
if datapoints.BoundingBoxes in info.kernels and info.dispatcher is not F.convert_bounding_box_format
],
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.BoundingBoxes),
)
......@@ -532,19 +532,19 @@ class TestConvertFormatBoundingBoxes:
)
def test_missing_new_format(self, inpt, old_format):
with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")):
F.convert_format_bounding_boxes(inpt, old_format)
F.convert_bounding_box_format(inpt, old_format)
def test_pure_tensor_insufficient_metadata(self):
pure_tensor = next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor)
with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
F.convert_format_bounding_boxes(pure_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)
F.convert_bounding_box_format(pure_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)
def test_datapoint_explicit_metadata(self):
datapoint = next(make_multiple_bounding_boxes())
with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")):
F.convert_format_bounding_boxes(
F.convert_bounding_box_format(
datapoint, old_format=datapoint.format, new_format=datapoints.BoundingBoxFormat.CXCYWH
)
......@@ -611,7 +611,7 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
]
in_boxes = torch.tensor(in_boxes, device=device)
if format != datapoints.BoundingBoxFormat.XYXY:
in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
in_boxes = convert_bounding_box_format(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
expected_bboxes = clamp_bounding_boxes(
datapoints.BoundingBoxes(expected_bboxes, format="XYXY", canvas_size=canvas_size)
......@@ -627,7 +627,7 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
)
if format != datapoints.BoundingBoxFormat.XYXY:
output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
output_boxes = convert_bounding_box_format(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
torch.testing.assert_close(output_canvas_size, canvas_size)
......@@ -681,12 +681,12 @@ def test_correctness_resized_crop_bounding_boxes(device, format, top, left, heig
in_boxes, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=canvas_size, device=device
)
if format != datapoints.BoundingBoxFormat.XYXY:
in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
in_boxes = convert_bounding_box_format(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size)
if format != datapoints.BoundingBoxFormat.XYXY:
output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
output_boxes = convert_bounding_box_format(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes, expected_bboxes)
torch.testing.assert_close(output_canvas_size, size)
......@@ -714,13 +714,13 @@ def test_correctness_pad_bounding_boxes(device, padding):
bbox = (
bbox.clone()
if format == datapoints.BoundingBoxFormat.XYXY
else convert_format_bounding_boxes(bbox, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
else convert_bounding_box_format(bbox, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
)
bbox[0::2] += pad_left
bbox[1::2] += pad_up
bbox = convert_format_bounding_boxes(bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format)
bbox = convert_bounding_box_format(bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format)
if bbox.dtype != dtype:
# Temporary cast to original dtype
# e.g. float32 -> int
......@@ -785,9 +785,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
]
)
bbox_xyxy = convert_format_bounding_boxes(
bbox, old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY
)
bbox_xyxy = convert_bounding_box_format(bbox, old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY)
points = np.array(
[
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
......@@ -808,7 +806,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
]
)
out_bbox = torch.from_numpy(out_bbox)
out_bbox = convert_format_bounding_boxes(
out_bbox = convert_bounding_box_format(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_
)
return clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_).to(bbox)
......@@ -848,7 +846,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
def test_correctness_center_crop_bounding_boxes(device, output_size):
def _compute_expected_bbox(bbox, format_, canvas_size_, output_size_):
dtype = bbox.dtype
bbox = convert_format_bounding_boxes(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH)
bbox = convert_bounding_box_format(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH)
if len(output_size_) == 1:
output_size_.append(output_size_[-1])
......@@ -862,7 +860,7 @@ def test_correctness_center_crop_bounding_boxes(device, output_size):
bbox[3].item(),
]
out_bbox = torch.tensor(out_bbox)
out_bbox = convert_format_bounding_boxes(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_)
out_bbox = convert_bounding_box_format(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_)
out_bbox = clamp_bounding_boxes(out_bbox, format=format_, canvas_size=output_size)
return out_bbox.to(dtype=dtype, device=bbox.device)
......
......@@ -342,7 +342,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
in_dtype = bbox.dtype
if not torch.is_floating_point(bbox):
bbox = bbox.float()
bbox_xyxy = F.convert_format_bounding_boxes(
bbox_xyxy = F.convert_bounding_box_format(
bbox.as_subclass(torch.Tensor),
old_format=format,
new_format=datapoints.BoundingBoxFormat.XYXY,
......@@ -366,7 +366,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
],
dtype=bbox_xyxy.dtype,
)
out_bbox = F.convert_format_bounding_boxes(
out_bbox = F.convert_bounding_box_format(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
)
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
......
......@@ -374,8 +374,8 @@ DISPATCHER_INFOS = [
],
),
DispatcherInfo(
F.convert_format_bounding_boxes,
kernels={datapoints.BoundingBoxes: F.convert_format_bounding_boxes},
F.convert_bounding_box_format,
kernels={datapoints.BoundingBoxes: F.convert_bounding_box_format},
test_marks=[
skip_dispatch_datapoint,
],
......
......@@ -190,7 +190,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
in_dtype = bbox.dtype
if not torch.is_floating_point(bbox):
bbox = bbox.float()
bbox_xyxy = F.convert_format_bounding_boxes(
bbox_xyxy = F.convert_bounding_box_format(
bbox.as_subclass(torch.Tensor),
old_format=format_,
new_format=datapoints.BoundingBoxFormat.XYXY,
......@@ -214,7 +214,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
],
dtype=bbox_xyxy.dtype,
)
out_bbox = F.convert_format_bounding_boxes(
out_bbox = F.convert_bounding_box_format(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True
)
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
......@@ -227,30 +227,30 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
).reshape(bounding_boxes.shape)
def sample_inputs_convert_format_bounding_boxes():
def sample_inputs_convert_bounding_box_format():
formats = list(datapoints.BoundingBoxFormat)
for bounding_boxes_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
yield ArgsKwargs(bounding_boxes_loader, old_format=bounding_boxes_loader.format, new_format=new_format)
def reference_convert_format_bounding_boxes(bounding_boxes, old_format, new_format):
def reference_convert_bounding_box_format(bounding_boxes, old_format, new_format):
return torchvision.ops.box_convert(
bounding_boxes, in_fmt=old_format.name.lower(), out_fmt=new_format.name.lower()
).to(bounding_boxes.dtype)
def reference_inputs_convert_format_bounding_boxes():
for args_kwargs in sample_inputs_convert_format_bounding_boxes():
def reference_inputs_convert_bounding_box_format():
for args_kwargs in sample_inputs_convert_bounding_box_format():
if len(args_kwargs.args[0].shape) == 2:
yield args_kwargs
KERNEL_INFOS.append(
KernelInfo(
F.convert_format_bounding_boxes,
sample_inputs_fn=sample_inputs_convert_format_bounding_boxes,
reference_fn=reference_convert_format_bounding_boxes,
reference_inputs_fn=reference_inputs_convert_format_bounding_boxes,
F.convert_bounding_box_format,
sample_inputs_fn=sample_inputs_convert_bounding_box_format,
reference_fn=reference_convert_bounding_box_format,
reference_inputs_fn=reference_inputs_convert_bounding_box_format,
logs_usage=True,
closeness_kwargs={
(("TestKernels", "test_against_reference"), torch.int64, "cpu"): dict(atol=1, rtol=0),
......
......@@ -368,7 +368,7 @@ def coco_dectection_wrapper_factory(dataset, target_keys):
target["image_id"] = image_id
if "boxes" in target_keys:
target["boxes"] = F.convert_format_bounding_boxes(
target["boxes"] = F.convert_bounding_box_format(
datapoints.BoundingBoxes(
batched_target["bbox"],
format=datapoints.BoundingBoxFormat.XYWH,
......@@ -489,7 +489,7 @@ def celeba_wrapper_factory(dataset, target_keys):
target,
target_types=dataset.target_type,
type_wrappers={
"bbox": lambda item: F.convert_format_bounding_boxes(
"bbox": lambda item: F.convert_bounding_box_format(
datapoints.BoundingBoxes(
item,
format=datapoints.BoundingBoxFormat.XYWH,
......@@ -636,7 +636,7 @@ def widerface_wrapper(dataset, target_keys):
target = {key: target[key] for key in target_keys}
if "bbox" in target_keys:
target["bbox"] = F.convert_format_bounding_boxes(
target["bbox"] = F.convert_bounding_box_format(
datapoints.BoundingBoxes(
target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, canvas_size=(image.height, image.width)
),
......
......@@ -22,6 +22,13 @@ def set_return_type(return_type: str):
``torchvision`` transforms or functionals, which will always return as
output the same type that was passed as input.
.. warning::
We recommend using :class:`~torchvision.transforms.v2.ToPureTensor` at
the end of your transform pipelines if you use
``set_return_type("dataptoint")``. This will avoid the
``__torch_function__`` overhead in the models ``forward()``.
Can be used as a global flag for the entire program:
.. code:: python
......
......@@ -80,7 +80,7 @@ class SimpleCopyPaste(Transform):
# There is a similar +1 in other reference implementations:
# https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422
xyxy_boxes[:, 2:] += 1
boxes = F.convert_format_bounding_boxes(
boxes = F.convert_bounding_box_format(
xyxy_boxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox_format, inplace=True
)
out_target["boxes"] = torch.cat([boxes, paste_boxes])
......@@ -89,7 +89,7 @@ class SimpleCopyPaste(Transform):
out_target["labels"] = torch.cat([labels, paste_labels])
# Check for degenerated boxes and remove them
boxes = F.convert_format_bounding_boxes(
boxes = F.convert_bounding_box_format(
out_target["boxes"], old_format=bbox_format, new_format=datapoints.BoundingBoxFormat.XYXY
)
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
......
......@@ -76,7 +76,7 @@ class FixedSizeCrop(Transform):
width=new_width,
)
bounding_boxes = F.clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size)
height_and_width = F.convert_format_bounding_boxes(
height_and_width = F.convert_bounding_box_format(
bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYWH
)[..., 2:]
is_valid = torch.all(height_and_width > 0, dim=-1)
......
......@@ -10,13 +10,15 @@ from torchvision.transforms.v2 import Transform
class ToTensor(Transform):
"""[BETA] Convert a PIL Image or ndarray to tensor and scale the values accordingly.
"""[BETA] [DEPRECATED] Use ``v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`` instead.
Convert a PIL Image or ndarray to tensor and scale the values accordingly.
.. v2betastatus:: ToTensor transform
.. warning::
:class:`v2.ToTensor` is deprecated and will be removed in a future release.
Please use instead ``v2.Compose([transforms.ToImageTensor(), v2.ToDtype(torch.float32, scale=True)])``.
Please use instead ``v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])``.
This transform does not support torchscript.
......@@ -40,7 +42,7 @@ class ToTensor(Transform):
def __init__(self) -> None:
warnings.warn(
"The transform `ToTensor()` is deprecated and will be removed in a future release. "
"Instead, please use `v2.Compose([transforms.ToImageTensor(), v2.ToDtype(torch.float32, scale=True)])`."
"Instead, please use `v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`."
)
super().__init__()
......
......@@ -1186,7 +1186,7 @@ class RandomIoUCrop(Transform):
continue
# check for any valid boxes with centers within the crop area
xyxy_bboxes = F.convert_format_bounding_boxes(
xyxy_bboxes = F.convert_bounding_box_format(
bboxes.as_subclass(torch.Tensor),
bboxes.format,
datapoints.BoundingBoxFormat.XYXY,
......
......@@ -24,7 +24,7 @@ class ConvertBoundingBoxFormat(Transform):
self.format = format
def _transform(self, inpt: datapoints.BoundingBoxes, params: Dict[str, Any]) -> datapoints.BoundingBoxes:
return F.convert_format_bounding_boxes(inpt, new_format=self.format) # type: ignore[return-value]
return F.convert_bounding_box_format(inpt, new_format=self.format) # type: ignore[return-value]
class ClampBoundingBoxes(Transform):
......
......@@ -293,7 +293,9 @@ class ToDtype(Transform):
class ConvertImageDtype(Transform):
"""[BETA] Convert input image to the given ``dtype`` and scale the values accordingly.
"""[BETA] [DEPRECATED] Use ``v2.ToDtype(dtype, scale=True)`` instead.
Convert input image to the given ``dtype`` and scale the values accordingly.
.. v2betastatus:: ConvertImageDtype transform
......@@ -388,7 +390,7 @@ class SanitizeBoundingBoxes(Transform):
boxes = cast(
datapoints.BoundingBoxes,
F.convert_format_bounding_boxes(
F.convert_bounding_box_format(
boxes,
new_format=datapoints.BoundingBoxFormat.XYXY,
),
......
......@@ -4,7 +4,7 @@ from ._utils import is_pure_tensor, register_kernel # usort: skip
from ._meta import (
clamp_bounding_boxes,
convert_format_bounding_boxes,
convert_bounding_box_format,
get_dimensions_image,
_get_dimensions_image_pil,
get_dimensions_video,
......
......@@ -17,6 +17,7 @@ def erase(
v: torch.Tensor,
inplace: bool = False,
) -> torch.Tensor:
"""[BETA] See :class:`~torchvision.transforms.v2.RandomErase` for details."""
if torch.jit.is_scripting():
return erase_image(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
......
......@@ -15,6 +15,7 @@ from ._utils import _get_kernel, _register_kernel_internal
def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:
"""[BETA] See :class:`~torchvision.transforms.v2.Grayscale` for details."""
if torch.jit.is_scripting():
return rgb_to_grayscale_image(inpt, num_output_channels=num_output_channels)
......@@ -69,6 +70,7 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te
def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Tensor:
"""Adjust brightness."""
if torch.jit.is_scripting():
return adjust_brightness_image(inpt, brightness_factor=brightness_factor)
......@@ -106,6 +108,7 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Tensor:
"""Adjust saturation."""
if torch.jit.is_scripting():
return adjust_saturation_image(inpt, saturation_factor=saturation_factor)
......@@ -144,6 +147,7 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor:
"""[BETA] See :class:`~torchvision.transforms.RandomAutocontrast`"""
if torch.jit.is_scripting():
return adjust_contrast_image(inpt, contrast_factor=contrast_factor)
......@@ -182,6 +186,7 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
"""[BETA] See :class:`~torchvision.transforms.RandomAdjustSharpness`"""
if torch.jit.is_scripting():
return adjust_sharpness_image(inpt, sharpness_factor=sharpness_factor)
......@@ -254,6 +259,7 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
def adjust_hue(inpt: torch.Tensor, hue_factor: float) -> torch.Tensor:
"""Adjust hue"""
if torch.jit.is_scripting():
return adjust_hue_image(inpt, hue_factor=hue_factor)
......@@ -371,6 +377,7 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor:
"""Adjust gamma."""
if torch.jit.is_scripting():
return adjust_gamma_image(inpt, gamma=gamma, gain=gain)
......@@ -410,6 +417,7 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor:
"""[BETA] See :class:`~torchvision.transforms.v2.RandomPosterize` for details."""
if torch.jit.is_scripting():
return posterize_image(inpt, bits=bits)
......@@ -443,6 +451,7 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor:
"""[BETA] See :class:`~torchvision.transforms.v2.RandomSolarize` for details."""
if torch.jit.is_scripting():
return solarize_image(inpt, threshold=threshold)
......@@ -470,6 +479,7 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
def autocontrast(inpt: torch.Tensor) -> torch.Tensor:
"""[BETA] See :class:`~torchvision.transforms.v2.RandomAutocontrast` for details."""
if torch.jit.is_scripting():
return autocontrast_image(inpt)
......@@ -519,6 +529,7 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
def equalize(inpt: torch.Tensor) -> torch.Tensor:
"""[BETA] See :class:`~torchvision.transforms.v2.RandomEqualize` for details."""
if torch.jit.is_scripting():
return equalize_image(inpt)
......@@ -608,6 +619,7 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
def invert(inpt: torch.Tensor) -> torch.Tensor:
"""[BETA] See :func:`~torchvision.transforms.v2.RandomInvert`."""
if torch.jit.is_scripting():
return invert_image(inpt)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment