Commit 7aa320c5 authored by Chaochao Yan's avatar Chaochao Yan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 481234282
parent 9bcbe962
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
"""Video classification task definition.""" """Video classification task definition."""
from typing import Dict, List, Optional, Tuple
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -95,31 +97,46 @@ class YT8MTask(base_task.Task): ...@@ -95,31 +97,46 @@ class YT8MTask(base_task.Task):
return dataset return dataset
def build_losses(self, labels, model_outputs, aux_losses=None): def build_losses(self,
labels,
model_outputs,
label_weights=None,
aux_losses=None):
"""Sigmoid Cross Entropy. """Sigmoid Cross Entropy.
Args: Args:
labels: tensor containing truth labels. labels: tensor containing truth labels.
model_outputs: output logits of the classifier. model_outputs: output logits of the classifier.
label_weights: optional tensor of label weights.
aux_losses: tensor containing auxiliarly loss tensors, i.e. `losses` in aux_losses: tensor containing auxiliarly loss tensors, i.e. `losses` in
keras.Model. keras.Model.
Returns: Returns:
Tensors: The total loss, model loss tensors. A dict of tensors contains total loss, model loss tensors.
""" """
losses_config = self.task_config.losses losses_config = self.task_config.losses
model_loss = tf.keras.losses.binary_crossentropy( model_loss = tf.keras.losses.binary_crossentropy(
labels, labels,
model_outputs, model_outputs,
from_logits=losses_config.from_logits, from_logits=losses_config.from_logits,
label_smoothing=losses_config.label_smoothing) label_smoothing=losses_config.label_smoothing,
axis=None)
if label_weights is None:
model_loss = tf_utils.safe_mean(model_loss) model_loss = tf_utils.safe_mean(model_loss)
else:
model_loss = model_loss * label_weights
# Manutally compute weighted mean loss.
total_loss = tf.reduce_sum(model_loss)
total_weight = tf.cast(
tf.reduce_sum(label_weights), dtype=total_loss.dtype)
model_loss = tf.math.divide_no_nan(total_loss, total_weight)
total_loss = model_loss total_loss = model_loss
if aux_losses: if aux_losses:
total_loss += tf.add_n(aux_losses) total_loss += tf.add_n(aux_losses)
return total_loss, model_loss return {'total_loss': total_loss, 'model_loss': model_loss}
def build_metrics(self, training=True): def build_metrics(self, training=True):
"""Gets streaming metrics for training/validation. """Gets streaming metrics for training/validation.
...@@ -130,10 +147,10 @@ class YT8MTask(base_task.Task): ...@@ -130,10 +147,10 @@ class YT8MTask(base_task.Task):
top_n: A positive Integer specifying the average precision at n, or None top_n: A positive Integer specifying the average precision at n, or None
to use all provided data points. to use all provided data points.
Args: Args:
training: bool value, true for training mode, false for eval/validation. training: Bool value, true for training mode, false for eval/validation.
Returns: Returns:
list of strings that indicate metrics to be used A list of strings that indicate metrics to be used.
""" """
metrics = [] metrics = []
metric_names = ['total_loss', 'model_loss'] metric_names = ['total_loss', 'model_loss']
...@@ -149,15 +166,48 @@ class YT8MTask(base_task.Task): ...@@ -149,15 +166,48 @@ class YT8MTask(base_task.Task):
return metrics return metrics
def process_metrics(self,
metrics: List[tf.keras.metrics.Metric],
labels: tf.Tensor,
outputs: tf.Tensor,
model_losses: Optional[Dict[str, tf.Tensor]] = None,
label_weights: Optional[tf.Tensor] = None,
training: bool = True,
**kwargs) -> Dict[str, Tuple[tf.Tensor, ...]]:
"""Updates metrics.
Args:
metrics: Evaluation metrics to be updated.
labels: A tensor containing truth labels.
outputs: Model output logits of the classifier.
model_losses: An optional dict of model losses.
label_weights: Optional label weights, can be broadcast into shape of
outputs/labels.
training: Bool indicates if in training mode.
**kwargs: Additional input arguments.
Returns:
Updated dict of metrics log.
"""
if model_losses is None:
model_losses = {}
logs = {}
if not training:
logs.update({self.avg_prec_metric.name: (labels, outputs)})
for m in metrics:
m.update_state(model_losses[m.name])
logs[m.name] = m.result()
return logs
def train_step(self, inputs, model, optimizer, metrics=None): def train_step(self, inputs, model, optimizer, metrics=None):
"""Does forward and backward. """Does forward and backward.
Args: Args:
inputs: a dictionary of input tensors. output_dict = { inputs: a dictionary of input tensors. output_dict = { "video_ids":
"video_ids": batch_video_ids, batch_video_ids, "video_matrix": batch_video_matrix, "labels":
"video_matrix": batch_video_matrix, batch_labels, "num_frames": batch_frames, }
"labels": batch_labels,
"num_frames": batch_frames, }
model: the model, forward pass definition. model: the model, forward pass definition.
optimizer: the optimizer for this training step. optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects. metrics: a nested structure of metrics objects.
...@@ -167,6 +217,7 @@ class YT8MTask(base_task.Task): ...@@ -167,6 +217,7 @@ class YT8MTask(base_task.Task):
""" """
features, labels = inputs['video_matrix'], inputs['labels'] features, labels = inputs['video_matrix'], inputs['labels']
num_frames = inputs['num_frames'] num_frames = inputs['num_frames']
label_weights = inputs.get('label_weights', None)
# sample random frames / random sequence # sample random frames / random sequence
num_frames = tf.cast(num_frames, tf.float32) num_frames = tf.cast(num_frames, tf.float32)
...@@ -183,26 +234,28 @@ class YT8MTask(base_task.Task): ...@@ -183,26 +234,28 @@ class YT8MTask(base_task.Task):
# Casting output layer as float32 is necessary when mixed_precision is # Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32. # mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs) outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss # Computes per-replica loss
loss, model_loss = self.build_losses( all_losses = self.build_losses(
model_outputs=outputs, labels=labels, aux_losses=model.losses) model_outputs=outputs,
labels=labels,
label_weights=label_weights,
aux_losses=model.losses)
loss = all_losses['total_loss']
# Scales loss as the default gradients allreduce performs sum inside the # Scales loss as the default gradients allreduce performs sum inside the
# optimizer. # optimizer.
scaled_loss = loss / num_replicas scaled_loss = loss / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is # For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability. # scaled for numerical stability.
if isinstance(optimizer, if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss) scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars) grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient before apply_gradients when LossScaleOptimizer is # Scales back gradient before apply_gradients when LossScaleOptimizer is
# used. # used.
if isinstance(optimizer, if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads) grads = optimizer.get_unscaled_gradients(grads)
# Apply gradient clipping. # Apply gradient clipping.
...@@ -213,12 +266,14 @@ class YT8MTask(base_task.Task): ...@@ -213,12 +266,14 @@ class YT8MTask(base_task.Task):
logs = {self.loss: loss} logs = {self.loss: loss}
all_losses = {'total_loss': loss, 'model_loss': model_loss} logs.update(
self.process_metrics(
if metrics: metrics,
for m in metrics: labels=labels,
m.update_state(all_losses[m.name]) outputs=outputs,
logs.update({m.name: m.result()}) model_losses=all_losses,
label_weights=label_weights,
training=True))
return logs return logs
...@@ -226,11 +281,9 @@ class YT8MTask(base_task.Task): ...@@ -226,11 +281,9 @@ class YT8MTask(base_task.Task):
"""Validatation step. """Validatation step.
Args: Args:
inputs: a dictionary of input tensors. output_dict = { inputs: a dictionary of input tensors. output_dict = { "video_ids":
"video_ids": batch_video_ids, batch_video_ids, "video_matrix": batch_video_matrix, "labels":
"video_matrix": batch_video_matrix, batch_labels, "num_frames": batch_frames, }
"labels": batch_labels,
"num_frames": batch_frames, }
model: the model, forward definition model: the model, forward definition
metrics: a nested structure of metrics objects. metrics: a nested structure of metrics objects.
...@@ -239,6 +292,7 @@ class YT8MTask(base_task.Task): ...@@ -239,6 +292,7 @@ class YT8MTask(base_task.Task):
""" """
features, labels = inputs['video_matrix'], inputs['labels'] features, labels = inputs['video_matrix'], inputs['labels']
num_frames = inputs['num_frames'] num_frames = inputs['num_frames']
label_weights = inputs.get('label_weights', None)
# sample random frames (None, 5, 1152) -> (None, 30, 1152) # sample random frames (None, 5, 1152) -> (None, 30, 1152)
sample_frames = self.task_config.validation_data.num_frames sample_frames = self.task_config.validation_data.num_frames
...@@ -252,23 +306,28 @@ class YT8MTask(base_task.Task): ...@@ -252,23 +306,28 @@ class YT8MTask(base_task.Task):
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs) outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
if self.task_config.validation_data.segment_labels: if self.task_config.validation_data.segment_labels:
# workaround to ignore the unrated labels. # workaround to ignore the unrated labels.
outputs *= inputs['label_weights'] outputs *= label_weights
# remove padding # remove padding
outputs = outputs[~tf.reduce_all(labels == -1, axis=1)] outputs = outputs[~tf.reduce_all(labels == -1, axis=1)]
labels = labels[~tf.reduce_all(labels == -1, axis=1)] labels = labels[~tf.reduce_all(labels == -1, axis=1)]
loss, model_loss = self.build_losses(
model_outputs=outputs, labels=labels, aux_losses=model.losses)
logs = {self.loss: loss} all_losses = self.build_losses(
labels=labels,
model_outputs=outputs,
label_weights=label_weights,
aux_losses=model.losses)
all_losses = {'total_loss': loss, 'model_loss': model_loss} logs = {self.loss: all_losses['total_loss']}
logs.update({self.avg_prec_metric.name: (labels, outputs)}) logs.update(
self.process_metrics(
metrics,
labels=labels,
outputs=outputs,
model_losses=all_losses,
label_weights=inputs.get('label_weights', None),
training=False))
if metrics:
for m in metrics:
m.update_state(all_losses[m.name])
logs.update({m.name: m.result()})
return logs return logs
def inference_step(self, inputs, model): def inference_step(self, inputs, model):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment