Commit 292ec4cb authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 342166637
parent 876b0c05
...@@ -48,6 +48,7 @@ class DataConfig(cfg.DataConfig): ...@@ -48,6 +48,7 @@ class DataConfig(cfg.DataConfig):
is_training: bool = True is_training: bool = True
cycle_length: int = 10 cycle_length: int = 10
min_image_size: int = 256 min_image_size: int = 256
is_multilabel: bool = False
def kinetics400(is_training): def kinetics400(is_training):
......
...@@ -146,6 +146,11 @@ def _process_label(label: tf.Tensor, ...@@ -146,6 +146,11 @@ def _process_label(label: tf.Tensor,
if one_hot_label: if one_hot_label:
# Replace label index by one hot representation. # Replace label index by one hot representation.
label = tf.one_hot(label, num_classes) label = tf.one_hot(label, num_classes)
if len(label.shape.as_list()) > 1:
label = tf.reduce_sum(label, axis=0)
if num_classes == 1:
# The trick for single label.
label = 1 - label
return label return label
...@@ -154,11 +159,11 @@ class Decoder(decoder.Decoder): ...@@ -154,11 +159,11 @@ class Decoder(decoder.Decoder):
"""A tf.Example decoder for classification task.""" """A tf.Example decoder for classification task."""
def __init__(self, image_key: str = IMAGE_KEY, label_key: str = LABEL_KEY): def __init__(self, image_key: str = IMAGE_KEY, label_key: str = LABEL_KEY):
self._image_key = IMAGE_KEY self._image_key = image_key
self._label_key = LABEL_KEY self._label_key = label_key
self._context_description = { self._context_description = {
# One integer stored in context. # One integer stored in context.
self._label_key: tf.io.FixedLenFeature((), tf.int64), self._label_key: tf.io.VarLenFeature(tf.int64),
} }
self._sequence_description = { self._sequence_description = {
# Each image is a string encoding JPEG. # Each image is a string encoding JPEG.
...@@ -172,7 +177,7 @@ class Decoder(decoder.Decoder): ...@@ -172,7 +177,7 @@ class Decoder(decoder.Decoder):
self._sequence_description) self._sequence_description)
return { return {
self._image_key: sequences[self._image_key], self._image_key: sequences[self._image_key],
self._label_key: context[self._label_key] self._label_key: tf.sparse.to_dense(context[self._label_key])
} }
...@@ -200,7 +205,6 @@ class Parser(parser.Parser): ...@@ -200,7 +205,6 @@ class Parser(parser.Parser):
"""Parses data for training.""" """Parses data for training."""
# Process image and label. # Process image and label.
image = decoded_tensors[self._image_key] image = decoded_tensors[self._image_key]
label = decoded_tensors[self._label_key]
image = _process_image( image = _process_image(
image=image, image=image,
is_training=True, is_training=True,
...@@ -210,6 +214,8 @@ class Parser(parser.Parser): ...@@ -210,6 +214,8 @@ class Parser(parser.Parser):
min_resize=self._min_resize, min_resize=self._min_resize,
crop_size=self._crop_size) crop_size=self._crop_size)
image = tf.cast(image, dtype=self._dtype) image = tf.cast(image, dtype=self._dtype)
label = decoded_tensors[self._label_key]
label = _process_label(label, self._one_hot_label, self._num_classes) label = _process_label(label, self._one_hot_label, self._num_classes)
return {'image': image}, label return {'image': image}, label
...@@ -219,7 +225,6 @@ class Parser(parser.Parser): ...@@ -219,7 +225,6 @@ class Parser(parser.Parser):
) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]: ) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
"""Parses data for evaluation.""" """Parses data for evaluation."""
image = decoded_tensors[self._image_key] image = decoded_tensors[self._image_key]
label = decoded_tensors[self._label_key]
image = _process_image( image = _process_image(
image=image, image=image,
is_training=False, is_training=False,
...@@ -229,6 +234,8 @@ class Parser(parser.Parser): ...@@ -229,6 +234,8 @@ class Parser(parser.Parser):
min_resize=self._min_resize, min_resize=self._min_resize,
crop_size=self._crop_size) crop_size=self._crop_size)
image = tf.cast(image, dtype=self._dtype) image = tf.cast(image, dtype=self._dtype)
label = decoded_tensors[self._label_key]
label = _process_label(label, self._one_hot_label, self._num_classes) label = _process_label(label, self._one_hot_label, self._num_classes)
return {'image': image}, label return {'image': image}, label
......
...@@ -84,22 +84,41 @@ class VideoClassificationTask(base_task.Task): ...@@ -84,22 +84,41 @@ class VideoClassificationTask(base_task.Task):
Returns: Returns:
The total loss tensor. The total loss tensor.
""" """
all_losses = {}
losses_config = self.task_config.losses losses_config = self.task_config.losses
if losses_config.one_hot: total_loss = None
total_loss = tf.keras.losses.categorical_crossentropy( if self.task_config.train_data.is_multilabel:
labels, entropy = -tf.reduce_mean(
model_outputs, tf.reduce_sum(model_outputs * tf.math.log(model_outputs + 1e-8), -1))
from_logits=True, total_loss = tf.keras.losses.binary_crossentropy(
label_smoothing=losses_config.label_smoothing) labels, model_outputs, from_logits=False)
all_losses.update({
'class_loss': total_loss,
'entropy': entropy,
})
else: else:
total_loss = tf.keras.losses.sparse_categorical_crossentropy( if losses_config.one_hot:
labels, model_outputs, from_logits=True) total_loss = tf.keras.losses.categorical_crossentropy(
labels,
model_outputs,
from_logits=False,
label_smoothing=losses_config.label_smoothing)
else:
total_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, model_outputs, from_logits=False)
total_loss = tf_utils.safe_mean(total_loss) total_loss = tf_utils.safe_mean(total_loss)
all_losses.update({
'class_loss': total_loss,
})
if aux_losses: if aux_losses:
all_losses.update({
'reg_loss': aux_losses,
})
total_loss += tf.add_n(aux_losses) total_loss += tf.add_n(aux_losses)
all_losses[self.loss] = total_loss
return total_loss return all_losses
def build_metrics(self, training=True): def build_metrics(self, training=True):
"""Gets streaming metrics for training/validation.""" """Gets streaming metrics for training/validation."""
...@@ -109,6 +128,20 @@ class VideoClassificationTask(base_task.Task): ...@@ -109,6 +128,20 @@ class VideoClassificationTask(base_task.Task):
tf.keras.metrics.TopKCategoricalAccuracy(k=1, name='top_1_accuracy'), tf.keras.metrics.TopKCategoricalAccuracy(k=1, name='top_1_accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='top_5_accuracy') tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='top_5_accuracy')
] ]
if self.task_config.train_data.is_multilabel:
metrics.append(
tf.keras.metrics.AUC(
curve='ROC',
multi_label=self.task_config.train_data.is_multilabel,
name='ROC-AUC'))
metrics.append(
tf.keras.metrics.RecallAtPrecision(
0.95, name='RecallAtPrecision95'))
metrics.append(
tf.keras.metrics.AUC(
curve='PR',
multi_label=self.task_config.train_data.is_multilabel,
name='PR-AUC'))
else: else:
metrics = [ metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'), tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
...@@ -119,6 +152,21 @@ class VideoClassificationTask(base_task.Task): ...@@ -119,6 +152,21 @@ class VideoClassificationTask(base_task.Task):
] ]
return metrics return metrics
def process_metrics(self, metrics, labels, model_outputs):
"""Process and update metrics.
Called when using custom training loop API.
Args:
metrics: a nested structure of metrics objects. The return of function
self.build_metrics.
labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors. For example,
output of the keras model built by self.build_model.
"""
for metric in metrics:
metric.update_state(labels, model_outputs)
def train_step(self, inputs, model, optimizer, metrics=None): def train_step(self, inputs, model, optimizer, metrics=None):
"""Does forward and backward. """Does forward and backward.
...@@ -142,8 +190,13 @@ class VideoClassificationTask(base_task.Task): ...@@ -142,8 +190,13 @@ class VideoClassificationTask(base_task.Task):
lambda x: tf.cast(x, tf.float32), outputs) lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss. # Computes per-replica loss.
loss = self.build_losses( if self.task_config.train_data.is_multilabel:
outputs = tf.math.sigmoid(outputs)
else:
outputs = tf.math.softmax(outputs)
all_losses = self.build_losses(
model_outputs=outputs, labels=labels, aux_losses=model.losses) model_outputs=outputs, labels=labels, aux_losses=model.losses)
loss = all_losses[self.loss]
# Scales loss as the default gradients allreduce performs sum inside the # Scales loss as the default gradients allreduce performs sum inside the
# optimizer. # optimizer.
scaled_loss = loss / num_replicas scaled_loss = loss / num_replicas
...@@ -162,7 +215,7 @@ class VideoClassificationTask(base_task.Task): ...@@ -162,7 +215,7 @@ class VideoClassificationTask(base_task.Task):
grads = optimizer.get_unscaled_gradients(grads) grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars))) optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss} logs = all_losses
if metrics: if metrics:
self.process_metrics(metrics, labels, outputs) self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics}) logs.update({m.name: m.result() for m in metrics})
...@@ -186,10 +239,9 @@ class VideoClassificationTask(base_task.Task): ...@@ -186,10 +239,9 @@ class VideoClassificationTask(base_task.Task):
outputs = self.inference_step(features['image'], model) outputs = self.inference_step(features['image'], model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs) outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
loss = self.build_losses(model_outputs=outputs, labels=labels, logs = self.build_losses(model_outputs=outputs, labels=labels,
aux_losses=model.losses) aux_losses=model.losses)
logs = {self.loss: loss}
if metrics: if metrics:
self.process_metrics(metrics, labels, outputs) self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics}) logs.update({m.name: m.result() for m in metrics})
...@@ -200,4 +252,9 @@ class VideoClassificationTask(base_task.Task): ...@@ -200,4 +252,9 @@ class VideoClassificationTask(base_task.Task):
def inference_step(self, inputs, model): def inference_step(self, inputs, model):
"""Performs the forward step.""" """Performs the forward step."""
return model(inputs, training=False) outputs = model(inputs, training=False)
if self.task_config.train_data.is_multilabel:
outputs = tf.math.sigmoid(outputs)
else:
outputs = tf.math.softmax(outputs)
return outputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment