Commit 7aa320c5 authored by Chaochao Yan's avatar Chaochao Yan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 481234282
parent 9bcbe962
......@@ -13,6 +13,8 @@
# limitations under the License.
"""Video classification task definition."""
from typing import Dict, List, Optional, Tuple
from absl import logging
import tensorflow as tf
......@@ -95,31 +97,46 @@ class YT8MTask(base_task.Task):
return dataset
def build_losses(self, labels, model_outputs, aux_losses=None):
def build_losses(self,
labels,
model_outputs,
label_weights=None,
aux_losses=None):
"""Sigmoid Cross Entropy.
Args:
labels: tensor containing truth labels.
model_outputs: output logits of the classifier.
label_weights: optional tensor of label weights.
aux_losses: tensor containing auxiliarly loss tensors, i.e. `losses` in
keras.Model.
Returns:
Tensors: The total loss, model loss tensors.
A dict of tensors contains total loss, model loss tensors.
"""
losses_config = self.task_config.losses
model_loss = tf.keras.losses.binary_crossentropy(
labels,
model_outputs,
from_logits=losses_config.from_logits,
label_smoothing=losses_config.label_smoothing)
label_smoothing=losses_config.label_smoothing,
axis=None)
if label_weights is None:
model_loss = tf_utils.safe_mean(model_loss)
else:
model_loss = model_loss * label_weights
# Manutally compute weighted mean loss.
total_loss = tf.reduce_sum(model_loss)
total_weight = tf.cast(
tf.reduce_sum(label_weights), dtype=total_loss.dtype)
model_loss = tf.math.divide_no_nan(total_loss, total_weight)
model_loss = tf_utils.safe_mean(model_loss)
total_loss = model_loss
if aux_losses:
total_loss += tf.add_n(aux_losses)
return total_loss, model_loss
return {'total_loss': total_loss, 'model_loss': model_loss}
def build_metrics(self, training=True):
"""Gets streaming metrics for training/validation.
......@@ -130,10 +147,10 @@ class YT8MTask(base_task.Task):
top_n: A positive Integer specifying the average precision at n, or None
to use all provided data points.
Args:
training: bool value, true for training mode, false for eval/validation.
training: Bool value, true for training mode, false for eval/validation.
Returns:
list of strings that indicate metrics to be used
A list of strings that indicate metrics to be used.
"""
metrics = []
metric_names = ['total_loss', 'model_loss']
......@@ -149,15 +166,48 @@ class YT8MTask(base_task.Task):
return metrics
def process_metrics(self,
metrics: List[tf.keras.metrics.Metric],
labels: tf.Tensor,
outputs: tf.Tensor,
model_losses: Optional[Dict[str, tf.Tensor]] = None,
label_weights: Optional[tf.Tensor] = None,
training: bool = True,
**kwargs) -> Dict[str, Tuple[tf.Tensor, ...]]:
"""Updates metrics.
Args:
metrics: Evaluation metrics to be updated.
labels: A tensor containing truth labels.
outputs: Model output logits of the classifier.
model_losses: An optional dict of model losses.
label_weights: Optional label weights, can be broadcast into shape of
outputs/labels.
training: Bool indicates if in training mode.
**kwargs: Additional input arguments.
Returns:
Updated dict of metrics log.
"""
if model_losses is None:
model_losses = {}
logs = {}
if not training:
logs.update({self.avg_prec_metric.name: (labels, outputs)})
for m in metrics:
m.update_state(model_losses[m.name])
logs[m.name] = m.result()
return logs
def train_step(self, inputs, model, optimizer, metrics=None):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors. output_dict = {
"video_ids": batch_video_ids,
"video_matrix": batch_video_matrix,
"labels": batch_labels,
"num_frames": batch_frames, }
inputs: a dictionary of input tensors. output_dict = { "video_ids":
batch_video_ids, "video_matrix": batch_video_matrix, "labels":
batch_labels, "num_frames": batch_frames, }
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
......@@ -167,6 +217,7 @@ class YT8MTask(base_task.Task):
"""
features, labels = inputs['video_matrix'], inputs['labels']
num_frames = inputs['num_frames']
label_weights = inputs.get('label_weights', None)
# sample random frames / random sequence
num_frames = tf.cast(num_frames, tf.float32)
......@@ -183,26 +234,28 @@ class YT8MTask(base_task.Task):
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss
loss, model_loss = self.build_losses(
model_outputs=outputs, labels=labels, aux_losses=model.losses)
all_losses = self.build_losses(
model_outputs=outputs,
labels=labels,
label_weights=label_weights,
aux_losses=model.losses)
loss = all_losses['total_loss']
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(optimizer,
tf.keras.mixed_precision.LossScaleOptimizer):
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if isinstance(optimizer,
tf.keras.mixed_precision.LossScaleOptimizer):
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
# Apply gradient clipping.
......@@ -213,12 +266,14 @@ class YT8MTask(base_task.Task):
logs = {self.loss: loss}
all_losses = {'total_loss': loss, 'model_loss': model_loss}
if metrics:
for m in metrics:
m.update_state(all_losses[m.name])
logs.update({m.name: m.result()})
logs.update(
self.process_metrics(
metrics,
labels=labels,
outputs=outputs,
model_losses=all_losses,
label_weights=label_weights,
training=True))
return logs
......@@ -226,11 +281,9 @@ class YT8MTask(base_task.Task):
"""Validatation step.
Args:
inputs: a dictionary of input tensors. output_dict = {
"video_ids": batch_video_ids,
"video_matrix": batch_video_matrix,
"labels": batch_labels,
"num_frames": batch_frames, }
inputs: a dictionary of input tensors. output_dict = { "video_ids":
batch_video_ids, "video_matrix": batch_video_matrix, "labels":
batch_labels, "num_frames": batch_frames, }
model: the model, forward definition
metrics: a nested structure of metrics objects.
......@@ -239,6 +292,7 @@ class YT8MTask(base_task.Task):
"""
features, labels = inputs['video_matrix'], inputs['labels']
num_frames = inputs['num_frames']
label_weights = inputs.get('label_weights', None)
# sample random frames (None, 5, 1152) -> (None, 30, 1152)
sample_frames = self.task_config.validation_data.num_frames
......@@ -252,23 +306,28 @@ class YT8MTask(base_task.Task):
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
if self.task_config.validation_data.segment_labels:
# workaround to ignore the unrated labels.
outputs *= inputs['label_weights']
outputs *= label_weights
# remove padding
outputs = outputs[~tf.reduce_all(labels == -1, axis=1)]
labels = labels[~tf.reduce_all(labels == -1, axis=1)]
loss, model_loss = self.build_losses(
model_outputs=outputs, labels=labels, aux_losses=model.losses)
logs = {self.loss: loss}
all_losses = self.build_losses(
labels=labels,
model_outputs=outputs,
label_weights=label_weights,
aux_losses=model.losses)
all_losses = {'total_loss': loss, 'model_loss': model_loss}
logs = {self.loss: all_losses['total_loss']}
logs.update({self.avg_prec_metric.name: (labels, outputs)})
logs.update(
self.process_metrics(
metrics,
labels=labels,
outputs=outputs,
model_losses=all_losses,
label_weights=inputs.get('label_weights', None),
training=False))
if metrics:
for m in metrics:
m.update_state(all_losses[m.name])
logs.update({m.name: m.result()})
return logs
def inference_step(self, inputs, model):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment