Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
40cd0e14
Commit
40cd0e14
authored
Jan 15, 2021
by
Yin Cui
Committed by
A. Unique TensorFlower
Jan 15, 2021
Browse files
Internal change
PiperOrigin-RevId: 352087472
parent
2b949afd
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
45 additions
and
21 deletions
+45
-21
official/vision/beta/configs/experiments/video_classification/k400_slowonly8x8_tpu.yaml
...xperiments/video_classification/k400_slowonly8x8_tpu.yaml
+6
-2
official/vision/beta/configs/video_classification.py
official/vision/beta/configs/video_classification.py
+1
-0
official/vision/beta/dataloaders/video_input.py
official/vision/beta/dataloaders/video_input.py
+24
-12
official/vision/beta/ops/preprocess_ops_3d.py
official/vision/beta/ops/preprocess_ops_3d.py
+7
-7
official/vision/beta/tasks/video_classification.py
official/vision/beta/tasks/video_classification.py
+7
-0
No files found.
official/vision/beta/configs/experiments/video_classification/k400_slowonly8x8_tpu.yaml
View file @
40cd0e14
# SlowOnly video classification on Kinetics-400. Expected performance to be updated.
#
# --experiment_type=video_classification_kinetics400
# Expected accuracy: 71.5% top-1, 89.5% top-5.
runtime
:
distribution_strategy
:
'
tpu'
mixed_precision_dtype
:
'
bfloat16'
...
...
@@ -61,8 +64,9 @@ task:
-
256
-
3
temporal_stride
:
8
num_test_clips
:
1
global_batch_size
:
32
num_test_clips
:
10
num_test_crops
:
3
global_batch_size
:
64
dtype
:
'
bfloat16'
drop_remainder
:
false
trainer
:
...
...
official/vision/beta/configs/video_classification.py
View file @
40cd0e14
...
...
@@ -34,6 +34,7 @@ class DataConfig(cfg.DataConfig):
feature_shape
:
Tuple
[
int
,
...]
=
(
64
,
224
,
224
,
3
)
temporal_stride
:
int
=
1
num_test_clips
:
int
=
1
num_test_crops
:
int
=
1
num_classes
:
int
=
-
1
num_channels
:
int
=
3
num_examples
:
int
=
-
1
...
...
official/vision/beta/dataloaders/video_input.py
View file @
40cd0e14
...
...
@@ -34,8 +34,9 @@ def _process_image(image: tf.Tensor,
num_frames
:
int
=
32
,
stride
:
int
=
1
,
num_test_clips
:
int
=
1
,
min_resize
:
int
=
224
,
crop_size
:
int
=
200
,
min_resize
:
int
=
256
,
crop_size
:
int
=
224
,
num_crops
:
int
=
1
,
zero_centering_image
:
bool
=
False
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Processes a serialized image tensor.
...
...
@@ -54,6 +55,7 @@ def _process_image(image: tf.Tensor,
min_resize: Frames are resized so that min(height, width) is min_resize.
crop_size: Final size of the frame after cropping the resized frames. Both
height and width are the same.
num_crops: Number of crops to perform on the resized frames.
zero_centering_image: If True, frames are normalized to values in [-1, 1].
If False, values in [0, 1].
seed: A deterministic seed to use when sampling.
...
...
@@ -93,8 +95,9 @@ def _process_image(image: tf.Tensor,
seed
)
image
=
preprocess_ops_3d
.
random_flip_left_right
(
image
,
seed
)
else
:
# Central crop of the frames.
image
=
preprocess_ops_3d
.
crop_image
(
image
,
crop_size
,
crop_size
,
False
)
# Crop of the frames.
image
=
preprocess_ops_3d
.
crop_image
(
image
,
crop_size
,
crop_size
,
False
,
num_crops
)
# Cast the frames in float32, normalizing according to zero_centering_image.
return
preprocess_ops_3d
.
normalize_image
(
image
,
zero_centering_image
)
...
...
@@ -103,7 +106,8 @@ def _process_image(image: tf.Tensor,
def
_postprocess_image
(
image
:
tf
.
Tensor
,
is_training
:
bool
=
True
,
num_frames
:
int
=
32
,
num_test_clips
:
int
=
1
)
->
tf
.
Tensor
:
num_test_clips
:
int
=
1
,
num_test_crops
:
int
=
1
)
->
tf
.
Tensor
:
"""Processes a batched Tensor of frames.
The same parameters used in process should be used here.
...
...
@@ -117,15 +121,19 @@ def _postprocess_image(image: tf.Tensor,
will sample multiple linearly spaced clips within each video at test time.
If 1, then a single clip in the middle of the video is sampled. The clips
are aggreagated in the batch dimension.
num_test_crops: Number of test crops (1 by default). If more than 1, there
are multiple crops for each clip at test time. If 1, there is a single
central crop. The crops are aggreagated in the batch dimension.
Returns:
Processed frames. Tensor of shape
[batch * num_test_clips, num_frames, height, width, 3].
[batch * num_test_clips
* num_test_crops
, num_frames, height, width, 3].
"""
if
num_test_clips
>
1
and
not
is_training
:
# In this case, multiple clips are merged together in batch dimenstion which
# will be B * num_test_clips.
image
=
tf
.
reshape
(
image
,
(
-
1
,
num_frames
)
+
image
.
shape
[
2
:])
num_views
=
num_test_clips
*
num_test_crops
if
num_views
>
1
and
not
is_training
:
# In this case, multiple views are merged together in batch dimenstion which
# will be batch * num_views.
image
=
tf
.
reshape
(
image
,
[
-
1
,
num_frames
]
+
image
.
shape
[
2
:].
as_list
())
return
image
...
...
@@ -207,6 +215,7 @@ class Parser(parser.Parser):
self
.
_num_test_clips
=
input_params
.
num_test_clips
self
.
_min_resize
=
input_params
.
min_image_size
self
.
_crop_size
=
input_params
.
feature_shape
[
1
]
self
.
_num_crops
=
input_params
.
num_test_crops
self
.
_one_hot_label
=
input_params
.
one_hot
self
.
_num_classes
=
input_params
.
num_classes
self
.
_image_key
=
image_key
...
...
@@ -260,7 +269,8 @@ class Parser(parser.Parser):
stride
=
self
.
_stride
,
num_test_clips
=
self
.
_num_test_clips
,
min_resize
=
self
.
_min_resize
,
crop_size
=
self
.
_crop_size
)
crop_size
=
self
.
_crop_size
,
num_crops
=
self
.
_num_crops
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
features
=
{
'image'
:
image
}
...
...
@@ -286,6 +296,7 @@ class PostBatchProcessor(object):
self
.
_num_frames
=
input_params
.
feature_shape
[
0
]
self
.
_num_test_clips
=
input_params
.
num_test_clips
self
.
_num_test_crops
=
input_params
.
num_test_crops
def
__call__
(
self
,
features
:
Dict
[
str
,
tf
.
Tensor
],
label
:
tf
.
Tensor
)
->
Tuple
[
Dict
[
str
,
tf
.
Tensor
],
tf
.
Tensor
]:
...
...
@@ -296,6 +307,7 @@ class PostBatchProcessor(object):
image
=
features
[
key
],
is_training
=
self
.
_is_training
,
num_frames
=
self
.
_num_frames
,
num_test_clips
=
self
.
_num_test_clips
)
num_test_clips
=
self
.
_num_test_clips
,
num_test_crops
=
self
.
_num_test_crops
)
return
features
,
label
official/vision/beta/ops/preprocess_ops_3d.py
View file @
40cd0e14
...
...
@@ -151,19 +151,19 @@ def crop_image(frames: tf.Tensor,
target_height
:
int
,
target_width
:
int
,
random
:
bool
=
False
,
num_
view
s
:
int
=
1
,
num_
crop
s
:
int
=
1
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Crops the image sequence of images.
If requested size is bigger than image size, image is padded with 0. If not
random cropping, a central crop is performed.
random cropping, a central crop is performed
if num_crops is 1
.
Args:
frames: A Tensor of dimension [timesteps, in_height, in_width, channels].
target_height: Target cropped image height.
target_width: Target cropped image width.
random: A boolean indicating if crop should be randomized.
num_
view
s: Number of
views to crop in evaluation
.
num_
crop
s: Number of
crops (support 1 for central crop and 3 for 3-crop)
.
seed: A deterministic seed to use when random cropping.
Returns:
...
...
@@ -181,13 +181,13 @@ def crop_image(frames: tf.Tensor,
frames
=
tf
.
image
.
random_crop
(
frames
,
(
seq_len
,
target_height
,
target_width
,
channels
),
seed
)
else
:
if
num_
view
s
==
1
:
if
num_
crop
s
==
1
:
# Central crop or pad.
frames
=
tf
.
image
.
resize_with_crop_or_pad
(
frames
,
target_height
,
target_width
)
elif
num_
view
s
==
3
:
# Three-
view
evaluation.
elif
num_
crop
s
==
3
:
# Three-
crop
evaluation.
shape
=
tf
.
shape
(
frames
)
static_shape
=
frames
.
shape
.
as_list
()
seq_len
=
shape
[
0
]
if
static_shape
[
0
]
is
None
else
static_shape
[
0
]
...
...
@@ -224,7 +224,7 @@ def crop_image(frames: tf.Tensor,
else
:
raise
NotImplementedError
(
f
"Only 1
crop and 3
crop are supported. Found
{
num_
view
s
!
r
}
."
)
f
"Only 1
-
crop and 3
-
crop are supported. Found
{
num_
crop
s
!
r
}
."
)
return
frames
...
...
official/vision/beta/tasks/video_classification.py
View file @
40cd0e14
...
...
@@ -275,4 +275,11 @@ class VideoClassificationTask(base_task.Task):
outputs
=
tf
.
math
.
sigmoid
(
outputs
)
else
:
outputs
=
tf
.
math
.
softmax
(
outputs
)
num_test_clips
=
self
.
task_config
.
validation_data
.
num_test_clips
num_test_crops
=
self
.
task_config
.
validation_data
.
num_test_crops
num_test_views
=
num_test_clips
*
num_test_crops
if
num_test_views
>
1
:
# Averaging output probabilities across multiples views.
outputs
=
tf
.
reshape
(
outputs
,
[
-
1
,
num_test_views
,
outputs
.
shape
[
-
1
]])
outputs
=
tf
.
reduce_mean
(
outputs
,
axis
=
1
)
return
outputs
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment