Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b7cbd12b
Commit
b7cbd12b
authored
Nov 12, 2020
by
Yeqing Li
Committed by
A. Unique TensorFlower
Nov 12, 2020
Browse files
Internal change
PiperOrigin-RevId: 342166637
parent
0a52f120
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
86 additions
and
21 deletions
+86
-21
official/vision/beta/configs/video_classification.py
official/vision/beta/configs/video_classification.py
+1
-0
official/vision/beta/dataloaders/video_input.py
official/vision/beta/dataloaders/video_input.py
+13
-6
official/vision/beta/tasks/video_classification.py
official/vision/beta/tasks/video_classification.py
+72
-15
No files found.
official/vision/beta/configs/video_classification.py
View file @
b7cbd12b
...
...
@@ -48,6 +48,7 @@ class DataConfig(cfg.DataConfig):
is_training
:
bool
=
True
cycle_length
:
int
=
10
min_image_size
:
int
=
256
is_multilabel
:
bool
=
False
def
kinetics400
(
is_training
):
...
...
official/vision/beta/dataloaders/video_input.py
View file @
b7cbd12b
...
...
@@ -146,6 +146,11 @@ def _process_label(label: tf.Tensor,
if
one_hot_label
:
# Replace label index by one hot representation.
label
=
tf
.
one_hot
(
label
,
num_classes
)
if
len
(
label
.
shape
.
as_list
())
>
1
:
label
=
tf
.
reduce_sum
(
label
,
axis
=
0
)
if
num_classes
==
1
:
# The trick for single label.
label
=
1
-
label
return
label
...
...
@@ -154,11 +159,11 @@ class Decoder(decoder.Decoder):
"""A tf.Example decoder for classification task."""
def
__init__
(
self
,
image_key
:
str
=
IMAGE_KEY
,
label_key
:
str
=
LABEL_KEY
):
self
.
_image_key
=
IMAGE_KEY
self
.
_label_key
=
LABEL_KEY
self
.
_image_key
=
image_key
self
.
_label_key
=
label_key
self
.
_context_description
=
{
# One integer stored in context.
self
.
_label_key
:
tf
.
io
.
Fixed
LenFeature
(
(),
tf
.
int64
),
self
.
_label_key
:
tf
.
io
.
Var
LenFeature
(
tf
.
int64
),
}
self
.
_sequence_description
=
{
# Each image is a string encoding JPEG.
...
...
@@ -172,7 +177,7 @@ class Decoder(decoder.Decoder):
self
.
_sequence_description
)
return
{
self
.
_image_key
:
sequences
[
self
.
_image_key
],
self
.
_label_key
:
context
[
self
.
_label_key
]
self
.
_label_key
:
tf
.
sparse
.
to_dense
(
context
[
self
.
_label_key
]
)
}
...
...
@@ -200,7 +205,6 @@ class Parser(parser.Parser):
"""Parses data for training."""
# Process image and label.
image
=
decoded_tensors
[
self
.
_image_key
]
label
=
decoded_tensors
[
self
.
_label_key
]
image
=
_process_image
(
image
=
image
,
is_training
=
True
,
...
...
@@ -210,6 +214,8 @@ class Parser(parser.Parser):
min_resize
=
self
.
_min_resize
,
crop_size
=
self
.
_crop_size
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
label
=
decoded_tensors
[
self
.
_label_key
]
label
=
_process_label
(
label
,
self
.
_one_hot_label
,
self
.
_num_classes
)
return
{
'image'
:
image
},
label
...
...
@@ -219,7 +225,6 @@ class Parser(parser.Parser):
)
->
Tuple
[
Dict
[
str
,
tf
.
Tensor
],
tf
.
Tensor
]:
"""Parses data for evaluation."""
image
=
decoded_tensors
[
self
.
_image_key
]
label
=
decoded_tensors
[
self
.
_label_key
]
image
=
_process_image
(
image
=
image
,
is_training
=
False
,
...
...
@@ -229,6 +234,8 @@ class Parser(parser.Parser):
min_resize
=
self
.
_min_resize
,
crop_size
=
self
.
_crop_size
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
label
=
decoded_tensors
[
self
.
_label_key
]
label
=
_process_label
(
label
,
self
.
_one_hot_label
,
self
.
_num_classes
)
return
{
'image'
:
image
},
label
...
...
official/vision/beta/tasks/video_classification.py
View file @
b7cbd12b
...
...
@@ -84,22 +84,41 @@ class VideoClassificationTask(base_task.Task):
Returns:
The total loss tensor.
"""
all_losses
=
{}
losses_config
=
self
.
task_config
.
losses
if
losses_config
.
one_hot
:
total_loss
=
tf
.
keras
.
losses
.
categorical_crossentropy
(
labels
,
model_outputs
,
from_logits
=
True
,
label_smoothing
=
losses_config
.
label_smoothing
)
total_loss
=
None
if
self
.
task_config
.
train_data
.
is_multilabel
:
entropy
=
-
tf
.
reduce_mean
(
tf
.
reduce_sum
(
model_outputs
*
tf
.
math
.
log
(
model_outputs
+
1e-8
),
-
1
))
total_loss
=
tf
.
keras
.
losses
.
binary_crossentropy
(
labels
,
model_outputs
,
from_logits
=
False
)
all_losses
.
update
({
'class_loss'
:
total_loss
,
'entropy'
:
entropy
,
})
else
:
total_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
,
model_outputs
,
from_logits
=
True
)
if
losses_config
.
one_hot
:
total_loss
=
tf
.
keras
.
losses
.
categorical_crossentropy
(
labels
,
model_outputs
,
from_logits
=
False
,
label_smoothing
=
losses_config
.
label_smoothing
)
else
:
total_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
,
model_outputs
,
from_logits
=
False
)
total_loss
=
tf_utils
.
safe_mean
(
total_loss
)
total_loss
=
tf_utils
.
safe_mean
(
total_loss
)
all_losses
.
update
({
'class_loss'
:
total_loss
,
})
if
aux_losses
:
all_losses
.
update
({
'reg_loss'
:
aux_losses
,
})
total_loss
+=
tf
.
add_n
(
aux_losses
)
all_losses
[
self
.
loss
]
=
total_loss
return
tot
al_loss
return
a
l
l_loss
es
def
build_metrics
(
self
,
training
=
True
):
"""Gets streaming metrics for training/validation."""
...
...
@@ -109,6 +128,20 @@ class VideoClassificationTask(base_task.Task):
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
k
=
1
,
name
=
'top_1_accuracy'
),
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
k
=
5
,
name
=
'top_5_accuracy'
)
]
if
self
.
task_config
.
train_data
.
is_multilabel
:
metrics
.
append
(
tf
.
keras
.
metrics
.
AUC
(
curve
=
'ROC'
,
multi_label
=
self
.
task_config
.
train_data
.
is_multilabel
,
name
=
'ROC-AUC'
))
metrics
.
append
(
tf
.
keras
.
metrics
.
RecallAtPrecision
(
0.95
,
name
=
'RecallAtPrecision95'
))
metrics
.
append
(
tf
.
keras
.
metrics
.
AUC
(
curve
=
'PR'
,
multi_label
=
self
.
task_config
.
train_data
.
is_multilabel
,
name
=
'PR-AUC'
))
else
:
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
),
...
...
@@ -119,6 +152,21 @@ class VideoClassificationTask(base_task.Task):
]
return
metrics
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
"""Process and update metrics.
Called when using custom training loop API.
Args:
metrics: a nested structure of metrics objects. The return of function
self.build_metrics.
labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors. For example,
output of the keras model built by self.build_model.
"""
for
metric
in
metrics
:
metric
.
update_state
(
labels
,
model_outputs
)
def
train_step
(
self
,
inputs
,
model
,
optimizer
,
metrics
=
None
):
"""Does forward and backward.
...
...
@@ -142,8 +190,13 @@ class VideoClassificationTask(base_task.Task):
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
# Computes per-replica loss.
loss
=
self
.
build_losses
(
if
self
.
task_config
.
train_data
.
is_multilabel
:
outputs
=
tf
.
math
.
sigmoid
(
outputs
)
else
:
outputs
=
tf
.
math
.
softmax
(
outputs
)
all_losses
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
loss
=
all_losses
[
self
.
loss
]
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss
=
loss
/
num_replicas
...
...
@@ -162,7 +215,7 @@ class VideoClassificationTask(base_task.Task):
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
logs
=
{
self
.
loss
:
loss
}
logs
=
all_losses
if
metrics
:
self
.
process_metrics
(
metrics
,
labels
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
metrics
})
...
...
@@ -186,10 +239,9 @@ class VideoClassificationTask(base_task.Task):
outputs
=
self
.
inference_step
(
features
[
'image'
],
model
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
lo
s
s
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
lo
g
s
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
logs
=
{
self
.
loss
:
loss
}
if
metrics
:
self
.
process_metrics
(
metrics
,
labels
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
metrics
})
...
...
@@ -200,4 +252,9 @@ class VideoClassificationTask(base_task.Task):
def
inference_step
(
self
,
inputs
,
model
):
"""Performs the forward step."""
return
model
(
inputs
,
training
=
False
)
outputs
=
model
(
inputs
,
training
=
False
)
if
self
.
task_config
.
train_data
.
is_multilabel
:
outputs
=
tf
.
math
.
sigmoid
(
outputs
)
else
:
outputs
=
tf
.
math
.
softmax
(
outputs
)
return
outputs
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment