Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
7aa320c5
Commit
7aa320c5
authored
Oct 14, 2022
by
Chaochao Yan
Committed by
A. Unique TensorFlower
Oct 14, 2022
Browse files
Internal change
PiperOrigin-RevId: 481234282
parent
9bcbe962
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
99 additions
and
40 deletions
+99
-40
official/projects/yt8m/tasks/yt8m_task.py
official/projects/yt8m/tasks/yt8m_task.py
+99
-40
No files found.
official/projects/yt8m/tasks/yt8m_task.py
View file @
7aa320c5
...
@@ -13,6 +13,8 @@
...
@@ -13,6 +13,8 @@
# limitations under the License.
# limitations under the License.
"""Video classification task definition."""
"""Video classification task definition."""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -95,31 +97,46 @@ class YT8MTask(base_task.Task):
...
@@ -95,31 +97,46 @@ class YT8MTask(base_task.Task):
return
dataset
return
dataset
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
):
def
build_losses
(
self
,
labels
,
model_outputs
,
label_weights
=
None
,
aux_losses
=
None
):
"""Sigmoid Cross Entropy.
"""Sigmoid Cross Entropy.
Args:
Args:
labels: tensor containing truth labels.
labels: tensor containing truth labels.
model_outputs: output logits of the classifier.
model_outputs: output logits of the classifier.
label_weights: optional tensor of label weights.
aux_losses: tensor containing auxiliarly loss tensors, i.e. `losses` in
aux_losses: tensor containing auxiliarly loss tensors, i.e. `losses` in
keras.Model.
keras.Model.
Returns:
Returns:
Tensors: The
total loss, model loss tensors.
A dict of tensors contains
total loss, model loss tensors.
"""
"""
losses_config
=
self
.
task_config
.
losses
losses_config
=
self
.
task_config
.
losses
model_loss
=
tf
.
keras
.
losses
.
binary_crossentropy
(
model_loss
=
tf
.
keras
.
losses
.
binary_crossentropy
(
labels
,
labels
,
model_outputs
,
model_outputs
,
from_logits
=
losses_config
.
from_logits
,
from_logits
=
losses_config
.
from_logits
,
label_smoothing
=
losses_config
.
label_smoothing
)
label_smoothing
=
losses_config
.
label_smoothing
,
axis
=
None
)
if
label_weights
is
None
:
model_loss
=
tf_utils
.
safe_mean
(
model_loss
)
model_loss
=
tf_utils
.
safe_mean
(
model_loss
)
else
:
model_loss
=
model_loss
*
label_weights
# Manutally compute weighted mean loss.
total_loss
=
tf
.
reduce_sum
(
model_loss
)
total_weight
=
tf
.
cast
(
tf
.
reduce_sum
(
label_weights
),
dtype
=
total_loss
.
dtype
)
model_loss
=
tf
.
math
.
divide_no_nan
(
total_loss
,
total_weight
)
total_loss
=
model_loss
total_loss
=
model_loss
if
aux_losses
:
if
aux_losses
:
total_loss
+=
tf
.
add_n
(
aux_losses
)
total_loss
+=
tf
.
add_n
(
aux_losses
)
return
total_loss
,
model_loss
return
{
'total_loss'
:
total_loss
,
'
model_loss
'
:
model_loss
}
def
build_metrics
(
self
,
training
=
True
):
def
build_metrics
(
self
,
training
=
True
):
"""Gets streaming metrics for training/validation.
"""Gets streaming metrics for training/validation.
...
@@ -130,10 +147,10 @@ class YT8MTask(base_task.Task):
...
@@ -130,10 +147,10 @@ class YT8MTask(base_task.Task):
top_n: A positive Integer specifying the average precision at n, or None
top_n: A positive Integer specifying the average precision at n, or None
to use all provided data points.
to use all provided data points.
Args:
Args:
training:
b
ool value, true for training mode, false for eval/validation.
training:
B
ool value, true for training mode, false for eval/validation.
Returns:
Returns:
list of strings that indicate metrics to be used
A
list of strings that indicate metrics to be used
.
"""
"""
metrics
=
[]
metrics
=
[]
metric_names
=
[
'total_loss'
,
'model_loss'
]
metric_names
=
[
'total_loss'
,
'model_loss'
]
...
@@ -149,15 +166,48 @@ class YT8MTask(base_task.Task):
...
@@ -149,15 +166,48 @@ class YT8MTask(base_task.Task):
return
metrics
return
metrics
def
process_metrics
(
self
,
metrics
:
List
[
tf
.
keras
.
metrics
.
Metric
],
labels
:
tf
.
Tensor
,
outputs
:
tf
.
Tensor
,
model_losses
:
Optional
[
Dict
[
str
,
tf
.
Tensor
]]
=
None
,
label_weights
:
Optional
[
tf
.
Tensor
]
=
None
,
training
:
bool
=
True
,
**
kwargs
)
->
Dict
[
str
,
Tuple
[
tf
.
Tensor
,
...]]:
"""Updates metrics.
Args:
metrics: Evaluation metrics to be updated.
labels: A tensor containing truth labels.
outputs: Model output logits of the classifier.
model_losses: An optional dict of model losses.
label_weights: Optional label weights, can be broadcast into shape of
outputs/labels.
training: Bool indicates if in training mode.
**kwargs: Additional input arguments.
Returns:
Updated dict of metrics log.
"""
if
model_losses
is
None
:
model_losses
=
{}
logs
=
{}
if
not
training
:
logs
.
update
({
self
.
avg_prec_metric
.
name
:
(
labels
,
outputs
)})
for
m
in
metrics
:
m
.
update_state
(
model_losses
[
m
.
name
])
logs
[
m
.
name
]
=
m
.
result
()
return
logs
def
train_step
(
self
,
inputs
,
model
,
optimizer
,
metrics
=
None
):
def
train_step
(
self
,
inputs
,
model
,
optimizer
,
metrics
=
None
):
"""Does forward and backward.
"""Does forward and backward.
Args:
Args:
inputs: a dictionary of input tensors. output_dict = {
inputs: a dictionary of input tensors. output_dict = { "video_ids":
"video_ids": batch_video_ids,
batch_video_ids, "video_matrix": batch_video_matrix, "labels":
"video_matrix": batch_video_matrix,
batch_labels, "num_frames": batch_frames, }
"labels": batch_labels,
"num_frames": batch_frames, }
model: the model, forward pass definition.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
metrics: a nested structure of metrics objects.
...
@@ -167,6 +217,7 @@ class YT8MTask(base_task.Task):
...
@@ -167,6 +217,7 @@ class YT8MTask(base_task.Task):
"""
"""
features
,
labels
=
inputs
[
'video_matrix'
],
inputs
[
'labels'
]
features
,
labels
=
inputs
[
'video_matrix'
],
inputs
[
'labels'
]
num_frames
=
inputs
[
'num_frames'
]
num_frames
=
inputs
[
'num_frames'
]
label_weights
=
inputs
.
get
(
'label_weights'
,
None
)
# sample random frames / random sequence
# sample random frames / random sequence
num_frames
=
tf
.
cast
(
num_frames
,
tf
.
float32
)
num_frames
=
tf
.
cast
(
num_frames
,
tf
.
float32
)
...
@@ -183,26 +234,28 @@ class YT8MTask(base_task.Task):
...
@@ -183,26 +234,28 @@ class YT8MTask(base_task.Task):
# Casting output layer as float32 is necessary when mixed_precision is
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
# Computes per-replica loss
# Computes per-replica loss
loss
,
model_loss
=
self
.
build_losses
(
all_losses
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
model_outputs
=
outputs
,
labels
=
labels
,
label_weights
=
label_weights
,
aux_losses
=
model
.
losses
)
loss
=
all_losses
[
'total_loss'
]
# Scales loss as the default gradients allreduce performs sum inside the
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
# optimizer.
scaled_loss
=
loss
/
num_replicas
scaled_loss
=
loss
/
num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
# scaled for numerical stability.
if
isinstance
(
optimizer
,
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
scaled_loss
=
optimizer
.
get_scaled_loss
(
scaled_loss
)
scaled_loss
=
optimizer
.
get_scaled_loss
(
scaled_loss
)
tvars
=
model
.
trainable_variables
tvars
=
model
.
trainable_variables
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
# used.
if
isinstance
(
optimizer
,
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
# Apply gradient clipping.
# Apply gradient clipping.
...
@@ -213,12 +266,14 @@ class YT8MTask(base_task.Task):
...
@@ -213,12 +266,14 @@ class YT8MTask(base_task.Task):
logs
=
{
self
.
loss
:
loss
}
logs
=
{
self
.
loss
:
loss
}
all_losses
=
{
'total_loss'
:
loss
,
'model_loss'
:
model_loss
}
logs
.
update
(
self
.
process_metrics
(
if
metrics
:
metrics
,
for
m
in
metrics
:
labels
=
labels
,
m
.
update_state
(
all_losses
[
m
.
name
])
outputs
=
outputs
,
logs
.
update
({
m
.
name
:
m
.
result
()})
model_losses
=
all_losses
,
label_weights
=
label_weights
,
training
=
True
))
return
logs
return
logs
...
@@ -226,11 +281,9 @@ class YT8MTask(base_task.Task):
...
@@ -226,11 +281,9 @@ class YT8MTask(base_task.Task):
"""Validatation step.
"""Validatation step.
Args:
Args:
inputs: a dictionary of input tensors. output_dict = {
inputs: a dictionary of input tensors. output_dict = { "video_ids":
"video_ids": batch_video_ids,
batch_video_ids, "video_matrix": batch_video_matrix, "labels":
"video_matrix": batch_video_matrix,
batch_labels, "num_frames": batch_frames, }
"labels": batch_labels,
"num_frames": batch_frames, }
model: the model, forward definition
model: the model, forward definition
metrics: a nested structure of metrics objects.
metrics: a nested structure of metrics objects.
...
@@ -239,6 +292,7 @@ class YT8MTask(base_task.Task):
...
@@ -239,6 +292,7 @@ class YT8MTask(base_task.Task):
"""
"""
features
,
labels
=
inputs
[
'video_matrix'
],
inputs
[
'labels'
]
features
,
labels
=
inputs
[
'video_matrix'
],
inputs
[
'labels'
]
num_frames
=
inputs
[
'num_frames'
]
num_frames
=
inputs
[
'num_frames'
]
label_weights
=
inputs
.
get
(
'label_weights'
,
None
)
# sample random frames (None, 5, 1152) -> (None, 30, 1152)
# sample random frames (None, 5, 1152) -> (None, 30, 1152)
sample_frames
=
self
.
task_config
.
validation_data
.
num_frames
sample_frames
=
self
.
task_config
.
validation_data
.
num_frames
...
@@ -252,23 +306,28 @@ class YT8MTask(base_task.Task):
...
@@ -252,23 +306,28 @@ class YT8MTask(base_task.Task):
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
if
self
.
task_config
.
validation_data
.
segment_labels
:
if
self
.
task_config
.
validation_data
.
segment_labels
:
# workaround to ignore the unrated labels.
# workaround to ignore the unrated labels.
outputs
*=
inputs
[
'
label_weights
'
]
outputs
*=
label_weights
# remove padding
# remove padding
outputs
=
outputs
[
~
tf
.
reduce_all
(
labels
==
-
1
,
axis
=
1
)]
outputs
=
outputs
[
~
tf
.
reduce_all
(
labels
==
-
1
,
axis
=
1
)]
labels
=
labels
[
~
tf
.
reduce_all
(
labels
==
-
1
,
axis
=
1
)]
labels
=
labels
[
~
tf
.
reduce_all
(
labels
==
-
1
,
axis
=
1
)]
loss
,
model_loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
logs
=
{
self
.
loss
:
loss
}
all_losses
=
self
.
build_losses
(
labels
=
labels
,
model_outputs
=
outputs
,
label_weights
=
label_weights
,
aux_losses
=
model
.
losses
)
all_losses
=
{
'total_
loss
'
:
loss
,
'model_loss'
:
mode
l_loss
}
logs
=
{
self
.
loss
:
all_losses
[
'tota
l_loss
'
]
}
logs
.
update
({
self
.
avg_prec_metric
.
name
:
(
labels
,
outputs
)})
logs
.
update
(
self
.
process_metrics
(
metrics
,
labels
=
labels
,
outputs
=
outputs
,
model_losses
=
all_losses
,
label_weights
=
inputs
.
get
(
'label_weights'
,
None
),
training
=
False
))
if
metrics
:
for
m
in
metrics
:
m
.
update_state
(
all_losses
[
m
.
name
])
logs
.
update
({
m
.
name
:
m
.
result
()})
return
logs
return
logs
def
inference_step
(
self
,
inputs
,
model
):
def
inference_step
(
self
,
inputs
,
model
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment