Commit 93581b17 authored by Zhenyu Tan's avatar Zhenyu Tan Committed by A. Unique TensorFlower
Browse files

Renaming ASPP to SpatialPyramidPooling.

PiperOrigin-RevId: 335049970
parent f1bcd9bb
......@@ -13,4 +13,4 @@
# limitations under the License.
# ==============================================================================
"""Keras-CV layers package definition."""
from official.vision.keras_cv.layers.deeplab import ASPP
from official.vision.keras_cv.layers.deeplab import SpatialPyramidPooling
......@@ -18,7 +18,7 @@ import tensorflow as tf
@tf.keras.utils.register_keras_serializable(package='keras_cv')
class ASPP(tf.keras.layers.Layer):
class SpatialPyramidPooling(tf.keras.layers.Layer):
"""Implements the Atrous Spatial Pyramid Pooling.
Reference:
......@@ -36,10 +36,10 @@ class ASPP(tf.keras.layers.Layer):
kernel_regularizer=None,
interpolation='bilinear',
**kwargs):
"""Initializes `ASPP`.
"""Initializes `SpatialPyramidPooling`.
Arguments:
output_channels: Number of channels produced by ASPP.
output_channels: Number of channels produced by SpatialPyramidPooling.
dilation_rates: A list of integers for parallel dilated conv.
batchnorm_momentum: A float for the momentum in BatchNorm. Defaults to
0.99.
......@@ -51,7 +51,7 @@ class ASPP(tf.keras.layers.Layer):
`bilinear`.
**kwargs: Other keyword arguments for the layer.
"""
super(ASPP, self).__init__(**kwargs)
super(SpatialPyramidPooling, self).__init__(**kwargs)
self.output_channels = output_channels
self.dilation_rates = dilation_rates
......@@ -133,5 +133,5 @@ class ASPP(tf.keras.layers.Layer):
self.kernel_regularizer),
'interpolation': self.interpolation,
}
base_config = super(ASPP, self).get_config()
base_config = super(SpatialPyramidPooling, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
......@@ -25,20 +25,22 @@ class DeeplabTest(keras_parameterized.TestCase):
def test_aspp(self):
inputs = tf.keras.Input(shape=(64, 64, 128), dtype=tf.float32)
layer = deeplab.ASPP(output_channels=256, dilation_rates=[6, 12, 18])
layer = deeplab.SpatialPyramidPooling(output_channels=256,
dilation_rates=[6, 12, 18])
output = layer(inputs)
self.assertAllEqual([None, 64, 64, 256], output.shape)
def test_aspp_invalid_shape(self):
inputs = tf.keras.Input(shape=(64, 64), dtype=tf.float32)
layer = deeplab.ASPP(output_channels=256, dilation_rates=[6, 12, 18])
layer = deeplab.SpatialPyramidPooling(output_channels=256,
dilation_rates=[6, 12, 18])
with self.assertRaises(ValueError):
_ = layer(inputs)
def test_config_with_custom_name(self):
layer = deeplab.ASPP(256, [5], name='aspp')
layer = deeplab.SpatialPyramidPooling(256, [5], name='aspp')
config = layer.get_config()
layer_1 = deeplab.ASPP.from_config(config)
layer_1 = deeplab.SpatialPyramidPooling.from_config(config)
self.assertEqual(layer_1.name, layer.name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment