Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
3a7e5e38
Unverified
Commit
3a7e5e38
authored
Aug 31, 2021
by
Vasilis Vryniotis
Committed by
GitHub
Aug 31, 2021
Browse files
Refactor AutoAugment to support more augmentations. (#4338)
parent
72393f2b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
153 additions
and
155 deletions
+153
-155
torchvision/transforms/autoaugment.py
torchvision/transforms/autoaugment.py
+153
-155
No files found.
torchvision/transforms/autoaugment.py
View file @
3a7e5e38
...
...
@@ -10,6 +10,45 @@ from . import functional as F, InterpolationMode
__all__
=
[
"AutoAugmentPolicy"
,
"AutoAugment"
]
def
_apply_op
(
img
:
Tensor
,
op_name
:
str
,
magnitude
:
float
,
interpolation
:
InterpolationMode
,
fill
:
Optional
[
List
[
float
]]):
if
op_name
==
"ShearX"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
math
.
degrees
(
magnitude
),
0.0
],
interpolation
=
interpolation
,
fill
=
fill
)
elif
op_name
==
"ShearY"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
0.0
,
math
.
degrees
(
magnitude
)],
interpolation
=
interpolation
,
fill
=
fill
)
elif
op_name
==
"TranslateX"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
int
(
magnitude
),
0
],
scale
=
1.0
,
interpolation
=
interpolation
,
shear
=
[
0.0
,
0.0
],
fill
=
fill
)
elif
op_name
==
"TranslateY"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
0
,
int
(
magnitude
)],
scale
=
1.0
,
interpolation
=
interpolation
,
shear
=
[
0.0
,
0.0
],
fill
=
fill
)
elif
op_name
==
"Rotate"
:
img
=
F
.
rotate
(
img
,
magnitude
,
interpolation
=
interpolation
,
fill
=
fill
)
elif
op_name
==
"Brightness"
:
img
=
F
.
adjust_brightness
(
img
,
1.0
+
magnitude
)
elif
op_name
==
"Color"
:
img
=
F
.
adjust_saturation
(
img
,
1.0
+
magnitude
)
elif
op_name
==
"Contrast"
:
img
=
F
.
adjust_contrast
(
img
,
1.0
+
magnitude
)
elif
op_name
==
"Sharpness"
:
img
=
F
.
adjust_sharpness
(
img
,
1.0
+
magnitude
)
elif
op_name
==
"Posterize"
:
img
=
F
.
posterize
(
img
,
int
(
magnitude
))
elif
op_name
==
"Solarize"
:
img
=
F
.
solarize
(
img
,
magnitude
)
elif
op_name
==
"AutoContrast"
:
img
=
F
.
autocontrast
(
img
)
elif
op_name
==
"Equalize"
:
img
=
F
.
equalize
(
img
)
elif
op_name
==
"Invert"
:
img
=
F
.
invert
(
img
)
else
:
raise
ValueError
(
"The provided operator {} is not recognized."
.
format
(
op_name
))
return
img
class
AutoAugmentPolicy
(
Enum
):
"""AutoAugment policies learned on different datasets.
Available policies are IMAGENET, CIFAR10 and SVHN.
...
...
@@ -19,9 +58,39 @@ class AutoAugmentPolicy(Enum):
SVHN
=
"svhn"
def
_get_transforms
(
# type: ignore[return]
class
AutoAugment
(
torch
.
nn
.
Module
):
r
"""AutoAugment data augmentation method based on
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
policy (AutoAugmentPolicy): Desired policy enum defined by
:class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
def
__init__
(
self
,
policy
:
AutoAugmentPolicy
=
AutoAugmentPolicy
.
IMAGENET
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
fill
:
Optional
[
List
[
float
]]
=
None
)
->
None
:
super
().
__init__
()
self
.
policy
=
policy
self
.
interpolation
=
interpolation
self
.
fill
=
fill
self
.
transforms
=
self
.
_get_transforms
(
policy
)
def
_get_transforms
(
self
,
policy
:
AutoAugmentPolicy
)
->
List
[
Tuple
[
Tuple
[
str
,
float
,
Optional
[
int
]],
Tuple
[
str
,
float
,
Optional
[
int
]]]]:
)
->
List
[
Tuple
[
Tuple
[
str
,
float
,
Optional
[
int
]],
Tuple
[
str
,
float
,
Optional
[
int
]]]]:
if
policy
==
AutoAugmentPolicy
.
IMAGENET
:
return
[
((
"Posterize"
,
0.4
,
8
),
(
"Rotate"
,
0.6
,
9
)),
...
...
@@ -106,62 +175,28 @@ def _get_transforms( # type: ignore[return]
((
"ShearY"
,
0.8
,
5
),
(
"AutoContrast"
,
0.7
,
None
)),
((
"ShearX"
,
0.7
,
2
),
(
"Invert"
,
0.1
,
None
)),
]
else
:
raise
ValueError
(
"The provided policy {} is not recognized."
.
format
(
policy
))
def
_get_magnitudes
()
->
Dict
[
str
,
Tuple
[
Optional
[
Tensor
],
Optional
[
bool
]]]:
_BINS
=
10
def
_get_magnitudes
(
self
,
num_bins
:
int
,
image_size
:
List
[
int
])
->
Dict
[
str
,
Tuple
[
Tensor
,
bool
]]:
return
{
# name: (magnitudes, signed)
"ShearX"
:
(
torch
.
linspace
(
0.0
,
0.3
,
_BINS
),
True
),
"ShearY"
:
(
torch
.
linspace
(
0.0
,
0.3
,
_BINS
),
True
),
"TranslateX"
:
(
torch
.
linspace
(
0.0
,
150.0
/
331.0
,
_BINS
),
True
),
"TranslateY"
:
(
torch
.
linspace
(
0.0
,
150.0
/
331.0
,
_BINS
),
True
),
"Rotate"
:
(
torch
.
linspace
(
0.0
,
30.0
,
_BINS
),
True
),
"Brightness"
:
(
torch
.
linspace
(
0.0
,
0.9
,
_BINS
),
True
),
"Color"
:
(
torch
.
linspace
(
0.0
,
0.9
,
_BINS
),
True
),
"Contrast"
:
(
torch
.
linspace
(
0.0
,
0.9
,
_BINS
),
True
),
"Sharpness"
:
(
torch
.
linspace
(
0.0
,
0.9
,
_BINS
),
True
),
"Posterize"
:
(
torch
.
tensor
([
8
,
8
,
7
,
7
,
6
,
6
,
5
,
5
,
4
,
4
]
),
False
),
"Solarize"
:
(
torch
.
linspace
(
256.0
,
0.0
,
_BINS
),
False
),
"AutoContrast"
:
(
None
,
Non
e
),
"Equalize"
:
(
None
,
Non
e
),
"Invert"
:
(
None
,
Non
e
),
"ShearX"
:
(
torch
.
linspace
(
0.0
,
0.3
,
num_bins
),
True
),
"ShearY"
:
(
torch
.
linspace
(
0.0
,
0.3
,
num_bins
),
True
),
"TranslateX"
:
(
torch
.
linspace
(
0.0
,
150.0
/
331.0
*
image_size
[
0
],
num_bins
),
True
),
"TranslateY"
:
(
torch
.
linspace
(
0.0
,
150.0
/
331.0
*
image_size
[
1
],
num_bins
),
True
),
"Rotate"
:
(
torch
.
linspace
(
0.0
,
30.0
,
num_bins
),
True
),
"Brightness"
:
(
torch
.
linspace
(
0.0
,
0.9
,
num_bins
),
True
),
"Color"
:
(
torch
.
linspace
(
0.0
,
0.9
,
num_bins
),
True
),
"Contrast"
:
(
torch
.
linspace
(
0.0
,
0.9
,
num_bins
),
True
),
"Sharpness"
:
(
torch
.
linspace
(
0.0
,
0.9
,
num_bins
),
True
),
"Posterize"
:
(
8
-
(
torch
.
arange
(
num_bins
)
/
((
num_bins
-
1
)
/
4
)).
round
().
int
(
),
False
),
"Solarize"
:
(
torch
.
linspace
(
256.0
,
0.0
,
num_bins
),
False
),
"AutoContrast"
:
(
torch
.
tensor
(
0.0
),
Fals
e
),
"Equalize"
:
(
torch
.
tensor
(
0.0
),
Fals
e
),
"Invert"
:
(
torch
.
tensor
(
0.0
),
Fals
e
),
}
class
AutoAugment
(
torch
.
nn
.
Module
):
r
"""AutoAugment data augmentation method based on
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
policy (AutoAugmentPolicy): Desired policy enum defined by
:class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
def
__init__
(
self
,
policy
:
AutoAugmentPolicy
=
AutoAugmentPolicy
.
IMAGENET
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
fill
:
Optional
[
List
[
float
]]
=
None
)
->
None
:
super
().
__init__
()
self
.
policy
=
policy
self
.
interpolation
=
interpolation
self
.
fill
=
fill
self
.
transforms
=
_get_transforms
(
policy
)
if
self
.
transforms
is
None
:
raise
ValueError
(
"The provided policy {} is not recognized."
.
format
(
policy
))
self
.
_op_meta
=
_get_magnitudes
()
@
staticmethod
def
get_params
(
transform_num
:
int
)
->
Tuple
[
int
,
Tensor
,
Tensor
]:
"""Get parameters for autoaugment transformation
...
...
@@ -175,9 +210,6 @@ class AutoAugment(torch.nn.Module):
return
policy_id
,
probs
,
signs
def
_get_op_meta
(
self
,
name
:
str
)
->
Tuple
[
Optional
[
Tensor
],
Optional
[
bool
]]:
return
self
.
_op_meta
[
name
]
def
forward
(
self
,
img
:
Tensor
)
->
Tensor
:
"""
img (PIL Image or Tensor): Image to be transformed.
...
...
@@ -196,46 +228,12 @@ class AutoAugment(torch.nn.Module):
for
i
,
(
op_name
,
p
,
magnitude_id
)
in
enumerate
(
self
.
transforms
[
transform_id
]):
if
probs
[
i
]
<=
p
:
magnitudes
,
signed
=
self
.
_get_op_meta
(
op_name
)
magnitude
=
float
(
magnitudes
[
magnitude_id
].
item
())
\
if
magnitude
s
is
not
None
and
magnitude_id
is
not
None
else
0.0
if
signed
is
not
None
and
signed
and
signs
[
i
]
==
0
:
op_meta
=
self
.
_get_magnitudes
(
10
,
F
.
get_image_size
(
img
)
)
magnitude
s
,
signed
=
op_meta
[
op_name
]
magnitude
=
float
(
magnitudes
[
magnitude_id
].
item
())
if
magnitude_id
is
not
None
else
0.0
if
signed
and
signs
[
i
]
==
0
:
magnitude
*=
-
1.0
if
op_name
==
"ShearX"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
math
.
degrees
(
magnitude
),
0.0
],
interpolation
=
self
.
interpolation
,
fill
=
fill
)
elif
op_name
==
"ShearY"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
0.0
,
math
.
degrees
(
magnitude
)],
interpolation
=
self
.
interpolation
,
fill
=
fill
)
elif
op_name
==
"TranslateX"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
int
(
F
.
get_image_size
(
img
)[
0
]
*
magnitude
),
0
],
scale
=
1.0
,
interpolation
=
self
.
interpolation
,
shear
=
[
0.0
,
0.0
],
fill
=
fill
)
elif
op_name
==
"TranslateY"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
0
,
int
(
F
.
get_image_size
(
img
)[
1
]
*
magnitude
)],
scale
=
1.0
,
interpolation
=
self
.
interpolation
,
shear
=
[
0.0
,
0.0
],
fill
=
fill
)
elif
op_name
==
"Rotate"
:
img
=
F
.
rotate
(
img
,
magnitude
,
interpolation
=
self
.
interpolation
,
fill
=
fill
)
elif
op_name
==
"Brightness"
:
img
=
F
.
adjust_brightness
(
img
,
1.0
+
magnitude
)
elif
op_name
==
"Color"
:
img
=
F
.
adjust_saturation
(
img
,
1.0
+
magnitude
)
elif
op_name
==
"Contrast"
:
img
=
F
.
adjust_contrast
(
img
,
1.0
+
magnitude
)
elif
op_name
==
"Sharpness"
:
img
=
F
.
adjust_sharpness
(
img
,
1.0
+
magnitude
)
elif
op_name
==
"Posterize"
:
img
=
F
.
posterize
(
img
,
int
(
magnitude
))
elif
op_name
==
"Solarize"
:
img
=
F
.
solarize
(
img
,
magnitude
)
elif
op_name
==
"AutoContrast"
:
img
=
F
.
autocontrast
(
img
)
elif
op_name
==
"Equalize"
:
img
=
F
.
equalize
(
img
)
elif
op_name
==
"Invert"
:
img
=
F
.
invert
(
img
)
else
:
raise
ValueError
(
"The provided operator {} is not recognized."
.
format
(
op_name
))
img
=
_apply_op
(
img
,
op_name
,
magnitude
,
interpolation
=
self
.
interpolation
,
fill
=
fill
)
return
img
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment