Commit cc30e3e9 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 388566700
parent 7797ebad
......@@ -50,6 +50,7 @@ class ASPP(hyperparams.Config):
dilation_rates: List[int] = dataclasses.field(default_factory=list)
dropout_rate: float = 0.0
num_filters: int = 256
use_depthwise_convolution: bool = False
pool_kernel_size: Optional[List[int]] = None # Use global average pooling.
......
......@@ -42,6 +42,7 @@ class ASPP(tf.keras.layers.Layer):
kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
interpolation: str = 'bilinear',
use_depthwise_convolution: bool = False,
**kwargs):
"""Initializes an Atrous Spatial Pyramid Pooling (ASPP) layer.
......@@ -64,6 +65,8 @@ class ASPP(tf.keras.layers.Layer):
interpolation: A `str` of interpolation method. It should be one of
`bilinear`, `nearest`, `bicubic`, `area`, `lanczos3`, `lanczos5`,
`gaussian`, or `mitchellcubic`.
use_depthwise_convolution: If True depthwise separable convolutions will
be added to the Atrous spatial pyramid pooling.
**kwargs: Additional keyword arguments to be passed.
"""
super(ASPP, self).__init__(**kwargs)
......@@ -80,6 +83,7 @@ class ASPP(tf.keras.layers.Layer):
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'interpolation': interpolation,
'use_depthwise_convolution': use_depthwise_convolution,
}
def build(self, input_shape):
......@@ -100,7 +104,9 @@ class ASPP(tf.keras.layers.Layer):
dropout=self._config_dict['dropout_rate'],
kernel_initializer=self._config_dict['kernel_initializer'],
kernel_regularizer=self._config_dict['kernel_regularizer'],
interpolation=self._config_dict['interpolation'])
interpolation=self._config_dict['interpolation'],
use_depthwise_convolution=self._config_dict['use_depthwise_convolution']
)
def call(self, inputs: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
"""Calls the Atrous Spatial Pyramid Pooling (ASPP) layer on an input.
......
......@@ -70,6 +70,7 @@ class ASPPTest(parameterized.TestCase, tf.test.TestCase):
kernel_regularizer=None,
interpolation='bilinear',
dropout_rate=0.2,
use_depthwise_convolution='false',
)
network = aspp.ASPP(**kwargs)
......
......@@ -21,9 +21,11 @@ import tensorflow as tf
class SpatialPyramidPooling(tf.keras.layers.Layer):
"""Implements the Atrous Spatial Pyramid Pooling.
Reference:
References:
[Rethinking Atrous Convolution for Semantic Image Segmentation](
https://arxiv.org/pdf/1706.05587.pdf)
[Encoder-Decoder with Atrous Separable Convolution for Semantic Image
Segmentation](https://arxiv.org/pdf/1802.02611.pdf)
"""
def __init__(
......@@ -39,6 +41,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
kernel_initializer='glorot_uniform',
kernel_regularizer=None,
interpolation='bilinear',
use_depthwise_convolution=False,
**kwargs):
"""Initializes `SpatialPyramidPooling`.
......@@ -60,6 +63,10 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
kernel_regularizer: Kernel regularizer for conv layers. Defaults to None.
interpolation: The interpolation method for upsampling. Defaults to
`bilinear`.
use_depthwise_convolution: Allows spatial pooling to be separable
depthwise convolusions. [Encoder-Decoder with Atrous Separable
Convolution for Semantic Image Segmentation](
https://arxiv.org/pdf/1802.02611.pdf)
**kwargs: Other keyword arguments for the layer.
"""
super(SpatialPyramidPooling, self).__init__(**kwargs)
......@@ -76,6 +83,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self.interpolation = interpolation
self.input_spec = tf.keras.layers.InputSpec(ndim=4)
self.pool_kernel_size = pool_kernel_size
self.use_depthwise_convolution = use_depthwise_convolution
def build(self, input_shape):
height = input_shape[1]
......@@ -109,9 +117,20 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self.aspp_layers.append(conv_sequential)
for dilation_rate in self.dilation_rates:
conv_sequential = tf.keras.Sequential([
leading_layers = []
kernel_size = (3, 3)
if self.use_depthwise_convolution:
leading_layers += [
tf.keras.layers.DepthwiseConv2D(
depth_multiplier=1, kernel_size=kernel_size,
padding='same', depthwise_regularizer=self.kernel_regularizer,
depthwise_initializer=self.kernel_initializer,
dilation_rate=dilation_rate, use_bias=False)
]
kernel_size = (1, 1)
conv_sequential = tf.keras.Sequential(leading_layers + [
tf.keras.layers.Conv2D(
filters=self.output_channels, kernel_size=(3, 3),
filters=self.output_channels, kernel_size=kernel_size,
padding='same', kernel_regularizer=self.kernel_regularizer,
kernel_initializer=self.kernel_initializer,
dilation_rate=dilation_rate, use_bias=False),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment