base.py 4.49 KB
Newer Older
liyinhao's avatar
liyinhao committed
1
import copy
liyinhao's avatar
liyinhao committed
2
3
4
import mmcv
import torch
from mmcv.parallel import DataContainer as DC
zhangwenwei's avatar
zhangwenwei committed
5
from os import path as osp
liyinhao's avatar
liyinhao committed
6
7

from mmdet3d.core import Box3DMode, show_result
zhangwenwei's avatar
zhangwenwei committed
8
from mmdet.models.detectors import BaseDetector
zhangwenwei's avatar
zhangwenwei committed
9
10


zhangwenwei's avatar
zhangwenwei committed
11
class Base3DDetector(BaseDetector):
zhangwenwei's avatar
zhangwenwei committed
12
    """Base class for detectors."""
zhangwenwei's avatar
zhangwenwei committed
13

zhangwenwei's avatar
zhangwenwei committed
14
    def forward_test(self, points, img_metas, img=None, **kwargs):
zhangwenwei's avatar
zhangwenwei committed
15
16
        """
        Args:
liyinhao's avatar
liyinhao committed
17
18
            points (list[torch.Tensor]): the outer list indicates test-time
                augmentations and inner torch.Tensor should have a shape NxC,
zhangwenwei's avatar
zhangwenwei committed
19
                which contains all points in the batch.
liyinhao's avatar
liyinhao committed
20
            img_metas (list[list[dict]]): the outer list indicates test-time
zhangwenwei's avatar
zhangwenwei committed
21
22
                augs (multiscale, flip, etc.) and the inner list indicates
                images in a batch
liyinhao's avatar
liyinhao committed
23
24
25
26
            img (list[torch.Tensor], optional): the outer
                list indicates test-time augmentations and inner
                torch.Tensor should have a shape NxCxHxW, which contains
                all images in the batch. Defaults to None.
zhangwenwei's avatar
zhangwenwei committed
27
        """
zhangwenwei's avatar
zhangwenwei committed
28
        for var, name in [(points, 'points'), (img_metas, 'img_metas')]:
zhangwenwei's avatar
zhangwenwei committed
29
30
31
32
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))

zhangwenwei's avatar
zhangwenwei committed
33
        num_augs = len(points)
zhangwenwei's avatar
zhangwenwei committed
34
35
36
        if num_augs != len(img_metas):
            raise ValueError(
                'num of augmentations ({}) != num of image meta ({})'.format(
zhangwenwei's avatar
zhangwenwei committed
37
                    len(points), len(img_metas)))
zhangwenwei's avatar
zhangwenwei committed
38
        # TODO: remove the restriction of imgs_per_gpu == 1 when prepared
zhangwenwei's avatar
zhangwenwei committed
39
40
        samples_per_gpu = len(points[0])
        assert samples_per_gpu == 1
zhangwenwei's avatar
zhangwenwei committed
41
42

        if num_augs == 1:
zhangwenwei's avatar
zhangwenwei committed
43
44
            img = [img] if img is None else img
            return self.simple_test(points[0], img_metas[0], img[0], **kwargs)
zhangwenwei's avatar
zhangwenwei committed
45
        else:
zhangwenwei's avatar
zhangwenwei committed
46
            return self.aug_test(points, img_metas, img, **kwargs)
zhangwenwei's avatar
zhangwenwei committed
47

zhangwenwei's avatar
zhangwenwei committed
48
    def forward(self, return_loss=True, **kwargs):
zhangwenwei's avatar
zhangwenwei committed
49
50
51
52
53
        """Calls either forward_train or forward_test depending on whether
        return_loss=True.

        Note this setting will change the expected inputs. When
        `return_loss=True`, img and img_metas are single-nested (i.e.
liyinhao's avatar
liyinhao committed
54
        torch.Tensor and list[dict]), and when `resturn_loss=False`, img and
zhangwenwei's avatar
zhangwenwei committed
55
56
57
        img_metas should be double nested (i.e.  list[torch.Tensor],
        list[list[dict]]), with the outer list indicating test time
        augmentations.
zhangwenwei's avatar
zhangwenwei committed
58
59
        """
        if return_loss:
zhangwenwei's avatar
zhangwenwei committed
60
            return self.forward_train(**kwargs)
zhangwenwei's avatar
zhangwenwei committed
61
        else:
zhangwenwei's avatar
zhangwenwei committed
62
            return self.forward_test(**kwargs)
liyinhao's avatar
liyinhao committed
63
64

    def show_results(self, data, result, out_dir):
liyinhao's avatar
liyinhao committed
65
66
67
        """Results visualization.

        Args:
68
            data (dict): Input points and the information of the sample.
liyinhao's avatar
liyinhao committed
69
70
71
            result (dict): Prediction results.
            out_dir (str): Output directory of visualization result.
        """
liyinhao's avatar
liyinhao committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        if isinstance(data['points'][0], DC):
            points = data['points'][0]._data[0][0].numpy()
        elif mmcv.is_list_of(data['points'][0], torch.Tensor):
            points = data['points'][0][0]
        else:
            ValueError(f"Unsupported data type {type(data['points'][0])} "
                       f'for visualization!')
        if isinstance(data['img_metas'][0], DC):
            pts_filename = data['img_metas'][0]._data[0][0]['pts_filename']
            box_mode_3d = data['img_metas'][0]._data[0][0]['box_mode_3d']
        elif mmcv.is_list_of(data['img_metas'][0], dict):
            pts_filename = data['img_metas'][0][0]['pts_filename']
            box_mode_3d = data['img_metas'][0][0]['box_mode_3d']
        else:
            ValueError(f"Unsupported data type {type(data['img_metas'][0])} "
                       f'for visualization!')
liyinhao's avatar
liyinhao committed
88
89
90
91
        file_name = osp.split(pts_filename)[-1].split('.')[0]

        assert out_dir is not None, 'Expect out_dir, got none.'

liyinhao's avatar
liyinhao committed
92
        pred_bboxes = copy.deepcopy(result['boxes_3d'].tensor.numpy())
liyinhao's avatar
liyinhao committed
93
        # for now we convert points into depth mode
liyinhao's avatar
liyinhao committed
94
95
96
        if box_mode_3d == Box3DMode.DEPTH:
            pred_bboxes[..., 2] += pred_bboxes[..., 5] / 2
        elif box_mode_3d == Box3DMode.CAM or box_mode_3d == Box3DMode.LIDAR:
liyinhao's avatar
liyinhao committed
97
98
            points = points[..., [1, 0, 2]]
            points[..., 0] *= -1
liyinhao's avatar
liyinhao committed
99
100
            pred_bboxes = Box3DMode.convert(pred_bboxes, box_mode_3d,
                                            Box3DMode.DEPTH)
liyinhao's avatar
liyinhao committed
101
102
            pred_bboxes[..., 2] += pred_bboxes[..., 5] / 2
        else:
liyinhao's avatar
liyinhao committed
103
104
105
            ValueError(
                f'Unsupported box_mode_3d {box_mode_3d} for convertion!')

liyinhao's avatar
liyinhao committed
106
        show_result(points, None, pred_bboxes, out_dir, file_name)