feature_pyramid_network.py 8.32 KB
Newer Older
1
from collections import OrderedDict
2
from typing import Tuple, List, Dict, Callable, Optional
3
4

import torch.nn.functional as F
eellison's avatar
eellison committed
5
6
from torch import nn, Tensor

7
from ..ops.misc import Conv2dNormActivation
8
9
from ..utils import _log_api_usage_once

10
11
12
13
14

class ExtraFPNBlock(nn.Module):
    """
    Base class for the extra block in the FPN.

15
    Args:
16
17
18
19
20
21
22
23
24
25
        results (List[Tensor]): the result of the FPN
        x (List[Tensor]): the original feature maps
        names (List[str]): the names for each one of the
            original feature maps

    Returns:
        results (List[Tensor]): the extended set of results
            of the FPN
        names (List[str]): the extended set of names for the results
    """
26

27
28
29
30
31
32
33
    def forward(
        self,
        results: List[Tensor],
        x: List[Tensor],
        names: List[str],
    ) -> Tuple[List[Tensor], List[str]]:
        pass
34
35
36
37


class FeaturePyramidNetwork(nn.Module):
    """
38
39
40
    Module that adds a FPN from on top of a set of feature maps. This is based on
    `"Feature Pyramid Network for Object Detection" <https://arxiv.org/abs/1612.03144>`_.

41
    The feature maps are currently supposed to be in increasing depth
42
43
44
45
46
    order.

    The input to the model is expected to be an OrderedDict[Tensor], containing
    the feature maps on top of which the FPN will be added.

47
    Args:
48
49
50
51
52
53
54
        in_channels_list (list[int]): number of channels for each feature map that
            is passed to the module
        out_channels (int): number of channels of the FPN representation
        extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
            be performed. It is expected to take the fpn features, the original
            features and the names of the original features as input, and returns
            a new list of feature maps and their corresponding names
55
        norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

    Examples::

        >>> m = torchvision.ops.FeaturePyramidNetwork([10, 20, 30], 5)
        >>> # get some dummy data
        >>> x = OrderedDict()
        >>> x['feat0'] = torch.rand(1, 10, 64, 64)
        >>> x['feat2'] = torch.rand(1, 20, 16, 16)
        >>> x['feat3'] = torch.rand(1, 30, 8, 8)
        >>> # compute the FPN on top of x
        >>> output = m(x)
        >>> print([(k, v.shape) for k, v in output.items()])
        >>> # returns
        >>>   [('feat0', torch.Size([1, 5, 64, 64])),
        >>>    ('feat2', torch.Size([1, 5, 16, 16])),
        >>>    ('feat3', torch.Size([1, 5, 8, 8]))]

73
    """
74

75
76
    _version = 2

77
78
79
80
81
    def __init__(
        self,
        in_channels_list: List[int],
        out_channels: int,
        extra_blocks: Optional[ExtraFPNBlock] = None,
82
        norm_layer: Optional[Callable[..., nn.Module]] = None,
83
    ):
84
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
85
        _log_api_usage_once(self)
86
87
88
89
        self.inner_blocks = nn.ModuleList()
        self.layer_blocks = nn.ModuleList()
        for in_channels in in_channels_list:
            if in_channels == 0:
eellison's avatar
eellison committed
90
                raise ValueError("in_channels=0 is currently not supported")
91
92
93
94
95
96
            inner_block_module = Conv2dNormActivation(
                in_channels, out_channels, kernel_size=1, padding=0, norm_layer=norm_layer, activation_layer=None
            )
            layer_block_module = Conv2dNormActivation(
                out_channels, out_channels, kernel_size=3, norm_layer=norm_layer, activation_layer=None
            )
97
98
99
100
            self.inner_blocks.append(inner_block_module)
            self.layer_blocks.append(layer_block_module)

        # initialize parameters now to avoid modifying the initialization of top_blocks
101
        for m in self.modules():
102
103
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight, a=1)
104
105
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
106
107

        if extra_blocks is not None:
108
109
            if not isinstance(extra_blocks, ExtraFPNBlock):
                raise TypeError(f"extra_blocks should be of type ExtraFPNBlock not {type(extra_blocks)}")
110
111
        self.extra_blocks = extra_blocks

