Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
fb9f35c8
Commit
fb9f35c8
authored
May 17, 2021
by
Yeqing Li
Committed by
A. Unique TensorFlower
May 17, 2021
Browse files
Adds a dummy task for multi-modal representation learning.
PiperOrigin-RevId: 374239384
parent
9f88ce51
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
18 deletions
+31
-18
official/vision/beta/tasks/video_classification.py
official/vision/beta/tasks/video_classification.py
+31
-18
No files found.
official/vision/beta/tasks/video_classification.py
View file @
fb9f35c8
...
...
@@ -30,13 +30,32 @@ from official.vision.beta.modeling import factory_3d
class
VideoClassificationTask
(
base_task
.
Task
):
"""A task for video classification."""
def
build_model
(
self
):
"""Builds video classification model."""
common_input_shape
=
[
def
_get_num_classes
(
self
):
"""Gets the number of classes."""
return
self
.
task_config
.
train_data
.
num_classes
def
_get_feature_shape
(
self
):
"""Get the common feature shape for train and eval."""
return
[
d1
if
d1
==
d2
else
None
for
d1
,
d2
in
zip
(
self
.
task_config
.
train_data
.
feature_shape
,
self
.
task_config
.
validation_data
.
feature_shape
)
]
def
_get_num_test_views
(
self
):
"""Gets number of views for test."""
num_test_clips
=
self
.
task_config
.
validation_data
.
num_test_clips
num_test_crops
=
self
.
task_config
.
validation_data
.
num_test_crops
num_test_views
=
num_test_clips
*
num_test_crops
return
num_test_views
def
_is_multilabel
(
self
):
"""If the label is multi-labels."""
return
self
.
task_config
.
train_data
.
is_multilabel
def
build_model
(
self
):
"""Builds video classification model."""
common_input_shape
=
self
.
_get_feature_shape
()
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
common_input_shape
)
logging
.
info
(
'Build model input %r'
,
common_input_shape
)
...
...
@@ -51,7 +70,7 @@ class VideoClassificationTask(base_task.Task):
self
.
task_config
.
model
.
model_type
,
input_specs
=
input_specs
,
model_config
=
self
.
task_config
.
model
,
num_classes
=
self
.
task_config
.
train_data
.
num_classes
,
num_classes
=
self
.
_get_
num_classes
()
,
l2_regularizer
=
l2_regularizer
)
return
model
...
...
@@ -138,7 +157,7 @@ class VideoClassificationTask(base_task.Task):
all_losses
=
{}
losses_config
=
self
.
task_config
.
losses
total_loss
=
None
if
self
.
task_config
.
train_data
.
is_multilabel
:
if
self
.
_
is_multilabel
()
:
entropy
=
-
tf
.
reduce_mean
(
tf
.
reduce_sum
(
model_outputs
*
tf
.
math
.
log
(
model_outputs
+
1e-8
),
-
1
))
total_loss
=
tf
.
keras
.
losses
.
binary_crossentropy
(
...
...
@@ -179,22 +198,18 @@ class VideoClassificationTask(base_task.Task):
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
k
=
1
,
name
=
'top_1_accuracy'
),
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
k
=
5
,
name
=
'top_5_accuracy'
)
]
if
self
.
task_config
.
train_data
.
is_multilabel
:
if
self
.
_
is_multilabel
()
:
metrics
.
append
(
tf
.
keras
.
metrics
.
AUC
(
curve
=
'ROC'
,
multi_label
=
self
.
task_config
.
train_data
.
is_multilabel
,
name
=
'ROC-AUC'
))
curve
=
'ROC'
,
multi_label
=
self
.
_is_multilabel
(),
name
=
'ROC-AUC'
))
metrics
.
append
(
tf
.
keras
.
metrics
.
RecallAtPrecision
(
0.95
,
name
=
'RecallAtPrecision95'
))
metrics
.
append
(
tf
.
keras
.
metrics
.
AUC
(
curve
=
'PR'
,
multi_label
=
self
.
task_config
.
train_data
.
is_multilabel
,
name
=
'PR-AUC'
))
curve
=
'PR'
,
multi_label
=
self
.
_is_multilabel
(),
name
=
'PR-AUC'
))
if
self
.
task_config
.
metrics
.
use_per_class_recall
:
for
i
in
range
(
self
.
task_config
.
train_data
.
num_classes
):
for
i
in
range
(
self
.
_get_
num_classes
()
):
metrics
.
append
(
tf
.
keras
.
metrics
.
Recall
(
class_id
=
i
,
name
=
f
'recall-
{
i
}
'
))
else
:
...
...
@@ -250,7 +265,7 @@ class VideoClassificationTask(base_task.Task):
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
# Computes per-replica loss.
if
self
.
task_config
.
train_data
.
is_multilabel
:
if
self
.
_
is_multilabel
()
:
outputs
=
tf
.
math
.
sigmoid
(
outputs
)
else
:
outputs
=
tf
.
math
.
softmax
(
outputs
)
...
...
@@ -316,13 +331,11 @@ class VideoClassificationTask(base_task.Task):
def
inference_step
(
self
,
features
:
tf
.
Tensor
,
model
:
tf
.
keras
.
Model
):
"""Performs the forward step."""
outputs
=
model
(
features
,
training
=
False
)
if
self
.
task_config
.
train_data
.
is_multilabel
:
if
self
.
_
is_multilabel
()
:
outputs
=
tf
.
math
.
sigmoid
(
outputs
)
else
:
outputs
=
tf
.
math
.
softmax
(
outputs
)
num_test_clips
=
self
.
task_config
.
validation_data
.
num_test_clips
num_test_crops
=
self
.
task_config
.
validation_data
.
num_test_crops
num_test_views
=
num_test_clips
*
num_test_crops
num_test_views
=
self
.
_get_num_test_views
()
if
num_test_views
>
1
:
# Averaging output probabilities across multiples views.
outputs
=
tf
.
reshape
(
outputs
,
[
-
1
,
num_test_views
,
outputs
.
shape
[
-
1
]])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment