Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
3e4e353d
Unverified
Commit
3e4e353d
authored
Aug 01, 2023
by
Nicolas Hug
Committed by
GitHub
Aug 01, 2023
Browse files
Cutmix -> CutMix (#7784)
parent
edde8255
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
29 additions
and
29 deletions
+29
-29
docs/source/transforms.rst
docs/source/transforms.rst
+2
-2
references/classification/transforms.py
references/classification/transforms.py
+8
-8
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+2
-2
test/test_transforms_v2_refactored.py
test/test_transforms_v2_refactored.py
+3
-3
torchvision/prototype/transforms/__init__.py
torchvision/prototype/transforms/__init__.py
+1
-1
torchvision/prototype/transforms/_augment.py
torchvision/prototype/transforms/_augment.py
+3
-3
torchvision/transforms/v2/__init__.py
torchvision/transforms/v2/__init__.py
+1
-1
torchvision/transforms/v2/_augment.py
torchvision/transforms/v2/_augment.py
+7
-7
torchvision/transforms/v2/_utils.py
torchvision/transforms/v2/_utils.py
+2
-2
No files found.
docs/source/transforms.rst
View file @
3e4e353d
...
@@ -274,8 +274,8 @@ are combining pairs of images together. These can be used after the dataloader
...
@@ -274,8 +274,8 @@ are combining pairs of images together. These can be used after the dataloader
:toctree: generated/
:toctree: generated/
:template: class.rst
:template: class.rst
v2.Cut
m
ix
v2.Cut
M
ix
v2.Mix
u
p
v2.Mix
U
p
.. _functional_transforms:
.. _functional_transforms:
...
...
references/classification/transforms.py
View file @
3e4e353d
...
@@ -13,15 +13,15 @@ def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_categories, use_v2):
...
@@ -13,15 +13,15 @@ def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_categories, use_v2):
mixup_cutmix
=
[]
mixup_cutmix
=
[]
if
mixup_alpha
>
0
:
if
mixup_alpha
>
0
:
mixup_cutmix
.
append
(
mixup_cutmix
.
append
(
transforms_module
.
Mix
u
p
(
alpha
=
mixup_alpha
,
num_categories
=
num_categories
)
transforms_module
.
Mix
U
p
(
alpha
=
mixup_alpha
,
num_categories
=
num_categories
)
if
use_v2
if
use_v2
else
RandomMix
u
p
(
num_classes
=
num_categories
,
p
=
1.0
,
alpha
=
mixup_alpha
)
else
RandomMix
U
p
(
num_classes
=
num_categories
,
p
=
1.0
,
alpha
=
mixup_alpha
)
)
)
if
cutmix_alpha
>
0
:
if
cutmix_alpha
>
0
:
mixup_cutmix
.
append
(
mixup_cutmix
.
append
(
transforms_module
.
Cut
m
ix
(
alpha
=
mixup_alpha
,
num_categories
=
num_categories
)
transforms_module
.
Cut
M
ix
(
alpha
=
mixup_alpha
,
num_categories
=
num_categories
)
if
use_v2
if
use_v2
else
RandomCut
m
ix
(
num_classes
=
num_categories
,
p
=
1.0
,
alpha
=
mixup_alpha
)
else
RandomCut
M
ix
(
num_classes
=
num_categories
,
p
=
1.0
,
alpha
=
mixup_alpha
)
)
)
if
not
mixup_cutmix
:
if
not
mixup_cutmix
:
return
None
return
None
...
@@ -29,8 +29,8 @@ def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_categories, use_v2):
...
@@ -29,8 +29,8 @@ def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_categories, use_v2):
return
transforms_module
.
RandomChoice
(
mixup_cutmix
)
return
transforms_module
.
RandomChoice
(
mixup_cutmix
)
class
RandomMix
u
p
(
torch
.
nn
.
Module
):
class
RandomMix
U
p
(
torch
.
nn
.
Module
):
"""Randomly apply Mix
u
p to the provided batch and targets.
"""Randomly apply Mix
U
p to the provided batch and targets.
The class implements the data augmentations as described in the paper
The class implements the data augmentations as described in the paper
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
...
@@ -112,8 +112,8 @@ class RandomMixup(torch.nn.Module):
...
@@ -112,8 +112,8 @@ class RandomMixup(torch.nn.Module):
return
s
return
s
class
RandomCut
m
ix
(
torch
.
nn
.
Module
):
class
RandomCut
M
ix
(
torch
.
nn
.
Module
):
"""Randomly apply Cut
m
ix to the provided batch and targets.
"""Randomly apply Cut
M
ix to the provided batch and targets.
The class implements the data augmentations as described in the paper
The class implements the data augmentations as described in the paper
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
<https://arxiv.org/abs/1905.04899>`_.
<https://arxiv.org/abs/1905.04899>`_.
...
...
test/test_prototype_transforms.py
View file @
3e4e353d
...
@@ -60,8 +60,8 @@ def parametrize(transforms_with_inputs):
...
@@ -60,8 +60,8 @@ def parametrize(transforms_with_inputs):
],
],
)
)
for
transform
in
[
for
transform
in
[
transforms
.
RandomMix
u
p
(
alpha
=
1.0
),
transforms
.
RandomMix
U
p
(
alpha
=
1.0
),
transforms
.
RandomCut
m
ix
(
alpha
=
1.0
),
transforms
.
RandomCut
M
ix
(
alpha
=
1.0
),
]
]
]
]
)
)
...
...
test/test_transforms_v2_refactored.py
View file @
3e4e353d
...
@@ -1914,7 +1914,7 @@ class TestCutMixMixUp:
...
@@ -1914,7 +1914,7 @@ class TestCutMixMixUp:
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
size
return
self
.
size
@
pytest
.
mark
.
parametrize
(
"T"
,
[
transforms
.
Cut
m
ix
,
transforms
.
Mix
u
p
])
@
pytest
.
mark
.
parametrize
(
"T"
,
[
transforms
.
Cut
M
ix
,
transforms
.
Mix
U
p
])
def
test_supported_input_structure
(
self
,
T
):
def
test_supported_input_structure
(
self
,
T
):
batch_size
=
32
batch_size
=
32
...
@@ -1964,7 +1964,7 @@ class TestCutMixMixUp:
...
@@ -1964,7 +1964,7 @@ class TestCutMixMixUp:
check_output
(
img
,
target
)
check_output
(
img
,
target
)
@
needs_cuda
@
needs_cuda
@
pytest
.
mark
.
parametrize
(
"T"
,
[
transforms
.
Cut
m
ix
,
transforms
.
Mix
u
p
])
@
pytest
.
mark
.
parametrize
(
"T"
,
[
transforms
.
Cut
M
ix
,
transforms
.
Mix
U
p
])
def
test_cpu_vs_gpu
(
self
,
T
):
def
test_cpu_vs_gpu
(
self
,
T
):
num_classes
=
10
num_classes
=
10
batch_size
=
3
batch_size
=
3
...
@@ -1976,7 +1976,7 @@ class TestCutMixMixUp:
...
@@ -1976,7 +1976,7 @@ class TestCutMixMixUp:
_check_kernel_cuda_vs_cpu
(
cutmix_mixup
,
imgs
,
labels
,
rtol
=
None
,
atol
=
None
)
_check_kernel_cuda_vs_cpu
(
cutmix_mixup
,
imgs
,
labels
,
rtol
=
None
,
atol
=
None
)
@
pytest
.
mark
.
parametrize
(
"T"
,
[
transforms
.
Cut
m
ix
,
transforms
.
Mix
u
p
])
@
pytest
.
mark
.
parametrize
(
"T"
,
[
transforms
.
Cut
M
ix
,
transforms
.
Mix
U
p
])
def
test_error
(
self
,
T
):
def
test_error
(
self
,
T
):
num_classes
=
10
num_classes
=
10
...
...
torchvision/prototype/transforms/__init__.py
View file @
3e4e353d
from
._presets
import
StereoMatching
# usort: skip
from
._presets
import
StereoMatching
# usort: skip
from
._augment
import
RandomCut
m
ix
,
RandomMix
u
p
,
SimpleCopyPaste
from
._augment
import
RandomCut
M
ix
,
RandomMix
U
p
,
SimpleCopyPaste
from
._geometry
import
FixedSizeCrop
from
._geometry
import
FixedSizeCrop
from
._misc
import
PermuteDimensions
,
TransposeDimensions
from
._misc
import
PermuteDimensions
,
TransposeDimensions
from
._type_conversion
import
LabelToOneHot
from
._type_conversion
import
LabelToOneHot
torchvision/prototype/transforms/_augment.py
View file @
3e4e353d
...
@@ -14,7 +14,7 @@ from torchvision.transforms.v2.functional._geometry import _check_interpolation
...
@@ -14,7 +14,7 @@ from torchvision.transforms.v2.functional._geometry import _check_interpolation
from
torchvision.transforms.v2.utils
import
has_any
,
is_simple_tensor
,
query_size
from
torchvision.transforms.v2.utils
import
has_any
,
is_simple_tensor
,
query_size
class
_BaseMix
u
pCut
m
ix
(
_RandomApplyTransform
):
class
_BaseMix
U
pCut
M
ix
(
_RandomApplyTransform
):
def
__init__
(
self
,
alpha
:
float
,
p
:
float
=
0.5
)
->
None
:
def
__init__
(
self
,
alpha
:
float
,
p
:
float
=
0.5
)
->
None
:
super
().
__init__
(
p
=
p
)
super
().
__init__
(
p
=
p
)
self
.
alpha
=
alpha
self
.
alpha
=
alpha
...
@@ -38,7 +38,7 @@ class _BaseMixupCutmix(_RandomApplyTransform):
...
@@ -38,7 +38,7 @@ class _BaseMixupCutmix(_RandomApplyTransform):
return
proto_datapoints
.
OneHotLabel
.
wrap_like
(
inpt
,
output
)
return
proto_datapoints
.
OneHotLabel
.
wrap_like
(
inpt
,
output
)
class
RandomMix
u
p
(
_BaseMix
u
pCut
m
ix
):
class
RandomMix
U
p
(
_BaseMix
U
pCut
M
ix
):
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
return
dict
(
lam
=
float
(
self
.
_dist
.
sample
(())))
# type: ignore[arg-type]
return
dict
(
lam
=
float
(
self
.
_dist
.
sample
(())))
# type: ignore[arg-type]
...
@@ -60,7 +60,7 @@ class RandomMixup(_BaseMixupCutmix):
...
@@ -60,7 +60,7 @@ class RandomMixup(_BaseMixupCutmix):
return
inpt
return
inpt
class
RandomCut
m
ix
(
_BaseMix
u
pCut
m
ix
):
class
RandomCut
M
ix
(
_BaseMix
U
pCut
M
ix
):
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
lam
=
float
(
self
.
_dist
.
sample
(()))
# type: ignore[arg-type]
lam
=
float
(
self
.
_dist
.
sample
(()))
# type: ignore[arg-type]
...
...
torchvision/transforms/v2/__init__.py
View file @
3e4e353d
...
@@ -4,7 +4,7 @@ from . import functional, utils # usort: skip
...
@@ -4,7 +4,7 @@ from . import functional, utils # usort: skip
from
._transform
import
Transform
# usort: skip
from
._transform
import
Transform
# usort: skip
from
._augment
import
Cut
m
ix
,
Mix
u
p
,
RandomErasing
from
._augment
import
Cut
M
ix
,
Mix
U
p
,
RandomErasing
from
._auto_augment
import
AugMix
,
AutoAugment
,
RandAugment
,
TrivialAugmentWide
from
._auto_augment
import
AugMix
,
AutoAugment
,
RandAugment
,
TrivialAugmentWide
from
._color
import
(
from
._color
import
(
ColorJitter
,
ColorJitter
,
...
...
torchvision/transforms/v2/_augment.py
View file @
3e4e353d
...
@@ -140,7 +140,7 @@ class RandomErasing(_RandomApplyTransform):
...
@@ -140,7 +140,7 @@ class RandomErasing(_RandomApplyTransform):
return
inpt
return
inpt
class
_BaseMix
u
pCut
m
ix
(
Transform
):
class
_BaseMix
U
pCut
M
ix
(
Transform
):
def
__init__
(
self
,
*
,
alpha
:
float
=
1.0
,
num_classes
:
int
,
labels_getter
=
"default"
)
->
None
:
def
__init__
(
self
,
*
,
alpha
:
float
=
1.0
,
num_classes
:
int
,
labels_getter
=
"default"
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
alpha
=
float
(
alpha
)
self
.
alpha
=
float
(
alpha
)
...
@@ -203,10 +203,10 @@ class _BaseMixupCutmix(Transform):
...
@@ -203,10 +203,10 @@ class _BaseMixupCutmix(Transform):
return
label
.
roll
(
1
,
0
).
mul_
(
1.0
-
lam
).
add_
(
label
.
mul
(
lam
))
return
label
.
roll
(
1
,
0
).
mul_
(
1.0
-
lam
).
add_
(
label
.
mul
(
lam
))
class
Mix
u
p
(
_BaseMix
u
pCut
m
ix
):
class
Mix
U
p
(
_BaseMix
U
pCut
M
ix
):
"""[BETA] Apply MixUp to the provided batch of images and labels.
"""[BETA] Apply MixUp to the provided batch of images and labels.
.. v2betastatus:: Mix
u
p transform
.. v2betastatus:: Mix
U
p transform
Paper: `mixup: Beyond Empirical Risk Minimization <https://arxiv.org/abs/1710.09412>`_.
Paper: `mixup: Beyond Empirical Risk Minimization <https://arxiv.org/abs/1710.09412>`_.
...
@@ -227,7 +227,7 @@ class Mixup(_BaseMixupCutmix):
...
@@ -227,7 +227,7 @@ class Mixup(_BaseMixupCutmix):
num_classes (int): number of classes in the batch. Used for one-hot-encoding.
num_classes (int): number of classes in the batch. Used for one-hot-encoding.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter a the labels if it's a tensor. This covers the most
By default, this will pick the second parameter a the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``Mix
u
p()(imgs_batch, labels_batch)``.
common scenario where this transform is called as ``Mix
U
p()(imgs_batch, labels_batch)``.
It can also be a callable that takes the same input as the transform, and returns the labels.
It can also be a callable that takes the same input as the transform, and returns the labels.
"""
"""
...
@@ -252,10 +252,10 @@ class Mixup(_BaseMixupCutmix):
...
@@ -252,10 +252,10 @@ class Mixup(_BaseMixupCutmix):
return
inpt
return
inpt
class
Cut
m
ix
(
_BaseMix
u
pCut
m
ix
):
class
Cut
M
ix
(
_BaseMix
U
pCut
M
ix
):
"""[BETA] Apply CutMix to the provided batch of images and labels.
"""[BETA] Apply CutMix to the provided batch of images and labels.
.. v2betastatus:: Cut
m
ix transform
.. v2betastatus:: Cut
M
ix transform
Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features
Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features
<https://arxiv.org/abs/1905.04899>`_.
<https://arxiv.org/abs/1905.04899>`_.
...
@@ -277,7 +277,7 @@ class Cutmix(_BaseMixupCutmix):
...
@@ -277,7 +277,7 @@ class Cutmix(_BaseMixupCutmix):
num_classes (int): number of classes in the batch. Used for one-hot-encoding.
num_classes (int): number of classes in the batch. Used for one-hot-encoding.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter a the labels if it's a tensor. This covers the most
By default, this will pick the second parameter a the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``Cut
m
ix()(imgs_batch, labels_batch)``.
common scenario where this transform is called as ``Cut
M
ix()(imgs_batch, labels_batch)``.
It can also be a callable that takes the same input as the transform, and returns the labels.
It can also be a callable that takes the same input as the transform, and returns the labels.
"""
"""
...
...
torchvision/transforms/v2/_utils.py
View file @
3e4e353d
...
@@ -89,7 +89,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor:
...
@@ -89,7 +89,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor:
This heuristic covers three cases:
This heuristic covers three cases:
1. The input is tuple or list whose second item is a labels tensor. This happens for already batched
1. The input is tuple or list whose second item is a labels tensor. This happens for already batched
classification inputs for Mix
u
p and Cut
m
ix (typically after the Dataloder).
classification inputs for Mix
U
p and Cut
M
ix (typically after the Dataloder).
2. The input is a tuple or list whose second item is a dictionary that contains the labels tensor
2. The input is a tuple or list whose second item is a dictionary that contains the labels tensor
under a label-like (see below) key. This happens for the inputs of detection models.
under a label-like (see below) key. This happens for the inputs of detection models.
3. The input is a dictionary that is structured as the one from 2.
3. The input is a dictionary that is structured as the one from 2.
...
@@ -103,7 +103,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor:
...
@@ -103,7 +103,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor:
if
isinstance
(
inputs
,
(
tuple
,
list
)):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
inputs
=
inputs
[
1
]
inputs
=
inputs
[
1
]
# Mix
u
p, Cut
m
ix
# Mix
U
p, Cut
M
ix
if
isinstance
(
inputs
,
torch
.
Tensor
):
if
isinstance
(
inputs
,
torch
.
Tensor
):
return
inputs
return
inputs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment