Commit 0c94ae2d authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 337550702
parent 3bc37a6a
......@@ -14,7 +14,7 @@
# limitations under the License.
# ==============================================================================
"""Decoders configurations."""
from typing import Optional
from typing import Optional, List
# Import libraries
import dataclasses
......@@ -35,6 +35,15 @@ class FPN(hyperparams.Config):
use_separable_conv: bool = False
@dataclasses.dataclass
class ASPP(hyperparams.Config):
"""ASPP config."""
level: int = 4
dilation_rates: List[int] = dataclasses.field(default_factory=list)
dropout_rate: float = 0.0
num_filters: int = 256
@dataclasses.dataclass
class Decoder(hyperparams.OneOfConfig):
"""Configuration for decoders.
......@@ -46,3 +55,4 @@ class Decoder(hyperparams.OneOfConfig):
type: Optional[str] = None
fpn: FPN = FPN()
identity: Identity = Identity()
aspp: ASPP = ASPP()
......@@ -15,4 +15,5 @@
# ==============================================================================
"""Decoders package definition."""
from official.vision.beta.modeling.decoders.aspp import ASPP
from official.vision.beta.modeling.decoders.fpn import FPN
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ASPP decoder."""
# Import libraries
import tensorflow as tf
from official.vision import keras_cv
@tf.keras.utils.register_keras_serializable(package='Vision')
class ASPP(tf.keras.layers.Layer):
"""ASPP."""
def __init__(self,
level,
dilation_rates,
num_filters=256,
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
dropout_rate=0.0,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
interpolation='bilinear',
**kwargs):
"""ASPP initialization function.
Args:
level: `int` level to apply ASPP.
dilation_rates: `list` of dilation rates.
num_filters: `int` number of output filters in ASPP.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
dropout_rate: `float` rate for dropout regularization.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
interpolation: interpolation method, one of bilinear, nearest, bicubic,
area, lanczos3, lanczos5, gaussian, or mitchellcubic.
**kwargs: keyword arguments to be passed.
"""
super(ASPP, self).__init__(**kwargs)
self._config_dict = {
'level': level,
'dilation_rates': dilation_rates,
'num_filters': num_filters,
'use_sync_bn': use_sync_bn,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
'dropout_rate': dropout_rate,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'interpolation': interpolation,
}
def build(self, input_shape):
self.aspp = keras_cv.layers.SpatialPyramidPooling(
output_channels=self._config_dict['num_filters'],
dilation_rates=self._config_dict['dilation_rates'],
use_sync_bn=self._config_dict['use_sync_bn'],
batchnorm_momentum=self._config_dict['norm_momentum'],
batchnorm_epsilon=self._config_dict['norm_epsilon'],
dropout=self._config_dict['dropout_rate'],
kernel_initializer=self._config_dict['kernel_initializer'],
kernel_regularizer=self._config_dict['kernel_regularizer'],
interpolation=self._config_dict['interpolation'])
def call(self, inputs):
"""ASPP call method.
The output of ASPP will be a dict of level, Tensor even if only one
level is present. Hence, this will be compatible with the rest of the
segmentation model interfaces..
Args:
inputs: A dict of tensors
- key: `str`, the level of the multilevel feature maps.
- values: `Tensor`, [batch, height_l, width_l, filter_size].
Returns:
A dict of tensors
- key: `str`, the level of the multilevel feature maps.
- values: `Tensor`, output of ASPP module.
"""
outputs = {}
level = str(self._config_dict['level'])
outputs[level] = self.aspp(inputs[level])
return outputs
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for aspp."""
# Import libraries
from absl.testing import parameterized
import tensorflow as tf
from official.vision.beta.modeling.backbones import resnet
from official.vision.beta.modeling.decoders import aspp
class ASPPTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(3, [6, 12, 18, 24], 128),
(3, [6, 12, 18], 128),
(3, [6, 12], 256),
(4, [6, 12, 18, 24], 128),
(4, [6, 12, 18], 128),
(4, [6, 12], 256),
)
def test_network_creation(self, level, dilation_rates, num_filters):
"""Test creation of ASPP."""
input_size = 256
tf.keras.backend.set_image_data_format('channels_last')
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
backbone = resnet.ResNet(model_id=50)
network = aspp.ASPP(
level=level,
dilation_rates=dilation_rates,
num_filters=num_filters)
endpoints = backbone(inputs)
feats = network(endpoints)
self.assertIn(str(level), feats)
self.assertAllEqual(
[1, input_size // 2**level, input_size // 2**level, num_filters],
feats[str(level)].shape.as_list())
def test_serialize_deserialize(self):
# Create a network object that sets all of its config options.
kwargs = dict(
level=3,
dilation_rates=[6, 12],
num_filters=256,
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
interpolation='bilinear',
dropout_rate=0.2,
)
network = aspp.ASPP(**kwargs)
expected_config = dict(kwargs)
self.assertEqual(network.get_config(), expected_config)
# Create another network object from the first object's config.
new_network = aspp.ASPP.from_config(network.get_config())
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(network.get_config(), new_network.get_config())
if __name__ == '__main__':
tf.test.main()
......@@ -52,6 +52,16 @@ def build_decoder(input_specs,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
elif decoder_type == 'aspp':
decoder = decoders.ASPP(
level=decoder_cfg.level,
dilation_rates=decoder_cfg.dilation_rates,
num_filters=decoder_cfg.num_filters,
dropout_rate=decoder_cfg.dropout_rate,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
else:
raise ValueError('Decoder {!r} not implement'.format(decoder_type))
......
......@@ -30,7 +30,9 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self,
output_channels,
dilation_rates,
use_sync_bn=False,
batchnorm_momentum=0.99,
batchnorm_epsilon=0.001,
dropout=0.5,
kernel_initializer='glorot_uniform',
kernel_regularizer=None,
......@@ -41,8 +43,11 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
Arguments:
output_channels: Number of channels produced by SpatialPyramidPooling.
dilation_rates: A list of integers for parallel dilated conv.
use_sync_bn: A bool, whether or not to use sync batch normalization.
batchnorm_momentum: A float for the momentum in BatchNorm. Defaults to
0.99.
batchnorm_epsilon: A float for the epsilon value in BatchNorm. Defaults to
0.001.
dropout: A float for the dropout rate before output. Defaults to 0.5.
kernel_initializer: Kernel initializer for conv layers. Defaults to
`glorot_uniform`.
......@@ -55,7 +60,9 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self.output_channels = output_channels
self.dilation_rates = dilation_rates
self.use_sync_bn = use_sync_bn
self.batchnorm_momentum = batchnorm_momentum
self.batchnorm_epsilon = batchnorm_epsilon
self.dropout = dropout
self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
......@@ -69,13 +76,28 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self.aspp_layers = []
if self.use_sync_bn:
bn_op = tf.keras.layers.experimental.SyncBatchNormalization
else:
bn_op = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = -1
else:
bn_axis = 1
conv_sequential = tf.keras.Sequential([
tf.keras.layers.Conv2D(
filters=self.output_channels, kernel_size=(1, 1),
kernel_initializer=self.kernel_initializer,
kernel_regularizer=self.kernel_regularizer, use_bias=False),
tf.keras.layers.BatchNormalization(momentum=self.batchnorm_momentum),
tf.keras.layers.Activation('relu')])
kernel_regularizer=self.kernel_regularizer,
use_bias=False),
bn_op(
axis=bn_axis,
momentum=self.batchnorm_momentum,
epsilon=self.batchnorm_epsilon),
tf.keras.layers.Activation('relu')
])
self.aspp_layers.append(conv_sequential)
for dilation_rate in self.dilation_rates:
......@@ -85,7 +107,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
padding='same', kernel_regularizer=self.kernel_regularizer,
kernel_initializer=self.kernel_initializer,
dilation_rate=dilation_rate, use_bias=False),
tf.keras.layers.BatchNormalization(momentum=self.batchnorm_momentum),
bn_op(axis=bn_axis, momentum=self.batchnorm_momentum,
epsilon=self.batchnorm_epsilon),
tf.keras.layers.Activation('relu')])
self.aspp_layers.append(conv_sequential)
......@@ -95,8 +118,12 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
tf.keras.layers.Conv2D(
filters=self.output_channels, kernel_size=(1, 1),
kernel_initializer=self.kernel_initializer,
kernel_regularizer=self.kernel_regularizer, use_bias=False),
tf.keras.layers.BatchNormalization(momentum=self.batchnorm_momentum),
kernel_regularizer=self.kernel_regularizer,
use_bias=False),
bn_op(
axis=bn_axis,
momentum=self.batchnorm_momentum,
epsilon=self.batchnorm_epsilon),
tf.keras.layers.Activation('relu'),
tf.keras.layers.experimental.preprocessing.Resizing(
height, width, interpolation=self.interpolation)])
......@@ -106,8 +133,12 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
tf.keras.layers.Conv2D(
filters=self.output_channels, kernel_size=(1, 1),
kernel_initializer=self.kernel_initializer,
kernel_regularizer=self.kernel_regularizer, use_bias=False),
tf.keras.layers.BatchNormalization(momentum=self.batchnorm_momentum),
kernel_regularizer=self.kernel_regularizer,
use_bias=False),
bn_op(
axis=bn_axis,
momentum=self.batchnorm_momentum,
epsilon=self.batchnorm_epsilon),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Dropout(rate=self.dropout)])
......@@ -125,7 +156,9 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
config = {
'output_channels': self.output_channels,
'dilation_rates': self.dilation_rates,
'use_sync_bn': self.use_sync_bn,
'batchnorm_momentum': self.batchnorm_momentum,
'batchnorm_epsilon': self.batchnorm_epsilon,
'dropout': self.dropout,
'kernel_initializer': tf.keras.initializers.serialize(
self.kernel_initializer),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment