feature_pyramid_network.py 8.5 KB
Newer Older
1
from collections import OrderedDict
limm's avatar
limm committed
2
from typing import Callable, Dict, List, Optional, Tuple
3
4

import torch.nn.functional as F
eellison's avatar
eellison committed
5
6
from torch import nn, Tensor

limm's avatar
limm committed
7
8
from ..ops.misc import Conv2dNormActivation
from ..utils import _log_api_usage_once
9
10
11
12
13
14


class ExtraFPNBlock(nn.Module):
    """
    Base class for the extra block in the FPN.

15
    Args:
16
17
18
19
20
21
22
23
24
25
        results (List[Tensor]): the result of the FPN
        x (List[Tensor]): the original feature maps
        names (List[str]): the names for each one of the
            original feature maps

    Returns:
        results (List[Tensor]): the extended set of results
            of the FPN
        names (List[str]): the extended set of names for the results
    """
limm's avatar
limm committed
26

27
28
29
30
31
32
33
    def forward(
        self,
        results: List[Tensor],
        x: List[Tensor],
        names: List[str],
    ) -> Tuple[List[Tensor], List[str]]:
        pass
34
35
36
37


class FeaturePyramidNetwork(nn.Module):
    """
38
39
40
    Module that adds a FPN from on top of a set of feature maps. This is based on
    `"Feature Pyramid Network for Object Detection" <https://arxiv.org/abs/1612.03144>`_.

41
    The feature maps are currently supposed to be in increasing depth
42
43
44
45
46
    order.

    The input to the model is expected to be an OrderedDict[Tensor], containing
    the feature maps on top of which the FPN will be added.

47
    Args:
48
49
50
51
52
53
54
        in_channels_list (list[int]): number of channels for each feature map that
            is passed to the module
        out_channels (int): number of channels of the FPN representation
        extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
            be performed. It is expected to take the fpn features, the original
            features and the names of the original features as input, and returns
            a new list of feature maps and their corresponding names
limm's avatar
limm committed
55
        norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

    Examples::

        >>> m = torchvision.ops.FeaturePyramidNetwork([10, 20, 30], 5)
        >>> # get some dummy data
        >>> x = OrderedDict()
        >>> x['feat0'] = torch.rand(1, 10, 64, 64)
        >>> x['feat2'] = torch.rand(1, 20, 16, 16)
        >>> x['feat3'] = torch.rand(1, 30, 8, 8)
        >>> # compute the FPN on top of x
        >>> output = m(x)
        >>> print([(k, v.shape) for k, v in output.items()])
        >>> # returns
        >>>   [('feat0', torch.Size([1, 5, 64, 64])),
        >>>    ('feat2', torch.Size([1, 5, 16, 16])),
        >>>    ('feat3', torch.Size([1, 5, 8, 8]))]

73
    """
limm's avatar
limm committed
74
75
76

    _version = 2

77
78
79
80
81
    def __init__(
        self,
        in_channels_list: List[int],
        out_channels: int,
        extra_blocks: Optional[ExtraFPNBlock] = None,
limm's avatar
limm committed
82
        norm_layer: Optional[Callable[..., nn.Module]] = None,
83
    ):
limm's avatar
limm committed
84
85
        super().__init__()
        _log_api_usage_once(self)
86
87
88
89
        self.inner_blocks = nn.ModuleList()
        self.layer_blocks = nn.ModuleList()
        for in_channels in in_channels_list:
            if in_channels == 0:
eellison's avatar
eellison committed
90
                raise ValueError("in_channels=0 is currently not supported")
limm's avatar
limm committed
91
92
93
94
95
96
            inner_block_module = Conv2dNormActivation(
                in_channels, out_channels, kernel_size=1, padding=0, norm_layer=norm_layer, activation_layer=None
            )
            layer_block_module = Conv2dNormActivation(
                out_channels, out_channels, kernel_size=3, norm_layer=norm_layer, activation_layer=None
            )
97
98
99
100
            self.inner_blocks.append(inner_block_module)
            self.layer_blocks.append(layer_block_module)

        # initialize parameters now to avoid modifying the initialization of top_blocks
101
        for m in self.modules():
102
103
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight, a=1)
limm's avatar
limm committed
104
105
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
106
107

        if extra_blocks is not None:
limm's avatar
limm committed
108
109
            if not isinstance(extra_blocks, ExtraFPNBlock):
                raise TypeError(f"extra_blocks should be of type ExtraFPNBlock not {type(extra_blocks)}")
110
111
        self.extra_blocks = extra_blocks

