single_level.py 3.11 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
3
4
5
6
from __future__ import division

import torch
import torch.nn as nn

from mmdet import ops
Cao Yuhang's avatar
Cao Yuhang committed
7
from mmdet.core import force_fp32
Kai Chen's avatar
Kai Chen committed
8
from ..registry import ROI_EXTRACTORS
Kai Chen's avatar
Kai Chen committed
9
10


Kai Chen's avatar
Kai Chen committed
11
@ROI_EXTRACTORS.register_module
12
13
14
15
16
17
18
19
20
21
22
23
class SingleRoIExtractor(nn.Module):
    """Extract RoI features from a single level feature map.

    If there are mulitple input feature levels, each RoI is mapped to a level
    according to its scale.

    Args:
        roi_layer (dict): Specify RoI layer type and arguments.
        out_channels (int): Output channels of RoI layers.
        featmap_strides (int): Strides of input feature maps.
        finest_scale (int): Scale threshold of mapping to level 0.
    """
Kai Chen's avatar
Kai Chen committed
24
25
26
27
28
29

    def __init__(self,
                 roi_layer,
                 out_channels,
                 featmap_strides,
                 finest_scale=56):
30
        super(SingleRoIExtractor, self).__init__()
Kai Chen's avatar
Kai Chen committed
31
32
33
34
        self.roi_layers = self.build_roi_layers(roi_layer, featmap_strides)
        self.out_channels = out_channels
        self.featmap_strides = featmap_strides
        self.finest_scale = finest_scale
Cao Yuhang's avatar
Cao Yuhang committed
35
        self.fp16_enabled = False
Kai Chen's avatar
Kai Chen committed
36
37
38

    @property
    def num_inputs(self):
39
        """int: Input feature map levels."""
Kai Chen's avatar
Kai Chen committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
        return len(self.featmap_strides)

    def init_weights(self):
        pass

    def build_roi_layers(self, layer_cfg, featmap_strides):
        cfg = layer_cfg.copy()
        layer_type = cfg.pop('type')
        assert hasattr(ops, layer_type)
        layer_cls = getattr(ops, layer_type)
        roi_layers = nn.ModuleList(
            [layer_cls(spatial_scale=1 / s, **cfg) for s in featmap_strides])
        return roi_layers

    def map_roi_levels(self, rois, num_levels):
55
        """Map rois to corresponding feature levels by scales.
Kai Chen's avatar
Kai Chen committed
56

Jiangmiao Pang's avatar
Jiangmiao Pang committed
57
58
59
60
        - scale < finest_scale * 2: level 0
        - finest_scale * 2 <= scale < finest_scale * 4: level 1
        - finest_scale * 4 <= scale < finest_scale * 8: level 2
        - scale >= finest_scale * 8: level 3
61
62
63
64
65
66
67

        Args:
            rois (Tensor): Input RoIs, shape (k, 5).
            num_levels (int): Total level number.

        Returns:
            Tensor: Level index (0-based) of each RoI, shape (k, )
Kai Chen's avatar
Kai Chen committed
68
69
70
71
72
73
74
        """
        scale = torch.sqrt(
            (rois[:, 3] - rois[:, 1] + 1) * (rois[:, 4] - rois[:, 2] + 1))
        target_lvls = torch.floor(torch.log2(scale / self.finest_scale + 1e-6))
        target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
        return target_lvls

Cao Yuhang's avatar
Cao Yuhang committed
75
    @force_fp32(apply_to=('feats',), out_fp16=True)
Kai Chen's avatar
Kai Chen committed
76
77
78
79
80
81
82
    def forward(self, feats, rois):
        if len(feats) == 1:
            return self.roi_layers[0](feats[0], rois)

        out_size = self.roi_layers[0].out_size
        num_levels = len(feats)
        target_lvls = self.map_roi_levels(rois, num_levels)
Cao Yuhang's avatar
Cao Yuhang committed
83
84
        roi_feats = feats[0].new_zeros(rois.size()[0], self.out_channels,
                                       out_size, out_size)
Kai Chen's avatar
Kai Chen committed
85
86
87
88
89
90
91
        for i in range(num_levels):
            inds = target_lvls == i
            if inds.any():
                rois_ = rois[inds, :]
                roi_feats_t = self.roi_layers[i](feats[i], rois_)
                roi_feats[inds] += roi_feats_t
        return roi_feats