Commit ae871f41 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 356580296
parent 1596bb28
...@@ -83,6 +83,7 @@ def build_video_classification_model( ...@@ -83,6 +83,7 @@ def build_video_classification_model(
num_classes: int, num_classes: int,
l2_regularizer: tf.keras.regularizers.Regularizer = None): l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds the video classification model.""" """Builds the video classification model."""
input_specs_dict = {'image': input_specs}
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
input_specs=input_specs, input_specs=input_specs,
model_config=model_config, model_config=model_config,
...@@ -91,7 +92,7 @@ def build_video_classification_model( ...@@ -91,7 +92,7 @@ def build_video_classification_model(
model = video_classification_model.VideoClassificationModel( model = video_classification_model.VideoClassificationModel(
backbone=backbone, backbone=backbone,
num_classes=num_classes, num_classes=num_classes,
input_specs=input_specs, input_specs=input_specs_dict,
dropout_rate=model_config.dropout_rate, dropout_rate=model_config.dropout_rate,
aggregate_endpoints=model_config.aggregate_endpoints, aggregate_endpoints=model_config.aggregate_endpoints,
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer)
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Build video classification models.""" """Build video classification models."""
# Import libraries from typing import Mapping
import tensorflow as tf import tensorflow as tf
layers = tf.keras.layers layers = tf.keras.layers
...@@ -24,11 +24,11 @@ class VideoClassificationModel(tf.keras.Model): ...@@ -24,11 +24,11 @@ class VideoClassificationModel(tf.keras.Model):
"""A video classification class builder.""" """A video classification class builder."""
def __init__(self, def __init__(self,
backbone, backbone: tf.keras.Model,
num_classes, num_classes: int,
input_specs=layers.InputSpec(shape=[None, None, None, None, 3]), input_specs: Mapping[str, tf.keras.layers.InputSpec] = None,
dropout_rate=0.0, dropout_rate: float = 0.0,
aggregate_endpoints=False, aggregate_endpoints: bool = False,
kernel_initializer='random_uniform', kernel_initializer='random_uniform',
kernel_regularizer=None, kernel_regularizer=None,
bias_regularizer=None, bias_regularizer=None,
...@@ -49,6 +49,10 @@ class VideoClassificationModel(tf.keras.Model): ...@@ -49,6 +49,10 @@ class VideoClassificationModel(tf.keras.Model):
None. None.
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
""" """
if not input_specs:
input_specs = {
'image': layers.InputSpec(shape=[None, None, None, None, 3])
}
self._self_setattr_tracking = False self._self_setattr_tracking = False
self._config_dict = { self._config_dict = {
'backbone': backbone, 'backbone': backbone,
...@@ -65,8 +69,10 @@ class VideoClassificationModel(tf.keras.Model): ...@@ -65,8 +69,10 @@ class VideoClassificationModel(tf.keras.Model):
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
self._backbone = backbone self._backbone = backbone
inputs = tf.keras.Input(shape=input_specs.shape[1:]) inputs = {
endpoints = backbone(inputs) k: tf.keras.Input(shape=v.shape[1:]) for k, v in input_specs.items()
}
endpoints = backbone(inputs['image'])
if aggregate_endpoints: if aggregate_endpoints:
pooled_feats = [] pooled_feats = []
......
...@@ -53,7 +53,7 @@ class VideoClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase): ...@@ -53,7 +53,7 @@ class VideoClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase):
model = video_classification_model.VideoClassificationModel( model = video_classification_model.VideoClassificationModel(
backbone=backbone, backbone=backbone,
num_classes=num_classes, num_classes=num_classes,
input_specs=input_specs, input_specs={'image': input_specs},
dropout_rate=0.2, dropout_rate=0.2,
aggregate_endpoints=aggregate_endpoints, aggregate_endpoints=aggregate_endpoints,
) )
......
...@@ -195,10 +195,7 @@ class VideoClassificationTask(base_task.Task): ...@@ -195,10 +195,7 @@ class VideoClassificationTask(base_task.Task):
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
if self.task_config.train_data.output_audio: outputs = model(features, training=True)
outputs = model(features, training=True)
else:
outputs = model(features['image'], training=True)
# Casting output layer as float32 is necessary when mixed_precision is # Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32. # mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure( outputs = tf.nest.map_structure(
...@@ -267,10 +264,7 @@ class VideoClassificationTask(base_task.Task): ...@@ -267,10 +264,7 @@ class VideoClassificationTask(base_task.Task):
def inference_step(self, features, model): def inference_step(self, features, model):
"""Performs the forward step.""" """Performs the forward step."""
if self.task_config.train_data.output_audio: outputs = model(features, training=False)
outputs = model(features, training=False)
else:
outputs = model(features['image'], training=False)
if self.task_config.train_data.is_multilabel: if self.task_config.train_data.is_multilabel:
outputs = tf.math.sigmoid(outputs) outputs = tf.math.sigmoid(outputs)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment