lidar_det3d_inferencer.py 9.89 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Dict, List, Optional, Sequence, Union

import mmengine
import numpy as np
7
import torch
8
from mmengine.dataset import Compose
9
10
from mmengine.fileio import (get_file_backend, isdir, join_path,
                             list_dir_or_file)
11
12
13
14
from mmengine.infer.infer import ModelType
from mmengine.structures import InstanceData

from mmdet3d.registry import INFERENCERS
15
16
from mmdet3d.structures import (CameraInstance3DBoxes, DepthInstance3DBoxes,
                                Det3DDataSample, LiDARInstance3DBoxes)
17
from mmdet3d.utils import ConfigType
18
from .base_3d_inferencer import Base3DInferencer
19
20
21
22
23
24
25
26
27
28
29

InstanceList = List[InstanceData]
InputType = Union[str, np.ndarray]
InputsType = Union[InputType, Sequence[InputType]]
PredType = Union[InstanceData, InstanceList]
ImgType = Union[np.ndarray, Sequence[np.ndarray]]
ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]


@INFERENCERS.register_module(name='det3d-lidar')
@INFERENCERS.register_module()
30
class LidarDet3DInferencer(Base3DInferencer):
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    """The inferencer of LiDAR-based detection.

    Args:
        model (str, optional): Path to the config file or the model name
            defined in metafile. For example, it could be
            "pointpillars_kitti-3class" or
            "configs/pointpillars/pointpillars_hv_secfpn_8xb6-160e_kitti-3d-3class.py". # noqa: E501
            If model is not specified, user must provide the
            `weights` saved by MMEngine which contains the config string.
            Defaults to None.
        weights (str, optional): Path to the checkpoint. If it is not specified
            and model is a model name of metafile, the weights will be loaded
            from metafile. Defaults to None.
        device (str, optional): Device to run inference. If None, the available
            device will be automatically used. Defaults to None.
46
47
48
        scope (str): The scope of the model. Defaults to 'mmdet3d'.
        palette (str): Color palette used for visualization. The order of
            priority is palette -> config -> checkpoint. Defaults to 'none'.
49
50
51
52
53
54
    """

    def __init__(self,
                 model: Union[ModelType, str, None] = None,
                 weights: Optional[str] = None,
                 device: Optional[str] = None,
55
                 scope: str = 'mmdet3d',
56
57
58
59
                 palette: str = 'none') -> None:
        # A global counter tracking the number of frames processed, for
        # naming of the output results
        self.num_visualized_frames = 0
60
61
62
63
64
65
        super(LidarDet3DInferencer, self).__init__(
            model=model,
            weights=weights,
            device=device,
            scope=scope,
            palette=palette)
66

67
    def _inputs_to_list(self, inputs: Union[dict, list], **kwargs) -> list:
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        """Preprocess the inputs to a list.

        Preprocess inputs to a list according to its type:

        - list or tuple: return inputs
        - dict: the value with key 'points' is
            - Directory path: return all files in the directory
            - other cases: return a list containing the string. The string
              could be a path to file, a url or other types of string according
              to the task.

        Args:
            inputs (Union[dict, list]): Inputs for the inferencer.

        Returns:
            list: List of input for the :meth:`preprocess`.
        """
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        if isinstance(inputs, dict) and isinstance(inputs['points'], str):
            pcd = inputs['points']
            backend = get_file_backend(pcd)
            if hasattr(backend, 'isdir') and isdir(pcd):
                # Backends like HttpsBackend do not implement `isdir`, so
                # only those backends that implement `isdir` could accept
                # the inputs as a directory
                filename_list = list_dir_or_file(pcd, list_dir=False)
                inputs = [{
                    'points': join_path(pcd, filename)
                } for filename in filename_list]

        if not isinstance(inputs, (list, tuple)):
            inputs = [inputs]

        return list(inputs)
101
102
103
104
105

    def _init_pipeline(self, cfg: ConfigType) -> Compose:
        """Initialize the test pipeline."""
        pipeline_cfg = cfg.test_dataloader.dataset.pipeline

106
107
108
        load_point_idx = self._get_transform_idx(pipeline_cfg,
                                                 'LoadPointsFromFile')
        if load_point_idx == -1:
109
110
111
            raise ValueError(
                'LoadPointsFromFile is not found in the test pipeline')

112
        load_cfg = pipeline_cfg[load_point_idx]
113
114
115
116
117
        self.coord_type, self.load_dim = load_cfg['coord_type'], load_cfg[
            'load_dim']
        self.use_dim = list(range(load_cfg['use_dim'])) if isinstance(
            load_cfg['use_dim'], int) else load_cfg['use_dim']

118
        pipeline_cfg[load_point_idx]['type'] = 'LidarDet3DInferencerLoader'
119
120
121
122
123
124
125
        return Compose(pipeline_cfg)

    def visualize(self,
                  inputs: InputsType,
                  preds: PredType,
                  return_vis: bool = False,
                  show: bool = False,
126
                  wait_time: int = -1,
127
128
                  draw_pred: bool = True,
                  pred_score_thr: float = 0.3,
129
                  no_save_vis: bool = False,
130
131
132
133
134
135
136
137
138
139
                  img_out_dir: str = '') -> Union[List[np.ndarray], None]:
        """Visualize predictions.

        Args:
            inputs (InputsType): Inputs for the inferencer.
            preds (PredType): Predictions of the model.
            return_vis (bool): Whether to return the visualization result.
                Defaults to False.
            show (bool): Whether to display the image in a popup window.
                Defaults to False.
140
            wait_time (float): The interval of show (s). Defaults to -1.
141
142
143
144
            draw_pred (bool): Whether to draw predicted bounding boxes.
                Defaults to True.
            pred_score_thr (float): Minimum score of bboxes to draw.
                Defaults to 0.3.
145
146
            no_save_vis (bool): Whether to force not to save prediction
                vis results. Defaults to False.
147
148
            img_out_dir (str): Output directory of visualization results.
                If left as empty, no file will be saved. Defaults to ''.
149

150
151
152
153
        Returns:
            List[np.ndarray] or None: Returns visualization results only if
            applicable.
        """
154
155
156
157
        if no_save_vis is True:
            img_out_dir = ''

        if not show and img_out_dir == '' and not return_vis:
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
            return None

        if getattr(self, 'visualizer') is None:
            raise ValueError('Visualization needs the "visualizer" term'
                             'defined in the config, but got None.')

        results = []

        for single_input, pred in zip(inputs, preds):
            single_input = single_input['points']
            if isinstance(single_input, str):
                pts_bytes = mmengine.fileio.get(single_input)
                points = np.frombuffer(pts_bytes, dtype=np.float32)
                points = points.reshape(-1, self.load_dim)
                points = points[:, self.use_dim]
                pc_name = osp.basename(single_input).split('.bin')[0]
                pc_name = f'{pc_name}.png'
            elif isinstance(single_input, np.ndarray):
                points = single_input.copy()
                pc_num = str(self.num_visualized_frames).zfill(8)
178
                pc_name = f'{pc_num}.png'
179
180
181
182
            else:
                raise ValueError('Unsupported input type: '
                                 f'{type(single_input)}')

183
184
185
186
187
            if img_out_dir != '' and show:
                o3d_save_path = osp.join(img_out_dir, 'vis_lidar', pc_name)
                mmengine.mkdir_or_exist(osp.dirname(o3d_save_path))
            else:
                o3d_save_path = None
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205

            data_input = dict(points=points)
            self.visualizer.add_datasample(
                pc_name,
                data_input,
                pred,
                show=show,
                wait_time=wait_time,
                draw_gt=False,
                draw_pred=draw_pred,
                pred_score_thr=pred_score_thr,
                o3d_save_path=o3d_save_path,
                vis_task='lidar_det',
            )
            results.append(points)
            self.num_visualized_frames += 1

        return results
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242

    def visualize_preds_fromfile(self, inputs: InputsType, preds: PredType,
                                 **kwargs) -> Union[List[np.ndarray], None]:
        """Visualize predictions from `*.json` files.

        Args:
            inputs (InputsType): Inputs for the inferencer.
            preds (PredType): Predictions of the model.

        Returns:
            List[np.ndarray] or None: Returns visualization results only if
            applicable.
        """
        data_samples = []
        for pred in preds:
            pred = mmengine.load(pred)
            data_sample = Det3DDataSample()
            data_sample.pred_instances_3d = InstanceData()

            data_sample.pred_instances_3d.labels_3d = torch.tensor(
                pred['labels_3d'])
            data_sample.pred_instances_3d.scores_3d = torch.tensor(
                pred['scores_3d'])
            if pred['box_type_3d'] == 'LiDAR':
                data_sample.pred_instances_3d.bboxes_3d = \
                    LiDARInstance3DBoxes(pred['bboxes_3d'])
            elif pred['box_type_3d'] == 'Camera':
                data_sample.pred_instances_3d.bboxes_3d = \
                    CameraInstance3DBoxes(pred['bboxes_3d'])
            elif pred['box_type_3d'] == 'Depth':
                data_sample.pred_instances_3d.bboxes_3d = \
                    DepthInstance3DBoxes(pred['bboxes_3d'])
            else:
                raise ValueError('Unsupported box type: '
                                 f'{pred["box_type_3d"]}')
            data_samples.append(data_sample)
        return self.visualize(inputs=inputs, preds=data_samples, **kwargs)