Commit ae82b280 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Create a spatial pyramid pooling layer.

PiperOrigin-RevId: 404855140
parent 073170ba
...@@ -970,3 +970,213 @@ class Conv3D(tf.keras.layers.Conv3D, CausalConvMixin): ...@@ -970,3 +970,213 @@ class Conv3D(tf.keras.layers.Conv3D, CausalConvMixin):
"""Computes the spatial output shape from the input shape.""" """Computes the spatial output shape from the input shape."""
shape = super(Conv3D, self)._spatial_output_shape(spatial_input_shape) shape = super(Conv3D, self)._spatial_output_shape(spatial_input_shape)
return self._buffered_spatial_output_shape(shape) return self._buffered_spatial_output_shape(shape)
@tf.keras.utils.register_keras_serializable(package='Vision')
class SpatialPyramidPooling(tf.keras.layers.Layer):
"""Implements the Atrous Spatial Pyramid Pooling.
References:
[Rethinking Atrous Convolution for Semantic Image Segmentation](
https://arxiv.org/pdf/1706.05587.pdf)
[Encoder-Decoder with Atrous Separable Convolution for Semantic Image
Segmentation](https://arxiv.org/pdf/1802.02611.pdf)
"""
def __init__(
self,
output_channels: int,
dilation_rates: List[int],
pool_kernel_size: Optional[List[int]] = None,
use_sync_bn: bool = False,
batchnorm_momentum: float = 0.99,
batchnorm_epsilon: float = 0.001,
activation: str = 'relu',
dropout: float = 0.5,
kernel_initializer: str = 'GlorotUniform',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
interpolation: str = 'bilinear',
use_depthwise_convolution: bool = False,
**kwargs):
"""Initializes `SpatialPyramidPooling`.
Args:
output_channels: Number of channels produced by SpatialPyramidPooling.
dilation_rates: A list of integers for parallel dilated conv.
pool_kernel_size: A list of integers or None. If None, global average
pooling is applied, otherwise an average pooling of pool_kernel_size is
applied.
use_sync_bn: A bool, whether or not to use sync batch normalization.
batchnorm_momentum: A float for the momentum in BatchNorm. Defaults to
0.99.
batchnorm_epsilon: A float for the epsilon value in BatchNorm. Defaults to
0.001.
activation: A `str` for type of activation to be used. Defaults to 'relu'.
dropout: A float for the dropout rate before output. Defaults to 0.5.
kernel_initializer: Kernel initializer for conv layers. Defaults to
`glorot_uniform`.
kernel_regularizer: Kernel regularizer for conv layers. Defaults to None.
interpolation: The interpolation method for upsampling. Defaults to
`bilinear`.
use_depthwise_convolution: Allows spatial pooling to be separable
depthwise convolusions. [Encoder-Decoder with Atrous Separable
Convolution for Semantic Image Segmentation](
https://arxiv.org/pdf/1802.02611.pdf)
**kwargs: Other keyword arguments for the layer.
"""
super().__init__(**kwargs)
self._output_channels = output_channels
self._dilation_rates = dilation_rates
self._use_sync_bn = use_sync_bn
self._batchnorm_momentum = batchnorm_momentum
self._batchnorm_epsilon = batchnorm_epsilon
self._activation = activation
self._dropout = dropout
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._interpolation = interpolation
self._pool_kernel_size = pool_kernel_size
self._use_depthwise_convolution = use_depthwise_convolution
self._activation_fn = tf_utils.get_activation(activation)
if self._use_sync_bn:
self._bn_op = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._bn_op = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
def build(self, input_shape):
height = input_shape[1]
width = input_shape[2]
channels = input_shape[3]
self.aspp_layers = []
conv1 = tf.keras.layers.Conv2D(
filters=self._output_channels,
kernel_size=(1, 1),
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
use_bias=False)
norm1 = self._bn_op(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon)
self.aspp_layers.append([conv1, norm1])
for dilation_rate in self._dilation_rates:
leading_layers = []
kernel_size = (3, 3)
if self._use_depthwise_convolution:
leading_layers += [
tf.keras.layers.DepthwiseConv2D(
depth_multiplier=1,
kernel_size=kernel_size,
padding='same',
depthwise_regularizer=self._kernel_regularizer,
depthwise_initializer=self._kernel_initializer,
dilation_rate=dilation_rate,
use_bias=False)
]
kernel_size = (1, 1)
conv_dilation = leading_layers + [
tf.keras.layers.Conv2D(
filters=self._output_channels,
kernel_size=kernel_size,
padding='same',
kernel_regularizer=self._kernel_regularizer,
kernel_initializer=self._kernel_initializer,
dilation_rate=dilation_rate,
use_bias=False)
]
norm_dilation = self._bn_op(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon)
self.aspp_layers.append(conv_dilation + [norm_dilation])
if self._pool_kernel_size is None:
pooling = [
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Reshape((1, 1, channels))
]
else:
pooling = [tf.keras.layers.AveragePooling2D(self._pool_kernel_size)]
conv2 = tf.keras.layers.Conv2D(
filters=self._output_channels,
kernel_size=(1, 1),
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
use_bias=False)
norm2 = self._bn_op(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon)
self.aspp_layers.append(pooling + [conv2, norm2])
self._resize_layer = tf.keras.layers.Resizing(
height, width, interpolation=self._interpolation, dtype=tf.float32)
self._projection = [
tf.keras.layers.Conv2D(
filters=self._output_channels,
kernel_size=(1, 1),
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
use_bias=False),
self._bn_op(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon)
]
self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout)
self._concat_layer = tf.keras.layers.Concatenate(axis=-1)
def call(self,
inputs: tf.Tensor,
training: Optional[bool] = None) -> tf.Tensor:
if training is None:
training = tf.keras.backend.learning_phase()
result = []
for i, layers in enumerate(self.aspp_layers):
x = inputs
for layer in layers:
# Apply layers sequentially.
x = layer(x, training=training)
x = self._activation_fn(x)
# Apply resize layer to the end of the last set of layers.
if i == len(self.aspp_layers) - 1:
x = self._resize_layer(x)
result.append(tf.cast(x, inputs.dtype))
x = self._concat_layer(result)
for layer in self._projection:
x = layer(x, training=training)
x = self._activation_fn(x)
return self._dropout_layer(x)
def get_config(self):
config = {
'output_channels': self._output_channels,
'dilation_rates': self._dilation_rates,
'pool_kernel_size': self._pool_kernel_size,
'use_sync_bn': self._use_sync_bn,
'batchnorm_momentum': self._batchnorm_momentum,
'batchnorm_epsilon': self._batchnorm_epsilon,
'activation': self._activation,
'dropout': self._dropout,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'interpolation': self._interpolation,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
...@@ -406,5 +406,19 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -406,5 +406,19 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
[[[[[1.]]], [[[[[1.]]],
[[[3.]]]]]) [[[3.]]]]])
@parameterized.parameters(
(None, []),
(None, [6, 12, 18]),
([32, 32], [6, 12, 18]),
)
def test_aspp(self, pool_kernel_size, dilation_rates):
inputs = tf.keras.Input(shape=(64, 64, 128), dtype=tf.float32)
layer = nn_layers.SpatialPyramidPooling(
output_channels=256,
dilation_rates=dilation_rates,
pool_kernel_size=pool_kernel_size)
output = layer(inputs)
self.assertAllEqual([None, 64, 64, 256], output.shape)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment