Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
c585a515
Unverified
Commit
c585a515
authored
May 28, 2024
by
Mahdi Lamb
Committed by
GitHub
May 28, 2024
Browse files
Enable one-hot-encoded labels in MixUp and CutMix (#8427)
Co-authored-by:
Nicolas Hug
<
nh.nicolas.hug@gmail.com
>
parent
778ce48b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
23 deletions
+40
-23
test/test_transforms_v2.py
test/test_transforms_v2.py
+19
-16
torchvision/transforms/v2/_augment.py
torchvision/transforms/v2/_augment.py
+21
-7
No files found.
test/test_transforms_v2.py
View file @
c585a515
...
@@ -2169,26 +2169,30 @@ class TestAdjustBrightness:
...
@@ -2169,26 +2169,30 @@ class TestAdjustBrightness:
class
TestCutMixMixUp
:
class
TestCutMixMixUp
:
class
DummyDataset
:
class
DummyDataset
:
def
__init__
(
self
,
size
,
num_classes
):
def
__init__
(
self
,
size
,
num_classes
,
one_hot_labels
):
self
.
size
=
size
self
.
size
=
size
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
one_hot_labels
=
one_hot_labels
assert
size
<
num_classes
assert
size
<
num_classes
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
img
=
torch
.
rand
(
3
,
100
,
100
)
img
=
torch
.
rand
(
3
,
100
,
100
)
label
=
idx
# This ensures all labels in a batch are unique and makes testing easier
label
=
idx
# This ensures all labels in a batch are unique and makes testing easier
if
self
.
one_hot_labels
:
label
=
torch
.
nn
.
functional
.
one_hot
(
torch
.
tensor
(
label
),
num_classes
=
self
.
num_classes
)
return
img
,
label
return
img
,
label
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
size
return
self
.
size
@
pytest
.
mark
.
parametrize
(
"T"
,
[
transforms
.
CutMix
,
transforms
.
MixUp
])
@
pytest
.
mark
.
parametrize
(
"T"
,
[
transforms
.
CutMix
,
transforms
.
MixUp
])
def
test_supported_input_structure
(
self
,
T
):
@
pytest
.
mark
.
parametrize
(
"one_hot_labels"
,
(
True
,
False
))
def
test_supported_input_structure
(
self
,
T
,
one_hot_labels
):
batch_size
=
32
batch_size
=
32
num_classes
=
100
num_classes
=
100
dataset
=
self
.
DummyDataset
(
size
=
batch_size
,
num_classes
=
num_classes
)
dataset
=
self
.
DummyDataset
(
size
=
batch_size
,
num_classes
=
num_classes
,
one_hot_labels
=
one_hot_labels
)
cutmix_mixup
=
T
(
num_classes
=
num_classes
)
cutmix_mixup
=
T
(
num_classes
=
num_classes
)
...
@@ -2198,7 +2202,7 @@ class TestCutMixMixUp:
...
@@ -2198,7 +2202,7 @@ class TestCutMixMixUp:
img
,
target
=
next
(
iter
(
dl
))
img
,
target
=
next
(
iter
(
dl
))
input_img_size
=
img
.
shape
[
-
3
:]
input_img_size
=
img
.
shape
[
-
3
:]
assert
isinstance
(
img
,
torch
.
Tensor
)
and
isinstance
(
target
,
torch
.
Tensor
)
assert
isinstance
(
img
,
torch
.
Tensor
)
and
isinstance
(
target
,
torch
.
Tensor
)
assert
target
.
shape
==
(
batch_size
,)
assert
target
.
shape
==
(
batch_size
,
num_classes
)
if
one_hot_labels
else
(
batch_size
,)
def
check_output
(
img
,
target
):
def
check_output
(
img
,
target
):
assert
img
.
shape
==
(
batch_size
,
*
input_img_size
)
assert
img
.
shape
==
(
batch_size
,
*
input_img_size
)
...
@@ -2209,7 +2213,7 @@ class TestCutMixMixUp:
...
@@ -2209,7 +2213,7 @@ class TestCutMixMixUp:
# After Dataloader, as unpacked input
# After Dataloader, as unpacked input
img
,
target
=
next
(
iter
(
dl
))
img
,
target
=
next
(
iter
(
dl
))
assert
target
.
shape
==
(
batch_size
,)
assert
target
.
shape
==
(
batch_size
,
num_classes
)
if
one_hot_labels
else
(
batch_size
,)
img
,
target
=
cutmix_mixup
(
img
,
target
)
img
,
target
=
cutmix_mixup
(
img
,
target
)
check_output
(
img
,
target
)
check_output
(
img
,
target
)
...
@@ -2264,7 +2268,7 @@ class TestCutMixMixUp:
...
@@ -2264,7 +2268,7 @@ class TestCutMixMixUp:
with
pytest
.
raises
(
ValueError
,
match
=
"Could not infer where the labels are"
):
with
pytest
.
raises
(
ValueError
,
match
=
"Could not infer where the labels are"
):
cutmix_mixup
({
"img"
:
imgs
,
"Nothing_else"
:
3
})
cutmix_mixup
({
"img"
:
imgs
,
"Nothing_else"
:
3
})
with
pytest
.
raises
(
ValueError
,
match
=
"labels
tensor
should be
of shape
"
):
with
pytest
.
raises
(
ValueError
,
match
=
"labels should be
index based
"
):
# Note: the error message isn't ideal, but that's because the label heuristic found the img as the label
# Note: the error message isn't ideal, but that's because the label heuristic found the img as the label
# It's OK, it's an edge-case. The important thing is that this fails loudly instead of passing silently
# It's OK, it's an edge-case. The important thing is that this fails loudly instead of passing silently
cutmix_mixup
(
imgs
)
cutmix_mixup
(
imgs
)
...
@@ -2272,22 +2276,21 @@ class TestCutMixMixUp:
...
@@ -2272,22 +2276,21 @@ class TestCutMixMixUp:
with
pytest
.
raises
(
ValueError
,
match
=
"When using the default labels_getter"
):
with
pytest
.
raises
(
ValueError
,
match
=
"When using the default labels_getter"
):
cutmix_mixup
(
imgs
,
"not_a_tensor"
)
cutmix_mixup
(
imgs
,
"not_a_tensor"
)
with
pytest
.
raises
(
ValueError
,
match
=
"labels tensor should be of shape"
):
cutmix_mixup
(
imgs
,
torch
.
randint
(
0
,
2
,
size
=
(
2
,
3
)))
with
pytest
.
raises
(
ValueError
,
match
=
"Expected a batched input with 4 dims"
):
with
pytest
.
raises
(
ValueError
,
match
=
"Expected a batched input with 4 dims"
):
cutmix_mixup
(
imgs
[
None
,
None
],
torch
.
randint
(
0
,
num_classes
,
size
=
(
batch_size
,)))
cutmix_mixup
(
imgs
[
None
,
None
],
torch
.
randint
(
0
,
num_classes
,
size
=
(
batch_size
,)))
with
pytest
.
raises
(
ValueError
,
match
=
"does not match the batch size of the labels"
):
with
pytest
.
raises
(
ValueError
,
match
=
"does not match the batch size of the labels"
):
cutmix_mixup
(
imgs
,
torch
.
randint
(
0
,
num_classes
,
size
=
(
batch_size
+
1
,)))
cutmix_mixup
(
imgs
,
torch
.
randint
(
0
,
num_classes
,
size
=
(
batch_size
+
1
,)))
with
pytest
.
raises
(
ValueError
,
match
=
"labels tensor should be of shape"
):
with
pytest
.
raises
(
ValueError
,
match
=
"When passing 2D labels"
):
# The purpose of this check is more about documenting the current
wrong_num_classes
=
num_classes
+
1
# behaviour of what happens on a Compose(), rather than actually
T
(
alpha
=
0.5
,
num_classes
=
num_classes
)(
imgs
,
torch
.
randint
(
0
,
2
,
size
=
(
batch_size
,
wrong_num_classes
)))
# asserting the expected behaviour. We may support Compose() in the
# future, e.g. for 2 consecutive CutMix?
with
pytest
.
raises
(
ValueError
,
match
=
"but got a tensor of shape"
):
labels
=
torch
.
randint
(
0
,
num_classes
,
size
=
(
batch_size
,))
cutmix_mixup
(
imgs
,
torch
.
randint
(
0
,
2
,
size
=
(
2
,
3
,
4
)))
transforms
.
Compose
([
cutmix_mixup
,
cutmix_mixup
])(
imgs
,
labels
)
with
pytest
.
raises
(
ValueError
,
match
=
"num_classes must be passed"
):
T
(
alpha
=
0.5
)(
imgs
,
torch
.
randint
(
0
,
num_classes
,
size
=
(
batch_size
,)))
@
pytest
.
mark
.
parametrize
(
"key"
,
(
"labels"
,
"LABELS"
,
"LaBeL"
,
"SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT"
))
@
pytest
.
mark
.
parametrize
(
"key"
,
(
"labels"
,
"LABELS"
,
"LaBeL"
,
"SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT"
))
...
...
torchvision/transforms/v2/_augment.py
View file @
c585a515
import
math
import
math
import
numbers
import
numbers
import
warnings
import
warnings
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Sequence
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
PIL.Image
import
PIL.Image
import
torch
import
torch
...
@@ -142,7 +142,7 @@ class RandomErasing(_RandomApplyTransform):
...
@@ -142,7 +142,7 @@ class RandomErasing(_RandomApplyTransform):
class
_BaseMixUpCutMix
(
Transform
):
class
_BaseMixUpCutMix
(
Transform
):
def
__init__
(
self
,
*
,
alpha
:
float
=
1.0
,
num_classes
:
int
,
labels_getter
=
"default"
)
->
None
:
def
__init__
(
self
,
*
,
alpha
:
float
=
1.0
,
num_classes
:
Optional
[
int
]
=
None
,
labels_getter
=
"default"
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
alpha
=
float
(
alpha
)
self
.
alpha
=
float
(
alpha
)
self
.
_dist
=
torch
.
distributions
.
Beta
(
torch
.
tensor
([
alpha
]),
torch
.
tensor
([
alpha
]))
self
.
_dist
=
torch
.
distributions
.
Beta
(
torch
.
tensor
([
alpha
]),
torch
.
tensor
([
alpha
]))
...
@@ -162,10 +162,21 @@ class _BaseMixUpCutMix(Transform):
...
@@ -162,10 +162,21 @@ class _BaseMixUpCutMix(Transform):
labels
=
self
.
_labels_getter
(
inputs
)
labels
=
self
.
_labels_getter
(
inputs
)
if
not
isinstance
(
labels
,
torch
.
Tensor
):
if
not
isinstance
(
labels
,
torch
.
Tensor
):
raise
ValueError
(
f
"The labels must be a tensor, but got
{
type
(
labels
)
}
instead."
)
raise
ValueError
(
f
"The labels must be a tensor, but got
{
type
(
labels
)
}
instead."
)
el
if
labels
.
ndim
!=
1
:
if
labels
.
ndim
not
in
(
1
,
2
)
:
raise
ValueError
(
raise
ValueError
(
f
"labels tensor should be of shape (batch_size,) "
f
"but got shape
{
labels
.
shape
}
instead."
f
"labels should be index based with shape (batch_size,) "
f
"or probability based with shape (batch_size, num_classes), "
f
"but got a tensor of shape
{
labels
.
shape
}
instead."
)
)
if
labels
.
ndim
==
2
and
self
.
num_classes
is
not
None
and
labels
.
shape
[
-
1
]
!=
self
.
num_classes
:
raise
ValueError
(
f
"When passing 2D labels, "
f
"the number of elements in last dimension must match num_classes: "
f
"
{
labels
.
shape
[
-
1
]
}
!=
{
self
.
num_classes
}
. "
f
"You can Leave num_classes to None."
)
if
labels
.
ndim
==
1
and
self
.
num_classes
is
None
:
raise
ValueError
(
"num_classes must be passed if the labels are index-based (1D)"
)
params
=
{
params
=
{
"labels"
:
labels
,
"labels"
:
labels
,
...
@@ -198,7 +209,8 @@ class _BaseMixUpCutMix(Transform):
...
@@ -198,7 +209,8 @@ class _BaseMixUpCutMix(Transform):
)
)
def
_mixup_label
(
self
,
label
:
torch
.
Tensor
,
*
,
lam
:
float
)
->
torch
.
Tensor
:
def
_mixup_label
(
self
,
label
:
torch
.
Tensor
,
*
,
lam
:
float
)
->
torch
.
Tensor
:
label
=
one_hot
(
label
,
num_classes
=
self
.
num_classes
)
if
label
.
ndim
==
1
:
label
=
one_hot
(
label
,
num_classes
=
self
.
num_classes
)
# type: ignore[arg-type]
if
not
label
.
dtype
.
is_floating_point
:
if
not
label
.
dtype
.
is_floating_point
:
label
=
label
.
float
()
label
=
label
.
float
()
return
label
.
roll
(
1
,
0
).
mul_
(
1.0
-
lam
).
add_
(
label
.
mul
(
lam
))
return
label
.
roll
(
1
,
0
).
mul_
(
1.0
-
lam
).
add_
(
label
.
mul
(
lam
))
...
@@ -223,7 +235,8 @@ class MixUp(_BaseMixUpCutMix):
...
@@ -223,7 +235,8 @@ class MixUp(_BaseMixUpCutMix):
Args:
Args:
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
num_classes (int): number of classes in the batch. Used for one-hot-encoding.
num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding.
Can be None only if the labels are already one-hot-encoded.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``.
common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``.
...
@@ -271,7 +284,8 @@ class CutMix(_BaseMixUpCutMix):
...
@@ -271,7 +284,8 @@ class CutMix(_BaseMixUpCutMix):
Args:
Args:
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
num_classes (int): number of classes in the batch. Used for one-hot-encoding.
num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding.
Can be None only if the labels are already one-hot-encoded.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``.
common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment