Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
9b4ec8df
Unverified
Commit
9b4ec8df
authored
Jul 31, 2023
by
Nicolas Hug
Committed by
GitHub
Jul 31, 2023
Browse files
Add gallery example for MixUp and CutMix (#7772)
parent
8d4e8793
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
171 additions
and
13 deletions
+171
-13
docs/source/transforms.rst
docs/source/transforms.rst
+4
-4
gallery/plot_cutmix_mixup.py
gallery/plot_cutmix_mixup.py
+146
-2
test/test_transforms_v2_refactored.py
test/test_transforms_v2_refactored.py
+1
-1
torchvision/transforms/v2/_augment.py
torchvision/transforms/v2/_augment.py
+20
-6
No files found.
docs/source/transforms.rst
View file @
9b4ec8df
...
@@ -261,13 +261,13 @@ The new transform can be used standalone or mixed-and-matched with existing tran
...
@@ -261,13 +261,13 @@ The new transform can be used standalone or mixed-and-matched with existing tran
AugMix
AugMix
v2.AugMix
v2.AugMix
Cut
m
ix - Mix
u
p
Cut
M
ix - Mix
U
p
--------------
--------------
Cut
m
ix and Mix
u
p are special transforms that
Cut
M
ix and Mix
U
p are special transforms that
are meant to be used on batches rather than on individual images, because they
are meant to be used on batches rather than on individual images, because they
are combining pairs of images together. These can be used after the dataloader
,
are combining pairs of images together. These can be used after the dataloader
or part of a collation function. See
(once the samples are batched),
or part of a collation function. See
:ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples.
:ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples.
.. autosummary::
.. autosummary::
...
...
gallery/plot_cutmix_mixup.py
View file @
9b4ec8df
"""
"""
===========================
===========================
How to use Cut
m
ix and Mix
u
p
How to use Cut
M
ix and Mix
U
p
===========================
===========================
TODO
:class:`~torchvision.transforms.v2.Cutmix` and
:class:`~torchvision.transforms.v2.Mixup` are popular augmentation strategies
that can improve classification accuracy.
These transforms are slightly different from the rest of the Torchvision
transforms, because they expect
**batches** of samples as input, not individual images. In this example we'll
explain how to use them: after the ``DataLoader``, or as part of a collation
function.
"""
"""
# %%
import
torch
import
torchvision
from
torchvision.datasets
import
FakeData
# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
# some APIs may slightly change in the future
torchvision
.
disable_beta_transforms_warning
()
from
torchvision.transforms
import
v2
NUM_CLASSES
=
100
# %%
# Pre-processing pipeline
# -----------------------
#
# We'll use a simple but typical image classification pipeline:
preproc
=
v2
.
Compose
([
v2
.
PILToTensor
(),
v2
.
RandomResizedCrop
(
size
=
(
224
,
224
),
antialias
=
True
),
v2
.
RandomHorizontalFlip
(
p
=
0.5
),
v2
.
ToDtype
(
torch
.
float32
,
scale
=
True
),
# to float32 in [0, 1]
v2
.
Normalize
(
mean
=
(
0.485
,
0.456
,
0.406
),
std
=
(
0.229
,
0.224
,
0.225
)),
# typically from ImageNet
])
dataset
=
FakeData
(
size
=
1000
,
num_classes
=
NUM_CLASSES
,
transform
=
preproc
)
img
,
label
=
dataset
[
0
]
print
(
f
"
{
type
(
img
)
=
}
,
{
img
.
dtype
=
}
,
{
img
.
shape
=
}
,
{
label
=
}
"
)
# %%
#
# One important thing to note is that neither CutMix nor MixUp are part of this
# pre-processing pipeline. We'll add them a bit later once we define the
# DataLoader. Just as a refresher, this is what the DataLoader and training loop
# would look like if we weren't using CutMix or MixUp:
from
torch.utils.data
import
DataLoader
dataloader
=
DataLoader
(
dataset
,
batch_size
=
4
,
shuffle
=
True
)
for
images
,
labels
in
dataloader
:
print
(
f
"
{
images
.
shape
=
}
,
{
labels
.
shape
=
}
"
)
print
(
labels
.
dtype
)
# <rest of the training loop here>
break
# %%
# %%
# Where to use MixUp and CutMix
# -----------------------------
#
# After the DataLoader
# ^^^^^^^^^^^^^^^^^^^^
#
# Now let's add CutMix and MixUp. The simplest way to do this right after the
# DataLoader: the Dataloader has already batched the images and labels for us,
# and this is exactly what these transforms expect as input:
dataloader
=
DataLoader
(
dataset
,
batch_size
=
4
,
shuffle
=
True
)
cutmix
=
v2
.
Cutmix
(
num_classes
=
NUM_CLASSES
)
mixup
=
v2
.
Mixup
(
num_classes
=
NUM_CLASSES
)
cutmix_or_mixup
=
v2
.
RandomChoice
([
cutmix
,
mixup
])
for
images
,
labels
in
dataloader
:
print
(
f
"Before CutMix/MixUp:
{
images
.
shape
=
}
,
{
labels
.
shape
=
}
"
)
images
,
labels
=
cutmix_or_mixup
(
images
,
labels
)
print
(
f
"After CutMix/MixUp:
{
images
.
shape
=
}
,
{
labels
.
shape
=
}
"
)
# <rest of the training loop here>
break
# %%
#
# Note how the labels were also transformed: we went from a batched label of
# shape (batch_size,) to a tensor of shape (batch_size, num_classes). The
# transformed labels can still be passed as-is to a loss function like
# :func:`torch.nn.functional.cross_entropy`.
#
# As part of the collation function
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Passing the transforms after the DataLoader is the simplest way to use CutMix
# and MixUp, but one disadvantage is that it does not take advantage of the
# DataLoader multi-processing. For that, we can pass those transforms as part of
# the collation function (refer to the `PyTorch docs
# <https://pytorch.org/docs/stable/data.html#dataloader-collate-fn>`_ to learn
# more about collation).
from
torch.utils.data
import
default_collate
def
collate_fn
(
batch
):
return
cutmix_or_mixup
(
*
default_collate
(
batch
))
dataloader
=
DataLoader
(
dataset
,
batch_size
=
4
,
shuffle
=
True
,
num_workers
=
2
,
collate_fn
=
collate_fn
)
for
images
,
labels
in
dataloader
:
print
(
f
"
{
images
.
shape
=
}
,
{
labels
.
shape
=
}
"
)
# No need to call cutmix_or_mixup, it's already been called as part of the DataLoader!
# <rest of the training loop here>
break
# %%
# Non-standard input format
# -------------------------
#
# So far we've used a typical sample structure where we pass ``(images,
# labels)`` as inputs. MixUp and CutMix will magically work by default with most
# common sample structures: tuples where the second parameter is a tensor label,
# or dict with a "label[s]" key. Look at the documentation of the
# ``labels_getter`` parameter for more details.
#
# If your samples have a different structure, you can still use CutMix and MixUp
# by passing a callable to the ``labels_getter`` parameter. For example:
batch
=
{
"imgs"
:
torch
.
rand
(
4
,
3
,
224
,
224
),
"target"
:
{
"classes"
:
torch
.
randint
(
0
,
NUM_CLASSES
,
size
=
(
4
,)),
"some_other_key"
:
"this is going to be passed-through"
}
}
def
labels_getter
(
batch
):
return
batch
[
"target"
][
"classes"
]
out
=
v2
.
Cutmix
(
num_classes
=
NUM_CLASSES
,
labels_getter
=
labels_getter
)(
batch
)
print
(
f
"
{
out
[
'imgs'
].
shape
=
}
,
{
out
[
'target'
][
'classes'
].
shape
=
}
"
)
test/test_transforms_v2_refactored.py
View file @
9b4ec8df
...
@@ -1922,7 +1922,7 @@ class TestCutMixMixUp:
...
@@ -1922,7 +1922,7 @@ class TestCutMixMixUp:
dataset
=
self
.
DummyDataset
(
size
=
batch_size
,
num_classes
=
num_classes
)
dataset
=
self
.
DummyDataset
(
size
=
batch_size
,
num_classes
=
num_classes
)
cutmix_mixup
=
T
(
alpha
=
0.5
,
num_classes
=
num_classes
)
cutmix_mixup
=
T
(
num_classes
=
num_classes
)
dl
=
DataLoader
(
dataset
,
batch_size
=
batch_size
)
dl
=
DataLoader
(
dataset
,
batch_size
=
batch_size
)
...
...
torchvision/transforms/v2/_augment.py
View file @
9b4ec8df
...
@@ -141,9 +141,9 @@ class RandomErasing(_RandomApplyTransform):
...
@@ -141,9 +141,9 @@ class RandomErasing(_RandomApplyTransform):
class
_BaseMixupCutmix
(
Transform
):
class
_BaseMixupCutmix
(
Transform
):
def
__init__
(
self
,
*
,
alpha
:
float
=
1
,
num_classes
:
int
,
labels_getter
=
"default"
)
->
None
:
def
__init__
(
self
,
*
,
alpha
:
float
=
1
.0
,
num_classes
:
int
,
labels_getter
=
"default"
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
alpha
=
alpha
self
.
alpha
=
float
(
alpha
)
self
.
_dist
=
torch
.
distributions
.
Beta
(
torch
.
tensor
([
alpha
]),
torch
.
tensor
([
alpha
]))
self
.
_dist
=
torch
.
distributions
.
Beta
(
torch
.
tensor
([
alpha
]),
torch
.
tensor
([
alpha
]))
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
...
@@ -204,13 +204,20 @@ class _BaseMixupCutmix(Transform):
...
@@ -204,13 +204,20 @@ class _BaseMixupCutmix(Transform):
class
Mixup
(
_BaseMixupCutmix
):
class
Mixup
(
_BaseMixupCutmix
):
"""[BETA] Apply Mix
u
p to the provided batch of images and labels.
"""[BETA] Apply Mix
U
p to the provided batch of images and labels.
.. v2betastatus:: Mixup transform
.. v2betastatus:: Mixup transform
Paper: `mixup: Beyond Empirical Risk Minimization <https://arxiv.org/abs/1710.09412>`_.
Paper: `mixup: Beyond Empirical Risk Minimization <https://arxiv.org/abs/1710.09412>`_.
See :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples.
.. note::
This transform is meant to be used on **batches** of samples, not
individual images. See
:ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage
examples.
The sample pairing is deterministic and done by matching consecutive
samples in the batch, so the batch needs to be shuffled (this is an
implementation detail, not a guaranteed convention.)
In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
into a tensor of shape ``(batch_size, num_classes)``.
into a tensor of shape ``(batch_size, num_classes)``.
...
@@ -246,14 +253,21 @@ class Mixup(_BaseMixupCutmix):
...
@@ -246,14 +253,21 @@ class Mixup(_BaseMixupCutmix):
class
Cutmix
(
_BaseMixupCutmix
):
class
Cutmix
(
_BaseMixupCutmix
):
"""[BETA] Apply Cut
m
ix to the provided batch of images and labels.
"""[BETA] Apply Cut
M
ix to the provided batch of images and labels.
.. v2betastatus:: Cutmix transform
.. v2betastatus:: Cutmix transform
Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features
Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features
<https://arxiv.org/abs/1905.04899>`_.
<https://arxiv.org/abs/1905.04899>`_.
See :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples.
.. note::
This transform is meant to be used on **batches** of samples, not
individual images. See
:ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage
examples.
The sample pairing is deterministic and done by matching consecutive
samples in the batch, so the batch needs to be shuffled (this is an
implementation detail, not a guaranteed convention.)
In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
into a tensor of shape ``(batch_size, num_classes)``.
into a tensor of shape ``(batch_size, num_classes)``.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment