base.py 5.43 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from os import path as osp

liyinhao's avatar
liyinhao committed
4
5
6
import mmcv
import torch
from mmcv.parallel import DataContainer as DC
7
from mmcv.runner import auto_fp16
liyinhao's avatar
liyinhao committed
8

9
from mmdet3d.core import Box3DMode, Coord3DMode, show_result
zhangwenwei's avatar
zhangwenwei committed
10
from mmdet.models.detectors import BaseDetector
zhangwenwei's avatar
zhangwenwei committed
11
12


zhangwenwei's avatar
zhangwenwei committed
13
class Base3DDetector(BaseDetector):
zhangwenwei's avatar
zhangwenwei committed
14
    """Base class for detectors."""
zhangwenwei's avatar
zhangwenwei committed
15

zhangwenwei's avatar
zhangwenwei committed
16
    def forward_test(self, points, img_metas, img=None, **kwargs):
zhangwenwei's avatar
zhangwenwei committed
17
18
        """
        Args:
liyinhao's avatar
liyinhao committed
19
20
            points (list[torch.Tensor]): the outer list indicates test-time
                augmentations and inner torch.Tensor should have a shape NxC,
zhangwenwei's avatar
zhangwenwei committed
21
                which contains all points in the batch.
liyinhao's avatar
liyinhao committed
22
            img_metas (list[list[dict]]): the outer list indicates test-time
zhangwenwei's avatar
zhangwenwei committed
23
24
                augs (multiscale, flip, etc.) and the inner list indicates
                images in a batch
liyinhao's avatar
liyinhao committed
25
26
27
28
            img (list[torch.Tensor], optional): the outer
                list indicates test-time augmentations and inner
                torch.Tensor should have a shape NxCxHxW, which contains
                all images in the batch. Defaults to None.
zhangwenwei's avatar
zhangwenwei committed
29
        """
zhangwenwei's avatar
zhangwenwei committed
30
        for var, name in [(points, 'points'), (img_metas, 'img_metas')]:
zhangwenwei's avatar
zhangwenwei committed
31
32
33
34
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))

zhangwenwei's avatar
zhangwenwei committed
35
        num_augs = len(points)
zhangwenwei's avatar
zhangwenwei committed
36
37
38
        if num_augs != len(img_metas):
            raise ValueError(
                'num of augmentations ({}) != num of image meta ({})'.format(
zhangwenwei's avatar
zhangwenwei committed
39
                    len(points), len(img_metas)))
zhangwenwei's avatar
zhangwenwei committed
40
41

        if num_augs == 1:
zhangwenwei's avatar
zhangwenwei committed
42
43
            img = [img] if img is None else img
            return self.simple_test(points[0], img_metas[0], img[0], **kwargs)
zhangwenwei's avatar
zhangwenwei committed
44
        else:
zhangwenwei's avatar
zhangwenwei committed
45
            return self.aug_test(points, img_metas, img, **kwargs)
zhangwenwei's avatar
zhangwenwei committed
46

47
    @auto_fp16(apply_to=('img', 'points'))
zhangwenwei's avatar
zhangwenwei committed
48
    def forward(self, return_loss=True, **kwargs):
zhangwenwei's avatar
zhangwenwei committed
49
50
51
52
53
        """Calls either forward_train or forward_test depending on whether
        return_loss=True.

        Note this setting will change the expected inputs. When
        `return_loss=True`, img and img_metas are single-nested (i.e.
54
55
        torch.Tensor and list[dict]), and when `resturn_loss=False`, img and
        img_metas should be double nested (i.e.  list[torch.Tensor],
zhangwenwei's avatar
zhangwenwei committed
56
57
        list[list[dict]]), with the outer list indicating test time
        augmentations.
zhangwenwei's avatar
zhangwenwei committed
58
59
        """
        if return_loss:
zhangwenwei's avatar
zhangwenwei committed
60
            return self.forward_train(**kwargs)
zhangwenwei's avatar
zhangwenwei committed
61
        else:
zhangwenwei's avatar
zhangwenwei committed
62
            return self.forward_test(**kwargs)
liyinhao's avatar
liyinhao committed
63

MilkClouds's avatar
MilkClouds committed
64
    def show_results(self, data, result, out_dir, show=False, score_thr=None):
liyinhao's avatar
liyinhao committed
65
66
67
        """Results visualization.

        Args:
68
69
            data (list[dict]): Input points and the information of the sample.
            result (list[dict]): Prediction results.
liyinhao's avatar
liyinhao committed
70
            out_dir (str): Output directory of visualization result.
MilkClouds's avatar
MilkClouds committed
71
72
73
74
75
            show (bool, optional): Determines whether you are
                going to show result by open3d.
                Defaults to False.
            score_thr (float, optional): Score threshold of bounding boxes.
                Default to None.
liyinhao's avatar
liyinhao committed
76
        """
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        for batch_id in range(len(result)):
            if isinstance(data['points'][0], DC):
                points = data['points'][0]._data[0][batch_id].numpy()
            elif mmcv.is_list_of(data['points'][0], torch.Tensor):
                points = data['points'][0][batch_id]
            else:
                ValueError(f"Unsupported data type {type(data['points'][0])} "
                           f'for visualization!')
            if isinstance(data['img_metas'][0], DC):
                pts_filename = data['img_metas'][0]._data[0][batch_id][
                    'pts_filename']
                box_mode_3d = data['img_metas'][0]._data[0][batch_id][
                    'box_mode_3d']
            elif mmcv.is_list_of(data['img_metas'][0], dict):
                pts_filename = data['img_metas'][0][batch_id]['pts_filename']
                box_mode_3d = data['img_metas'][0][batch_id]['box_mode_3d']
            else:
                ValueError(
                    f"Unsupported data type {type(data['img_metas'][0])} "
                    f'for visualization!')
            file_name = osp.split(pts_filename)[-1].split('.')[0]
liyinhao's avatar
liyinhao committed
98

99
            assert out_dir is not None, 'Expect out_dir, got none.'
liyinhao's avatar
liyinhao committed
100

101
            pred_bboxes = result[batch_id]['boxes_3d']
MilkClouds's avatar
MilkClouds committed
102
103
104
105
106
107
            pred_labels = result[batch_id]['labels_3d']

            if score_thr is not None:
                mask = result[batch_id]['scores_3d'] > score_thr
                pred_bboxes = pred_bboxes[mask]
                pred_labels = pred_labels[mask]
108
109
110
111
112
113

            # for now we convert points and bbox into depth mode
            if (box_mode_3d == Box3DMode.CAM) or (box_mode_3d
                                                  == Box3DMode.LIDAR):
                points = Coord3DMode.convert_point(points, Coord3DMode.LIDAR,
                                                   Coord3DMode.DEPTH)
114
115
                pred_bboxes = Box3DMode.convert(pred_bboxes, box_mode_3d,
                                                Box3DMode.DEPTH)
116
            elif box_mode_3d != Box3DMode.DEPTH:
117
                ValueError(
118
                    f'Unsupported box_mode_3d {box_mode_3d} for conversion!')
119
            pred_bboxes = pred_bboxes.tensor.cpu().numpy()
MilkClouds's avatar
MilkClouds committed
120
121
122
123
124
125
126
127
            show_result(
                points,
                None,
                pred_bboxes,
                out_dir,
                file_name,
                show=show,
                pred_labels=pred_labels)