aspp.py 5.17 KB
Newer Older
Yeqing Li's avatar
Yeqing Li committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14

Fan Yang's avatar
Fan Yang committed
15
"""Contains definitions of Atrous Spatial Pyramid Pooling (ASPP) decoder."""
Fan Yang's avatar
Fan Yang committed
16
from typing import Any, List, Optional, Mapping
Abdullah Rashwan's avatar
Abdullah Rashwan committed
17
18
19
20
21
22
23
24
25

# Import libraries
import tensorflow as tf

from official.vision import keras_cv


@tf.keras.utils.register_keras_serializable(package='Vision')
class ASPP(tf.keras.layers.Layer):
Fan Yang's avatar
Fan Yang committed
26
  """Creates an Atrous Spatial Pyramid Pooling (ASPP) layer."""
Abdullah Rashwan's avatar
Abdullah Rashwan committed
27

Fan Yang's avatar
Fan Yang committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
  def __init__(
      self,
      level: int,
      dilation_rates: List[int],
      num_filters: int = 256,
      pool_kernel_size: Optional[int] = None,
      use_sync_bn: bool = False,
      norm_momentum: float = 0.99,
      norm_epsilon: float = 0.001,
      activation: str = 'relu',
      dropout_rate: float = 0.0,
      kernel_initializer: str = 'VarianceScaling',
      kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
      interpolation: str = 'bilinear',
      **kwargs):
Fan Yang's avatar
Fan Yang committed
43
    """Initializes an Atrous Spatial Pyramid Pooling (ASPP) layer.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
44
45

    Args:
Fan Yang's avatar
Fan Yang committed
46
47
48
49
      level: An `int` level to apply ASPP.
      dilation_rates: A `list` of dilation rates.
      num_filters: An `int` number of output filters in ASPP.
      pool_kernel_size: A `list` of [height, width] of pooling kernel size or
Abdullah Rashwan's avatar
Abdullah Rashwan committed
50
51
        None. Pooling size is with respect to original image size, it will be
        scaled down by 2**level. If None, global average pooling is used.
Fan Yang's avatar
Fan Yang committed
52
53
54
55
56
57
58
59
60
61
62
63
64
      use_sync_bn: A `bool`. If True, use synchronized batch normalization.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A `float` added to variance to avoid dividing by zero.
      activation: A `str` activation to be used in ASPP.
      dropout_rate: A `float` rate for dropout regularization.
      kernel_initializer: A `str` name of kernel_initializer for convolutional
        layers.
      kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
        Conv2D. Default is None.
      interpolation: A `str` of interpolation method. It should be one of
        `bilinear`, `nearest`, `bicubic`, `area`, `lanczos3`, `lanczos5`,
        `gaussian`, or `mitchellcubic`.
      **kwargs: Additional keyword arguments to be passed.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
65
66
67
68
69
70
    """
    super(ASPP, self).__init__(**kwargs)
    self._config_dict = {
        'level': level,
        'dilation_rates': dilation_rates,
        'num_filters': num_filters,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
71
        'pool_kernel_size': pool_kernel_size,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
72
73
74
        'use_sync_bn': use_sync_bn,
        'norm_momentum': norm_momentum,
        'norm_epsilon': norm_epsilon,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
75
        'activation': activation,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
76
77
78
79
80
81
82
        'dropout_rate': dropout_rate,
        'kernel_initializer': kernel_initializer,
        'kernel_regularizer': kernel_regularizer,
        'interpolation': interpolation,
    }

  def build(self, input_shape):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
83
84
85
86
87
88
    pool_kernel_size = None
    if self._config_dict['pool_kernel_size']:
      pool_kernel_size = [
          int(p_size // 2**self._config_dict['level'])
          for p_size in self._config_dict['pool_kernel_size']
      ]
Abdullah Rashwan's avatar
Abdullah Rashwan committed
89
90
91
    self.aspp = keras_cv.layers.SpatialPyramidPooling(
        output_channels=self._config_dict['num_filters'],
        dilation_rates=self._config_dict['dilation_rates'],
Abdullah Rashwan's avatar
Abdullah Rashwan committed
92
        pool_kernel_size=pool_kernel_size,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
93
94
95
        use_sync_bn=self._config_dict['use_sync_bn'],
        batchnorm_momentum=self._config_dict['norm_momentum'],
        batchnorm_epsilon=self._config_dict['norm_epsilon'],
Abdullah Rashwan's avatar
Abdullah Rashwan committed
96
        activation=self._config_dict['activation'],
Abdullah Rashwan's avatar
Abdullah Rashwan committed
97
98
99
100
101
        dropout=self._config_dict['dropout_rate'],
        kernel_initializer=self._config_dict['kernel_initializer'],
        kernel_regularizer=self._config_dict['kernel_regularizer'],
        interpolation=self._config_dict['interpolation'])

Fan Yang's avatar
Fan Yang committed
102
  def call(self, inputs: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
Fan Yang's avatar
Fan Yang committed
103
    """Calls the Atrous Spatial Pyramid Pooling (ASPP) layer on an input.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
104

Fan Yang's avatar
Fan Yang committed
105
    The output of ASPP will be a dict of {`level`, `tf.Tensor`} even if only one
Abdullah Rashwan's avatar
Abdullah Rashwan committed
106
    level is present. Hence, this will be compatible with the rest of the
Fan Yang's avatar
Fan Yang committed
107
    segmentation model interfaces.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
108
109

    Args:
Fan Yang's avatar
Fan Yang committed
110
111
112
113
114
      inputs: A `dict` of `tf.Tensor` where
        - key: A `str` of the level of the multilevel feature maps.
        - values: A `tf.Tensor` of shape [batch, height_l, width_l,
          filter_size].

Abdullah Rashwan's avatar
Abdullah Rashwan committed
115
    Returns:
Fan Yang's avatar
Fan Yang committed
116
117
118
      A `dict` of `tf.Tensor` where
        - key: A `str` of the level of the multilevel feature maps.
        - values: A `tf.Tensor` of output of ASPP module.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
119
120
121
122
123
124
    """
    outputs = {}
    level = str(self._config_dict['level'])
    outputs[level] = self.aspp(inputs[level])
    return outputs

Fan Yang's avatar
Fan Yang committed
125
  def get_config(self) -> Mapping[str, Any]:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
126
127
128
129
130
    return self._config_dict

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)