112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        version = local_metadata.get("version", None)

        if version is None or version < 2:
            num_blocks = len(self.inner_blocks)
            for block in ["inner_blocks", "layer_blocks"]:
                for i in range(num_blocks):
                    for type in ["weight", "bias"]:
                        old_key = f"{prefix}{block}.{i}.{type}"
                        new_key = f"{prefix}{block}.{i}.0.{type}"
                        state_dict[new_key] = state_dict.pop(old_key)

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )

143
    def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor:
eellison's avatar
eellison committed
144
145
146
147
        """
        This is equivalent to self.inner_blocks[idx](x),
        but torchscript doesn't support this yet
        """
148
        num_blocks = len(self.inner_blocks)
eellison's avatar
eellison committed
149
150
151
        if idx < 0:
            idx += num_blocks
        out = x
152
        for i, module in enumerate(self.inner_blocks):
eellison's avatar
eellison committed
153
154
155
156
            if i == idx:
                out = module(x)
        return out

157
    def get_result_from_layer_blocks(self, x: Tensor, idx: int) -> Tensor:
eellison's avatar
eellison committed
158
159
160
161
        """
        This is equivalent to self.layer_blocks[idx](x),
        but torchscript doesn't support this yet
        """
162
        num_blocks = len(self.layer_blocks)
eellison's avatar
eellison committed
163
164
165
        if idx < 0:
            idx += num_blocks
        out = x
166
        for i, module in enumerate(self.layer_blocks):
eellison's avatar
eellison committed
167
168
169
170
            if i == idx:
                out = module(x)
        return out

171
    def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
172
        """
173
174
        Computes the FPN for a set of feature maps.

175
        Args:
176
            x (OrderedDict[Tensor]): feature maps for each feature level.
177

178
179
180
181
182
183
184
185
        Returns:
            results (OrderedDict[Tensor]): feature maps after FPN layers.
                They are ordered from highest resolution first.
        """
        # unpack OrderedDict into two lists for easier handling
        names = list(x.keys())
        x = list(x.values())

eellison's avatar
eellison committed
186
        last_inner = self.get_result_from_inner_blocks(x[-1], -1)
187
        results = []
eellison's avatar
eellison committed
188
189
190
191
        results.append(self.get_result_from_layer_blocks(last_inner, -1))

        for idx in range(len(x) - 2, -1, -1):
            inner_lateral = self.get_result_from_inner_blocks(x[idx], idx)
192
193
194
            feat_shape = inner_lateral.shape[-2:]
            inner_top_down = F.interpolate(last_inner, size=feat_shape, mode="nearest")
            last_inner = inner_lateral + inner_top_down
eellison's avatar
eellison committed
195
            results.insert(0, self.get_result_from_layer_blocks(last_inner, idx))
196
197
198
199
200
201
202
203
204
205
206

        if self.extra_blocks is not None:
            results, names = self.extra_blocks(results, x, names)

        # make it back an OrderedDict
        out = OrderedDict([(k, v) for k, v in zip(names, results)])

        return out


class LastLevelMaxPool(ExtraFPNBlock):
207
208
209
    """
    Applies a max_pool2d on top of the last feature map
    """
210

211
212
213
214
215
216
    def forward(
        self,
        x: List[Tensor],
        y: List[Tensor],
        names: List[str],
    ) -> Tuple[List[Tensor], List[str]]:
217
218
219
220
221
222
223
224
225
        names.append("pool")
        x.append(F.max_pool2d(x[-1], 1, 2, 0))
        return x, names


class LastLevelP6P7(ExtraFPNBlock):
    """
    This module is used in RetinaNet to generate extra layers, P6 and P7.
    """
226

227
    def __init__(self, in_channels: int, out_channels: int):
228
        super().__init__()
229
230
231
232
233
234
235
        self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
        self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
        for module in [self.p6, self.p7]:
            nn.init.kaiming_uniform_(module.weight, a=1)
            nn.init.constant_(module.bias, 0)
        self.use_P5 = in_channels == out_channels

236
237
238
239
240
241
    def forward(
        self,
        p: List[Tensor],
        c: List[Tensor],
        names: List[str],
    ) -> Tuple[List[Tensor], List[str]]:
242
243
244
245
246
247
248
        p5, c5 = p[-1], c[-1]
        x = p5 if self.use_P5 else c5
        p6 = self.p6(x)
        p7 = self.p7(F.relu(p6))
        p.extend([p6, p7])
        names.extend(["p6", "p7"])
        return p, names