aspp.py 6.85 KB
Newer Older
Yeqing Li's avatar
Yeqing Li committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14

Fan Yang's avatar
Fan Yang committed
15
"""Contains definitions of Atrous Spatial Pyramid Pooling (ASPP) decoder."""
16
from typing import Any, List, Mapping, Optional
Abdullah Rashwan's avatar
Abdullah Rashwan committed
17
18

# Import libraries
19

Abdullah Rashwan's avatar
Abdullah Rashwan committed
20
21
import tensorflow as tf

22
from official.modeling import hyperparams
Abdullah Rashwan's avatar
Abdullah Rashwan committed
23
from official.vision import keras_cv
24
from official.vision.beta.modeling.decoders import factory
Abdullah Rashwan's avatar
Abdullah Rashwan committed
25
26
27
28


@tf.keras.utils.register_keras_serializable(package='Vision')
class ASPP(tf.keras.layers.Layer):
Fan Yang's avatar
Fan Yang committed
29
  """Creates an Atrous Spatial Pyramid Pooling (ASPP) layer."""
Abdullah Rashwan's avatar
Abdullah Rashwan committed
30

Fan Yang's avatar
Fan Yang committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
  def __init__(
      self,
      level: int,
      dilation_rates: List[int],
      num_filters: int = 256,
      pool_kernel_size: Optional[int] = None,
      use_sync_bn: bool = False,
      norm_momentum: float = 0.99,
      norm_epsilon: float = 0.001,
      activation: str = 'relu',
      dropout_rate: float = 0.0,
      kernel_initializer: str = 'VarianceScaling',
      kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
      interpolation: str = 'bilinear',
      **kwargs):
Fan Yang's avatar
Fan Yang committed
46
    """Initializes an Atrous Spatial Pyramid Pooling (ASPP) layer.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
47
48

    Args:
Fan Yang's avatar
Fan Yang committed
49
50
51
52
      level: An `int` level to apply ASPP.
      dilation_rates: A `list` of dilation rates.
      num_filters: An `int` number of output filters in ASPP.
      pool_kernel_size: A `list` of [height, width] of pooling kernel size or
Abdullah Rashwan's avatar
Abdullah Rashwan committed
53
54
        None. Pooling size is with respect to original image size, it will be
        scaled down by 2**level. If None, global average pooling is used.
Fan Yang's avatar
Fan Yang committed
55
56
57
58
59
60
61
62
63
64
65
66
67
      use_sync_bn: A `bool`. If True, use synchronized batch normalization.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A `float` added to variance to avoid dividing by zero.
      activation: A `str` activation to be used in ASPP.
      dropout_rate: A `float` rate for dropout regularization.
      kernel_initializer: A `str` name of kernel_initializer for convolutional
        layers.
      kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
        Conv2D. Default is None.
      interpolation: A `str` of interpolation method. It should be one of
        `bilinear`, `nearest`, `bicubic`, `area`, `lanczos3`, `lanczos5`,
        `gaussian`, or `mitchellcubic`.
      **kwargs: Additional keyword arguments to be passed.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
68
69
70
71
72
73
    """
    super(ASPP, self).__init__(**kwargs)
    self._config_dict = {
        'level': level,
        'dilation_rates': dilation_rates,
        'num_filters': num_filters,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
74
        'pool_kernel_size': pool_kernel_size,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
75
76
77
        'use_sync_bn': use_sync_bn,
        'norm_momentum': norm_momentum,
        'norm_epsilon': norm_epsilon,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
78
        'activation': activation,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
79
80
81
82
83
84
85
        'dropout_rate': dropout_rate,
        'kernel_initializer': kernel_initializer,
        'kernel_regularizer': kernel_regularizer,
        'interpolation': interpolation,
    }

  def build(self, input_shape):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
86
87
88
89
90
91
    pool_kernel_size = None
    if self._config_dict['pool_kernel_size']:
      pool_kernel_size = [
          int(p_size // 2**self._config_dict['level'])
          for p_size in self._config_dict['pool_kernel_size']
      ]
Abdullah Rashwan's avatar
Abdullah Rashwan committed
92
93
94
    self.aspp = keras_cv.layers.SpatialPyramidPooling(
        output_channels=self._config_dict['num_filters'],
        dilation_rates=self._config_dict['dilation_rates'],
Abdullah Rashwan's avatar
Abdullah Rashwan committed
95
        pool_kernel_size=pool_kernel_size,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
96
97
98
        use_sync_bn=self._config_dict['use_sync_bn'],
        batchnorm_momentum=self._config_dict['norm_momentum'],
        batchnorm_epsilon=self._config_dict['norm_epsilon'],
Abdullah Rashwan's avatar
Abdullah Rashwan committed
99
        activation=self._config_dict['activation'],
Abdullah Rashwan's avatar
Abdullah Rashwan committed
100
101
102
103
104
        dropout=self._config_dict['dropout_rate'],
        kernel_initializer=self._config_dict['kernel_initializer'],
        kernel_regularizer=self._config_dict['kernel_regularizer'],
        interpolation=self._config_dict['interpolation'])

Fan Yang's avatar
Fan Yang committed
105
  def call(self, inputs: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
Fan Yang's avatar
Fan Yang committed
106
    """Calls the Atrous Spatial Pyramid Pooling (ASPP) layer on an input.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
107

Fan Yang's avatar
Fan Yang committed
108
    The output of ASPP will be a dict of {`level`, `tf.Tensor`} even if only one
Abdullah Rashwan's avatar
Abdullah Rashwan committed
109
    level is present. Hence, this will be compatible with the rest of the
Fan Yang's avatar
Fan Yang committed
110
    segmentation model interfaces.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
111
112

    Args:
Fan Yang's avatar
Fan Yang committed
113
114
115
116
117
      inputs: A `dict` of `tf.Tensor` where
        - key: A `str` of the level of the multilevel feature maps.
        - values: A `tf.Tensor` of shape [batch, height_l, width_l,
          filter_size].

Abdullah Rashwan's avatar
Abdullah Rashwan committed
118
    Returns:
Fan Yang's avatar
Fan Yang committed
119
120
121
      A `dict` of `tf.Tensor` where
        - key: A `str` of the level of the multilevel feature maps.
        - values: A `tf.Tensor` of output of ASPP module.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
122
123
124
125
126
127
    """
    outputs = {}
    level = str(self._config_dict['level'])
    outputs[level] = self.aspp(inputs[level])
    return outputs

Fan Yang's avatar
Fan Yang committed
128
  def get_config(self) -> Mapping[str, Any]:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
129
130
131
132
133
    return self._config_dict

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176


@factory.register_decoder_builder('aspp')
def build_aspp_decoder(
    input_specs: Mapping[str, tf.TensorShape],
    model_config: hyperparams.Config,
    l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
  """Builds ASPP decoder from a config.

  Args:
    input_specs: A `dict` of input specifications. A dictionary consists of
      {level: TensorShape} from a backbone. Note this is for consistent
        interface, and is not used by ASPP decoder.
    model_config: A OneOfConfig. Model config.
    l2_regularizer: A `tf.keras.regularizers.Regularizer` instance. Default to
      None.

  Returns:
    A `tf.keras.Model` instance of the ASPP decoder.

  Raises:
    ValueError: If the model_config.decoder.type is not `aspp`.
  """
  del input_specs  # input_specs is not used by ASPP decoder.
  decoder_type = model_config.decoder.type
  decoder_cfg = model_config.decoder.get()
  if decoder_type != 'aspp':
    raise ValueError(f'Inconsistent decoder type {decoder_type}. '
                     'Need to be `aspp`.')

  norm_activation_config = model_config.norm_activation
  return ASPP(
      level=decoder_cfg.level,
      dilation_rates=decoder_cfg.dilation_rates,
      num_filters=decoder_cfg.num_filters,
      pool_kernel_size=decoder_cfg.pool_kernel_size,
      dropout_rate=decoder_cfg.dropout_rate,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      activation=norm_activation_config.activation,
      kernel_regularizer=l2_regularizer)