inference.py 12.6 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
ChaimZhu's avatar
ChaimZhu committed
2
import warnings
3
4
from copy import deepcopy
from os import path as osp
ChaimZhu's avatar
ChaimZhu committed
5
6
from pathlib import Path
from typing import Optional, Sequence, Union
7

8
import mmengine
9
import numpy as np
wuyuefeng's avatar
Demo  
wuyuefeng committed
10
import torch
ChaimZhu's avatar
ChaimZhu committed
11
import torch.nn as nn
12
from mmengine.config import Config
13
from mmengine.dataset import Compose, pseudo_collate
14
from mmengine.registry import init_default_scope
ChaimZhu's avatar
ChaimZhu committed
15
from mmengine.runner import load_checkpoint
wuyuefeng's avatar
Demo  
wuyuefeng committed
16

zhangshilong's avatar
zhangshilong committed
17
18
19
from mmdet3d.registry import MODELS
from mmdet3d.structures import Box3DMode, Det3DDataSample, get_box_type
from mmdet3d.structures.det3d_data_sample import SampleList
wuyuefeng's avatar
Demo  
wuyuefeng committed
20
21


22
23
24
25
def convert_SyncBN(config):
    """Convert config's naiveSyncBN to BN.

    Args:
26
         config (str or :obj:`mmengine.Config`): Config file path or the config
27
28
29
30
31
32
33
34
35
36
37
            object.
    """
    if isinstance(config, dict):
        for item in config:
            if item == 'norm_cfg':
                config[item]['type'] = config[item]['type']. \
                                    replace('naiveSyncBN', 'BN')
            else:
                convert_SyncBN(config[item])


ChaimZhu's avatar
ChaimZhu committed
38
39
40
41
def init_model(config: Union[str, Path, Config],
               checkpoint: Optional[str] = None,
               device: str = 'cuda:0',
               cfg_options: Optional[dict] = None):
42
43
    """Initialize a model from config file, which could be a 3D detector or a
    3D segmentor.
wuyuefeng's avatar
Demo  
wuyuefeng committed
44
45

    Args:
ChaimZhu's avatar
ChaimZhu committed
46
47
        config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
            :obj:`Path`, or the config object.
wuyuefeng's avatar
Demo  
wuyuefeng committed
48
49
50
        checkpoint (str, optional): Checkpoint path. If left as None, the model
            will not load any weights.
        device (str): Device to use.
ChaimZhu's avatar
ChaimZhu committed
51
52
        cfg_options (dict, optional): Options to override some settings in
            the used config.
wuyuefeng's avatar
Demo  
wuyuefeng committed
53
54
55
56

    Returns:
        nn.Module: The constructed detector.
    """
ChaimZhu's avatar
ChaimZhu committed
57
    if isinstance(config, (str, Path)):
58
59
        config = Config.fromfile(config)
    elif not isinstance(config, Config):
wuyuefeng's avatar
Demo  
wuyuefeng committed
60
61
        raise TypeError('config must be a filename or Config object, '
                        f'but got {type(config)}')
ChaimZhu's avatar
ChaimZhu committed
62
63
    if cfg_options is not None:
        config.merge_from_dict(cfg_options)
64

65
    convert_SyncBN(config.model)
66
    config.model.train_cfg = None
67
    init_default_scope(config.get('default_scope', 'mmdet3d'))
zhangshilong's avatar
zhangshilong committed
68
    model = MODELS.build(config.model)
ChaimZhu's avatar
ChaimZhu committed
69

wuyuefeng's avatar
Demo  
wuyuefeng committed
70
    if checkpoint is not None:
71
        checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
ChaimZhu's avatar
ChaimZhu committed
72
73
74
        # save the dataset_meta in the model for convenience
        if 'dataset_meta' in checkpoint.get('meta', {}):
            # mmdet3d 1.x
75
            model.dataset_meta = checkpoint['meta']['dataset_meta']
ChaimZhu's avatar
ChaimZhu committed
76
77
78
        elif 'CLASSES' in checkpoint.get('meta', {}):
            # < mmdet3d 1.x
            classes = checkpoint['meta']['CLASSES']
79
            model.dataset_meta = {'classes': classes}
ChaimZhu's avatar
ChaimZhu committed
80
81

            if 'PALETTE' in checkpoint.get('meta', {}):  # 3D Segmentor
82
                model.dataset_meta['palette'] = checkpoint['meta']['PALETTE']
wuyuefeng's avatar
Demo  
wuyuefeng committed
83
        else:
ChaimZhu's avatar
ChaimZhu committed
84
            # < mmdet3d 1.x
85
            model.dataset_meta = {'classes': config.class_names}
ChaimZhu's avatar
ChaimZhu committed
86
87

            if 'PALETTE' in checkpoint.get('meta', {}):  # 3D Segmentor
88
                model.dataset_meta['palette'] = checkpoint['meta']['PALETTE']
ChaimZhu's avatar
ChaimZhu committed
89

wuyuefeng's avatar
Demo  
wuyuefeng committed
90
    model.cfg = config  # save the config in the model for convenience
91
92
93
    if device != 'cpu':
        torch.cuda.set_device(device)
    else:
ChaimZhu's avatar
ChaimZhu committed
94
95
96
        warnings.warn('Don\'t suggest using CPU device. '
                      'Some functions are not supported for now.')

wuyuefeng's avatar
Demo  
wuyuefeng committed
97
98
99
100
101
    model.to(device)
    model.eval()
    return model


ChaimZhu's avatar
ChaimZhu committed
102
103
104
105
106
107
PointsType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
ImagesType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]


def inference_detector(model: nn.Module,
                       pcds: PointsType) -> Union[Det3DDataSample, SampleList]:
wuyuefeng's avatar
Demo  
wuyuefeng committed
108
109
110
111
    """Inference point cloud with the detector.

    Args:
        model (nn.Module): The loaded detector.
ChaimZhu's avatar
ChaimZhu committed
112
113
        pcds (str, ndarray, Sequence[str/ndarray]):
            Either point cloud files or loaded point cloud.
wuyuefeng's avatar
Demo  
wuyuefeng committed
114
115

    Returns:
ChaimZhu's avatar
ChaimZhu committed
116
117
118
        :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
        If pcds is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
wuyuefeng's avatar
Demo  
wuyuefeng committed
119
    """
ChaimZhu's avatar
ChaimZhu committed
120
121
122
123
124
125
    if isinstance(pcds, (list, tuple)):
        is_batch = True
    else:
        pcds = [pcds]
        is_batch = False

wuyuefeng's avatar
Demo  
wuyuefeng committed
126
    cfg = model.cfg
127

ChaimZhu's avatar
ChaimZhu committed
128
    if not isinstance(pcds[0], str):
129
130
        cfg = cfg.copy()
        # set loading pipeline type
ChaimZhu's avatar
ChaimZhu committed
131
        cfg.test_dataloader.dataset.pipeline[0].type = 'LoadPointsFromDict'
132

wuyuefeng's avatar
Demo  
wuyuefeng committed
133
    # build the data pipeline
ChaimZhu's avatar
ChaimZhu committed
134
    test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
wuyuefeng's avatar
Demo  
wuyuefeng committed
135
    test_pipeline = Compose(test_pipeline)
136
137
    box_type_3d, box_mode_3d = \
        get_box_type(cfg.test_dataloader.dataset.box_type_3d)
ChaimZhu's avatar
ChaimZhu committed
138
139
140
141
142
143
144
145

    data = []
    for pcd in pcds:
        # prepare data
        if isinstance(pcd, str):
            # load from point cloud file
            data_ = dict(
                lidar_points=dict(lidar_path=pcd),
146
                timestamp=1,
ChaimZhu's avatar
ChaimZhu committed
147
                # for ScanNet demo we need axis_align_matrix
148
149
150
                axis_align_matrix=np.eye(4),
                box_type_3d=box_type_3d,
                box_mode_3d=box_mode_3d)
ChaimZhu's avatar
ChaimZhu committed
151
152
153
154
        else:
            # directly use loaded point cloud
            data_ = dict(
                points=pcd,
155
                timestamp=1,
ChaimZhu's avatar
ChaimZhu committed
156
                # for ScanNet demo we need axis_align_matrix
157
158
159
                axis_align_matrix=np.eye(4),
                box_type_3d=box_type_3d,
                box_mode_3d=box_mode_3d)
ChaimZhu's avatar
ChaimZhu committed
160
161
        data_ = test_pipeline(data_)
        data.append(data_)
162

163
164
    collate_data = pseudo_collate(data)

wuyuefeng's avatar
Demo  
wuyuefeng committed
165
166
    # forward the model
    with torch.no_grad():
167
        results = model.test_step(collate_data)
ChaimZhu's avatar
ChaimZhu committed
168
169

    if not is_batch:
170
        return results[0], data[0]
ChaimZhu's avatar
ChaimZhu committed
171
    else:
172
        return results, data
wuyuefeng's avatar
Demo  
wuyuefeng committed
173
174


ChaimZhu's avatar
ChaimZhu committed
175
176
177
def inference_multi_modality_detector(model: nn.Module,
                                      pcds: Union[str, Sequence[str]],
                                      imgs: Union[str, Sequence[str]],
178
                                      ann_file: Union[str, Sequence[str]],
179
180
181
182
                                      cam_type: str = 'CAM2'):
    """Inference point cloud with the multi-modality detector. Now we only
    support multi-modality detector for KITTI dataset since the multi-view
    image loading is not supported yet in this inference function.
183
184
185

    Args:
        model (nn.Module): The loaded detector.
ChaimZhu's avatar
ChaimZhu committed
186
187
188
189
        pcds (str, Sequence[str]):
            Either point cloud files or loaded point cloud.
        imgs (str, Sequence[str]):
           Either image files or loaded images.
190
191
        ann_file (str, Sequence[str]): Annotation files.
        cam_type (str): Image of Camera chose to infer.
192
            For kitti dataset, it should be 'CAM2',
193
194
            and for nuscenes dataset, it should be
            'CAM_FRONT'. Defaults to 'CAM_FRONT'.
195
196

    Returns:
ChaimZhu's avatar
ChaimZhu committed
197
198
199
        :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
        If pcds is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
200
    """
ChaimZhu's avatar
ChaimZhu committed
201
202
203
204
205

    # TODO: We will support
    if isinstance(pcds, (list, tuple)):
        is_batch = True
        assert isinstance(imgs, (list, tuple))
206
        assert len(pcds) == len(imgs)
ChaimZhu's avatar
ChaimZhu committed
207
208
209
210
211
    else:
        pcds = [pcds]
        imgs = [imgs]
        is_batch = False

212
    cfg = model.cfg
ChaimZhu's avatar
ChaimZhu committed
213

214
    # build the data pipeline
ChaimZhu's avatar
ChaimZhu committed
215
    test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
216
    test_pipeline = Compose(test_pipeline)
ChaimZhu's avatar
ChaimZhu committed
217
218
219
    box_type_3d, box_mode_3d = \
        get_box_type(cfg.test_dataloader.dataset.box_type_3d)

220
    data_list = mmengine.load(ann_file)['data_list']
221

ChaimZhu's avatar
ChaimZhu committed
222
223
224
225
    data = []
    for index, pcd in enumerate(pcds):
        # get data info containing calib
        img = imgs[index]
226
227
228
229
230
231
        data_info = data_list[index]
        img_path = data_info['images'][cam_type]['img_path']

        if osp.basename(img_path) != osp.basename(img):
            raise ValueError(f'the info file of {img_path} is not provided.')

232
233
234
        data_info['images'][cam_type]['img_path'] = img
        cam2img = np.array(data_info['images'][cam_type]['cam2img'])

ChaimZhu's avatar
ChaimZhu committed
235
236
        # TODO: check the name consistency of
        # image file and point cloud file
237
        # TODO: support multi-view image loading
ChaimZhu's avatar
ChaimZhu committed
238
239
        data_ = dict(
            lidar_points=dict(lidar_path=pcd),
240
            img_path=img,
ChaimZhu's avatar
ChaimZhu committed
241
            box_type_3d=box_type_3d,
242
243
            box_mode_3d=box_mode_3d,
            cam2img=cam2img)
ChaimZhu's avatar
ChaimZhu committed
244
245
246

        # LiDAR to image conversion for KITTI dataset
        if box_mode_3d == Box3DMode.LIDAR:
247
248
            data_['lidar2img'] = np.array(
                data_info['images'][cam_type]['lidar2img'])
ChaimZhu's avatar
ChaimZhu committed
249
250
        # Depth to image conversion for SUNRGBD dataset
        elif box_mode_3d == Box3DMode.DEPTH:
251
252
            data_['depth2img'] = np.array(
                data_info['images'][cam_type]['depth2img'])
ChaimZhu's avatar
ChaimZhu committed
253

254
        data_ = test_pipeline(data_)
ChaimZhu's avatar
ChaimZhu committed
255
        data.append(data_)
256

257
258
    collate_data = pseudo_collate(data)

259
260
    # forward the model
    with torch.no_grad():
261
        results = model.test_step(collate_data)
262

ChaimZhu's avatar
ChaimZhu committed
263
    if not is_batch:
264
        return results[0], data[0]
ChaimZhu's avatar
ChaimZhu committed
265
    else:
266
        return results, data
267
268


269
270
271
272
def inference_mono_3d_detector(model: nn.Module,
                               imgs: ImagesType,
                               ann_file: Union[str, Sequence[str]],
                               cam_type: str = 'CAM_FRONT'):
273
274
275
276
    """Inference image with the monocular 3D detector.

    Args:
        model (nn.Module): The loaded detector.
ChaimZhu's avatar
ChaimZhu committed
277
278
279
        imgs (str, Sequence[str]):
           Either image files or loaded images.
        ann_files (str, Sequence[str]): Annotation files.
280
281
282
283
        cam_type (str): Image of Camera chose to infer.
            For kitti dataset, it should be 'CAM_2',
            and for nuscenes dataset, it should be
            'CAM_FRONT'. Defaults to 'CAM_FRONT'.
284
285

    Returns:
ChaimZhu's avatar
ChaimZhu committed
286
287
288
        :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
        If pcds is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
289
    """
ChaimZhu's avatar
ChaimZhu committed
290
291
292
293
294
295
    if isinstance(imgs, (list, tuple)):
        is_batch = True
    else:
        imgs = [imgs]
        is_batch = False

296
    cfg = model.cfg
ChaimZhu's avatar
ChaimZhu committed
297

298
    # build the data pipeline
ChaimZhu's avatar
ChaimZhu committed
299
    test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
300
    test_pipeline = Compose(test_pipeline)
ChaimZhu's avatar
ChaimZhu committed
301
302
303
    box_type_3d, box_mode_3d = \
        get_box_type(cfg.test_dataloader.dataset.box_type_3d)

304
    data_list = mmengine.load(ann_file)['data_list']
305
306
    assert len(imgs) == len(data_list)

ChaimZhu's avatar
ChaimZhu committed
307
308
309
    data = []
    for index, img in enumerate(imgs):
        # get data info containing calib
310
311
312
313
314
315
316
        data_info = data_list[index]
        img_path = data_info['images'][cam_type]['img_path']
        if osp.basename(img_path) != osp.basename(img):
            raise ValueError(f'the info file of {img_path} is not provided.')

        # replace the img_path in data_info with img
        data_info['images'][cam_type]['img_path'] = img
ChaimZhu's avatar
ChaimZhu committed
317
318
319
320
321
322
        data_ = dict(
            images=data_info['images'],
            box_type_3d=box_type_3d,
            box_mode_3d=box_mode_3d)

        data_ = test_pipeline(data_)
323
        data.append(data_)
324

325
326
    collate_data = pseudo_collate(data)

327
328
    # forward the model
    with torch.no_grad():
329
        results = model.test_step(collate_data)
330

ChaimZhu's avatar
ChaimZhu committed
331
332
333
334
    if not is_batch:
        return results[0]
    else:
        return results
335

ChaimZhu's avatar
ChaimZhu committed
336
337

def inference_segmentor(model: nn.Module, pcds: PointsType):
338
    """Inference point cloud with the segmentor.
wuyuefeng's avatar
Demo  
wuyuefeng committed
339
340

    Args:
341
        model (nn.Module): The loaded segmentor.
ChaimZhu's avatar
ChaimZhu committed
342
343
        pcds (str, Sequence[str]):
            Either point cloud files or loaded point cloud.
344
345

    Returns:
ChaimZhu's avatar
ChaimZhu committed
346
347
348
        :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
        If pcds is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
wuyuefeng's avatar
Demo  
wuyuefeng committed
349
    """
ChaimZhu's avatar
ChaimZhu committed
350
351
352
353
354
355
    if isinstance(pcds, (list, tuple)):
        is_batch = True
    else:
        pcds = [pcds]
        is_batch = False

356
    cfg = model.cfg
ChaimZhu's avatar
ChaimZhu committed
357

358
    # build the data pipeline
ChaimZhu's avatar
ChaimZhu committed
359
    test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
360
361
362
363
364
365

    new_test_pipeline = []
    for pipeline in test_pipeline:
        if pipeline['type'] != 'LoadAnnotations3D':
            new_test_pipeline.append(pipeline)
    test_pipeline = Compose(new_test_pipeline)
ChaimZhu's avatar
ChaimZhu committed
366
367

    data = []
368
    # TODO: support load points array
ChaimZhu's avatar
ChaimZhu committed
369
370
371
372
373
    for pcd in pcds:
        data_ = dict(lidar_points=dict(lidar_path=pcd))
        data_ = test_pipeline(data_)
        data.append(data_)

374
375
    collate_data = pseudo_collate(data)

376
377
    # forward the model
    with torch.no_grad():
378
        results = model.test_step(collate_data)
ChaimZhu's avatar
ChaimZhu committed
379
380

    if not is_batch:
381
        return results[0], data[0]
ChaimZhu's avatar
ChaimZhu committed
382
    else:
383
        return results, data