Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
96674ab0
Commit
96674ab0
authored
Dec 16, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 416886349
parent
8d41d6c0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
875 additions
and
63 deletions
+875
-63
official/vision/beta/configs/retinanet.py
official/vision/beta/configs/retinanet.py
+6
-1
official/vision/beta/dataloaders/retinanet_input.py
official/vision/beta/dataloaders/retinanet_input.py
+24
-2
official/vision/beta/ops/augment.py
official/vision/beta/ops/augment.py
+749
-49
official/vision/beta/ops/augment_test.py
official/vision/beta/ops/augment_test.py
+95
-11
official/vision/beta/tasks/retinanet.py
official/vision/beta/tasks/retinanet.py
+1
-0
No files found.
official/vision/beta/configs/retinanet.py
View file @
96674ab0
...
...
@@ -55,9 +55,14 @@ class Parser(hyperparams.Config):
aug_rand_hflip
:
bool
=
False
aug_scale_min
:
float
=
1.0
aug_scale_max
:
float
=
1.0
aug_policy
:
Optional
[
str
]
=
None
skip_crowd_during_training
:
bool
=
True
max_num_instances
:
int
=
100
# Can choose AutoAugment and RandAugment.
# TODO(b/205346436) Support RandAugment.
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
# Keep for backward compatibility. Not used.
aug_policy
:
Optional
[
str
]
=
None
@
dataclasses
.
dataclass
...
...
official/vision/beta/dataloaders/retinanet_input.py
View file @
96674ab0
...
...
@@ -19,11 +19,13 @@ into (image, labels) tuple for RetinaNet.
"""
# Import libraries
from
absl
import
logging
import
tensorflow
as
tf
from
official.vision.beta.dataloaders
import
parser
from
official.vision.beta.dataloaders
import
utils
from
official.vision.beta.ops
import
anchor
from
official.vision.beta.ops
import
augment
from
official.vision.beta.ops
import
box_ops
from
official.vision.beta.ops
import
preprocess_ops
...
...
@@ -40,6 +42,7 @@ class Parser(parser.Parser):
anchor_size
,
match_threshold
=
0.5
,
unmatched_threshold
=
0.5
,
aug_type
=
None
,
aug_rand_hflip
=
False
,
aug_scale_min
=
1.0
,
aug_scale_max
=
1.0
,
...
...
@@ -71,6 +74,8 @@ class Parser(parser.Parser):
unmatched_threshold: `float` number between 0 and 1 representing the
upper-bound threshold to assign negative labels for anchors. An anchor
with a score below the threshold is labeled negative.
aug_type: An optional Augmentation object to choose from AutoAugment and
RandAugment. The latter is not supported, and will raise ValueError.
aug_rand_hflip: `bool`, if True, augment training with random horizontal
flip.
aug_scale_min: `float`, the minimum scale applied to `output_size` for
...
...
@@ -108,7 +113,20 @@ class Parser(parser.Parser):
self
.
_aug_scale_min
=
aug_scale_min
self
.
_aug_scale_max
=
aug_scale_max
# Data Augmentation with AutoAugment.
# Data augmentation with AutoAugment or RandAugment.
self
.
_augmenter
=
None
if
aug_type
is
not
None
:
if
aug_type
.
type
==
'autoaug'
:
logging
.
info
(
'Using AutoAugment.'
)
self
.
_augmenter
=
augment
.
AutoAugment
(
augmentation_name
=
aug_type
.
autoaug
.
augmentation_name
,
cutout_const
=
aug_type
.
autoaug
.
cutout_const
,
translate_const
=
aug_type
.
autoaug
.
translate_const
)
else
:
# TODO(b/205346436) Support RandAugment.
raise
ValueError
(
f
'Augmentation policy
{
aug_type
.
type
}
not supported.'
)
# Deprecated. Data Augmentation with AutoAugment.
self
.
_use_autoaugment
=
use_autoaugment
self
.
_autoaugment_policy_name
=
autoaugment_policy_name
...
...
@@ -138,9 +156,13 @@ class Parser(parser.Parser):
for
k
,
v
in
attributes
.
items
():
attributes
[
k
]
=
tf
.
gather
(
v
,
indices
)
# Gets original image
and its size
.
# Gets original image.
image
=
data
[
'image'
]
# Apply autoaug or randaug.
if
self
.
_augmenter
is
not
None
:
image
,
boxes
=
self
.
_augmenter
.
distort_with_boxes
(
image
,
boxes
)
image_shape
=
tf
.
shape
(
input
=
image
)[
0
:
2
]
# Normalizes image with mean and std pixel values.
...
...
official/vision/beta/ops/augment.py
View file @
96674ab0
...
...
@@ -14,7 +14,9 @@
"""Augmentation policies for enhanced image/video preprocessing.
AutoAugment Reference: https://arxiv.org/abs/1805.09501
AutoAugment Reference:
- AutoAugment Reference: https://arxiv.org/abs/1805.09501
- AutoAugment for Object Detection Reference: https://arxiv.org/abs/1906.11172
RandAugment Reference: https://arxiv.org/abs/1909.13719
RandomErasing Reference: https://arxiv.org/abs/1708.04896
MixupAndCutmix:
...
...
@@ -25,6 +27,7 @@ RandomErasing, Mixup and Cutmix are inspired by
https://github.com/rwightman/pytorch-image-models
"""
import
inspect
import
math
from
typing
import
Any
,
List
,
Iterable
,
Optional
,
Text
,
Tuple
...
...
@@ -702,6 +705,572 @@ def unwrap(image: tf.Tensor, replace: int) -> tf.Tensor:
return
image
def
_scale_bbox_only_op_probability
(
prob
):
"""Reduce the probability of the bbox-only operation.
Probability is reduced so that we do not distort the content of too many
bounding boxes that are close to each other. The value of 3.0 was a chosen
hyper parameter when designing the autoaugment algorithm that we found
empirically to work well.
Args:
prob: Float that is the probability of applying the bbox-only operation.
Returns:
Reduced probability.
"""
return
prob
/
3.0
def
_apply_bbox_augmentation
(
image
,
bbox
,
augmentation_func
,
*
args
):
"""Applies augmentation_func to the subsection of image indicated by bbox.
Args:
image: 3D uint8 Tensor.
bbox: 1D Tensor that has 4 elements (min_y, min_x, max_y, max_x)
of type float that represents the normalized coordinates between 0 and 1.
augmentation_func: Augmentation function that will be applied to the
subsection of image.
*args: Additional parameters that will be passed into augmentation_func
when it is called.
Returns:
A modified version of image, where the bbox location in the image will
have `ugmentation_func applied to it.
"""
image_height
=
tf
.
cast
(
tf
.
shape
(
image
)[
0
],
tf
.
float32
)
image_width
=
tf
.
cast
(
tf
.
shape
(
image
)[
1
],
tf
.
float32
)
min_y
=
tf
.
cast
(
image_height
*
bbox
[
0
],
tf
.
int32
)
min_x
=
tf
.
cast
(
image_width
*
bbox
[
1
],
tf
.
int32
)
max_y
=
tf
.
cast
(
image_height
*
bbox
[
2
],
tf
.
int32
)
max_x
=
tf
.
cast
(
image_width
*
bbox
[
3
],
tf
.
int32
)
image_height
=
tf
.
cast
(
image_height
,
tf
.
int32
)
image_width
=
tf
.
cast
(
image_width
,
tf
.
int32
)
# Clip to be sure the max values do not fall out of range.
max_y
=
tf
.
minimum
(
max_y
,
image_height
-
1
)
max_x
=
tf
.
minimum
(
max_x
,
image_width
-
1
)
# Get the sub-tensor that is the image within the bounding box region.
bbox_content
=
image
[
min_y
:
max_y
+
1
,
min_x
:
max_x
+
1
,
:]
# Apply the augmentation function to the bbox portion of the image.
augmented_bbox_content
=
augmentation_func
(
bbox_content
,
*
args
)
# Pad the augmented_bbox_content and the mask to match the shape of original
# image.
augmented_bbox_content
=
tf
.
pad
(
augmented_bbox_content
,
[[
min_y
,
(
image_height
-
1
)
-
max_y
],
[
min_x
,
(
image_width
-
1
)
-
max_x
],
[
0
,
0
]])
# Create a mask that will be used to zero out a part of the original image.
mask_tensor
=
tf
.
zeros_like
(
bbox_content
)
mask_tensor
=
tf
.
pad
(
mask_tensor
,
[[
min_y
,
(
image_height
-
1
)
-
max_y
],
[
min_x
,
(
image_width
-
1
)
-
max_x
],
[
0
,
0
]],
constant_values
=
1
)
# Replace the old bbox content with the new augmented content.
image
=
image
*
mask_tensor
+
augmented_bbox_content
return
image
def
_concat_bbox
(
bbox
,
bboxes
):
"""Helper function that concates bbox to bboxes along the first dimension."""
# Note if all elements in bboxes are -1 (_INVALID_BOX), then this means
# we discard bboxes and start the bboxes Tensor with the current bbox.
bboxes_sum_check
=
tf
.
reduce_sum
(
bboxes
)
bbox
=
tf
.
expand_dims
(
bbox
,
0
)
# This check will be true when it is an _INVALID_BOX
bboxes
=
tf
.
cond
(
tf
.
equal
(
bboxes_sum_check
,
-
4.0
),
lambda
:
bbox
,
lambda
:
tf
.
concat
([
bboxes
,
bbox
],
0
))
return
bboxes
def
_apply_bbox_augmentation_wrapper
(
image
,
bbox
,
new_bboxes
,
prob
,
augmentation_func
,
func_changes_bbox
,
*
args
):
"""Applies _apply_bbox_augmentation with probability prob.
Args:
image: 3D uint8 Tensor.
bbox: 1D Tensor that has 4 elements (min_y, min_x, max_y, max_x)
of type float that represents the normalized coordinates between 0 and 1.
new_bboxes: 2D Tensor that is a list of the bboxes in the image after they
have been altered by aug_func. These will only be changed when
func_changes_bbox is set to true. Each bbox has 4 elements
(min_y, min_x, max_y, max_x) of type float that are the normalized
bbox coordinates between 0 and 1.
prob: Float that is the probability of applying _apply_bbox_augmentation.
augmentation_func: Augmentation function that will be applied to the
subsection of image.
func_changes_bbox: Boolean. Does augmentation_func return bbox in addition
to image.
*args: Additional parameters that will be passed into augmentation_func
when it is called.
Returns:
A tuple. Fist element is a modified version of image, where the bbox
location in the image will have augmentation_func applied to it if it is
chosen to be called with probability `prob`. The second element is a
Tensor of Tensors of length 4 that will contain the altered bbox after
applying augmentation_func.
"""
should_apply_op
=
tf
.
cast
(
tf
.
floor
(
tf
.
random
.
uniform
([],
dtype
=
tf
.
float32
)
+
prob
),
tf
.
bool
)
if
func_changes_bbox
:
augmented_image
,
bbox
=
tf
.
cond
(
should_apply_op
,
lambda
:
augmentation_func
(
image
,
bbox
,
*
args
),
lambda
:
(
image
,
bbox
))
else
:
augmented_image
=
tf
.
cond
(
should_apply_op
,
lambda
:
_apply_bbox_augmentation
(
image
,
bbox
,
augmentation_func
,
*
args
),
lambda
:
image
)
new_bboxes
=
_concat_bbox
(
bbox
,
new_bboxes
)
return
augmented_image
,
new_bboxes
def
_apply_multi_bbox_augmentation_wrapper
(
image
,
bboxes
,
prob
,
aug_func
,
func_changes_bbox
,
*
args
):
"""Checks to be sure num bboxes > 0 before calling inner function."""
num_bboxes
=
tf
.
shape
(
bboxes
)[
0
]
image
,
bboxes
=
tf
.
cond
(
tf
.
equal
(
num_bboxes
,
0
),
lambda
:
(
image
,
bboxes
),
# pylint:disable=g-long-lambda
lambda
:
_apply_multi_bbox_augmentation
(
image
,
bboxes
,
prob
,
aug_func
,
func_changes_bbox
,
*
args
))
# pylint:enable=g-long-lambda
return
image
,
bboxes
# Represents an invalid bounding box that is used for checking for padding
# lists of bounding box coordinates for a few augmentation operations
_INVALID_BOX
=
[[
-
1.0
,
-
1.0
,
-
1.0
,
-
1.0
]]
def
_apply_multi_bbox_augmentation
(
image
,
bboxes
,
prob
,
aug_func
,
func_changes_bbox
,
*
args
):
"""Applies aug_func to the image for each bbox in bboxes.
Args:
image: 3D uint8 Tensor.
bboxes: 2D Tensor that is a list of the bboxes in the image. Each bbox
has 4 elements (min_y, min_x, max_y, max_x) of type float.
prob: Float that is the probability of applying aug_func to a specific
bounding box within the image.
aug_func: Augmentation function that will be applied to the
subsections of image indicated by the bbox values in bboxes.
func_changes_bbox: Boolean. Does augmentation_func return bbox in addition
to image.
*args: Additional parameters that will be passed into augmentation_func
when it is called.
Returns:
A modified version of image, where each bbox location in the image will
have augmentation_func applied to it if it is chosen to be called with
probability prob independently across all bboxes. Also the final
bboxes are returned that will be unchanged if func_changes_bbox is set to
false and if true, the new altered ones will be returned.
Raises:
ValueError if applied to video.
"""
if
image
.
shape
.
rank
==
4
:
raise
ValueError
(
'Image rank 4 is not supported'
)
# Will keep track of the new altered bboxes after aug_func is repeatedly
# applied. The -1 values are a dummy value and this first Tensor will be
# removed upon appending the first real bbox.
new_bboxes
=
tf
.
constant
(
_INVALID_BOX
)
# If the bboxes are empty, then just give it _INVALID_BOX. The result
# will be thrown away.
bboxes
=
tf
.
cond
(
tf
.
equal
(
tf
.
size
(
bboxes
),
0
),
lambda
:
tf
.
constant
(
_INVALID_BOX
),
lambda
:
bboxes
)
bboxes
=
tf
.
ensure_shape
(
bboxes
,
(
None
,
4
))
# pylint:disable=g-long-lambda
wrapped_aug_func
=
(
lambda
_image
,
bbox
,
_new_bboxes
:
_apply_bbox_augmentation_wrapper
(
_image
,
bbox
,
_new_bboxes
,
prob
,
aug_func
,
func_changes_bbox
,
*
args
))
# pylint:enable=g-long-lambda
# Setup the while_loop.
num_bboxes
=
tf
.
shape
(
bboxes
)[
0
]
# We loop until we go over all bboxes.
idx
=
tf
.
constant
(
0
)
# Counter for the while loop.
# Conditional function when to end the loop once we go over all bboxes
# images_and_bboxes contain (_image, _new_bboxes)
cond
=
lambda
_idx
,
_images_and_bboxes
:
tf
.
less
(
_idx
,
num_bboxes
)
# Shuffle the bboxes so that the augmentation order is not deterministic if
# we are not changing the bboxes with aug_func.
if
not
func_changes_bbox
:
loop_bboxes
=
tf
.
random
.
shuffle
(
bboxes
)
else
:
loop_bboxes
=
bboxes
# Main function of while_loop where we repeatedly apply augmentation on the
# bboxes in the image.
# pylint:disable=g-long-lambda
body
=
lambda
_idx
,
_images_and_bboxes
:
[
_idx
+
1
,
wrapped_aug_func
(
_images_and_bboxes
[
0
],
loop_bboxes
[
_idx
],
_images_and_bboxes
[
1
])]
# pylint:enable=g-long-lambda
_
,
(
image
,
new_bboxes
)
=
tf
.
while_loop
(
cond
,
body
,
[
idx
,
(
image
,
new_bboxes
)],
shape_invariants
=
[
idx
.
get_shape
(),
(
image
.
get_shape
(),
tf
.
TensorShape
([
None
,
4
]))])
# Either return the altered bboxes or the original ones depending on if
# we altered them in anyway.
if
func_changes_bbox
:
final_bboxes
=
new_bboxes
else
:
final_bboxes
=
bboxes
return
image
,
final_bboxes
def
_clip_bbox
(
min_y
,
min_x
,
max_y
,
max_x
):
"""Clip bounding box coordinates between 0 and 1.
Args:
min_y: Normalized bbox coordinate of type float between 0 and 1.
min_x: Normalized bbox coordinate of type float between 0 and 1.
max_y: Normalized bbox coordinate of type float between 0 and 1.
max_x: Normalized bbox coordinate of type float between 0 and 1.
Returns:
Clipped coordinate values between 0 and 1.
"""
min_y
=
tf
.
clip_by_value
(
min_y
,
0.0
,
1.0
)
min_x
=
tf
.
clip_by_value
(
min_x
,
0.0
,
1.0
)
max_y
=
tf
.
clip_by_value
(
max_y
,
0.0
,
1.0
)
max_x
=
tf
.
clip_by_value
(
max_x
,
0.0
,
1.0
)
return
min_y
,
min_x
,
max_y
,
max_x
def
_check_bbox_area
(
min_y
,
min_x
,
max_y
,
max_x
,
delta
=
0.05
):
"""Adjusts bbox coordinates to make sure the area is > 0.
Args:
min_y: Normalized bbox coordinate of type float between 0 and 1.
min_x: Normalized bbox coordinate of type float between 0 and 1.
max_y: Normalized bbox coordinate of type float between 0 and 1.
max_x: Normalized bbox coordinate of type float between 0 and 1.
delta: Float, this is used to create a gap of size 2 * delta between
bbox min/max coordinates that are the same on the boundary.
This prevents the bbox from having an area of zero.
Returns:
Tuple of new bbox coordinates between 0 and 1 that will now have a
guaranteed area > 0.
"""
height
=
max_y
-
min_y
width
=
max_x
-
min_x
def
_adjust_bbox_boundaries
(
min_coord
,
max_coord
):
# Make sure max is never 0 and min is never 1.
max_coord
=
tf
.
maximum
(
max_coord
,
0.0
+
delta
)
min_coord
=
tf
.
minimum
(
min_coord
,
1.0
-
delta
)
return
min_coord
,
max_coord
min_y
,
max_y
=
tf
.
cond
(
tf
.
equal
(
height
,
0.0
),
lambda
:
_adjust_bbox_boundaries
(
min_y
,
max_y
),
lambda
:
(
min_y
,
max_y
))
min_x
,
max_x
=
tf
.
cond
(
tf
.
equal
(
width
,
0.0
),
lambda
:
_adjust_bbox_boundaries
(
min_x
,
max_x
),
lambda
:
(
min_x
,
max_x
))
return
min_y
,
min_x
,
max_y
,
max_x
def
_rotate_bbox
(
bbox
,
image_height
,
image_width
,
degrees
):
"""Rotates the bbox coordinated by degrees.
Args:
bbox: 1D Tensor that has 4 elements (min_y, min_x, max_y, max_x)
of type float that represents the normalized coordinates between 0 and 1.
image_height: Int, height of the image.
image_width: Int, height of the image.
degrees: Float, a scalar angle in degrees to rotate all images by. If
degrees is positive the image will be rotated clockwise otherwise it will
be rotated counterclockwise.
Returns:
A tensor of the same shape as bbox, but now with the rotated coordinates.
"""
image_height
,
image_width
=
(
tf
.
cast
(
image_height
,
tf
.
float32
),
tf
.
cast
(
image_width
,
tf
.
float32
))
# Convert from degrees to radians.
degrees_to_radians
=
math
.
pi
/
180.0
radians
=
degrees
*
degrees_to_radians
# Translate the bbox to the center of the image and turn the normalized 0-1
# coordinates to absolute pixel locations.
# Y coordinates are made negative as the y axis of images goes down with
# increasing pixel values, so we negate to make sure x axis and y axis points
# are in the traditionally positive direction.
min_y
=
-
tf
.
cast
(
image_height
*
(
bbox
[
0
]
-
0.5
),
tf
.
int32
)
min_x
=
tf
.
cast
(
image_width
*
(
bbox
[
1
]
-
0.5
),
tf
.
int32
)
max_y
=
-
tf
.
cast
(
image_height
*
(
bbox
[
2
]
-
0.5
),
tf
.
int32
)
max_x
=
tf
.
cast
(
image_width
*
(
bbox
[
3
]
-
0.5
),
tf
.
int32
)
coordinates
=
tf
.
stack
(
[[
min_y
,
min_x
],
[
min_y
,
max_x
],
[
max_y
,
min_x
],
[
max_y
,
max_x
]])
coordinates
=
tf
.
cast
(
coordinates
,
tf
.
float32
)
# Rotate the coordinates according to the rotation matrix clockwise if
# radians is positive, else negative
rotation_matrix
=
tf
.
stack
(
[[
tf
.
cos
(
radians
),
tf
.
sin
(
radians
)],
[
-
tf
.
sin
(
radians
),
tf
.
cos
(
radians
)]])
new_coords
=
tf
.
cast
(
tf
.
matmul
(
rotation_matrix
,
tf
.
transpose
(
coordinates
)),
tf
.
int32
)
# Find min/max values and convert them back to normalized 0-1 floats.
min_y
=
-
(
tf
.
cast
(
tf
.
reduce_max
(
new_coords
[
0
,
:]),
tf
.
float32
)
/
image_height
-
0.5
)
min_x
=
tf
.
cast
(
tf
.
reduce_min
(
new_coords
[
1
,
:]),
tf
.
float32
)
/
image_width
+
0.5
max_y
=
-
(
tf
.
cast
(
tf
.
reduce_min
(
new_coords
[
0
,
:]),
tf
.
float32
)
/
image_height
-
0.5
)
max_x
=
tf
.
cast
(
tf
.
reduce_max
(
new_coords
[
1
,
:]),
tf
.
float32
)
/
image_width
+
0.5
# Clip the bboxes to be sure the fall between [0, 1].
min_y
,
min_x
,
max_y
,
max_x
=
_clip_bbox
(
min_y
,
min_x
,
max_y
,
max_x
)
min_y
,
min_x
,
max_y
,
max_x
=
_check_bbox_area
(
min_y
,
min_x
,
max_y
,
max_x
)
return
tf
.
stack
([
min_y
,
min_x
,
max_y
,
max_x
])
def
rotate_with_bboxes
(
image
,
bboxes
,
degrees
,
replace
):
"""Equivalent of PIL Rotate that rotates the image and bbox.
Args:
image: 3D uint8 Tensor.
bboxes: 2D Tensor that is a list of the bboxes in the image. Each bbox
has 4 elements (min_y, min_x, max_y, max_x) of type float.
degrees: Float, a scalar angle in degrees to rotate all images by. If
degrees is positive the image will be rotated clockwise otherwise it will
be rotated counterclockwise.
replace: A one or three value 1D tensor to fill empty pixels.
Returns:
A tuple containing a 3D uint8 Tensor that will be the result of rotating
image by degrees. The second element of the tuple is bboxes, where now
the coordinates will be shifted to reflect the rotated image.
Raises:
ValueError: If applied to video.
"""
if
image
.
shape
.
rank
==
4
:
raise
ValueError
(
'Image rank 4 is not supported'
)
# Rotate the image.
image
=
wrapped_rotate
(
image
,
degrees
,
replace
)
# Convert bbox coordinates to pixel values.
image_height
=
tf
.
shape
(
image
)[
0
]
image_width
=
tf
.
shape
(
image
)[
1
]
# pylint:disable=g-long-lambda
wrapped_rotate_bbox
=
lambda
bbox
:
_rotate_bbox
(
bbox
,
image_height
,
image_width
,
degrees
)
# pylint:enable=g-long-lambda
bboxes
=
tf
.
map_fn
(
wrapped_rotate_bbox
,
bboxes
)
return
image
,
bboxes
def
_shear_bbox
(
bbox
,
image_height
,
image_width
,
level
,
shear_horizontal
):
"""Shifts the bbox according to how the image was sheared.
Args:
bbox: 1D Tensor that has 4 elements (min_y, min_x, max_y, max_x)
of type float that represents the normalized coordinates between 0 and 1.
image_height: Int, height of the image.
image_width: Int, height of the image.
level: Float. How much to shear the image.
shear_horizontal: If true then shear in X dimension else shear in
the Y dimension.
Returns:
A tensor of the same shape as bbox, but now with the shifted coordinates.
"""
image_height
,
image_width
=
(
tf
.
cast
(
image_height
,
tf
.
float32
),
tf
.
cast
(
image_width
,
tf
.
float32
))
# Change bbox coordinates to be pixels.
min_y
=
tf
.
cast
(
image_height
*
bbox
[
0
],
tf
.
int32
)
min_x
=
tf
.
cast
(
image_width
*
bbox
[
1
],
tf
.
int32
)
max_y
=
tf
.
cast
(
image_height
*
bbox
[
2
],
tf
.
int32
)
max_x
=
tf
.
cast
(
image_width
*
bbox
[
3
],
tf
.
int32
)
coordinates
=
tf
.
stack
(
[[
min_y
,
min_x
],
[
min_y
,
max_x
],
[
max_y
,
min_x
],
[
max_y
,
max_x
]])
coordinates
=
tf
.
cast
(
coordinates
,
tf
.
float32
)
# Shear the coordinates according to the translation matrix.
if
shear_horizontal
:
translation_matrix
=
tf
.
stack
(
[[
1
,
0
],
[
-
level
,
1
]])
else
:
translation_matrix
=
tf
.
stack
(
[[
1
,
-
level
],
[
0
,
1
]])
translation_matrix
=
tf
.
cast
(
translation_matrix
,
tf
.
float32
)
new_coords
=
tf
.
cast
(
tf
.
matmul
(
translation_matrix
,
tf
.
transpose
(
coordinates
)),
tf
.
int32
)
# Find min/max values and convert them back to floats.
min_y
=
tf
.
cast
(
tf
.
reduce_min
(
new_coords
[
0
,
:]),
tf
.
float32
)
/
image_height
min_x
=
tf
.
cast
(
tf
.
reduce_min
(
new_coords
[
1
,
:]),
tf
.
float32
)
/
image_width
max_y
=
tf
.
cast
(
tf
.
reduce_max
(
new_coords
[
0
,
:]),
tf
.
float32
)
/
image_height
max_x
=
tf
.
cast
(
tf
.
reduce_max
(
new_coords
[
1
,
:]),
tf
.
float32
)
/
image_width
# Clip the bboxes to be sure the fall between [0, 1].
min_y
,
min_x
,
max_y
,
max_x
=
_clip_bbox
(
min_y
,
min_x
,
max_y
,
max_x
)
min_y
,
min_x
,
max_y
,
max_x
=
_check_bbox_area
(
min_y
,
min_x
,
max_y
,
max_x
)
return
tf
.
stack
([
min_y
,
min_x
,
max_y
,
max_x
])
def
shear_with_bboxes
(
image
,
bboxes
,
level
,
replace
,
shear_horizontal
):
"""Applies Shear Transformation to the image and shifts the bboxes.
Args:
image: 3D uint8 Tensor.
bboxes: 2D Tensor that is a list of the bboxes in the image. Each bbox
has 4 elements (min_y, min_x, max_y, max_x) of type float with values
between [0, 1].
level: Float. How much to shear the image. This value will be between
-0.3 to 0.3.
replace: A one or three value 1D tensor to fill empty pixels.
shear_horizontal: Boolean. If true then shear in X dimension else shear in
the Y dimension.
Returns:
A tuple containing a 3D uint8 Tensor that will be the result of shearing
image by level. The second element of the tuple is bboxes, where now
the coordinates will be shifted to reflect the sheared image.
Raises:
ValueError: If applied to video.
"""
if
image
.
shape
.
rank
==
4
:
raise
ValueError
(
'Image rank 4 is not supported'
)
if
shear_horizontal
:
image
=
shear_x
(
image
,
level
,
replace
)
else
:
image
=
shear_y
(
image
,
level
,
replace
)
# Convert bbox coordinates to pixel values.
image_height
=
tf
.
shape
(
image
)[
0
]
image_width
=
tf
.
shape
(
image
)[
1
]
# pylint:disable=g-long-lambda
wrapped_shear_bbox
=
lambda
bbox
:
_shear_bbox
(
bbox
,
image_height
,
image_width
,
level
,
shear_horizontal
)
# pylint:enable=g-long-lambda
bboxes
=
tf
.
map_fn
(
wrapped_shear_bbox
,
bboxes
)
return
image
,
bboxes
def
_shift_bbox
(
bbox
,
image_height
,
image_width
,
pixels
,
shift_horizontal
):
"""Shifts the bbox coordinates by pixels.
Args:
bbox: 1D Tensor that has 4 elements (min_y, min_x, max_y, max_x)
of type float that represents the normalized coordinates between 0 and 1.
image_height: Int, height of the image.
image_width: Int, width of the image.
pixels: An int. How many pixels to shift the bbox.
shift_horizontal: Boolean. If true then shift in X dimension else shift in
Y dimension.
Returns:
A tensor of the same shape as bbox, but now with the shifted coordinates.
"""
pixels
=
tf
.
cast
(
pixels
,
tf
.
int32
)
# Convert bbox to integer pixel locations.
min_y
=
tf
.
cast
(
tf
.
cast
(
image_height
,
tf
.
float32
)
*
bbox
[
0
],
tf
.
int32
)
min_x
=
tf
.
cast
(
tf
.
cast
(
image_width
,
tf
.
float32
)
*
bbox
[
1
],
tf
.
int32
)
max_y
=
tf
.
cast
(
tf
.
cast
(
image_height
,
tf
.
float32
)
*
bbox
[
2
],
tf
.
int32
)
max_x
=
tf
.
cast
(
tf
.
cast
(
image_width
,
tf
.
float32
)
*
bbox
[
3
],
tf
.
int32
)
if
shift_horizontal
:
min_x
=
tf
.
maximum
(
0
,
min_x
-
pixels
)
max_x
=
tf
.
minimum
(
image_width
,
max_x
-
pixels
)
else
:
min_y
=
tf
.
maximum
(
0
,
min_y
-
pixels
)
max_y
=
tf
.
minimum
(
image_height
,
max_y
-
pixels
)
# Convert bbox back to floats.
min_y
=
tf
.
cast
(
min_y
,
tf
.
float32
)
/
tf
.
cast
(
image_height
,
tf
.
float32
)
min_x
=
tf
.
cast
(
min_x
,
tf
.
float32
)
/
tf
.
cast
(
image_width
,
tf
.
float32
)
max_y
=
tf
.
cast
(
max_y
,
tf
.
float32
)
/
tf
.
cast
(
image_height
,
tf
.
float32
)
max_x
=
tf
.
cast
(
max_x
,
tf
.
float32
)
/
tf
.
cast
(
image_width
,
tf
.
float32
)
# Clip the bboxes to be sure the fall between [0, 1].
min_y
,
min_x
,
max_y
,
max_x
=
_clip_bbox
(
min_y
,
min_x
,
max_y
,
max_x
)
min_y
,
min_x
,
max_y
,
max_x
=
_check_bbox_area
(
min_y
,
min_x
,
max_y
,
max_x
)
return
tf
.
stack
([
min_y
,
min_x
,
max_y
,
max_x
])
def
translate_bbox
(
image
,
bboxes
,
pixels
,
replace
,
shift_horizontal
):
"""Equivalent of PIL Translate in X/Y dimension that shifts image and bbox.
Args:
image: 3D uint8 Tensor.
bboxes: 2D Tensor that is a list of the bboxes in the image. Each bbox
has 4 elements (min_y, min_x, max_y, max_x) of type float with values
between [0, 1].
pixels: An int. How many pixels to shift the image and bboxes
replace: A one or three value 1D tensor to fill empty pixels.
shift_horizontal: Boolean. If true then shift in X dimension else shift in
Y dimension.
Returns:
A tuple containing a 3D uint8 Tensor that will be the result of translating
image by pixels. The second element of the tuple is bboxes, where now
the coordinates will be shifted to reflect the shifted image.
Raises:
ValueError if applied to video.
"""
if
image
.
shape
.
rank
==
4
:
raise
ValueError
(
'Image rank 4 is not supported'
)
if
shift_horizontal
:
image
=
translate_x
(
image
,
pixels
,
replace
)
else
:
image
=
translate_y
(
image
,
pixels
,
replace
)
# Convert bbox coordinates to pixel values.
image_height
=
tf
.
shape
(
image
)[
0
]
image_width
=
tf
.
shape
(
image
)[
1
]
# pylint:disable=g-long-lambda
wrapped_shift_bbox
=
lambda
bbox
:
_shift_bbox
(
bbox
,
image_height
,
image_width
,
pixels
,
shift_horizontal
)
# pylint:enable=g-long-lambda
bboxes
=
tf
.
map_fn
(
wrapped_shift_bbox
,
bboxes
)
return
image
,
bboxes
def
translate_y_only_bboxes
(
image
:
tf
.
Tensor
,
bboxes
:
tf
.
Tensor
,
prob
:
float
,
pixels
:
int
,
replace
):
"""Apply translate_y to each bbox in the image with probability prob."""
if
bboxes
.
shape
.
rank
==
4
:
raise
ValueError
(
'translate_y_only_bboxes does not support rank 4 boxes'
)
func_changes_bbox
=
False
prob
=
_scale_bbox_only_op_probability
(
prob
)
return
_apply_multi_bbox_augmentation_wrapper
(
image
,
bboxes
,
prob
,
translate_y
,
func_changes_bbox
,
pixels
,
replace
)
def
_randomly_negate_tensor
(
tensor
):
"""With 50% prob turn the tensor negative."""
should_flip
=
tf
.
cast
(
tf
.
floor
(
tf
.
random
.
uniform
([])
+
0.5
),
tf
.
bool
)
...
...
@@ -746,29 +1315,35 @@ def _mult_to_arg(level: float, multiplier: float = 1.):
return
(
int
((
level
/
_MAX_LEVEL
)
*
multiplier
),)
def
_apply_func_with_prob
(
func
:
Any
,
image
:
tf
.
Tensor
,
args
:
Any
,
prob
:
float
):
def
_apply_func_with_prob
(
func
:
Any
,
image
:
tf
.
Tensor
,
bboxes
:
Optional
[
tf
.
Tensor
],
args
:
Any
,
prob
:
float
):
"""Apply `func` to image w/ `args` as input with probability `prob`."""
assert
isinstance
(
args
,
tuple
)
assert
inspect
.
getfullargspec
(
func
)[
0
][
1
]
==
'bboxes'
# Apply the function with probability `prob`.
should_apply_op
=
tf
.
cast
(
tf
.
floor
(
tf
.
random
.
uniform
([],
dtype
=
tf
.
float32
)
+
prob
),
tf
.
bool
)
augmented_image
=
tf
.
cond
(
should_apply_op
,
lambda
:
func
(
image
,
*
args
),
lambda
:
image
)
return
augmented_image
augmented_image
,
augmented_bboxes
=
tf
.
cond
(
should_apply_op
,
lambda
:
func
(
image
,
bboxes
,
*
args
),
lambda
:
(
image
,
bboxes
))
return
augmented_image
,
augmented_bboxes
def
select_and_apply_random_policy
(
policies
:
Any
,
image
:
tf
.
Tensor
):
def
select_and_apply_random_policy
(
policies
:
Any
,
image
:
tf
.
Tensor
,
bboxes
:
Optional
[
tf
.
Tensor
]
=
None
):
"""Select a random policy from `policies` and apply it to `image`."""
policy_to_select
=
tf
.
random
.
uniform
([],
maxval
=
len
(
policies
),
dtype
=
tf
.
int32
)
# Note that using tf.case instead of tf.conds would result in significantly
# larger graphs and would even break export for some larger policies.
for
(
i
,
policy
)
in
enumerate
(
policies
):
image
=
tf
.
cond
(
image
,
bboxes
=
tf
.
cond
(
tf
.
equal
(
i
,
policy_to_select
),
lambda
selected_policy
=
policy
:
selected_policy
(
image
),
lambda
:
image
)
return
image
lambda
selected_policy
=
policy
:
selected_policy
(
image
,
bboxes
),
lambda
:
(
image
,
bboxes
)
)
return
image
,
bboxes
NAME_TO_FUNC
=
{
...
...
@@ -788,8 +1363,35 @@ NAME_TO_FUNC = {
'TranslateX'
:
translate_x
,
'TranslateY'
:
translate_y
,
'Cutout'
:
cutout
,
'Rotate_BBox'
:
rotate_with_bboxes
,
# pylint:disable=g-long-lambda
'ShearX_BBox'
:
lambda
image
,
bboxes
,
level
,
replace
:
shear_with_bboxes
(
image
,
bboxes
,
level
,
replace
,
shear_horizontal
=
True
),
'ShearY_BBox'
:
lambda
image
,
bboxes
,
level
,
replace
:
shear_with_bboxes
(
image
,
bboxes
,
level
,
replace
,
shear_horizontal
=
False
),
'TranslateX_BBox'
:
lambda
image
,
bboxes
,
pixels
,
replace
:
translate_bbox
(
image
,
bboxes
,
pixels
,
replace
,
shift_horizontal
=
True
),
'TranslateY_BBox'
:
lambda
image
,
bboxes
,
pixels
,
replace
:
translate_bbox
(
image
,
bboxes
,
pixels
,
replace
,
shift_horizontal
=
False
),
# pylint:enable=g-long-lambda
'TranslateY_Only_BBoxes'
:
translate_y_only_bboxes
,
}
# Functions that require a `bboxes` parameter.
REQUIRE_BOXES_FUNCS
=
frozenset
({
'Rotate_BBox'
,
'ShearX_BBox'
,
'ShearY_BBox'
,
'TranslateX_BBox'
,
'TranslateY_BBox'
,
'TranslateY_Only_BBoxes'
,
})
# Functions that have a 'prob' parameter
PROB_FUNCS
=
frozenset
({
'TranslateY_Only_BBoxes'
,
})
# Functions that have a 'replace' parameter
REPLACE_FUNCS
=
frozenset
({
'Rotate'
,
...
...
@@ -798,6 +1400,12 @@ REPLACE_FUNCS = frozenset({
'ShearY'
,
'TranslateY'
,
'Cutout'
,
'Rotate_BBox'
,
'ShearX_BBox'
,
'ShearY_BBox'
,
'TranslateX_BBox'
,
'TranslateY_BBox'
,
'TranslateY_Only_BBoxes'
,
})
...
...
@@ -810,6 +1418,7 @@ def level_to_arg(cutout_const: float, translate_const: float):
solarize_add_arg
=
lambda
level
:
_mult_to_arg
(
level
,
110
)
cutout_arg
=
lambda
level
:
_mult_to_arg
(
level
,
cutout_const
)
translate_arg
=
lambda
level
:
_translate_level_to_arg
(
level
,
translate_const
)
translate_bbox_arg
=
lambda
level
:
_translate_level_to_arg
(
level
,
120
)
args
=
{
'AutoContrast'
:
no_arg
,
...
...
@@ -828,10 +1437,27 @@ def level_to_arg(cutout_const: float, translate_const: float):
'Cutout'
:
cutout_arg
,
'TranslateX'
:
translate_arg
,
'TranslateY'
:
translate_arg
,
'Rotate_BBox'
:
_rotate_level_to_arg
,
'ShearX_BBox'
:
_shear_level_to_arg
,
'ShearY_BBox'
:
_shear_level_to_arg
,
# pylint:disable=g-long-lambda
'TranslateX_BBox'
:
lambda
level
:
_translate_level_to_arg
(
level
,
translate_const
),
'TranslateY_BBox'
:
lambda
level
:
_translate_level_to_arg
(
level
,
translate_const
),
# pylint:enable=g-long-lambda
'TranslateY_Only_BBoxes'
:
translate_bbox_arg
,
}
return
args
def
bbox_wrapper
(
func
):
"""Adds a bboxes function argument to func and returns unchanged bboxes."""
def
wrapper
(
images
,
bboxes
,
*
args
,
**
kwargs
):
return
(
func
(
images
,
*
args
,
**
kwargs
),
bboxes
)
return
wrapper
def
_parse_policy_info
(
name
:
Text
,
prob
:
float
,
level
:
float
,
...
...
@@ -848,28 +1474,58 @@ def _parse_policy_info(name: Text,
args
=
level_to_arg
(
cutout_const
,
translate_const
)[
name
](
level
)
if
name
in
PROB_FUNCS
:
# Add in the prob arg if it is required for the function that is called.
args
=
tuple
([
prob
]
+
list
(
args
))
if
name
in
REPLACE_FUNCS
:
# Add in replace arg if it is required for the function that is called.
args
=
tuple
(
list
(
args
)
+
[
replace_value
])
# Add bboxes as the second positional argument for the function if it does
# not already exist.
if
'bboxes'
not
in
inspect
.
getfullargspec
(
func
)[
0
]:
func
=
bbox_wrapper
(
func
)
return
func
,
prob
,
args
class
ImageAugment
(
object
):
"""Image augmentation class for applying image distortions."""
def
distort
(
self
,
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
def
distort
(
self
,
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Given an image tensor, returns a distorted image with the same shape.
Args:
image: `Tensor` of shape [height, width, 3] or
[num_frames, height, width, 3] representing an image or image sequence.
[num_frames, height, width, 3] representing an image or image sequence.
Returns:
The augmented version of `image`.
"""
raise
NotImplementedError
()
def
distort_with_boxes
(
self
,
image
:
tf
.
Tensor
,
bboxes
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
"""Distorts the image and bounding boxes.
Args:
image: `Tensor` of shape [height, width, 3] or
[num_frames, height, width, 3] representing an image or image sequence.
bboxes: `Tensor` of shape [num_boxes, 4] or [num_frames, num_boxes, 4]
representing bounding boxes for an image or image sequence.
Returns:
The augmented version of `image` and `bboxes`.
"""
raise
NotImplementedError
class
AutoAugment
(
ImageAugment
):
"""Applies the AutoAugment policy to images.
...
...
@@ -920,6 +1576,7 @@ class AutoAugment(ImageAugment):
self
.
cutout_const
=
float
(
cutout_const
)
self
.
translate_const
=
float
(
translate_const
)
self
.
available_policies
=
{
'detection_v0'
:
self
.
detection_policy_v0
(),
'v0'
:
self
.
policy_v0
(),
'test'
:
self
.
policy_test
(),
'simple'
:
self
.
policy_simple
(),
...
...
@@ -954,24 +1611,8 @@ class AutoAugment(ImageAugment):
raise
ValueError
(
'Wrong shape detected for custom policy. Expected '
'(:, :, 3) but got {}.'
.
format
(
in_shape
))
def
distort
(
self
,
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Applies the AutoAugment policy to `image`.
AutoAugment is from the paper: https://arxiv.org/abs/1805.09501.
Args:
image: `Tensor` of shape [height, width, 3] representing an image.
Returns:
A version of image that now has data augmentation applied to it based on
the `policies` pass into the function.
"""
input_image_type
=
image
.
dtype
if
input_image_type
!=
tf
.
uint8
:
image
=
tf
.
clip_by_value
(
image
,
0.0
,
255.0
)
image
=
tf
.
cast
(
image
,
dtype
=
tf
.
uint8
)
def
_make_tf_policies
(
self
):
"""Prepares the TF functions for augmentations based on the policies."""
replace_value
=
[
128
]
*
3
# func is the string name of the augmentation function, prob is the
...
...
@@ -1000,20 +1641,64 @@ class AutoAugment(ImageAugment):
# on image.
def
make_final_policy
(
tf_policy_
):
def
final_policy
(
image_
):
def
final_policy
(
image_
,
bboxes_
):
for
func
,
prob
,
args
in
tf_policy_
:
image_
=
_apply_func_with_prob
(
func
,
image_
,
args
,
prob
)
return
image_
image_
,
bboxes_
=
_apply_func_with_prob
(
func
,
image_
,
bboxes_
,
args
,
prob
)
return
image_
,
bboxes_
return
final_policy
with
tf
.
control_dependencies
(
assert_ranges
):
tf_policies
.
append
(
make_final_policy
(
tf_policy
))
image
=
select_and_apply_random_policy
(
tf_policies
,
image
)
image
=
tf
.
cast
(
image
,
dtype
=
input_image_type
)
return
tf_policies
def
distort
(
self
,
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""See base class."""
input_image_type
=
image
.
dtype
if
input_image_type
!=
tf
.
uint8
:
image
=
tf
.
clip_by_value
(
image
,
0.0
,
255.0
)
image
=
tf
.
cast
(
image
,
dtype
=
tf
.
uint8
)
tf_policies
=
self
.
_make_tf_policies
()
image
,
_
=
select_and_apply_random_policy
(
tf_policies
,
image
,
bboxes
=
None
)
return
image
def
distort_with_boxes
(
self
,
image
:
tf
.
Tensor
,
bboxes
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
"""See base class."""
input_image_type
=
image
.
dtype
if
input_image_type
!=
tf
.
uint8
:
image
=
tf
.
clip_by_value
(
image
,
0.0
,
255.0
)
image
=
tf
.
cast
(
image
,
dtype
=
tf
.
uint8
)
tf_policies
=
self
.
_make_tf_policies
()
image
,
bboxes
=
select_and_apply_random_policy
(
tf_policies
,
image
,
bboxes
)
return
image
,
bboxes
@
staticmethod
def
detection_policy_v0
():
"""Autoaugment policy that was used in AutoAugment Paper for Detection.
https://arxiv.org/pdf/1906.11172
Each tuple is an augmentation operation of the form
(operation, probability, magnitude). Each element in policy is a
sub-policy that will be applied sequentially on the image.
Returns:
the policy.
"""
policy
=
[
[(
'TranslateX_BBox'
,
0.6
,
4
),
(
'Equalize'
,
0.8
,
10
)],
[(
'TranslateY_Only_BBoxes'
,
0.2
,
2
),
(
'Cutout'
,
0.8
,
8
)],
[(
'Sharpness'
,
0.0
,
8
),
(
'ShearX_BBox'
,
0.4
,
0
)],
[(
'ShearY_BBox'
,
1.0
,
2
),
(
'TranslateY_Only_BBoxes'
,
0.6
,
6
)],
[(
'Rotate_BBox'
,
0.6
,
10
),
(
'Color'
,
1.0
,
6
)],
]
return
policy
@
staticmethod
def
policy_v0
():
"""Autoaugment policy that was used in AutoAugment Paper.
...
...
@@ -1211,6 +1896,10 @@ class AutoAugment(ImageAugment):
return
policy
def
_maybe_identity
(
x
:
Optional
[
tf
.
Tensor
])
->
Optional
[
tf
.
Tensor
]:
return
tf
.
identity
(
x
)
if
x
is
not
None
else
None
class
RandAugment
(
ImageAugment
):
"""Applies the RandAugment policy to images.
...
...
@@ -1261,15 +1950,12 @@ class RandAugment(ImageAugment):
op
for
op
in
self
.
available_ops
if
op
not
in
exclude_ops
]
def
distort
(
self
,
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Applies the RandAugment policy to `image`.
Args:
image: `Tensor` of shape [height, width, 3] representing an image.
Returns:
The augmented version of `image`.
"""
def
_distort_common
(
self
,
image
:
tf
.
Tensor
,
bboxes
:
Optional
[
tf
.
Tensor
]
=
None
)
->
Tuple
[
tf
.
Tensor
,
Optional
[
tf
.
Tensor
]]:
"""Distorts the image and optionally bounding boxes."""
input_image_type
=
image
.
dtype
if
input_image_type
!=
tf
.
uint8
:
...
...
@@ -1280,6 +1966,7 @@ class RandAugment(ImageAugment):
min_prob
,
max_prob
=
0.2
,
0.8
aug_image
=
image
aug_bboxes
=
bboxes
for
_
in
range
(
self
.
num_layers
):
op_to_select
=
tf
.
random
.
uniform
([],
...
...
@@ -1300,23 +1987,36 @@ class RandAugment(ImageAugment):
i
,
# pylint:disable=g-long-lambda
lambda
selected_func
=
func
,
selected_args
=
args
:
selected_func
(
image
,
*
selected_args
)))
image
,
bboxes
,
*
selected_args
)))
# pylint:enable=g-long-lambda
aug_image
=
tf
.
switch_case
(
aug_image
,
aug_bboxes
=
tf
.
switch_case
(
branch_index
=
op_to_select
,
branch_fns
=
branch_fns
,
default
=
lambda
:
tf
.
identity
(
image
))
default
=
lambda
:
(
tf
.
identity
(
image
)
,
_maybe_identity
(
bboxes
))
)
if
self
.
prob_to_apply
is
not
None
:
aug_image
=
tf
.
cond
(
aug_image
,
aug_bboxes
=
tf
.
cond
(
tf
.
random
.
uniform
(
shape
=
[],
dtype
=
tf
.
float32
)
<
self
.
prob_to_apply
,
lambda
:
tf
.
identity
(
aug_image
),
lambda
:
tf
.
identity
(
image
))
lambda
:
(
tf
.
identity
(
aug_image
),
_maybe_identity
(
aug_bboxes
)),
lambda
:
(
tf
.
identity
(
image
),
_maybe_identity
(
bboxes
)))
image
=
aug_image
bboxes
=
aug_bboxes
image
=
tf
.
cast
(
image
,
dtype
=
input_image_type
)
return
image
,
bboxes
def
distort
(
self
,
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""See base class."""
image
,
_
=
self
.
_distort_common
(
image
)
return
image
def
distort_with_boxes
(
self
,
image
:
tf
.
Tensor
,
bboxes
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
"""See base class."""
image
,
bboxes
=
self
.
_distort_common
(
image
,
bboxes
)
return
image
,
bboxes
class
RandomErasing
(
ImageAugment
):
"""Applies RandomErasing to a single image.
...
...
official/vision/beta/ops/augment_test.py
View file @
96674ab0
...
...
@@ -95,15 +95,7 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
'reduced_cifar10'
,
'svhn'
,
'reduced_imagenet'
,
]
AVAILABLE_POLICIES
=
[
'v0'
,
'test'
,
'simple'
,
'reduced_cifar10'
,
'svhn'
,
'reduced_imagenet'
,
'detection_v0'
,
]
def
test_autoaugment
(
self
):
...
...
@@ -116,6 +108,18 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
def
test_autoaugment_with_bboxes
(
self
):
"""Smoke test to be sure there are no syntax errors with bboxes."""
image
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
bboxes
=
tf
.
ones
((
2
,
4
),
dtype
=
tf
.
float32
)
for
policy
in
self
.
AVAILABLE_POLICIES
:
augmenter
=
augment
.
AutoAugment
(
augmentation_name
=
policy
)
aug_image
,
aug_bboxes
=
augmenter
.
distort_with_boxes
(
image
,
bboxes
)
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
self
.
assertEqual
((
2
,
4
),
aug_bboxes
.
shape
)
def
test_randaug
(
self
):
"""Smoke test to be sure there are no syntax errors."""
image
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
...
...
@@ -125,6 +129,17 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
def
test_randaug_with_bboxes
(
self
):
"""Smoke test to be sure there are no syntax errors with bboxes."""
image
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
bboxes
=
tf
.
ones
((
2
,
4
),
dtype
=
tf
.
float32
)
augmenter
=
augment
.
RandAugment
()
aug_image
,
aug_bboxes
=
augmenter
.
distort_with_boxes
(
image
,
bboxes
)
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
self
.
assertEqual
((
2
,
4
),
aug_bboxes
.
shape
)
def
test_all_policy_ops
(
self
):
"""Smoke test to be sure all augmentation functions can execute."""
...
...
@@ -135,14 +150,37 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
translate_const
=
250
image
=
tf
.
ones
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
bboxes
=
None
for
op_name
in
augment
.
NAME_TO_FUNC
.
keys
()
-
augment
.
REQUIRE_BOXES_FUNCS
:
func
,
_
,
args
=
augment
.
_parse_policy_info
(
op_name
,
prob
,
magnitude
,
replace_value
,
cutout_const
,
translate_const
)
image
,
bboxes
=
func
(
image
,
bboxes
,
*
args
)
self
.
assertEqual
((
224
,
224
,
3
),
image
.
shape
)
self
.
assertIsNone
(
bboxes
)
def
test_all_policy_ops_with_bboxes
(
self
):
"""Smoke test to be sure all augmentation functions can execute."""
prob
=
1
magnitude
=
10
replace_value
=
[
128
]
*
3
cutout_const
=
100
translate_const
=
250
image
=
tf
.
ones
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
bboxes
=
tf
.
ones
((
2
,
4
),
dtype
=
tf
.
float32
)
for
op_name
in
augment
.
NAME_TO_FUNC
:
func
,
_
,
args
=
augment
.
_parse_policy_info
(
op_name
,
prob
,
magnitude
,
replace_value
,
cutout_const
,
translate_const
)
image
=
func
(
image
,
*
args
)
image
,
bboxes
=
func
(
image
,
bboxes
,
*
args
)
self
.
assertEqual
((
224
,
224
,
3
),
image
.
shape
)
self
.
assertEqual
((
2
,
4
),
bboxes
.
shape
)
def
test_autoaugment_video
(
self
):
"""Smoke test with video to be sure there are no syntax errors."""
...
...
@@ -154,6 +192,18 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertEqual
((
2
,
224
,
224
,
3
),
aug_image
.
shape
)
def
test_autoaugment_video_with_boxes
(
self
):
"""Smoke test with video to be sure there are no syntax errors."""
image
=
tf
.
zeros
((
2
,
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
bboxes
=
tf
.
ones
((
2
,
2
,
4
),
dtype
=
tf
.
float32
)
for
policy
in
self
.
AVAILABLE_POLICIES
:
augmenter
=
augment
.
AutoAugment
(
augmentation_name
=
policy
)
aug_image
,
aug_bboxes
=
augmenter
.
distort_with_boxes
(
image
,
bboxes
)
self
.
assertEqual
((
2
,
224
,
224
,
3
),
aug_image
.
shape
)
self
.
assertEqual
((
2
,
2
,
4
),
aug_bboxes
.
shape
)
def
test_randaug_video
(
self
):
"""Smoke test with video to be sure there are no syntax errors."""
image
=
tf
.
zeros
((
2
,
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
...
...
@@ -173,14 +223,48 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
translate_const
=
250
image
=
tf
.
ones
((
2
,
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
bboxes
=
None
for
op_name
in
augment
.
NAME_TO_FUNC
.
keys
()
-
augment
.
REQUIRE_BOXES_FUNCS
:
func
,
_
,
args
=
augment
.
_parse_policy_info
(
op_name
,
prob
,
magnitude
,
replace_value
,
cutout_const
,
translate_const
)
image
,
bboxes
=
func
(
image
,
bboxes
,
*
args
)
self
.
assertEqual
((
2
,
224
,
224
,
3
),
image
.
shape
)
self
.
assertIsNone
(
bboxes
)
def
test_all_policy_ops_video_with_bboxes
(
self
):
"""Smoke test to be sure all video augmentation functions can execute."""
prob
=
1
magnitude
=
10
replace_value
=
[
128
]
*
3
cutout_const
=
100
translate_const
=
250
image
=
tf
.
ones
((
2
,
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
bboxes
=
tf
.
ones
((
2
,
2
,
4
),
dtype
=
tf
.
float32
)
for
op_name
in
augment
.
NAME_TO_FUNC
:
func
,
_
,
args
=
augment
.
_parse_policy_info
(
op_name
,
prob
,
magnitude
,
replace_value
,
cutout_const
,
translate_const
)
image
=
func
(
image
,
*
args
)
if
op_name
in
{
'Rotate_BBox'
,
'ShearX_BBox'
,
'ShearY_BBox'
,
'TranslateX_BBox'
,
'TranslateY_BBox'
,
'TranslateY_Only_BBoxes'
,
}:
with
self
.
assertRaises
(
ValueError
):
func
(
image
,
bboxes
,
*
args
)
else
:
image
,
bboxes
=
func
(
image
,
bboxes
,
*
args
)
self
.
assertEqual
((
2
,
224
,
224
,
3
),
image
.
shape
)
self
.
assertEqual
((
2
,
2
,
4
),
bboxes
.
shape
)
def
_generate_test_policy
(
self
):
"""Generate a test policy at random."""
...
...
official/vision/beta/tasks/retinanet.py
View file @
96674ab0
...
...
@@ -119,6 +119,7 @@ class RetinaNetTask(base_task.Task):
dtype
=
params
.
dtype
,
match_threshold
=
params
.
parser
.
match_threshold
,
unmatched_threshold
=
params
.
parser
.
unmatched_threshold
,
aug_type
=
params
.
parser
.
aug_type
,
aug_rand_hflip
=
params
.
parser
.
aug_rand_hflip
,
aug_scale_min
=
params
.
parser
.
aug_scale_min
,
aug_scale_max
=
params
.
parser
.
aug_scale_max
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment