Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
224cbc83
Unverified
Commit
224cbc83
authored
Aug 25, 2023
by
Nicolas Hug
Committed by
GitHub
Aug 25, 2023
Browse files
Rewrite transforms v2 e2e example (#7881)
parent
92e4e9c6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
142 additions
and
103 deletions
+142
-103
gallery/v2_transforms/helpers.py
gallery/v2_transforms/helpers.py
+21
-8
gallery/v2_transforms/plot_transforms_v2.py
gallery/v2_transforms/plot_transforms_v2.py
+11
-7
gallery/v2_transforms/plot_transforms_v2_e2e.py
gallery/v2_transforms/plot_transforms_v2_e2e.py
+110
-88
No files found.
gallery/v2_transforms/helpers.py
View file @
224cbc83
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
from
torchvision.utils
import
draw_bounding_boxes
import
torch
from
torchvision.utils
import
draw_bounding_boxes
,
draw_segmentation_masks
from
torchvision
import
datapoints
from
torchvision.transforms.v2
import
functional
as
F
def
plot
(
imgs
):
def
plot
(
imgs
):
...
@@ -12,20 +15,30 @@ def plot(imgs):
...
@@ -12,20 +15,30 @@ def plot(imgs):
_
,
axs
=
plt
.
subplots
(
nrows
=
num_rows
,
ncols
=
num_cols
,
squeeze
=
False
)
_
,
axs
=
plt
.
subplots
(
nrows
=
num_rows
,
ncols
=
num_cols
,
squeeze
=
False
)
for
row_idx
,
row
in
enumerate
(
imgs
):
for
row_idx
,
row
in
enumerate
(
imgs
):
for
col_idx
,
img
in
enumerate
(
row
):
for
col_idx
,
img
in
enumerate
(
row
):
bboxes
=
None
boxes
=
None
masks
=
None
if
isinstance
(
img
,
tuple
):
if
isinstance
(
img
,
tuple
):
bboxes
=
img
[
1
]
img
,
target
=
img
img
=
img
[
0
]
if
isinstance
(
target
,
dict
):
if
isinstance
(
bboxes
,
dict
):
boxes
=
target
.
get
(
"boxes"
)
bboxes
=
bboxes
[
'bboxes'
]
masks
=
target
.
get
(
"masks"
)
elif
isinstance
(
target
,
datapoints
.
BoundingBoxes
):
boxes
=
target
else
:
raise
ValueError
(
f
"Unexpected target type:
{
type
(
target
)
}
"
)
img
=
F
.
to_image
(
img
)
if
img
.
dtype
.
is_floating_point
and
img
.
min
()
<
0
:
if
img
.
dtype
.
is_floating_point
and
img
.
min
()
<
0
:
# Poor man's re-normalization for the colors to be OK-ish. This
# Poor man's re-normalization for the colors to be OK-ish. This
# is useful for images coming out of Normalize()
# is useful for images coming out of Normalize()
img
-=
img
.
min
()
img
-=
img
.
min
()
img
/=
img
.
max
()
img
/=
img
.
max
()
if
bboxes
is
not
None
:
img
=
F
.
to_dtype
(
img
,
torch
.
uint8
,
scale
=
True
)
img
=
draw_bounding_boxes
(
img
,
bboxes
,
colors
=
"yellow"
,
width
=
3
)
if
boxes
is
not
None
:
img
=
draw_bounding_boxes
(
img
,
boxes
,
colors
=
"yellow"
,
width
=
3
)
if
masks
is
not
None
:
img
=
draw_segmentation_masks
(
img
,
masks
.
to
(
torch
.
bool
),
colors
=
[
"green"
]
*
masks
.
shape
[
0
],
alpha
=
.
65
)
ax
=
axs
[
row_idx
,
col_idx
]
ax
=
axs
[
row_idx
,
col_idx
]
ax
.
imshow
(
img
.
permute
(
1
,
2
,
0
).
numpy
())
ax
.
imshow
(
img
.
permute
(
1
,
2
,
0
).
numpy
())
ax
.
set
(
xticklabels
=
[],
yticklabels
=
[],
xticks
=
[],
yticks
=
[])
ax
.
set
(
xticklabels
=
[],
yticklabels
=
[],
xticks
=
[],
yticks
=
[])
...
...
gallery/v2_transforms/plot_transforms_v2.py
View file @
224cbc83
...
@@ -90,7 +90,7 @@ plot([img, out])
...
@@ -90,7 +90,7 @@ plot([img, out])
from
torchvision
import
datapoints
# we'll describe this a bit later, bare with us
from
torchvision
import
datapoints
# we'll describe this a bit later, bare with us
b
boxes
=
datapoints
.
BoundingBoxes
(
boxes
=
datapoints
.
BoundingBoxes
(
[
[
[
15
,
10
,
370
,
510
],
[
15
,
10
,
370
,
510
],
[
275
,
340
,
510
,
510
],
[
275
,
340
,
510
,
510
],
...
@@ -103,9 +103,10 @@ transforms = v2.Compose([
...
@@ -103,9 +103,10 @@ transforms = v2.Compose([
v2
.
RandomPhotometricDistort
(
p
=
1
),
v2
.
RandomPhotometricDistort
(
p
=
1
),
v2
.
RandomHorizontalFlip
(
p
=
1
),
v2
.
RandomHorizontalFlip
(
p
=
1
),
])
])
out_img
,
out_bboxes
=
transforms
(
img
,
bboxes
)
out_img
,
out_boxes
=
transforms
(
img
,
boxes
)
print
(
type
(
boxes
),
type
(
out_boxes
))
plot
([(
img
,
b
boxes
),
(
out_img
,
out_
b
boxes
)])
plot
([(
img
,
boxes
),
(
out_img
,
out_boxes
)])
# %%
# %%
#
#
...
@@ -119,6 +120,9 @@ plot([(img, bboxes), (out_img, out_bboxes)])
...
@@ -119,6 +120,9 @@ plot([(img, bboxes), (out_img, out_bboxes)])
# answer these in the next sections.
# answer these in the next sections.
# %%
# %%
#
# .. _what_are_datapoints:
#
# What are Datapoints?
# What are Datapoints?
# --------------------
# --------------------
#
#
...
@@ -151,7 +155,7 @@ print(f"{img_dp.dtype = }, {img_dp.shape = }, {img_dp.sum() = }")
...
@@ -151,7 +155,7 @@ print(f"{img_dp.dtype = }, {img_dp.shape = }, {img_dp.sum() = }")
#
#
# Above, we've seen two examples: one where we passed a single image as input
# Above, we've seen two examples: one where we passed a single image as input
# i.e. ``out = transforms(img)``, and one where we passed both an image and
# i.e. ``out = transforms(img)``, and one where we passed both an image and
# bounding boxes, i.e. ``out_img, out_
b
boxes = transforms(img,
b
boxes)``.
# bounding boxes, i.e. ``out_img, out_boxes = transforms(img, boxes)``.
#
#
# In fact, transforms support **arbitrary input structures**. The input can be a
# In fact, transforms support **arbitrary input structures**. The input can be a
# single image, a tuple, an arbitrarily nested dictionary... pretty much
# single image, a tuple, an arbitrarily nested dictionary... pretty much
...
@@ -160,15 +164,15 @@ print(f"{img_dp.dtype = }, {img_dp.shape = }, {img_dp.sum() = }")
...
@@ -160,15 +164,15 @@ print(f"{img_dp.dtype = }, {img_dp.shape = }, {img_dp.sum() = }")
# we're getting the same structure as output:
# we're getting the same structure as output:
target
=
{
target
=
{
"
b
boxes"
:
b
boxes
,
"boxes"
:
boxes
,
"labels"
:
torch
.
arange
(
b
boxes
.
shape
[
0
]),
"labels"
:
torch
.
arange
(
boxes
.
shape
[
0
]),
"this_is_ignored"
:
(
"arbitrary"
,
{
"structure"
:
"!"
})
"this_is_ignored"
:
(
"arbitrary"
,
{
"structure"
:
"!"
})
}
}
# Re-using the transforms and definitions from above.
# Re-using the transforms and definitions from above.
out_img
,
out_target
=
transforms
(
img
,
target
)
out_img
,
out_target
=
transforms
(
img
,
target
)
plot
([(
img
,
target
[
"
b
boxes"
]),
(
out_img
,
out_target
[
"
b
boxes"
])])
plot
([(
img
,
target
[
"boxes"
]),
(
out_img
,
out_target
[
"boxes"
])])
print
(
f
"
{
out_target
[
'this_is_ignored'
]
}
"
)
print
(
f
"
{
out_target
[
'this_is_ignored'
]
}
"
)
# %%
# %%
...
...
gallery/v2_transforms/plot_transforms_v2_e2e.py
View file @
224cbc83
"""
"""
==================================================
==================================================
=============
Transforms v2: End-to-end object detection example
Transforms v2: End-to-end object detection
/segmentation
example
==================================================
==================================================
=============
.. note::
.. note::
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_transforms_v2_e2e.ipynb>`_
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_transforms_v2_e2e.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_v2_transforms_plot_transforms_v2_e2e.py>` to download the full example code.
or :ref:`go to the end <sphx_glr_download_auto_examples_v2_transforms_plot_transforms_v2_e2e.py>` to download the full example code.
Object detection is not supported out of the box by ``torchvision.transforms`` v1, since it only supports images.
Object detection and segmentation tasks are natively supported:
``torchvision.transforms.v2`` enables jointly transforming images, videos, bounding boxes, and masks. This example
``torchvision.transforms.v2`` enables jointly transforming images, videos,
showcases an end-to-end object detection training using the stable ``torchvision.datasets`` and ``torchvision.models``
bounding boxes, and masks.
as well as the new ``torchvision.transforms.v2`` v2 API.
This example showcases an end-to-end instance segmentation training case using
Torchvision utils from ``torchvision.datasets``, ``torchvision.models`` and
``torchvision.transforms.v2``. Everything covered here can be applied similarly
to object detection or semantic segmentation tasks.
"""
"""
# %%
import
pathlib
import
pathlib
import
PIL.Image
import
torch
import
torch
import
torch.utils.data
import
torch.utils.data
from
torchvision
import
models
,
datasets
from
torchvision
import
models
,
datasets
,
datapoints
import
torchvision.transforms.v2
as
transforms
from
torchvision.transforms
import
v2
def
show
(
sample
):
import
matplotlib.pyplot
as
plt
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.utils
import
draw_bounding_boxes
image
,
target
=
sample
if
isinstance
(
image
,
PIL
.
Image
.
Image
):
image
=
F
.
to_image
(
image
)
image
=
F
.
to_dtype
(
image
,
torch
.
uint8
,
scale
=
True
)
annotated_image
=
draw_bounding_boxes
(
image
,
target
[
"boxes"
],
colors
=
"yellow"
,
width
=
3
)
fig
,
ax
=
plt
.
subplots
()
torch
.
manual_seed
(
0
)
ax
.
imshow
(
annotated_image
.
permute
(
1
,
2
,
0
).
numpy
())
ax
.
set
(
xticklabels
=
[],
yticklabels
=
[],
xticks
=
[],
yticks
=
[])
fig
.
tight_layout
()
fig
.
show
()
# This loads fake data for illustration purposes of this example. In practice, you'll have
# to replace this with the proper data.
# If you're trying to run that on collab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
ROOT
=
pathlib
.
Path
(
"../assets"
)
/
"coco"
IMAGES_PATH
=
str
(
ROOT
/
"images"
)
ANNOTATIONS_PATH
=
str
(
ROOT
/
"instances.json"
)
from
helpers
import
plot
# %%
# %%
# Dataset preparation
# -------------------
#
# We start off by loading the :class:`~torchvision.datasets.CocoDetection` dataset to have a look at what it currently
# We start off by loading the :class:`~torchvision.datasets.CocoDetection` dataset to have a look at what it currently
# returns, and we'll see how to convert it to a format that is compatible with our new transforms.
# returns.
def
load_example_coco_detection_dataset
(
**
kwargs
):
# This loads fake data for illustration purposes of this example. In practice, you'll have
# to replace this with the proper data
root
=
pathlib
.
Path
(
"../assets"
)
/
"coco"
return
datasets
.
CocoDetection
(
str
(
root
/
"images"
),
str
(
root
/
"instances.json"
),
**
kwargs
)
dataset
=
load_example_c
oco
_d
etection
_dataset
(
)
dataset
=
datasets
.
C
oco
D
etection
(
IMAGES_PATH
,
ANNOTATIONS_PATH
)
sample
=
dataset
[
0
]
sample
=
dataset
[
0
]
image
,
target
=
sample
img
,
target
=
sample
print
(
type
(
image
))
print
(
f
"
{
type
(
img
)
=
}
\n
{
type
(
target
)
=
}
\n
{
type
(
target
[
0
])
=
}
\n
{
target
[
0
].
keys
()
=
}
"
)
print
(
type
(
target
),
type
(
target
[
0
]),
list
(
target
[
0
].
keys
()))
# %%
# %%
# The dataset returns a two-tuple with the first item being a :class:`PIL.Image.Image` and second one a list of
# Torchvision datasets preserve the data structure and types as it was intended
# dictionaries, which each containing the annotations for a single object instance. As is, this format is not compatible
# by the datasets authors. So by default, the output structure may not always be
# with the ``torchvision.transforms.v2``, nor with the models. To overcome that, we provide the
# compatible with the models or the transforms.
#
# To overcome that, we can use the
# :func:`~torchvision.datasets.wrap_dataset_for_transforms_v2` function. For
# :func:`~torchvision.datasets.wrap_dataset_for_transforms_v2` function. For
# :class:`~torchvision.datasets.CocoDetection`, this changes the target structure to a single dictionary of lists. It
# :class:`~torchvision.datasets.CocoDetection`, this changes the target
# also adds the key-value-pairs ``"boxes"``, ``"masks"``, and ``"labels"`` wrapped in the corresponding
# structure to a single dictionary of lists:
# ``torchvision.datapoints``. By default, it only returns ``"boxes"`` and ``"labels"`` to avoid transforming unnecessary
# items down the line, but you can pass the ``target_type`` parameter for fine-grained control.
dataset
=
datasets
.
wrap_dataset_for_transforms_v2
(
dataset
)
dataset
=
datasets
.
wrap_dataset_for_transforms_v2
(
dataset
,
target_keys
=
(
"boxes"
,
"labels"
,
"masks"
)
)
sample
=
dataset
[
0
]
sample
=
dataset
[
0
]
image
,
target
=
sample
img
,
target
=
sample
print
(
type
(
image
))
print
(
f
"
{
type
(
img
)
=
}
\n
{
type
(
target
)
=
}
\n
{
target
.
keys
()
=
}
"
)
print
(
type
(
target
),
list
(
target
.
keys
()))
print
(
f
"
{
type
(
target
[
'boxes'
])
=
}
\n
{
type
(
target
[
'labels'
])
=
}
\n
{
type
(
target
[
'masks'
])
=
}
"
)
print
(
type
(
target
[
"boxes"
]),
type
(
target
[
"labels"
]))
# %%
# %%
# We used the ``target_keys`` parameter to specify the kind of output we're
# interested in. Our dataset now returns a target which is dict where the values
# are :ref:`Datapoints <what_are_datapoints>` (all are :class:`torch.Tensor`
# subclasses). We're dropped all unncessary keys from the previous output, but
# if you need any of the original keys e.g. "image_id", you can still ask for
# it.
#
# .. note::
#
# If you just want to do detection, you don't need and shouldn't pass
# "masks" in ``target_keys``: if masks are present in the sample, they will
# be transformed, slowing down your transformations unnecessarily.
#
# As baseline, let's have a look at a sample without transformations:
# As baseline, let's have a look at a sample without transformations:
show
(
sample
)
plot
([
dataset
[
0
],
dataset
[
1
]]
)
# %%
# %%
# With the dataset properly set up, we can now define the augmentation pipeline. This is done the same way it is done in
# Transforms
# ``torchvision.transforms`` v1, but now handles bounding boxes and masks without any extra configuration.
# ----------
#
# Let's now define our pre-processing transforms. All the transforms know how
# to handle images, bouding boxes and masks when relevant.
#
# Transforms are typically passed as the ``transforms`` parameter of the
# dataset so that they can leverage multi-processing from the
# :class:`torch.utils.data.DataLoader`.
transform
=
transforms
.
Compose
(
transform
s
=
v2
.
Compose
(
[
[
transforms
.
RandomPhotometricDistort
(),
v2
.
ToImage
(),
transforms
.
RandomZoomOut
(
fill
=
{
PIL
.
Image
.
Image
:
(
123
,
117
,
104
),
"others"
:
0
}
),
v2
.
RandomPhotometricDistort
(
p
=
1
),
transforms
.
RandomIoUCrop
(
),
v2
.
RandomZoomOut
(
fill
=
{
datapoints
.
Image
:
(
123
,
117
,
104
),
"others"
:
0
}
),
transforms
.
RandomHorizontalFli
p
(),
v2
.
RandomIoUCro
p
(),
transforms
.
ToImage
(
),
v2
.
RandomHorizontalFlip
(
p
=
1
),
transforms
.
ConvertImageDtype
(
torch
.
float32
),
v2
.
SanitizeBoundingBoxes
(
),
transforms
.
SanitizeBoundingBoxes
(
),
v2
.
ToDtype
(
torch
.
float32
,
scale
=
True
),
]
]
)
)
dataset
=
datasets
.
CocoDetection
(
IMAGES_PATH
,
ANNOTATIONS_PATH
,
transforms
=
transforms
)
dataset
=
datasets
.
wrap_dataset_for_transforms_v2
(
dataset
,
target_keys
=
[
"boxes"
,
"labels"
,
"masks"
])
# %%
# %%
# .. note::
# A few things are worth noting here:
# Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` transform is a no-op in this example, but it
#
# should be placed at least once at the end of a detection pipeline to remove degenerate bounding boxes as well as
# - We're converting the PIL image into a
# the corresponding labels and optionally masks. It is particularly critical to add it if
# :class:`~torchvision.transforms.v2.Image` object. This isn't strictly
# :class:`~torchvision.transforms.v2.RandomIoUCrop` was used.
# necessary, but relying on Tensors (here: a Tensor subclass) will
# :ref:`generally be faster <transforms_perf>`.
# - We are calling :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` to
# make sure we remove degenerate bounding boxes, as well as their
# corresponding labels and masks.
# :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` should be placed
# at least once at the end of a detection pipeline; it is particularly
# critical if :class:`~torchvision.transforms.v2.RandomIoUCrop` was used.
#
#
# Let's look how the sample looks like with our augmentation pipeline in place:
# Let's look how the sample looks like with our augmentation pipeline in place:
dataset
=
load_example_coco_detection_dataset
(
transforms
=
transform
)
dataset
=
datasets
.
wrap_dataset_for_transforms_v2
(
dataset
)
torch
.
manual_seed
(
3141
)
sample
=
dataset
[
0
]
# sphinx_gallery_thumbnail_number = 2
# sphinx_gallery_thumbnail_number = 2
show
(
sample
)
plot
([
dataset
[
0
],
dataset
[
1
]]
)
# %%
# %%
# We can see that the color of the image was distorted, we zoomed out on it (off center) and flipped it horizontally.
# We can see that the color of the images were distorted, zoomed in or out, and flipped.
# In all of this, the bounding box was transformed accordingly. And without any further ado, we can start training.
# The bounding boxes and the masks were transformed accordingly. And without any further ado, we can start training.
#
# Data loading and training loop
# ------------------------------
#
# Below we're using Mask-RCNN which is an instance segmentation model, but
# everything we've covered in this tutorial also applies to object detection and
# semantic segmentation tasks.
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
dataset
,
batch_size
=
2
,
batch_size
=
2
,
# We need a custom collation function here, since the object detection
models expect a
# We need a custom collation function here, since the object detection
# sequence of images and target dictionaries. The default
collation function tries to
#
models expect a
sequence of images and target dictionaries. The default
#
`
torch.stack
`
the individual elements,
which fails in general for object detection,
#
collation function tries to
torch.stack
()
the individual elements,
#
because the number of object instances varies between the samples. This is the same for
#
which fails in general for object detection, because the number of bouding
#
`torchvision.transforms` v1
#
boxes varies between the images of a same batch.
collate_fn
=
lambda
batch
:
tuple
(
zip
(
*
batch
)),
collate_fn
=
lambda
batch
:
tuple
(
zip
(
*
batch
)),
)
)
model
=
models
.
get_model
(
"
ssd300_vgg16
"
,
weights
=
None
,
weights_backbone
=
None
).
train
()
model
=
models
.
get_model
(
"
maskrcnn_resnet50_fpn_v2
"
,
weights
=
None
,
weights_backbone
=
None
).
train
()
for
images
,
targets
in
data_loader
:
for
imgs
,
targets
in
data_loader
:
loss_dict
=
model
(
images
,
targets
)
loss_dict
=
model
(
imgs
,
targets
)
print
(
loss_dict
)
# Put your training logic here
# Put your training logic here
break
print
(
f
"
{
[
img
.
shape
for
img
in
imgs
]
=
}
"
)
print
(
f
"
{
[
type
(
target
)
for
target
in
targets
]
=
}
"
)
for
name
,
loss_val
in
loss_dict
.
items
():
print
(
f
"
{
name
:
<
20
}{
loss_val
:.
3
f
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment