fast_rcnn.py 1.73 KB
Newer Older
pangjm's avatar
pangjm committed
1
from .two_stage import TwoStageDetector
Kai Chen's avatar
Kai Chen committed
2
from ..registry import DETECTORS
pangjm's avatar
pangjm committed
3
4


Kai Chen's avatar
Kai Chen committed
5
@DETECTORS.register_module
pangjm's avatar
pangjm committed
6
7
8
9
10
11
12
13
class FastRCNN(TwoStageDetector):

    def __init__(self,
                 backbone,
                 bbox_roi_extractor,
                 bbox_head,
                 train_cfg,
                 test_cfg,
myownskyW7's avatar
myownskyW7 committed
14
15
                 neck=None,
                 shared_head=None,
pangjm's avatar
pangjm committed
16
17
18
19
                 mask_roi_extractor=None,
                 mask_head=None,
                 pretrained=None):
        super(FastRCNN, self).__init__(
pangjm's avatar
pangjm committed
20
21
            backbone=backbone,
            neck=neck,
myownskyW7's avatar
myownskyW7 committed
22
            shared_head=shared_head,
pangjm's avatar
pangjm committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
            bbox_roi_extractor=bbox_roi_extractor,
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            mask_roi_extractor=mask_roi_extractor,
            mask_head=mask_head,
            pretrained=pretrained)

    def forward_test(self, imgs, img_metas, proposals, **kwargs):
        for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))

        num_augs = len(imgs)
        if num_augs != len(img_metas):
            raise ValueError(
                'num of augmentations ({}) != num of image meta ({})'.format(
                    len(imgs), len(img_metas)))
        # TODO: remove the restriction of imgs_per_gpu == 1 when prepared
        imgs_per_gpu = imgs[0].size(0)
        assert imgs_per_gpu == 1

        if num_augs == 1:
            return self.simple_test(imgs[0], img_metas[0], proposals[0],
                                    **kwargs)
        else:
            return self.aug_test(imgs, img_metas, proposals, **kwargs)