Commit cc30e3e9 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 388566700
parent 7797ebad
...@@ -50,6 +50,7 @@ class ASPP(hyperparams.Config): ...@@ -50,6 +50,7 @@ class ASPP(hyperparams.Config):
dilation_rates: List[int] = dataclasses.field(default_factory=list) dilation_rates: List[int] = dataclasses.field(default_factory=list)
dropout_rate: float = 0.0 dropout_rate: float = 0.0
num_filters: int = 256 num_filters: int = 256
use_depthwise_convolution: bool = False
pool_kernel_size: Optional[List[int]] = None # Use global average pooling. pool_kernel_size: Optional[List[int]] = None # Use global average pooling.
......
...@@ -42,6 +42,7 @@ class ASPP(tf.keras.layers.Layer): ...@@ -42,6 +42,7 @@ class ASPP(tf.keras.layers.Layer):
kernel_initializer: str = 'VarianceScaling', kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
interpolation: str = 'bilinear', interpolation: str = 'bilinear',
use_depthwise_convolution: bool = False,
**kwargs): **kwargs):
"""Initializes an Atrous Spatial Pyramid Pooling (ASPP) layer. """Initializes an Atrous Spatial Pyramid Pooling (ASPP) layer.
...@@ -64,6 +65,8 @@ class ASPP(tf.keras.layers.Layer): ...@@ -64,6 +65,8 @@ class ASPP(tf.keras.layers.Layer):
interpolation: A `str` of interpolation method. It should be one of interpolation: A `str` of interpolation method. It should be one of
`bilinear`, `nearest`, `bicubic`, `area`, `lanczos3`, `lanczos5`, `bilinear`, `nearest`, `bicubic`, `area`, `lanczos3`, `lanczos5`,
`gaussian`, or `mitchellcubic`. `gaussian`, or `mitchellcubic`.
use_depthwise_convolution: If True depthwise separable convolutions will
be added to the Atrous spatial pyramid pooling.
**kwargs: Additional keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
""" """
super(ASPP, self).__init__(**kwargs) super(ASPP, self).__init__(**kwargs)
...@@ -80,6 +83,7 @@ class ASPP(tf.keras.layers.Layer): ...@@ -80,6 +83,7 @@ class ASPP(tf.keras.layers.Layer):
'kernel_initializer': kernel_initializer, 'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer, 'kernel_regularizer': kernel_regularizer,
'interpolation': interpolation, 'interpolation': interpolation,
'use_depthwise_convolution': use_depthwise_convolution,
} }
def build(self, input_shape): def build(self, input_shape):
...@@ -100,7 +104,9 @@ class ASPP(tf.keras.layers.Layer): ...@@ -100,7 +104,9 @@ class ASPP(tf.keras.layers.Layer):
dropout=self._config_dict['dropout_rate'], dropout=self._config_dict['dropout_rate'],
kernel_initializer=self._config_dict['kernel_initializer'], kernel_initializer=self._config_dict['kernel_initializer'],
kernel_regularizer=self._config_dict['kernel_regularizer'], kernel_regularizer=self._config_dict['kernel_regularizer'],
interpolation=self._config_dict['interpolation']) interpolation=self._config_dict['interpolation'],
use_depthwise_convolution=self._config_dict['use_depthwise_convolution']
)
def call(self, inputs: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]: def call(self, inputs: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
"""Calls the Atrous Spatial Pyramid Pooling (ASPP) layer on an input. """Calls the Atrous Spatial Pyramid Pooling (ASPP) layer on an input.
......
...@@ -70,6 +70,7 @@ class ASPPTest(parameterized.TestCase, tf.test.TestCase): ...@@ -70,6 +70,7 @@ class ASPPTest(parameterized.TestCase, tf.test.TestCase):
kernel_regularizer=None, kernel_regularizer=None,
interpolation='bilinear', interpolation='bilinear',
dropout_rate=0.2, dropout_rate=0.2,
use_depthwise_convolution='false',
) )
network = aspp.ASPP(**kwargs) network = aspp.ASPP(**kwargs)
......
...@@ -21,9 +21,11 @@ import tensorflow as tf ...@@ -21,9 +21,11 @@ import tensorflow as tf
class SpatialPyramidPooling(tf.keras.layers.Layer): class SpatialPyramidPooling(tf.keras.layers.Layer):
"""Implements the Atrous Spatial Pyramid Pooling. """Implements the Atrous Spatial Pyramid Pooling.
Reference: References:
[Rethinking Atrous Convolution for Semantic Image Segmentation]( [Rethinking Atrous Convolution for Semantic Image Segmentation](
https://arxiv.org/pdf/1706.05587.pdf) https://arxiv.org/pdf/1706.05587.pdf)
[Encoder-Decoder with Atrous Separable Convolution for Semantic Image
Segmentation](https://arxiv.org/pdf/1802.02611.pdf)
""" """
def __init__( def __init__(
...@@ -39,6 +41,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -39,6 +41,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
kernel_initializer='glorot_uniform', kernel_initializer='glorot_uniform',
kernel_regularizer=None, kernel_regularizer=None,
interpolation='bilinear', interpolation='bilinear',
use_depthwise_convolution=False,
**kwargs): **kwargs):
"""Initializes `SpatialPyramidPooling`. """Initializes `SpatialPyramidPooling`.
...@@ -60,6 +63,10 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -60,6 +63,10 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
kernel_regularizer: Kernel regularizer for conv layers. Defaults to None. kernel_regularizer: Kernel regularizer for conv layers. Defaults to None.
interpolation: The interpolation method for upsampling. Defaults to interpolation: The interpolation method for upsampling. Defaults to
`bilinear`. `bilinear`.
use_depthwise_convolution: Allows spatial pooling to be separable
depthwise convolusions. [Encoder-Decoder with Atrous Separable
Convolution for Semantic Image Segmentation](
https://arxiv.org/pdf/1802.02611.pdf)
**kwargs: Other keyword arguments for the layer. **kwargs: Other keyword arguments for the layer.
""" """
super(SpatialPyramidPooling, self).__init__(**kwargs) super(SpatialPyramidPooling, self).__init__(**kwargs)
...@@ -76,6 +83,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -76,6 +83,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self.interpolation = interpolation self.interpolation = interpolation
self.input_spec = tf.keras.layers.InputSpec(ndim=4) self.input_spec = tf.keras.layers.InputSpec(ndim=4)
self.pool_kernel_size = pool_kernel_size self.pool_kernel_size = pool_kernel_size
self.use_depthwise_convolution = use_depthwise_convolution
def build(self, input_shape): def build(self, input_shape):
height = input_shape[1] height = input_shape[1]
...@@ -109,9 +117,20 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -109,9 +117,20 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self.aspp_layers.append(conv_sequential) self.aspp_layers.append(conv_sequential)
for dilation_rate in self.dilation_rates: for dilation_rate in self.dilation_rates:
conv_sequential = tf.keras.Sequential([ leading_layers = []
kernel_size = (3, 3)
if self.use_depthwise_convolution:
leading_layers += [
tf.keras.layers.DepthwiseConv2D(
depth_multiplier=1, kernel_size=kernel_size,
padding='same', depthwise_regularizer=self.kernel_regularizer,
depthwise_initializer=self.kernel_initializer,
dilation_rate=dilation_rate, use_bias=False)
]
kernel_size = (1, 1)
conv_sequential = tf.keras.Sequential(leading_layers + [
tf.keras.layers.Conv2D( tf.keras.layers.Conv2D(
filters=self.output_channels, kernel_size=(3, 3), filters=self.output_channels, kernel_size=kernel_size,
padding='same', kernel_regularizer=self.kernel_regularizer, padding='same', kernel_regularizer=self.kernel_regularizer,
kernel_initializer=self.kernel_initializer, kernel_initializer=self.kernel_initializer,
dilation_rate=dilation_rate, use_bias=False), dilation_rate=dilation_rate, use_bias=False),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment