Commit ae871f41 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 356580296
parent 1596bb28
......@@ -83,6 +83,7 @@ def build_video_classification_model(
num_classes: int,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds the video classification model."""
input_specs_dict = {'image': input_specs}
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
model_config=model_config,
......@@ -91,7 +92,7 @@ def build_video_classification_model(
model = video_classification_model.VideoClassificationModel(
backbone=backbone,
num_classes=num_classes,
input_specs=input_specs,
input_specs=input_specs_dict,
dropout_rate=model_config.dropout_rate,
aggregate_endpoints=model_config.aggregate_endpoints,
kernel_regularizer=l2_regularizer)
......
......@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Build video classification models."""
# Import libraries
from typing import Mapping
import tensorflow as tf
layers = tf.keras.layers
......@@ -24,11 +24,11 @@ class VideoClassificationModel(tf.keras.Model):
"""A video classification class builder."""
def __init__(self,
backbone,
num_classes,
input_specs=layers.InputSpec(shape=[None, None, None, None, 3]),
dropout_rate=0.0,
aggregate_endpoints=False,
backbone: tf.keras.Model,
num_classes: int,
input_specs: Mapping[str, tf.keras.layers.InputSpec] = None,
dropout_rate: float = 0.0,
aggregate_endpoints: bool = False,
kernel_initializer='random_uniform',
kernel_regularizer=None,
bias_regularizer=None,
......@@ -49,6 +49,10 @@ class VideoClassificationModel(tf.keras.Model):
None.
**kwargs: keyword arguments to be passed.
"""
if not input_specs:
input_specs = {
'image': layers.InputSpec(shape=[None, None, None, None, 3])
}
self._self_setattr_tracking = False
self._config_dict = {
'backbone': backbone,
......@@ -65,8 +69,10 @@ class VideoClassificationModel(tf.keras.Model):
self._bias_regularizer = bias_regularizer
self._backbone = backbone
inputs = tf.keras.Input(shape=input_specs.shape[1:])
endpoints = backbone(inputs)
inputs = {
k: tf.keras.Input(shape=v.shape[1:]) for k, v in input_specs.items()
}
endpoints = backbone(inputs['image'])
if aggregate_endpoints:
pooled_feats = []
......
......@@ -53,7 +53,7 @@ class VideoClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase):
model = video_classification_model.VideoClassificationModel(
backbone=backbone,
num_classes=num_classes,
input_specs=input_specs,
input_specs={'image': input_specs},
dropout_rate=0.2,
aggregate_endpoints=aggregate_endpoints,
)
......
......@@ -195,10 +195,7 @@ class VideoClassificationTask(base_task.Task):
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
if self.task_config.train_data.output_audio:
outputs = model(features, training=True)
else:
outputs = model(features['image'], training=True)
outputs = model(features, training=True)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(
......@@ -267,10 +264,7 @@ class VideoClassificationTask(base_task.Task):
def inference_step(self, features, model):
"""Performs the forward step."""
if self.task_config.train_data.output_audio:
outputs = model(features, training=False)
else:
outputs = model(features['image'], training=False)
outputs = model(features, training=False)
if self.task_config.train_data.is_multilabel:
outputs = tf.math.sigmoid(outputs)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment