Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
68f301f7
Commit
68f301f7
authored
Dec 11, 2020
by
Hye Soo Yang
Committed by
A. Unique TensorFlower
Dec 11, 2020
Browse files
[keras_cv] Support custom policy for `AutoAugment`.
PiperOrigin-RevId: 347069098
parent
590201bb
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
143 additions
and
22 deletions
+143
-22
official/vision/beta/ops/augment.py
official/vision/beta/ops/augment.py
+62
-21
official/vision/beta/ops/augment_test.py
official/vision/beta/ops/augment_test.py
+81
-1
No files found.
official/vision/beta/ops/augment.py
View file @
68f301f7
...
...
@@ -18,9 +18,10 @@ AutoAugment Reference: https://arxiv.org/abs/1805.09501
RandAugment Reference: https://arxiv.org/abs/1909.13719
"""
import
math
from
typing
import
Any
,
List
,
Optional
,
Text
,
Tuple
,
Iterable
import
numpy
as
np
import
tensorflow
as
tf
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Text
,
Tuple
from
tensorflow.python.keras.layers.preprocessing
import
image_preprocessing
as
image_ops
...
...
@@ -732,7 +733,8 @@ class AutoAugment(ImageAugment):
def
__init__
(
self
,
augmentation_name
:
Text
=
'v0'
,
policies
:
Optional
[
Dict
[
Text
,
Any
]]
=
None
,
policies
:
Optional
[
Iterable
[
Iterable
[
Tuple
[
Text
,
float
,
float
]]]]
=
None
,
cutout_const
:
float
=
100
,
translate_const
:
float
=
250
):
"""Applies the AutoAugment policy to images.
...
...
@@ -745,17 +747,32 @@ class AutoAugment(ImageAugment):
the COCO dataset. `v1`, `v2` and `v3` are additional good policies found
on the COCO dataset that have slight variation in what operations were
used during the search procedure along with how many operations are
applied in parallel to a single image (2 vs 3).
applied in parallel to a single image (2 vs 3). Make sure to set
`policies` to `None` (the default) if you want to set options using
`augmentation_name`.
policies: list of lists of tuples in the form `(func, prob, level)`,
`func` is a string name of the augmentation function, `prob` is the
probability of applying the `func` operation, `level` is the input
argument for `func`.
probability of applying the `func` operation, `level` (or magnitude) is
the input argument for `func`. For example:
```
[[('Equalize', 0.9, 3), ('Color', 0.7, 8)],
[('Invert', 0.6, 5), ('Rotate', 0.2, 9), ('ShearX', 0.1, 2)], ...]
```
The outer-most list must be 3-d. The number of operations in a
sub-policy can vary from one sub-policy to another.
If you provide `policies` as input, any option set with
`augmentation_name` will get overriden as they are mutually exclusive.
cutout_const: multiplier for applying cutout.
translate_const: multiplier for applying translation.
Raises:
ValueError if `augmentation_name` is unsupported.
"""
super
(
AutoAugment
,
self
).
__init__
()
if
policies
is
None
:
self
.
augmentation_name
=
augmentation_name
self
.
cutout_const
=
float
(
cutout_const
)
self
.
translate_const
=
float
(
translate_const
)
self
.
available_policies
=
{
'v0'
:
self
.
policy_v0
(),
'test'
:
self
.
policy_test
(),
...
...
@@ -765,14 +782,31 @@ class AutoAugment(ImageAugment):
'reduced_imagenet'
:
self
.
policy_reduced_imagenet
(),
}
if
not
policies
:
if
augmentation_name
not
in
self
.
available_policies
:
raise
ValueError
(
'Invalid augmentation_name: {}'
.
format
(
augmentation_name
))
self
.
augmentation_name
=
augmentation_name
self
.
policies
=
self
.
available_policies
[
augmentation_name
]
self
.
cutout_const
=
float
(
cutout_const
)
self
.
translate_const
=
float
(
translate_const
)
else
:
self
.
_check_policy_shape
(
policies
)
self
.
policies
=
policies
def
_check_policy_shape
(
self
,
policies
):
"""Checks dimension and shape of the custom policy.
Args:
policies: List of list of tuples in the form `(func, prob, level)`. Must
have shape of `(:, :, 3)`.
Raises:
ValueError if the shape of `policies` is unexpected.
"""
in_shape
=
np
.
array
(
policies
).
shape
if
len
(
in_shape
)
!=
3
or
in_shape
[
-
1
:]
!=
(
3
,):
raise
ValueError
(
'Wrong shape detected for custom policy. Expected '
'(:, :, 3) but got {}.'
.
format
(
in_shape
))
def
distort
(
self
,
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Applies the AutoAugment policy to `image`.
...
...
@@ -803,9 +837,15 @@ class AutoAugment(ImageAugment):
tf_policies
=
[]
for
policy
in
self
.
policies
:
tf_policy
=
[]
assert_ranges
=
[]
# Link string name to the correct python function and make sure the
# correct argument is passed into that function.
for
policy_info
in
policy
:
_
,
prob
,
level
=
policy_info
assert_ranges
.
append
(
tf
.
Assert
(
tf
.
less_equal
(
prob
,
1.
),
[
prob
]))
assert_ranges
.
append
(
tf
.
Assert
(
tf
.
less_equal
(
level
,
int
(
_MAX_LEVEL
)),
[
level
]))
policy_info
=
list
(
policy_info
)
+
[
replace_value
,
self
.
cutout_const
,
self
.
translate_const
]
...
...
@@ -821,6 +861,7 @@ class AutoAugment(ImageAugment):
return
final_policy
with
tf
.
control_dependencies
(
assert_ranges
):
tf_policies
.
append
(
make_final_policy
(
tf_policy
))
image
=
select_and_apply_random_policy
(
tf_policies
,
image
)
...
...
official/vision/beta/ops/augment_test.py
View file @
68f301f7
...
...
@@ -19,6 +19,7 @@ from __future__ import division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
random
from
absl.testing
import
parameterized
import
tensorflow
as
tf
...
...
@@ -86,7 +87,16 @@ class TransformsTest(parameterized.TestCase, tf.test.TestCase):
self
.
assertAllEqual
(
image
,
augment
.
rotate
(
image
,
degrees
))
class
AutoaugmentTest
(
tf
.
test
.
TestCase
):
class
AutoaugmentTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
AVAILABLE_POLICIES
=
[
'v0'
,
'test'
,
'simple'
,
'reduced_cifar10'
,
'svhn'
,
'reduced_imagenet'
,
]
AVAILABLE_POLICIES
=
[
'v0'
,
...
...
@@ -135,6 +145,76 @@ class AutoaugmentTest(tf.test.TestCase):
self
.
assertEqual
((
224
,
224
,
3
),
image
.
shape
)
def
_generate_test_policy
(
self
):
"""Generate a test policy at random."""
op_list
=
list
(
augment
.
NAME_TO_FUNC
.
keys
())
size
=
6
prob
=
[
round
(
random
.
uniform
(
0.
,
1.
),
1
)
for
_
in
range
(
size
)]
mag
=
[
round
(
random
.
uniform
(
0
,
10
))
for
_
in
range
(
size
)]
policy
=
[]
for
i
in
range
(
0
,
size
,
2
):
policy
.
append
([(
op_list
[
i
],
prob
[
i
],
mag
[
i
]),
(
op_list
[
i
+
1
],
prob
[
i
+
1
],
mag
[
i
+
1
])])
return
policy
def
test_custom_policy
(
self
):
"""Test autoaugment with a custom policy."""
image
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
augmenter
=
augment
.
AutoAugment
(
policies
=
self
.
_generate_test_policy
())
aug_image
=
augmenter
.
distort
(
image
)
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
@
parameterized
.
named_parameters
(
{
'testcase_name'
:
'_OutOfRangeProb'
,
'sub_policy'
:
(
'Equalize'
,
1.1
,
3
),
'value'
:
'1.1'
},
{
'testcase_name'
:
'_OutOfRangeMag'
,
'sub_policy'
:
(
'Equalize'
,
0.9
,
11
),
'value'
:
'11'
},
)
def
test_invalid_custom_sub_policy
(
self
,
sub_policy
,
value
):
"""Test autoaugment with out-of-range values in the custom policy."""
image
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
policy
=
self
.
_generate_test_policy
()
policy
[
0
][
0
]
=
sub_policy
augmenter
=
augment
.
AutoAugment
(
policies
=
policy
)
with
self
.
assertRaisesRegex
(
tf
.
errors
.
InvalidArgumentError
,
r
'Expected \'tf.Tensor\(False, shape=\(\), dtype=bool\)\' to be true. '
r
'Summarized data: ({})'
.
format
(
value
)):
augmenter
.
distort
(
image
)
def
test_invalid_custom_policy_ndim
(
self
):
"""Test autoaugment with wrong dimension in the custom policy."""
policy
=
[[(
'Equalize'
,
0.8
,
1
),
(
'Shear'
,
0.8
,
4
)],
[(
'TranslateY'
,
0.6
,
3
),
(
'Rotate'
,
0.9
,
3
)]]
policy
=
[[
policy
]]
with
self
.
assertRaisesRegex
(
ValueError
,
r
'Expected \(:, :, 3\) but got \(1, 1, 2, 2, 3\).'
):
augment
.
AutoAugment
(
policies
=
policy
)
def
test_invalid_custom_policy_shape
(
self
):
"""Test autoaugment with wrong shape in the custom policy."""
policy
=
[[(
'Equalize'
,
0.8
,
1
,
1
),
(
'Shear'
,
0.8
,
4
,
1
)],
[(
'TranslateY'
,
0.6
,
3
,
1
),
(
'Rotate'
,
0.9
,
3
,
1
)]]
with
self
.
assertRaisesRegex
(
ValueError
,
r
'Expected \(:, :, 3\) but got \(2, 2, 4\)'
):
augment
.
AutoAugment
(
policies
=
policy
)
def
test_invalid_custom_policy_key
(
self
):
"""Test autoaugment with invalid key in the custom policy."""
image
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
policy
=
[[(
'AAAAA'
,
0.8
,
1
),
(
'Shear'
,
0.8
,
4
)],
[(
'TranslateY'
,
0.6
,
3
),
(
'Rotate'
,
0.9
,
3
)]]
augmenter
=
augment
.
AutoAugment
(
policies
=
policy
)
with
self
.
assertRaisesRegex
(
KeyError
,
'
\'
AAAAA
\'
'
):
augmenter
.
distort
(
image
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment