Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c9a7e0b2
Commit
c9a7e0b2
authored
Jan 12, 2022
by
A. Unique TensorFlower
Browse files
Add builder that applies bounding box-specific ops for RandAugment
PiperOrigin-RevId: 421439862
parent
49a5706c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
58 additions
and
4 deletions
+58
-4
official/vision/beta/configs/retinanet.py
official/vision/beta/configs/retinanet.py
+0
-1
official/vision/beta/dataloaders/retinanet_input.py
official/vision/beta/dataloaders/retinanet_input.py
+10
-3
official/vision/beta/ops/augment.py
official/vision/beta/ops/augment.py
+31
-0
official/vision/beta/ops/augment_test.py
official/vision/beta/ops/augment_test.py
+17
-0
No files found.
official/vision/beta/configs/retinanet.py
View file @
c9a7e0b2
...
@@ -58,7 +58,6 @@ class Parser(hyperparams.Config):
...
@@ -58,7 +58,6 @@ class Parser(hyperparams.Config):
skip_crowd_during_training
:
bool
=
True
skip_crowd_during_training
:
bool
=
True
max_num_instances
:
int
=
100
max_num_instances
:
int
=
100
# Can choose AutoAugment and RandAugment.
# Can choose AutoAugment and RandAugment.
# TODO(b/205346436) Support RandAugment.
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
# Keep for backward compatibility. Not used.
# Keep for backward compatibility. Not used.
...
...
official/vision/beta/dataloaders/retinanet_input.py
View file @
c9a7e0b2
...
@@ -75,7 +75,7 @@ class Parser(parser.Parser):
...
@@ -75,7 +75,7 @@ class Parser(parser.Parser):
upper-bound threshold to assign negative labels for anchors. An anchor
upper-bound threshold to assign negative labels for anchors. An anchor
with a score below the threshold is labeled negative.
with a score below the threshold is labeled negative.
aug_type: An optional Augmentation object to choose from AutoAugment and
aug_type: An optional Augmentation object to choose from AutoAugment and
RandAugment.
The latter is not supported, and will raise ValueError.
RandAugment.
aug_rand_hflip: `bool`, if True, augment training with random horizontal
aug_rand_hflip: `bool`, if True, augment training with random horizontal
flip.
flip.
aug_scale_min: `float`, the minimum scale applied to `output_size` for
aug_scale_min: `float`, the minimum scale applied to `output_size` for
...
@@ -122,8 +122,16 @@ class Parser(parser.Parser):
...
@@ -122,8 +122,16 @@ class Parser(parser.Parser):
augmentation_name
=
aug_type
.
autoaug
.
augmentation_name
,
augmentation_name
=
aug_type
.
autoaug
.
augmentation_name
,
cutout_const
=
aug_type
.
autoaug
.
cutout_const
,
cutout_const
=
aug_type
.
autoaug
.
cutout_const
,
translate_const
=
aug_type
.
autoaug
.
translate_const
)
translate_const
=
aug_type
.
autoaug
.
translate_const
)
elif
aug_type
.
type
==
'randaug'
:
logging
.
info
(
'Using RandAugment.'
)
self
.
_augmenter
=
augment
.
RandAugment
.
build_for_detection
(
num_layers
=
aug_type
.
randaug
.
num_layers
,
magnitude
=
aug_type
.
randaug
.
magnitude
,
cutout_const
=
aug_type
.
randaug
.
cutout_const
,
translate_const
=
aug_type
.
randaug
.
translate_const
,
prob_to_apply
=
aug_type
.
randaug
.
prob_to_apply
,
exclude_ops
=
aug_type
.
randaug
.
exclude_ops
)
else
:
else
:
# TODO(b/205346436) Support RandAugment.
raise
ValueError
(
f
'Augmentation policy
{
aug_type
.
type
}
not supported.'
)
raise
ValueError
(
f
'Augmentation policy
{
aug_type
.
type
}
not supported.'
)
# Deprecated. Data Augmentation with AutoAugment.
# Deprecated. Data Augmentation with AutoAugment.
...
@@ -162,7 +170,6 @@ class Parser(parser.Parser):
...
@@ -162,7 +170,6 @@ class Parser(parser.Parser):
# Apply autoaug or randaug.
# Apply autoaug or randaug.
if
self
.
_augmenter
is
not
None
:
if
self
.
_augmenter
is
not
None
:
image
,
boxes
=
self
.
_augmenter
.
distort_with_boxes
(
image
,
boxes
)
image
,
boxes
=
self
.
_augmenter
.
distort_with_boxes
(
image
,
boxes
)
image_shape
=
tf
.
shape
(
input
=
image
)[
0
:
2
]
image_shape
=
tf
.
shape
(
input
=
image
)[
0
:
2
]
# Normalizes image with mean and std pixel values.
# Normalizes image with mean and std pixel values.
...
...
official/vision/beta/ops/augment.py
View file @
c9a7e0b2
...
@@ -1950,6 +1950,37 @@ class RandAugment(ImageAugment):
...
@@ -1950,6 +1950,37 @@ class RandAugment(ImageAugment):
op
for
op
in
self
.
available_ops
if
op
not
in
exclude_ops
op
for
op
in
self
.
available_ops
if
op
not
in
exclude_ops
]
]
@
classmethod
def
build_for_detection
(
cls
,
num_layers
:
int
=
2
,
magnitude
:
float
=
10.
,
cutout_const
:
float
=
40.
,
translate_const
:
float
=
100.
,
magnitude_std
:
float
=
0.0
,
prob_to_apply
:
Optional
[
float
]
=
None
,
exclude_ops
:
Optional
[
List
[
str
]]
=
None
):
"""Builds a RandAugment that modifies bboxes for geometric transforms."""
augmenter
=
cls
(
num_layers
=
num_layers
,
magnitude
=
magnitude
,
cutout_const
=
cutout_const
,
translate_const
=
translate_const
,
magnitude_std
=
magnitude_std
,
prob_to_apply
=
prob_to_apply
,
exclude_ops
=
exclude_ops
)
box_aware_ops_by_base_name
=
{
'Rotate'
:
'Rotate_BBox'
,
'ShearX'
:
'ShearX_BBox'
,
'ShearY'
:
'ShearY_BBox'
,
'TranslateX'
:
'TranslateX_BBox'
,
'TranslateY'
:
'TranslateY_BBox'
,
}
augmenter
.
available_ops
=
[
box_aware_ops_by_base_name
.
get
(
op_name
)
or
op_name
for
op_name
in
augmenter
.
available_ops
]
return
augmenter
def
_distort_common
(
def
_distort_common
(
self
,
self
,
image
:
tf
.
Tensor
,
image
:
tf
.
Tensor
,
...
...
official/vision/beta/ops/augment_test.py
View file @
c9a7e0b2
...
@@ -140,6 +140,23 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -140,6 +140,23 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
self
.
assertEqual
((
2
,
4
),
aug_bboxes
.
shape
)
self
.
assertEqual
((
2
,
4
),
aug_bboxes
.
shape
)
def
test_randaug_build_for_detection
(
self
):
"""Smoke test to be sure there are no syntax errors built for detection."""
image
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
bboxes
=
tf
.
ones
((
2
,
4
),
dtype
=
tf
.
float32
)
augmenter
=
augment
.
RandAugment
.
build_for_detection
()
self
.
assertCountEqual
(
augmenter
.
available_ops
,
[
'AutoContrast'
,
'Equalize'
,
'Invert'
,
'Posterize'
,
'Solarize'
,
'Color'
,
'Contrast'
,
'Brightness'
,
'Sharpness'
,
'Cutout'
,
'SolarizeAdd'
,
'Rotate_BBox'
,
'ShearX_BBox'
,
'ShearY_BBox'
,
'TranslateX_BBox'
,
'TranslateY_BBox'
])
aug_image
,
aug_bboxes
=
augmenter
.
distort_with_boxes
(
image
,
bboxes
)
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
self
.
assertEqual
((
2
,
4
),
aug_bboxes
.
shape
)
def
test_all_policy_ops
(
self
):
def
test_all_policy_ops
(
self
):
"""Smoke test to be sure all augmentation functions can execute."""
"""Smoke test to be sure all augmentation functions can execute."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment