Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
3a7e5e38
Unverified
Commit
3a7e5e38
authored
Aug 31, 2021
by
Vasilis Vryniotis
Committed by
GitHub
Aug 31, 2021
Browse files
Refactor AutoAugment to support more augmentations. (#4338)
parent
72393f2b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
153 additions
and
155 deletions
+153
-155
torchvision/transforms/autoaugment.py
torchvision/transforms/autoaugment.py
+153
-155
No files found.
torchvision/transforms/autoaugment.py
View file @
3a7e5e38
...
@@ -10,6 +10,45 @@ from . import functional as F, InterpolationMode
...
@@ -10,6 +10,45 @@ from . import functional as F, InterpolationMode
__all__
=
[
"AutoAugmentPolicy"
,
"AutoAugment"
]
__all__
=
[
"AutoAugmentPolicy"
,
"AutoAugment"
]
def
_apply_op
(
img
:
Tensor
,
op_name
:
str
,
magnitude
:
float
,
interpolation
:
InterpolationMode
,
fill
:
Optional
[
List
[
float
]]):
if
op_name
==
"ShearX"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
math
.
degrees
(
magnitude
),
0.0
],
interpolation
=
interpolation
,
fill
=
fill
)
elif
op_name
==
"ShearY"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
0.0
,
math
.
degrees
(
magnitude
)],
interpolation
=
interpolation
,
fill
=
fill
)
elif
op_name
==
"TranslateX"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
int
(
magnitude
),
0
],
scale
=
1.0
,
interpolation
=
interpolation
,
shear
=
[
0.0
,
0.0
],
fill
=
fill
)
elif
op_name
==
"TranslateY"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
0
,
int
(
magnitude
)],
scale
=
1.0
,
interpolation
=
interpolation
,
shear
=
[
0.0
,
0.0
],
fill
=
fill
)
elif
op_name
==
"Rotate"
:
img
=
F
.
rotate
(
img
,
magnitude
,
interpolation
=
interpolation
,
fill
=
fill
)
elif
op_name
==
"Brightness"
:
img
=
F
.
adjust_brightness
(
img
,
1.0
+
magnitude
)
elif
op_name
==
"Color"
:
img
=
F
.
adjust_saturation
(
img
,
1.0
+
magnitude
)
elif
op_name
==
"Contrast"
:
img
=
F
.
adjust_contrast
(
img
,
1.0
+
magnitude
)
elif
op_name
==
"Sharpness"
:
img
=
F
.
adjust_sharpness
(
img
,
1.0
+
magnitude
)
elif
op_name
==
"Posterize"
:
img
=
F
.
posterize
(
img
,
int
(
magnitude
))
elif
op_name
==
"Solarize"
:
img
=
F
.
solarize
(
img
,
magnitude
)
elif
op_name
==
"AutoContrast"
:
img
=
F
.
autocontrast
(
img
)
elif
op_name
==
"Equalize"
:
img
=
F
.
equalize
(
img
)
elif
op_name
==
"Invert"
:
img
=
F
.
invert
(
img
)
else
:
raise
ValueError
(
"The provided operator {} is not recognized."
.
format
(
op_name
))
return
img
class
AutoAugmentPolicy
(
Enum
):
class
AutoAugmentPolicy
(
Enum
):
"""AutoAugment policies learned on different datasets.
"""AutoAugment policies learned on different datasets.
Available policies are IMAGENET, CIFAR10 and SVHN.
Available policies are IMAGENET, CIFAR10 and SVHN.
...
@@ -19,116 +58,6 @@ class AutoAugmentPolicy(Enum):
...
@@ -19,116 +58,6 @@ class AutoAugmentPolicy(Enum):
SVHN
=
"svhn"
SVHN
=
"svhn"
def
_get_transforms
(
# type: ignore[return]
policy
:
AutoAugmentPolicy
)
->
List
[
Tuple
[
Tuple
[
str
,
float
,
Optional
[
int
]],
Tuple
[
str
,
float
,
Optional
[
int
]]]]:
if
policy
==
AutoAugmentPolicy
.
IMAGENET
:
return
[
((
"Posterize"
,
0.4
,
8
),
(
"Rotate"
,
0.6
,
9
)),
((
"Solarize"
,
0.6
,
5
),
(
"AutoContrast"
,
0.6
,
None
)),
((
"Equalize"
,
0.8
,
None
),
(
"Equalize"
,
0.6
,
None
)),
((
"Posterize"
,
0.6
,
7
),
(
"Posterize"
,
0.6
,
6
)),
((
"Equalize"
,
0.4
,
None
),
(
"Solarize"
,
0.2
,
4
)),
((
"Equalize"
,
0.4
,
None
),
(
"Rotate"
,
0.8
,
8
)),
((
"Solarize"
,
0.6
,
3
),
(
"Equalize"
,
0.6
,
None
)),
((
"Posterize"
,
0.8
,
5
),
(
"Equalize"
,
1.0
,
None
)),
((
"Rotate"
,
0.2
,
3
),
(
"Solarize"
,
0.6
,
8
)),
((
"Equalize"
,
0.6
,
None
),
(
"Posterize"
,
0.4
,
6
)),
((
"Rotate"
,
0.8
,
8
),
(
"Color"
,
0.4
,
0
)),
((
"Rotate"
,
0.4
,
9
),
(
"Equalize"
,
0.6
,
None
)),
((
"Equalize"
,
0.0
,
None
),
(
"Equalize"
,
0.8
,
None
)),
((
"Invert"
,
0.6
,
None
),
(
"Equalize"
,
1.0
,
None
)),
((
"Color"
,
0.6
,
4
),
(
"Contrast"
,
1.0
,
8
)),
((
"Rotate"
,
0.8
,
8
),
(
"Color"
,
1.0
,
2
)),
((
"Color"
,
0.8
,
8
),
(
"Solarize"
,
0.8
,
7
)),
((
"Sharpness"
,
0.4
,
7
),
(
"Invert"
,
0.6
,
None
)),
((
"ShearX"
,
0.6
,
5
),
(
"Equalize"
,
1.0
,
None
)),
((
"Color"
,
0.4
,
0
),
(
"Equalize"
,
0.6
,
None
)),
((
"Equalize"
,
0.4
,
None
),
(
"Solarize"
,
0.2
,
4
)),
((
"Solarize"
,
0.6
,
5
),
(
"AutoContrast"
,
0.6
,
None
)),
((
"Invert"
,
0.6
,
None
),
(
"Equalize"
,
1.0
,
None
)),
((
"Color"
,
0.6
,
4
),
(
"Contrast"
,
1.0
,
8
)),
((
"Equalize"
,
0.8
,
None
),
(
"Equalize"
,
0.6
,
None
)),
]
elif
policy
==
AutoAugmentPolicy
.
CIFAR10
:
return
[
((
"Invert"
,
0.1
,
None
),
(
"Contrast"
,
0.2
,
6
)),
((
"Rotate"
,
0.7
,
2
),
(
"TranslateX"
,
0.3
,
9
)),
((
"Sharpness"
,
0.8
,
1
),
(
"Sharpness"
,
0.9
,
3
)),
((
"ShearY"
,
0.5
,
8
),
(
"TranslateY"
,
0.7
,
9
)),
((
"AutoContrast"
,
0.5
,
None
),
(
"Equalize"
,
0.9
,
None
)),
((
"ShearY"
,
0.2
,
7
),
(
"Posterize"
,
0.3
,
7
)),
((
"Color"
,
0.4
,
3
),
(
"Brightness"
,
0.6
,
7
)),
((
"Sharpness"
,
0.3
,
9
),
(
"Brightness"
,
0.7
,
9
)),
((
"Equalize"
,
0.6
,
None
),
(
"Equalize"
,
0.5
,
None
)),
((
"Contrast"
,
0.6
,
7
),
(
"Sharpness"
,
0.6
,
5
)),
((
"Color"
,
0.7
,
7
),
(
"TranslateX"
,
0.5
,
8
)),
((
"Equalize"
,
0.3
,
None
),
(
"AutoContrast"
,
0.4
,
None
)),
((
"TranslateY"
,
0.4
,
3
),
(
"Sharpness"
,
0.2
,
6
)),
((
"Brightness"
,
0.9
,
6
),
(
"Color"
,
0.2
,
8
)),
((
"Solarize"
,
0.5
,
2
),
(
"Invert"
,
0.0
,
None
)),
((
"Equalize"
,
0.2
,
None
),
(
"AutoContrast"
,
0.6
,
None
)),
((
"Equalize"
,
0.2
,
None
),
(
"Equalize"
,
0.6
,
None
)),
((
"Color"
,
0.9
,
9
),
(
"Equalize"
,
0.6
,
None
)),
((
"AutoContrast"
,
0.8
,
None
),
(
"Solarize"
,
0.2
,
8
)),
((
"Brightness"
,
0.1
,
3
),
(
"Color"
,
0.7
,
0
)),
((
"Solarize"
,
0.4
,
5
),
(
"AutoContrast"
,
0.9
,
None
)),
((
"TranslateY"
,
0.9
,
9
),
(
"TranslateY"
,
0.7
,
9
)),
((
"AutoContrast"
,
0.9
,
None
),
(
"Solarize"
,
0.8
,
3
)),
((
"Equalize"
,
0.8
,
None
),
(
"Invert"
,
0.1
,
None
)),
((
"TranslateY"
,
0.7
,
9
),
(
"AutoContrast"
,
0.9
,
None
)),
]
elif
policy
==
AutoAugmentPolicy
.
SVHN
:
return
[
((
"ShearX"
,
0.9
,
4
),
(
"Invert"
,
0.2
,
None
)),
((
"ShearY"
,
0.9
,
8
),
(
"Invert"
,
0.7
,
None
)),
((
"Equalize"
,
0.6
,
None
),
(
"Solarize"
,
0.6
,
6
)),
((
"Invert"
,
0.9
,
None
),
(
"Equalize"
,
0.6
,
None
)),
((
"Equalize"
,
0.6
,
None
),
(
"Rotate"
,
0.9
,
3
)),
((
"ShearX"
,
0.9
,
4
),
(
"AutoContrast"
,
0.8
,
None
)),
((
"ShearY"
,
0.9
,
8
),
(
"Invert"
,
0.4
,
None
)),
((
"ShearY"
,
0.9
,
5
),
(
"Solarize"
,
0.2
,
6
)),
((
"Invert"
,
0.9
,
None
),
(
"AutoContrast"
,
0.8
,
None
)),
((
"Equalize"
,
0.6
,
None
),
(
"Rotate"
,
0.9
,
3
)),
((
"ShearX"
,
0.9
,
4
),
(
"Solarize"
,
0.3
,
3
)),
((
"ShearY"
,
0.8
,
8
),
(
"Invert"
,
0.7
,
None
)),
((
"Equalize"
,
0.9
,
None
),
(
"TranslateY"
,
0.6
,
6
)),
((
"Invert"
,
0.9
,
None
),
(
"Equalize"
,
0.6
,
None
)),
((
"Contrast"
,
0.3
,
3
),
(
"Rotate"
,
0.8
,
4
)),
((
"Invert"
,
0.8
,
None
),
(
"TranslateY"
,
0.0
,
2
)),
((
"ShearY"
,
0.7
,
6
),
(
"Solarize"
,
0.4
,
8
)),
((
"Invert"
,
0.6
,
None
),
(
"Rotate"
,
0.8
,
4
)),
((
"ShearY"
,
0.3
,
7
),
(
"TranslateX"
,
0.9
,
3
)),
((
"ShearX"
,
0.1
,
6
),
(
"Invert"
,
0.6
,
None
)),
((
"Solarize"
,
0.7
,
2
),
(
"TranslateY"
,
0.6
,
7
)),
((
"ShearY"
,
0.8
,
4
),
(
"Invert"
,
0.8
,
None
)),
((
"ShearX"
,
0.7
,
9
),
(
"TranslateY"
,
0.8
,
3
)),
((
"ShearY"
,
0.8
,
5
),
(
"AutoContrast"
,
0.7
,
None
)),
((
"ShearX"
,
0.7
,
2
),
(
"Invert"
,
0.1
,
None
)),
]
def
_get_magnitudes
()
->
Dict
[
str
,
Tuple
[
Optional
[
Tensor
],
Optional
[
bool
]]]:
_BINS
=
10
return
{
# name: (magnitudes, signed)
"ShearX"
:
(
torch
.
linspace
(
0.0
,
0.3
,
_BINS
),
True
),
"ShearY"
:
(
torch
.
linspace
(
0.0
,
0.3
,
_BINS
),
True
),
"TranslateX"
:
(
torch
.
linspace
(
0.0
,
150.0
/
331.0
,
_BINS
),
True
),
"TranslateY"
:
(
torch
.
linspace
(
0.0
,
150.0
/
331.0
,
_BINS
),
True
),
"Rotate"
:
(
torch
.
linspace
(
0.0
,
30.0
,
_BINS
),
True
),
"Brightness"
:
(
torch
.
linspace
(
0.0
,
0.9
,
_BINS
),
True
),
"Color"
:
(
torch
.
linspace
(
0.0
,
0.9
,
_BINS
),
True
),
"Contrast"
:
(
torch
.
linspace
(
0.0
,
0.9
,
_BINS
),
True
),
"Sharpness"
:
(
torch
.
linspace
(
0.0
,
0.9
,
_BINS
),
True
),
"Posterize"
:
(
torch
.
tensor
([
8
,
8
,
7
,
7
,
6
,
6
,
5
,
5
,
4
,
4
]),
False
),
"Solarize"
:
(
torch
.
linspace
(
256.0
,
0.0
,
_BINS
),
False
),
"AutoContrast"
:
(
None
,
None
),
"Equalize"
:
(
None
,
None
),
"Invert"
:
(
None
,
None
),
}
class
AutoAugment
(
torch
.
nn
.
Module
):
class
AutoAugment
(
torch
.
nn
.
Module
):
r
"""AutoAugment data augmentation method based on
r
"""AutoAugment data augmentation method based on
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
...
@@ -156,11 +85,117 @@ class AutoAugment(torch.nn.Module):
...
@@ -156,11 +85,117 @@ class AutoAugment(torch.nn.Module):
self
.
policy
=
policy
self
.
policy
=
policy
self
.
interpolation
=
interpolation
self
.
interpolation
=
interpolation
self
.
fill
=
fill
self
.
fill
=
fill
self
.
transforms
=
self
.
_get_transforms
(
policy
)
self
.
transforms
=
_get_transforms
(
policy
)
def
_get_transforms
(
if
self
.
transforms
is
None
:
self
,
policy
:
AutoAugmentPolicy
)
->
List
[
Tuple
[
Tuple
[
str
,
float
,
Optional
[
int
]],
Tuple
[
str
,
float
,
Optional
[
int
]]]]:
if
policy
==
AutoAugmentPolicy
.
IMAGENET
:
return
[
((
"Posterize"
,
0.4
,
8
),
(
"Rotate"
,
0.6
,
9
)),
((
"Solarize"
,
0.6
,
5
),
(
"AutoContrast"
,
0.6
,
None
)),
((
"Equalize"
,
0.8
,
None
),
(
"Equalize"
,
0.6
,
None
)),
((
"Posterize"
,
0.6
,
7
),
(
"Posterize"
,
0.6
,
6
)),
((
"Equalize"
,
0.4
,
None
),
(
"Solarize"
,
0.2
,
4
)),
((
"Equalize"
,
0.4
,
None
),
(
"Rotate"
,
0.8
,
8
)),
((
"Solarize"
,
0.6
,
3
),
(
"Equalize"
,
0.6
,
None
)),
((
"Posterize"
,
0.8
,
5
),
(
"Equalize"
,
1.0
,
None
)),
((
"Rotate"
,
0.2
,
3
),
(
"Solarize"
,
0.6
,
8
)),
((
"Equalize"
,
0.6
,
None
),
(
"Posterize"
,
0.4
,
6
)),
((
"Rotate"
,
0.8
,
8
),
(
"Color"
,
0.4
,
0
)),
((
"Rotate"
,
0.4
,
9
),
(
"Equalize"
,
0.6
,
None
)),
((
"Equalize"
,
0.0
,
None
),
(
"Equalize"
,
0.8
,
None
)),
((
"Invert"
,
0.6
,
None
),
(
"Equalize"
,
1.0
,
None
)),
((
"Color"
,
0.6
,
4
),
(
"Contrast"
,
1.0
,
8
)),
((
"Rotate"
,
0.8
,
8
),
(
"Color"
,
1.0
,
2
)),
((
"Color"
,
0.8
,
8
),
(
"Solarize"
,
0.8
,
7
)),
((
"Sharpness"
,
0.4
,
7
),
(
"Invert"
,
0.6
,
None
)),
((
"ShearX"
,
0.6
,
5
),
(
"Equalize"
,
1.0
,
None
)),
((
"Color"
,
0.4
,
0
),
(
"Equalize"
,
0.6
,
None
)),
((
"Equalize"
,
0.4
,
None
),
(
"Solarize"
,
0.2
,
4
)),
((
"Solarize"
,
0.6
,
5
),
(
"AutoContrast"
,
0.6
,
None
)),
((
"Invert"
,
0.6
,
None
),
(
"Equalize"
,
1.0
,
None
)),
((
"Color"
,
0.6
,
4
),
(
"Contrast"
,
1.0
,
8
)),
((
"Equalize"
,
0.8
,
None
),
(
"Equalize"
,
0.6
,
None
)),
]
elif
policy
==
AutoAugmentPolicy
.
CIFAR10
:
return
[
((
"Invert"
,
0.1
,
None
),
(
"Contrast"
,
0.2
,
6
)),
((
"Rotate"
,
0.7
,
2
),
(
"TranslateX"
,
0.3
,
9
)),
((
"Sharpness"
,
0.8
,
1
),
(
"Sharpness"
,
0.9
,
3
)),
((
"ShearY"
,
0.5
,
8
),
(
"TranslateY"
,
0.7
,
9
)),
((
"AutoContrast"
,
0.5
,
None
),
(
"Equalize"
,
0.9
,
None
)),
((
"ShearY"
,
0.2
,
7
),
(
"Posterize"
,
0.3
,
7
)),
((
"Color"
,
0.4
,
3
),
(
"Brightness"
,
0.6
,
7
)),
((
"Sharpness"
,
0.3
,
9
),
(
"Brightness"
,
0.7
,
9
)),
((
"Equalize"
,
0.6
,
None
),
(
"Equalize"
,
0.5
,
None
)),
((
"Contrast"
,
0.6
,
7
),
(
"Sharpness"
,
0.6
,
5
)),
((
"Color"
,
0.7
,
7
),
(
"TranslateX"
,
0.5
,
8
)),
((
"Equalize"
,
0.3
,
None
),
(
"AutoContrast"
,
0.4
,
None
)),
((
"TranslateY"
,
0.4
,
3
),
(
"Sharpness"
,
0.2
,
6
)),
((
"Brightness"
,
0.9
,
6
),
(
"Color"
,
0.2
,
8
)),
((
"Solarize"
,
0.5
,
2
),
(
"Invert"
,
0.0
,
None
)),
((
"Equalize"
,
0.2
,
None
),
(
"AutoContrast"
,
0.6
,
None
)),
((
"Equalize"
,
0.2
,
None
),
(
"Equalize"
,
0.6
,
None
)),
((
"Color"
,
0.9
,
9
),
(
"Equalize"
,
0.6
,
None
)),
((
"AutoContrast"
,
0.8
,
None
),
(
"Solarize"
,
0.2
,
8
)),
((
"Brightness"
,
0.1
,
3
),
(
"Color"
,
0.7
,
0
)),
((
"Solarize"
,
0.4
,
5
),
(
"AutoContrast"
,
0.9
,
None
)),
((
"TranslateY"
,
0.9
,
9
),
(
"TranslateY"
,
0.7
,
9
)),
((
"AutoContrast"
,
0.9
,
None
),
(
"Solarize"
,
0.8
,
3
)),
((
"Equalize"
,
0.8
,
None
),
(
"Invert"
,
0.1
,
None
)),
((
"TranslateY"
,
0.7
,
9
),
(
"AutoContrast"
,
0.9
,
None
)),
]
elif
policy
==
AutoAugmentPolicy
.
SVHN
:
return
[
((
"ShearX"
,
0.9
,
4
),
(
"Invert"
,
0.2
,
None
)),
((
"ShearY"
,
0.9
,
8
),
(
"Invert"
,
0.7
,
None
)),
((
"Equalize"
,
0.6
,
None
),
(
"Solarize"
,
0.6
,
6
)),
((
"Invert"
,
0.9
,
None
),
(
"Equalize"
,
0.6
,
None
)),
((
"Equalize"
,
0.6
,
None
),
(
"Rotate"
,
0.9
,
3
)),
((
"ShearX"
,
0.9
,
4
),
(
"AutoContrast"
,
0.8
,
None
)),
((
"ShearY"
,
0.9
,
8
),
(
"Invert"
,
0.4
,
None
)),
((
"ShearY"
,
0.9
,
5
),
(
"Solarize"
,
0.2
,
6
)),
((
"Invert"
,
0.9
,
None
),
(
"AutoContrast"
,
0.8
,
None
)),
((
"Equalize"
,
0.6
,
None
),
(
"Rotate"
,
0.9
,
3
)),
((
"ShearX"
,
0.9
,
4
),
(
"Solarize"
,
0.3
,
3
)),
((
"ShearY"
,
0.8
,
8
),
(
"Invert"
,
0.7
,
None
)),
((
"Equalize"
,
0.9
,
None
),
(
"TranslateY"
,
0.6
,
6
)),
((
"Invert"
,
0.9
,
None
),
(
"Equalize"
,
0.6
,
None
)),
((
"Contrast"
,
0.3
,
3
),
(
"Rotate"
,
0.8
,
4
)),
((
"Invert"
,
0.8
,
None
),
(
"TranslateY"
,
0.0
,
2
)),
((
"ShearY"
,
0.7
,
6
),
(
"Solarize"
,
0.4
,
8
)),
((
"Invert"
,
0.6
,
None
),
(
"Rotate"
,
0.8
,
4
)),
((
"ShearY"
,
0.3
,
7
),
(
"TranslateX"
,
0.9
,
3
)),
((
"ShearX"
,
0.1
,
6
),
(
"Invert"
,
0.6
,
None
)),
((
"Solarize"
,
0.7
,
2
),
(
"TranslateY"
,
0.6
,
7
)),
((
"ShearY"
,
0.8
,
4
),
(
"Invert"
,
0.8
,
None
)),
((
"ShearX"
,
0.7
,
9
),
(
"TranslateY"
,
0.8
,
3
)),
((
"ShearY"
,
0.8
,
5
),
(
"AutoContrast"
,
0.7
,
None
)),
((
"ShearX"
,
0.7
,
2
),
(
"Invert"
,
0.1
,
None
)),
]
else
:
raise
ValueError
(
"The provided policy {} is not recognized."
.
format
(
policy
))
raise
ValueError
(
"The provided policy {} is not recognized."
.
format
(
policy
))
self
.
_op_meta
=
_get_magnitudes
()
def
_get_magnitudes
(
self
,
num_bins
:
int
,
image_size
:
List
[
int
])
->
Dict
[
str
,
Tuple
[
Tensor
,
bool
]]:
return
{
# name: (magnitudes, signed)
"ShearX"
:
(
torch
.
linspace
(
0.0
,
0.3
,
num_bins
),
True
),
"ShearY"
:
(
torch
.
linspace
(
0.0
,
0.3
,
num_bins
),
True
),
"TranslateX"
:
(
torch
.
linspace
(
0.0
,
150.0
/
331.0
*
image_size
[
0
],
num_bins
),
True
),
"TranslateY"
:
(
torch
.
linspace
(
0.0
,
150.0
/
331.0
*
image_size
[
1
],
num_bins
),
True
),
"Rotate"
:
(
torch
.
linspace
(
0.0
,
30.0
,
num_bins
),
True
),
"Brightness"
:
(
torch
.
linspace
(
0.0
,
0.9
,
num_bins
),
True
),
"Color"
:
(
torch
.
linspace
(
0.0
,
0.9
,
num_bins
),
True
),
"Contrast"
:
(
torch
.
linspace
(
0.0
,
0.9
,
num_bins
),
True
),
"Sharpness"
:
(
torch
.
linspace
(
0.0
,
0.9
,
num_bins
),
True
),
"Posterize"
:
(
8
-
(
torch
.
arange
(
num_bins
)
/
((
num_bins
-
1
)
/
4
)).
round
().
int
(),
False
),
"Solarize"
:
(
torch
.
linspace
(
256.0
,
0.0
,
num_bins
),
False
),
"AutoContrast"
:
(
torch
.
tensor
(
0.0
),
False
),
"Equalize"
:
(
torch
.
tensor
(
0.0
),
False
),
"Invert"
:
(
torch
.
tensor
(
0.0
),
False
),
}
@
staticmethod
@
staticmethod
def
get_params
(
transform_num
:
int
)
->
Tuple
[
int
,
Tensor
,
Tensor
]:
def
get_params
(
transform_num
:
int
)
->
Tuple
[
int
,
Tensor
,
Tensor
]:
...
@@ -175,9 +210,6 @@ class AutoAugment(torch.nn.Module):
...
@@ -175,9 +210,6 @@ class AutoAugment(torch.nn.Module):
return
policy_id
,
probs
,
signs
return
policy_id
,
probs
,
signs
def
_get_op_meta
(
self
,
name
:
str
)
->
Tuple
[
Optional
[
Tensor
],
Optional
[
bool
]]:
return
self
.
_op_meta
[
name
]
def
forward
(
self
,
img
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
img
:
Tensor
)
->
Tensor
:
"""
"""
img (PIL Image or Tensor): Image to be transformed.
img (PIL Image or Tensor): Image to be transformed.
...
@@ -196,46 +228,12 @@ class AutoAugment(torch.nn.Module):
...
@@ -196,46 +228,12 @@ class AutoAugment(torch.nn.Module):
for
i
,
(
op_name
,
p
,
magnitude_id
)
in
enumerate
(
self
.
transforms
[
transform_id
]):
for
i
,
(
op_name
,
p
,
magnitude_id
)
in
enumerate
(
self
.
transforms
[
transform_id
]):
if
probs
[
i
]
<=
p
:
if
probs
[
i
]
<=
p
:
magnitudes
,
signed
=
self
.
_get_op_meta
(
op_name
)
op_meta
=
self
.
_get_magnitudes
(
10
,
F
.
get_image_size
(
img
)
)
magnitude
=
float
(
magnitudes
[
magnitude_id
].
item
())
\
magnitude
s
,
signed
=
op_meta
[
op_name
]
if
magnitude
s
is
not
None
and
magnitude_id
is
not
None
else
0.0
magnitude
=
float
(
magnitudes
[
magnitude_id
].
item
())
if
magnitude_id
is
not
None
else
0.0
if
signed
is
not
None
and
signed
and
signs
[
i
]
==
0
:
if
signed
and
signs
[
i
]
==
0
:
magnitude
*=
-
1.0
magnitude
*=
-
1.0
img
=
_apply_op
(
img
,
op_name
,
magnitude
,
interpolation
=
self
.
interpolation
,
fill
=
fill
)
if
op_name
==
"ShearX"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
math
.
degrees
(
magnitude
),
0.0
],
interpolation
=
self
.
interpolation
,
fill
=
fill
)
elif
op_name
==
"ShearY"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
0.0
,
math
.
degrees
(
magnitude
)],
interpolation
=
self
.
interpolation
,
fill
=
fill
)
elif
op_name
==
"TranslateX"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
int
(
F
.
get_image_size
(
img
)[
0
]
*
magnitude
),
0
],
scale
=
1.0
,
interpolation
=
self
.
interpolation
,
shear
=
[
0.0
,
0.0
],
fill
=
fill
)
elif
op_name
==
"TranslateY"
:
img
=
F
.
affine
(
img
,
angle
=
0.0
,
translate
=
[
0
,
int
(
F
.
get_image_size
(
img
)[
1
]
*
magnitude
)],
scale
=
1.0
,
interpolation
=
self
.
interpolation
,
shear
=
[
0.0
,
0.0
],
fill
=
fill
)
elif
op_name
==
"Rotate"
:
img
=
F
.
rotate
(
img
,
magnitude
,
interpolation
=
self
.
interpolation
,
fill
=
fill
)
elif
op_name
==
"Brightness"
:
img
=
F
.
adjust_brightness
(
img
,
1.0
+
magnitude
)
elif
op_name
==
"Color"
:
img
=
F
.
adjust_saturation
(
img
,
1.0
+
magnitude
)
elif
op_name
==
"Contrast"
:
img
=
F
.
adjust_contrast
(
img
,
1.0
+
magnitude
)
elif
op_name
==
"Sharpness"
:
img
=
F
.
adjust_sharpness
(
img
,
1.0
+
magnitude
)
elif
op_name
==
"Posterize"
:
img
=
F
.
posterize
(
img
,
int
(
magnitude
))
elif
op_name
==
"Solarize"
:
img
=
F
.
solarize
(
img
,
magnitude
)
elif
op_name
==
"AutoContrast"
:
img
=
F
.
autocontrast
(
img
)
elif
op_name
==
"Equalize"
:
img
=
F
.
equalize
(
img
)
elif
op_name
==
"Invert"
:
img
=
F
.
invert
(
img
)
else
:
raise
ValueError
(
"The provided operator {} is not recognized."
.
format
(
op_name
))
return
img
return
img
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment