Commit 65c81380 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change.

PiperOrigin-RevId: 411683806
parent 6f4e596a
......@@ -14,10 +14,10 @@
# Lint as: python3
"""Decoders configurations."""
from typing import Optional, List
import dataclasses
from typing import List, Optional
# Import libraries
import dataclasses
from official.modeling import hyperparams
......@@ -53,6 +53,8 @@ class ASPP(hyperparams.Config):
num_filters: int = 256
use_depthwise_convolution: bool = False
pool_kernel_size: Optional[List[int]] = None # Use global average pooling.
spp_layer_version: str = 'v1'
output_tensor: bool = False
@dataclasses.dataclass
......
......@@ -13,7 +13,7 @@
# limitations under the License.
"""Contains definitions of Atrous Spatial Pyramid Pooling (ASPP) decoder."""
from typing import Any, List, Mapping, Optional
from typing import Any, List, Mapping, Optional, Union
# Import libraries
......@@ -22,6 +22,9 @@ import tensorflow as tf
from official.modeling import hyperparams
from official.vision.beta.modeling.decoders import factory
from official.vision.beta.modeling.layers import deeplab
from official.vision.beta.modeling.layers import nn_layers
TensorMapUnion = Union[tf.Tensor, Mapping[str, tf.Tensor]]
@tf.keras.utils.register_keras_serializable(package='Vision')
......@@ -43,6 +46,8 @@ class ASPP(tf.keras.layers.Layer):
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
interpolation: str = 'bilinear',
use_depthwise_convolution: bool = False,
spp_layer_version: str = 'v1',
output_tensor: bool = False,
**kwargs):
"""Initializes an Atrous Spatial Pyramid Pooling (ASPP) layer.
......@@ -67,9 +72,12 @@ class ASPP(tf.keras.layers.Layer):
`gaussian`, or `mitchellcubic`.
use_depthwise_convolution: If True depthwise separable convolutions will
be added to the Atrous spatial pyramid pooling.
spp_layer_version: A `str` of spatial pyramid pooling layer version.
output_tensor: Whether to output a single tensor or a dictionary of tensor.
Default is false.
**kwargs: Additional keyword arguments to be passed.
"""
super(ASPP, self).__init__(**kwargs)
super().__init__(**kwargs)
self._config_dict = {
'level': level,
'dilation_rates': dilation_rates,
......@@ -84,7 +92,11 @@ class ASPP(tf.keras.layers.Layer):
'kernel_regularizer': kernel_regularizer,
'interpolation': interpolation,
'use_depthwise_convolution': use_depthwise_convolution,
'spp_layer_version': spp_layer_version,
'output_tensor': output_tensor
}
self._aspp_layer = deeplab.SpatialPyramidPooling if self._config_dict[
'spp_layer_version'] == 'v1' else nn_layers.SpatialPyramidPooling
def build(self, input_shape):
pool_kernel_size = None
......@@ -93,7 +105,8 @@ class ASPP(tf.keras.layers.Layer):
int(p_size // 2**self._config_dict['level'])
for p_size in self._config_dict['pool_kernel_size']
]
self.aspp = deeplab.SpatialPyramidPooling(
self.aspp = self._aspp_layer(
output_channels=self._config_dict['num_filters'],
dilation_rates=self._config_dict['dilation_rates'],
pool_kernel_size=pool_kernel_size,
......@@ -108,28 +121,32 @@ class ASPP(tf.keras.layers.Layer):
use_depthwise_convolution=self._config_dict['use_depthwise_convolution']
)
def call(self, inputs: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
def call(self, inputs: TensorMapUnion) -> TensorMapUnion:
"""Calls the Atrous Spatial Pyramid Pooling (ASPP) layer on an input.
The output of ASPP will be a dict of {`level`, `tf.Tensor`} even if only one
level is present. Hence, this will be compatible with the rest of the
segmentation model interfaces.
level is present, if output_tensor is false. Hence, this will be compatible
with the rest of the segmentation model interfaces.
If output_tensor is true, a single tensot is output.
Args:
inputs: A `dict` of `tf.Tensor` where
inputs: A `tf.Tensor` of shape [batch, height_l, width_l, filter_size] or
a `dict` of `tf.Tensor` where
- key: A `str` of the level of the multilevel feature maps.
- values: A `tf.Tensor` of shape [batch, height_l, width_l,
filter_size].
Returns:
A `dict` of `tf.Tensor` where
A `tf.Tensor` of shape [batch, height_l, width_l, filter_size] or a `dict`
of `tf.Tensor` where
- key: A `str` of the level of the multilevel feature maps.
- values: A `tf.Tensor` of output of ASPP module.
"""
outputs = {}
level = str(self._config_dict['level'])
outputs[level] = self.aspp(inputs[level])
return outputs
backbone_output = inputs[level] if isinstance(inputs, dict) else inputs
outputs = self.aspp(backbone_output)
return outputs if self._config_dict['output_tensor'] else {level: outputs}
def get_config(self) -> Mapping[str, Any]:
return self._config_dict
......@@ -180,4 +197,6 @@ def build_aspp_decoder(
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
activation=norm_activation_config.activation,
kernel_regularizer=l2_regularizer)
kernel_regularizer=l2_regularizer,
spp_layer_version=decoder_cfg.spp_layer_version,
output_tensor=decoder_cfg.output_tensor)
......@@ -26,14 +26,15 @@ from official.vision.beta.modeling.decoders import aspp
class ASPPTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(3, [6, 12, 18, 24], 128),
(3, [6, 12, 18], 128),
(3, [6, 12], 256),
(4, [6, 12, 18, 24], 128),
(4, [6, 12, 18], 128),
(4, [6, 12], 256),
(3, [6, 12, 18, 24], 128, 'v1'),
(3, [6, 12, 18], 128, 'v1'),
(3, [6, 12], 256, 'v1'),
(4, [6, 12, 18, 24], 128, 'v2'),
(4, [6, 12, 18], 128, 'v2'),
(4, [6, 12], 256, 'v2'),
)
def test_network_creation(self, level, dilation_rates, num_filters):
def test_network_creation(self, level, dilation_rates, num_filters,
spp_layer_version):
"""Test creation of ASPP."""
input_size = 256
......@@ -45,7 +46,8 @@ class ASPPTest(parameterized.TestCase, tf.test.TestCase):
network = aspp.ASPP(
level=level,
dilation_rates=dilation_rates,
num_filters=num_filters)
num_filters=num_filters,
spp_layer_version=spp_layer_version)
endpoints = backbone(inputs)
feats = network(endpoints)
......@@ -71,7 +73,8 @@ class ASPPTest(parameterized.TestCase, tf.test.TestCase):
interpolation='bilinear',
dropout_rate=0.2,
use_depthwise_convolution='false',
)
spp_layer_version='v1',
output_tensor=False)
network = aspp.ASPP(**kwargs)
expected_config = dict(kwargs)
......
......@@ -1217,7 +1217,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self.aspp_layers.append(pooling + [conv2, norm2])
self._resize_layer = tf.keras.layers.Resizing(
self._resizing_layer = tf.keras.layers.Resizing(
height, width, interpolation=self._interpolation, dtype=tf.float32)
self._projection = [
......@@ -1250,7 +1250,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
# Apply resize layer to the end of the last set of layers.
if i == len(self.aspp_layers) - 1:
x = self._resize_layer(x)
x = self._resizing_layer(x)
result.append(tf.cast(x, inputs.dtype))
x = self._concat_layer(result)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment