fpn.py 9.03 KB
Newer Older
Yeqing Li's avatar
Yeqing Li committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14

Fan Yang's avatar
Fan Yang committed
15
"""Contains the definitions of Feature Pyramid Networks (FPN)."""
Fan Yang's avatar
Fan Yang committed
16
from typing import Any, Mapping, Optional
Abdullah Rashwan's avatar
Abdullah Rashwan committed
17
18

# Import libraries
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
19
from absl import logging
Abdullah Rashwan's avatar
Abdullah Rashwan committed
20
21
import tensorflow as tf

22
from official.modeling import hyperparams
Abdullah Rashwan's avatar
Abdullah Rashwan committed
23
from official.modeling import tf_utils
24
from official.vision.beta.modeling.decoders import factory
Abdullah Rashwan's avatar
Abdullah Rashwan committed
25
26
27
28
29
from official.vision.beta.ops import spatial_transform_ops


@tf.keras.utils.register_keras_serializable(package='Vision')
class FPN(tf.keras.Model):
Fan Yang's avatar
Fan Yang committed
30
31
32
33
34
35
36
37
  """Creates a Feature Pyramid Network (FPN).

  This implemets the paper:
  Tsung-Yi Lin, Piotr Dollar, Ross Girshick, Kaiming He, Bharath Hariharan, and
  Serge Belongie.
  Feature Pyramid Networks for Object Detection.
  (https://arxiv.org/pdf/1612.03144)
  """
Abdullah Rashwan's avatar
Abdullah Rashwan committed
38

Fan Yang's avatar
Fan Yang committed
39
40
41
42
43
44
  def __init__(
      self,
      input_specs: Mapping[str, tf.TensorShape],
      min_level: int = 3,
      max_level: int = 7,
      num_filters: int = 256,
Xianzhi Du's avatar
Xianzhi Du committed
45
      fusion_type: str = 'sum',
Fan Yang's avatar
Fan Yang committed
46
47
48
49
50
51
52
53
54
      use_separable_conv: bool = False,
      activation: str = 'relu',
      use_sync_bn: bool = False,
      norm_momentum: float = 0.99,
      norm_epsilon: float = 0.001,
      kernel_initializer: str = 'VarianceScaling',
      kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
      bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
      **kwargs):
Fan Yang's avatar
Fan Yang committed
55
    """Initializes a Feature Pyramid Network (FPN).
Abdullah Rashwan's avatar
Abdullah Rashwan committed
56
57

    Args:
Fan Yang's avatar
Fan Yang committed
58
      input_specs: A `dict` of input specifications. A dictionary consists of
Abdullah Rashwan's avatar
Abdullah Rashwan committed
59
        {level: TensorShape} from a backbone.
Fan Yang's avatar
Fan Yang committed
60
61
62
      min_level: An `int` of minimum level in FPN output feature maps.
      max_level: An `int` of maximum level in FPN output feature maps.
      num_filters: An `int` number of filters in FPN layers.
Xianzhi Du's avatar
Xianzhi Du committed
63
64
      fusion_type: A `str` of `sum` or `concat`. Whether performing sum or
        concat for feature fusion.
Fan Yang's avatar
Fan Yang committed
65
      use_separable_conv: A `bool`.  If True use separable convolution for
Abdullah Rashwan's avatar
Abdullah Rashwan committed
66
        convolution in FPN layers.
Fan Yang's avatar
Fan Yang committed
67
68
69
70
71
72
73
74
75
76
      activation: A `str` name of the activation function.
      use_sync_bn: A `bool`. If True, use synchronized batch normalization.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A `float` added to variance to avoid dividing by zero.
      kernel_initializer: A `str` name of kernel_initializer for convolutional
        layers.
      kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
        Conv2D. Default is None.
      bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
      **kwargs: Additional keyword arguments to be passed.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
77
78
79
80
81
82
    """
    self._config_dict = {
        'input_specs': input_specs,
        'min_level': min_level,
        'max_level': max_level,
        'num_filters': num_filters,
Xianzhi Du's avatar
Xianzhi Du committed
83
        'fusion_type': fusion_type,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        'use_separable_conv': use_separable_conv,
        'activation': activation,
        'use_sync_bn': use_sync_bn,
        'norm_momentum': norm_momentum,
        'norm_epsilon': norm_epsilon,
        'kernel_initializer': kernel_initializer,
        'kernel_regularizer': kernel_regularizer,
        'bias_regularizer': bias_regularizer,
    }
    if use_separable_conv:
      conv2d = tf.keras.layers.SeparableConv2D
    else:
      conv2d = tf.keras.layers.Conv2D
    if use_sync_bn:
      norm = tf.keras.layers.experimental.SyncBatchNormalization
    else:
      norm = tf.keras.layers.BatchNormalization
    activation_fn = tf.keras.layers.Activation(
        tf_utils.get_activation(activation))

    # Build input feature pyramid.
    if tf.keras.backend.image_data_format() == 'channels_last':
      bn_axis = -1
    else:
      bn_axis = 1

    # Get input feature pyramid from backbone.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
111
    logging.info('FPN input_specs: %s', input_specs)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
112
    inputs = self._build_input_pyramid(input_specs, min_level)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
113
    backbone_max_level = min(int(max(inputs.keys())), max_level)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
114
115
116
117

    # Build lateral connections.
    feats_lateral = {}
    for level in range(min_level, backbone_max_level + 1):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
118
      feats_lateral[str(level)] = conv2d(
Abdullah Rashwan's avatar
Abdullah Rashwan committed
119
120
121
122
123
124
          filters=num_filters,
          kernel_size=1,
          padding='same',
          kernel_initializer=kernel_initializer,
          kernel_regularizer=kernel_regularizer,
          bias_regularizer=bias_regularizer)(
Abdullah Rashwan's avatar
Abdullah Rashwan committed
125
              inputs[str(level)])
Abdullah Rashwan's avatar
Abdullah Rashwan committed
126
127

    # Build top-down path.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
128
    feats = {str(backbone_max_level): feats_lateral[str(backbone_max_level)]}
Abdullah Rashwan's avatar
Abdullah Rashwan committed
129
    for level in range(backbone_max_level - 1, min_level - 1, -1):
Xianzhi Du's avatar
Xianzhi Du committed
130
131
132
133
134
135
136
137
138
139
      feat_a = spatial_transform_ops.nearest_upsampling(
          feats[str(level + 1)], 2)
      feat_b = feats_lateral[str(level)]

      if fusion_type == 'sum':
        feats[str(level)] = feat_a + feat_b
      elif fusion_type == 'concat':
        feats[str(level)] = tf.concat([feat_a, feat_b], axis=-1)
      else:
        raise ValueError('Fusion type {} not supported.'.format(fusion_type))
Abdullah Rashwan's avatar
Abdullah Rashwan committed
140
141
142
143

    # TODO(xianzhi): consider to remove bias in conv2d.
    # Build post-hoc 3x3 convolution kernel.
    for level in range(min_level, backbone_max_level + 1):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
144
      feats[str(level)] = conv2d(
Abdullah Rashwan's avatar
Abdullah Rashwan committed
145
146
147
148
149
150
151
          filters=num_filters,
          strides=1,
          kernel_size=3,
          padding='same',
          kernel_initializer=kernel_initializer,
          kernel_regularizer=kernel_regularizer,
          bias_regularizer=bias_regularizer)(
Abdullah Rashwan's avatar
Abdullah Rashwan committed
152
              feats[str(level)])
Abdullah Rashwan's avatar
Abdullah Rashwan committed
153
154
155
156

    # TODO(xianzhi): consider to remove bias in conv2d.
    # Build coarser FPN levels introduced for RetinaNet.
    for level in range(backbone_max_level + 1, max_level + 1):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
157
      feats_in = feats[str(level - 1)]
Abdullah Rashwan's avatar
Abdullah Rashwan committed
158
159
      if level > backbone_max_level + 1:
        feats_in = activation_fn(feats_in)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
160
      feats[str(level)] = conv2d(
Abdullah Rashwan's avatar
Abdullah Rashwan committed
161
162
163
164
165
166
167
168
169
170
171
          filters=num_filters,
          strides=2,
          kernel_size=3,
          padding='same',
          kernel_initializer=kernel_initializer,
          kernel_regularizer=kernel_regularizer,
          bias_regularizer=bias_regularizer)(
              feats_in)

    # Apply batch norm layers.
    for level in range(min_level, max_level + 1):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
172
      feats[str(level)] = norm(
Abdullah Rashwan's avatar
Abdullah Rashwan committed
173
          axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
Abdullah Rashwan's avatar
Abdullah Rashwan committed
174
              feats[str(level)])
Abdullah Rashwan's avatar
Abdullah Rashwan committed
175
176

    self._output_specs = {
Abdullah Rashwan's avatar
Abdullah Rashwan committed
177
        str(level): feats[str(level)].get_shape()
Abdullah Rashwan's avatar
Abdullah Rashwan committed
178
179
180
181
182
        for level in range(min_level, max_level + 1)
    }

    super(FPN, self).__init__(inputs=inputs, outputs=feats, **kwargs)

Fan Yang's avatar
Fan Yang committed
183
184
  def _build_input_pyramid(self, input_specs: Mapping[str, tf.TensorShape],
                           min_level: int):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
185
    assert isinstance(input_specs, dict)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
186
    if min(input_specs.keys()) > str(min_level):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
187
188
189
190
191
192
193
194
      raise ValueError(
          'Backbone min level should be less or equal to FPN min level')

    inputs = {}
    for level, spec in input_specs.items():
      inputs[level] = tf.keras.Input(shape=spec[1:])
    return inputs

Fan Yang's avatar
Fan Yang committed
195
  def get_config(self) -> Mapping[str, Any]:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
196
197
198
199
200
201
202
    return self._config_dict

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)

  @property
Fan Yang's avatar
Fan Yang committed
203
  def output_specs(self) -> Mapping[str, tf.TensorShape]:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
204
205
    """A dict of {level: TensorShape} pairs for the model output."""
    return self._output_specs
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239


@factory.register_decoder_builder('fpn')
def build_fpn_decoder(
    input_specs: Mapping[str, tf.TensorShape],
    model_config: hyperparams.Config,
    l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
  """Builds FPN decoder from a config.

  Args:
    input_specs: A `dict` of input specifications. A dictionary consists of
      {level: TensorShape} from a backbone.
    model_config: A OneOfConfig. Model config.
    l2_regularizer: A `tf.keras.regularizers.Regularizer` instance. Default to
      None.

  Returns:
    A `tf.keras.Model` instance of the FPN decoder.

  Raises:
    ValueError: If the model_config.decoder.type is not `fpn`.
  """
  decoder_type = model_config.decoder.type
  decoder_cfg = model_config.decoder.get()
  if decoder_type != 'fpn':
    raise ValueError(f'Inconsistent decoder type {decoder_type}. '
                     'Need to be `fpn`.')
  norm_activation_config = model_config.norm_activation
  return FPN(
      input_specs=input_specs,
      min_level=model_config.min_level,
      max_level=model_config.max_level,
      num_filters=decoder_cfg.num_filters,
Xianzhi Du's avatar
Xianzhi Du committed
240
      fusion_type=decoder_cfg.fusion_type,
241
242
243
244
245
246
      use_separable_conv=decoder_cfg.use_separable_conv,
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer)