test_rpn_heads.py 4.22 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


import logging
import unittest

import torch
from detectron2.modeling import build_anchor_generator, build_backbone
from detectron2.modeling.proposal_generator import rpn
from d2go.runner import GeneralizedRCNNRunner

logger = logging.getLogger(__name__)


# overwrite configs if specified, otherwise default config is used
RPN_CFGS = {}


class TestRPNHeads(unittest.TestCase):
    def test_build_rpn_heads(self):
        """ Make sure rpn heads run """

        self.assertGreater(len(rpn.RPN_HEAD_REGISTRY._obj_map), 0)

        for name, builder in rpn.RPN_HEAD_REGISTRY._obj_map.items():
            logger.info("Testing {}...".format(name))
            cfg = GeneralizedRCNNRunner().get_default_cfg()
            if name in RPN_CFGS:
                cfg.merge_from_file(RPN_CFGS[name])

            backbone = build_backbone(cfg)
            backbone_shape = backbone.output_shape()
            rpn_input_shape = [backbone_shape[x] for x in cfg.MODEL.RPN.IN_FEATURES]
            rpn_head = builder(cfg, rpn_input_shape)

            in_channels = list(backbone_shape.values())[0].channels
            num_anchors = build_anchor_generator(cfg, rpn_input_shape).num_cell_anchors[
                0
            ]

            N, C_in, H, W = 2, in_channels, 24, 32
            input = torch.rand([N, C_in, H, W], dtype=torch.float32)
            LAYERS = len(cfg.MODEL.RPN.IN_FEATURES)
            out = rpn_head([input] * LAYERS)
            self.assertEqual(len(out), 2)
            logits, bbox_reg = out
            for idx in range(LAYERS):
                self.assertEqual(
                    logits[idx].shape,
                    torch.Size(
                        [input.shape[0], num_anchors, input.shape[2], input.shape[3]]
                    ),
                )
                self.assertEqual(
                    bbox_reg[idx].shape,
                    torch.Size(
                        [
                            logits[idx].shape[0],
                            num_anchors * 4,
                            logits[idx].shape[2],
                            logits[idx].shape[3],
                        ]
                    ),
                )

    def test_build_rpn_heads_with_rotated_anchor_generator(self):
        """ Make sure rpn heads work with rotated anchor generator"""

        self.assertGreater(len(rpn.RPN_HEAD_REGISTRY._obj_map), 0)

        for name, builder in rpn.RPN_HEAD_REGISTRY._obj_map.items():
            logger.info("Testing {}...".format(name))
            cfg = GeneralizedRCNNRunner().get_default_cfg()
            if name in RPN_CFGS:
                cfg.merge_from_file(RPN_CFGS[name])

            cfg.MODEL.ANCHOR_GENERATOR.NAME = "RotatedAnchorGenerator"

            backbone = build_backbone(cfg)
            backbone_shape = backbone.output_shape()
            rpn_input_shape = [backbone_shape[x] for x in cfg.MODEL.RPN.IN_FEATURES]
            rpn_head = builder(cfg, rpn_input_shape)

            in_channels = list(backbone_shape.values())[0].channels
            anchor_generator = build_anchor_generator(cfg, rpn_input_shape)
            num_anchors = anchor_generator.num_cell_anchors[0]
            box_dim = anchor_generator.box_dim

            N, C_in, H, W = 2, in_channels, 24, 32
            input = torch.rand([N, C_in, H, W], dtype=torch.float32)
            LAYERS = len(cfg.MODEL.RPN.IN_FEATURES)
            out = rpn_head([input] * LAYERS)
            self.assertEqual(len(out), 2)
            logits, bbox_reg = out
            for idx in range(LAYERS):
                self.assertEqual(
                    logits[idx].shape,
                    torch.Size(
                        [input.shape[0], num_anchors, input.shape[2], input.shape[3]]
                    ),
                )
                self.assertEqual(
                    bbox_reg[idx].shape,
                    torch.Size(
                        [
                            logits[idx].shape[0],
                            num_anchors * box_dim,
                            logits[idx].shape[2],
                            logits[idx].shape[3],
                        ]
                    ),
                )


if __name__ == "__main__":
    unittest.main()