base.py 2.28 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
from mmdet.models.detectors import BaseDetector
zhangwenwei's avatar
zhangwenwei committed
2
3


zhangwenwei's avatar
zhangwenwei committed
4
class Base3DDetector(BaseDetector):
zhangwenwei's avatar
zhangwenwei committed
5
6
    """Base class for detectors"""

zhangwenwei's avatar
zhangwenwei committed
7
    def forward_test(self, points, img_metas, img=None, **kwargs):
zhangwenwei's avatar
zhangwenwei committed
8
9
        """
        Args:
zhangwenwei's avatar
zhangwenwei committed
10
11
12
13
            points (List[Tensor]): the outer list indicates test-time
                augmentations and inner Tensor should have a shape NxC,
                which contains all points in the batch.
            img_metas (List[List[dict]]): the outer list indicates test-time
zhangwenwei's avatar
zhangwenwei committed
14
15
                augs (multiscale, flip, etc.) and the inner list indicates
                images in a batch
zhangwenwei's avatar
zhangwenwei committed
16
            img (List[Tensor], optional): the outer list indicates test-time
zhangwenwei's avatar
zhangwenwei committed
17
18
                augmentations and inner Tensor should have a shape NxCxHxW,
                which contains all images in the batch. Defaults to None.
zhangwenwei's avatar
zhangwenwei committed
19
        """
zhangwenwei's avatar
zhangwenwei committed
20
        for var, name in [(points, 'points'), (img_metas, 'img_metas')]:
zhangwenwei's avatar
zhangwenwei committed
21
22
23
24
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))

zhangwenwei's avatar
zhangwenwei committed
25
        num_augs = len(points)
zhangwenwei's avatar
zhangwenwei committed
26
27
28
        if num_augs != len(img_metas):
            raise ValueError(
                'num of augmentations ({}) != num of image meta ({})'.format(
zhangwenwei's avatar
zhangwenwei committed
29
                    len(points), len(img_metas)))
zhangwenwei's avatar
zhangwenwei committed
30
        # TODO: remove the restriction of imgs_per_gpu == 1 when prepared
zhangwenwei's avatar
zhangwenwei committed
31
32
        samples_per_gpu = len(points[0])
        assert samples_per_gpu == 1
zhangwenwei's avatar
zhangwenwei committed
33
34

        if num_augs == 1:
zhangwenwei's avatar
zhangwenwei committed
35
36
            img = [img] if img is None else img
            return self.simple_test(points[0], img_metas[0], img[0], **kwargs)
zhangwenwei's avatar
zhangwenwei committed
37
        else:
zhangwenwei's avatar
zhangwenwei committed
38
            return self.aug_test(points, img_metas, img, **kwargs)
zhangwenwei's avatar
zhangwenwei committed
39

zhangwenwei's avatar
zhangwenwei committed
40
    def forward(self, return_loss=True, **kwargs):
zhangwenwei's avatar
zhangwenwei committed
41
42
43
        """
        Calls either forward_train or forward_test depending on whether
        return_loss=True. Note this setting will change the expected inputs.
zhangwenwei's avatar
zhangwenwei committed
44
45
46
47
48
        When `return_loss=True`, img and img_metas are single-nested (i.e.
        Tensor and List[dict]), and when `resturn_loss=False`, img and
        img_metas should be double nested
        (i.e.  List[Tensor], List[List[dict]]), with the outer list
        indicating test time augmentations.
zhangwenwei's avatar
zhangwenwei committed
49
50
        """
        if return_loss:
zhangwenwei's avatar
zhangwenwei committed
51
            return self.forward_train(**kwargs)
zhangwenwei's avatar
zhangwenwei committed
52
        else:
zhangwenwei's avatar
zhangwenwei committed
53
            return self.forward_test(**kwargs)