Commit fb9f35c8 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Adds a dummy task for multi-modal representation learning.

PiperOrigin-RevId: 374239384
parent 9f88ce51
......@@ -30,13 +30,32 @@ from official.vision.beta.modeling import factory_3d
class VideoClassificationTask(base_task.Task):
"""A task for video classification."""
def build_model(self):
"""Builds video classification model."""
common_input_shape = [
def _get_num_classes(self):
"""Gets the number of classes."""
return self.task_config.train_data.num_classes
def _get_feature_shape(self):
"""Get the common feature shape for train and eval."""
return [
d1 if d1 == d2 else None
for d1, d2 in zip(self.task_config.train_data.feature_shape,
self.task_config.validation_data.feature_shape)
]
def _get_num_test_views(self):
"""Gets number of views for test."""
num_test_clips = self.task_config.validation_data.num_test_clips
num_test_crops = self.task_config.validation_data.num_test_crops
num_test_views = num_test_clips * num_test_crops
return num_test_views
def _is_multilabel(self):
"""If the label is multi-labels."""
return self.task_config.train_data.is_multilabel
def build_model(self):
"""Builds video classification model."""
common_input_shape = self._get_feature_shape()
input_specs = tf.keras.layers.InputSpec(shape=[None] + common_input_shape)
logging.info('Build model input %r', common_input_shape)
......@@ -51,7 +70,7 @@ class VideoClassificationTask(base_task.Task):
self.task_config.model.model_type,
input_specs=input_specs,
model_config=self.task_config.model,
num_classes=self.task_config.train_data.num_classes,
num_classes=self._get_num_classes(),
l2_regularizer=l2_regularizer)
return model
......@@ -138,7 +157,7 @@ class VideoClassificationTask(base_task.Task):
all_losses = {}
losses_config = self.task_config.losses
total_loss = None
if self.task_config.train_data.is_multilabel:
if self._is_multilabel():
entropy = -tf.reduce_mean(
tf.reduce_sum(model_outputs * tf.math.log(model_outputs + 1e-8), -1))
total_loss = tf.keras.losses.binary_crossentropy(
......@@ -179,22 +198,18 @@ class VideoClassificationTask(base_task.Task):
tf.keras.metrics.TopKCategoricalAccuracy(k=1, name='top_1_accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='top_5_accuracy')
]
if self.task_config.train_data.is_multilabel:
if self._is_multilabel():
metrics.append(
tf.keras.metrics.AUC(
curve='ROC',
multi_label=self.task_config.train_data.is_multilabel,
name='ROC-AUC'))
curve='ROC', multi_label=self._is_multilabel(), name='ROC-AUC'))
metrics.append(
tf.keras.metrics.RecallAtPrecision(
0.95, name='RecallAtPrecision95'))
metrics.append(
tf.keras.metrics.AUC(
curve='PR',
multi_label=self.task_config.train_data.is_multilabel,
name='PR-AUC'))
curve='PR', multi_label=self._is_multilabel(), name='PR-AUC'))
if self.task_config.metrics.use_per_class_recall:
for i in range(self.task_config.train_data.num_classes):
for i in range(self._get_num_classes()):
metrics.append(
tf.keras.metrics.Recall(class_id=i, name=f'recall-{i}'))
else:
......@@ -250,7 +265,7 @@ class VideoClassificationTask(base_task.Task):
lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss.
if self.task_config.train_data.is_multilabel:
if self._is_multilabel():
outputs = tf.math.sigmoid(outputs)
else:
outputs = tf.math.softmax(outputs)
......@@ -316,13 +331,11 @@ class VideoClassificationTask(base_task.Task):
def inference_step(self, features: tf.Tensor, model: tf.keras.Model):
"""Performs the forward step."""
outputs = model(features, training=False)
if self.task_config.train_data.is_multilabel:
if self._is_multilabel():
outputs = tf.math.sigmoid(outputs)
else:
outputs = tf.math.softmax(outputs)
num_test_clips = self.task_config.validation_data.num_test_clips
num_test_crops = self.task_config.validation_data.num_test_crops
num_test_views = num_test_clips * num_test_crops
num_test_views = self._get_num_test_views()
if num_test_views > 1:
# Averaging output probabilities across multiples views.
outputs = tf.reshape(outputs, [-1, num_test_views, outputs.shape[-1]])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment