base.py 5.43 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
liyinhao's avatar
liyinhao committed
2
3
4
import mmcv
import torch
from mmcv.parallel import DataContainer as DC
5
from mmcv.runner import auto_fp16
zhangwenwei's avatar
zhangwenwei committed
6
from os import path as osp
liyinhao's avatar
liyinhao committed
7

8
from mmdet3d.core import Box3DMode, Coord3DMode, show_result
zhangwenwei's avatar
zhangwenwei committed
9
from mmdet.models.detectors import BaseDetector
zhangwenwei's avatar
zhangwenwei committed
10
11


zhangwenwei's avatar
zhangwenwei committed
12
class Base3DDetector(BaseDetector):
zhangwenwei's avatar
zhangwenwei committed
13
    """Base class for detectors."""
zhangwenwei's avatar
zhangwenwei committed
14

zhangwenwei's avatar
zhangwenwei committed
15
    def forward_test(self, points, img_metas, img=None, **kwargs):
zhangwenwei's avatar
zhangwenwei committed
16
17
        """
        Args:
liyinhao's avatar
liyinhao committed
18
19
            points (list[torch.Tensor]): the outer list indicates test-time
                augmentations and inner torch.Tensor should have a shape NxC,
zhangwenwei's avatar
zhangwenwei committed
20
                which contains all points in the batch.
liyinhao's avatar
liyinhao committed
21
            img_metas (list[list[dict]]): the outer list indicates test-time
zhangwenwei's avatar
zhangwenwei committed
22
23
                augs (multiscale, flip, etc.) and the inner list indicates
                images in a batch
liyinhao's avatar
liyinhao committed
24
25
26
27
            img (list[torch.Tensor], optional): the outer
                list indicates test-time augmentations and inner
                torch.Tensor should have a shape NxCxHxW, which contains
                all images in the batch. Defaults to None.
zhangwenwei's avatar
zhangwenwei committed
28
        """
zhangwenwei's avatar
zhangwenwei committed
29
        for var, name in [(points, 'points'), (img_metas, 'img_metas')]:
zhangwenwei's avatar
zhangwenwei committed
30
31
32
33
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))

zhangwenwei's avatar
zhangwenwei committed
34
        num_augs = len(points)
zhangwenwei's avatar
zhangwenwei committed
35
36
37
        if num_augs != len(img_metas):
            raise ValueError(
                'num of augmentations ({}) != num of image meta ({})'.format(
zhangwenwei's avatar
zhangwenwei committed
38
                    len(points), len(img_metas)))
zhangwenwei's avatar
zhangwenwei committed
39
40

        if num_augs == 1:
zhangwenwei's avatar
zhangwenwei committed
41
42
            img = [img] if img is None else img
            return self.simple_test(points[0], img_metas[0], img[0], **kwargs)
zhangwenwei's avatar
zhangwenwei committed
43
        else:
zhangwenwei's avatar
zhangwenwei committed
44
            return self.aug_test(points, img_metas, img, **kwargs)
zhangwenwei's avatar
zhangwenwei committed
45

46
    @auto_fp16(apply_to=('img', 'points'))
zhangwenwei's avatar
zhangwenwei committed
47
    def forward(self, return_loss=True, **kwargs):
zhangwenwei's avatar
zhangwenwei committed
48
49
50
51
52
        """Calls either forward_train or forward_test depending on whether
        return_loss=True.

        Note this setting will change the expected inputs. When
        `return_loss=True`, img and img_metas are single-nested (i.e.
53
54
        torch.Tensor and list[dict]), and when `resturn_loss=False`, img and
        img_metas should be double nested (i.e.  list[torch.Tensor],
zhangwenwei's avatar
zhangwenwei committed
55
56
        list[list[dict]]), with the outer list indicating test time
        augmentations.
zhangwenwei's avatar
zhangwenwei committed
57
58
        """
        if return_loss:
zhangwenwei's avatar
zhangwenwei committed
59
            return self.forward_train(**kwargs)
zhangwenwei's avatar
zhangwenwei committed
60
        else:
zhangwenwei's avatar
zhangwenwei committed
61
            return self.forward_test(**kwargs)
liyinhao's avatar
liyinhao committed
62

MilkClouds's avatar
MilkClouds committed
63
    def show_results(self, data, result, out_dir, show=False, score_thr=None):
liyinhao's avatar
liyinhao committed
64
65
66
        """Results visualization.

        Args:
67
68
            data (list[dict]): Input points and the information of the sample.
            result (list[dict]): Prediction results.
liyinhao's avatar
liyinhao committed
69
            out_dir (str): Output directory of visualization result.
MilkClouds's avatar
MilkClouds committed
70
71
72
73
74
            show (bool, optional): Determines whether you are
                going to show result by open3d.
                Defaults to False.
            score_thr (float, optional): Score threshold of bounding boxes.
                Default to None.
liyinhao's avatar
liyinhao committed
75
        """
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        for batch_id in range(len(result)):
            if isinstance(data['points'][0], DC):
                points = data['points'][0]._data[0][batch_id].numpy()
            elif mmcv.is_list_of(data['points'][0], torch.Tensor):
                points = data['points'][0][batch_id]
            else:
                ValueError(f"Unsupported data type {type(data['points'][0])} "
                           f'for visualization!')
            if isinstance(data['img_metas'][0], DC):
                pts_filename = data['img_metas'][0]._data[0][batch_id][
                    'pts_filename']
                box_mode_3d = data['img_metas'][0]._data[0][batch_id][
                    'box_mode_3d']
            elif mmcv.is_list_of(data['img_metas'][0], dict):
                pts_filename = data['img_metas'][0][batch_id]['pts_filename']
                box_mode_3d = data['img_metas'][0][batch_id]['box_mode_3d']
            else:
                ValueError(
                    f"Unsupported data type {type(data['img_metas'][0])} "
                    f'for visualization!')
            file_name = osp.split(pts_filename)[-1].split('.')[0]
liyinhao's avatar
liyinhao committed
97

98
            assert out_dir is not None, 'Expect out_dir, got none.'
liyinhao's avatar
liyinhao committed
99

100
            pred_bboxes = result[batch_id]['boxes_3d']
MilkClouds's avatar
MilkClouds committed
101
102
103
104
105
106
            pred_labels = result[batch_id]['labels_3d']

            if score_thr is not None:
                mask = result[batch_id]['scores_3d'] > score_thr
                pred_bboxes = pred_bboxes[mask]
                pred_labels = pred_labels[mask]
107
108
109
110
111
112

            # for now we convert points and bbox into depth mode
            if (box_mode_3d == Box3DMode.CAM) or (box_mode_3d
                                                  == Box3DMode.LIDAR):
                points = Coord3DMode.convert_point(points, Coord3DMode.LIDAR,
                                                   Coord3DMode.DEPTH)
113
114
                pred_bboxes = Box3DMode.convert(pred_bboxes, box_mode_3d,
                                                Box3DMode.DEPTH)
115
            elif box_mode_3d != Box3DMode.DEPTH:
116
                ValueError(
117
                    f'Unsupported box_mode_3d {box_mode_3d} for conversion!')
118
            pred_bboxes = pred_bboxes.tensor.cpu().numpy()
MilkClouds's avatar
MilkClouds committed
119
120
121
122
123
124
125
126
            show_result(
                points,
                None,
                pred_bboxes,
                out_dir,
                file_name,
                show=show,
                pred_labels=pred_labels)