deeplab.py 7.28 KB
Newer Older
Zhenyu Tan's avatar
Zhenyu Tan committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layers for DeepLabV3."""

import tensorflow as tf


@tf.keras.utils.register_keras_serializable(package='keras_cv')
21
class SpatialPyramidPooling(tf.keras.layers.Layer):
Zhenyu Tan's avatar
Zhenyu Tan committed
22
23
24
25
26
27
28
29
30
31
32
  """Implements the Atrous Spatial Pyramid Pooling.

  Reference:
    [Rethinking Atrous Convolution for Semantic Image Segmentation](
      https://arxiv.org/pdf/1706.05587.pdf)
  """

  def __init__(
      self,
      output_channels,
      dilation_rates,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
33
      pool_kernel_size=None,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
34
      use_sync_bn=False,
Zhenyu Tan's avatar
Zhenyu Tan committed
35
      batchnorm_momentum=0.99,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
36
      batchnorm_epsilon=0.001,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
37
      activation='relu',
Zhenyu Tan's avatar
Zhenyu Tan committed
38
39
40
41
42
      dropout=0.5,
      kernel_initializer='glorot_uniform',
      kernel_regularizer=None,
      interpolation='bilinear',
      **kwargs):
43
    """Initializes `SpatialPyramidPooling`.
Zhenyu Tan's avatar
Zhenyu Tan committed
44

45
    Args:
46
      output_channels: Number of channels produced by SpatialPyramidPooling.
Zhenyu Tan's avatar
Zhenyu Tan committed
47
      dilation_rates: A list of integers for parallel dilated conv.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
48
49
50
      pool_kernel_size: A list of integers or None. If None, global average
        pooling is applied, otherwise an average pooling of pool_kernel_size
        is applied.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
51
      use_sync_bn: A bool, whether or not to use sync batch normalization.
Zhenyu Tan's avatar
Zhenyu Tan committed
52
53
      batchnorm_momentum: A float for the momentum in BatchNorm. Defaults to
        0.99.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
54
55
      batchnorm_epsilon: A float for the epsilon value in BatchNorm. Defaults to
        0.001.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
56
      activation: A `str` for type of activation to be used. Defaults to 'relu'.
Zhenyu Tan's avatar
Zhenyu Tan committed
57
58
59
60
61
62
63
64
      dropout: A float for the dropout rate before output. Defaults to 0.5.
      kernel_initializer: Kernel initializer for conv layers. Defaults to
        `glorot_uniform`.
      kernel_regularizer: Kernel regularizer for conv layers. Defaults to None.
      interpolation: The interpolation method for upsampling. Defaults to
        `bilinear`.
      **kwargs: Other keyword arguments for the layer.
    """
65
    super(SpatialPyramidPooling, self).__init__(**kwargs)
Zhenyu Tan's avatar
Zhenyu Tan committed
66
67
68

    self.output_channels = output_channels
    self.dilation_rates = dilation_rates
Abdullah Rashwan's avatar
Abdullah Rashwan committed
69
    self.use_sync_bn = use_sync_bn
Zhenyu Tan's avatar
Zhenyu Tan committed
70
    self.batchnorm_momentum = batchnorm_momentum
Abdullah Rashwan's avatar
Abdullah Rashwan committed
71
    self.batchnorm_epsilon = batchnorm_epsilon
Abdullah Rashwan's avatar
Abdullah Rashwan committed
72
    self.activation = activation
Zhenyu Tan's avatar
Zhenyu Tan committed
73
74
75
76
77
    self.dropout = dropout
    self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
    self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
    self.interpolation = interpolation
    self.input_spec = tf.keras.layers.InputSpec(ndim=4)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
78
    self.pool_kernel_size = pool_kernel_size
Zhenyu Tan's avatar
Zhenyu Tan committed
79
80
81
82
83
84
85
86

  def build(self, input_shape):
    height = input_shape[1]
    width = input_shape[2]
    channels = input_shape[3]

    self.aspp_layers = []

Abdullah Rashwan's avatar
Abdullah Rashwan committed
87
88
89
90
91
92
93
94
95
96
    if self.use_sync_bn:
      bn_op = tf.keras.layers.experimental.SyncBatchNormalization
    else:
      bn_op = tf.keras.layers.BatchNormalization

    if tf.keras.backend.image_data_format() == 'channels_last':
      bn_axis = -1
    else:
      bn_axis = 1

Zhenyu Tan's avatar
Zhenyu Tan committed
97
98
99
100
    conv_sequential = tf.keras.Sequential([
        tf.keras.layers.Conv2D(
            filters=self.output_channels, kernel_size=(1, 1),
            kernel_initializer=self.kernel_initializer,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
101
102
103
104
105
106
            kernel_regularizer=self.kernel_regularizer,
            use_bias=False),
        bn_op(
            axis=bn_axis,
            momentum=self.batchnorm_momentum,
            epsilon=self.batchnorm_epsilon),
Abdullah Rashwan's avatar
Abdullah Rashwan committed
107
        tf.keras.layers.Activation(self.activation)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
108
    ])
Zhenyu Tan's avatar
Zhenyu Tan committed
109
110
111
112
113
114
115
116
117
    self.aspp_layers.append(conv_sequential)

    for dilation_rate in self.dilation_rates:
      conv_sequential = tf.keras.Sequential([
          tf.keras.layers.Conv2D(
              filters=self.output_channels, kernel_size=(3, 3),
              padding='same', kernel_regularizer=self.kernel_regularizer,
              kernel_initializer=self.kernel_initializer,
              dilation_rate=dilation_rate, use_bias=False),
Abdullah Rashwan's avatar
Abdullah Rashwan committed
118
119
          bn_op(axis=bn_axis, momentum=self.batchnorm_momentum,
                epsilon=self.batchnorm_epsilon),
Abdullah Rashwan's avatar
Abdullah Rashwan committed
120
          tf.keras.layers.Activation(self.activation)])
Zhenyu Tan's avatar
Zhenyu Tan committed
121
122
      self.aspp_layers.append(conv_sequential)

Abdullah Rashwan's avatar
Abdullah Rashwan committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    if self.pool_kernel_size is None:
      pool_sequential = tf.keras.Sequential([
          tf.keras.layers.GlobalAveragePooling2D(),
          tf.keras.layers.Reshape((1, 1, channels))])
    else:
      pool_sequential = tf.keras.Sequential([
          tf.keras.layers.AveragePooling2D(self.pool_kernel_size)])

    pool_sequential.add(
        tf.keras.Sequential([
            tf.keras.layers.Conv2D(
                filters=self.output_channels,
                kernel_size=(1, 1),
                kernel_initializer=self.kernel_initializer,
                kernel_regularizer=self.kernel_regularizer,
                use_bias=False),
            bn_op(
                axis=bn_axis,
                momentum=self.batchnorm_momentum,
                epsilon=self.batchnorm_epsilon),
            tf.keras.layers.Activation(self.activation),
            tf.keras.layers.experimental.preprocessing.Resizing(
Abdullah Rashwan's avatar
Abdullah Rashwan committed
145
146
147
148
                height,
                width,
                interpolation=self.interpolation,
                dtype=tf.float32)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
149
150
        ]))

Zhenyu Tan's avatar
Zhenyu Tan committed
151
152
153
154
155
156
    self.aspp_layers.append(pool_sequential)

    self.projection = tf.keras.Sequential([
        tf.keras.layers.Conv2D(
            filters=self.output_channels, kernel_size=(1, 1),
            kernel_initializer=self.kernel_initializer,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
157
158
159
160
161
162
            kernel_regularizer=self.kernel_regularizer,
            use_bias=False),
        bn_op(
            axis=bn_axis,
            momentum=self.batchnorm_momentum,
            epsilon=self.batchnorm_epsilon),
Abdullah Rashwan's avatar
Abdullah Rashwan committed
163
        tf.keras.layers.Activation(self.activation),
Zhenyu Tan's avatar
Zhenyu Tan committed
164
165
166
167
168
169
170
        tf.keras.layers.Dropout(rate=self.dropout)])

  def call(self, inputs, training=None):
    if training is None:
      training = tf.keras.backend.learning_phase()
    result = []
    for layer in self.aspp_layers:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
171
      result.append(tf.cast(layer(inputs, training=training), inputs.dtype))
Zhenyu Tan's avatar
Zhenyu Tan committed
172
173
174
175
176
177
178
179
    result = tf.concat(result, axis=-1)
    result = self.projection(result, training=training)
    return result

  def get_config(self):
    config = {
        'output_channels': self.output_channels,
        'dilation_rates': self.dilation_rates,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
180
        'pool_kernel_size': self.pool_kernel_size,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
181
        'use_sync_bn': self.use_sync_bn,
Zhenyu Tan's avatar
Zhenyu Tan committed
182
        'batchnorm_momentum': self.batchnorm_momentum,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
183
        'batchnorm_epsilon': self.batchnorm_epsilon,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
184
        'activation': self.activation,
Zhenyu Tan's avatar
Zhenyu Tan committed
185
186
187
188
189
190
191
        'dropout': self.dropout,
        'kernel_initializer': tf.keras.initializers.serialize(
            self.kernel_initializer),
        'kernel_regularizer': tf.keras.regularizers.serialize(
            self.kernel_regularizer),
        'interpolation': self.interpolation,
    }
192
    base_config = super(SpatialPyramidPooling, self).get_config()
Zhenyu Tan's avatar
Zhenyu Tan committed
193
    return dict(list(base_config.items()) + list(config.items()))