Commit 93581b17 authored by Zhenyu Tan's avatar Zhenyu Tan Committed by A. Unique TensorFlower
Browse files

Renaming ASPP to SpatialPyramidPooling.

PiperOrigin-RevId: 335049970
parent f1bcd9bb
...@@ -13,4 +13,4 @@ ...@@ -13,4 +13,4 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Keras-CV layers package definition.""" """Keras-CV layers package definition."""
from official.vision.keras_cv.layers.deeplab import ASPP from official.vision.keras_cv.layers.deeplab import SpatialPyramidPooling
...@@ -18,7 +18,7 @@ import tensorflow as tf ...@@ -18,7 +18,7 @@ import tensorflow as tf
@tf.keras.utils.register_keras_serializable(package='keras_cv') @tf.keras.utils.register_keras_serializable(package='keras_cv')
class ASPP(tf.keras.layers.Layer): class SpatialPyramidPooling(tf.keras.layers.Layer):
"""Implements the Atrous Spatial Pyramid Pooling. """Implements the Atrous Spatial Pyramid Pooling.
Reference: Reference:
...@@ -36,10 +36,10 @@ class ASPP(tf.keras.layers.Layer): ...@@ -36,10 +36,10 @@ class ASPP(tf.keras.layers.Layer):
kernel_regularizer=None, kernel_regularizer=None,
interpolation='bilinear', interpolation='bilinear',
**kwargs): **kwargs):
"""Initializes `ASPP`. """Initializes `SpatialPyramidPooling`.
Arguments: Arguments:
output_channels: Number of channels produced by ASPP. output_channels: Number of channels produced by SpatialPyramidPooling.
dilation_rates: A list of integers for parallel dilated conv. dilation_rates: A list of integers for parallel dilated conv.
batchnorm_momentum: A float for the momentum in BatchNorm. Defaults to batchnorm_momentum: A float for the momentum in BatchNorm. Defaults to
0.99. 0.99.
...@@ -51,7 +51,7 @@ class ASPP(tf.keras.layers.Layer): ...@@ -51,7 +51,7 @@ class ASPP(tf.keras.layers.Layer):
`bilinear`. `bilinear`.
**kwargs: Other keyword arguments for the layer. **kwargs: Other keyword arguments for the layer.
""" """
super(ASPP, self).__init__(**kwargs) super(SpatialPyramidPooling, self).__init__(**kwargs)
self.output_channels = output_channels self.output_channels = output_channels
self.dilation_rates = dilation_rates self.dilation_rates = dilation_rates
...@@ -133,5 +133,5 @@ class ASPP(tf.keras.layers.Layer): ...@@ -133,5 +133,5 @@ class ASPP(tf.keras.layers.Layer):
self.kernel_regularizer), self.kernel_regularizer),
'interpolation': self.interpolation, 'interpolation': self.interpolation,
} }
base_config = super(ASPP, self).get_config() base_config = super(SpatialPyramidPooling, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
...@@ -25,20 +25,22 @@ class DeeplabTest(keras_parameterized.TestCase): ...@@ -25,20 +25,22 @@ class DeeplabTest(keras_parameterized.TestCase):
def test_aspp(self): def test_aspp(self):
inputs = tf.keras.Input(shape=(64, 64, 128), dtype=tf.float32) inputs = tf.keras.Input(shape=(64, 64, 128), dtype=tf.float32)
layer = deeplab.ASPP(output_channels=256, dilation_rates=[6, 12, 18]) layer = deeplab.SpatialPyramidPooling(output_channels=256,
dilation_rates=[6, 12, 18])
output = layer(inputs) output = layer(inputs)
self.assertAllEqual([None, 64, 64, 256], output.shape) self.assertAllEqual([None, 64, 64, 256], output.shape)
def test_aspp_invalid_shape(self): def test_aspp_invalid_shape(self):
inputs = tf.keras.Input(shape=(64, 64), dtype=tf.float32) inputs = tf.keras.Input(shape=(64, 64), dtype=tf.float32)
layer = deeplab.ASPP(output_channels=256, dilation_rates=[6, 12, 18]) layer = deeplab.SpatialPyramidPooling(output_channels=256,
dilation_rates=[6, 12, 18])
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = layer(inputs) _ = layer(inputs)
def test_config_with_custom_name(self): def test_config_with_custom_name(self):
layer = deeplab.ASPP(256, [5], name='aspp') layer = deeplab.SpatialPyramidPooling(256, [5], name='aspp')
config = layer.get_config() config = layer.get_config()
layer_1 = deeplab.ASPP.from_config(config) layer_1 = deeplab.SpatialPyramidPooling.from_config(config)
self.assertEqual(layer_1.name, layer.name) self.assertEqual(layer_1.name, layer.name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment