Commit b7cbd12b authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 342166637
parent 0a52f120
......@@ -48,6 +48,7 @@ class DataConfig(cfg.DataConfig):
is_training: bool = True
cycle_length: int = 10
min_image_size: int = 256
is_multilabel: bool = False
def kinetics400(is_training):
......
......@@ -146,6 +146,11 @@ def _process_label(label: tf.Tensor,
if one_hot_label:
# Replace label index by one hot representation.
label = tf.one_hot(label, num_classes)
if len(label.shape.as_list()) > 1:
label = tf.reduce_sum(label, axis=0)
if num_classes == 1:
# The trick for single label.
label = 1 - label
return label
......@@ -154,11 +159,11 @@ class Decoder(decoder.Decoder):
"""A tf.Example decoder for classification task."""
def __init__(self, image_key: str = IMAGE_KEY, label_key: str = LABEL_KEY):
self._image_key = IMAGE_KEY
self._label_key = LABEL_KEY
self._image_key = image_key
self._label_key = label_key
self._context_description = {
# One integer stored in context.
self._label_key: tf.io.FixedLenFeature((), tf.int64),
self._label_key: tf.io.VarLenFeature(tf.int64),
}
self._sequence_description = {
# Each image is a string encoding JPEG.
......@@ -172,7 +177,7 @@ class Decoder(decoder.Decoder):
self._sequence_description)
return {
self._image_key: sequences[self._image_key],
self._label_key: context[self._label_key]
self._label_key: tf.sparse.to_dense(context[self._label_key])
}
......@@ -200,7 +205,6 @@ class Parser(parser.Parser):
"""Parses data for training."""
# Process image and label.
image = decoded_tensors[self._image_key]
label = decoded_tensors[self._label_key]
image = _process_image(
image=image,
is_training=True,
......@@ -210,6 +214,8 @@ class Parser(parser.Parser):
min_resize=self._min_resize,
crop_size=self._crop_size)
image = tf.cast(image, dtype=self._dtype)
label = decoded_tensors[self._label_key]
label = _process_label(label, self._one_hot_label, self._num_classes)
return {'image': image}, label
......@@ -219,7 +225,6 @@ class Parser(parser.Parser):
) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
"""Parses data for evaluation."""
image = decoded_tensors[self._image_key]
label = decoded_tensors[self._label_key]
image = _process_image(
image=image,
is_training=False,
......@@ -229,6 +234,8 @@ class Parser(parser.Parser):
min_resize=self._min_resize,
crop_size=self._crop_size)
image = tf.cast(image, dtype=self._dtype)
label = decoded_tensors[self._label_key]
label = _process_label(label, self._one_hot_label, self._num_classes)
return {'image': image}, label
......
......@@ -84,22 +84,41 @@ class VideoClassificationTask(base_task.Task):
Returns:
The total loss tensor.
"""
all_losses = {}
losses_config = self.task_config.losses
if losses_config.one_hot:
total_loss = tf.keras.losses.categorical_crossentropy(
labels,
model_outputs,
from_logits=True,
label_smoothing=losses_config.label_smoothing)
total_loss = None
if self.task_config.train_data.is_multilabel:
entropy = -tf.reduce_mean(
tf.reduce_sum(model_outputs * tf.math.log(model_outputs + 1e-8), -1))
total_loss = tf.keras.losses.binary_crossentropy(
labels, model_outputs, from_logits=False)
all_losses.update({
'class_loss': total_loss,
'entropy': entropy,
})
else:
total_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, model_outputs, from_logits=True)
if losses_config.one_hot:
total_loss = tf.keras.losses.categorical_crossentropy(
labels,
model_outputs,
from_logits=False,
label_smoothing=losses_config.label_smoothing)
else:
total_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, model_outputs, from_logits=False)
total_loss = tf_utils.safe_mean(total_loss)
total_loss = tf_utils.safe_mean(total_loss)
all_losses.update({
'class_loss': total_loss,
})
if aux_losses:
all_losses.update({
'reg_loss': aux_losses,
})
total_loss += tf.add_n(aux_losses)
all_losses[self.loss] = total_loss
return total_loss
return all_losses
def build_metrics(self, training=True):
"""Gets streaming metrics for training/validation."""
......@@ -109,6 +128,20 @@ class VideoClassificationTask(base_task.Task):
tf.keras.metrics.TopKCategoricalAccuracy(k=1, name='top_1_accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='top_5_accuracy')
]
if self.task_config.train_data.is_multilabel:
metrics.append(
tf.keras.metrics.AUC(
curve='ROC',
multi_label=self.task_config.train_data.is_multilabel,
name='ROC-AUC'))
metrics.append(
tf.keras.metrics.RecallAtPrecision(
0.95, name='RecallAtPrecision95'))
metrics.append(
tf.keras.metrics.AUC(
curve='PR',
multi_label=self.task_config.train_data.is_multilabel,
name='PR-AUC'))
else:
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
......@@ -119,6 +152,21 @@ class VideoClassificationTask(base_task.Task):
]
return metrics
def process_metrics(self, metrics, labels, model_outputs):
"""Process and update metrics.
Called when using custom training loop API.
Args:
metrics: a nested structure of metrics objects. The return of function
self.build_metrics.
labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors. For example,
output of the keras model built by self.build_model.
"""
for metric in metrics:
metric.update_state(labels, model_outputs)
def train_step(self, inputs, model, optimizer, metrics=None):
"""Does forward and backward.
......@@ -142,8 +190,13 @@ class VideoClassificationTask(base_task.Task):
lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss.
loss = self.build_losses(
if self.task_config.train_data.is_multilabel:
outputs = tf.math.sigmoid(outputs)
else:
outputs = tf.math.softmax(outputs)
all_losses = self.build_losses(
model_outputs=outputs, labels=labels, aux_losses=model.losses)
loss = all_losses[self.loss]
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / num_replicas
......@@ -162,7 +215,7 @@ class VideoClassificationTask(base_task.Task):
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss}
logs = all_losses
if metrics:
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
......@@ -186,10 +239,9 @@ class VideoClassificationTask(base_task.Task):
outputs = self.inference_step(features['image'], model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
loss = self.build_losses(model_outputs=outputs, labels=labels,
logs = self.build_losses(model_outputs=outputs, labels=labels,
aux_losses=model.losses)
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
......@@ -200,4 +252,9 @@ class VideoClassificationTask(base_task.Task):
def inference_step(self, inputs, model):
"""Performs the forward step."""
return model(inputs, training=False)
outputs = model(inputs, training=False)
if self.task_config.train_data.is_multilabel:
outputs = tf.math.sigmoid(outputs)
else:
outputs = tf.math.softmax(outputs)
return outputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment