deeplab.py 6.61 KB
Newer Older
Zhenyu Tan's avatar
Zhenyu Tan committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layers for DeepLabV3."""

import tensorflow as tf


@tf.keras.utils.register_keras_serializable(package='keras_cv')
21
class SpatialPyramidPooling(tf.keras.layers.Layer):
Zhenyu Tan's avatar
Zhenyu Tan committed
22
23
24
25
26
27
28
29
30
31
32
  """Implements the Atrous Spatial Pyramid Pooling.

  Reference:
    [Rethinking Atrous Convolution for Semantic Image Segmentation](
      https://arxiv.org/pdf/1706.05587.pdf)
  """

  def __init__(
      self,
      output_channels,
      dilation_rates,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
33
      use_sync_bn=False,
Zhenyu Tan's avatar
Zhenyu Tan committed
34
      batchnorm_momentum=0.99,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
35
      batchnorm_epsilon=0.001,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
36
      activation='relu',
Zhenyu Tan's avatar
Zhenyu Tan committed
37
38
39
40
41
      dropout=0.5,
      kernel_initializer='glorot_uniform',
      kernel_regularizer=None,
      interpolation='bilinear',
      **kwargs):
42
    """Initializes `SpatialPyramidPooling`.
Zhenyu Tan's avatar
Zhenyu Tan committed
43
44

    Arguments:
45
      output_channels: Number of channels produced by SpatialPyramidPooling.
Zhenyu Tan's avatar
Zhenyu Tan committed
46
      dilation_rates: A list of integers for parallel dilated conv.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
47
      use_sync_bn: A bool, whether or not to use sync batch normalization.
Zhenyu Tan's avatar
Zhenyu Tan committed
48
49
      batchnorm_momentum: A float for the momentum in BatchNorm. Defaults to
        0.99.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
50
51
      batchnorm_epsilon: A float for the epsilon value in BatchNorm. Defaults to
        0.001.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
52
      activation: A `str` for type of activation to be used. Defaults to 'relu'.
Zhenyu Tan's avatar
Zhenyu Tan committed
53
54
55
56
57
58
59
60
      dropout: A float for the dropout rate before output. Defaults to 0.5.
      kernel_initializer: Kernel initializer for conv layers. Defaults to
        `glorot_uniform`.
      kernel_regularizer: Kernel regularizer for conv layers. Defaults to None.
      interpolation: The interpolation method for upsampling. Defaults to
        `bilinear`.
      **kwargs: Other keyword arguments for the layer.
    """
61
    super(SpatialPyramidPooling, self).__init__(**kwargs)
Zhenyu Tan's avatar
Zhenyu Tan committed
62
63
64

    self.output_channels = output_channels
    self.dilation_rates = dilation_rates
Abdullah Rashwan's avatar
Abdullah Rashwan committed
65
    self.use_sync_bn = use_sync_bn
Zhenyu Tan's avatar
Zhenyu Tan committed
66
    self.batchnorm_momentum = batchnorm_momentum
Abdullah Rashwan's avatar
Abdullah Rashwan committed
67
    self.batchnorm_epsilon = batchnorm_epsilon
Abdullah Rashwan's avatar
Abdullah Rashwan committed
68
    self.activation = activation
Zhenyu Tan's avatar
Zhenyu Tan committed
69
70
71
72
73
74
75
76
77
78
79
80
81
    self.dropout = dropout
    self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
    self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
    self.interpolation = interpolation
    self.input_spec = tf.keras.layers.InputSpec(ndim=4)

  def build(self, input_shape):
    height = input_shape[1]
    width = input_shape[2]
    channels = input_shape[3]

    self.aspp_layers = []

Abdullah Rashwan's avatar
Abdullah Rashwan committed
82
83
84
85
86
87
88
89
90
91
    if self.use_sync_bn:
      bn_op = tf.keras.layers.experimental.SyncBatchNormalization
    else:
      bn_op = tf.keras.layers.BatchNormalization

    if tf.keras.backend.image_data_format() == 'channels_last':
      bn_axis = -1
    else:
      bn_axis = 1

Zhenyu Tan's avatar
Zhenyu Tan committed
92
93
94
95
    conv_sequential = tf.keras.Sequential([
        tf.keras.layers.Conv2D(
            filters=self.output_channels, kernel_size=(1, 1),
            kernel_initializer=self.kernel_initializer,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
96
97
98
99
100
101
            kernel_regularizer=self.kernel_regularizer,
            use_bias=False),
        bn_op(
            axis=bn_axis,
            momentum=self.batchnorm_momentum,
            epsilon=self.batchnorm_epsilon),
Abdullah Rashwan's avatar
Abdullah Rashwan committed
102
        tf.keras.layers.Activation(self.activation)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
103
    ])
Zhenyu Tan's avatar
Zhenyu Tan committed
104
105
106
107
108
109
110
111
112
    self.aspp_layers.append(conv_sequential)

    for dilation_rate in self.dilation_rates:
      conv_sequential = tf.keras.Sequential([
          tf.keras.layers.Conv2D(
              filters=self.output_channels, kernel_size=(3, 3),
              padding='same', kernel_regularizer=self.kernel_regularizer,
              kernel_initializer=self.kernel_initializer,
              dilation_rate=dilation_rate, use_bias=False),
Abdullah Rashwan's avatar
Abdullah Rashwan committed
113
114
          bn_op(axis=bn_axis, momentum=self.batchnorm_momentum,
                epsilon=self.batchnorm_epsilon),
Abdullah Rashwan's avatar
Abdullah Rashwan committed
115
          tf.keras.layers.Activation(self.activation)])
Zhenyu Tan's avatar
Zhenyu Tan committed
116
117
118
119
120
121
122
123
      self.aspp_layers.append(conv_sequential)

    pool_sequential = tf.keras.Sequential([
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Reshape((1, 1, channels)),
        tf.keras.layers.Conv2D(
            filters=self.output_channels, kernel_size=(1, 1),
            kernel_initializer=self.kernel_initializer,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
124
125
126
127
128
129
            kernel_regularizer=self.kernel_regularizer,
            use_bias=False),
        bn_op(
            axis=bn_axis,
            momentum=self.batchnorm_momentum,
            epsilon=self.batchnorm_epsilon),
Abdullah Rashwan's avatar
Abdullah Rashwan committed
130
        tf.keras.layers.Activation(self.activation),
Zhenyu Tan's avatar
Zhenyu Tan committed
131
132
133
134
135
136
137
138
        tf.keras.layers.experimental.preprocessing.Resizing(
            height, width, interpolation=self.interpolation)])
    self.aspp_layers.append(pool_sequential)

    self.projection = tf.keras.Sequential([
        tf.keras.layers.Conv2D(
            filters=self.output_channels, kernel_size=(1, 1),
            kernel_initializer=self.kernel_initializer,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
139
140
141
142
143
144
            kernel_regularizer=self.kernel_regularizer,
            use_bias=False),
        bn_op(
            axis=bn_axis,
            momentum=self.batchnorm_momentum,
            epsilon=self.batchnorm_epsilon),
Abdullah Rashwan's avatar
Abdullah Rashwan committed
145
        tf.keras.layers.Activation(self.activation),
Zhenyu Tan's avatar
Zhenyu Tan committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
        tf.keras.layers.Dropout(rate=self.dropout)])

  def call(self, inputs, training=None):
    if training is None:
      training = tf.keras.backend.learning_phase()
    result = []
    for layer in self.aspp_layers:
      result.append(layer(inputs, training=training))
    result = tf.concat(result, axis=-1)
    result = self.projection(result, training=training)
    return result

  def get_config(self):
    config = {
        'output_channels': self.output_channels,
        'dilation_rates': self.dilation_rates,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
162
        'use_sync_bn': self.use_sync_bn,
Zhenyu Tan's avatar
Zhenyu Tan committed
163
        'batchnorm_momentum': self.batchnorm_momentum,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
164
        'batchnorm_epsilon': self.batchnorm_epsilon,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
165
        'activation': self.activation,
Zhenyu Tan's avatar
Zhenyu Tan committed
166
167
168
169
170
171
172
        'dropout': self.dropout,
        'kernel_initializer': tf.keras.initializers.serialize(
            self.kernel_initializer),
        'kernel_regularizer': tf.keras.regularizers.serialize(
            self.kernel_regularizer),
        'interpolation': self.interpolation,
    }
173
    base_config = super(SpatialPyramidPooling, self).get_config()
Zhenyu Tan's avatar
Zhenyu Tan committed
174
    return dict(list(base_config.items()) + list(config.items()))