fast_rcnn.py 2.27 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
from ..registry import DETECTORS
2
from .two_stage import TwoStageDetector
pangjm's avatar
pangjm committed
3
4


Kai Chen's avatar
Kai Chen committed
5
@DETECTORS.register_module
pangjm's avatar
pangjm committed
6
7
8
9
10
11
12
13
class FastRCNN(TwoStageDetector):

    def __init__(self,
                 backbone,
                 bbox_roi_extractor,
                 bbox_head,
                 train_cfg,
                 test_cfg,
myownskyW7's avatar
myownskyW7 committed
14
15
                 neck=None,
                 shared_head=None,
pangjm's avatar
pangjm committed
16
17
18
19
                 mask_roi_extractor=None,
                 mask_head=None,
                 pretrained=None):
        super(FastRCNN, self).__init__(
pangjm's avatar
pangjm committed
20
21
            backbone=backbone,
            neck=neck,
myownskyW7's avatar
myownskyW7 committed
22
            shared_head=shared_head,
pangjm's avatar
pangjm committed
23
24
25
26
27
28
29
30
31
            bbox_roi_extractor=bbox_roi_extractor,
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            mask_roi_extractor=mask_roi_extractor,
            mask_head=mask_head,
            pretrained=pretrained)

    def forward_test(self, imgs, img_metas, proposals, **kwargs):
32
33
34
35
36
37
38
39
40
41
42
        """
        Args:
            imgs (List[Tensor]): the outer list indicates test-time
                augmentations and inner Tensor should have a shape NxCxHxW,
                which contains all images in the batch.
            img_meta (List[List[dict]]): the outer list indicates test-time
                augs (multiscale, flip, etc.) and the inner list indicates
                images in a batch
            proposals (List[List[Tensor | None]]): predefiend proposals for
                each test-time augmentation and each item.
        """
pangjm's avatar
pangjm committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))

        num_augs = len(imgs)
        if num_augs != len(img_metas):
            raise ValueError(
                'num of augmentations ({}) != num of image meta ({})'.format(
                    len(imgs), len(img_metas)))
        # TODO: remove the restriction of imgs_per_gpu == 1 when prepared
        imgs_per_gpu = imgs[0].size(0)
        assert imgs_per_gpu == 1

        if num_augs == 1:
            return self.simple_test(imgs[0], img_metas[0], proposals[0],
                                    **kwargs)
        else:
            return self.aug_test(imgs, img_metas, proposals, **kwargs)