Commit bc7c670f authored by Zhenyu Tan's avatar Zhenyu Tan Committed by A. Unique TensorFlower
Browse files

Add ASPP for KerasCV.

PiperOrigin-RevId: 333644426
parent b547c6fa
...@@ -14,5 +14,6 @@ ...@@ -14,5 +14,6 @@
# ============================================================================== # ==============================================================================
"""Keras-CV package definition.""" """Keras-CV package definition."""
# pylint: disable=wildcard-import # pylint: disable=wildcard-import
from official.vision.keras_cv import layers
from official.vision.keras_cv import losses from official.vision.keras_cv import losses
from official.vision.keras_cv import ops from official.vision.keras_cv import ops
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-CV layers package definition."""
from official.vision.keras_cv.layers.deeplab import ASPP
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layers for DeepLabV3."""
import tensorflow as tf
@tf.keras.utils.register_keras_serializable(package='keras_cv')
class ASPP(tf.keras.layers.Layer):
"""Implements the Atrous Spatial Pyramid Pooling.
Reference:
[Rethinking Atrous Convolution for Semantic Image Segmentation](
https://arxiv.org/pdf/1706.05587.pdf)
"""
def __init__(
self,
output_channels,
dilation_rates,
batchnorm_momentum=0.99,
dropout=0.5,
kernel_initializer='glorot_uniform',
kernel_regularizer=None,
interpolation='bilinear',
**kwargs):
"""Initializes `ASPP`.
Arguments:
output_channels: Number of channels produced by ASPP.
dilation_rates: A list of integers for parallel dilated conv.
batchnorm_momentum: A float for the momentum in BatchNorm. Defaults to
0.99.
dropout: A float for the dropout rate before output. Defaults to 0.5.
kernel_initializer: Kernel initializer for conv layers. Defaults to
`glorot_uniform`.
kernel_regularizer: Kernel regularizer for conv layers. Defaults to None.
interpolation: The interpolation method for upsampling. Defaults to
`bilinear`.
**kwargs: Other keyword arguments for the layer.
"""
super(ASPP, self).__init__(**kwargs)
self.output_channels = output_channels
self.dilation_rates = dilation_rates
self.batchnorm_momentum = batchnorm_momentum
self.dropout = dropout
self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self.interpolation = interpolation
self.input_spec = tf.keras.layers.InputSpec(ndim=4)
def build(self, input_shape):
height = input_shape[1]
width = input_shape[2]
channels = input_shape[3]
self.aspp_layers = []
conv_sequential = tf.keras.Sequential([
tf.keras.layers.Conv2D(
filters=self.output_channels, kernel_size=(1, 1),
kernel_initializer=self.kernel_initializer,
kernel_regularizer=self.kernel_regularizer, use_bias=False),
tf.keras.layers.BatchNormalization(momentum=self.batchnorm_momentum),
tf.keras.layers.Activation('relu')])
self.aspp_layers.append(conv_sequential)
for dilation_rate in self.dilation_rates:
conv_sequential = tf.keras.Sequential([
tf.keras.layers.Conv2D(
filters=self.output_channels, kernel_size=(3, 3),
padding='same', kernel_regularizer=self.kernel_regularizer,
kernel_initializer=self.kernel_initializer,
dilation_rate=dilation_rate, use_bias=False),
tf.keras.layers.BatchNormalization(momentum=self.batchnorm_momentum),
tf.keras.layers.Activation('relu')])
self.aspp_layers.append(conv_sequential)
pool_sequential = tf.keras.Sequential([
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Reshape((1, 1, channels)),
tf.keras.layers.Conv2D(
filters=self.output_channels, kernel_size=(1, 1),
kernel_initializer=self.kernel_initializer,
kernel_regularizer=self.kernel_regularizer, use_bias=False),
tf.keras.layers.BatchNormalization(momentum=self.batchnorm_momentum),
tf.keras.layers.Activation('relu'),
tf.keras.layers.experimental.preprocessing.Resizing(
height, width, interpolation=self.interpolation)])
self.aspp_layers.append(pool_sequential)
self.projection = tf.keras.Sequential([
tf.keras.layers.Conv2D(
filters=self.output_channels, kernel_size=(1, 1),
kernel_initializer=self.kernel_initializer,
kernel_regularizer=self.kernel_regularizer, use_bias=False),
tf.keras.layers.BatchNormalization(momentum=self.batchnorm_momentum),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Dropout(rate=self.dropout)])
def call(self, inputs, training=None):
if training is None:
training = tf.keras.backend.learning_phase()
result = []
for layer in self.aspp_layers:
result.append(layer(inputs, training=training))
result = tf.concat(result, axis=-1)
result = self.projection(result, training=training)
return result
def get_config(self):
config = {
'output_channels': self.output_channels,
'dilation_rates': self.dilation_rates,
'batchnorm_momentum': self.batchnorm_momentum,
'dropout': self.dropout,
'kernel_initializer': tf.keras.initializers.serialize(
self.kernel_initializer),
'kernel_regularizer': tf.keras.regularizers.serialize(
self.kernel_regularizer),
'interpolation': self.interpolation,
}
base_config = super(ASPP, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ASPP."""
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized
from official.vision.keras_cv.layers import deeplab
@keras_parameterized.run_all_keras_modes
class DeeplabTest(keras_parameterized.TestCase):
def test_aspp(self):
inputs = tf.keras.Input(shape=(64, 64, 128), dtype=tf.float32)
layer = deeplab.ASPP(output_channels=256, dilation_rates=[6, 12, 18])
output = layer(inputs)
self.assertAllEqual([None, 64, 64, 256], output.shape)
def test_aspp_invalid_shape(self):
inputs = tf.keras.Input(shape=(64, 64), dtype=tf.float32)
layer = deeplab.ASPP(output_channels=256, dilation_rates=[6, 12, 18])
with self.assertRaises(ValueError):
_ = layer(inputs)
def test_config_with_custom_name(self):
layer = deeplab.ASPP(256, [5], name='aspp')
config = layer.get_config()
layer_1 = deeplab.ASPP.from_config(config)
self.assertEqual(layer_1.name, layer.name)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment