# Copyright 2022 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Contains modules related to Inception networks.""" from typing import Callable, Dict, Optional, Sequence, Set, Text, Tuple, Type, Union import tensorflow as tf from official.projects.s3d.modeling import net_utils from official.vision.beta.modeling.layers import nn_blocks_3d INCEPTION_V1_CONV_ENDPOINTS = [ 'Conv2d_1a_7x7', 'Conv2d_2c_3x3', 'Mixed_3b', 'Mixed_3c', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 'Mixed_4f', 'Mixed_5b', 'Mixed_5c' ] # Mapping from endpoint to branch filters. The endpoint shapes below are # specific for input 64x224x224. INCEPTION_V1_ARCH_SKELETON = [ ('Mixed_3b', [[64], [96, 128], [16, 32], [32]]), # 32x28x28x256 ('Mixed_3c', [[128], [128, 192], [32, 96], [64]]), # 32x28x28x480 ('MaxPool_4a_3x3', [[3, 3, 3], [2, 2, 2]]), # 16x14x14x480 ('Mixed_4b', [[192], [96, 208], [16, 48], [64]]), # 16x14x14x512 ('Mixed_4c', [[160], [112, 224], [24, 64], [64]]), # 16x14x14x512 ('Mixed_4d', [[128], [128, 256], [24, 64], [64]]), # 16x14x14x512 ('Mixed_4e', [[112], [144, 288], [32, 64], [64]]), # 16x14x14x528 ('Mixed_4f', [[256], [160, 320], [32, 128], [128]]), # 16x14x14x832 ('MaxPool_5a_2x2', [[2, 2, 2], [2, 2, 2]]), # 8x7x7x832 ('Mixed_5b', [[256], [160, 320], [32, 128], [128]]), # 8x7x7x832 ('Mixed_5c', [[384], [192, 384], [48, 128], [128]]), # 8x7x7x1024 ] INCEPTION_V1_LOCAL_SKELETON = [ ('MaxPool_5a_2x2_local', [[2, 2, 2], [2, 2, 2]]), # 8x7x7x832 ('Mixed_5b_local', [[256], [160, 320], [32, 128], [128]]), # 8x7x7x832 ('Mixed_5c_local', [[384], [192, 384], [48, 128], [128]]), # 8x7x7x1024 ] initializers = tf.keras.initializers regularizers = tf.keras.regularizers def inception_v1_stem_cells( inputs: tf.Tensor, depth_multiplier: float, final_endpoint: Text, temporal_conv_endpoints: Optional[Set[Text]] = None, self_gating_endpoints: Optional[Set[Text]] = None, temporal_conv_type: Text = '3d', first_temporal_kernel_size: int = 7, use_sync_bn: bool = False, norm_momentum: float = 0.999, norm_epsilon: float = 0.001, temporal_conv_initializer: Union[ Text, initializers.Initializer] = initializers.TruncatedNormal( mean=0.0, stddev=0.01), kernel_initializer: Union[Text, initializers.Initializer] = 'truncated_normal', kernel_regularizer: Union[Text, regularizers.Regularizer] = 'l2', parameterized_conv_layer: Type[ net_utils.ParameterizedConvLayer] = net_utils.ParameterizedConvLayer, layer_naming_fn: Callable[[Text], Text] = lambda end_point: None, ) -> Tuple[tf.Tensor, Dict[Text, tf.Tensor]]: """Stem cells used in the original I3D/S3D model. Args: inputs: A 5-D float tensor of size [batch_size, num_frames, height, width, channels]. depth_multiplier: A float to reduce/increase number of channels. final_endpoint: Specifies the endpoint to construct the network up to. It can be one of ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 'Conv2d_2c_3x3', 'MaxPool_3a_3x3']. temporal_conv_endpoints: Specifies the endpoints where to perform temporal convolution. self_gating_endpoints: Specifies the endpoints where to perform self gating. temporal_conv_type: '3d' for I3D model and '2+1d' for S3D model. first_temporal_kernel_size: temporal kernel size of the first convolution layer. use_sync_bn: If True, use synchronized batch normalization. norm_momentum: A `float` of normalization momentum for the moving average. norm_epsilon: A `float` added to variance to avoid dividing by zero. temporal_conv_initializer: Weight initializer for temporal convolution inside the cell. It only applies to 2+1d and 1+2d cases. kernel_initializer: Weight initializer for convolutional layers other than temporal convolution. kernel_regularizer: Weight regularizer for all convolutional layers. parameterized_conv_layer: class for parameterized conv layer. layer_naming_fn: function to customize conv / pooling layer names given endpoint name of the block. This is mainly used to creat model that is compatible with TF1 checkpoints. Returns: A dictionary from components of the network to the corresponding activation. """ if temporal_conv_endpoints is None: temporal_conv_endpoints = set() if self_gating_endpoints is None: self_gating_endpoints = set() if use_sync_bn: batch_norm = tf.keras.layers.experimental.SyncBatchNormalization else: batch_norm = tf.keras.layers.BatchNormalization if tf.keras.backend.image_data_format() == 'channels_last': bn_axis = -1 else: bn_axis = 1 end_points = {} # batch_size x 32 x 112 x 112 x 64 end_point = 'Conv2d_1a_7x7' net = tf.keras.layers.Conv3D( filters=net_utils.apply_depth_multiplier(64, depth_multiplier), kernel_size=[first_temporal_kernel_size, 7, 7], strides=[2, 2, 2], padding='same', use_bias=False, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name=layer_naming_fn(end_point))( inputs) net = batch_norm( axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon, scale=False, gamma_initializer='ones', name=layer_naming_fn(end_point + '/BatchNorm'))( net) net = tf.nn.relu(net) end_points[end_point] = net if final_endpoint == end_point: return net, end_points # batch_size x 32 x 56 x 56 x 64 end_point = 'MaxPool_2a_3x3' net = tf.keras.layers.MaxPool3D( pool_size=[1, 3, 3], strides=[1, 2, 2], padding='same', name=layer_naming_fn(end_point))( net) end_points[end_point] = net if final_endpoint == end_point: return net, end_points # batch_size x 32 x 56 x 56 x 64 end_point = 'Conv2d_2b_1x1' net = tf.keras.layers.Conv3D( filters=net_utils.apply_depth_multiplier(64, depth_multiplier), strides=[1, 1, 1], kernel_size=[1, 1, 1], padding='same', use_bias=False, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name=layer_naming_fn(end_point))( net) net = batch_norm( axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon, scale=False, gamma_initializer='ones', name=layer_naming_fn(end_point + '/BatchNorm'))( net) net = tf.nn.relu(net) end_points[end_point] = net if final_endpoint == end_point: return net, end_points # batch_size x 32 x 56 x 56 x 192 end_point = 'Conv2d_2c_3x3' if end_point not in temporal_conv_endpoints: temporal_conv_type = '2d' net = parameterized_conv_layer( conv_type=temporal_conv_type, kernel_size=3, filters=net_utils.apply_depth_multiplier(192, depth_multiplier), strides=[1, 1, 1], rates=[1, 1, 1], use_sync_bn=use_sync_bn, norm_momentum=norm_momentum, norm_epsilon=norm_epsilon, temporal_conv_initializer=temporal_conv_initializer, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name=layer_naming_fn(end_point))( net) if end_point in self_gating_endpoints: net = nn_blocks_3d.SelfGating( filters=net_utils.apply_depth_multiplier(192, depth_multiplier), name=layer_naming_fn(end_point + '/self_gating'))( net) end_points[end_point] = net if final_endpoint == end_point: return net, end_points # batch_size x 32 x 28 x 28 x 192 end_point = 'MaxPool_3a_3x3' net = tf.keras.layers.MaxPool3D( pool_size=[1, 3, 3], strides=[1, 2, 2], padding='same', name=layer_naming_fn(end_point))( net) end_points[end_point] = net return net, end_points def _construct_branch_3_layers( channels: int, swap_pool_and_1x1x1: bool, pool_type: Text, batch_norm_layer: tf.keras.layers.Layer, kernel_initializer: Union[Text, initializers.Initializer], kernel_regularizer: Union[Text, regularizers.Regularizer], ): """Helper function for Branch 3 inside Inception module.""" kernel_size = [1, 3, 3] if pool_type == '2d' else [3] * 3 conv = tf.keras.layers.Conv3D( filters=channels, kernel_size=[1, 1, 1], padding='same', use_bias=False, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer) activation = tf.keras.layers.Activation('relu') pool = tf.keras.layers.MaxPool3D( pool_size=kernel_size, strides=[1, 1, 1], padding='same') if swap_pool_and_1x1x1: branch_3_layers = [conv, batch_norm_layer, activation, pool] else: branch_3_layers = [pool, conv, batch_norm_layer, activation] return branch_3_layers class InceptionV1CellLayer(tf.keras.layers.Layer): """A single Tensorflow 2 cell used in the original I3D/S3D model.""" def __init__( self, branch_filters: Sequence[Sequence[int]], conv_type: Text = '3d', temporal_dilation_rate: int = 1, swap_pool_and_1x1x1: bool = False, use_self_gating_on_branch: bool = False, use_self_gating_on_cell: bool = False, use_sync_bn: bool = False, norm_momentum: float = 0.999, norm_epsilon: float = 0.001, temporal_conv_initializer: Union[ Text, initializers.Initializer] = initializers.TruncatedNormal( mean=0.0, stddev=0.01), kernel_initializer: Union[Text, initializers.Initializer] = 'truncated_normal', kernel_regularizer: Union[Text, regularizers.Regularizer] = 'l2', parameterized_conv_layer: Type[ net_utils.ParameterizedConvLayer] = net_utils.ParameterizedConvLayer, **kwargs): """A cell structure inspired by Inception V1. Args: branch_filters: Specifies the number of filters in four branches (Branch_0, Branch_1, Branch_2, Branch_3). Single number for Branch_0 and Branch_3. For Branch_1 and Branch_2, each need to specify two numbers, one for 1x1x1 and one for 3x3x3. conv_type: The type of parameterized convolution. Currently, we support '2d', '3d', '2+1d', '1+2d'. temporal_dilation_rate: The dilation rate for temporal convolution. swap_pool_and_1x1x1: A boolean flag indicates that whether to swap the order of convolution and max pooling in Branch_3. use_self_gating_on_branch: Whether or not to apply self gating on each branch of the inception cell. use_self_gating_on_cell: Whether or not to apply self gating on each cell after the concatenation of all branches. use_sync_bn: If True, use synchronized batch normalization. norm_momentum: A `float` of normalization momentum for the moving average. norm_epsilon: A `float` added to variance to avoid dividing by zero. temporal_conv_initializer: Weight initializer for temporal convolution inside the cell. It only applies to 2+1d and 1+2d cases. kernel_initializer: Weight initializer for convolutional layers other than temporal convolution. kernel_regularizer: Weight regularizer for all convolutional layers. parameterized_conv_layer: class for parameterized conv layer. **kwargs: keyword arguments to be passed. Returns: out_tensor: A 5-D float tensor of size [batch_size, num_frames, height, width, channels]. """ super(InceptionV1CellLayer, self).__init__(**kwargs) self._branch_filters = branch_filters self._conv_type = conv_type self._temporal_dilation_rate = temporal_dilation_rate self._swap_pool_and_1x1x1 = swap_pool_and_1x1x1 self._use_self_gating_on_branch = use_self_gating_on_branch self._use_self_gating_on_cell = use_self_gating_on_cell self._use_sync_bn = use_sync_bn self._norm_momentum = norm_momentum self._norm_epsilon = norm_epsilon self._temporal_conv_initializer = temporal_conv_initializer self._kernel_initializer = kernel_initializer self._kernel_regularizer = kernel_regularizer self._parameterized_conv_layer = parameterized_conv_layer if use_sync_bn: self._norm = tf.keras.layers.experimental.SyncBatchNormalization else: self._norm = tf.keras.layers.BatchNormalization if tf.keras.backend.image_data_format() == 'channels_last': self._channel_axis = -1 else: self._channel_axis = 1 def _build_branch_params(self): branch_0_params = [ # Conv3D dict( filters=self._branch_filters[0][0], kernel_size=[1, 1, 1], padding='same', use_bias=False, kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer), # norm dict( axis=self._channel_axis, momentum=self._norm_momentum, epsilon=self._norm_epsilon, scale=False, gamma_initializer='ones'), # relu dict(), ] branch_1_params = [ # Conv3D dict( filters=self._branch_filters[1][0], kernel_size=[1, 1, 1], padding='same', use_bias=False, kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer), # norm dict( axis=self._channel_axis, momentum=self._norm_momentum, epsilon=self._norm_epsilon, scale=False, gamma_initializer='ones'), # relu dict(), # ParameterizedConvLayer dict( conv_type=self._conv_type, kernel_size=3, filters=self._branch_filters[1][1], strides=[1, 1, 1], rates=[self._temporal_dilation_rate, 1, 1], use_sync_bn=self._use_sync_bn, norm_momentum=self._norm_momentum, norm_epsilon=self._norm_epsilon, temporal_conv_initializer=self._temporal_conv_initializer, kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer), ] branch_2_params = [ # Conv3D dict( filters=self._branch_filters[2][0], kernel_size=[1, 1, 1], padding='same', use_bias=False, kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer), # norm dict( axis=self._channel_axis, momentum=self._norm_momentum, epsilon=self._norm_epsilon, scale=False, gamma_initializer='ones'), # relu dict(), # ParameterizedConvLayer dict( conv_type=self._conv_type, kernel_size=3, filters=self._branch_filters[2][1], strides=[1, 1, 1], rates=[self._temporal_dilation_rate, 1, 1], use_sync_bn=self._use_sync_bn, norm_momentum=self._norm_momentum, norm_epsilon=self._norm_epsilon, temporal_conv_initializer=self._temporal_conv_initializer, kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer) ] branch_3_params = [ # Conv3D dict( filters=self._branch_filters[3][0], kernel_size=[1, 1, 1], padding='same', use_bias=False, kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer), # norm dict( axis=self._channel_axis, momentum=self._norm_momentum, epsilon=self._norm_epsilon, scale=False, gamma_initializer='ones'), # relu dict(), # pool dict( pool_size=([1, 3, 3] if self._conv_type == '2d' else [3] * 3), strides=[1, 1, 1], padding='same') ] if self._use_self_gating_on_branch: branch_0_params.append(dict(filters=self._branch_filters[0][0])) branch_1_params.append(dict(filters=self._branch_filters[1][1])) branch_2_params.append(dict(filters=self._branch_filters[2][1])) branch_3_params.append(dict(filters=self._branch_filters[3][0])) out_gating_params = [] if self._use_self_gating_on_cell: out_channels = ( self._branch_filters[0][0] + self._branch_filters[1][1] + self._branch_filters[2][1] + self._branch_filters[3][0]) out_gating_params.append(dict(filters=out_channels)) return [ branch_0_params, branch_1_params, branch_2_params, branch_3_params, out_gating_params ] def build(self, input_shape): branch_params = self._build_branch_params() self._branch_0_layers = [ tf.keras.layers.Conv3D(**branch_params[0][0]), self._norm(**branch_params[0][1]), tf.keras.layers.Activation('relu', **branch_params[0][2]), ] self._branch_1_layers = [ tf.keras.layers.Conv3D(**branch_params[1][0]), self._norm(**branch_params[1][1]), tf.keras.layers.Activation('relu', **branch_params[1][2]), self._parameterized_conv_layer(**branch_params[1][3]), ] self._branch_2_layers = [ tf.keras.layers.Conv3D(**branch_params[2][0]), self._norm(**branch_params[2][1]), tf.keras.layers.Activation('relu', **branch_params[2][2]), self._parameterized_conv_layer(**branch_params[2][3]) ] if self._swap_pool_and_1x1x1: self._branch_3_layers = [ tf.keras.layers.Conv3D(**branch_params[3][0]), self._norm(**branch_params[3][1]), tf.keras.layers.Activation('relu', **branch_params[3][2]), tf.keras.layers.MaxPool3D(**branch_params[3][3]), ] else: self._branch_3_layers = [ tf.keras.layers.MaxPool3D(**branch_params[3][3]), tf.keras.layers.Conv3D(**branch_params[3][0]), self._norm(**branch_params[3][1]), tf.keras.layers.Activation('relu', **branch_params[3][2]), ] if self._use_self_gating_on_branch: self._branch_0_layers.append( nn_blocks_3d.SelfGating(**branch_params[0][-1])) self._branch_1_layers.append( nn_blocks_3d.SelfGating(**branch_params[1][-1])) self._branch_2_layers.append( nn_blocks_3d.SelfGating(**branch_params[2][-1])) self._branch_3_layers.append( nn_blocks_3d.SelfGating(**branch_params[3][-1])) if self._use_self_gating_on_cell: self.cell_self_gating = nn_blocks_3d.SelfGating(**branch_params[4][0]) super(InceptionV1CellLayer, self).build(input_shape) def call(self, inputs): x = inputs for layer in self._branch_0_layers: x = layer(x) branch_0 = x x = inputs for layer in self._branch_1_layers: x = layer(x) branch_1 = x x = inputs for layer in self._branch_2_layers: x = layer(x) branch_2 = x x = inputs for layer in self._branch_3_layers: x = layer(x) branch_3 = x out_tensor = tf.concat([branch_0, branch_1, branch_2, branch_3], axis=self._channel_axis) if self._use_self_gating_on_cell: out_tensor = self.cell_self_gating(out_tensor) return out_tensor