Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
96674ab0
Commit
96674ab0
authored
Dec 16, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 416886349
parent
8d41d6c0
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
875 additions
and
63 deletions
+875
-63
official/vision/beta/configs/retinanet.py
official/vision/beta/configs/retinanet.py
+6
-1
official/vision/beta/dataloaders/retinanet_input.py
official/vision/beta/dataloaders/retinanet_input.py
+24
-2
official/vision/beta/ops/augment.py
official/vision/beta/ops/augment.py
+749
-49
official/vision/beta/ops/augment_test.py
official/vision/beta/ops/augment_test.py
+95
-11
official/vision/beta/tasks/retinanet.py
official/vision/beta/tasks/retinanet.py
+1
-0
No files found.
official/vision/beta/configs/retinanet.py
View file @
96674ab0
...
@@ -55,9 +55,14 @@ class Parser(hyperparams.Config):
...
@@ -55,9 +55,14 @@ class Parser(hyperparams.Config):
aug_rand_hflip
:
bool
=
False
aug_rand_hflip
:
bool
=
False
aug_scale_min
:
float
=
1.0
aug_scale_min
:
float
=
1.0
aug_scale_max
:
float
=
1.0
aug_scale_max
:
float
=
1.0
aug_policy
:
Optional
[
str
]
=
None
skip_crowd_during_training
:
bool
=
True
skip_crowd_during_training
:
bool
=
True
max_num_instances
:
int
=
100
max_num_instances
:
int
=
100
# Can choose AutoAugment and RandAugment.
# TODO(b/205346436) Support RandAugment.
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
# Keep for backward compatibility. Not used.
aug_policy
:
Optional
[
str
]
=
None
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/vision/beta/dataloaders/retinanet_input.py
View file @
96674ab0
...
@@ -19,11 +19,13 @@ into (image, labels) tuple for RetinaNet.
...
@@ -19,11 +19,13 @@ into (image, labels) tuple for RetinaNet.
"""
"""
# Import libraries
# Import libraries
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision.beta.dataloaders
import
parser
from
official.vision.beta.dataloaders
import
parser
from
official.vision.beta.dataloaders
import
utils
from
official.vision.beta.dataloaders
import
utils
from
official.vision.beta.ops
import
anchor
from
official.vision.beta.ops
import
anchor
from
official.vision.beta.ops
import
augment
from
official.vision.beta.ops
import
box_ops
from
official.vision.beta.ops
import
box_ops
from
official.vision.beta.ops
import
preprocess_ops
from
official.vision.beta.ops
import
preprocess_ops
...
@@ -40,6 +42,7 @@ class Parser(parser.Parser):
...
@@ -40,6 +42,7 @@ class Parser(parser.Parser):
anchor_size
,
anchor_size
,
match_threshold
=
0.5
,
match_threshold
=
0.5
,
unmatched_threshold
=
0.5
,
unmatched_threshold
=
0.5
,
aug_type
=
None
,
aug_rand_hflip
=
False
,
aug_rand_hflip
=
False
,
aug_scale_min
=
1.0
,
aug_scale_min
=
1.0
,
aug_scale_max
=
1.0
,
aug_scale_max
=
1.0
,
...
@@ -71,6 +74,8 @@ class Parser(parser.Parser):
...
@@ -71,6 +74,8 @@ class Parser(parser.Parser):
unmatched_threshold: `float` number between 0 and 1 representing the
unmatched_threshold: `float` number between 0 and 1 representing the
upper-bound threshold to assign negative labels for anchors. An anchor
upper-bound threshold to assign negative labels for anchors. An anchor
with a score below the threshold is labeled negative.
with a score below the threshold is labeled negative.
aug_type: An optional Augmentation object to choose from AutoAugment and
RandAugment. The latter is not supported, and will raise ValueError.
aug_rand_hflip: `bool`, if True, augment training with random horizontal
aug_rand_hflip: `bool`, if True, augment training with random horizontal
flip.
flip.
aug_scale_min: `float`, the minimum scale applied to `output_size` for
aug_scale_min: `float`, the minimum scale applied to `output_size` for
...
@@ -108,7 +113,20 @@ class Parser(parser.Parser):
...
@@ -108,7 +113,20 @@ class Parser(parser.Parser):
self
.
_aug_scale_min
=
aug_scale_min
self
.
_aug_scale_min
=
aug_scale_min
self
.
_aug_scale_max
=
aug_scale_max
self
.
_aug_scale_max
=
aug_scale_max
# Data Augmentation with AutoAugment.
# Data augmentation with AutoAugment or RandAugment.
self
.
_augmenter
=
None
if
aug_type
is
not
None
:
if
aug_type
.
type
==
'autoaug'
:
logging
.
info
(
'Using AutoAugment.'
)
self
.
_augmenter
=
augment
.
AutoAugment
(
augmentation_name
=
aug_type
.
autoaug
.
augmentation_name
,
cutout_const
=
aug_type
.
autoaug
.
cutout_const
,
translate_const
=
aug_type
.
autoaug
.
translate_const
)
else
:
# TODO(b/205346436) Support RandAugment.
raise
ValueError
(
f
'Augmentation policy
{
aug_type
.
type
}
not supported.'
)
# Deprecated. Data Augmentation with AutoAugment.
self
.
_use_autoaugment
=
use_autoaugment
self
.
_use_autoaugment
=
use_autoaugment
self
.
_autoaugment_policy_name
=
autoaugment_policy_name
self
.
_autoaugment_policy_name
=
autoaugment_policy_name
...
@@ -138,9 +156,13 @@ class Parser(parser.Parser):
...
@@ -138,9 +156,13 @@ class Parser(parser.Parser):
for
k
,
v
in
attributes
.
items
():
for
k
,
v
in
attributes
.
items
():
attributes
[
k
]
=
tf
.
gather
(
v
,
indices
)
attributes
[
k
]
=
tf
.
gather
(
v
,
indices
)
# Gets original image
and its size
.
# Gets original image.
image
=
data
[
'image'
]
image
=
data
[
'image'
]
# Apply autoaug or randaug.
if
self
.
_augmenter
is
not
None
:
image
,
boxes
=
self
.
_augmenter
.
distort_with_boxes
(
image
,
boxes
)
image_shape
=
tf
.
shape
(
input
=
image
)[
0
:
2
]
image_shape
=
tf
.
shape
(
input
=
image
)[
0
:
2
]
# Normalizes image with mean and std pixel values.
# Normalizes image with mean and std pixel values.
...
...
official/vision/beta/ops/augment.py
View file @
96674ab0
This diff is collapsed.
Click to expand it.
official/vision/beta/ops/augment_test.py
View file @
96674ab0
...
@@ -95,15 +95,7 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -95,15 +95,7 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
'reduced_cifar10'
,
'reduced_cifar10'
,
'svhn'
,
'svhn'
,
'reduced_imagenet'
,
'reduced_imagenet'
,
]
'detection_v0'
,
AVAILABLE_POLICIES
=
[
'v0'
,
'test'
,
'simple'
,
'reduced_cifar10'
,
'svhn'
,
'reduced_imagenet'
,
]
]
def
test_autoaugment
(
self
):
def
test_autoaugment
(
self
):
...
@@ -116,6 +108,18 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -116,6 +108,18 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
def
test_autoaugment_with_bboxes
(
self
):
"""Smoke test to be sure there are no syntax errors with bboxes."""
image
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
bboxes
=
tf
.
ones
((
2
,
4
),
dtype
=
tf
.
float32
)
for
policy
in
self
.
AVAILABLE_POLICIES
:
augmenter
=
augment
.
AutoAugment
(
augmentation_name
=
policy
)
aug_image
,
aug_bboxes
=
augmenter
.
distort_with_boxes
(
image
,
bboxes
)
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
self
.
assertEqual
((
2
,
4
),
aug_bboxes
.
shape
)
def
test_randaug
(
self
):
def
test_randaug
(
self
):
"""Smoke test to be sure there are no syntax errors."""
"""Smoke test to be sure there are no syntax errors."""
image
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
image
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
...
@@ -125,6 +129,17 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -125,6 +129,17 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
def
test_randaug_with_bboxes
(
self
):
"""Smoke test to be sure there are no syntax errors with bboxes."""
image
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
bboxes
=
tf
.
ones
((
2
,
4
),
dtype
=
tf
.
float32
)
augmenter
=
augment
.
RandAugment
()
aug_image
,
aug_bboxes
=
augmenter
.
distort_with_boxes
(
image
,
bboxes
)
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
self
.
assertEqual
((
2
,
4
),
aug_bboxes
.
shape
)
def
test_all_policy_ops
(
self
):
def
test_all_policy_ops
(
self
):
"""Smoke test to be sure all augmentation functions can execute."""
"""Smoke test to be sure all augmentation functions can execute."""
...
@@ -135,14 +150,37 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -135,14 +150,37 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
translate_const
=
250
translate_const
=
250
image
=
tf
.
ones
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
image
=
tf
.
ones
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
bboxes
=
None
for
op_name
in
augment
.
NAME_TO_FUNC
.
keys
()
-
augment
.
REQUIRE_BOXES_FUNCS
:
func
,
_
,
args
=
augment
.
_parse_policy_info
(
op_name
,
prob
,
magnitude
,
replace_value
,
cutout_const
,
translate_const
)
image
,
bboxes
=
func
(
image
,
bboxes
,
*
args
)
self
.
assertEqual
((
224
,
224
,
3
),
image
.
shape
)
self
.
assertIsNone
(
bboxes
)
def
test_all_policy_ops_with_bboxes
(
self
):
"""Smoke test to be sure all augmentation functions can execute."""
prob
=
1
magnitude
=
10
replace_value
=
[
128
]
*
3
cutout_const
=
100
translate_const
=
250
image
=
tf
.
ones
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
bboxes
=
tf
.
ones
((
2
,
4
),
dtype
=
tf
.
float32
)
for
op_name
in
augment
.
NAME_TO_FUNC
:
for
op_name
in
augment
.
NAME_TO_FUNC
:
func
,
_
,
args
=
augment
.
_parse_policy_info
(
op_name
,
prob
,
magnitude
,
func
,
_
,
args
=
augment
.
_parse_policy_info
(
op_name
,
prob
,
magnitude
,
replace_value
,
cutout_const
,
replace_value
,
cutout_const
,
translate_const
)
translate_const
)
image
=
func
(
image
,
*
args
)
image
,
bboxes
=
func
(
image
,
bboxes
,
*
args
)
self
.
assertEqual
((
224
,
224
,
3
),
image
.
shape
)
self
.
assertEqual
((
224
,
224
,
3
),
image
.
shape
)
self
.
assertEqual
((
2
,
4
),
bboxes
.
shape
)
def
test_autoaugment_video
(
self
):
def
test_autoaugment_video
(
self
):
"""Smoke test with video to be sure there are no syntax errors."""
"""Smoke test with video to be sure there are no syntax errors."""
...
@@ -154,6 +192,18 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -154,6 +192,18 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertEqual
((
2
,
224
,
224
,
3
),
aug_image
.
shape
)
self
.
assertEqual
((
2
,
224
,
224
,
3
),
aug_image
.
shape
)
def
test_autoaugment_video_with_boxes
(
self
):
"""Smoke test with video to be sure there are no syntax errors."""
image
=
tf
.
zeros
((
2
,
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
bboxes
=
tf
.
ones
((
2
,
2
,
4
),
dtype
=
tf
.
float32
)
for
policy
in
self
.
AVAILABLE_POLICIES
:
augmenter
=
augment
.
AutoAugment
(
augmentation_name
=
policy
)
aug_image
,
aug_bboxes
=
augmenter
.
distort_with_boxes
(
image
,
bboxes
)
self
.
assertEqual
((
2
,
224
,
224
,
3
),
aug_image
.
shape
)
self
.
assertEqual
((
2
,
2
,
4
),
aug_bboxes
.
shape
)
def
test_randaug_video
(
self
):
def
test_randaug_video
(
self
):
"""Smoke test with video to be sure there are no syntax errors."""
"""Smoke test with video to be sure there are no syntax errors."""
image
=
tf
.
zeros
((
2
,
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
image
=
tf
.
zeros
((
2
,
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
...
@@ -173,14 +223,48 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -173,14 +223,48 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
translate_const
=
250
translate_const
=
250
image
=
tf
.
ones
((
2
,
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
image
=
tf
.
ones
((
2
,
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
bboxes
=
None
for
op_name
in
augment
.
NAME_TO_FUNC
.
keys
()
-
augment
.
REQUIRE_BOXES_FUNCS
:
func
,
_
,
args
=
augment
.
_parse_policy_info
(
op_name
,
prob
,
magnitude
,
replace_value
,
cutout_const
,
translate_const
)
image
,
bboxes
=
func
(
image
,
bboxes
,
*
args
)
self
.
assertEqual
((
2
,
224
,
224
,
3
),
image
.
shape
)
self
.
assertIsNone
(
bboxes
)
def
test_all_policy_ops_video_with_bboxes
(
self
):
"""Smoke test to be sure all video augmentation functions can execute."""
prob
=
1
magnitude
=
10
replace_value
=
[
128
]
*
3
cutout_const
=
100
translate_const
=
250
image
=
tf
.
ones
((
2
,
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
bboxes
=
tf
.
ones
((
2
,
2
,
4
),
dtype
=
tf
.
float32
)
for
op_name
in
augment
.
NAME_TO_FUNC
:
for
op_name
in
augment
.
NAME_TO_FUNC
:
func
,
_
,
args
=
augment
.
_parse_policy_info
(
op_name
,
prob
,
magnitude
,
func
,
_
,
args
=
augment
.
_parse_policy_info
(
op_name
,
prob
,
magnitude
,
replace_value
,
cutout_const
,
replace_value
,
cutout_const
,
translate_const
)
translate_const
)
image
=
func
(
image
,
*
args
)
if
op_name
in
{
'Rotate_BBox'
,
'ShearX_BBox'
,
'ShearY_BBox'
,
'TranslateX_BBox'
,
'TranslateY_BBox'
,
'TranslateY_Only_BBoxes'
,
}:
with
self
.
assertRaises
(
ValueError
):
func
(
image
,
bboxes
,
*
args
)
else
:
image
,
bboxes
=
func
(
image
,
bboxes
,
*
args
)
self
.
assertEqual
((
2
,
224
,
224
,
3
),
image
.
shape
)
self
.
assertEqual
((
2
,
224
,
224
,
3
),
image
.
shape
)
self
.
assertEqual
((
2
,
2
,
4
),
bboxes
.
shape
)
def
_generate_test_policy
(
self
):
def
_generate_test_policy
(
self
):
"""Generate a test policy at random."""
"""Generate a test policy at random."""
...
...
official/vision/beta/tasks/retinanet.py
View file @
96674ab0
...
@@ -119,6 +119,7 @@ class RetinaNetTask(base_task.Task):
...
@@ -119,6 +119,7 @@ class RetinaNetTask(base_task.Task):
dtype
=
params
.
dtype
,
dtype
=
params
.
dtype
,
match_threshold
=
params
.
parser
.
match_threshold
,
match_threshold
=
params
.
parser
.
match_threshold
,
unmatched_threshold
=
params
.
parser
.
unmatched_threshold
,
unmatched_threshold
=
params
.
parser
.
unmatched_threshold
,
aug_type
=
params
.
parser
.
aug_type
,
aug_rand_hflip
=
params
.
parser
.
aug_rand_hflip
,
aug_rand_hflip
=
params
.
parser
.
aug_rand_hflip
,
aug_scale_min
=
params
.
parser
.
aug_scale_min
,
aug_scale_min
=
params
.
parser
.
aug_scale_min
,
aug_scale_max
=
params
.
parser
.
aug_scale_max
,
aug_scale_max
=
params
.
parser
.
aug_scale_max
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment