Commit 98d7c8e7 authored by Liangzhe Yuan's avatar Liangzhe Yuan Committed by A. Unique TensorFlower
Browse files

#movinet Change 'reshape->conv->reshape->bn->activation' pattern to...

#movinet Change 'reshape->conv->reshape->bn->activation' pattern to 'reshape->conv->bn->activation->reshape' in the MobileConv2D layer in movinet.

PiperOrigin-RevId: 425432746
parent 5a182aec
...@@ -93,10 +93,9 @@ class MobileConv2D(tf.keras.layers.Layer): ...@@ -93,10 +93,9 @@ class MobileConv2D(tf.keras.layers.Layer):
data_format: Optional[str] = None, data_format: Optional[str] = None,
dilation_rate: Union[int, Sequence[int]] = (1, 1), dilation_rate: Union[int, Sequence[int]] = (1, 1),
groups: int = 1, groups: int = 1,
activation: Optional[nn_layers.Activation] = None,
use_bias: bool = True, use_bias: bool = True,
kernel_initializer: tf.keras.initializers.Initializer = 'glorot_uniform', kernel_initializer: str = 'glorot_uniform',
bias_initializer: tf.keras.initializers.Initializer = 'zeros', bias_initializer: str = 'zeros',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
activity_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, activity_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
...@@ -105,6 +104,8 @@ class MobileConv2D(tf.keras.layers.Layer): ...@@ -105,6 +104,8 @@ class MobileConv2D(tf.keras.layers.Layer):
use_depthwise: bool = False, use_depthwise: bool = False,
use_temporal: bool = False, use_temporal: bool = False,
use_buffered_input: bool = False, # pytype: disable=annotation-type-mismatch # typed-keras use_buffered_input: bool = False, # pytype: disable=annotation-type-mismatch # typed-keras
batch_norm_op: Optional[Any] = None,
activation_op: Optional[Any] = None,
**kwargs): # pylint: disable=g-doc-args **kwargs): # pylint: disable=g-doc-args
"""Initializes mobile conv2d. """Initializes mobile conv2d.
...@@ -117,6 +118,10 @@ class MobileConv2D(tf.keras.layers.Layer): ...@@ -117,6 +118,10 @@ class MobileConv2D(tf.keras.layers.Layer):
use_buffered_input: if True, the input is expected to be padded use_buffered_input: if True, the input is expected to be padded
beforehand. In effect, calling this layer will use 'valid' padding on beforehand. In effect, calling this layer will use 'valid' padding on
the temporal dimension to simulate 'causal' padding. the temporal dimension to simulate 'causal' padding.
batch_norm_op: A callable object of batch norm layer. If None, no batch
norm will be applied after the convolution.
activation_op: A callabel object of activation layer. If None, no
activation will be applied after the convolution.
**kwargs: keyword arguments to be passed to this layer. **kwargs: keyword arguments to be passed to this layer.
Returns: Returns:
...@@ -130,7 +135,6 @@ class MobileConv2D(tf.keras.layers.Layer): ...@@ -130,7 +135,6 @@ class MobileConv2D(tf.keras.layers.Layer):
self._data_format = data_format self._data_format = data_format
self._dilation_rate = dilation_rate self._dilation_rate = dilation_rate
self._groups = groups self._groups = groups
self._activation = activation
self._use_bias = use_bias self._use_bias = use_bias
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
self._bias_initializer = bias_initializer self._bias_initializer = bias_initializer
...@@ -142,6 +146,8 @@ class MobileConv2D(tf.keras.layers.Layer): ...@@ -142,6 +146,8 @@ class MobileConv2D(tf.keras.layers.Layer):
self._use_depthwise = use_depthwise self._use_depthwise = use_depthwise
self._use_temporal = use_temporal self._use_temporal = use_temporal
self._use_buffered_input = use_buffered_input self._use_buffered_input = use_buffered_input
self._batch_norm_op = batch_norm_op
self._activation_op = activation_op
kernel_size = normalize_tuple(kernel_size, 2, 'kernel_size') kernel_size = normalize_tuple(kernel_size, 2, 'kernel_size')
...@@ -156,7 +162,6 @@ class MobileConv2D(tf.keras.layers.Layer): ...@@ -156,7 +162,6 @@ class MobileConv2D(tf.keras.layers.Layer):
depth_multiplier=1, depth_multiplier=1,
data_format=data_format, data_format=data_format,
dilation_rate=dilation_rate, dilation_rate=dilation_rate,
activation=activation,
use_bias=use_bias, use_bias=use_bias,
depthwise_initializer=kernel_initializer, depthwise_initializer=kernel_initializer,
bias_initializer=bias_initializer, bias_initializer=bias_initializer,
...@@ -175,7 +180,6 @@ class MobileConv2D(tf.keras.layers.Layer): ...@@ -175,7 +180,6 @@ class MobileConv2D(tf.keras.layers.Layer):
data_format=data_format, data_format=data_format,
dilation_rate=dilation_rate, dilation_rate=dilation_rate,
groups=groups, groups=groups,
activation=activation,
use_bias=use_bias, use_bias=use_bias,
kernel_initializer=kernel_initializer, kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer, bias_initializer=bias_initializer,
...@@ -196,7 +200,6 @@ class MobileConv2D(tf.keras.layers.Layer): ...@@ -196,7 +200,6 @@ class MobileConv2D(tf.keras.layers.Layer):
'data_format': self._data_format, 'data_format': self._data_format,
'dilation_rate': self._dilation_rate, 'dilation_rate': self._dilation_rate,
'groups': self._groups, 'groups': self._groups,
'activation': self._activation,
'use_bias': self._use_bias, 'use_bias': self._use_bias,
'kernel_initializer': self._kernel_initializer, 'kernel_initializer': self._kernel_initializer,
'bias_initializer': self._bias_initializer, 'bias_initializer': self._bias_initializer,
...@@ -229,6 +232,10 @@ class MobileConv2D(tf.keras.layers.Layer): ...@@ -229,6 +232,10 @@ class MobileConv2D(tf.keras.layers.Layer):
x = tf.reshape(inputs, input_shape) x = tf.reshape(inputs, input_shape)
x = self._conv(x) x = self._conv(x)
if self._batch_norm_op is not None:
x = self._batch_norm_op(x)
if self._activation_op is not None:
x = self._activation_op(x)
if self._use_temporal: if self._use_temporal:
output_shape = [ output_shape = [
...@@ -357,8 +364,20 @@ class ConvBlock(tf.keras.layers.Layer): ...@@ -357,8 +364,20 @@ class ConvBlock(tf.keras.layers.Layer):
padding = 'causal' if self._causal else 'same' padding = 'causal' if self._causal else 'same'
self._groups = input_shape[-1] if self._depthwise else 1 self._groups = input_shape[-1] if self._depthwise else 1
self._conv_temporal = None self._batch_norm = None
self._batch_norm_temporal = None
if self._use_batch_norm:
self._batch_norm = self._batch_norm_layer(
momentum=self._batch_norm_momentum,
epsilon=self._batch_norm_epsilon,
name='bn')
if self._conv_type != '3d' and self._kernel_size[0] > 1:
self._batch_norm_temporal = self._batch_norm_layer(
momentum=self._batch_norm_momentum,
epsilon=self._batch_norm_epsilon,
name='bn_temporal')
self._conv_temporal = None
if self._conv_type == '3d_2plus1d' and self._kernel_size[0] > 1: if self._conv_type == '3d_2plus1d' and self._kernel_size[0] > 1:
self._conv = nn_layers.Conv3D( self._conv = nn_layers.Conv3D(
self._filters, self._filters,
...@@ -394,6 +413,8 @@ class ConvBlock(tf.keras.layers.Layer): ...@@ -394,6 +413,8 @@ class ConvBlock(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
use_buffered_input=False, use_buffered_input=False,
batch_norm_op=self._batch_norm,
activation_op=self._activation_layer,
name='conv2d') name='conv2d')
if self._kernel_size[0] > 1: if self._kernel_size[0] > 1:
self._conv_temporal = MobileConv2D( self._conv_temporal = MobileConv2D(
...@@ -408,6 +429,8 @@ class ConvBlock(tf.keras.layers.Layer): ...@@ -408,6 +429,8 @@ class ConvBlock(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
use_buffered_input=self._use_buffered_input, use_buffered_input=self._use_buffered_input,
batch_norm_op=self._batch_norm_temporal,
activation_op=self._activation_layer,
name='conv2d_temporal') name='conv2d_temporal')
else: else:
self._conv = nn_layers.Conv3D( self._conv = nn_layers.Conv3D(
...@@ -422,37 +445,26 @@ class ConvBlock(tf.keras.layers.Layer): ...@@ -422,37 +445,26 @@ class ConvBlock(tf.keras.layers.Layer):
use_buffered_input=self._use_buffered_input, use_buffered_input=self._use_buffered_input,
name='conv3d') name='conv3d')
self._batch_norm = None
self._batch_norm_temporal = None
if self._use_batch_norm:
self._batch_norm = self._batch_norm_layer(
momentum=self._batch_norm_momentum,
epsilon=self._batch_norm_epsilon,
name='bn')
if self._conv_type != '3d' and self._conv_temporal is not None:
self._batch_norm_temporal = self._batch_norm_layer(
momentum=self._batch_norm_momentum,
epsilon=self._batch_norm_epsilon,
name='bn_temporal')
super(ConvBlock, self).build(input_shape) super(ConvBlock, self).build(input_shape)
def call(self, inputs): def call(self, inputs):
"""Calls the layer with the given inputs.""" """Calls the layer with the given inputs."""
x = inputs x = inputs
# bn_op and activation_op are folded into the '2plus1d' conv layer so that
# we do not explicitly call them here.
# TODO(lzyuan): clean the conv layers api once the models are re-trained.
x = self._conv(x) x = self._conv(x)
if self._batch_norm is not None: if self._batch_norm is not None and self._conv_type != '2plus1d':
x = self._batch_norm(x) x = self._batch_norm(x)
if self._activation_layer is not None: if self._activation_layer is not None and self._conv_type != '2plus1d':
x = self._activation_layer(x) x = self._activation_layer(x)
if self._conv_temporal is not None: if self._conv_temporal is not None:
x = self._conv_temporal(x) x = self._conv_temporal(x)
if self._batch_norm_temporal is not None: if self._batch_norm_temporal is not None and self._conv_type != '2plus1d':
x = self._batch_norm_temporal(x) x = self._batch_norm_temporal(x)
if self._activation_layer is not None: if self._activation_layer is not None and self._conv_type != '2plus1d':
x = self._activation_layer(x) x = self._activation_layer(x)
return x return x
...@@ -640,10 +652,13 @@ class StreamConvBlock(ConvBlock): ...@@ -640,10 +652,13 @@ class StreamConvBlock(ConvBlock):
if self._conv_temporal is None and self._stream_buffer is not None: if self._conv_temporal is None and self._stream_buffer is not None:
x, states = self._stream_buffer(x, states=states) x, states = self._stream_buffer(x, states=states)
# bn_op and activation_op are folded into the '2plus1d' conv layer so that
# we do not explicitly call them here.
# TODO(lzyuan): clean the conv layers api once the models are re-trained.
x = self._conv(x) x = self._conv(x)
if self._batch_norm is not None: if self._batch_norm is not None and self._conv_type != '2plus1d':
x = self._batch_norm(x) x = self._batch_norm(x)
if self._activation_layer is not None: if self._activation_layer is not None and self._conv_type != '2plus1d':
x = self._activation_layer(x) x = self._activation_layer(x)
if self._conv_temporal is not None: if self._conv_temporal is not None:
...@@ -653,9 +668,9 @@ class StreamConvBlock(ConvBlock): ...@@ -653,9 +668,9 @@ class StreamConvBlock(ConvBlock):
x, states = self._stream_buffer(x, states=states) x, states = self._stream_buffer(x, states=states)
x = self._conv_temporal(x) x = self._conv_temporal(x)
if self._batch_norm_temporal is not None: if self._batch_norm_temporal is not None and self._conv_type != '2plus1d':
x = self._batch_norm_temporal(x) x = self._batch_norm_temporal(x)
if self._activation_layer is not None: if self._activation_layer is not None and self._conv_type != '2plus1d':
x = self._activation_layer(x) x = self._activation_layer(x)
return x, states return x, states
......
...@@ -64,6 +64,72 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -64,6 +64,72 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(predicted.shape, expected.shape) self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected) self.assertAllClose(predicted, expected)
def test_mobile_conv2d_bn(self):
batch_norm_op = tf.keras.layers.BatchNormalization(
momentum=0.9,
epsilon=1.,
name='bn')
conv2d = movinet_layers.MobileConv2D(
filters=3,
kernel_size=(3, 3),
strides=(1, 1),
padding='same',
kernel_initializer='ones',
use_bias=False,
use_depthwise=False,
use_temporal=False,
use_buffered_input=True,
batch_norm_op=batch_norm_op,
)
inputs = tf.ones([1, 2, 2, 2, 3])
predicted = conv2d(inputs)
expected = tf.constant(
[[[[[8.48528, 8.48528, 8.48528],
[8.48528, 8.48528, 8.48528]],
[[8.48528, 8.48528, 8.48528],
[8.48528, 8.48528, 8.48528]]],
[[[8.48528, 8.48528, 8.48528],
[8.48528, 8.48528, 8.48528]],
[[8.48528, 8.48528, 8.48528],
[8.48528, 8.48528, 8.48528]]]]])
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
def test_mobile_conv2d_activation(self):
conv2d = movinet_layers.MobileConv2D(
filters=3,
kernel_size=(3, 3),
strides=(1, 1),
padding='same',
kernel_initializer='ones',
use_bias=False,
use_depthwise=False,
use_temporal=False,
use_buffered_input=True,
activation_op=tf.nn.relu6,
)
inputs = tf.ones([1, 2, 2, 2, 3])
predicted = conv2d(inputs)
expected = tf.constant(
[[[[[6., 6., 6.],
[6., 6., 6.]],
[[6., 6., 6.],
[6., 6., 6.]]],
[[[6., 6., 6.],
[6., 6., 6.]],
[[6., 6., 6.],
[6., 6., 6.]]]]])
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
def test_mobile_conv2d_temporal(self): def test_mobile_conv2d_temporal(self):
conv2d = movinet_layers.MobileConv2D( conv2d = movinet_layers.MobileConv2D(
filters=3, filters=3,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment