Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
3a7e5e38
"docs/vscode:/vscode.git/clone" did not exist on "4b1fb6816a800db3febec3c1a9d3a5543287d1be"
Unverified
Commit
3a7e5e38
authored
Aug 31, 2021
by
Vasilis Vryniotis
Committed by
GitHub
Aug 31, 2021
Browse files
Refactor AutoAugment to support more augmentations. (#4338)
parent
72393f2b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
153 additions
and
155 deletions
+153
-155
torchvision/transforms/autoaugment.py
torchvision/transforms/autoaugment.py
+153
-155
No files found.
torchvision/transforms/autoaugment.py
View file @
3a7e5e38
...
@@ -10,6 +10,45 @@ from . import functional as F, InterpolationMode
...
@@ -10,6 +10,45 @@ from . import functional as F, InterpolationMode
__all__
=
[
"AutoAugmentPolicy"
,
"AutoAugment"
]
__all__
=
[
"AutoAugmentPolicy"
,
"AutoAugment"
]
def
_apply_op
(
img
:
Tensor
,
op_name
:
str
,
magnitude
:
float
,
interpolation
:
InterpolationMode
,
fill
:
Optional
[
List
[
float
]]):
if
op_name
==
"ShearX"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
math
.
degrees
(
magnitude
),
0.0
],
interpolation
=
interpolation
,
fill
=
fill
)
elif
op_name
==
"ShearY"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
0.0
,
math
.
degrees
(
magnitude
)],
interpolation
=
interpolation
,
fill
=
fill
)
elif
op_name
==
"TranslateX"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
int
(
magnitude
),
0
],
scale
=
1.0
,
interpolation
=
interpolation
,
shear
=
[
0.0
,
0.0
],
fill
=
fill
)
elif
op_name
==
"TranslateY"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
0
,
int
(
magnitude
)],
scale
=
1.0
,
interpolation
=
interpolation
,
shear
=
[
0.0
,
0.0
],
fill
=
fill
)
elif
op_name
==
"Rotate"
:
img
=
F
.
rotate
(
img
,
magnitude
,
interpolation
=
interpolation
,
fill
=
fill
)
elif
op_name
==
"Brightness"
:
img
=
F
.
adjust_brightness
(
img
,
1.0
+
magnitude
)
elif
op_name
==
"Color"
:
img
=
F
.
adjust_saturation
(
img
,
1.0
+
magnitude
)
elif
op_name
==
"Contrast"
:
img
=
F
.
adjust_contrast
(
img
,
1.0
+
magnitude
)
elif
op_name
==
"Sharpness"
:
img
=
F
.
adjust_sharpness
(
img
,
1.0
+
magnitude
)
elif
op_name
==
"Posterize"
:
img
=
F
.
posterize
(
img
,
int
(
magnitude
))
elif
op_name
==
"Solarize"
:
img
=
F
.
solarize
(
img
,
magnitude
)
elif
op_name
==
"AutoContrast"
:
img
=
F
.
autocontrast
(
img
)
elif
op_name
==
"Equalize"
:
img
=
F
.
equalize
(
img
)
elif
op_name
==
"Invert"
:
img
=
F
.
invert
(
img
)
else
:
raise
ValueError
(
"The provided operator {} is not recognized."
.
format
(
op_name
))
return
img
class
AutoAugmentPolicy
(
Enum
):
class
AutoAugmentPolicy
(
Enum
):
"""AutoAugment policies learned on different datasets.
"""AutoAugment policies learned on different datasets.
Available policies are IMAGENET, CIFAR10 and SVHN.
Available policies are IMAGENET, CIFAR10 and SVHN.
...
@@ -19,9 +58,39 @@ class AutoAugmentPolicy(Enum):
...
@@ -19,9 +58,39 @@ class AutoAugmentPolicy(Enum):
SVHN
=
"svhn"
SVHN
=
"svhn"
def
_get_transforms
(
# type: ignore[return]
class
AutoAugment
(
torch
.
nn
.
Module
):
r
"""AutoAugment data augmentation method based on
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
policy (AutoAugmentPolicy): Desired policy enum defined by
:class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
def
__init__
(
self
,
policy
:
AutoAugmentPolicy
=
AutoAugmentPolicy
.
IMAGENET
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
fill
:
Optional
[
List
[
float
]]
=
None
)
->
None
:
super
().
__init__
()
self
.
policy
=
policy
self
.
interpolation
=
interpolation
self
.
fill
=
fill
self
.
transforms
=
self
.
_get_transforms
(
policy
)
def
_get_transforms
(
self
,
policy
:
AutoAugmentPolicy
policy
:
AutoAugmentPolicy
)
->
List
[
Tuple
[
Tuple
[
str
,
float
,
Optional
[
int
]],
Tuple
[
str
,
float
,
Optional
[
int
]]]]:
)
->
List
[
Tuple
[
Tuple
[
str
,
float
,
Optional
[
int
]],
Tuple
[
str
,
float
,
Optional
[
int
]]]]:
if
policy
==
AutoAugmentPolicy
.
IMAGENET
:
if
policy
==
AutoAugmentPolicy
.
IMAGENET
:
return
[
return
[
((
"Posterize"
,
0.4
,
8
),
(
"Rotate"
,
0.6
,
9
)),
((
"Posterize"
,
0.4
,
8
),
(
"Rotate"
,
0.6
,
9
)),
...
@@ -106,62 +175,28 @@ def _get_transforms( # type: ignore[return]
...
@@ -106,62 +175,28 @@ def _get_transforms( # type: ignore[return]
((
"ShearY"
,
0.8
,
5
),
(
"AutoContrast"
,
0.7
,
None
)),
((
"ShearY"
,
0.8
,
5
),
(
"AutoContrast"
,
0.7
,
None
)),
((
"ShearX"
,
0.7
,
2
),
(
"Invert"
,
0.1
,
None
)),
((
"ShearX"
,
0.7
,
2
),
(
"Invert"
,
0.1
,
None
)),
]
]
else
:
raise
ValueError
(
"The provided policy {} is not recognized."
.
format
(
policy
))
def
_get_magnitudes
(
self
,
num_bins
:
int
,
image_size
:
List
[
int
])
->
Dict
[
str
,
Tuple
[
Tensor
,
bool
]]:
def
_get_magnitudes
()
->
Dict
[
str
,
Tuple
[
Optional
[
Tensor
],
Optional
[
bool
]]]:
_BINS
=
10
return
{
return
{
# name: (magnitudes, signed)
# name: (magnitudes, signed)
"ShearX"
:
(
torch
.
linspace
(
0.0
,
0.3
,
_BINS
),
True
),
"ShearX"
:
(
torch
.
linspace
(
0.0
,
0.3
,
num_bins
),
True
),
"ShearY"
:
(
torch
.
linspace
(
0.0
,
0.3
,
_BINS
),
True
),
"ShearY"
:
(
torch
.
linspace
(
0.0
,
0.3
,
num_bins
),
True
),
"TranslateX"
:
(
torch
.
linspace
(
0.0
,
150.0
/
331.0
,
_BINS
),
True
),
"TranslateX"
:
(
torch
.
linspace
(
0.0
,
150.0
/
331.0
*
image_size
[
0
],
num_bins
),
True
),
"TranslateY"
:
(
torch
.
linspace
(
0.0
,
150.0
/
331.0
,
_BINS
),
True
),
"TranslateY"
:
(
torch
.
linspace
(
0.0
,
150.0
/
331.0
*
image_size
[
1
],
num_bins
),
True
),
"Rotate"
:
(
torch
.
linspace
(
0.0
,
30.0
,
_BINS
),
True
),
"Rotate"
:
(
torch
.
linspace
(
0.0
,
30.0
,
num_bins
),
True
),
"Brightness"
:
(
torch
.
linspace
(
0.0
,
0.9
,
_BINS
),
True
),
"Brightness"
:
(
torch
.
linspace
(
0.0
,
0.9
,
num_bins
),
True
),
"Color"
:
(
torch
.
linspace
(
0.0
,
0.9
,
_BINS
),
True
),
"Color"
:
(
torch
.
linspace
(
0.0
,
0.9
,
num_bins
),
True
),
"Contrast"
:
(
torch
.
linspace
(
0.0
,
0.9
,
_BINS
),
True
),
"Contrast"
:
(
torch
.
linspace
(
0.0
,
0.9
,
num_bins
),
True
),
"Sharpness"
:
(
torch
.
linspace
(
0.0
,
0.9
,
_BINS
),
True
),
"Sharpness"
:
(
torch
.
linspace
(
0.0
,
0.9
,
num_bins
),
True
),
"Posterize"
:
(
torch
.
tensor
([
8
,
8
,
7
,
7
,
6
,
6
,
5
,
5
,
4
,
4
]
),
False
),
"Posterize"
:
(
8
-
(
torch
.
arange
(
num_bins
)
/
((
num_bins
-
1
)
/
4
)).
round
().
int
(
),
False
),
"Solarize"
:
(
torch
.
linspace
(
256.0
,
0.0
,
_BINS
),
False
),
"Solarize"
:
(
torch
.
linspace
(
256.0
,
0.0
,
num_bins
),
False
),
"AutoContrast"
:
(
None
,
Non
e
),
"AutoContrast"
:
(
torch
.
tensor
(
0.0
),
Fals
e
),
"Equalize"
:
(
None
,
Non
e
),
"Equalize"
:
(
torch
.
tensor
(
0.0
),
Fals
e
),
"Invert"
:
(
None
,
Non
e
),
"Invert"
:
(
torch
.
tensor
(
0.0
),
Fals
e
),
}
}
class
AutoAugment
(
torch
.
nn
.
Module
):
r
"""AutoAugment data augmentation method based on
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
policy (AutoAugmentPolicy): Desired policy enum defined by
:class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
def
__init__
(
self
,
policy
:
AutoAugmentPolicy
=
AutoAugmentPolicy
.
IMAGENET
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
NEAREST
,
fill
:
Optional
[
List
[
float
]]
=
None
)
->
None
:
super
().
__init__
()
self
.
policy
=
policy
self
.
interpolation
=
interpolation
self
.
fill
=
fill
self
.
transforms
=
_get_transforms
(
policy
)
if
self
.
transforms
is
None
:
raise
ValueError
(
"The provided policy {} is not recognized."
.
format
(
policy
))
self
.
_op_meta
=
_get_magnitudes
()
@
staticmethod
@
staticmethod
def
get_params
(
transform_num
:
int
)
->
Tuple
[
int
,
Tensor
,
Tensor
]:
def
get_params
(
transform_num
:
int
)
->
Tuple
[
int
,
Tensor
,
Tensor
]:
"""Get parameters for autoaugment transformation
"""Get parameters for autoaugment transformation
...
@@ -175,9 +210,6 @@ class AutoAugment(torch.nn.Module):
...
@@ -175,9 +210,6 @@ class AutoAugment(torch.nn.Module):
return
policy_id
,
probs
,
signs
return
policy_id
,
probs
,
signs
def
_get_op_meta
(
self
,
name
:
str
)
->
Tuple
[
Optional
[
Tensor
],
Optional
[
bool
]]:
return
self
.
_op_meta
[
name
]
def
forward
(
self
,
img
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
img
:
Tensor
)
->
Tensor
:
"""
"""
img (PIL Image or Tensor): Image to be transformed.
img (PIL Image or Tensor): Image to be transformed.
...
@@ -196,46 +228,12 @@ class AutoAugment(torch.nn.Module):
...
@@ -196,46 +228,12 @@ class AutoAugment(torch.nn.Module):
for
i
,
(
op_name
,
p
,
magnitude_id
)
in
enumerate
(
self
.
transforms
[
transform_id
]):
for
i
,
(
op_name
,
p
,
magnitude_id
)
in
enumerate
(
self
.
transforms
[
transform_id
]):
if
probs
[
i
]
<=
p
:
if
probs
[
i
]
<=
p
:
magnitudes
,
signed
=
self
.
_get_op_meta
(
op_name
)
op_meta
=
self
.
_get_magnitudes
(
10
,
F
.
get_image_size
(
img
)
)
magnitude
=
float
(
magnitudes
[
magnitude_id
].
item
())
\
magnitude
s
,
signed
=
op_meta
[
op_name
]
if
magnitude
s
is
not
None
and
magnitude_id
is
not
None
else
0.0
magnitude
=
float
(
magnitudes
[
magnitude_id
].
item
())
if
magnitude_id
is
not
None
else
0.0
if
signed
is
not
None
and
signed
and
signs
[
i
]
==
0
:
if
signed
and
signs
[
i
]
==
0
:
magnitude
*=
-
1.0
magnitude
*=
-
1.0
img
=
_apply_op
(
img
,
op_name
,
magnitude
,
interpolation
=
self
.
interpolation
,
fill
=
fill
)
if
op_name
==
"ShearX"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
math
.
degrees
(
magnitude
),
0.0
],
interpolation
=
self
.
interpolation
,
fill
=
fill
)
elif
op_name
==
"ShearY"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
0.0
,
math
.
degrees
(
magnitude
)],
interpolation
=
self
.
interpolation
,
fill
=
fill
)
elif
op_name
==
"TranslateX"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
int
(
F
.
get_image_size
(
img
)[
0
]
*
magnitude
),
0
],
scale
=
1.0
,
interpolation
=
self
.
interpolation
,
shear
=
[
0.0
,
0.0
],
fill
=
fill
)
elif
op_name
==
"TranslateY"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
0
,
int
(
F
.
get_image_size
(
img
)[
1
]
*
magnitude
)],
scale
=
1.0
,
interpolation
=
self
.
interpolation
,
shear
=
[
0.0
,
0.0
],
fill
=
fill
)
elif
op_name
==
"Rotate"
:
img
=
F
.
rotate
(
img
,
magnitude
,
interpolation
=
self
.
interpolation
,
fill
=
fill
)
elif
op_name
==
"Brightness"
:
img
=
F
.
adjust_brightness
(
img
,
1.0
+
magnitude
)
elif
op_name
==
"Color"
:
img
=
F
.
adjust_saturation
(
img
,
1.0
+
magnitude
)
elif
op_name
==
"Contrast"
:
img
=
F
.
adjust_contrast
(
img
,
1.0
+
magnitude
)
elif
op_name
==
"Sharpness"
:
img
=
F
.
adjust_sharpness
(
img
,
1.0
+
magnitude
)
elif
op_name
==
"Posterize"
:
img
=
F
.
posterize
(
img
,
int
(
magnitude
))
elif
op_name
==
"Solarize"
:
img
=
F
.
solarize
(
img
,
magnitude
)
elif
op_name
==
"AutoContrast"
:
img
=
F
.
autocontrast
(
img
)
elif
op_name
==
"Equalize"
:
img
=
F
.
equalize
(
img
)
elif
op_name
==
"Invert"
:
img
=
F
.
invert
(
img
)
else
:
raise
ValueError
(
"The provided operator {} is not recognized."
.
format
(
op_name
))
return
img
return
img
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment