Commit 1d38cca0 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 343441984
parent 3dae7a77
...@@ -54,18 +54,30 @@ class VideoClassificationTask(base_task.Task): ...@@ -54,18 +54,30 @@ class VideoClassificationTask(base_task.Task):
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
return model return model
def _get_dataset_fn(self, params):
if params.file_type == 'tfrecord':
return tf.data.TFRecordDataset
else:
raise ValueError('Unknown input file type {!r}'.format(params.file_type))
def _get_decoder_fn(self, params):
decoder = video_input.Decoder()
if self.task_config.train_data.output_audio:
assert self.task_config.train_data.audio_feature, 'audio feature is empty'
decoder.add_feature(self.task_config.train_data.audio_feature,
tf.io.VarLenFeature(dtype=tf.float32))
return decoder.decode
def build_inputs(self, params: exp_cfg.DataConfig, input_context=None): def build_inputs(self, params: exp_cfg.DataConfig, input_context=None):
"""Builds classification input.""" """Builds classification input."""
decoder = video_input.Decoder()
decoder_fn = decoder.decode
parser = video_input.Parser(input_params=params) parser = video_input.Parser(input_params=params)
postprocess_fn = video_input.PostBatchProcessor(params) postprocess_fn = video_input.PostBatchProcessor(params)
reader = input_reader.InputReader( reader = input_reader.InputReader(
params, params,
dataset_fn=tf.data.TFRecordDataset, dataset_fn=self._get_dataset_fn(params),
decoder_fn=decoder_fn, decoder_fn=self._get_decoder_fn(params),
parser_fn=parser.parse_fn(params.is_training), parser_fn=parser.parse_fn(params.is_training),
postprocess_fn=postprocess_fn) postprocess_fn=postprocess_fn)
...@@ -183,6 +195,9 @@ class VideoClassificationTask(base_task.Task): ...@@ -183,6 +195,9 @@ class VideoClassificationTask(base_task.Task):
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
if self.task_config.train_data.output_audio:
outputs = model(features, training=True)
else:
outputs = model(features['image'], training=True) outputs = model(features['image'], training=True)
# Casting output layer as float32 is necessary when mixed_precision is # Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32. # mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
...@@ -237,7 +252,7 @@ class VideoClassificationTask(base_task.Task): ...@@ -237,7 +252,7 @@ class VideoClassificationTask(base_task.Task):
""" """
features, labels = inputs features, labels = inputs
outputs = self.inference_step(features['image'], model) outputs = self.inference_step(features, model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs) outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
logs = self.build_losses(model_outputs=outputs, labels=labels, logs = self.build_losses(model_outputs=outputs, labels=labels,
aux_losses=model.losses) aux_losses=model.losses)
...@@ -250,9 +265,12 @@ class VideoClassificationTask(base_task.Task): ...@@ -250,9 +265,12 @@ class VideoClassificationTask(base_task.Task):
logs.update({m.name: m.result() for m in model.metrics}) logs.update({m.name: m.result() for m in model.metrics})
return logs return logs
def inference_step(self, inputs, model): def inference_step(self, features, model):
"""Performs the forward step.""" """Performs the forward step."""
outputs = model(inputs, training=False) if self.task_config.train_data.output_audio:
outputs = model(features, training=False)
else:
outputs = model(features['image'], training=False)
if self.task_config.train_data.is_multilabel: if self.task_config.train_data.is_multilabel:
outputs = tf.math.sigmoid(outputs) outputs = tf.math.sigmoid(outputs)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment