Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
96674ab0
Commit
96674ab0
authored
Dec 16, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 416886349
parent
8d41d6c0
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
875 additions
and
63 deletions
+875
-63
official/vision/beta/configs/retinanet.py
official/vision/beta/configs/retinanet.py
+6
-1
official/vision/beta/dataloaders/retinanet_input.py
official/vision/beta/dataloaders/retinanet_input.py
+24
-2
official/vision/beta/ops/augment.py
official/vision/beta/ops/augment.py
+749
-49
official/vision/beta/ops/augment_test.py
official/vision/beta/ops/augment_test.py
+95
-11
official/vision/beta/tasks/retinanet.py
official/vision/beta/tasks/retinanet.py
+1
-0
No files found.
official/vision/beta/configs/retinanet.py
View file @
96674ab0
...
...
@@ -55,9 +55,14 @@ class Parser(hyperparams.Config):
aug_rand_hflip
:
bool
=
False
aug_scale_min
:
float
=
1.0
aug_scale_max
:
float
=
1.0
aug_policy
:
Optional
[
str
]
=
None
skip_crowd_during_training
:
bool
=
True
max_num_instances
:
int
=
100
# Can choose AutoAugment and RandAugment.
# TODO(b/205346436) Support RandAugment.
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
# Keep for backward compatibility. Not used.
aug_policy
:
Optional
[
str
]
=
None
@
dataclasses
.
dataclass
...
...
official/vision/beta/dataloaders/retinanet_input.py
View file @
96674ab0
...
...
@@ -19,11 +19,13 @@ into (image, labels) tuple for RetinaNet.
"""
# Import libraries
from
absl
import
logging
import
tensorflow
as
tf
from
official.vision.beta.dataloaders
import
parser
from
official.vision.beta.dataloaders
import
utils
from
official.vision.beta.ops
import
anchor
from
official.vision.beta.ops
import
augment
from
official.vision.beta.ops
import
box_ops
from
official.vision.beta.ops
import
preprocess_ops
...
...
@@ -40,6 +42,7 @@ class Parser(parser.Parser):
anchor_size
,
match_threshold
=
0.5
,
unmatched_threshold
=
0.5
,
aug_type
=
None
,
aug_rand_hflip
=
False
,
aug_scale_min
=
1.0
,
aug_scale_max
=
1.0
,
...
...
@@ -71,6 +74,8 @@ class Parser(parser.Parser):
unmatched_threshold: `float` number between 0 and 1 representing the
upper-bound threshold to assign negative labels for anchors. An anchor
with a score below the threshold is labeled negative.
aug_type: An optional Augmentation object to choose from AutoAugment and
RandAugment. The latter is not supported, and will raise ValueError.
aug_rand_hflip: `bool`, if True, augment training with random horizontal
flip.
aug_scale_min: `float`, the minimum scale applied to `output_size` for
...
...
@@ -108,7 +113,20 @@ class Parser(parser.Parser):
self
.
_aug_scale_min
=
aug_scale_min
self
.
_aug_scale_max
=
aug_scale_max
# Data Augmentation with AutoAugment.
# Data augmentation with AutoAugment or RandAugment.
self
.
_augmenter
=
None
if
aug_type
is
not
None
:
if
aug_type
.
type
==
'autoaug'
:
logging
.
info
(
'Using AutoAugment.'
)
self
.
_augmenter
=
augment
.
AutoAugment
(
augmentation_name
=
aug_type
.
autoaug
.
augmentation_name
,
cutout_const
=
aug_type
.
autoaug
.
cutout_const
,
translate_const
=
aug_type
.
autoaug
.
translate_const
)
else
:
# TODO(b/205346436) Support RandAugment.
raise
ValueError
(
f
'Augmentation policy
{
aug_type
.
type
}
not supported.'
)
# Deprecated. Data Augmentation with AutoAugment.
self
.
_use_autoaugment
=
use_autoaugment
self
.
_autoaugment_policy_name
=
autoaugment_policy_name
...
...
@@ -138,9 +156,13 @@ class Parser(parser.Parser):
for
k
,
v
in
attributes
.
items
():
attributes
[
k
]
=
tf
.
gather
(
v
,
indices
)
# Gets original image
and its size
.
# Gets original image.
image
=
data
[
'image'
]
# Apply autoaug or randaug.
if
self
.
_augmenter
is
not
None
:
image
,
boxes
=
self
.
_augmenter
.
distort_with_boxes
(
image
,
boxes
)
image_shape
=
tf
.
shape
(
input
=
image
)[
0
:
2
]
# Normalizes image with mean and std pixel values.
...
...
official/vision/beta/ops/augment.py
View file @
96674ab0
This diff is collapsed.
Click to expand it.
official/vision/beta/ops/augment_test.py
View file @
96674ab0
...
...
@@ -95,15 +95,7 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
'reduced_cifar10'
,
'svhn'
,
'reduced_imagenet'
,
]
AVAILABLE_POLICIES
=
[
'v0'
,
'test'
,
'simple'
,
'reduced_cifar10'
,
'svhn'
,
'reduced_imagenet'
,
'detection_v0'
,
]
def
test_autoaugment
(
self
):
...
...
@@ -116,6 +108,18 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
def
test_autoaugment_with_bboxes
(
self
):
"""Smoke test to be sure there are no syntax errors with bboxes."""
image
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
bboxes
=
tf
.
ones
((
2
,
4
),
dtype
=
tf
.
float32
)
for
policy
in
self
.
AVAILABLE_POLICIES
:
augmenter
=
augment
.
AutoAugment
(
augmentation_name
=
policy
)
aug_image
,
aug_bboxes
=
augmenter
.
distort_with_boxes
(
image
,
bboxes
)
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
self
.
assertEqual
((
2
,
4
),
aug_bboxes
.
shape
)
def
test_randaug
(
self
):
"""Smoke test to be sure there are no syntax errors."""
image
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
...
...
@@ -125,6 +129,17 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
def
test_randaug_with_bboxes
(
self
):
"""Smoke test to be sure there are no syntax errors with bboxes."""
image
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
bboxes
=
tf
.
ones
((
2
,
4
),
dtype
=
tf
.
float32
)
augmenter
=
augment
.
RandAugment
()
aug_image
,
aug_bboxes
=
augmenter
.
distort_with_boxes
(
image
,
bboxes
)
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
self
.
assertEqual
((
2
,
4
),
aug_bboxes
.
shape
)
def
test_all_policy_ops
(
self
):
"""Smoke test to be sure all augmentation functions can execute."""
...
...
@@ -135,14 +150,37 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
translate_const
=
250
image
=
tf
.
ones
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
bboxes
=
None
for
op_name
in
augment
.
NAME_TO_FUNC
.
keys
()
-
augment
.
REQUIRE_BOXES_FUNCS
:
func
,
_
,
args
=
augment
.
_parse_policy_info
(
op_name
,
prob
,
magnitude
,
replace_value
,
cutout_const
,
translate_const
)
image
,
bboxes
=
func
(
image
,
bboxes
,
*
args
)
self
.
assertEqual
((
224
,
224
,
3
),
image
.
shape
)
self
.
assertIsNone
(
bboxes
)
def
test_all_policy_ops_with_bboxes
(
self
):
"""Smoke test to be sure all augmentation functions can execute."""
prob
=
1
magnitude
=
10
replace_value
=
[
128
]
*
3
cutout_const
=
100
translate_const
=
250
image
=
tf
.
ones
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
bboxes
=
tf
.
ones
((
2
,
4
),
dtype
=
tf
.
float32
)
for
op_name
in
augment
.
NAME_TO_FUNC
:
func
,
_
,
args
=
augment
.
_parse_policy_info
(
op_name
,
prob
,
magnitude
,
replace_value
,
cutout_const
,
translate_const
)
image
=
func
(
image
,
*
args
)
image
,
bboxes
=
func
(
image
,
bboxes
,
*
args
)
self
.
assertEqual
((
224
,
224
,
3
),
image
.
shape
)
self
.
assertEqual
((
2
,
4
),
bboxes
.
shape
)
def
test_autoaugment_video
(
self
):
"""Smoke test with video to be sure there are no syntax errors."""
...
...
@@ -154,6 +192,18 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertEqual
((
2
,
224
,
224
,
3
),
aug_image
.
shape
)
def
test_autoaugment_video_with_boxes
(
self
):
"""Smoke test with video to be sure there are no syntax errors."""
image
=
tf
.
zeros
((
2
,
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
bboxes
=
tf
.
ones
((
2
,
2
,
4
),
dtype
=
tf
.
float32
)
for
policy
in
self
.
AVAILABLE_POLICIES
:
augmenter
=
augment
.
AutoAugment
(
augmentation_name
=
policy
)
aug_image
,
aug_bboxes
=
augmenter
.
distort_with_boxes
(
image
,
bboxes
)
self
.
assertEqual
((
2
,
224
,
224
,
3
),
aug_image
.
shape
)
self
.
assertEqual
((
2
,
2
,
4
),
aug_bboxes
.
shape
)
def
test_randaug_video
(
self
):
"""Smoke test with video to be sure there are no syntax errors."""
image
=
tf
.
zeros
((
2
,
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
...
...
@@ -173,14 +223,48 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
translate_const
=
250
image
=
tf
.
ones
((
2
,
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
bboxes
=
None
for
op_name
in
augment
.
NAME_TO_FUNC
.
keys
()
-
augment
.
REQUIRE_BOXES_FUNCS
:
func
,
_
,
args
=
augment
.
_parse_policy_info
(
op_name
,
prob
,
magnitude
,
replace_value
,
cutout_const
,
translate_const
)
image
,
bboxes
=
func
(
image
,
bboxes
,
*
args
)
self
.
assertEqual
((
2
,
224
,
224
,
3
),
image
.
shape
)
self
.
assertIsNone
(
bboxes
)
def
test_all_policy_ops_video_with_bboxes
(
self
):
"""Smoke test to be sure all video augmentation functions can execute."""
prob
=
1
magnitude
=
10
replace_value
=
[
128
]
*
3
cutout_const
=
100
translate_const
=
250
image
=
tf
.
ones
((
2
,
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
bboxes
=
tf
.
ones
((
2
,
2
,
4
),
dtype
=
tf
.
float32
)
for
op_name
in
augment
.
NAME_TO_FUNC
:
func
,
_
,
args
=
augment
.
_parse_policy_info
(
op_name
,
prob
,
magnitude
,
replace_value
,
cutout_const
,
translate_const
)
image
=
func
(
image
,
*
args
)
if
op_name
in
{
'Rotate_BBox'
,
'ShearX_BBox'
,
'ShearY_BBox'
,
'TranslateX_BBox'
,
'TranslateY_BBox'
,
'TranslateY_Only_BBoxes'
,
}:
with
self
.
assertRaises
(
ValueError
):
func
(
image
,
bboxes
,
*
args
)
else
:
image
,
bboxes
=
func
(
image
,
bboxes
,
*
args
)
self
.
assertEqual
((
2
,
224
,
224
,
3
),
image
.
shape
)
self
.
assertEqual
((
2
,
2
,
4
),
bboxes
.
shape
)
def
_generate_test_policy
(
self
):
"""Generate a test policy at random."""
...
...
official/vision/beta/tasks/retinanet.py
View file @
96674ab0
...
...
@@ -119,6 +119,7 @@ class RetinaNetTask(base_task.Task):
dtype
=
params
.
dtype
,
match_threshold
=
params
.
parser
.
match_threshold
,
unmatched_threshold
=
params
.
parser
.
unmatched_threshold
,
aug_type
=
params
.
parser
.
aug_type
,
aug_rand_hflip
=
params
.
parser
.
aug_rand_hflip
,
aug_scale_min
=
params
.
parser
.
aug_scale_min
,
aug_scale_max
=
params
.
parser
.
aug_scale_max
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment