Commit 65c81380 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change.

PiperOrigin-RevId: 411683806
parent 6f4e596a
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
# Lint as: python3 # Lint as: python3
"""Decoders configurations.""" """Decoders configurations."""
from typing import Optional, List import dataclasses
from typing import List, Optional
# Import libraries # Import libraries
import dataclasses
from official.modeling import hyperparams from official.modeling import hyperparams
...@@ -53,6 +53,8 @@ class ASPP(hyperparams.Config): ...@@ -53,6 +53,8 @@ class ASPP(hyperparams.Config):
num_filters: int = 256 num_filters: int = 256
use_depthwise_convolution: bool = False use_depthwise_convolution: bool = False
pool_kernel_size: Optional[List[int]] = None # Use global average pooling. pool_kernel_size: Optional[List[int]] = None # Use global average pooling.
spp_layer_version: str = 'v1'
output_tensor: bool = False
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains definitions of Atrous Spatial Pyramid Pooling (ASPP) decoder.""" """Contains definitions of Atrous Spatial Pyramid Pooling (ASPP) decoder."""
from typing import Any, List, Mapping, Optional from typing import Any, List, Mapping, Optional, Union
# Import libraries # Import libraries
...@@ -22,6 +22,9 @@ import tensorflow as tf ...@@ -22,6 +22,9 @@ import tensorflow as tf
from official.modeling import hyperparams from official.modeling import hyperparams
from official.vision.beta.modeling.decoders import factory from official.vision.beta.modeling.decoders import factory
from official.vision.beta.modeling.layers import deeplab from official.vision.beta.modeling.layers import deeplab
from official.vision.beta.modeling.layers import nn_layers
TensorMapUnion = Union[tf.Tensor, Mapping[str, tf.Tensor]]
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
...@@ -43,6 +46,8 @@ class ASPP(tf.keras.layers.Layer): ...@@ -43,6 +46,8 @@ class ASPP(tf.keras.layers.Layer):
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
interpolation: str = 'bilinear', interpolation: str = 'bilinear',
use_depthwise_convolution: bool = False, use_depthwise_convolution: bool = False,
spp_layer_version: str = 'v1',
output_tensor: bool = False,
**kwargs): **kwargs):
"""Initializes an Atrous Spatial Pyramid Pooling (ASPP) layer. """Initializes an Atrous Spatial Pyramid Pooling (ASPP) layer.
...@@ -67,9 +72,12 @@ class ASPP(tf.keras.layers.Layer): ...@@ -67,9 +72,12 @@ class ASPP(tf.keras.layers.Layer):
`gaussian`, or `mitchellcubic`. `gaussian`, or `mitchellcubic`.
use_depthwise_convolution: If True depthwise separable convolutions will use_depthwise_convolution: If True depthwise separable convolutions will
be added to the Atrous spatial pyramid pooling. be added to the Atrous spatial pyramid pooling.
spp_layer_version: A `str` of spatial pyramid pooling layer version.
output_tensor: Whether to output a single tensor or a dictionary of tensor.
Default is false.
**kwargs: Additional keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
""" """
super(ASPP, self).__init__(**kwargs) super().__init__(**kwargs)
self._config_dict = { self._config_dict = {
'level': level, 'level': level,
'dilation_rates': dilation_rates, 'dilation_rates': dilation_rates,
...@@ -84,7 +92,11 @@ class ASPP(tf.keras.layers.Layer): ...@@ -84,7 +92,11 @@ class ASPP(tf.keras.layers.Layer):
'kernel_regularizer': kernel_regularizer, 'kernel_regularizer': kernel_regularizer,
'interpolation': interpolation, 'interpolation': interpolation,
'use_depthwise_convolution': use_depthwise_convolution, 'use_depthwise_convolution': use_depthwise_convolution,
'spp_layer_version': spp_layer_version,
'output_tensor': output_tensor
} }
self._aspp_layer = deeplab.SpatialPyramidPooling if self._config_dict[
'spp_layer_version'] == 'v1' else nn_layers.SpatialPyramidPooling
def build(self, input_shape): def build(self, input_shape):
pool_kernel_size = None pool_kernel_size = None
...@@ -93,7 +105,8 @@ class ASPP(tf.keras.layers.Layer): ...@@ -93,7 +105,8 @@ class ASPP(tf.keras.layers.Layer):
int(p_size // 2**self._config_dict['level']) int(p_size // 2**self._config_dict['level'])
for p_size in self._config_dict['pool_kernel_size'] for p_size in self._config_dict['pool_kernel_size']
] ]
self.aspp = deeplab.SpatialPyramidPooling(
self.aspp = self._aspp_layer(
output_channels=self._config_dict['num_filters'], output_channels=self._config_dict['num_filters'],
dilation_rates=self._config_dict['dilation_rates'], dilation_rates=self._config_dict['dilation_rates'],
pool_kernel_size=pool_kernel_size, pool_kernel_size=pool_kernel_size,
...@@ -108,28 +121,32 @@ class ASPP(tf.keras.layers.Layer): ...@@ -108,28 +121,32 @@ class ASPP(tf.keras.layers.Layer):
use_depthwise_convolution=self._config_dict['use_depthwise_convolution'] use_depthwise_convolution=self._config_dict['use_depthwise_convolution']
) )
def call(self, inputs: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]: def call(self, inputs: TensorMapUnion) -> TensorMapUnion:
"""Calls the Atrous Spatial Pyramid Pooling (ASPP) layer on an input. """Calls the Atrous Spatial Pyramid Pooling (ASPP) layer on an input.
The output of ASPP will be a dict of {`level`, `tf.Tensor`} even if only one The output of ASPP will be a dict of {`level`, `tf.Tensor`} even if only one
level is present. Hence, this will be compatible with the rest of the level is present, if output_tensor is false. Hence, this will be compatible
segmentation model interfaces. with the rest of the segmentation model interfaces.
If output_tensor is true, a single tensot is output.
Args: Args:
inputs: A `dict` of `tf.Tensor` where inputs: A `tf.Tensor` of shape [batch, height_l, width_l, filter_size] or
a `dict` of `tf.Tensor` where
- key: A `str` of the level of the multilevel feature maps. - key: A `str` of the level of the multilevel feature maps.
- values: A `tf.Tensor` of shape [batch, height_l, width_l, - values: A `tf.Tensor` of shape [batch, height_l, width_l,
filter_size]. filter_size].
Returns: Returns:
A `dict` of `tf.Tensor` where A `tf.Tensor` of shape [batch, height_l, width_l, filter_size] or a `dict`
of `tf.Tensor` where
- key: A `str` of the level of the multilevel feature maps. - key: A `str` of the level of the multilevel feature maps.
- values: A `tf.Tensor` of output of ASPP module. - values: A `tf.Tensor` of output of ASPP module.
""" """
outputs = {} outputs = {}
level = str(self._config_dict['level']) level = str(self._config_dict['level'])
outputs[level] = self.aspp(inputs[level]) backbone_output = inputs[level] if isinstance(inputs, dict) else inputs
return outputs outputs = self.aspp(backbone_output)
return outputs if self._config_dict['output_tensor'] else {level: outputs}
def get_config(self) -> Mapping[str, Any]: def get_config(self) -> Mapping[str, Any]:
return self._config_dict return self._config_dict
...@@ -180,4 +197,6 @@ def build_aspp_decoder( ...@@ -180,4 +197,6 @@ def build_aspp_decoder(
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
activation=norm_activation_config.activation, activation=norm_activation_config.activation,
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer,
spp_layer_version=decoder_cfg.spp_layer_version,
output_tensor=decoder_cfg.output_tensor)
...@@ -26,14 +26,15 @@ from official.vision.beta.modeling.decoders import aspp ...@@ -26,14 +26,15 @@ from official.vision.beta.modeling.decoders import aspp
class ASPPTest(parameterized.TestCase, tf.test.TestCase): class ASPPTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
(3, [6, 12, 18, 24], 128), (3, [6, 12, 18, 24], 128, 'v1'),
(3, [6, 12, 18], 128), (3, [6, 12, 18], 128, 'v1'),
(3, [6, 12], 256), (3, [6, 12], 256, 'v1'),
(4, [6, 12, 18, 24], 128), (4, [6, 12, 18, 24], 128, 'v2'),
(4, [6, 12, 18], 128), (4, [6, 12, 18], 128, 'v2'),
(4, [6, 12], 256), (4, [6, 12], 256, 'v2'),
) )
def test_network_creation(self, level, dilation_rates, num_filters): def test_network_creation(self, level, dilation_rates, num_filters,
spp_layer_version):
"""Test creation of ASPP.""" """Test creation of ASPP."""
input_size = 256 input_size = 256
...@@ -45,7 +46,8 @@ class ASPPTest(parameterized.TestCase, tf.test.TestCase): ...@@ -45,7 +46,8 @@ class ASPPTest(parameterized.TestCase, tf.test.TestCase):
network = aspp.ASPP( network = aspp.ASPP(
level=level, level=level,
dilation_rates=dilation_rates, dilation_rates=dilation_rates,
num_filters=num_filters) num_filters=num_filters,
spp_layer_version=spp_layer_version)
endpoints = backbone(inputs) endpoints = backbone(inputs)
feats = network(endpoints) feats = network(endpoints)
...@@ -71,7 +73,8 @@ class ASPPTest(parameterized.TestCase, tf.test.TestCase): ...@@ -71,7 +73,8 @@ class ASPPTest(parameterized.TestCase, tf.test.TestCase):
interpolation='bilinear', interpolation='bilinear',
dropout_rate=0.2, dropout_rate=0.2,
use_depthwise_convolution='false', use_depthwise_convolution='false',
) spp_layer_version='v1',
output_tensor=False)
network = aspp.ASPP(**kwargs) network = aspp.ASPP(**kwargs)
expected_config = dict(kwargs) expected_config = dict(kwargs)
......
...@@ -1217,7 +1217,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -1217,7 +1217,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self.aspp_layers.append(pooling + [conv2, norm2]) self.aspp_layers.append(pooling + [conv2, norm2])
self._resize_layer = tf.keras.layers.Resizing( self._resizing_layer = tf.keras.layers.Resizing(
height, width, interpolation=self._interpolation, dtype=tf.float32) height, width, interpolation=self._interpolation, dtype=tf.float32)
self._projection = [ self._projection = [
...@@ -1250,7 +1250,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -1250,7 +1250,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
# Apply resize layer to the end of the last set of layers. # Apply resize layer to the end of the last set of layers.
if i == len(self.aspp_layers) - 1: if i == len(self.aspp_layers) - 1:
x = self._resize_layer(x) x = self._resizing_layer(x)
result.append(tf.cast(x, inputs.dtype)) result.append(tf.cast(x, inputs.dtype))
x = self._concat_layer(result) x = self._concat_layer(result)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment