deeplab.py 5.52 KB
Newer Older
Zhenyu Tan's avatar
Zhenyu Tan committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layers for DeepLabV3."""

import tensorflow as tf


@tf.keras.utils.register_keras_serializable(package='keras_cv')
21
class SpatialPyramidPooling(tf.keras.layers.Layer):
Zhenyu Tan's avatar
Zhenyu Tan committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
  """Implements the Atrous Spatial Pyramid Pooling.

  Reference:
    [Rethinking Atrous Convolution for Semantic Image Segmentation](
      https://arxiv.org/pdf/1706.05587.pdf)
  """

  def __init__(
      self,
      output_channels,
      dilation_rates,
      batchnorm_momentum=0.99,
      dropout=0.5,
      kernel_initializer='glorot_uniform',
      kernel_regularizer=None,
      interpolation='bilinear',
      **kwargs):
39
    """Initializes `SpatialPyramidPooling`.
Zhenyu Tan's avatar
Zhenyu Tan committed
40
41

    Arguments:
42
      output_channels: Number of channels produced by SpatialPyramidPooling.
Zhenyu Tan's avatar
Zhenyu Tan committed
43
44
45
46
47
48
49
50
51
52
53
      dilation_rates: A list of integers for parallel dilated conv.
      batchnorm_momentum: A float for the momentum in BatchNorm. Defaults to
        0.99.
      dropout: A float for the dropout rate before output. Defaults to 0.5.
      kernel_initializer: Kernel initializer for conv layers. Defaults to
        `glorot_uniform`.
      kernel_regularizer: Kernel regularizer for conv layers. Defaults to None.
      interpolation: The interpolation method for upsampling. Defaults to
        `bilinear`.
      **kwargs: Other keyword arguments for the layer.
    """
54
    super(SpatialPyramidPooling, self).__init__(**kwargs)
Zhenyu Tan's avatar
Zhenyu Tan committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135

    self.output_channels = output_channels
    self.dilation_rates = dilation_rates
    self.batchnorm_momentum = batchnorm_momentum
    self.dropout = dropout
    self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
    self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
    self.interpolation = interpolation
    self.input_spec = tf.keras.layers.InputSpec(ndim=4)

  def build(self, input_shape):
    height = input_shape[1]
    width = input_shape[2]
    channels = input_shape[3]

    self.aspp_layers = []

    conv_sequential = tf.keras.Sequential([
        tf.keras.layers.Conv2D(
            filters=self.output_channels, kernel_size=(1, 1),
            kernel_initializer=self.kernel_initializer,
            kernel_regularizer=self.kernel_regularizer, use_bias=False),
        tf.keras.layers.BatchNormalization(momentum=self.batchnorm_momentum),
        tf.keras.layers.Activation('relu')])
    self.aspp_layers.append(conv_sequential)

    for dilation_rate in self.dilation_rates:
      conv_sequential = tf.keras.Sequential([
          tf.keras.layers.Conv2D(
              filters=self.output_channels, kernel_size=(3, 3),
              padding='same', kernel_regularizer=self.kernel_regularizer,
              kernel_initializer=self.kernel_initializer,
              dilation_rate=dilation_rate, use_bias=False),
          tf.keras.layers.BatchNormalization(momentum=self.batchnorm_momentum),
          tf.keras.layers.Activation('relu')])
      self.aspp_layers.append(conv_sequential)

    pool_sequential = tf.keras.Sequential([
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Reshape((1, 1, channels)),
        tf.keras.layers.Conv2D(
            filters=self.output_channels, kernel_size=(1, 1),
            kernel_initializer=self.kernel_initializer,
            kernel_regularizer=self.kernel_regularizer, use_bias=False),
        tf.keras.layers.BatchNormalization(momentum=self.batchnorm_momentum),
        tf.keras.layers.Activation('relu'),
        tf.keras.layers.experimental.preprocessing.Resizing(
            height, width, interpolation=self.interpolation)])
    self.aspp_layers.append(pool_sequential)

    self.projection = tf.keras.Sequential([
        tf.keras.layers.Conv2D(
            filters=self.output_channels, kernel_size=(1, 1),
            kernel_initializer=self.kernel_initializer,
            kernel_regularizer=self.kernel_regularizer, use_bias=False),
        tf.keras.layers.BatchNormalization(momentum=self.batchnorm_momentum),
        tf.keras.layers.Activation('relu'),
        tf.keras.layers.Dropout(rate=self.dropout)])

  def call(self, inputs, training=None):
    if training is None:
      training = tf.keras.backend.learning_phase()
    result = []
    for layer in self.aspp_layers:
      result.append(layer(inputs, training=training))
    result = tf.concat(result, axis=-1)
    result = self.projection(result, training=training)
    return result

  def get_config(self):
    config = {
        'output_channels': self.output_channels,
        'dilation_rates': self.dilation_rates,
        'batchnorm_momentum': self.batchnorm_momentum,
        'dropout': self.dropout,
        'kernel_initializer': tf.keras.initializers.serialize(
            self.kernel_initializer),
        'kernel_regularizer': tf.keras.regularizers.serialize(
            self.kernel_regularizer),
        'interpolation': self.interpolation,
    }
136
    base_config = super(SpatialPyramidPooling, self).get_config()
Zhenyu Tan's avatar
Zhenyu Tan committed
137
    return dict(list(base_config.items()) + list(config.items()))