deeplab.py 8.15 KB
Newer Older
Yeqing Li's avatar
Yeqing Li committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Zhenyu Tan's avatar
Zhenyu Tan committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14

Zhenyu Tan's avatar
Zhenyu Tan committed
15
16
17
18
19
20
"""Layers for DeepLabV3."""

import tensorflow as tf


@tf.keras.utils.register_keras_serializable(package='keras_cv')
21
class SpatialPyramidPooling(tf.keras.layers.Layer):
Zhenyu Tan's avatar
Zhenyu Tan committed
22
23
  """Implements the Atrous Spatial Pyramid Pooling.

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
24
  References:
Zhenyu Tan's avatar
Zhenyu Tan committed
25
26
    [Rethinking Atrous Convolution for Semantic Image Segmentation](
      https://arxiv.org/pdf/1706.05587.pdf)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
27
28
    [Encoder-Decoder with Atrous Separable Convolution for Semantic Image
    Segmentation](https://arxiv.org/pdf/1802.02611.pdf)
Zhenyu Tan's avatar
Zhenyu Tan committed
29
30
31
32
33
34
  """

  def __init__(
      self,
      output_channels,
      dilation_rates,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
35
      pool_kernel_size=None,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
36
      use_sync_bn=False,
Zhenyu Tan's avatar
Zhenyu Tan committed
37
      batchnorm_momentum=0.99,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
38
      batchnorm_epsilon=0.001,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
39
      activation='relu',
Zhenyu Tan's avatar
Zhenyu Tan committed
40
41
42
43
      dropout=0.5,
      kernel_initializer='glorot_uniform',
      kernel_regularizer=None,
      interpolation='bilinear',
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
44
      use_depthwise_convolution=False,
Zhenyu Tan's avatar
Zhenyu Tan committed
45
      **kwargs):
46
    """Initializes `SpatialPyramidPooling`.
Zhenyu Tan's avatar
Zhenyu Tan committed
47

48
    Args:
49
      output_channels: Number of channels produced by SpatialPyramidPooling.
Zhenyu Tan's avatar
Zhenyu Tan committed
50
      dilation_rates: A list of integers for parallel dilated conv.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
51
52
53
      pool_kernel_size: A list of integers or None. If None, global average
        pooling is applied, otherwise an average pooling of pool_kernel_size
        is applied.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
54
      use_sync_bn: A bool, whether or not to use sync batch normalization.
Zhenyu Tan's avatar
Zhenyu Tan committed
55
56
      batchnorm_momentum: A float for the momentum in BatchNorm. Defaults to
        0.99.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
57
58
      batchnorm_epsilon: A float for the epsilon value in BatchNorm. Defaults to
        0.001.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
59
      activation: A `str` for type of activation to be used. Defaults to 'relu'.
Zhenyu Tan's avatar
Zhenyu Tan committed
60
61
62
63
64
65
      dropout: A float for the dropout rate before output. Defaults to 0.5.
      kernel_initializer: Kernel initializer for conv layers. Defaults to
        `glorot_uniform`.
      kernel_regularizer: Kernel regularizer for conv layers. Defaults to None.
      interpolation: The interpolation method for upsampling. Defaults to
        `bilinear`.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
66
67
68
69
      use_depthwise_convolution: Allows spatial pooling to be separable
         depthwise convolusions. [Encoder-Decoder with Atrous Separable
         Convolution for Semantic Image Segmentation](
         https://arxiv.org/pdf/1802.02611.pdf)
Zhenyu Tan's avatar
Zhenyu Tan committed
70
71
      **kwargs: Other keyword arguments for the layer.
    """
72
    super(SpatialPyramidPooling, self).__init__(**kwargs)
Zhenyu Tan's avatar
Zhenyu Tan committed
73
74
75

    self.output_channels = output_channels
    self.dilation_rates = dilation_rates
Abdullah Rashwan's avatar
Abdullah Rashwan committed
76
    self.use_sync_bn = use_sync_bn
Zhenyu Tan's avatar
Zhenyu Tan committed
77
    self.batchnorm_momentum = batchnorm_momentum
Abdullah Rashwan's avatar
Abdullah Rashwan committed
78
    self.batchnorm_epsilon = batchnorm_epsilon
Abdullah Rashwan's avatar
Abdullah Rashwan committed
79
    self.activation = activation
Zhenyu Tan's avatar
Zhenyu Tan committed
80
81
82
83
84
    self.dropout = dropout
    self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
    self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
    self.interpolation = interpolation
    self.input_spec = tf.keras.layers.InputSpec(ndim=4)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
85
    self.pool_kernel_size = pool_kernel_size
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
86
    self.use_depthwise_convolution = use_depthwise_convolution
Zhenyu Tan's avatar
Zhenyu Tan committed
87
88
89
90
91
92
93
94

  def build(self, input_shape):
    height = input_shape[1]
    width = input_shape[2]
    channels = input_shape[3]

    self.aspp_layers = []

Abdullah Rashwan's avatar
Abdullah Rashwan committed
95
96
97
98
99
100
101
102
103
104
    if self.use_sync_bn:
      bn_op = tf.keras.layers.experimental.SyncBatchNormalization
    else:
      bn_op = tf.keras.layers.BatchNormalization

    if tf.keras.backend.image_data_format() == 'channels_last':
      bn_axis = -1
    else:
      bn_axis = 1

Zhenyu Tan's avatar
Zhenyu Tan committed
105
106
107
108
    conv_sequential = tf.keras.Sequential([
        tf.keras.layers.Conv2D(
            filters=self.output_channels, kernel_size=(1, 1),
            kernel_initializer=self.kernel_initializer,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
109
110
111
112
113
114
            kernel_regularizer=self.kernel_regularizer,
            use_bias=False),
        bn_op(
            axis=bn_axis,
            momentum=self.batchnorm_momentum,
            epsilon=self.batchnorm_epsilon),
Abdullah Rashwan's avatar
Abdullah Rashwan committed
115
        tf.keras.layers.Activation(self.activation)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
116
    ])
Zhenyu Tan's avatar
Zhenyu Tan committed
117
118
119
    self.aspp_layers.append(conv_sequential)

    for dilation_rate in self.dilation_rates:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
120
121
122
123
124
125
126
127
128
129
130
131
      leading_layers = []
      kernel_size = (3, 3)
      if self.use_depthwise_convolution:
        leading_layers += [
            tf.keras.layers.DepthwiseConv2D(
                depth_multiplier=1, kernel_size=kernel_size,
                padding='same', depthwise_regularizer=self.kernel_regularizer,
                depthwise_initializer=self.kernel_initializer,
                dilation_rate=dilation_rate, use_bias=False)
        ]
        kernel_size = (1, 1)
      conv_sequential = tf.keras.Sequential(leading_layers + [
Zhenyu Tan's avatar
Zhenyu Tan committed
132
          tf.keras.layers.Conv2D(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
133
              filters=self.output_channels, kernel_size=kernel_size,
Zhenyu Tan's avatar
Zhenyu Tan committed
134
135
136
              padding='same', kernel_regularizer=self.kernel_regularizer,
              kernel_initializer=self.kernel_initializer,
              dilation_rate=dilation_rate, use_bias=False),
Abdullah Rashwan's avatar
Abdullah Rashwan committed
137
138
          bn_op(axis=bn_axis, momentum=self.batchnorm_momentum,
                epsilon=self.batchnorm_epsilon),
Abdullah Rashwan's avatar
Abdullah Rashwan committed
139
          tf.keras.layers.Activation(self.activation)])
Zhenyu Tan's avatar
Zhenyu Tan committed
140
141
      self.aspp_layers.append(conv_sequential)

Abdullah Rashwan's avatar
Abdullah Rashwan committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    if self.pool_kernel_size is None:
      pool_sequential = tf.keras.Sequential([
          tf.keras.layers.GlobalAveragePooling2D(),
          tf.keras.layers.Reshape((1, 1, channels))])
    else:
      pool_sequential = tf.keras.Sequential([
          tf.keras.layers.AveragePooling2D(self.pool_kernel_size)])

    pool_sequential.add(
        tf.keras.Sequential([
            tf.keras.layers.Conv2D(
                filters=self.output_channels,
                kernel_size=(1, 1),
                kernel_initializer=self.kernel_initializer,
                kernel_regularizer=self.kernel_regularizer,
                use_bias=False),
            bn_op(
                axis=bn_axis,
                momentum=self.batchnorm_momentum,
                epsilon=self.batchnorm_epsilon),
            tf.keras.layers.Activation(self.activation),
            tf.keras.layers.experimental.preprocessing.Resizing(
Abdullah Rashwan's avatar
Abdullah Rashwan committed
164
165
166
167
                height,
                width,
                interpolation=self.interpolation,
                dtype=tf.float32)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
168
169
        ]))

Zhenyu Tan's avatar
Zhenyu Tan committed
170
171
172
173
174
175
    self.aspp_layers.append(pool_sequential)

    self.projection = tf.keras.Sequential([
        tf.keras.layers.Conv2D(
            filters=self.output_channels, kernel_size=(1, 1),
            kernel_initializer=self.kernel_initializer,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
176
177
178
179
180
181
            kernel_regularizer=self.kernel_regularizer,
            use_bias=False),
        bn_op(
            axis=bn_axis,
            momentum=self.batchnorm_momentum,
            epsilon=self.batchnorm_epsilon),
Abdullah Rashwan's avatar
Abdullah Rashwan committed
182
        tf.keras.layers.Activation(self.activation),
Zhenyu Tan's avatar
Zhenyu Tan committed
183
184
185
186
187
188
189
        tf.keras.layers.Dropout(rate=self.dropout)])

  def call(self, inputs, training=None):
    if training is None:
      training = tf.keras.backend.learning_phase()
    result = []
    for layer in self.aspp_layers:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
190
      result.append(tf.cast(layer(inputs, training=training), inputs.dtype))
Zhenyu Tan's avatar
Zhenyu Tan committed
191
192
193
194
195
196
197
198
    result = tf.concat(result, axis=-1)
    result = self.projection(result, training=training)
    return result

  def get_config(self):
    config = {
        'output_channels': self.output_channels,
        'dilation_rates': self.dilation_rates,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
199
        'pool_kernel_size': self.pool_kernel_size,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
200
        'use_sync_bn': self.use_sync_bn,
Zhenyu Tan's avatar
Zhenyu Tan committed
201
        'batchnorm_momentum': self.batchnorm_momentum,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
202
        'batchnorm_epsilon': self.batchnorm_epsilon,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
203
        'activation': self.activation,
Zhenyu Tan's avatar
Zhenyu Tan committed
204
205
206
207
208
209
210
        'dropout': self.dropout,
        'kernel_initializer': tf.keras.initializers.serialize(
            self.kernel_initializer),
        'kernel_regularizer': tf.keras.regularizers.serialize(
            self.kernel_regularizer),
        'interpolation': self.interpolation,
    }
211
    base_config = super(SpatialPyramidPooling, self).get_config()
Zhenyu Tan's avatar
Zhenyu Tan committed
212
    return dict(list(base_config.items()) + list(config.items()))