Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1d38cca0
"vscode:/vscode.git/clone" did not exist on "e66c29261a8b8db6214ddebdc727e7b247be74df"
Commit
1d38cca0
authored
Nov 19, 2020
by
Yeqing Li
Committed by
A. Unique TensorFlower
Nov 19, 2020
Browse files
Internal change
PiperOrigin-RevId: 343441984
parent
3dae7a77
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
8 deletions
+26
-8
official/vision/beta/tasks/video_classification.py
official/vision/beta/tasks/video_classification.py
+26
-8
No files found.
official/vision/beta/tasks/video_classification.py
View file @
1d38cca0
...
...
@@ -54,18 +54,30 @@ class VideoClassificationTask(base_task.Task):
l2_regularizer
=
l2_regularizer
)
return
model
def
_get_dataset_fn
(
self
,
params
):
if
params
.
file_type
==
'tfrecord'
:
return
tf
.
data
.
TFRecordDataset
else
:
raise
ValueError
(
'Unknown input file type {!r}'
.
format
(
params
.
file_type
))
def
_get_decoder_fn
(
self
,
params
):
decoder
=
video_input
.
Decoder
()
if
self
.
task_config
.
train_data
.
output_audio
:
assert
self
.
task_config
.
train_data
.
audio_feature
,
'audio feature is empty'
decoder
.
add_feature
(
self
.
task_config
.
train_data
.
audio_feature
,
tf
.
io
.
VarLenFeature
(
dtype
=
tf
.
float32
))
return
decoder
.
decode
def
build_inputs
(
self
,
params
:
exp_cfg
.
DataConfig
,
input_context
=
None
):
"""Builds classification input."""
decoder
=
video_input
.
Decoder
()
decoder_fn
=
decoder
.
decode
parser
=
video_input
.
Parser
(
input_params
=
params
)
postprocess_fn
=
video_input
.
PostBatchProcessor
(
params
)
reader
=
input_reader
.
InputReader
(
params
,
dataset_fn
=
tf
.
data
.
TFRecordDataset
,
decoder_fn
=
decoder_fn
,
dataset_fn
=
self
.
_get_dataset_fn
(
params
)
,
decoder_fn
=
self
.
_get_
decoder_fn
(
params
)
,
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
),
postprocess_fn
=
postprocess_fn
)
...
...
@@ -183,6 +195,9 @@ class VideoClassificationTask(base_task.Task):
num_replicas
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
with
tf
.
GradientTape
()
as
tape
:
if
self
.
task_config
.
train_data
.
output_audio
:
outputs
=
model
(
features
,
training
=
True
)
else
:
outputs
=
model
(
features
[
'image'
],
training
=
True
)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
...
...
@@ -237,7 +252,7 @@ class VideoClassificationTask(base_task.Task):
"""
features
,
labels
=
inputs
outputs
=
self
.
inference_step
(
features
[
'image'
]
,
model
)
outputs
=
self
.
inference_step
(
features
,
model
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
logs
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
...
...
@@ -250,9 +265,12 @@ class VideoClassificationTask(base_task.Task):
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
model
.
metrics
})
return
logs
def
inference_step
(
self
,
input
s
,
model
):
def
inference_step
(
self
,
feature
s
,
model
):
"""Performs the forward step."""
outputs
=
model
(
inputs
,
training
=
False
)
if
self
.
task_config
.
train_data
.
output_audio
:
outputs
=
model
(
features
,
training
=
False
)
else
:
outputs
=
model
(
features
[
'image'
],
training
=
False
)
if
self
.
task_config
.
train_data
.
is_multilabel
:
outputs
=
tf
.
math
.
sigmoid
(
outputs
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment