Commit 9bfc2db4 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 338710422
parent cdc4cad7
...@@ -80,7 +80,7 @@ class VideoClassificationModel(hyperparams.Config): ...@@ -80,7 +80,7 @@ class VideoClassificationModel(hyperparams.Config):
type='resnet_3d', resnet_3d=backbones_3d.ResNet3D50()) type='resnet_3d', resnet_3d=backbones_3d.ResNet3D50())
norm_activation: common.NormActivation = common.NormActivation() norm_activation: common.NormActivation = common.NormActivation()
dropout_rate: float = 0.2 dropout_rate: float = 0.2
add_head_batch_norm: bool = False aggregate_endpoints: bool = False
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -19,8 +19,8 @@ import tensorflow as tf ...@@ -19,8 +19,8 @@ import tensorflow as tf
from official.core import registry from official.core import registry
from official.vision.beta.configs import video_classification as video_classification_cfg from official.vision.beta.configs import video_classification as video_classification_cfg
from official.vision.beta.modeling import backbones
from official.vision.beta.modeling import video_classification_model from official.vision.beta.modeling import video_classification_model
from official.vision.beta.modeling import backbones
_REGISTERED_MODEL_CLS = {} _REGISTERED_MODEL_CLS = {}
...@@ -88,15 +88,11 @@ def build_video_classification_model( ...@@ -88,15 +88,11 @@ def build_video_classification_model(
model_config=model_config, model_config=model_config,
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
norm_activation_config = model_config.norm_activation
model = video_classification_model.VideoClassificationModel( model = video_classification_model.VideoClassificationModel(
backbone=backbone, backbone=backbone,
num_classes=num_classes, num_classes=num_classes,
input_specs=input_specs, input_specs=input_specs,
dropout_rate=model_config.dropout_rate, dropout_rate=model_config.dropout_rate,
kernel_regularizer=l2_regularizer, aggregate_endpoints=model_config.aggregate_endpoints,
add_head_batch_norm=model_config.add_head_batch_norm, kernel_regularizer=l2_regularizer)
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon)
return model return model
...@@ -28,13 +28,10 @@ class VideoClassificationModel(tf.keras.Model): ...@@ -28,13 +28,10 @@ class VideoClassificationModel(tf.keras.Model):
num_classes, num_classes,
input_specs=layers.InputSpec(shape=[None, None, None, None, 3]), input_specs=layers.InputSpec(shape=[None, None, None, None, 3]),
dropout_rate=0.0, dropout_rate=0.0,
aggregate_endpoints=False,
kernel_initializer='random_uniform', kernel_initializer='random_uniform',
kernel_regularizer=None, kernel_regularizer=None,
bias_regularizer=None, bias_regularizer=None,
add_head_batch_norm=False,
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
**kwargs): **kwargs):
"""Video Classification initialization function. """Video Classification initialization function.
...@@ -43,17 +40,13 @@ class VideoClassificationModel(tf.keras.Model): ...@@ -43,17 +40,13 @@ class VideoClassificationModel(tf.keras.Model):
num_classes: `int` number of classes in classification task. num_classes: `int` number of classes in classification task.
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor. input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
dropout_rate: `float` rate for dropout regularization. dropout_rate: `float` rate for dropout regularization.
aggregate_endpoints: `bool` aggregate all end ponits or only use the
final end point.
kernel_initializer: kernel initializer for the dense layer. kernel_initializer: kernel initializer for the dense layer.
kernel_regularizer: tf.keras.regularizers.Regularizer object. Default to kernel_regularizer: tf.keras.regularizers.Regularizer object. Default to
None. None.
bias_regularizer: tf.keras.regularizers.Regularizer object. Default to bias_regularizer: tf.keras.regularizers.Regularizer object. Default to
None. None.
add_head_batch_norm: `bool` whether to add a batch normalization layer
before pool.
use_sync_bn: `bool` if True, use synchronized batch normalization.
norm_momentum: `float` normalization momentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
""" """
self._self_setattr_tracking = False self._self_setattr_tracking = False
...@@ -62,31 +55,29 @@ class VideoClassificationModel(tf.keras.Model): ...@@ -62,31 +55,29 @@ class VideoClassificationModel(tf.keras.Model):
'num_classes': num_classes, 'num_classes': num_classes,
'input_specs': input_specs, 'input_specs': input_specs,
'dropout_rate': dropout_rate, 'dropout_rate': dropout_rate,
'aggregate_endpoints': aggregate_endpoints,
'kernel_initializer': kernel_initializer, 'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer, 'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer, 'bias_regularizer': bias_regularizer,
'add_head_batch_norm': add_head_batch_norm,
'use_sync_bn': use_sync_bn,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
} }
self._input_specs = input_specs self._input_specs = input_specs
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
self._backbone = backbone self._backbone = backbone
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
axis = -1 if tf.keras.backend.image_data_format() == 'channels_last' else 1
inputs = tf.keras.Input(shape=input_specs.shape[1:]) inputs = tf.keras.Input(shape=input_specs.shape[1:])
endpoints = backbone(inputs) endpoints = backbone(inputs)
x = endpoints[max(endpoints.keys())]
if add_head_batch_norm: if aggregate_endpoints:
x = self._norm(axis=axis, momentum=norm_momentum, epsilon=norm_epsilon)(x) pooled_feats = []
x = tf.keras.layers.GlobalAveragePooling3D()(x) for endpoint in endpoints.values():
x_pool = tf.keras.layers.GlobalAveragePooling3D()(endpoint)
pooled_feats.append(x_pool)
x = tf.concat(pooled_feats, axis=1)
else:
x = endpoints[max(endpoints.keys())]
x = tf.keras.layers.GlobalAveragePooling3D()(x)
x = tf.keras.layers.Dropout(dropout_rate)(x) x = tf.keras.layers.Dropout(dropout_rate)(x)
x = tf.keras.layers.Dense( x = tf.keras.layers.Dense(
num_classes, kernel_initializer=kernel_initializer, num_classes, kernel_initializer=kernel_initializer,
......
...@@ -27,11 +27,12 @@ from official.vision.beta.modeling import video_classification_model ...@@ -27,11 +27,12 @@ from official.vision.beta.modeling import video_classification_model
class VideoClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase): class VideoClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
(50, 8, 112, 'relu'), (50, 8, 112, 'relu', False),
(50, 8, 112, 'swish'), (50, 8, 112, 'swish', True),
) )
def test_resnet3d_network_creation(self, model_id, temporal_size, def test_resnet3d_network_creation(self, model_id, temporal_size,
spatial_size, activation): spatial_size, activation,
aggregate_endpoints):
"""Test for creation of a ResNet3D-50 classifier.""" """Test for creation of a ResNet3D-50 classifier."""
input_specs = tf.keras.layers.InputSpec( input_specs = tf.keras.layers.InputSpec(
shape=[None, temporal_size, spatial_size, spatial_size, 3]) shape=[None, temporal_size, spatial_size, spatial_size, 3])
...@@ -54,6 +55,7 @@ class VideoClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase): ...@@ -54,6 +55,7 @@ class VideoClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase):
num_classes=num_classes, num_classes=num_classes,
input_specs=input_specs, input_specs=input_specs,
dropout_rate=0.2, dropout_rate=0.2,
aggregate_endpoints=aggregate_endpoints,
) )
inputs = np.random.rand(2, temporal_size, spatial_size, spatial_size, 3) inputs = np.random.rand(2, temporal_size, spatial_size, spatial_size, 3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment