fpn.py 4.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Feature Pyramid Networks.

Feature Pyramid Networks were proposed in:
[1] Tsung-Yi Lin, Piotr Dollar, Ross Girshick, Kaiming He, Bharath Hariharan,
    , and Serge Belongie
    Feature Pyramid Networks for Object Detection. CVPR 2017.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow.compat.v2 as tf

from tensorflow.python.keras import backend
from official.vision.detection.modeling.architecture import nn_ops
31
from official.vision.detection.ops import spatial_transform_ops
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93


class Fpn(object):
  """Feature pyramid networks."""

  def __init__(self,
               min_level=3,
               max_level=7,
               fpn_feat_dims=256,
               batch_norm_relu=nn_ops.BatchNormRelu):
    """FPN initialization function.

    Args:
      min_level: `int` minimum level in FPN output feature maps.
      max_level: `int` maximum level in FPN output feature maps.
      fpn_feat_dims: `int` number of filters in FPN layers.
      batch_norm_relu: an operation that includes a batch normalization layer
        followed by a relu layer(optional).
    """
    self._min_level = min_level
    self._max_level = max_level
    self._fpn_feat_dims = fpn_feat_dims

    self._batch_norm_relus = {}
    for level in range(self._min_level, self._max_level + 1):
      self._batch_norm_relus[level] = batch_norm_relu(
          relu=False, name='p%d-bn' % level)

  def __call__(self, multilevel_features, is_training=None):
    """Returns the FPN features for a given multilevel features.

    Args:
      multilevel_features: a `dict` containing `int` keys for continuous feature
        levels, e.g., [2, 3, 4, 5]. The values are corresponding features with
        shape [batch_size, height_l, width_l, num_filters].
      is_training: `bool` if True, the model is in training mode.

    Returns:
      a `dict` containing `int` keys for continuous feature levels
      [min_level, min_level + 1, ..., max_level]. The values are corresponding
      FPN features with shape [batch_size, height_l, width_l, fpn_feat_dims].
    """
    input_levels = multilevel_features.keys()
    if min(input_levels) > self._min_level:
      raise ValueError(
          'The minimum backbone level %d should be '%(min(input_levels)) +
          'less or equal to FPN minimum level %d.:'%(self._min_level))
    backbone_max_level = min(max(input_levels), self._max_level)
    with backend.get_graph().as_default(), tf.name_scope('fpn'):
      # Adds lateral connections.
      feats_lateral = {}
      for level in range(self._min_level, backbone_max_level + 1):
        feats_lateral[level] = tf.keras.layers.Conv2D(
            filters=self._fpn_feat_dims,
            kernel_size=(1, 1),
            padding='same',
            name='l%d' % level)(
                multilevel_features[level])

      # Adds top-down path.
      feats = {backbone_max_level: feats_lateral[backbone_max_level]}
      for level in range(backbone_max_level - 1, self._min_level - 1, -1):
94
        feats[level] = spatial_transform_ops.nearest_upsampling(
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
            feats[level + 1], 2) + feats_lateral[level]

      # Adds post-hoc 3x3 convolution kernel.
      for level in range(self._min_level, backbone_max_level + 1):
        feats[level] = tf.keras.layers.Conv2D(
            filters=self._fpn_feat_dims,
            strides=(1, 1),
            kernel_size=(3, 3),
            padding='same',
            name='post_hoc_d%d' % level)(
                feats[level])

      # Adds coarser FPN levels introduced for RetinaNet.
      for level in range(backbone_max_level + 1, self._max_level + 1):
        feats_in = feats[level - 1]
        if level > backbone_max_level + 1:
          feats_in = tf.nn.relu(feats_in)
        feats[level] = tf.keras.layers.Conv2D(
            filters=self._fpn_feat_dims,
            strides=(2, 2),
            kernel_size=(3, 3),
            padding='same',
            name='p%d' % level)(
                feats_in)
      # Adds batch_norm layer.
      for level in range(self._min_level, self._max_level + 1):
        feats[level] = self._batch_norm_relus[level](
            feats[level], is_training=is_training)
    return feats