Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8717bca2
Commit
8717bca2
authored
Apr 27, 2021
by
Dan Kondratyuk
Committed by
A. Unique TensorFlower
Apr 27, 2021
Browse files
Internal change
PiperOrigin-RevId: 370717047
parent
d1a3fa6a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
217 additions
and
31 deletions
+217
-31
official/vision/beta/configs/video_classification.py
official/vision/beta/configs/video_classification.py
+1
-0
official/vision/beta/dataloaders/video_input.py
official/vision/beta/dataloaders/video_input.py
+22
-1
official/vision/beta/dataloaders/video_input_test.py
official/vision/beta/dataloaders/video_input_test.py
+22
-0
official/vision/beta/ops/augment.py
official/vision/beta/ops/augment.py
+134
-30
official/vision/beta/ops/augment_test.py
official/vision/beta/ops/augment_test.py
+38
-0
No files found.
official/vision/beta/configs/video_classification.py
View file @
8717bca2
...
@@ -57,6 +57,7 @@ class DataConfig(cfg.DataConfig):
...
@@ -57,6 +57,7 @@ class DataConfig(cfg.DataConfig):
aug_max_aspect_ratio
:
float
=
2.0
aug_max_aspect_ratio
:
float
=
2.0
aug_min_area_ratio
:
float
=
0.49
aug_min_area_ratio
:
float
=
0.49
aug_max_area_ratio
:
float
=
1.0
aug_max_area_ratio
:
float
=
1.0
aug_type
:
Optional
[
str
]
=
None
# 'autoaug', 'randaug', or None
image_field_key
:
str
=
'image/encoded'
image_field_key
:
str
=
'image/encoded'
label_field_key
:
str
=
'clip/label/index'
label_field_key
:
str
=
'clip/label/index'
...
...
official/vision/beta/dataloaders/video_input.py
View file @
8717bca2
...
@@ -23,6 +23,7 @@ import tensorflow as tf
...
@@ -23,6 +23,7 @@ import tensorflow as tf
from
official.vision.beta.configs
import
video_classification
as
exp_cfg
from
official.vision.beta.configs
import
video_classification
as
exp_cfg
from
official.vision.beta.dataloaders
import
decoder
from
official.vision.beta.dataloaders
import
decoder
from
official.vision.beta.dataloaders
import
parser
from
official.vision.beta.dataloaders
import
parser
from
official.vision.beta.ops
import
augment
from
official.vision.beta.ops
import
preprocess_ops_3d
from
official.vision.beta.ops
import
preprocess_ops_3d
IMAGE_KEY
=
'image/encoded'
IMAGE_KEY
=
'image/encoded'
...
@@ -43,6 +44,7 @@ def process_image(image: tf.Tensor,
...
@@ -43,6 +44,7 @@ def process_image(image: tf.Tensor,
max_aspect_ratio
:
float
=
2
,
max_aspect_ratio
:
float
=
2
,
min_area_ratio
:
float
=
0.49
,
min_area_ratio
:
float
=
0.49
,
max_area_ratio
:
float
=
1.0
,
max_area_ratio
:
float
=
1.0
,
augmenter
:
Optional
[
augment
.
ImageAugment
]
=
None
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Processes a serialized image tensor.
"""Processes a serialized image tensor.
...
@@ -72,6 +74,7 @@ def process_image(image: tf.Tensor,
...
@@ -72,6 +74,7 @@ def process_image(image: tf.Tensor,
max_aspect_ratio: The maximum aspect range for cropping.
max_aspect_ratio: The maximum aspect range for cropping.
min_area_ratio: The minimum area range for cropping.
min_area_ratio: The minimum area range for cropping.
max_area_ratio: The maximum area range for cropping.
max_area_ratio: The maximum area range for cropping.
augmenter: Image augmenter to distort each image.
seed: A deterministic seed to use when sampling.
seed: A deterministic seed to use when sampling.
Returns:
Returns:
...
@@ -119,6 +122,9 @@ def process_image(image: tf.Tensor,
...
@@ -119,6 +122,9 @@ def process_image(image: tf.Tensor,
(
min_aspect_ratio
,
max_aspect_ratio
),
(
min_aspect_ratio
,
max_aspect_ratio
),
(
min_area_ratio
,
max_area_ratio
))
(
min_area_ratio
,
max_area_ratio
))
image
=
preprocess_ops_3d
.
random_flip_left_right
(
image
,
seed
)
image
=
preprocess_ops_3d
.
random_flip_left_right
(
image
,
seed
)
if
augmenter
is
not
None
:
image
=
augmenter
.
distort
(
image
)
else
:
else
:
# Resize images (resize happens only if necessary to save compute).
# Resize images (resize happens only if necessary to save compute).
image
=
preprocess_ops_3d
.
resize_smallest
(
image
,
min_resize
)
image
=
preprocess_ops_3d
.
resize_smallest
(
image
,
min_resize
)
...
@@ -256,6 +262,19 @@ class Parser(parser.Parser):
...
@@ -256,6 +262,19 @@ class Parser(parser.Parser):
self
.
_audio_feature
=
input_params
.
audio_feature
self
.
_audio_feature
=
input_params
.
audio_feature
self
.
_audio_shape
=
input_params
.
audio_feature_shape
self
.
_audio_shape
=
input_params
.
audio_feature_shape
self
.
_augmenter
=
None
if
input_params
.
aug_type
is
not
None
:
aug_type
=
input_params
.
aug_type
if
aug_type
==
'autoaug'
:
logging
.
info
(
'Using AutoAugment.'
)
self
.
_augmenter
=
augment
.
AutoAugment
()
elif
aug_type
==
'randaug'
:
logging
.
info
(
'Using RandAugment.'
)
self
.
_augmenter
=
augment
.
RandAugment
()
else
:
raise
ValueError
(
'Augmentation policy {} is not supported.'
.
format
(
aug_type
))
def
_parse_train_data
(
def
_parse_train_data
(
self
,
decoded_tensors
:
Dict
[
str
,
tf
.
Tensor
]
self
,
decoded_tensors
:
Dict
[
str
,
tf
.
Tensor
]
)
->
Tuple
[
Dict
[
str
,
tf
.
Tensor
],
tf
.
Tensor
]:
)
->
Tuple
[
Dict
[
str
,
tf
.
Tensor
],
tf
.
Tensor
]:
...
@@ -274,8 +293,10 @@ class Parser(parser.Parser):
...
@@ -274,8 +293,10 @@ class Parser(parser.Parser):
min_aspect_ratio
=
self
.
_min_aspect_ratio
,
min_aspect_ratio
=
self
.
_min_aspect_ratio
,
max_aspect_ratio
=
self
.
_max_aspect_ratio
,
max_aspect_ratio
=
self
.
_max_aspect_ratio
,
min_area_ratio
=
self
.
_min_area_ratio
,
min_area_ratio
=
self
.
_min_area_ratio
,
max_area_ratio
=
self
.
_max_area_ratio
)
max_area_ratio
=
self
.
_max_area_ratio
,
augmenter
=
self
.
_augmenter
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
features
=
{
'image'
:
image
}
features
=
{
'image'
:
image
}
label
=
decoded_tensors
[
self
.
_label_key
]
label
=
decoded_tensors
[
self
.
_label_key
]
...
...
official/vision/beta/dataloaders/video_input_test.py
View file @
8717bca2
...
@@ -157,6 +157,28 @@ class VideoAndLabelParserTest(tf.test.TestCase):
...
@@ -157,6 +157,28 @@ class VideoAndLabelParserTest(tf.test.TestCase):
self
.
assertAllEqual
(
image
.
shape
,
(
2
,
224
,
224
,
3
))
self
.
assertAllEqual
(
image
.
shape
,
(
2
,
224
,
224
,
3
))
self
.
assertAllEqual
(
label
.
shape
,
(
600
,))
self
.
assertAllEqual
(
label
.
shape
,
(
600
,))
def
test_video_input_augmentation_returns_shape
(
self
):
params
=
exp_cfg
.
kinetics600
(
is_training
=
True
)
params
.
feature_shape
=
(
2
,
224
,
224
,
3
)
params
.
min_image_size
=
224
params
.
temporal_stride
=
2
params
.
aug_type
=
'autoaug'
decoder
=
video_input
.
Decoder
()
parser
=
video_input
.
Parser
(
params
).
parse_fn
(
params
.
is_training
)
seq_example
,
label
=
fake_seq_example
()
input_tensor
=
tf
.
constant
(
seq_example
.
SerializeToString
())
decoded_tensors
=
decoder
.
decode
(
input_tensor
)
output_tensor
=
parser
(
decoded_tensors
)
image_features
,
label
=
output_tensor
image
=
image_features
[
'image'
]
self
.
assertAllEqual
(
image
.
shape
,
(
2
,
224
,
224
,
3
))
self
.
assertAllEqual
(
label
.
shape
,
(
600
,))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
official/vision/beta/ops/augment.py
View file @
8717bca2
...
@@ -12,13 +12,13 @@
...
@@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""AutoAugment and RandAugment policies for enhanced image preprocessing.
"""AutoAugment and RandAugment policies for enhanced image
/video
preprocessing.
AutoAugment Reference: https://arxiv.org/abs/1805.09501
AutoAugment Reference: https://arxiv.org/abs/1805.09501
RandAugment Reference: https://arxiv.org/abs/1909.13719
RandAugment Reference: https://arxiv.org/abs/1909.13719
"""
"""
import
math
import
math
from
typing
import
Any
,
List
,
Optional
,
Text
,
Tuple
,
Iterable
from
typing
import
Any
,
List
,
Iterable
,
Optional
,
Text
,
Tuple
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -265,8 +265,8 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
...
@@ -265,8 +265,8 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
"""Apply cutout (https://arxiv.org/abs/1708.04552) to image.
"""Apply cutout (https://arxiv.org/abs/1708.04552) to image.
This operation applies a (2*pad_size x 2*pad_size) mask of zeros to
This operation applies a (2*pad_size x 2*pad_size) mask of zeros to
a random location within `im
g
`. The pixel values filled in will be of the
a random location within `im
age
`. The pixel values filled in will be of the
value `replace`. The locat
ed
where the mask will be applied is randomly
value `replace`. The locat
ion
where the mask will be applied is randomly
chosen uniformly over the whole image.
chosen uniformly over the whole image.
Args:
Args:
...
@@ -279,6 +279,12 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
...
@@ -279,6 +279,12 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
Returns:
Returns:
An image Tensor that is of type uint8.
An image Tensor that is of type uint8.
"""
"""
if
image
.
shape
.
rank
not
in
[
3
,
4
]:
raise
ValueError
(
'Bad image rank: {}'
.
format
(
image
.
shape
.
rank
))
if
image
.
shape
.
rank
==
4
:
return
cutout_video
(
image
,
replace
=
replace
)
image_height
=
tf
.
shape
(
image
)[
0
]
image_height
=
tf
.
shape
(
image
)[
0
]
image_width
=
tf
.
shape
(
image
)[
1
]
image_width
=
tf
.
shape
(
image
)[
1
]
...
@@ -311,7 +317,86 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
...
@@ -311,7 +317,86 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
return
image
return
image
def
cutout_video
(
image
:
tf
.
Tensor
,
replace
:
int
=
0
)
->
tf
.
Tensor
:
"""Apply cutout (https://arxiv.org/abs/1708.04552) to a video.
This operation applies a random size 3D mask of zeros to a random location
within `image`. The mask is padded The pixel values filled in will be of the
value `replace`. The location where the mask will be applied is randomly
chosen uniformly over the whole image. The size of the mask is randomly
sampled uniformly from [0.25*height, 0.5*height], [0.25*width, 0.5*width],
and [1, 0.25*depth], which represent the height, width, and number of frames
of the input video tensor respectively.
Args:
image: A video Tensor of type uint8.
replace: What pixel value to fill in the image in the area that has the
cutout mask applied to it.
Returns:
An video Tensor that is of type uint8.
"""
image_depth
=
tf
.
shape
(
image
)[
0
]
image_height
=
tf
.
shape
(
image
)[
1
]
image_width
=
tf
.
shape
(
image
)[
2
]
# Sample the center location in the image where the zero mask will be applied.
cutout_center_height
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0
,
maxval
=
image_height
,
dtype
=
tf
.
int32
)
cutout_center_width
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0
,
maxval
=
image_width
,
dtype
=
tf
.
int32
)
cutout_center_depth
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0
,
maxval
=
image_depth
,
dtype
=
tf
.
int32
)
pad_size_height
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
tf
.
maximum
(
1
,
tf
.
cast
(
image_height
/
4
,
tf
.
int32
)),
maxval
=
tf
.
maximum
(
2
,
tf
.
cast
(
image_height
/
2
,
tf
.
int32
)),
dtype
=
tf
.
int32
)
pad_size_width
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
tf
.
maximum
(
1
,
tf
.
cast
(
image_width
/
4
,
tf
.
int32
)),
maxval
=
tf
.
maximum
(
2
,
tf
.
cast
(
image_width
/
2
,
tf
.
int32
)),
dtype
=
tf
.
int32
)
pad_size_depth
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
1
,
maxval
=
tf
.
maximum
(
2
,
tf
.
cast
(
image_depth
/
4
,
tf
.
int32
)),
dtype
=
tf
.
int32
)
lower_pad
=
tf
.
maximum
(
0
,
cutout_center_height
-
pad_size_height
)
upper_pad
=
tf
.
maximum
(
0
,
image_height
-
cutout_center_height
-
pad_size_height
)
left_pad
=
tf
.
maximum
(
0
,
cutout_center_width
-
pad_size_width
)
right_pad
=
tf
.
maximum
(
0
,
image_width
-
cutout_center_width
-
pad_size_width
)
back_pad
=
tf
.
maximum
(
0
,
cutout_center_depth
-
pad_size_depth
)
forward_pad
=
tf
.
maximum
(
0
,
image_depth
-
cutout_center_depth
-
pad_size_depth
)
cutout_shape
=
[
image_depth
-
(
back_pad
+
forward_pad
),
image_height
-
(
lower_pad
+
upper_pad
),
image_width
-
(
left_pad
+
right_pad
),
]
padding_dims
=
[[
back_pad
,
forward_pad
],
[
lower_pad
,
upper_pad
],
[
left_pad
,
right_pad
]]
mask
=
tf
.
pad
(
tf
.
zeros
(
cutout_shape
,
dtype
=
image
.
dtype
),
padding_dims
,
constant_values
=
1
)
mask
=
tf
.
expand_dims
(
mask
,
-
1
)
mask
=
tf
.
tile
(
mask
,
[
1
,
1
,
1
,
3
])
image
=
tf
.
where
(
tf
.
equal
(
mask
,
0
),
tf
.
ones_like
(
image
,
dtype
=
image
.
dtype
)
*
replace
,
image
)
return
image
def
solarize
(
image
:
tf
.
Tensor
,
threshold
:
int
=
128
)
->
tf
.
Tensor
:
def
solarize
(
image
:
tf
.
Tensor
,
threshold
:
int
=
128
)
->
tf
.
Tensor
:
"""Solarize the input image(s)."""
# For each pixel in the image, select the pixel
# For each pixel in the image, select the pixel
# if the value is less than the threshold.
# if the value is less than the threshold.
# Otherwise, subtract 255 from the pixel.
# Otherwise, subtract 255 from the pixel.
...
@@ -321,6 +406,7 @@ def solarize(image: tf.Tensor, threshold: int = 128) -> tf.Tensor:
...
@@ -321,6 +406,7 @@ def solarize(image: tf.Tensor, threshold: int = 128) -> tf.Tensor:
def
solarize_add
(
image
:
tf
.
Tensor
,
def
solarize_add
(
image
:
tf
.
Tensor
,
addition
:
int
=
0
,
addition
:
int
=
0
,
threshold
:
int
=
128
)
->
tf
.
Tensor
:
threshold
:
int
=
128
)
->
tf
.
Tensor
:
"""Additive solarize the input image(s)."""
# For each pixel in the image less than threshold
# For each pixel in the image less than threshold
# we add 'addition' amount to it and then clip the
# we add 'addition' amount to it and then clip the
# pixel value to be between 0 and 255. The value
# pixel value to be between 0 and 255. The value
...
@@ -437,10 +523,11 @@ def autocontrast(image: tf.Tensor) -> tf.Tensor:
...
@@ -437,10 +523,11 @@ def autocontrast(image: tf.Tensor) -> tf.Tensor:
# Assumes RGB for now. Scales each channel independently
# Assumes RGB for now. Scales each channel independently
# and then stacks the result.
# and then stacks the result.
s1
=
scale_channel
(
image
[:,
:,
0
])
s1
=
scale_channel
(
image
[...,
0
])
s2
=
scale_channel
(
image
[:,
:,
1
])
s2
=
scale_channel
(
image
[...,
1
])
s3
=
scale_channel
(
image
[:,
:,
2
])
s3
=
scale_channel
(
image
[...,
2
])
image
=
tf
.
stack
([
s1
,
s2
,
s3
],
2
)
image
=
tf
.
stack
([
s1
,
s2
,
s3
],
-
1
)
return
image
return
image
...
@@ -451,22 +538,39 @@ def sharpness(image: tf.Tensor, factor: float) -> tf.Tensor:
...
@@ -451,22 +538,39 @@ def sharpness(image: tf.Tensor, factor: float) -> tf.Tensor:
# Make image 4D for conv operation.
# Make image 4D for conv operation.
image
=
tf
.
expand_dims
(
image
,
0
)
image
=
tf
.
expand_dims
(
image
,
0
)
# SMOOTH PIL Kernel.
# SMOOTH PIL Kernel.
kernel
=
tf
.
constant
([[
1
,
1
,
1
],
[
1
,
5
,
1
],
[
1
,
1
,
1
]],
if
orig_image
.
shape
.
rank
==
3
:
dtype
=
tf
.
float32
,
kernel
=
tf
.
constant
([[
1
,
1
,
1
],
[
1
,
5
,
1
],
[
1
,
1
,
1
]],
shape
=
[
3
,
3
,
1
,
1
])
/
13.
dtype
=
tf
.
float32
,
# Tile across channel dimension.
shape
=
[
3
,
3
,
1
,
1
])
/
13.
kernel
=
tf
.
tile
(
kernel
,
[
1
,
1
,
3
,
1
])
# Tile across channel dimension.
strides
=
[
1
,
1
,
1
,
1
]
kernel
=
tf
.
tile
(
kernel
,
[
1
,
1
,
3
,
1
])
degenerate
=
tf
.
nn
.
depthwise_conv2d
(
strides
=
[
1
,
1
,
1
,
1
]
image
,
kernel
,
strides
,
padding
=
'VALID'
,
dilations
=
[
1
,
1
])
degenerate
=
tf
.
nn
.
depthwise_conv2d
(
image
,
kernel
,
strides
,
padding
=
'VALID'
,
dilations
=
[
1
,
1
])
elif
orig_image
.
shape
.
rank
==
4
:
kernel
=
tf
.
constant
([[
1
,
1
,
1
],
[
1
,
5
,
1
],
[
1
,
1
,
1
]],
dtype
=
tf
.
float32
,
shape
=
[
1
,
3
,
3
,
1
,
1
])
/
13.
strides
=
[
1
,
1
,
1
,
1
,
1
]
# Run the kernel across each channel
channels
=
tf
.
split
(
image
,
3
,
axis
=-
1
)
degenerates
=
[
tf
.
nn
.
conv3d
(
channel
,
kernel
,
strides
,
padding
=
'VALID'
,
dilations
=
[
1
,
1
,
1
,
1
,
1
])
for
channel
in
channels
]
degenerate
=
tf
.
concat
(
degenerates
,
-
1
)
else
:
raise
ValueError
(
'Bad image rank: {}'
.
format
(
image
.
shape
.
rank
))
degenerate
=
tf
.
clip_by_value
(
degenerate
,
0.0
,
255.0
)
degenerate
=
tf
.
clip_by_value
(
degenerate
,
0.0
,
255.0
)
degenerate
=
tf
.
squeeze
(
tf
.
cast
(
degenerate
,
tf
.
uint8
),
[
0
])
degenerate
=
tf
.
squeeze
(
tf
.
cast
(
degenerate
,
tf
.
uint8
),
[
0
])
# For the borders of the resulting image, fill in the values of the
# For the borders of the resulting image, fill in the values of the
# original image.
# original image.
mask
=
tf
.
ones_like
(
degenerate
)
mask
=
tf
.
ones_like
(
degenerate
)
padded_mask
=
tf
.
pad
(
mask
,
[[
1
,
1
],
[
1
,
1
],
[
0
,
0
]])
paddings
=
[[
0
,
0
]]
*
(
orig_image
.
shape
.
rank
-
3
)
padded_degenerate
=
tf
.
pad
(
degenerate
,
[[
1
,
1
],
[
1
,
1
],
[
0
,
0
]])
padded_mask
=
tf
.
pad
(
mask
,
paddings
+
[[
1
,
1
],
[
1
,
1
],
[
0
,
0
]])
padded_degenerate
=
tf
.
pad
(
degenerate
,
paddings
+
[[
1
,
1
],
[
1
,
1
],
[
0
,
0
]])
result
=
tf
.
where
(
tf
.
equal
(
padded_mask
,
1
),
padded_degenerate
,
orig_image
)
result
=
tf
.
where
(
tf
.
equal
(
padded_mask
,
1
),
padded_degenerate
,
orig_image
)
# Blend the final result.
# Blend the final result.
...
@@ -478,7 +582,7 @@ def equalize(image: tf.Tensor) -> tf.Tensor:
...
@@ -478,7 +582,7 @@ def equalize(image: tf.Tensor) -> tf.Tensor:
def
scale_channel
(
im
,
c
):
def
scale_channel
(
im
,
c
):
"""Scale the data in the channel to implement equalize."""
"""Scale the data in the channel to implement equalize."""
im
=
tf
.
cast
(
im
[
:,
:
,
c
],
tf
.
int32
)
im
=
tf
.
cast
(
im
[
...
,
c
],
tf
.
int32
)
# Compute the histogram of the image channel.
# Compute the histogram of the image channel.
histo
=
tf
.
histogram_fixed_width
(
im
,
[
0
,
255
],
nbins
=
256
)
histo
=
tf
.
histogram_fixed_width
(
im
,
[
0
,
255
],
nbins
=
256
)
...
@@ -510,7 +614,7 @@ def equalize(image: tf.Tensor) -> tf.Tensor:
...
@@ -510,7 +614,7 @@ def equalize(image: tf.Tensor) -> tf.Tensor:
s1
=
scale_channel
(
image
,
0
)
s1
=
scale_channel
(
image
,
0
)
s2
=
scale_channel
(
image
,
1
)
s2
=
scale_channel
(
image
,
1
)
s3
=
scale_channel
(
image
,
2
)
s3
=
scale_channel
(
image
,
2
)
image
=
tf
.
stack
([
s1
,
s2
,
s3
],
2
)
image
=
tf
.
stack
([
s1
,
s2
,
s3
],
-
1
)
return
image
return
image
...
@@ -523,8 +627,8 @@ def invert(image: tf.Tensor) -> tf.Tensor:
...
@@ -523,8 +627,8 @@ def invert(image: tf.Tensor) -> tf.Tensor:
def
wrap
(
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
def
wrap
(
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Returns 'image' with an extra channel set to all 1s."""
"""Returns 'image' with an extra channel set to all 1s."""
shape
=
tf
.
shape
(
image
)
shape
=
tf
.
shape
(
image
)
extended_channel
=
tf
.
ones
(
[
shape
[
0
],
shape
[
1
],
1
],
image
.
dtype
)
extended_channel
=
tf
.
expand_dims
(
tf
.
ones
(
shape
[
:
-
1
],
image
.
dtype
)
,
-
1
)
extended
=
tf
.
concat
([
image
,
extended_channel
],
axis
=
2
)
extended
=
tf
.
concat
([
image
,
extended_channel
],
axis
=
-
1
)
return
extended
return
extended
...
@@ -548,10 +652,10 @@ def unwrap(image: tf.Tensor, replace: int) -> tf.Tensor:
...
@@ -548,10 +652,10 @@ def unwrap(image: tf.Tensor, replace: int) -> tf.Tensor:
"""
"""
image_shape
=
tf
.
shape
(
image
)
image_shape
=
tf
.
shape
(
image
)
# Flatten the spatial dimensions.
# Flatten the spatial dimensions.
flattened_image
=
tf
.
reshape
(
image
,
[
-
1
,
image_shape
[
2
]])
flattened_image
=
tf
.
reshape
(
image
,
[
-
1
,
image_shape
[
-
1
]])
# Find all pixels where the last channel is zero.
# Find all pixels where the last channel is zero.
alpha_channel
=
tf
.
expand_dims
(
flattened_image
[
:
,
3
],
axis
=-
1
)
alpha_channel
=
tf
.
expand_dims
(
flattened_image
[
...
,
3
],
axis
=-
1
)
replace
=
tf
.
concat
([
replace
,
tf
.
ones
([
1
],
image
.
dtype
)],
0
)
replace
=
tf
.
concat
([
replace
,
tf
.
ones
([
1
],
image
.
dtype
)],
0
)
...
@@ -562,7 +666,10 @@ def unwrap(image: tf.Tensor, replace: int) -> tf.Tensor:
...
@@ -562,7 +666,10 @@ def unwrap(image: tf.Tensor, replace: int) -> tf.Tensor:
flattened_image
)
flattened_image
)
image
=
tf
.
reshape
(
flattened_image
,
image_shape
)
image
=
tf
.
reshape
(
flattened_image
,
image_shape
)
image
=
tf
.
slice
(
image
,
[
0
,
0
,
0
],
[
image_shape
[
0
],
image_shape
[
1
],
3
])
image
=
tf
.
slice
(
image
,
[
0
]
*
image
.
shape
.
rank
,
tf
.
concat
([
image_shape
[:
-
1
],
[
3
]],
-
1
))
return
image
return
image
...
@@ -717,7 +824,8 @@ class ImageAugment(object):
...
@@ -717,7 +824,8 @@ class ImageAugment(object):
"""Given an image tensor, returns a distorted image with the same shape.
"""Given an image tensor, returns a distorted image with the same shape.
Args:
Args:
image: `Tensor` of shape [height, width, 3] representing an image.
image: `Tensor` of shape [height, width, 3] or
[num_frames, height, width, 3] representing an image or image sequence.
Returns:
Returns:
The augmented version of `image`.
The augmented version of `image`.
...
@@ -880,10 +988,6 @@ class AutoAugment(ImageAugment):
...
@@ -880,10 +988,6 @@ class AutoAugment(ImageAugment):
the policy.
the policy.
"""
"""
# TODO(dankondratyuk): tensorflow_addons defines custom ops, which
# for some reason are not included when building/linking
# This results in the error, "Op type not registered
# 'Addons>ImageProjectiveTransformV2' in binary" when running on borg TPUs
policy
=
[
policy
=
[
[(
'Equalize'
,
0.8
,
1
),
(
'ShearY'
,
0.8
,
4
)],
[(
'Equalize'
,
0.8
,
1
),
(
'ShearY'
,
0.8
,
4
)],
[(
'Color'
,
0.4
,
9
),
(
'Equalize'
,
0.6
,
3
)],
[(
'Color'
,
0.4
,
9
),
(
'Equalize'
,
0.6
,
3
)],
...
...
official/vision/beta/ops/augment_test.py
View file @
8717bca2
...
@@ -145,6 +145,44 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -145,6 +145,44 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertEqual
((
224
,
224
,
3
),
image
.
shape
)
self
.
assertEqual
((
224
,
224
,
3
),
image
.
shape
)
def
test_autoaugment_video
(
self
):
"""Smoke test with video to be sure there are no syntax errors."""
image
=
tf
.
zeros
((
2
,
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
for
policy
in
self
.
AVAILABLE_POLICIES
:
augmenter
=
augment
.
AutoAugment
(
augmentation_name
=
policy
)
aug_image
=
augmenter
.
distort
(
image
)
self
.
assertEqual
((
2
,
224
,
224
,
3
),
aug_image
.
shape
)
def
test_randaug_video
(
self
):
"""Smoke test with video to be sure there are no syntax errors."""
image
=
tf
.
zeros
((
2
,
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
augmenter
=
augment
.
RandAugment
()
aug_image
=
augmenter
.
distort
(
image
)
self
.
assertEqual
((
2
,
224
,
224
,
3
),
aug_image
.
shape
)
def
test_all_policy_ops_video
(
self
):
"""Smoke test to be sure all video augmentation functions can execute."""
prob
=
1
magnitude
=
10
replace_value
=
[
128
]
*
3
cutout_const
=
100
translate_const
=
250
image
=
tf
.
ones
((
2
,
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
for
op_name
in
augment
.
NAME_TO_FUNC
:
func
,
_
,
args
=
augment
.
_parse_policy_info
(
op_name
,
prob
,
magnitude
,
replace_value
,
cutout_const
,
translate_const
)
image
=
func
(
image
,
*
args
)
self
.
assertEqual
((
2
,
224
,
224
,
3
),
image
.
shape
)
def
_generate_test_policy
(
self
):
def
_generate_test_policy
(
self
):
"""Generate a test policy at random."""
"""Generate a test policy at random."""
op_list
=
list
(
augment
.
NAME_TO_FUNC
.
keys
())
op_list
=
list
(
augment
.
NAME_TO_FUNC
.
keys
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment