Commit cc26cd81 authored by panning's avatar panning
Browse files

merge v0.16.0

parents f78f29f5 fbb4cc54
...@@ -11,7 +11,7 @@ Search for Mobile <https://arxiv.org/pdf/1807.11626.pdf>`__ paper. ...@@ -11,7 +11,7 @@ Search for Mobile <https://arxiv.org/pdf/1807.11626.pdf>`__ paper.
Model builders Model builders
-------------- --------------
The following model builders can be used to instanciate an MNASNet model. The following model builders can be used to instantiate an MNASNet model.
All the model builders internally rely on the All the model builders internally rely on the
``torchvision.models.mnasnet.MNASNet`` base class. Please refer to the `source ``torchvision.models.mnasnet.MNASNet`` base class. Please refer to the `source
code code
......
...@@ -12,7 +12,7 @@ Model builders ...@@ -12,7 +12,7 @@ Model builders
-------------- --------------
The following model builders can be used to instantiate a RetinaNet model, with or The following model builders can be used to instantiate a RetinaNet model, with or
without pre-trained weights. All the model buidlers internally rely on the without pre-trained weights. All the model builders internally rely on the
``torchvision.models.detection.retinanet.RetinaNet`` base class. Please refer to the `source code ``torchvision.models.detection.retinanet.RetinaNet`` base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py>`_ for <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py>`_ for
more details about this class. more details about this class.
......
...@@ -12,7 +12,7 @@ The SSD model is based on the `SSD: Single Shot MultiBox Detector ...@@ -12,7 +12,7 @@ The SSD model is based on the `SSD: Single Shot MultiBox Detector
Model builders Model builders
-------------- --------------
The following model builders can be used to instanciate a SSD model, with or The following model builders can be used to instantiate a SSD model, with or
without pre-trained weights. All the model builders internally rely on the without pre-trained weights. All the model builders internally rely on the
``torchvision.models.detection.SSD`` base class. Please refer to the `source ``torchvision.models.detection.SSD`` base class. Please refer to the `source
code code
......
...@@ -15,7 +15,7 @@ Model builders ...@@ -15,7 +15,7 @@ Model builders
-------------- --------------
The following model builders can be used to instantiate an SwinTransformer model (original and V2) with and without pre-trained weights. The following model builders can be used to instantiate an SwinTransformer model (original and V2) with and without pre-trained weights.
All the model builders internally rely on the ``torchvision.models.swin_transformer.SwinTransformer`` All the model builders internally rely on the ``torchvision.models.swin_transformer.SwinTransformer``
base class. Please refer to the `source code base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_ for <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_ for
more details about this class. more details about this class.
......
...@@ -11,7 +11,7 @@ Model builders ...@@ -11,7 +11,7 @@ Model builders
-------------- --------------
The following model builders can be used to instantiate a VGG model, with or The following model builders can be used to instantiate a VGG model, with or
without pre-trained weights. All the model buidlers internally rely on the without pre-trained weights. All the model builders internally rely on the
``torchvision.models.vgg.VGG`` base class. Please refer to the `source code ``torchvision.models.vgg.VGG`` base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_ for <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_ for
more details about this class. more details about this class.
......
Video SwinTransformer
=====================
.. currentmodule:: torchvision.models.video
The Video SwinTransformer model is based on the `Video Swin Transformer <https://arxiv.org/abs/2106.13230>`__ paper.
.. betastatus:: video module
Model builders
--------------
The following model builders can be used to instantiate a VideoResNet model, with or
without pre-trained weights. All the model builders internally rely on the
``torchvision.models.video.swin_transformer.SwinTransformer3d`` base class. Please refer to the `source
code
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/swin_transformer.py>`_ for
more details about this class.
.. autosummary::
:toctree: generated/
:template: function.rst
swin3d_t
swin3d_s
swin3d_b
...@@ -5,123 +5,549 @@ Transforming and augmenting images ...@@ -5,123 +5,549 @@ Transforming and augmenting images
.. currentmodule:: torchvision.transforms .. currentmodule:: torchvision.transforms
Transforms are common image transformations available in the Torchvision supports common computer vision transformations in the
``torchvision.transforms`` module. They can be chained together using ``torchvision.transforms`` and ``torchvision.transforms.v2`` modules. Transforms
:class:`Compose`. can be used to transform or augment data for training or inference of different
Most transform classes have a function equivalent: :ref:`functional tasks (image classification, detection, segmentation, video classification).
transforms <functional_transforms>` give fine-grained control over the
transformations. .. code:: python
This is useful if you have to build a more complex transformation pipeline
(e.g. in the case of segmentation tasks). # Image Classification
import torch
Most transformations accept both `PIL <https://pillow.readthedocs.io>`_ from torchvision.transforms import v2
images and tensor images, although some transformations are :ref:`PIL-only
<transforms_pil_only>` and some are :ref:`tensor-only H, W = 32, 32
<transforms_tensor_only>`. The :ref:`conversion_transforms` may be used to img = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)
convert to and from PIL images.
transforms = v2.Compose([
The transformations that accept tensor images also accept batches of tensor v2.RandomResizedCrop(size=(224, 224), antialias=True),
images. A Tensor Image is a tensor with ``(C, H, W)`` shape, where ``C`` is a v2.RandomHorizontalFlip(p=0.5),
number of channels, ``H`` and ``W`` are image height and width. A batch of v2.ToDtype(torch.float32, scale=True),
Tensor Images is a tensor of ``(B, C, H, W)`` shape, where ``B`` is a number v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
of images in the batch. ])
img = transforms(img)
.. code:: python
# Detection (re-using imports and transforms from above)
from torchvision import tv_tensors
img = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)
boxes = torch.randint(0, H // 2, size=(3, 4))
boxes[:, 2:] += boxes[:, :2]
boxes = tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=(H, W))
# The same transforms can be used!
img, boxes = transforms(img, boxes)
# And you can pass arbitrary input structures
output_dict = transforms({"image": img, "boxes": boxes})
Transforms are typically passed as the ``transform`` or ``transforms`` argument
to the :ref:`Datasets <datasets>`.
Start here
----------
Whether you're new to Torchvision transforms, or you're already experienced with
them, we encourage you to start with
:ref:`sphx_glr_auto_examples_transforms_plot_transforms_getting_started.py` in
order to learn more about what can be done with the new v2 transforms.
Then, browse the sections in below this page for general information and
performance tips. The available transforms and functionals are listed in the
:ref:`API reference <v2_api_ref>`.
More information and tutorials can also be found in our :ref:`example gallery
<gallery>`, e.g. :ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`
or :ref:`sphx_glr_auto_examples_transforms_plot_custom_transforms.py`.
.. _conventions:
Supported input types and conventions
-------------------------------------
Most transformations accept both `PIL <https://pillow.readthedocs.io>`_ images
and tensor inputs. Both CPU and CUDA tensors are supported.
The result of both backends (PIL or Tensors) should be very
close. In general, we recommend relying on the tensor backend :ref:`for
performance <transforms_perf>`. The :ref:`conversion transforms
<conversion_transforms>` may be used to convert to and from PIL images, or for
converting dtypes and ranges.
Tensor image are expected to be of shape ``(C, H, W)``, where ``C`` is the
number of channels, and ``H`` and ``W`` refer to height and width. Most
transforms support batched tensor input. A batch of Tensor images is a tensor of
shape ``(N, C, H, W)``, where ``N`` is a number of images in the batch. The
:ref:`v2 <v1_or_v2>` transforms generally accept an arbitrary number of leading
dimensions ``(..., C, H, W)`` and can handle batched images or batched videos.
.. _range_and_dtype:
Dtype and expected value range
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The expected range of the values of a tensor image is implicitly defined by The expected range of the values of a tensor image is implicitly defined by
the tensor dtype. Tensor images with a float dtype are expected to have the tensor dtype. Tensor images with a float dtype are expected to have
values in ``[0, 1)``. Tensor images with an integer dtype are expected to values in ``[0, 1]``. Tensor images with an integer dtype are expected to
have values in ``[0, MAX_DTYPE]`` where ``MAX_DTYPE`` is the largest value have values in ``[0, MAX_DTYPE]`` where ``MAX_DTYPE`` is the largest value
that can be represented in that dtype. that can be represented in that dtype. Typically, images of dtype
``torch.uint8`` are expected to have values in ``[0, 255]``.
Randomized transformations will apply the same transformation to all the Use :class:`~torchvision.transforms.v2.ToDtype` to convert both the dtype and
images of a given batch, but they will produce different transformations range of the inputs.
across calls. For reproducible transformations across calls, you may use
:ref:`functional transforms <functional_transforms>`.
The following examples illustrate the use of the available transforms: .. _v1_or_v2:
* :ref:`sphx_glr_auto_examples_plot_transforms.py` V1 or V2? Which one should I use?
---------------------------------
.. figure:: ../source/auto_examples/images/sphx_glr_plot_transforms_001.png **TL;DR** We recommending using the ``torchvision.transforms.v2`` transforms
:align: center instead of those in ``torchvision.transforms``. They're faster and they can do
:scale: 65% more things. Just change the import and you should be good to go.
In Torchvision 0.15 (March 2023), we released a new set of transforms available
in the ``torchvision.transforms.v2`` namespace. These transforms have a lot of
advantages compared to the v1 ones (in ``torchvision.transforms``):
- They can transform images **but also** bounding boxes, masks, or videos. This
provides support for tasks beyond image classification: detection, segmentation,
video classification, etc. See
:ref:`sphx_glr_auto_examples_transforms_plot_transforms_getting_started.py`
and :ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`.
- They support more transforms like :class:`~torchvision.transforms.v2.CutMix`
and :class:`~torchvision.transforms.v2.MixUp`. See
:ref:`sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py`.
- They're :ref:`faster <transforms_perf>`.
- They support arbitrary input structures (dicts, lists, tuples, etc.).
- Future improvements and features will be added to the v2 transforms only.
These transforms are **fully backward compatible** with the v1 ones, so if
you're already using tranforms from ``torchvision.transforms``, all you need to
do to is to update the import to ``torchvision.transforms.v2``. In terms of
output, there might be negligible differences due to implementation differences.
.. note::
The v2 transforms are still BETA, but at this point we do not expect
disruptive changes to be made to their public APIs. We're planning to make
them fully stable in version 0.17. Please submit any feedback you may have
`here <https://github.com/pytorch/vision/issues/6753>`_.
.. _transforms_perf:
Performance considerations
--------------------------
* :ref:`sphx_glr_auto_examples_plot_scripted_tensor_transforms.py` We recommend the following guidelines to get the best performance out of the
transforms:
.. figure:: ../source/auto_examples/images/sphx_glr_plot_scripted_tensor_transforms_001.png - Rely on the v2 transforms from ``torchvision.transforms.v2``
:align: center - Use tensors instead of PIL images
:scale: 30% - Use ``torch.uint8`` dtype, especially for resizing
- Resize with bilinear or bicubic mode
.. warning:: This is what a typical transform pipeline could look like:
Since v0.8.0 all random transformations are using torch default random generator to sample random parameters. .. code:: python
It is a backward compatibility breaking change and user should set the random state as following:
from torchvision.transforms import v2
transforms = v2.Compose([
v2.ToImage(), # Convert to tensor, only needed if you had a PIL image
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
# ...
v2.RandomResizedCrop(size=(224, 224), antialias=True), # Or Resize(antialias=True)
# ...
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
The above should give you the best performance in a typical training environment
that relies on the :class:`torch.utils.data.DataLoader` with ``num_workers >
0``.
Transforms tend to be sensitive to the input strides / memory format. Some
transforms will be faster with channels-first images while others prefer
channels-last. Like ``torch`` operators, most transforms will preserve the
memory format of the input, but this may not always be respected due to
implementation details. You may want to experiment a bit if you're chasing the
very best performance. Using :func:`torch.compile` on individual transforms may
also help factoring out the memory format variable (e.g. on
:class:`~torchvision.transforms.v2.Normalize`). Note that we're talking about
**memory format**, not :ref:`tensor shape <conventions>`.
Note that resize transforms like :class:`~torchvision.transforms.v2.Resize`
and :class:`~torchvision.transforms.v2.RandomResizedCrop` typically prefer
channels-last input and tend **not** to benefit from :func:`torch.compile` at
this time.
.. code:: python .. _functional_transforms:
# Previous versions Transform classes, functionals, and kernels
# import random -------------------------------------------
# random.seed(12)
# Now Transforms are available as classes like
import torch :class:`~torchvision.transforms.v2.Resize`, but also as functionals like
torch.manual_seed(17) :func:`~torchvision.transforms.v2.functional.resize` in the
``torchvision.transforms.v2.functional`` namespace.
This is very much like the :mod:`torch.nn` package which defines both classes
and functional equivalents in :mod:`torch.nn.functional`.
Please, keep in mind that the same seed for torch random generator and Python random generator will not The functionals support PIL images, pure tensors, or :ref:`TVTensors
produce the same results. <tv_tensors>`, e.g. both ``resize(image_tensor)`` and ``resize(boxes)`` are
valid.
.. note::
Scriptable transforms Random transforms like :class:`~torchvision.transforms.v2.RandomCrop` will
--------------------- randomly sample some parameter each time they're called. Their functional
counterpart (:func:`~torchvision.transforms.v2.functional.crop`) does not do
any kind of random sampling and thus have a slighlty different
parametrization. The ``get_params()`` class method of the transforms class
can be used to perform parameter sampling when using the functional APIs.
In order to script the transformations, please use ``torch.nn.Sequential`` instead of :class:`Compose`.
The ``torchvision.transforms.v2.functional`` namespace also contains what we
call the "kernels". These are the low-level functions that implement the
core functionalities for specific types, e.g. ``resize_bounding_boxes`` or
```resized_crop_mask``. They are public, although not documented. Check the
`code
<https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/functional/__init__.py>`_
to see which ones are available (note that those starting with a leading
underscore are **not** public!). Kernels are only really useful if you want
:ref:`torchscript support <transforms_torchscript>` for types like bounding
boxes or masks.
.. _transforms_torchscript:
Torchscript support
-------------------
Most transform classes and functionals support torchscript. For composing
transforms, use :class:`torch.nn.Sequential` instead of
:class:`~torchvision.transforms.v2.Compose`:
.. code:: python .. code:: python
transforms = torch.nn.Sequential( transforms = torch.nn.Sequential(
transforms.CenterCrop(10), CenterCrop(10),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
) )
scripted_transforms = torch.jit.script(transforms) scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor`` and does not require .. warning::
`lambda` functions or ``PIL.Image``.
For any custom transformations to be used with ``torch.jit.script``, they should be derived from ``torch.nn.Module``. v2 transforms support torchscript, but if you call ``torch.jit.script()`` on
a v2 **class** transform, you'll actually end up with its (scripted) v1
equivalent. This may lead to slightly different results between the
scripted and eager executions due to implementation differences between v1
and v2.
If you really need torchscript support for the v2 transforms, we recommend
scripting the **functionals** from the
``torchvision.transforms.v2.functional`` namespace to avoid surprises.
Compositions of transforms
-------------------------- Also note that the functionals only support torchscript for pure tensors, which
are always treated as images. If you need torchscript support for other types
like bounding boxes or masks, you can rely on the :ref:`low-level kernels
<functional_transforms>`.
For any custom transformations to be used with ``torch.jit.script``, they should
be derived from ``torch.nn.Module``.
See also: :ref:`sphx_glr_auto_examples_others_plot_scripted_tensor_transforms.py`.
.. _v2_api_ref:
V2 API reference - Recommended
------------------------------
Geometry
^^^^^^^^
Resizing
""""""""
.. autosummary:: .. autosummary::
:toctree: generated/ :toctree: generated/
:template: class.rst :template: class.rst
Compose v2.Resize
v2.ScaleJitter
v2.RandomShortestSize
v2.RandomResize
Functionals
.. autosummary::
:toctree: generated/
:template: function.rst
v2.functional.resize
Cropping
""""""""
.. autosummary::
:toctree: generated/
:template: class.rst
v2.RandomCrop
v2.RandomResizedCrop
v2.RandomIoUCrop
v2.CenterCrop
v2.FiveCrop
v2.TenCrop
Functionals
.. autosummary::
:toctree: generated/
:template: function.rst
v2.functional.crop
v2.functional.resized_crop
v2.functional.ten_crop
v2.functional.center_crop
v2.functional.five_crop
Others
""""""
.. autosummary::
:toctree: generated/
:template: class.rst
v2.RandomHorizontalFlip
v2.RandomVerticalFlip
v2.Pad
v2.RandomZoomOut
v2.RandomRotation
v2.RandomAffine
v2.RandomPerspective
v2.ElasticTransform
Functionals
.. autosummary::
:toctree: generated/
:template: function.rst
v2.functional.horizontal_flip
v2.functional.vertical_flip
v2.functional.pad
v2.functional.rotate
v2.functional.affine
v2.functional.perspective
v2.functional.elastic
Color
^^^^^
.. autosummary::
:toctree: generated/
:template: class.rst
v2.ColorJitter
v2.RandomChannelPermutation
v2.RandomPhotometricDistort
v2.Grayscale
v2.RandomGrayscale
v2.GaussianBlur
v2.RandomInvert
v2.RandomPosterize
v2.RandomSolarize
v2.RandomAdjustSharpness
v2.RandomAutocontrast
v2.RandomEqualize
Functionals
.. autosummary::
:toctree: generated/
:template: function.rst
v2.functional.permute_channels
v2.functional.rgb_to_grayscale
v2.functional.to_grayscale
v2.functional.gaussian_blur
v2.functional.invert
v2.functional.posterize
v2.functional.solarize
v2.functional.adjust_sharpness
v2.functional.autocontrast
v2.functional.adjust_contrast
v2.functional.equalize
v2.functional.adjust_brightness
v2.functional.adjust_saturation
v2.functional.adjust_hue
v2.functional.adjust_gamma
Composition
^^^^^^^^^^^
.. autosummary::
:toctree: generated/
:template: class.rst
v2.Compose
v2.RandomApply
v2.RandomChoice
v2.RandomOrder
Miscellaneous
^^^^^^^^^^^^^
.. autosummary::
:toctree: generated/
:template: class.rst
v2.LinearTransformation
v2.Normalize
v2.RandomErasing
v2.Lambda
v2.SanitizeBoundingBoxes
v2.ClampBoundingBoxes
v2.UniformTemporalSubsample
Functionals
.. autosummary::
:toctree: generated/
:template: function.rst
v2.functional.normalize
v2.functional.erase
v2.functional.clamp_bounding_boxes
v2.functional.uniform_temporal_subsample
.. _conversion_transforms:
Conversion
^^^^^^^^^^
.. note::
Beware, some of these conversion transforms below will scale the values
while performing the conversion, while some may not do any scaling. By
scaling, we mean e.g. that a ``uint8`` -> ``float32`` would map the [0,
255] range into [0, 1] (and vice-versa). See :ref:`range_and_dtype`.
.. autosummary::
:toctree: generated/
:template: class.rst
v2.ToImage
v2.ToPureTensor
v2.PILToTensor
v2.ToPILImage
v2.ToDtype
v2.ConvertBoundingBoxFormat
functionals
.. autosummary::
:toctree: generated/
:template: functional.rst
v2.functional.to_image
v2.functional.pil_to_tensor
v2.functional.to_pil_image
v2.functional.to_dtype
v2.functional.convert_bounding_box_format
Deprecated
.. autosummary::
:toctree: generated/
:template: class.rst
v2.ToTensor
v2.functional.to_tensor
v2.ConvertImageDtype
v2.functional.convert_image_dtype
Auto-Augmentation
^^^^^^^^^^^^^^^^^
`AutoAugment <https://arxiv.org/pdf/1805.09501.pdf>`_ is a common Data Augmentation technique that can improve the accuracy of Image Classification models.
Though the data augmentation policies are directly linked to their trained dataset, empirical studies show that
ImageNet policies provide significant improvements when applied to other datasets.
In TorchVision we implemented 3 policies learned on the following datasets: ImageNet, CIFAR10 and SVHN.
The new transform can be used standalone or mixed-and-matched with existing transforms:
.. autosummary::
:toctree: generated/
:template: class.rst
v2.AutoAugment
v2.RandAugment
v2.TrivialAugmentWide
v2.AugMix
CutMix - MixUp
^^^^^^^^^^^^^^
Transforms on PIL Image and torch.\*Tensor CutMix and MixUp are special transforms that
------------------------------------------ are meant to be used on batches rather than on individual images, because they
are combining pairs of images together. These can be used after the dataloader
(once the samples are batched), or part of a collation function. See
:ref:`sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py` for detailed usage examples.
.. autosummary:: .. autosummary::
:toctree: generated/ :toctree: generated/
:template: class.rst :template: class.rst
v2.CutMix
v2.MixUp
Developer tools
^^^^^^^^^^^^^^^
.. autosummary::
:toctree: generated/
:template: function.rst
v2.functional.register_kernel
V1 API Reference
----------------
Geometry
^^^^^^^^
.. autosummary::
:toctree: generated/
:template: class.rst
Resize
RandomCrop
RandomResizedCrop
CenterCrop CenterCrop
ColorJitter
FiveCrop FiveCrop
Grayscale TenCrop
Pad Pad
RandomRotation
RandomAffine RandomAffine
RandomApply
RandomCrop
RandomGrayscale
RandomHorizontalFlip
RandomPerspective RandomPerspective
RandomResizedCrop ElasticTransform
RandomRotation RandomHorizontalFlip
RandomVerticalFlip RandomVerticalFlip
Resize
TenCrop
Color
^^^^^
.. autosummary::
:toctree: generated/
:template: class.rst
ColorJitter
Grayscale
RandomGrayscale
GaussianBlur GaussianBlur
RandomInvert RandomInvert
RandomPosterize RandomPosterize
...@@ -130,23 +556,20 @@ Transforms on PIL Image and torch.\*Tensor ...@@ -130,23 +556,20 @@ Transforms on PIL Image and torch.\*Tensor
RandomAutocontrast RandomAutocontrast
RandomEqualize RandomEqualize
Composition
.. _transforms_pil_only: ^^^^^^^^^^^
Transforms on PIL Image only
----------------------------
.. autosummary:: .. autosummary::
:toctree: generated/ :toctree: generated/
:template: class.rst :template: class.rst
Compose
RandomApply
RandomChoice RandomChoice
RandomOrder RandomOrder
.. _transforms_tensor_only: Miscellaneous
^^^^^^^^^^^^^
Transforms on torch.\*Tensor only
---------------------------------
.. autosummary:: .. autosummary::
:toctree: generated/ :toctree: generated/
...@@ -155,13 +578,17 @@ Transforms on torch.\*Tensor only ...@@ -155,13 +578,17 @@ Transforms on torch.\*Tensor only
LinearTransformation LinearTransformation
Normalize Normalize
RandomErasing RandomErasing
ConvertImageDtype Lambda
.. _conversion_transforms:
Conversion Transforms Conversion
--------------------- ^^^^^^^^^^
.. note::
Beware, some of these conversion transforms below will scale the values
while performing the conversion, while some may not do any scaling. By
scaling, we mean e.g. that a ``uint8`` -> ``float32`` would map the [0,
255] range into [0, 1] (and vice-versa). See :ref:`range_and_dtype`.
.. autosummary:: .. autosummary::
:toctree: generated/ :toctree: generated/
:template: class.rst :template: class.rst
...@@ -169,20 +596,10 @@ Conversion Transforms ...@@ -169,20 +596,10 @@ Conversion Transforms
ToPILImage ToPILImage
ToTensor ToTensor
PILToTensor PILToTensor
ConvertImageDtype
Auto-Augmentation
Generic Transforms ^^^^^^^^^^^^^^^^^
------------------
.. autosummary::
:toctree: generated/
:template: class.rst
Lambda
Automatic Augmentation Transforms
---------------------------------
`AutoAugment <https://arxiv.org/pdf/1805.09501.pdf>`_ is a common Data Augmentation technique that can improve the accuracy of Image Classification models. `AutoAugment <https://arxiv.org/pdf/1805.09501.pdf>`_ is a common Data Augmentation technique that can improve the accuracy of Image Classification models.
Though the data augmentation policies are directly linked to their trained dataset, empirical studies show that Though the data augmentation policies are directly linked to their trained dataset, empirical studies show that
...@@ -200,57 +617,13 @@ The new transform can be used standalone or mixed-and-matched with existing tran ...@@ -200,57 +617,13 @@ The new transform can be used standalone or mixed-and-matched with existing tran
TrivialAugmentWide TrivialAugmentWide
AugMix AugMix
.. _functional_transforms:
Functional Transforms Functional Transforms
--------------------- ^^^^^^^^^^^^^^^^^^^^^
.. currentmodule:: torchvision.transforms.functional .. currentmodule:: torchvision.transforms.functional
Functional transforms give you fine-grained control of the transformation pipeline.
As opposed to the transformations above, functional transforms don't contain a random number
generator for their parameters.
That means you have to specify/generate all parameters, but the functional transform will give you
reproducible results across calls.
Example:
you can apply a functional transform with the same parameters to multiple images like this:
.. code:: python
import torchvision.transforms.functional as TF
import random
def my_segmentation_transforms(image, segmentation):
if random.random() > 0.5:
angle = random.randint(-30, 30)
image = TF.rotate(image, angle)
segmentation = TF.rotate(segmentation, angle)
# more transforms ...
return image, segmentation
Example:
you can use a functional transform to build transform classes with custom behavior:
.. code:: python
import torchvision.transforms.functional as TF
import random
class MyRotationTransform:
"""Rotate by one of the given angles."""
def __init__(self, angles):
self.angles = angles
def __call__(self, x):
angle = random.choice(self.angles)
return TF.rotate(x, angle)
rotation_transform = MyRotationTransform(angles=[-30, -15, 0, 15, 30])
.. autosummary:: .. autosummary::
:toctree: generated/ :toctree: generated/
:template: function.rst :template: function.rst
......
.. _tv_tensors:
TVTensors
==========
.. currentmodule:: torchvision.tv_tensors
TVTensors are :class:`torch.Tensor` subclasses which the v2 :ref:`transforms
<transforms>` use under the hood to dispatch their inputs to the appropriate
lower-level kernels. Most users do not need to manipulate TVTensors directly.
Refer to
:ref:`sphx_glr_auto_examples_transforms_plot_transforms_getting_started.py` for
an introduction to TVTensors, or
:ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py` for more advanced
info.
.. autosummary::
:toctree: generated/
:template: class.rst
Image
Video
BoundingBoxFormat
BoundingBoxes
Mask
TVTensor
set_return_type
wrap
...@@ -4,7 +4,7 @@ Utils ...@@ -4,7 +4,7 @@ Utils
===== =====
The ``torchvision.utils`` module contains various utilities, mostly :ref:`for The ``torchvision.utils`` module contains various utilities, mostly :ref:`for
vizualization <sphx_glr_auto_examples_plot_visualization_utils.py>`. visualization <sphx_glr_auto_examples_others_plot_visualization_utils.py>`.
.. currentmodule:: torchvision.utils .. currentmodule:: torchvision.utils
......
...@@ -17,4 +17,4 @@ add_executable(hello-world main.cpp) ...@@ -17,4 +17,4 @@ add_executable(hello-world main.cpp)
# which also adds all the necessary torch dependencies. # which also adds all the necessary torch dependencies.
target_compile_features(hello-world PUBLIC cxx_range_for) target_compile_features(hello-world PUBLIC cxx_range_for)
target_link_libraries(hello-world TorchVision::TorchVision) target_link_libraries(hello-world TorchVision::TorchVision)
set_property(TARGET hello-world PROPERTY CXX_STANDARD 14) set_property(TARGET hello-world PROPERTY CXX_STANDARD 17)
Example gallery .. _gallery:
===============
Below is a gallery of examples Examples and tutorials
======================
../../astronaut.jpg
\ No newline at end of file
../../dog2.jpg
\ No newline at end of file
{"images": [{"file_name": "000000000001.jpg", "height": 512, "width": 512, "id": 1}, {"file_name": "000000000002.jpg", "height": 500, "width": 500, "id": 2}], "annotations": [{"segmentation": [[40.0, 511.0, 26.0, 487.0, 28.0, 438.0, 17.0, 397.0, 24.0, 346.0, 38.0, 306.0, 61.0, 250.0, 111.0, 206.0, 111.0, 187.0, 120.0, 183.0, 136.0, 159.0, 159.0, 150.0, 181.0, 148.0, 182.0, 132.0, 175.0, 132.0, 168.0, 120.0, 154.0, 102.0, 153.0, 62.0, 188.0, 35.0, 191.0, 29.0, 208.0, 20.0, 210.0, 22.0, 227.0, 16.0, 240.0, 16.0, 276.0, 31.0, 285.0, 39.0, 301.0, 88.0, 297.0, 108.0, 281.0, 128.0, 273.0, 138.0, 266.0, 138.0, 264.0, 153.0, 257.0, 162.0, 256.0, 174.0, 284.0, 197.0, 300.0, 221.0, 303.0, 236.0, 337.0, 258.0, 357.0, 306.0, 361.0, 351.0, 358.0, 511.0]], "iscrowd": 0, "image_id": 1, "bbox": [17.0, 16.0, 344.0, 495.0], "category_id": 1, "id": 1}, {"segmentation": [[0.0, 411.0, 43.0, 401.0, 99.0, 395.0, 105.0, 351.0, 124.0, 326.0, 181.0, 294.0, 227.0, 280.0, 245.0, 262.0, 259.0, 234.0, 262.0, 207.0, 271.0, 140.0, 283.0, 139.0, 301.0, 162.0, 309.0, 181.0, 341.0, 175.0, 362.0, 139.0, 369.0, 139.0, 377.0, 163.0, 378.0, 203.0, 381.0, 212.0, 380.0, 220.0, 382.0, 242.0, 404.0, 264.0, 392.0, 293.0, 384.0, 295.0, 385.0, 316.0, 399.0, 343.0, 391.0, 448.0, 452.0, 475.0, 457.0, 494.0, 436.0, 498.0, 402.0, 491.0, 369.0, 488.0, 366.0, 496.0, 319.0, 496.0, 302.0, 485.0, 226.0, 469.0, 128.0, 456.0, 74.0, 458.0, 29.0, 439.0, 0.0, 445.0]], "iscrowd": 0, "image_id": 2, "bbox": [0.0, 139.0, 457.0, 359.0], "category_id": 18, "id": 2}]}
...@@ -3,6 +3,10 @@ ...@@ -3,6 +3,10 @@
Optical Flow: Predicting movement with the RAFT model Optical Flow: Predicting movement with the RAFT model
===================================================== =====================================================
.. note::
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_optical_flow.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_others_plot_optical_flow.py>` to download the full example code.
Optical flow is the task of predicting movement between two images, usually two Optical flow is the task of predicting movement between two images, usually two
consecutive frames of a video. Optical flow models take two images as input, and consecutive frames of a video. Optical flow models take two images as input, and
predict a flow: the flow indicates the displacement of every single pixel in the predict a flow: the flow indicates the displacement of every single pixel in the
...@@ -42,7 +46,7 @@ def plot(imgs, **imshow_kwargs): ...@@ -42,7 +46,7 @@ def plot(imgs, **imshow_kwargs):
plt.tight_layout() plt.tight_layout()
################################### # %%
# Reading Videos Using Torchvision # Reading Videos Using Torchvision
# -------------------------------- # --------------------------------
# We will first read a video using :func:`~torchvision.io.read_video`. # We will first read a video using :func:`~torchvision.io.read_video`.
...@@ -62,7 +66,7 @@ video_url = "https://download.pytorch.org/tutorial/pexelscom_pavel_danilyuk_bask ...@@ -62,7 +66,7 @@ video_url = "https://download.pytorch.org/tutorial/pexelscom_pavel_danilyuk_bask
video_path = Path(tempfile.mkdtemp()) / "basketball.mp4" video_path = Path(tempfile.mkdtemp()) / "basketball.mp4"
_ = urlretrieve(video_url, video_path) _ = urlretrieve(video_url, video_path)
######################### # %%
# :func:`~torchvision.io.read_video` returns the video frames, audio frames and # :func:`~torchvision.io.read_video` returns the video frames, audio frames and
# the metadata associated with the video. In our case, we only need the video # the metadata associated with the video. In our case, we only need the video
# frames. # frames.
...@@ -79,11 +83,12 @@ img2_batch = torch.stack([frames[101], frames[151]]) ...@@ -79,11 +83,12 @@ img2_batch = torch.stack([frames[101], frames[151]])
plot(img1_batch) plot(img1_batch)
######################### # %%
# The RAFT model accepts RGB images. We first get the frames from # The RAFT model accepts RGB images. We first get the frames from
# :func:`~torchvision.io.read_video` and resize them to ensure their # :func:`~torchvision.io.read_video` and resize them to ensure their dimensions
# dimensions are divisible by 8. Then we use the transforms bundled into the # are divisible by 8. Note that we explicitly use ``antialias=False``, because
# weights in order to preprocess the input and rescale its values to the # this is how those models were trained. Then we use the transforms bundled into
# the weights in order to preprocess the input and rescale its values to the
# required ``[-1, 1]`` interval. # required ``[-1, 1]`` interval.
from torchvision.models.optical_flow import Raft_Large_Weights from torchvision.models.optical_flow import Raft_Large_Weights
...@@ -93,8 +98,8 @@ transforms = weights.transforms() ...@@ -93,8 +98,8 @@ transforms = weights.transforms()
def preprocess(img1_batch, img2_batch): def preprocess(img1_batch, img2_batch):
img1_batch = F.resize(img1_batch, size=[520, 960]) img1_batch = F.resize(img1_batch, size=[520, 960], antialias=False)
img2_batch = F.resize(img2_batch, size=[520, 960]) img2_batch = F.resize(img2_batch, size=[520, 960], antialias=False)
return transforms(img1_batch, img2_batch) return transforms(img1_batch, img2_batch)
...@@ -103,7 +108,7 @@ img1_batch, img2_batch = preprocess(img1_batch, img2_batch) ...@@ -103,7 +108,7 @@ img1_batch, img2_batch = preprocess(img1_batch, img2_batch)
print(f"shape = {img1_batch.shape}, dtype = {img1_batch.dtype}") print(f"shape = {img1_batch.shape}, dtype = {img1_batch.dtype}")
#################################### # %%
# Estimating Optical flow using RAFT # Estimating Optical flow using RAFT
# ---------------------------------- # ----------------------------------
# We will use our RAFT implementation from # We will use our RAFT implementation from
...@@ -124,12 +129,12 @@ list_of_flows = model(img1_batch.to(device), img2_batch.to(device)) ...@@ -124,12 +129,12 @@ list_of_flows = model(img1_batch.to(device), img2_batch.to(device))
print(f"type = {type(list_of_flows)}") print(f"type = {type(list_of_flows)}")
print(f"length = {len(list_of_flows)} = number of iterations of the model") print(f"length = {len(list_of_flows)} = number of iterations of the model")
#################################### # %%
# The RAFT model outputs lists of predicted flows where each entry is a # The RAFT model outputs lists of predicted flows where each entry is a
# (N, 2, H, W) batch of predicted flows that corresponds to a given "iteration" # (N, 2, H, W) batch of predicted flows that corresponds to a given "iteration"
# in the model. For more details on the iterative nature of the model, please # in the model. For more details on the iterative nature of the model, please
# refer to the `original paper <https://arxiv.org/abs/2003.12039>`_. Here, we # refer to the `original paper <https://arxiv.org/abs/2003.12039>`_. Here, we
# are only interested in the final predicted flows (they are the most acccurate # are only interested in the final predicted flows (they are the most accurate
# ones), so we will just retrieve the last item in the list. # ones), so we will just retrieve the last item in the list.
# #
# As described above, a flow is a tensor with dimensions (2, H, W) (or (N, 2, H, # As described above, a flow is a tensor with dimensions (2, H, W) (or (N, 2, H,
...@@ -143,10 +148,10 @@ print(f"shape = {predicted_flows.shape} = (N, 2, H, W)") ...@@ -143,10 +148,10 @@ print(f"shape = {predicted_flows.shape} = (N, 2, H, W)")
print(f"min = {predicted_flows.min()}, max = {predicted_flows.max()}") print(f"min = {predicted_flows.min()}, max = {predicted_flows.max()}")
#################################### # %%
# Visualizing predicted flows # Visualizing predicted flows
# --------------------------- # ---------------------------
# Torchvision provides the :func:`~torchvision.utils.flow_to_image` utlity to # Torchvision provides the :func:`~torchvision.utils.flow_to_image` utility to
# convert a flow into an RGB image. It also supports batches of flows. # convert a flow into an RGB image. It also supports batches of flows.
# each "direction" in the flow will be mapped to a given RGB color. In the # each "direction" in the flow will be mapped to a given RGB color. In the
# images below, pixels with similar colors are assumed by the model to be moving # images below, pixels with similar colors are assumed by the model to be moving
...@@ -165,7 +170,7 @@ img1_batch = [(img1 + 1) / 2 for img1 in img1_batch] ...@@ -165,7 +170,7 @@ img1_batch = [(img1 + 1) / 2 for img1 in img1_batch]
grid = [[img1, flow_img] for (img1, flow_img) in zip(img1_batch, flow_imgs)] grid = [[img1, flow_img] for (img1, flow_img) in zip(img1_batch, flow_imgs)]
plot(grid) plot(grid)
#################################### # %%
# Bonus: Creating GIFs of predicted flows # Bonus: Creating GIFs of predicted flows
# --------------------------------------- # ---------------------------------------
# In the example above we have only shown the predicted flows of 2 pairs of # In the example above we have only shown the predicted flows of 2 pairs of
...@@ -186,7 +191,7 @@ plot(grid) ...@@ -186,7 +191,7 @@ plot(grid)
# output_folder = "/tmp/" # Update this to the folder of your choice # output_folder = "/tmp/" # Update this to the folder of your choice
# write_jpeg(flow_img, output_folder + f"predicted_flow_{i}.jpg") # write_jpeg(flow_img, output_folder + f"predicted_flow_{i}.jpg")
#################################### # %%
# Once the .jpg flow images are saved, you can convert them into a video or a # Once the .jpg flow images are saved, you can convert them into a video or a
# GIF using ffmpeg with e.g.: # GIF using ffmpeg with e.g.:
# #
......
...@@ -3,6 +3,10 @@ ...@@ -3,6 +3,10 @@
Repurposing masks into bounding boxes Repurposing masks into bounding boxes
===================================== =====================================
.. note::
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_repurposing_annotations.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_others_plot_repurposing_annotations.py>` to download the full example code.
The following example illustrates the operations available The following example illustrates the operations available
the :ref:`torchvision.ops <ops>` module for repurposing the :ref:`torchvision.ops <ops>` module for repurposing
segmentation masks into object localization annotations for different tasks segmentation masks into object localization annotations for different tasks
...@@ -20,7 +24,7 @@ import matplotlib.pyplot as plt ...@@ -20,7 +24,7 @@ import matplotlib.pyplot as plt
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
ASSETS_DIRECTORY = "assets" ASSETS_DIRECTORY = "../assets"
plt.rcParams["savefig.bbox"] = "tight" plt.rcParams["savefig.bbox"] = "tight"
...@@ -36,7 +40,7 @@ def show(imgs): ...@@ -36,7 +40,7 @@ def show(imgs):
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
#################################### # %%
# Masks # Masks
# ----- # -----
# In tasks like instance and panoptic segmentation, masks are commonly defined, and are defined by this package, # In tasks like instance and panoptic segmentation, masks are commonly defined, and are defined by this package,
...@@ -53,7 +57,7 @@ def show(imgs): ...@@ -53,7 +57,7 @@ def show(imgs):
# A nice property of masks is that they can be easily repurposed to be used in methods to solve a variety of object # A nice property of masks is that they can be easily repurposed to be used in methods to solve a variety of object
# localization tasks. # localization tasks.
#################################### # %%
# Converting Masks to Bounding Boxes # Converting Masks to Bounding Boxes
# ----------------------------------------------- # -----------------------------------------------
# For example, the :func:`~torchvision.ops.masks_to_boxes` operation can be used to # For example, the :func:`~torchvision.ops.masks_to_boxes` operation can be used to
...@@ -70,7 +74,7 @@ img = read_image(img_path) ...@@ -70,7 +74,7 @@ img = read_image(img_path)
mask = read_image(mask_path) mask = read_image(mask_path)
######################### # %%
# Here the masks are represented as a PNG Image, with floating point values. # Here the masks are represented as a PNG Image, with floating point values.
# Each pixel is encoded as different colors, with 0 being background. # Each pixel is encoded as different colors, with 0 being background.
# Notice that the spatial dimensions of image and mask match. # Notice that the spatial dimensions of image and mask match.
...@@ -79,7 +83,7 @@ print(mask.size()) ...@@ -79,7 +83,7 @@ print(mask.size())
print(img.size()) print(img.size())
print(mask) print(mask)
############################ # %%
# We get the unique colors, as these would be the object ids. # We get the unique colors, as these would be the object ids.
obj_ids = torch.unique(mask) obj_ids = torch.unique(mask)
...@@ -91,7 +95,7 @@ obj_ids = obj_ids[1:] ...@@ -91,7 +95,7 @@ obj_ids = obj_ids[1:]
# Note that this snippet would work as well if the masks were float values instead of ints. # Note that this snippet would work as well if the masks were float values instead of ints.
masks = mask == obj_ids[:, None, None] masks = mask == obj_ids[:, None, None]
######################## # %%
# Now the masks are a boolean tensor. # Now the masks are a boolean tensor.
# The first dimension in this case 3 and denotes the number of instances: there are 3 people in the image. # The first dimension in this case 3 and denotes the number of instances: there are 3 people in the image.
# The other two dimensions are height and width, which are equal to the dimensions of the image. # The other two dimensions are height and width, which are equal to the dimensions of the image.
...@@ -101,7 +105,7 @@ masks = mask == obj_ids[:, None, None] ...@@ -101,7 +105,7 @@ masks = mask == obj_ids[:, None, None]
print(masks.size()) print(masks.size())
print(masks) print(masks)
#################################### # %%
# Let us visualize an image and plot its corresponding segmentation masks. # Let us visualize an image and plot its corresponding segmentation masks.
# We will use the :func:`~torchvision.utils.draw_segmentation_masks` to draw the segmentation masks. # We will use the :func:`~torchvision.utils.draw_segmentation_masks` to draw the segmentation masks.
...@@ -113,7 +117,7 @@ for mask in masks: ...@@ -113,7 +117,7 @@ for mask in masks:
show(drawn_masks) show(drawn_masks)
#################################### # %%
# To convert the boolean masks into bounding boxes. # To convert the boolean masks into bounding boxes.
# We will use the :func:`~torchvision.ops.masks_to_boxes` from the torchvision.ops module # We will use the :func:`~torchvision.ops.masks_to_boxes` from the torchvision.ops module
# It returns the boxes in ``(xmin, ymin, xmax, ymax)`` format. # It returns the boxes in ``(xmin, ymin, xmax, ymax)`` format.
...@@ -124,7 +128,7 @@ boxes = masks_to_boxes(masks) ...@@ -124,7 +128,7 @@ boxes = masks_to_boxes(masks)
print(boxes.size()) print(boxes.size())
print(boxes) print(boxes)
#################################### # %%
# As the shape denotes, there are 3 boxes and in ``(xmin, ymin, xmax, ymax)`` format. # As the shape denotes, there are 3 boxes and in ``(xmin, ymin, xmax, ymax)`` format.
# These can be visualized very easily with :func:`~torchvision.utils.draw_bounding_boxes` utility # These can be visualized very easily with :func:`~torchvision.utils.draw_bounding_boxes` utility
# provided in :ref:`torchvision.utils <utils>`. # provided in :ref:`torchvision.utils <utils>`.
...@@ -134,7 +138,7 @@ from torchvision.utils import draw_bounding_boxes ...@@ -134,7 +138,7 @@ from torchvision.utils import draw_bounding_boxes
drawn_boxes = draw_bounding_boxes(img, boxes, colors="red") drawn_boxes = draw_bounding_boxes(img, boxes, colors="red")
show(drawn_boxes) show(drawn_boxes)
################################### # %%
# These boxes can now directly be used by detection models in torchvision. # These boxes can now directly be used by detection models in torchvision.
# Here is demo with a Faster R-CNN model loaded from # Here is demo with a Faster R-CNN model loaded from
# :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` # :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn`
...@@ -153,7 +157,7 @@ target["labels"] = labels = torch.ones((masks.size(0),), dtype=torch.int64) ...@@ -153,7 +157,7 @@ target["labels"] = labels = torch.ones((masks.size(0),), dtype=torch.int64)
detection_outputs = model(img.unsqueeze(0), [target]) detection_outputs = model(img.unsqueeze(0), [target])
#################################### # %%
# Converting Segmentation Dataset to Detection Dataset # Converting Segmentation Dataset to Detection Dataset
# ---------------------------------------------------- # ----------------------------------------------------
# #
......
""" """
========================= ===================
Tensor transforms and JIT Torchscript support
========================= ===================
This example illustrates various features that are now supported by the
:ref:`image transformations <transforms>` on Tensor images. In particular, we
show how image transforms can be performed on GPU, and how one can also script
them using JIT compilation.
Prior to v0.8.0, transforms in torchvision have traditionally been PIL-centric
and presented multiple limitations due to that. Now, since v0.8.0, transforms
implementations are Tensor and PIL compatible and we can achieve the following
new features:
- transform multi-band torch tensor images (with more than 3-4 channels)
- torchscript transforms together with your model for deployment
- support for GPU acceleration
- batched transformation such as for videos
- read and decode data directly as torch tensor with torchscript support (for PNG and JPEG image formats)
.. note:: .. note::
These features are only possible with **Tensor** images. Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_scripted_tensor_transforms.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_others_plot_scripted_tensor_transforms.py>` to download the full example code.
This example illustrates `torchscript
<https://pytorch.org/docs/stable/jit.html>`_ support of the torchvision
:ref:`transforms <transforms>` on Tensor images.
""" """
# %%
from pathlib import Path from pathlib import Path
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np
import torch import torch
import torchvision.transforms as T import torch.nn as nn
from torchvision.io import read_image
import torchvision.transforms as v1
from torchvision.io import read_image
plt.rcParams["savefig.bbox"] = 'tight' plt.rcParams["savefig.bbox"] = 'tight'
torch.manual_seed(1) torch.manual_seed(1)
# If you're trying to run that on collab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
import sys
sys.path += ["../transforms"]
from helpers import plot
ASSETS_PATH = Path('../assets')
def show(imgs):
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
for i, img in enumerate(imgs):
img = T.ToPILImage()(img.to('cpu'))
axs[0, i].imshow(np.asarray(img))
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
# %%
# Most transforms support torchscript. For composing transforms, we use
# :class:`torch.nn.Sequential` instead of
# :class:`~torchvision.transforms.v2.Compose`:
#################################### dog1 = read_image(str(ASSETS_PATH / 'dog1.jpg'))
# The :func:`~torchvision.io.read_image` function allows to read an image and dog2 = read_image(str(ASSETS_PATH / 'dog2.jpg'))
# directly load it as a tensor
dog1 = read_image(str(Path('assets') / 'dog1.jpg'))
dog2 = read_image(str(Path('assets') / 'dog2.jpg'))
show([dog1, dog2])
####################################
# Transforming images on GPU
# --------------------------
# Most transforms natively support tensors on top of PIL images (to visualize
# the effect of the transforms, you may refer to see
# :ref:`sphx_glr_auto_examples_plot_transforms.py`).
# Using tensor images, we can run the transforms on GPUs if cuda is available!
import torch.nn as nn
transforms = torch.nn.Sequential( transforms = torch.nn.Sequential(
T.RandomCrop(224), v1.RandomCrop(224),
T.RandomHorizontalFlip(p=0.3), v1.RandomHorizontalFlip(p=0.3),
) )
device = 'cuda' if torch.cuda.is_available() else 'cpu' scripted_transforms = torch.jit.script(transforms)
dog1 = dog1.to(device)
dog2 = dog2.to(device)
transformed_dog1 = transforms(dog1) plot([dog1, scripted_transforms(dog1), dog2, scripted_transforms(dog2)])
transformed_dog2 = transforms(dog2)
show([transformed_dog1, transformed_dog2])
####################################
# Scriptable transforms for easier deployment via torchscript # %%
# ----------------------------------------------------------- # .. warning::
# We now show how to combine image transformations and a model forward pass, #
# while using ``torch.jit.script`` to obtain a single scripted module. # Above we have used transforms from the ``torchvision.transforms``
# namespace, i.e. the "v1" transforms. The v2 transforms from the
# ``torchvision.transforms.v2`` namespace are the :ref:`recommended
# <v1_or_v2>` way to use transforms in your code.
#
# The v2 transforms also support torchscript, but if you call
# ``torch.jit.script()`` on a v2 **class** transform, you'll actually end up
# with its (scripted) v1 equivalent. This may lead to slightly different
# results between the scripted and eager executions due to implementation
# differences between v1 and v2.
#
# If you really need torchscript support for the v2 transforms, **we
# recommend scripting the functionals** from the
# ``torchvision.transforms.v2.functional`` namespace to avoid surprises.
#
# Below we now show how to combine image transformations and a model forward
# pass, while using ``torch.jit.script`` to obtain a single scripted module.
# #
# Let's define a ``Predictor`` module that transforms the input tensor and then # Let's define a ``Predictor`` module that transforms the input tensor and then
# applies an ImageNet model on it. # applies an ImageNet model on it.
...@@ -94,7 +85,7 @@ class Predictor(nn.Module): ...@@ -94,7 +85,7 @@ class Predictor(nn.Module):
super().__init__() super().__init__()
weights = ResNet18_Weights.DEFAULT weights = ResNet18_Weights.DEFAULT
self.resnet18 = resnet18(weights=weights, progress=False).eval() self.resnet18 = resnet18(weights=weights, progress=False).eval()
self.transforms = weights.transforms() self.transforms = weights.transforms(antialias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad(): with torch.no_grad():
...@@ -103,10 +94,12 @@ class Predictor(nn.Module): ...@@ -103,10 +94,12 @@ class Predictor(nn.Module):
return y_pred.argmax(dim=1) return y_pred.argmax(dim=1)
#################################### # %%
# Now, let's define scripted and non-scripted instances of ``Predictor`` and # Now, let's define scripted and non-scripted instances of ``Predictor`` and
# apply it on multiple tensor images of the same size # apply it on multiple tensor images of the same size
device = "cuda" if torch.cuda.is_available() else "cpu"
predictor = Predictor().to(device) predictor = Predictor().to(device)
scripted_predictor = torch.jit.script(predictor).to(device) scripted_predictor = torch.jit.script(predictor).to(device)
...@@ -115,20 +108,20 @@ batch = torch.stack([dog1, dog2]).to(device) ...@@ -115,20 +108,20 @@ batch = torch.stack([dog1, dog2]).to(device)
res = predictor(batch) res = predictor(batch)
res_scripted = scripted_predictor(batch) res_scripted = scripted_predictor(batch)
#################################### # %%
# We can verify that the prediction of the scripted and non-scripted models are # We can verify that the prediction of the scripted and non-scripted models are
# the same: # the same:
import json import json
with open(Path('assets') / 'imagenet_class_index.json') as labels_file: with open(Path('../assets') / 'imagenet_class_index.json') as labels_file:
labels = json.load(labels_file) labels = json.load(labels_file)
for i, (pred, pred_scripted) in enumerate(zip(res, res_scripted)): for i, (pred, pred_scripted) in enumerate(zip(res, res_scripted)):
assert pred == pred_scripted assert pred == pred_scripted
print(f"Prediction for Dog {i + 1}: {labels[str(pred.item())]}") print(f"Prediction for Dog {i + 1}: {labels[str(pred.item())]}")
#################################### # %%
# Since the model is scripted, it can be easily dumped on disk and re-used # Since the model is scripted, it can be easily dumped on disk and re-used
import tempfile import tempfile
...@@ -139,3 +132,5 @@ with tempfile.NamedTemporaryFile() as f: ...@@ -139,3 +132,5 @@ with tempfile.NamedTemporaryFile() as f:
dumped_scripted_predictor = torch.jit.load(f.name) dumped_scripted_predictor = torch.jit.load(f.name)
res_scripted_dumped = dumped_scripted_predictor(batch) res_scripted_dumped = dumped_scripted_predictor(batch)
assert (res_scripted_dumped == res_scripted).all() assert (res_scripted_dumped == res_scripted).all()
# %%
""" """
======================= =========
Video API Video API
======================= =========
.. note::
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_video_api.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_others_plot_video_api.py>` to download the full example code.
This example illustrates some of the APIs that torchvision offers for This example illustrates some of the APIs that torchvision offers for
videos, together with the examples on how to build datasets and more. videos, together with the examples on how to build datasets and more.
""" """
#################################### # %%
# 1. Introduction: building a new video object and examining the properties # 1. Introduction: building a new video object and examining the properties
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# First we select a video to test the object out. For the sake of argument # First we select a video to test the object out. For the sake of argument
# we're using one from kinetics400 dataset. # we're using one from kinetics400 dataset.
# To create it, we need to define the path and the stream we want to use. # To create it, we need to define the path and the stream we want to use.
###################################### # %%
# Chosen video statistics: # Chosen video statistics:
# #
# - WUzgd7C1pWA.mp4 # - WUzgd7C1pWA.mp4
...@@ -32,6 +36,7 @@ videos, together with the examples on how to build datasets and more. ...@@ -32,6 +36,7 @@ videos, together with the examples on how to build datasets and more.
import torch import torch
import torchvision import torchvision
from torchvision.datasets.utils import download_url from torchvision.datasets.utils import download_url
torchvision.set_video_backend("video_reader")
# Download the sample video # Download the sample video
download_url( download_url(
...@@ -41,7 +46,7 @@ download_url( ...@@ -41,7 +46,7 @@ download_url(
) )
video_path = "./WUzgd7C1pWA.mp4" video_path = "./WUzgd7C1pWA.mp4"
###################################### # %%
# Streams are defined in a similar fashion as torch devices. We encode them as strings in a form # Streams are defined in a similar fashion as torch devices. We encode them as strings in a form
# of ``stream_type:stream_id`` where ``stream_type`` is a string and ``stream_id`` a long int. # of ``stream_type:stream_id`` where ``stream_type`` is a string and ``stream_id`` a long int.
# The constructor accepts passing a ``stream_type`` only, in which case the stream is auto-discovered. # The constructor accepts passing a ``stream_type`` only, in which case the stream is auto-discovered.
...@@ -51,7 +56,7 @@ stream = "video" ...@@ -51,7 +56,7 @@ stream = "video"
video = torchvision.io.VideoReader(video_path, stream) video = torchvision.io.VideoReader(video_path, stream)
video.get_metadata() video.get_metadata()
###################################### # %%
# Here we can see that video has two streams - a video and an audio stream. # Here we can see that video has two streams - a video and an audio stream.
# Currently available stream types include ['video', 'audio']. # Currently available stream types include ['video', 'audio'].
# Each descriptor consists of two parts: stream type (e.g. 'video') and a unique stream id # Each descriptor consists of two parts: stream type (e.g. 'video') and a unique stream id
...@@ -60,7 +65,7 @@ video.get_metadata() ...@@ -60,7 +65,7 @@ video.get_metadata()
# users can access the one they want. # users can access the one they want.
# If only stream type is passed, the decoder auto-detects first stream of that type and returns it. # If only stream type is passed, the decoder auto-detects first stream of that type and returns it.
###################################### # %%
# Let's read all the frames from the video stream. By default, the return value of # Let's read all the frames from the video stream. By default, the return value of
# ``next(video_reader)`` is a dict containing the following fields. # ``next(video_reader)`` is a dict containing the following fields.
# #
...@@ -84,7 +89,7 @@ approx_nf = metadata['audio']['duration'][0] * metadata['audio']['framerate'][0] ...@@ -84,7 +89,7 @@ approx_nf = metadata['audio']['duration'][0] * metadata['audio']['framerate'][0]
print("Approx total number of datapoints we can expect: ", approx_nf) print("Approx total number of datapoints we can expect: ", approx_nf)
print("Read data size: ", frames[0].size(0) * len(frames)) print("Read data size: ", frames[0].size(0) * len(frames))
###################################### # %%
# But what if we only want to read certain time segment of the video? # But what if we only want to read certain time segment of the video?
# That can be done easily using the combination of our ``seek`` function, and the fact that each call # That can be done easily using the combination of our ``seek`` function, and the fact that each call
# to next returns the presentation timestamp of the returned frame in seconds. # to next returns the presentation timestamp of the returned frame in seconds.
...@@ -106,7 +111,7 @@ for frame, pts in itertools.islice(video.seek(2), 10): ...@@ -106,7 +111,7 @@ for frame, pts in itertools.islice(video.seek(2), 10):
print("Total number of frames: ", len(frames)) print("Total number of frames: ", len(frames))
###################################### # %%
# Or if we wanted to read from 2nd to 5th second, # Or if we wanted to read from 2nd to 5th second,
# We seek into a second second of the video, # We seek into a second second of the video,
# then we utilize the itertools takewhile to get the # then we utilize the itertools takewhile to get the
...@@ -124,7 +129,7 @@ approx_nf = (5 - 2) * video.get_metadata()['video']['fps'][0] ...@@ -124,7 +129,7 @@ approx_nf = (5 - 2) * video.get_metadata()['video']['fps'][0]
print("We can expect approx: ", approx_nf) print("We can expect approx: ", approx_nf)
print("Tensor size: ", frames[0].size()) print("Tensor size: ", frames[0].size())
#################################### # %%
# 2. Building a sample read_video function # 2. Building a sample read_video function
# ---------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------
# We can utilize the methods above to build the read video function that follows # We can utilize the methods above to build the read video function that follows
...@@ -169,21 +174,21 @@ def example_read_video(video_object, start=0, end=None, read_video=True, read_au ...@@ -169,21 +174,21 @@ def example_read_video(video_object, start=0, end=None, read_video=True, read_au
vf, af, info, meta = example_read_video(video) vf, af, info, meta = example_read_video(video)
print(vf.size(), af.size()) print(vf.size(), af.size())
#################################### # %%
# 3. Building an example randomly sampled dataset (can be applied to training dataset of kinetics400) # 3. Building an example randomly sampled dataset (can be applied to training dataset of kinetics400)
# ------------------------------------------------------------------------------------------------------- # -------------------------------------------------------------------------------------------------------
# Cool, so now we can use the same principle to make the sample dataset. # Cool, so now we can use the same principle to make the sample dataset.
# We suggest trying out iterable dataset for this purpose. # We suggest trying out iterable dataset for this purpose.
# Here, we are going to build an example dataset that reads randomly selected 10 frames of video. # Here, we are going to build an example dataset that reads randomly selected 10 frames of video.
#################################### # %%
# Make sample dataset # Make sample dataset
import os import os
os.makedirs("./dataset", exist_ok=True) os.makedirs("./dataset", exist_ok=True)
os.makedirs("./dataset/1", exist_ok=True) os.makedirs("./dataset/1", exist_ok=True)
os.makedirs("./dataset/2", exist_ok=True) os.makedirs("./dataset/2", exist_ok=True)
#################################### # %%
# Download the videos # Download the videos
from torchvision.datasets.utils import download_url from torchvision.datasets.utils import download_url
download_url( download_url(
...@@ -211,7 +216,7 @@ download_url( ...@@ -211,7 +216,7 @@ download_url(
"v_SoccerJuggling_g24_c01.avi" "v_SoccerJuggling_g24_c01.avi"
) )
#################################### # %%
# Housekeeping and utilities # Housekeeping and utilities
import os import os
import random import random
...@@ -231,7 +236,7 @@ def get_samples(root, extensions=(".mp4", ".avi")): ...@@ -231,7 +236,7 @@ def get_samples(root, extensions=(".mp4", ".avi")):
_, class_to_idx = _find_classes(root) _, class_to_idx = _find_classes(root)
return make_dataset(root, class_to_idx, extensions=extensions) return make_dataset(root, class_to_idx, extensions=extensions)
#################################### # %%
# We are going to define the dataset and some basic arguments. # We are going to define the dataset and some basic arguments.
# We assume the structure of the FolderDataset, and add the following parameters: # We assume the structure of the FolderDataset, and add the following parameters:
# #
...@@ -286,7 +291,7 @@ class RandomDataset(torch.utils.data.IterableDataset): ...@@ -286,7 +291,7 @@ class RandomDataset(torch.utils.data.IterableDataset):
'end': current_pts} 'end': current_pts}
yield output yield output
#################################### # %%
# Given a path of videos in a folder structure, i.e: # Given a path of videos in a folder structure, i.e:
# #
# - dataset # - dataset
...@@ -308,7 +313,7 @@ frame_transform = t.Compose(transforms) ...@@ -308,7 +313,7 @@ frame_transform = t.Compose(transforms)
dataset = RandomDataset("./dataset", epoch_size=None, frame_transform=frame_transform) dataset = RandomDataset("./dataset", epoch_size=None, frame_transform=frame_transform)
#################################### # %%
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
loader = DataLoader(dataset, batch_size=12) loader = DataLoader(dataset, batch_size=12)
data = {"video": [], 'start': [], 'end': [], 'tensorsize': []} data = {"video": [], 'start': [], 'end': [], 'tensorsize': []}
...@@ -320,7 +325,7 @@ for batch in loader: ...@@ -320,7 +325,7 @@ for batch in loader:
data['tensorsize'].append(batch['video'][i].size()) data['tensorsize'].append(batch['video'][i].size())
print(data) print(data)
#################################### # %%
# 4. Data Visualization # 4. Data Visualization
# ---------------------------------- # ----------------------------------
# Example of visualized video # Example of visualized video
...@@ -333,7 +338,7 @@ for i in range(16): ...@@ -333,7 +338,7 @@ for i in range(16):
plt.imshow(batch["video"][0, i, ...].permute(1, 2, 0)) plt.imshow(batch["video"][0, i, ...].permute(1, 2, 0))
plt.axis("off") plt.axis("off")
#################################### # %%
# Cleanup the video and dataset: # Cleanup the video and dataset:
import os import os
import shutil import shutil
......
...@@ -3,6 +3,10 @@ ...@@ -3,6 +3,10 @@
Visualization utilities Visualization utilities
======================= =======================
.. note::
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_visualization_utils.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_others_plot_visualization_utils.py>` to download the full example code.
This example illustrates some of the utilities that torchvision offers for This example illustrates some of the utilities that torchvision offers for
visualizing images, bounding boxes, segmentation masks and keypoints. visualizing images, bounding boxes, segmentation masks and keypoints.
""" """
...@@ -30,7 +34,7 @@ def show(imgs): ...@@ -30,7 +34,7 @@ def show(imgs):
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
#################################### # %%
# Visualizing a grid of images # Visualizing a grid of images
# ---------------------------- # ----------------------------
# The :func:`~torchvision.utils.make_grid` function can be used to create a # The :func:`~torchvision.utils.make_grid` function can be used to create a
...@@ -41,14 +45,14 @@ from torchvision.utils import make_grid ...@@ -41,14 +45,14 @@ from torchvision.utils import make_grid
from torchvision.io import read_image from torchvision.io import read_image
from pathlib import Path from pathlib import Path
dog1_int = read_image(str(Path('assets') / 'dog1.jpg')) dog1_int = read_image(str(Path('../assets') / 'dog1.jpg'))
dog2_int = read_image(str(Path('assets') / 'dog2.jpg')) dog2_int = read_image(str(Path('../assets') / 'dog2.jpg'))
dog_list = [dog1_int, dog2_int] dog_list = [dog1_int, dog2_int]
grid = make_grid(dog_list) grid = make_grid(dog_list)
show(grid) show(grid)
#################################### # %%
# Visualizing bounding boxes # Visualizing bounding boxes
# -------------------------- # --------------------------
# We can use :func:`~torchvision.utils.draw_bounding_boxes` to draw boxes on an # We can use :func:`~torchvision.utils.draw_bounding_boxes` to draw boxes on an
...@@ -64,7 +68,7 @@ result = draw_bounding_boxes(dog1_int, boxes, colors=colors, width=5) ...@@ -64,7 +68,7 @@ result = draw_bounding_boxes(dog1_int, boxes, colors=colors, width=5)
show(result) show(result)
##################################### # %%
# Naturally, we can also plot bounding boxes produced by torchvision detection # Naturally, we can also plot bounding boxes produced by torchvision detection
# models. Here is a demo with a Faster R-CNN model loaded from # models. Here is a demo with a Faster R-CNN model loaded from
# :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` # :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn`
...@@ -85,7 +89,7 @@ model = model.eval() ...@@ -85,7 +89,7 @@ model = model.eval()
outputs = model(images) outputs = model(images)
print(outputs) print(outputs)
##################################### # %%
# Let's plot the boxes detected by our model. We will only plot the boxes with a # Let's plot the boxes detected by our model. We will only plot the boxes with a
# score greater than a given threshold. # score greater than a given threshold.
...@@ -96,7 +100,7 @@ dogs_with_boxes = [ ...@@ -96,7 +100,7 @@ dogs_with_boxes = [
] ]
show(dogs_with_boxes) show(dogs_with_boxes)
##################################### # %%
# Visualizing segmentation masks # Visualizing segmentation masks
# ------------------------------ # ------------------------------
# The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to # The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to
...@@ -125,7 +129,7 @@ batch = torch.stack([transforms(d) for d in dog_list]) ...@@ -125,7 +129,7 @@ batch = torch.stack([transforms(d) for d in dog_list])
output = model(batch)['out'] output = model(batch)['out']
print(output.shape, output.min().item(), output.max().item()) print(output.shape, output.min().item(), output.max().item())
##################################### # %%
# As we can see above, the output of the segmentation model is a tensor of shape # As we can see above, the output of the segmentation model is a tensor of shape
# ``(batch_size, num_classes, H, W)``. Each value is a non-normalized score, and # ``(batch_size, num_classes, H, W)``. Each value is a non-normalized score, and
# we can normalize them into ``[0, 1]`` by using a softmax. After the softmax, # we can normalize them into ``[0, 1]`` by using a softmax. After the softmax,
...@@ -147,7 +151,7 @@ dog_and_boat_masks = [ ...@@ -147,7 +151,7 @@ dog_and_boat_masks = [
show(dog_and_boat_masks) show(dog_and_boat_masks)
##################################### # %%
# As expected, the model is confident about the dog class, but not so much for # As expected, the model is confident about the dog class, but not so much for
# the boat class. # the boat class.
# #
...@@ -162,7 +166,7 @@ print(f"shape = {boolean_dog_masks.shape}, dtype = {boolean_dog_masks.dtype}") ...@@ -162,7 +166,7 @@ print(f"shape = {boolean_dog_masks.shape}, dtype = {boolean_dog_masks.dtype}")
show([m.float() for m in boolean_dog_masks]) show([m.float() for m in boolean_dog_masks])
##################################### # %%
# The line above where we define ``boolean_dog_masks`` is a bit cryptic, but you # The line above where we define ``boolean_dog_masks`` is a bit cryptic, but you
# can read it as the following query: "For which pixels is 'dog' the most likely # can read it as the following query: "For which pixels is 'dog' the most likely
# class?" # class?"
...@@ -184,11 +188,11 @@ dogs_with_masks = [ ...@@ -184,11 +188,11 @@ dogs_with_masks = [
] ]
show(dogs_with_masks) show(dogs_with_masks)
##################################### # %%
# We can plot more than one mask per image! Remember that the model returned as # We can plot more than one mask per image! Remember that the model returned as
# many masks as there are classes. Let's ask the same query as above, but this # many masks as there are classes. Let's ask the same query as above, but this
# time for *all* classes, not just the dog class: "For each pixel and each class # time for *all* classes, not just the dog class: "For each pixel and each class
# C, is class C the most most likely class?" # C, is class C the most likely class?"
# #
# This one is a bit more involved, so we'll first show how to do it with a # This one is a bit more involved, so we'll first show how to do it with a
# single image, and then we'll generalize to the batch # single image, and then we'll generalize to the batch
...@@ -204,7 +208,7 @@ print(f"dog1_all_classes_masks = {dog1_all_classes_masks.shape}, dtype = {dog1_a ...@@ -204,7 +208,7 @@ print(f"dog1_all_classes_masks = {dog1_all_classes_masks.shape}, dtype = {dog1_a
dog_with_all_masks = draw_segmentation_masks(dog1_int, masks=dog1_all_classes_masks, alpha=.6) dog_with_all_masks = draw_segmentation_masks(dog1_int, masks=dog1_all_classes_masks, alpha=.6)
show(dog_with_all_masks) show(dog_with_all_masks)
##################################### # %%
# We can see in the image above that only 2 masks were drawn: the mask for the # We can see in the image above that only 2 masks were drawn: the mask for the
# background and the mask for the dog. This is because the model thinks that # background and the mask for the dog. This is because the model thinks that
# only these 2 classes are the most likely ones across all the pixels. If the # only these 2 classes are the most likely ones across all the pixels. If the
...@@ -231,7 +235,7 @@ dogs_with_masks = [ ...@@ -231,7 +235,7 @@ dogs_with_masks = [
show(dogs_with_masks) show(dogs_with_masks)
##################################### # %%
# .. _instance_seg_output: # .. _instance_seg_output:
# #
# Instance segmentation models # Instance segmentation models
...@@ -265,7 +269,7 @@ model = model.eval() ...@@ -265,7 +269,7 @@ model = model.eval()
output = model(images) output = model(images)
print(output) print(output)
##################################### # %%
# Let's break this down. For each image in the batch, the model outputs some # Let's break this down. For each image in the batch, the model outputs some
# detections (or instances). The number of detections varies for each input # detections (or instances). The number of detections varies for each input
# image. Each instance is described by its bounding box, its label, its score # image. Each instance is described by its bounding box, its label, its score
...@@ -288,7 +292,7 @@ dog1_masks = dog1_output['masks'] ...@@ -288,7 +292,7 @@ dog1_masks = dog1_output['masks']
print(f"shape = {dog1_masks.shape}, dtype = {dog1_masks.dtype}, " print(f"shape = {dog1_masks.shape}, dtype = {dog1_masks.dtype}, "
f"min = {dog1_masks.min()}, max = {dog1_masks.max()}") f"min = {dog1_masks.min()}, max = {dog1_masks.max()}")
##################################### # %%
# Here the masks correspond to probabilities indicating, for each pixel, how # Here the masks correspond to probabilities indicating, for each pixel, how
# likely it is to belong to the predicted label of that instance. Those # likely it is to belong to the predicted label of that instance. Those
# predicted labels correspond to the 'labels' element in the same output dict. # predicted labels correspond to the 'labels' element in the same output dict.
...@@ -297,7 +301,7 @@ print(f"shape = {dog1_masks.shape}, dtype = {dog1_masks.dtype}, " ...@@ -297,7 +301,7 @@ print(f"shape = {dog1_masks.shape}, dtype = {dog1_masks.dtype}, "
print("For the first dog, the following instances were detected:") print("For the first dog, the following instances were detected:")
print([weights.meta["categories"][label] for label in dog1_output['labels']]) print([weights.meta["categories"][label] for label in dog1_output['labels']])
##################################### # %%
# Interestingly, the model detects two persons in the image. Let's go ahead and # Interestingly, the model detects two persons in the image. Let's go ahead and
# plot those masks. Since :func:`~torchvision.utils.draw_segmentation_masks` # plot those masks. Since :func:`~torchvision.utils.draw_segmentation_masks`
# expects boolean masks, we need to convert those probabilities into boolean # expects boolean masks, we need to convert those probabilities into boolean
...@@ -315,14 +319,14 @@ dog1_bool_masks = dog1_bool_masks.squeeze(1) ...@@ -315,14 +319,14 @@ dog1_bool_masks = dog1_bool_masks.squeeze(1)
show(draw_segmentation_masks(dog1_int, dog1_bool_masks, alpha=0.9)) show(draw_segmentation_masks(dog1_int, dog1_bool_masks, alpha=0.9))
##################################### # %%
# The model seems to have properly detected the dog, but it also confused trees # The model seems to have properly detected the dog, but it also confused trees
# with people. Looking more closely at the scores will help us plotting more # with people. Looking more closely at the scores will help us plot more
# relevant masks: # relevant masks:
print(dog1_output['scores']) print(dog1_output['scores'])
##################################### # %%
# Clearly the model is more confident about the dog detection than it is about # Clearly the model is more confident about the dog detection than it is about
# the people detections. That's good news. When plotting the masks, we can ask # the people detections. That's good news. When plotting the masks, we can ask
# for only those that have a good score. Let's use a score threshold of .75 # for only those that have a good score. Let's use a score threshold of .75
...@@ -341,12 +345,12 @@ dogs_with_masks = [ ...@@ -341,12 +345,12 @@ dogs_with_masks = [
] ]
show(dogs_with_masks) show(dogs_with_masks)
##################################### # %%
# The two 'people' masks in the first image where not selected because they have # The two 'people' masks in the first image where not selected because they have
# a lower score than the score threshold. Similarly in the second image, the # a lower score than the score threshold. Similarly, in the second image, the
# instance with class 15 (which corresponds to 'bench') was not selected. # instance with class 15 (which corresponds to 'bench') was not selected.
##################################### # %%
# .. _keypoint_output: # .. _keypoint_output:
# #
# Visualizing keypoints # Visualizing keypoints
...@@ -360,7 +364,7 @@ show(dogs_with_masks) ...@@ -360,7 +364,7 @@ show(dogs_with_masks)
from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
from torchvision.io import read_image from torchvision.io import read_image
person_int = read_image(str(Path("assets") / "person1.jpg")) person_int = read_image(str(Path("../assets") / "person1.jpg"))
weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
transforms = weights.transforms() transforms = weights.transforms()
...@@ -373,7 +377,7 @@ model = model.eval() ...@@ -373,7 +377,7 @@ model = model.eval()
outputs = model([person_float]) outputs = model([person_float])
print(outputs) print(outputs)
##################################### # %%
# As we see the output contains a list of dictionaries. # As we see the output contains a list of dictionaries.
# The output list is of length batch_size. # The output list is of length batch_size.
# We currently have just a single image so length of list is 1. # We currently have just a single image so length of list is 1.
...@@ -388,7 +392,7 @@ scores = outputs[0]['scores'] ...@@ -388,7 +392,7 @@ scores = outputs[0]['scores']
print(kpts) print(kpts)
print(scores) print(scores)
##################################### # %%
# The KeypointRCNN model detects there are two instances in the image. # The KeypointRCNN model detects there are two instances in the image.
# If you plot the boxes by using :func:`~draw_bounding_boxes` # If you plot the boxes by using :func:`~draw_bounding_boxes`
# you would recognize they are the person and the surfboard. # you would recognize they are the person and the surfboard.
...@@ -402,7 +406,7 @@ keypoints = kpts[idx] ...@@ -402,7 +406,7 @@ keypoints = kpts[idx]
print(keypoints) print(keypoints)
##################################### # %%
# Great, now we have the keypoints corresponding to the person. # Great, now we have the keypoints corresponding to the person.
# Each keypoint is represented by x, y coordinates and the visibility. # Each keypoint is represented by x, y coordinates and the visibility.
# We can now use the :func:`~torchvision.utils.draw_keypoints` function to draw keypoints. # We can now use the :func:`~torchvision.utils.draw_keypoints` function to draw keypoints.
...@@ -413,7 +417,7 @@ from torchvision.utils import draw_keypoints ...@@ -413,7 +417,7 @@ from torchvision.utils import draw_keypoints
res = draw_keypoints(person_int, keypoints, colors="blue", radius=3) res = draw_keypoints(person_int, keypoints, colors="blue", radius=3)
show(res) show(res)
##################################### # %%
# As we see the keypoints appear as colored circles over the image. # As we see the keypoints appear as colored circles over the image.
# The coco keypoints for a person are ordered and represent the following list.\ # The coco keypoints for a person are ordered and represent the following list.\
...@@ -424,7 +428,7 @@ coco_keypoints = [ ...@@ -424,7 +428,7 @@ coco_keypoints = [
"left_knee", "right_knee", "left_ankle", "right_ankle", "left_knee", "right_knee", "left_ankle", "right_ankle",
] ]
##################################### # %%
# What if we are interested in joining the keypoints? # What if we are interested in joining the keypoints?
# This is especially useful in creating pose detection or action recognition. # This is especially useful in creating pose detection or action recognition.
# We can join the keypoints easily using the `connectivity` parameter. # We can join the keypoints easily using the `connectivity` parameter.
...@@ -450,7 +454,7 @@ connect_skeleton = [ ...@@ -450,7 +454,7 @@ connect_skeleton = [
(7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14), (13, 15), (14, 16) (7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14), (13, 15), (14, 16)
] ]
##################################### # %%
# We pass the above list to the connectivity parameter to connect the keypoints. # We pass the above list to the connectivity parameter to connect the keypoints.
# #
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment