Commit f143929a authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 339559679
parent b625f436
......@@ -30,6 +30,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self,
output_channels,
dilation_rates,
pool_kernel_size=None,
use_sync_bn=False,
batchnorm_momentum=0.99,
batchnorm_epsilon=0.001,
......@@ -44,6 +45,9 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
Arguments:
output_channels: Number of channels produced by SpatialPyramidPooling.
dilation_rates: A list of integers for parallel dilated conv.
pool_kernel_size: A list of integers or None. If None, global average
pooling is applied, otherwise an average pooling of pool_kernel_size
is applied.
use_sync_bn: A bool, whether or not to use sync batch normalization.
batchnorm_momentum: A float for the momentum in BatchNorm. Defaults to
0.99.
......@@ -71,6 +75,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self.interpolation = interpolation
self.input_spec = tf.keras.layers.InputSpec(ndim=4)
self.pool_kernel_size = pool_kernel_size
def build(self, input_shape):
height = input_shape[1]
......@@ -115,11 +120,19 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
tf.keras.layers.Activation(self.activation)])
self.aspp_layers.append(conv_sequential)
if self.pool_kernel_size is None:
pool_sequential = tf.keras.Sequential([
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Reshape((1, 1, channels)),
tf.keras.layers.Reshape((1, 1, channels))])
else:
pool_sequential = tf.keras.Sequential([
tf.keras.layers.AveragePooling2D(self.pool_kernel_size)])
pool_sequential.add(
tf.keras.Sequential([
tf.keras.layers.Conv2D(
filters=self.output_channels, kernel_size=(1, 1),
filters=self.output_channels,
kernel_size=(1, 1),
kernel_initializer=self.kernel_initializer,
kernel_regularizer=self.kernel_regularizer,
use_bias=False),
......@@ -129,7 +142,9 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
epsilon=self.batchnorm_epsilon),
tf.keras.layers.Activation(self.activation),
tf.keras.layers.experimental.preprocessing.Resizing(
height, width, interpolation=self.interpolation)])
height, width, interpolation=self.interpolation)
]))
self.aspp_layers.append(pool_sequential)
self.projection = tf.keras.Sequential([
......@@ -159,6 +174,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
config = {
'output_channels': self.output_channels,
'dilation_rates': self.dilation_rates,
'pool_kernel_size': self.pool_kernel_size,
'use_sync_bn': self.use_sync_bn,
'batchnorm_momentum': self.batchnorm_momentum,
'batchnorm_epsilon': self.batchnorm_epsilon,
......
......@@ -23,10 +23,15 @@ from official.vision.keras_cv.layers import deeplab
@keras_parameterized.run_all_keras_modes
class DeeplabTest(keras_parameterized.TestCase):
def test_aspp(self):
@keras_parameterized.parameterized.parameters(
(None,),
([32, 32],),
)
def test_aspp(self, pool_kernel_size):
inputs = tf.keras.Input(shape=(64, 64, 128), dtype=tf.float32)
layer = deeplab.SpatialPyramidPooling(output_channels=256,
dilation_rates=[6, 12, 18])
dilation_rates=[6, 12, 18],
pool_kernel_size=None)
output = layer(inputs)
self.assertAllEqual([None, 64, 64, 256], output.shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment