"research/gan/stargan/train.py" did not exist on "4f7074f67c5ca838c7ed7a86b45ef6efd55be373"
fast_rcnn.py 1.65 KB
Newer Older
pangjm's avatar
pangjm committed
1
from .two_stage import TwoStageDetector
Kai Chen's avatar
Kai Chen committed
2
from ..registry import DETECTORS
pangjm's avatar
pangjm committed
3
4


Kai Chen's avatar
Kai Chen committed
5
@DETECTORS.register_module
pangjm's avatar
pangjm committed
6
7
8
9
10
11
12
13
14
15
16
17
18
class FastRCNN(TwoStageDetector):

    def __init__(self,
                 backbone,
                 neck,
                 bbox_roi_extractor,
                 bbox_head,
                 train_cfg,
                 test_cfg,
                 mask_roi_extractor=None,
                 mask_head=None,
                 pretrained=None):
        super(FastRCNN, self).__init__(
pangjm's avatar
pangjm committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
            backbone=backbone,
            neck=neck,
            bbox_roi_extractor=bbox_roi_extractor,
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            mask_roi_extractor=mask_roi_extractor,
            mask_head=mask_head,
            pretrained=pretrained)

    def forward_test(self, imgs, img_metas, proposals, **kwargs):
        for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))

        num_augs = len(imgs)
        if num_augs != len(img_metas):
            raise ValueError(
                'num of augmentations ({}) != num of image meta ({})'.format(
                    len(imgs), len(img_metas)))
        # TODO: remove the restriction of imgs_per_gpu == 1 when prepared
        imgs_per_gpu = imgs[0].size(0)
        assert imgs_per_gpu == 1

        if num_augs == 1:
            return self.simple_test(imgs[0], img_metas[0], proposals[0],
                                    **kwargs)
        else:
            return self.aug_test(imgs, img_metas, proposals, **kwargs)