Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6d6e881a
Commit
6d6e881a
authored
Jul 20, 2022
by
A. Unique TensorFlower
Browse files
Add additional parameters for processing different image shape and label type.
PiperOrigin-RevId: 462275036
parent
62c74392
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
53 additions
and
13 deletions
+53
-13
official/vision/configs/video_classification.py
official/vision/configs/video_classification.py
+1
-0
official/vision/dataloaders/video_input.py
official/vision/dataloaders/video_input.py
+30
-13
official/vision/dataloaders/video_input_test.py
official/vision/dataloaders/video_input_test.py
+22
-0
No files found.
official/vision/configs/video_classification.py
View file @
6d6e881a
...
@@ -41,6 +41,7 @@ class DataConfig(cfg.DataConfig):
...
@@ -41,6 +41,7 @@ class DataConfig(cfg.DataConfig):
global_batch_size
:
int
=
128
global_batch_size
:
int
=
128
data_format
:
str
=
'channels_last'
data_format
:
str
=
'channels_last'
dtype
:
str
=
'float32'
dtype
:
str
=
'float32'
label_dtype
:
str
=
'int32'
one_hot
:
bool
=
True
one_hot
:
bool
=
True
shuffle_buffer_size
:
int
=
64
shuffle_buffer_size
:
int
=
64
cache
:
bool
=
False
cache
:
bool
=
False
...
...
official/vision/dataloaders/video_input.py
View file @
6d6e881a
...
@@ -36,7 +36,8 @@ def process_image(image: tf.Tensor,
...
@@ -36,7 +36,8 @@ def process_image(image: tf.Tensor,
random_stride_range
:
int
=
0
,
random_stride_range
:
int
=
0
,
num_test_clips
:
int
=
1
,
num_test_clips
:
int
=
1
,
min_resize
:
int
=
256
,
min_resize
:
int
=
256
,
crop_size
:
int
=
224
,
crop_size
:
Union
[
int
,
Tuple
[
int
,
int
]]
=
224
,
num_channels
:
int
=
3
,
num_crops
:
int
=
1
,
num_crops
:
int
=
1
,
zero_centering_image
:
bool
=
False
,
zero_centering_image
:
bool
=
False
,
min_aspect_ratio
:
float
=
0.5
,
min_aspect_ratio
:
float
=
0.5
,
...
@@ -64,8 +65,10 @@ def process_image(image: tf.Tensor,
...
@@ -64,8 +65,10 @@ def process_image(image: tf.Tensor,
If 1, then a single clip in the middle of the video is sampled. The clips
If 1, then a single clip in the middle of the video is sampled. The clips
are aggreagated in the batch dimension.
are aggreagated in the batch dimension.
min_resize: Frames are resized so that min(height, width) is min_resize.
min_resize: Frames are resized so that min(height, width) is min_resize.
crop_size: Final size of the frame after cropping the resized frames. Both
crop_size: Final size of the frame after cropping the resized frames.
height and width are the same.
Optionally, specify a tuple of (crop_height, crop_width) if
crop_height != crop_width.
num_channels: Number of channels of the clip.
num_crops: Number of crops to perform on the resized frames.
num_crops: Number of crops to perform on the resized frames.
zero_centering_image: If True, frames are normalized to values in [-1, 1].
zero_centering_image: If True, frames are normalized to values in [-1, 1].
If False, values in [0, 1].
If False, values in [0, 1].
...
@@ -78,7 +81,7 @@ def process_image(image: tf.Tensor,
...
@@ -78,7 +81,7 @@ def process_image(image: tf.Tensor,
Returns:
Returns:
Processed frames. Tensor of shape
Processed frames. Tensor of shape
[num_frames * num_test_clips, crop_
size
, crop_
size, 3
].
[num_frames * num_test_clips, crop_
height
, crop_
width, num_channels
].
"""
"""
# Validate parameters.
# Validate parameters.
if
is_training
and
num_test_clips
!=
1
:
if
is_training
and
num_test_clips
!=
1
:
...
@@ -90,6 +93,10 @@ def process_image(image: tf.Tensor,
...
@@ -90,6 +93,10 @@ def process_image(image: tf.Tensor,
raise
ValueError
(
'Random stride range should be >= 0, got {}'
.
format
(
raise
ValueError
(
'Random stride range should be >= 0, got {}'
.
format
(
random_stride_range
))
random_stride_range
))
if
isinstance
(
crop_size
,
int
):
crop_size
=
(
crop_size
,
crop_size
)
crop_height
,
crop_width
=
crop_size
# Temporal sampler.
# Temporal sampler.
if
is_training
:
if
is_training
:
if
random_stride_range
>
0
:
if
random_stride_range
>
0
:
...
@@ -113,12 +120,12 @@ def process_image(image: tf.Tensor,
...
@@ -113,12 +120,12 @@ def process_image(image: tf.Tensor,
# Decode JPEG string to tf.uint8.
# Decode JPEG string to tf.uint8.
if
image
.
dtype
==
tf
.
string
:
if
image
.
dtype
==
tf
.
string
:
image
=
preprocess_ops_3d
.
decode_jpeg
(
image
,
3
)
image
=
preprocess_ops_3d
.
decode_jpeg
(
image
,
num_channels
)
if
is_training
:
if
is_training
:
# Standard image data augmentation: random resized crop and random flip.
# Standard image data augmentation: random resized crop and random flip.
image
=
preprocess_ops_3d
.
random_crop_resize
(
image
=
preprocess_ops_3d
.
random_crop_resize
(
image
,
crop_
size
,
crop_
size
,
num_frames
,
3
,
image
,
crop_
height
,
crop_
width
,
num_frames
,
num_channels
,
(
min_aspect_ratio
,
max_aspect_ratio
),
(
min_aspect_ratio
,
max_aspect_ratio
),
(
min_area_ratio
,
max_area_ratio
))
(
min_area_ratio
,
max_area_ratio
))
image
=
preprocess_ops_3d
.
random_flip_left_right
(
image
,
seed
)
image
=
preprocess_ops_3d
.
random_flip_left_right
(
image
,
seed
)
...
@@ -129,7 +136,7 @@ def process_image(image: tf.Tensor,
...
@@ -129,7 +136,7 @@ def process_image(image: tf.Tensor,
# Resize images (resize happens only if necessary to save compute).
# Resize images (resize happens only if necessary to save compute).
image
=
preprocess_ops_3d
.
resize_smallest
(
image
,
min_resize
)
image
=
preprocess_ops_3d
.
resize_smallest
(
image
,
min_resize
)
# Crop of the frames.
# Crop of the frames.
image
=
preprocess_ops_3d
.
crop_image
(
image
,
crop_
size
,
crop_
size
,
False
,
image
=
preprocess_ops_3d
.
crop_image
(
image
,
crop_
height
,
crop_
width
,
False
,
num_crops
)
num_crops
)
# Cast the frames in float32, normalizing according to zero_centering_image.
# Cast the frames in float32, normalizing according to zero_centering_image.
...
@@ -173,15 +180,16 @@ def postprocess_image(image: tf.Tensor,
...
@@ -173,15 +180,16 @@ def postprocess_image(image: tf.Tensor,
def
process_label
(
label
:
tf
.
Tensor
,
def
process_label
(
label
:
tf
.
Tensor
,
one_hot_label
:
bool
=
True
,
one_hot_label
:
bool
=
True
,
num_classes
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
num_classes
:
Optional
[
int
]
=
None
,
label_dtype
:
tf
.
DType
=
tf
.
int32
)
->
tf
.
Tensor
:
"""Processes label Tensor."""
"""Processes label Tensor."""
# Validate parameters.
# Validate parameters.
if
one_hot_label
and
not
num_classes
:
if
one_hot_label
and
not
num_classes
:
raise
ValueError
(
raise
ValueError
(
'`num_classes` should be given when requesting one hot label.'
)
'`num_classes` should be given when requesting one hot label.'
)
# Cast to tf.int32.
# Cast to
label_dtype (default =
tf.int32
)
.
label
=
tf
.
cast
(
label
,
dtype
=
tf
.
int32
)
label
=
tf
.
cast
(
label
,
dtype
=
label_dtype
)
if
one_hot_label
:
if
one_hot_label
:
# Replace label index by one hot representation.
# Replace label index by one hot representation.
...
@@ -269,7 +277,11 @@ class Parser(parser.Parser):
...
@@ -269,7 +277,11 @@ class Parser(parser.Parser):
self
.
_random_stride_range
=
input_params
.
random_stride_range
self
.
_random_stride_range
=
input_params
.
random_stride_range
self
.
_num_test_clips
=
input_params
.
num_test_clips
self
.
_num_test_clips
=
input_params
.
num_test_clips
self
.
_min_resize
=
input_params
.
min_image_size
self
.
_min_resize
=
input_params
.
min_image_size
self
.
_crop_size
=
input_params
.
feature_shape
[
1
]
crop_height
=
input_params
.
feature_shape
[
1
]
crop_width
=
input_params
.
feature_shape
[
2
]
self
.
_crop_size
=
crop_height
if
crop_height
==
crop_width
else
(
crop_height
,
crop_width
)
self
.
_num_channels
=
input_params
.
feature_shape
[
3
]
self
.
_num_crops
=
input_params
.
num_test_crops
self
.
_num_crops
=
input_params
.
num_test_crops
self
.
_zero_centering_image
=
input_params
.
zero_centering_image
self
.
_zero_centering_image
=
input_params
.
zero_centering_image
self
.
_one_hot_label
=
input_params
.
one_hot
self
.
_one_hot_label
=
input_params
.
one_hot
...
@@ -277,6 +289,7 @@ class Parser(parser.Parser):
...
@@ -277,6 +289,7 @@ class Parser(parser.Parser):
self
.
_image_key
=
image_key
self
.
_image_key
=
image_key
self
.
_label_key
=
label_key
self
.
_label_key
=
label_key
self
.
_dtype
=
tf
.
dtypes
.
as_dtype
(
input_params
.
dtype
)
self
.
_dtype
=
tf
.
dtypes
.
as_dtype
(
input_params
.
dtype
)
self
.
_label_dtype
=
tf
.
dtypes
.
as_dtype
(
input_params
.
label_dtype
)
self
.
_output_audio
=
input_params
.
output_audio
self
.
_output_audio
=
input_params
.
output_audio
self
.
_min_aspect_ratio
=
input_params
.
aug_min_aspect_ratio
self
.
_min_aspect_ratio
=
input_params
.
aug_min_aspect_ratio
self
.
_max_aspect_ratio
=
input_params
.
aug_max_aspect_ratio
self
.
_max_aspect_ratio
=
input_params
.
aug_max_aspect_ratio
...
@@ -324,6 +337,7 @@ class Parser(parser.Parser):
...
@@ -324,6 +337,7 @@ class Parser(parser.Parser):
num_test_clips
=
self
.
_num_test_clips
,
num_test_clips
=
self
.
_num_test_clips
,
min_resize
=
self
.
_min_resize
,
min_resize
=
self
.
_min_resize
,
crop_size
=
self
.
_crop_size
,
crop_size
=
self
.
_crop_size
,
num_channels
=
self
.
_num_channels
,
min_aspect_ratio
=
self
.
_min_aspect_ratio
,
min_aspect_ratio
=
self
.
_min_aspect_ratio
,
max_aspect_ratio
=
self
.
_max_aspect_ratio
,
max_aspect_ratio
=
self
.
_max_aspect_ratio
,
min_area_ratio
=
self
.
_min_area_ratio
,
min_area_ratio
=
self
.
_min_area_ratio
,
...
@@ -335,7 +349,8 @@ class Parser(parser.Parser):
...
@@ -335,7 +349,8 @@ class Parser(parser.Parser):
features
=
{
'image'
:
image
}
features
=
{
'image'
:
image
}
label
=
decoded_tensors
[
self
.
_label_key
]
label
=
decoded_tensors
[
self
.
_label_key
]
label
=
process_label
(
label
,
self
.
_one_hot_label
,
self
.
_num_classes
)
label
=
process_label
(
label
,
self
.
_one_hot_label
,
self
.
_num_classes
,
self
.
_label_dtype
)
if
self
.
_output_audio
:
if
self
.
_output_audio
:
audio
=
decoded_tensors
[
self
.
_audio_feature
]
audio
=
decoded_tensors
[
self
.
_audio_feature
]
...
@@ -361,13 +376,15 @@ class Parser(parser.Parser):
...
@@ -361,13 +376,15 @@ class Parser(parser.Parser):
num_test_clips
=
self
.
_num_test_clips
,
num_test_clips
=
self
.
_num_test_clips
,
min_resize
=
self
.
_min_resize
,
min_resize
=
self
.
_min_resize
,
crop_size
=
self
.
_crop_size
,
crop_size
=
self
.
_crop_size
,
num_channels
=
self
.
_num_channels
,
num_crops
=
self
.
_num_crops
,
num_crops
=
self
.
_num_crops
,
zero_centering_image
=
self
.
_zero_centering_image
)
zero_centering_image
=
self
.
_zero_centering_image
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
features
=
{
'image'
:
image
}
features
=
{
'image'
:
image
}
label
=
decoded_tensors
[
self
.
_label_key
]
label
=
decoded_tensors
[
self
.
_label_key
]
label
=
process_label
(
label
,
self
.
_one_hot_label
,
self
.
_num_classes
)
label
=
process_label
(
label
,
self
.
_one_hot_label
,
self
.
_num_classes
,
self
.
_label_dtype
)
if
self
.
_output_audio
:
if
self
.
_output_audio
:
audio
=
decoded_tensors
[
self
.
_audio_feature
]
audio
=
decoded_tensors
[
self
.
_audio_feature
]
...
...
official/vision/dataloaders/video_input_test.py
View file @
6d6e881a
...
@@ -191,6 +191,28 @@ class VideoAndLabelParserTest(tf.test.TestCase):
...
@@ -191,6 +191,28 @@ class VideoAndLabelParserTest(tf.test.TestCase):
self
.
assertAllEqual
(
image
.
shape
,
(
2
,
224
,
224
,
3
))
self
.
assertAllEqual
(
image
.
shape
,
(
2
,
224
,
224
,
3
))
self
.
assertAllEqual
(
label
.
shape
,
(
600
,))
self
.
assertAllEqual
(
label
.
shape
,
(
600
,))
def
test_video_input_image_shape_label_type
(
self
):
params
=
exp_cfg
.
kinetics600
(
is_training
=
True
)
params
.
feature_shape
=
(
2
,
168
,
224
,
1
)
params
.
min_image_size
=
168
params
.
label_dtype
=
'float32'
params
.
one_hot
=
False
decoder
=
video_input
.
Decoder
()
parser
=
video_input
.
Parser
(
params
).
parse_fn
(
params
.
is_training
)
seq_example
,
label
=
fake_seq_example
()
input_tensor
=
tf
.
constant
(
seq_example
.
SerializeToString
())
decoded_tensors
=
decoder
.
decode
(
input_tensor
)
output_tensor
=
parser
(
decoded_tensors
)
image_features
,
label
=
output_tensor
image
=
image_features
[
'image'
]
self
.
assertAllEqual
(
image
.
shape
,
(
2
,
168
,
224
,
1
))
self
.
assertAllEqual
(
label
.
shape
,
(
1
,))
self
.
assertDTypeEqual
(
label
,
tf
.
float32
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment