base.py 2.33 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
from mmdet.models.detectors import BaseDetector
zhangwenwei's avatar
zhangwenwei committed
2
3


zhangwenwei's avatar
zhangwenwei committed
4
class Base3DDetector(BaseDetector):
zhangwenwei's avatar
zhangwenwei committed
5
6
    """Base class for detectors"""

zhangwenwei's avatar
zhangwenwei committed
7
    def forward_test(self, points, img_metas, img=None, **kwargs):
zhangwenwei's avatar
zhangwenwei committed
8
9
        """
        Args:
liyinhao's avatar
liyinhao committed
10
11
            points (list[torch.Tensor]): the outer list indicates test-time
                augmentations and inner torch.Tensor should have a shape NxC,
zhangwenwei's avatar
zhangwenwei committed
12
                which contains all points in the batch.
liyinhao's avatar
liyinhao committed
13
            img_metas (list[list[dict]]): the outer list indicates test-time
zhangwenwei's avatar
zhangwenwei committed
14
15
                augs (multiscale, flip, etc.) and the inner list indicates
                images in a batch
liyinhao's avatar
liyinhao committed
16
17
18
19
            img (list[torch.Tensor], optional): the outer
                list indicates test-time augmentations and inner
                torch.Tensor should have a shape NxCxHxW, which contains
                all images in the batch. Defaults to None.
zhangwenwei's avatar
zhangwenwei committed
20
        """
zhangwenwei's avatar
zhangwenwei committed
21
        for var, name in [(points, 'points'), (img_metas, 'img_metas')]:
zhangwenwei's avatar
zhangwenwei committed
22
23
24
25
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))

zhangwenwei's avatar
zhangwenwei committed
26
        num_augs = len(points)
zhangwenwei's avatar
zhangwenwei committed
27
28
29
        if num_augs != len(img_metas):
            raise ValueError(
                'num of augmentations ({}) != num of image meta ({})'.format(
zhangwenwei's avatar
zhangwenwei committed
30
                    len(points), len(img_metas)))
zhangwenwei's avatar
zhangwenwei committed
31
        # TODO: remove the restriction of imgs_per_gpu == 1 when prepared
zhangwenwei's avatar
zhangwenwei committed
32
33
        samples_per_gpu = len(points[0])
        assert samples_per_gpu == 1
zhangwenwei's avatar
zhangwenwei committed
34
35

        if num_augs == 1:
zhangwenwei's avatar
zhangwenwei committed
36
37
            img = [img] if img is None else img
            return self.simple_test(points[0], img_metas[0], img[0], **kwargs)
zhangwenwei's avatar
zhangwenwei committed
38
        else:
zhangwenwei's avatar
zhangwenwei committed
39
            return self.aug_test(points, img_metas, img, **kwargs)
zhangwenwei's avatar
zhangwenwei committed
40

zhangwenwei's avatar
zhangwenwei committed
41
    def forward(self, return_loss=True, **kwargs):
zhangwenwei's avatar
zhangwenwei committed
42
43
44
        """
        Calls either forward_train or forward_test depending on whether
        return_loss=True. Note this setting will change the expected inputs.
zhangwenwei's avatar
zhangwenwei committed
45
        When `return_loss=True`, img and img_metas are single-nested (i.e.
liyinhao's avatar
liyinhao committed
46
        torch.Tensor and list[dict]), and when `resturn_loss=False`, img and
zhangwenwei's avatar
zhangwenwei committed
47
        img_metas should be double nested
liyinhao's avatar
liyinhao committed
48
        (i.e.  list[torch.Tensor], list[list[dict]]), with the outer list
zhangwenwei's avatar
zhangwenwei committed
49
        indicating test time augmentations.
zhangwenwei's avatar
zhangwenwei committed
50
51
        """
        if return_loss:
zhangwenwei's avatar
zhangwenwei committed
52
            return self.forward_train(**kwargs)
zhangwenwei's avatar
zhangwenwei committed
53
        else:
zhangwenwei's avatar
zhangwenwei committed
54
            return self.forward_test(**kwargs)