limm's avatar
limm committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        version = local_metadata.get("version", None)

        if version is None or version < 2:
            num_blocks = len(self.inner_blocks)
            for block in ["inner_blocks", "layer_blocks"]:
                for i in range(num_blocks):
                    for type in ["weight", "bias"]:
                        old_key = f"{prefix}{block}.{i}.{type}"
                        new_key = f"{prefix}{block}.{i}.0.{type}"
                        if old_key in state_dict:
                            state_dict[new_key] = state_dict.pop(old_key)

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )

144
    def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor:
eellison's avatar
eellison committed
145
146
147
148
        """
        This is equivalent to self.inner_blocks[idx](x),
        but torchscript doesn't support this yet
        """
149
        num_blocks = len(self.inner_blocks)
eellison's avatar
eellison committed
150
151
152
        if idx < 0:
            idx += num_blocks
        out = x
limm's avatar
limm committed
153
        for i, module in enumerate(self.inner_blocks):
eellison's avatar
eellison committed
154
155
156
157
            if i == idx:
                out = module(x)
        return out

158
    def get_result_from_layer_blocks(self, x: Tensor, idx: int) -> Tensor:
eellison's avatar
eellison committed
159
160
161
162
        """
        This is equivalent to self.layer_blocks[idx](x),
        but torchscript doesn't support this yet
        """
163
        num_blocks = len(self.layer_blocks)
eellison's avatar
eellison committed
164
165
166
        if idx < 0:
            idx += num_blocks
        out = x
limm's avatar
limm committed
167
        for i, module in enumerate(self.layer_blocks):
eellison's avatar
eellison committed
168
169
170
171
            if i == idx:
                out = module(x)
        return out

172
    def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
173
        """
174
175
        Computes the FPN for a set of feature maps.

176
        Args:
177
            x (OrderedDict[Tensor]): feature maps for each feature level.
178

179
180
        Returns:
            results (OrderedDict[Tensor]): feature maps after FPN layers.
limm's avatar
limm committed
181
                They are ordered from the highest resolution first.
182
183
184
185
186
        """
        # unpack OrderedDict into two lists for easier handling
        names = list(x.keys())
        x = list(x.values())

eellison's avatar
eellison committed
187
        last_inner = self.get_result_from_inner_blocks(x[-1], -1)
188
        results = []
eellison's avatar
eellison committed
189
190
191
192
        results.append(self.get_result_from_layer_blocks(last_inner, -1))

        for idx in range(len(x) - 2, -1, -1):
            inner_lateral = self.get_result_from_inner_blocks(x[idx], idx)
193
194
195
            feat_shape = inner_lateral.shape[-2:]
            inner_top_down = F.interpolate(last_inner, size=feat_shape, mode="nearest")
            last_inner = inner_lateral + inner_top_down
eellison's avatar
eellison committed
196
            results.insert(0, self.get_result_from_layer_blocks(last_inner, idx))
197
198
199
200
201
202
203
204
205
206
207

        if self.extra_blocks is not None:
            results, names = self.extra_blocks(results, x, names)

        # make it back an OrderedDict
        out = OrderedDict([(k, v) for k, v in zip(names, results)])

        return out


class LastLevelMaxPool(ExtraFPNBlock):
208
    """
limm's avatar
limm committed
209
    Applies a max_pool2d (not actual max_pool2d, we just subsample) on top of the last feature map
210
    """
limm's avatar
limm committed
211

212
213
214
215
216
217
    def forward(
        self,
        x: List[Tensor],
        y: List[Tensor],
        names: List[str],
    ) -> Tuple[List[Tensor], List[str]]:
218
        names.append("pool")
limm's avatar
limm committed
219
220
        # Use max pooling to simulate stride 2 subsampling
        x.append(F.max_pool2d(x[-1], kernel_size=1, stride=2, padding=0))
221
222
223
224
225
226
227
        return x, names


class LastLevelP6P7(ExtraFPNBlock):
    """
    This module is used in RetinaNet to generate extra layers, P6 and P7.
    """
limm's avatar
limm committed
228

229
    def __init__(self, in_channels: int, out_channels: int):
limm's avatar
limm committed
230
        super().__init__()
231
232
233
234
235
236
237
        self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
        self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
        for module in [self.p6, self.p7]:
            nn.init.kaiming_uniform_(module.weight, a=1)
            nn.init.constant_(module.bias, 0)
        self.use_P5 = in_channels == out_channels

238
239
240
241
242
243
    def forward(
        self,
        p: List[Tensor],
        c: List[Tensor],
        names: List[str],
    ) -> Tuple[List[Tensor], List[str]]:
244
245
246
247
248
249
250
        p5, c5 = p[-1], c[-1]
        x = p5 if self.use_P5 else c5
        p6 = self.p6(x)
        p7 = self.p7(F.relu(p6))
        p.extend([p6, p7])
        names.extend(["p6", "p7"])
        return p, names