Unverified Commit c786d755 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add end-to-end example gallery for transforms v2 (#7302)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent ed48bb1c
...@@ -5,3 +5,4 @@ sphinx-gallery>=0.11.1 ...@@ -5,3 +5,4 @@ sphinx-gallery>=0.11.1
sphinx==5.0.0 sphinx==5.0.0
tabulate tabulate
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
pycocotools
../../astronaut.jpg
\ No newline at end of file
../../dog2.jpg
\ No newline at end of file
{"images": [{"file_name": "000000000001.jpg", "height": 512, "width": 512, "id": 1}, {"file_name": "000000000002.jpg", "height": 500, "width": 500, "id": 2}], "annotations": [{"segmentation": [[40.0, 511.0, 26.0, 487.0, 28.0, 438.0, 17.0, 397.0, 24.0, 346.0, 38.0, 306.0, 61.0, 250.0, 111.0, 206.0, 111.0, 187.0, 120.0, 183.0, 136.0, 159.0, 159.0, 150.0, 181.0, 148.0, 182.0, 132.0, 175.0, 132.0, 168.0, 120.0, 154.0, 102.0, 153.0, 62.0, 188.0, 35.0, 191.0, 29.0, 208.0, 20.0, 210.0, 22.0, 227.0, 16.0, 240.0, 16.0, 276.0, 31.0, 285.0, 39.0, 301.0, 88.0, 297.0, 108.0, 281.0, 128.0, 273.0, 138.0, 266.0, 138.0, 264.0, 153.0, 257.0, 162.0, 256.0, 174.0, 284.0, 197.0, 300.0, 221.0, 303.0, 236.0, 337.0, 258.0, 357.0, 306.0, 361.0, 351.0, 358.0, 511.0]], "iscrowd": 0, "image_id": 1, "bbox": [17.0, 16.0, 344.0, 495.0], "category_id": 1, "id": 1}, {"segmentation": [[0.0, 411.0, 43.0, 401.0, 99.0, 395.0, 105.0, 351.0, 124.0, 326.0, 181.0, 294.0, 227.0, 280.0, 245.0, 262.0, 259.0, 234.0, 262.0, 207.0, 271.0, 140.0, 283.0, 139.0, 301.0, 162.0, 309.0, 181.0, 341.0, 175.0, 362.0, 139.0, 369.0, 139.0, 377.0, 163.0, 378.0, 203.0, 381.0, 212.0, 380.0, 220.0, 382.0, 242.0, 404.0, 264.0, 392.0, 293.0, 384.0, 295.0, 385.0, 316.0, 399.0, 343.0, 391.0, 448.0, 452.0, 475.0, 457.0, 494.0, 436.0, 498.0, 402.0, 491.0, 369.0, 488.0, 366.0, 496.0, 319.0, 496.0, 302.0, 485.0, 226.0, 469.0, 128.0, 456.0, 74.0, 458.0, 29.0, 439.0, 0.0, 445.0]], "iscrowd": 0, "image_id": 2, "bbox": [0.0, 139.0, 457.0, 359.0], "category_id": 18, "id": 2}]}
"""
==================================================
transforms v2: End-to-end object detection example
==================================================
Object detection is not supported out of the box by ``torchvision.transforms`` v1, since it only supports images.
``torchvision.transforms.v2`` enables jointly transforming images, videos, bounding boxes, and masks. This example
showcases an end-to-end object detection training using the stable ``torchvisio.datasets`` and ``torchvision.models`` as
well as the new ``torchvision.transforms.v2`` v2 API.
"""
import pathlib
from collections import defaultdict
import PIL.Image
import torch
import torch.utils.data
import torchvision
# sphinx_gallery_thumbnail_number = -1
def show(sample):
import matplotlib.pyplot as plt
from torchvision.transforms.v2 import functional as F
from torchvision.utils import draw_bounding_boxes
image, target = sample
if isinstance(image, PIL.Image.Image):
image = F.to_image_tensor(image)
image = F.convert_dtype(image, torch.uint8)
annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3)
fig, ax = plt.subplots()
ax.imshow(annotated_image.permute(1, 2, 0).numpy())
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
fig.tight_layout()
fig.show()
# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
# some APIs may slightly change in the future
torchvision.disable_beta_transforms_warning()
from torchvision import models, datasets
import torchvision.transforms.v2 as transforms
########################################################################################################################
# We start off by loading the :class:`~torchvision.datasets.CocoDetection` dataset to have a look at what it currently
# returns, and we'll see how to convert it to a format that is compatible with our new transforms.
def load_example_coco_detection_dataset(**kwargs):
# This loads fake data for illustration purposes of this example. In practice, you'll have
# to replace this with the proper data
root = pathlib.Path("assets") / "coco"
return datasets.CocoDetection(str(root / "images"), str(root / "instances.json"), **kwargs)
dataset = load_example_coco_detection_dataset()
sample = dataset[0]
image, target = sample
print(type(image))
print(type(target), type(target[0]), list(target[0].keys()))
########################################################################################################################
# The dataset returns a two-tuple with the first item being a :class:`PIL.Image.Image` and second one a list of
# dictionaries, which each containing the annotations for a single object instance. As is, this format is not compatible
# with the ``torchvision.transforms.v2``, nor with the models. To overcome that, we provide the
# :func:`~torchvision.datasets.wrap_dataset_for_transforms_v2` function. For
# :class:`~torchvision.datasets.CocoDetection`, this changes the target structure to a single dictionary of lists. It
# also adds the key-value-pairs ``"boxes"``, ``"masks"``, and ``"labels"`` wrapped in the corresponding
# ``torchvision.datapoints``.
dataset = datasets.wrap_dataset_for_transforms_v2(dataset)
sample = dataset[0]
image, target = sample
print(type(image))
print(type(target), list(target.keys()))
print(type(target["boxes"]), type(target["masks"]), type(target["labels"]))
########################################################################################################################
# As baseline, let's have a look at a sample without transformations:
show(sample)
########################################################################################################################
# With the dataset properly set up, we can now define the augmentation pipeline. This is done the same way it is done in
# ``torchvision.transforms`` v1, but now handles bounding boxes and masks without any extra configuration.
transform = transforms.Compose(
[
transforms.RandomPhotometricDistort(),
transforms.RandomZoomOut(
fill=defaultdict(lambda: 0, {PIL.Image.Image: (123, 117, 104)})
),
transforms.RandomIoUCrop(),
transforms.RandomHorizontalFlip(),
transforms.ToImageTensor(),
transforms.ConvertImageDtype(torch.float32),
transforms.SanitizeBoundingBoxes(),
]
)
########################################################################################################################
# .. note::
# Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` transform is a no-op in this example, but it
# should be placed at least once at the end of a detection pipeline to remove degenerate bounding boxes as well as
# the corresponding labels and optionally masks. It is particularly critical to add it if
# :class:`~torchvision.transforms.v2.RandomIoUCrop` was used.
#
# Let's look how the sample looks like with our augmentation pipeline in place:
dataset = load_example_coco_detection_dataset(transforms=transform)
dataset = datasets.wrap_dataset_for_transforms_v2(dataset)
torch.manual_seed(3141)
sample = dataset[0]
show(sample)
########################################################################################################################
# We can see that the color of the image was distorted, we zoomed out on it (off center) and flipped it horizontally.
# In all of this, the bounding box was transformed accordingly. And without any further ado, we can start training.
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=2,
# We need a custom collation function here, since the object detection models expect a
# sequence of images and target dictionaries. The default collation function tries to
# `torch.stack` the individual elements, which fails in general for object detection,
# because the number of object instances varies between the samples. This is the same for
# `torchvision.transforms` v1
collate_fn=lambda batch: tuple(zip(*batch)),
)
model = models.get_model("ssd300_vgg16", weights=None, weights_backbone=None).train()
for images, targets in data_loader:
loss_dict = model(images, targets)
print(loss_dict)
# Put your training logic here
break
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment