Commit a053b555 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 339559679
parent 3a32af9b
...@@ -30,6 +30,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -30,6 +30,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self, self,
output_channels, output_channels,
dilation_rates, dilation_rates,
pool_kernel_size=None,
use_sync_bn=False, use_sync_bn=False,
batchnorm_momentum=0.99, batchnorm_momentum=0.99,
batchnorm_epsilon=0.001, batchnorm_epsilon=0.001,
...@@ -44,6 +45,9 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -44,6 +45,9 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
Arguments: Arguments:
output_channels: Number of channels produced by SpatialPyramidPooling. output_channels: Number of channels produced by SpatialPyramidPooling.
dilation_rates: A list of integers for parallel dilated conv. dilation_rates: A list of integers for parallel dilated conv.
pool_kernel_size: A list of integers or None. If None, global average
pooling is applied, otherwise an average pooling of pool_kernel_size
is applied.
use_sync_bn: A bool, whether or not to use sync batch normalization. use_sync_bn: A bool, whether or not to use sync batch normalization.
batchnorm_momentum: A float for the momentum in BatchNorm. Defaults to batchnorm_momentum: A float for the momentum in BatchNorm. Defaults to
0.99. 0.99.
...@@ -71,6 +75,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -71,6 +75,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self.interpolation = interpolation self.interpolation = interpolation
self.input_spec = tf.keras.layers.InputSpec(ndim=4) self.input_spec = tf.keras.layers.InputSpec(ndim=4)
self.pool_kernel_size = pool_kernel_size
def build(self, input_shape): def build(self, input_shape):
height = input_shape[1] height = input_shape[1]
...@@ -115,11 +120,19 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -115,11 +120,19 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
tf.keras.layers.Activation(self.activation)]) tf.keras.layers.Activation(self.activation)])
self.aspp_layers.append(conv_sequential) self.aspp_layers.append(conv_sequential)
if self.pool_kernel_size is None:
pool_sequential = tf.keras.Sequential([ pool_sequential = tf.keras.Sequential([
tf.keras.layers.GlobalAveragePooling2D(), tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Reshape((1, 1, channels)), tf.keras.layers.Reshape((1, 1, channels))])
else:
pool_sequential = tf.keras.Sequential([
tf.keras.layers.AveragePooling2D(self.pool_kernel_size)])
pool_sequential.add(
tf.keras.Sequential([
tf.keras.layers.Conv2D( tf.keras.layers.Conv2D(
filters=self.output_channels, kernel_size=(1, 1), filters=self.output_channels,
kernel_size=(1, 1),
kernel_initializer=self.kernel_initializer, kernel_initializer=self.kernel_initializer,
kernel_regularizer=self.kernel_regularizer, kernel_regularizer=self.kernel_regularizer,
use_bias=False), use_bias=False),
...@@ -129,7 +142,9 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -129,7 +142,9 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
epsilon=self.batchnorm_epsilon), epsilon=self.batchnorm_epsilon),
tf.keras.layers.Activation(self.activation), tf.keras.layers.Activation(self.activation),
tf.keras.layers.experimental.preprocessing.Resizing( tf.keras.layers.experimental.preprocessing.Resizing(
height, width, interpolation=self.interpolation)]) height, width, interpolation=self.interpolation)
]))
self.aspp_layers.append(pool_sequential) self.aspp_layers.append(pool_sequential)
self.projection = tf.keras.Sequential([ self.projection = tf.keras.Sequential([
...@@ -159,6 +174,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -159,6 +174,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
config = { config = {
'output_channels': self.output_channels, 'output_channels': self.output_channels,
'dilation_rates': self.dilation_rates, 'dilation_rates': self.dilation_rates,
'pool_kernel_size': self.pool_kernel_size,
'use_sync_bn': self.use_sync_bn, 'use_sync_bn': self.use_sync_bn,
'batchnorm_momentum': self.batchnorm_momentum, 'batchnorm_momentum': self.batchnorm_momentum,
'batchnorm_epsilon': self.batchnorm_epsilon, 'batchnorm_epsilon': self.batchnorm_epsilon,
......
...@@ -23,10 +23,15 @@ from official.vision.keras_cv.layers import deeplab ...@@ -23,10 +23,15 @@ from official.vision.keras_cv.layers import deeplab
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class DeeplabTest(keras_parameterized.TestCase): class DeeplabTest(keras_parameterized.TestCase):
def test_aspp(self): @keras_parameterized.parameterized.parameters(
(None,),
([32, 32],),
)
def test_aspp(self, pool_kernel_size):
inputs = tf.keras.Input(shape=(64, 64, 128), dtype=tf.float32) inputs = tf.keras.Input(shape=(64, 64, 128), dtype=tf.float32)
layer = deeplab.SpatialPyramidPooling(output_channels=256, layer = deeplab.SpatialPyramidPooling(output_channels=256,
dilation_rates=[6, 12, 18]) dilation_rates=[6, 12, 18],
pool_kernel_size=None)
output = layer(inputs) output = layer(inputs)
self.assertAllEqual([None, 64, 64, 256], output.shape) self.assertAllEqual([None, 64, 64, 256], output.shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment