deeplab.py 6.4 KB
Newer Older
Zhenyu Tan's avatar
Zhenyu Tan committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layers for DeepLabV3."""

import tensorflow as tf


@tf.keras.utils.register_keras_serializable(package='keras_cv')
21
class SpatialPyramidPooling(tf.keras.layers.Layer):
Zhenyu Tan's avatar
Zhenyu Tan committed
22
23
24
25
26
27
28
29
30
31
32
  """Implements the Atrous Spatial Pyramid Pooling.

  Reference:
    [Rethinking Atrous Convolution for Semantic Image Segmentation](
      https://arxiv.org/pdf/1706.05587.pdf)
  """

  def __init__(
      self,
      output_channels,
      dilation_rates,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
33
      use_sync_bn=False,
Zhenyu Tan's avatar
Zhenyu Tan committed
34
      batchnorm_momentum=0.99,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
35
      batchnorm_epsilon=0.001,
Zhenyu Tan's avatar
Zhenyu Tan committed
36
37
38
39
40
      dropout=0.5,
      kernel_initializer='glorot_uniform',
      kernel_regularizer=None,
      interpolation='bilinear',
      **kwargs):
41
    """Initializes `SpatialPyramidPooling`.
Zhenyu Tan's avatar
Zhenyu Tan committed
42
43

    Arguments:
44
      output_channels: Number of channels produced by SpatialPyramidPooling.
Zhenyu Tan's avatar
Zhenyu Tan committed
45
      dilation_rates: A list of integers for parallel dilated conv.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
46
      use_sync_bn: A bool, whether or not to use sync batch normalization.
Zhenyu Tan's avatar
Zhenyu Tan committed
47
48
      batchnorm_momentum: A float for the momentum in BatchNorm. Defaults to
        0.99.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
49
50
      batchnorm_epsilon: A float for the epsilon value in BatchNorm. Defaults to
        0.001.
Zhenyu Tan's avatar
Zhenyu Tan committed
51
52
53
54
55
56
57
58
      dropout: A float for the dropout rate before output. Defaults to 0.5.
      kernel_initializer: Kernel initializer for conv layers. Defaults to
        `glorot_uniform`.
      kernel_regularizer: Kernel regularizer for conv layers. Defaults to None.
      interpolation: The interpolation method for upsampling. Defaults to
        `bilinear`.
      **kwargs: Other keyword arguments for the layer.
    """
59
    super(SpatialPyramidPooling, self).__init__(**kwargs)
Zhenyu Tan's avatar
Zhenyu Tan committed
60
61
62

    self.output_channels = output_channels
    self.dilation_rates = dilation_rates
Abdullah Rashwan's avatar
Abdullah Rashwan committed
63
    self.use_sync_bn = use_sync_bn
Zhenyu Tan's avatar
Zhenyu Tan committed
64
    self.batchnorm_momentum = batchnorm_momentum
Abdullah Rashwan's avatar
Abdullah Rashwan committed
65
    self.batchnorm_epsilon = batchnorm_epsilon
Zhenyu Tan's avatar
Zhenyu Tan committed
66
67
68
69
70
71
72
73
74
75
76
77
78
    self.dropout = dropout
    self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
    self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
    self.interpolation = interpolation
    self.input_spec = tf.keras.layers.InputSpec(ndim=4)

  def build(self, input_shape):
    height = input_shape[1]
    width = input_shape[2]
    channels = input_shape[3]

    self.aspp_layers = []

Abdullah Rashwan's avatar
Abdullah Rashwan committed
79
80
81
82
83
84
85
86
87
88
    if self.use_sync_bn:
      bn_op = tf.keras.layers.experimental.SyncBatchNormalization
    else:
      bn_op = tf.keras.layers.BatchNormalization

    if tf.keras.backend.image_data_format() == 'channels_last':
      bn_axis = -1
    else:
      bn_axis = 1

Zhenyu Tan's avatar
Zhenyu Tan committed
89
90
91
92
    conv_sequential = tf.keras.Sequential([
        tf.keras.layers.Conv2D(
            filters=self.output_channels, kernel_size=(1, 1),
            kernel_initializer=self.kernel_initializer,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
93
94
95
96
97
98
99
100
            kernel_regularizer=self.kernel_regularizer,
            use_bias=False),
        bn_op(
            axis=bn_axis,
            momentum=self.batchnorm_momentum,
            epsilon=self.batchnorm_epsilon),
        tf.keras.layers.Activation('relu')
    ])
Zhenyu Tan's avatar
Zhenyu Tan committed
101
102
103
104
105
106
107
108
109
    self.aspp_layers.append(conv_sequential)

    for dilation_rate in self.dilation_rates:
      conv_sequential = tf.keras.Sequential([
          tf.keras.layers.Conv2D(
              filters=self.output_channels, kernel_size=(3, 3),
              padding='same', kernel_regularizer=self.kernel_regularizer,
              kernel_initializer=self.kernel_initializer,
              dilation_rate=dilation_rate, use_bias=False),
Abdullah Rashwan's avatar
Abdullah Rashwan committed
110
111
          bn_op(axis=bn_axis, momentum=self.batchnorm_momentum,
                epsilon=self.batchnorm_epsilon),
Zhenyu Tan's avatar
Zhenyu Tan committed
112
113
114
115
116
117
118
119
120
          tf.keras.layers.Activation('relu')])
      self.aspp_layers.append(conv_sequential)

    pool_sequential = tf.keras.Sequential([
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Reshape((1, 1, channels)),
        tf.keras.layers.Conv2D(
            filters=self.output_channels, kernel_size=(1, 1),
            kernel_initializer=self.kernel_initializer,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
121
122
123
124
125
126
            kernel_regularizer=self.kernel_regularizer,
            use_bias=False),
        bn_op(
            axis=bn_axis,
            momentum=self.batchnorm_momentum,
            epsilon=self.batchnorm_epsilon),
Zhenyu Tan's avatar
Zhenyu Tan committed
127
128
129
130
131
132
133
134
135
        tf.keras.layers.Activation('relu'),
        tf.keras.layers.experimental.preprocessing.Resizing(
            height, width, interpolation=self.interpolation)])
    self.aspp_layers.append(pool_sequential)

    self.projection = tf.keras.Sequential([
        tf.keras.layers.Conv2D(
            filters=self.output_channels, kernel_size=(1, 1),
            kernel_initializer=self.kernel_initializer,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
136
137
138
139
140
141
            kernel_regularizer=self.kernel_regularizer,
            use_bias=False),
        bn_op(
            axis=bn_axis,
            momentum=self.batchnorm_momentum,
            epsilon=self.batchnorm_epsilon),
Zhenyu Tan's avatar
Zhenyu Tan committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        tf.keras.layers.Activation('relu'),
        tf.keras.layers.Dropout(rate=self.dropout)])

  def call(self, inputs, training=None):
    if training is None:
      training = tf.keras.backend.learning_phase()
    result = []
    for layer in self.aspp_layers:
      result.append(layer(inputs, training=training))
    result = tf.concat(result, axis=-1)
    result = self.projection(result, training=training)
    return result

  def get_config(self):
    config = {
        'output_channels': self.output_channels,
        'dilation_rates': self.dilation_rates,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
159
        'use_sync_bn': self.use_sync_bn,
Zhenyu Tan's avatar
Zhenyu Tan committed
160
        'batchnorm_momentum': self.batchnorm_momentum,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
161
        'batchnorm_epsilon': self.batchnorm_epsilon,
Zhenyu Tan's avatar
Zhenyu Tan committed
162
163
164
165
166
167
168
        'dropout': self.dropout,
        'kernel_initializer': tf.keras.initializers.serialize(
            self.kernel_initializer),
        'kernel_regularizer': tf.keras.regularizers.serialize(
            self.kernel_regularizer),
        'interpolation': self.interpolation,
    }
169
    base_config = super(SpatialPyramidPooling, self).get_config()
Zhenyu Tan's avatar
Zhenyu Tan committed
170
    return dict(list(base_config.items()) + list(config.items()))