inference.py 12.2 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
ChaimZhu's avatar
ChaimZhu committed
2
import warnings
3
4
from copy import deepcopy
from os import path as osp
ChaimZhu's avatar
ChaimZhu committed
5
6
from pathlib import Path
from typing import Optional, Sequence, Union
7

8
import mmengine
9
import numpy as np
wuyuefeng's avatar
Demo  
wuyuefeng committed
10
import torch
ChaimZhu's avatar
ChaimZhu committed
11
import torch.nn as nn
12
from mmengine.config import Config
13
from mmengine.dataset import Compose, pseudo_collate
ChaimZhu's avatar
ChaimZhu committed
14
from mmengine.runner import load_checkpoint
wuyuefeng's avatar
Demo  
wuyuefeng committed
15

zhangshilong's avatar
zhangshilong committed
16
17
18
from mmdet3d.registry import MODELS
from mmdet3d.structures import Box3DMode, Det3DDataSample, get_box_type
from mmdet3d.structures.det3d_data_sample import SampleList
wuyuefeng's avatar
Demo  
wuyuefeng committed
19
20


21
22
23
24
def convert_SyncBN(config):
    """Convert config's naiveSyncBN to BN.

    Args:
25
         config (str or :obj:`mmengine.Config`): Config file path or the config
26
27
28
29
30
31
32
33
34
35
36
            object.
    """
    if isinstance(config, dict):
        for item in config:
            if item == 'norm_cfg':
                config[item]['type'] = config[item]['type']. \
                                    replace('naiveSyncBN', 'BN')
            else:
                convert_SyncBN(config[item])


ChaimZhu's avatar
ChaimZhu committed
37
38
39
40
def init_model(config: Union[str, Path, Config],
               checkpoint: Optional[str] = None,
               device: str = 'cuda:0',
               cfg_options: Optional[dict] = None):
41
42
    """Initialize a model from config file, which could be a 3D detector or a
    3D segmentor.
wuyuefeng's avatar
Demo  
wuyuefeng committed
43
44

    Args:
ChaimZhu's avatar
ChaimZhu committed
45
46
        config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
            :obj:`Path`, or the config object.
wuyuefeng's avatar
Demo  
wuyuefeng committed
47
48
49
        checkpoint (str, optional): Checkpoint path. If left as None, the model
            will not load any weights.
        device (str): Device to use.
ChaimZhu's avatar
ChaimZhu committed
50
51
        cfg_options (dict, optional): Options to override some settings in
            the used config.
wuyuefeng's avatar
Demo  
wuyuefeng committed
52
53
54
55

    Returns:
        nn.Module: The constructed detector.
    """
ChaimZhu's avatar
ChaimZhu committed
56
    if isinstance(config, (str, Path)):
57
58
        config = Config.fromfile(config)
    elif not isinstance(config, Config):
wuyuefeng's avatar
Demo  
wuyuefeng committed
59
60
        raise TypeError('config must be a filename or Config object, '
                        f'but got {type(config)}')
ChaimZhu's avatar
ChaimZhu committed
61
62
    if cfg_options is not None:
        config.merge_from_dict(cfg_options)
63

64
    convert_SyncBN(config.model)
65
    config.model.train_cfg = None
zhangshilong's avatar
zhangshilong committed
66
    model = MODELS.build(config.model)
ChaimZhu's avatar
ChaimZhu committed
67

wuyuefeng's avatar
Demo  
wuyuefeng committed
68
    if checkpoint is not None:
69
        checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
ChaimZhu's avatar
ChaimZhu committed
70
71
72
        # save the dataset_meta in the model for convenience
        if 'dataset_meta' in checkpoint.get('meta', {}):
            # mmdet3d 1.x
73
            model.dataset_meta = checkpoint['meta']['dataset_meta']
ChaimZhu's avatar
ChaimZhu committed
74
75
76
77
78
79
80
        elif 'CLASSES' in checkpoint.get('meta', {}):
            # < mmdet3d 1.x
            classes = checkpoint['meta']['CLASSES']
            model.dataset_meta = {'CLASSES': classes}

            if 'PALETTE' in checkpoint.get('meta', {}):  # 3D Segmentor
                model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE']
wuyuefeng's avatar
Demo  
wuyuefeng committed
81
        else:
ChaimZhu's avatar
ChaimZhu committed
82
83
84
85
86
87
            # < mmdet3d 1.x
            model.dataset_meta = {'CLASSES': config.class_names}

            if 'PALETTE' in checkpoint.get('meta', {}):  # 3D Segmentor
                model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE']

wuyuefeng's avatar
Demo  
wuyuefeng committed
88
    model.cfg = config  # save the config in the model for convenience
89
90
91
    if device != 'cpu':
        torch.cuda.set_device(device)
    else:
ChaimZhu's avatar
ChaimZhu committed
92
93
94
        warnings.warn('Don\'t suggest using CPU device. '
                      'Some functions are not supported for now.')

wuyuefeng's avatar
Demo  
wuyuefeng committed
95
96
97
98
99
    model.to(device)
    model.eval()
    return model


ChaimZhu's avatar
ChaimZhu committed
100
101
102
103
104
105
PointsType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
ImagesType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]


def inference_detector(model: nn.Module,
                       pcds: PointsType) -> Union[Det3DDataSample, SampleList]:
wuyuefeng's avatar
Demo  
wuyuefeng committed
106
107
108
109
    """Inference point cloud with the detector.

    Args:
        model (nn.Module): The loaded detector.
ChaimZhu's avatar
ChaimZhu committed
110
111
        pcds (str, ndarray, Sequence[str/ndarray]):
            Either point cloud files or loaded point cloud.
wuyuefeng's avatar
Demo  
wuyuefeng committed
112
113

    Returns:
ChaimZhu's avatar
ChaimZhu committed
114
115
116
        :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
        If pcds is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
wuyuefeng's avatar
Demo  
wuyuefeng committed
117
    """
ChaimZhu's avatar
ChaimZhu committed
118
119
120
121
122
123
    if isinstance(pcds, (list, tuple)):
        is_batch = True
    else:
        pcds = [pcds]
        is_batch = False

wuyuefeng's avatar
Demo  
wuyuefeng committed
124
    cfg = model.cfg
125

ChaimZhu's avatar
ChaimZhu committed
126
    if not isinstance(pcds[0], str):
127
128
        cfg = cfg.copy()
        # set loading pipeline type
ChaimZhu's avatar
ChaimZhu committed
129
        cfg.test_dataloader.dataset.pipeline[0].type = 'LoadPointsFromDict'
130

wuyuefeng's avatar
Demo  
wuyuefeng committed
131
    # build the data pipeline
ChaimZhu's avatar
ChaimZhu committed
132
    test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
wuyuefeng's avatar
Demo  
wuyuefeng committed
133
    test_pipeline = Compose(test_pipeline)
134
135
    box_type_3d, box_mode_3d = \
        get_box_type(cfg.test_dataloader.dataset.box_type_3d)
ChaimZhu's avatar
ChaimZhu committed
136
137
138
139
140
141
142
143

    data = []
    for pcd in pcds:
        # prepare data
        if isinstance(pcd, str):
            # load from point cloud file
            data_ = dict(
                lidar_points=dict(lidar_path=pcd),
144
                timestamp=1,
ChaimZhu's avatar
ChaimZhu committed
145
                # for ScanNet demo we need axis_align_matrix
146
147
148
                axis_align_matrix=np.eye(4),
                box_type_3d=box_type_3d,
                box_mode_3d=box_mode_3d)
ChaimZhu's avatar
ChaimZhu committed
149
150
151
152
        else:
            # directly use loaded point cloud
            data_ = dict(
                points=pcd,
153
                timestamp=1,
ChaimZhu's avatar
ChaimZhu committed
154
                # for ScanNet demo we need axis_align_matrix
155
156
157
                axis_align_matrix=np.eye(4),
                box_type_3d=box_type_3d,
                box_mode_3d=box_mode_3d)
ChaimZhu's avatar
ChaimZhu committed
158
159
        data_ = test_pipeline(data_)
        data.append(data_)
160

161
162
    collate_data = pseudo_collate(data)

wuyuefeng's avatar
Demo  
wuyuefeng committed
163
164
    # forward the model
    with torch.no_grad():
165
        results = model.test_step(collate_data)
ChaimZhu's avatar
ChaimZhu committed
166
167

    if not is_batch:
168
        return results[0], data[0]
ChaimZhu's avatar
ChaimZhu committed
169
    else:
170
        return results, data
wuyuefeng's avatar
Demo  
wuyuefeng committed
171
172


ChaimZhu's avatar
ChaimZhu committed
173
174
175
def inference_multi_modality_detector(model: nn.Module,
                                      pcds: Union[str, Sequence[str]],
                                      imgs: Union[str, Sequence[str]],
176
177
                                      ann_file: Union[str, Sequence[str]],
                                      cam_type: str = 'CAM_FRONT'):
178
    """Inference point cloud with the multi-modality detector.
179
180
181

    Args:
        model (nn.Module): The loaded detector.
ChaimZhu's avatar
ChaimZhu committed
182
183
184
185
        pcds (str, Sequence[str]):
            Either point cloud files or loaded point cloud.
        imgs (str, Sequence[str]):
           Either image files or loaded images.
186
187
188
189
190
        ann_file (str, Sequence[str]): Annotation files.
        cam_type (str): Image of Camera chose to infer.
            For kitti dataset, it should be 'CAM_2',
            and for nuscenes dataset, it should be
            'CAM_FRONT'. Defaults to 'CAM_FRONT'.
191
192

    Returns:
ChaimZhu's avatar
ChaimZhu committed
193
194
195
        :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
        If pcds is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
196
    """
ChaimZhu's avatar
ChaimZhu committed
197
198
199
200
201

    # TODO: We will support
    if isinstance(pcds, (list, tuple)):
        is_batch = True
        assert isinstance(imgs, (list, tuple))
202
        assert len(pcds) == len(imgs)
ChaimZhu's avatar
ChaimZhu committed
203
204
205
206
207
    else:
        pcds = [pcds]
        imgs = [imgs]
        is_batch = False

208
    cfg = model.cfg
ChaimZhu's avatar
ChaimZhu committed
209

210
    # build the data pipeline
ChaimZhu's avatar
ChaimZhu committed
211
    test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
212
    test_pipeline = Compose(test_pipeline)
ChaimZhu's avatar
ChaimZhu committed
213
214
215
    box_type_3d, box_mode_3d = \
        get_box_type(cfg.test_dataloader.dataset.box_type_3d)

216
    data_list = mmengine.load(ann_file)['data_list']
217
218
    assert len(imgs) == len(data_list)

ChaimZhu's avatar
ChaimZhu committed
219
220
221
222
    data = []
    for index, pcd in enumerate(pcds):
        # get data info containing calib
        img = imgs[index]
223
224
225
226
227
228
        data_info = data_list[index]
        img_path = data_info['images'][cam_type]['img_path']

        if osp.basename(img_path) != osp.basename(img):
            raise ValueError(f'the info file of {img_path} is not provided.')

ChaimZhu's avatar
ChaimZhu committed
229
230
231
232
        # TODO: check the name consistency of
        # image file and point cloud file
        data_ = dict(
            lidar_points=dict(lidar_path=pcd),
233
            img_path=img,
ChaimZhu's avatar
ChaimZhu committed
234
235
236
237
238
            box_type_3d=box_type_3d,
            box_mode_3d=box_mode_3d)

        # LiDAR to image conversion for KITTI dataset
        if box_mode_3d == Box3DMode.LIDAR:
239
240
            data_['lidar2img'] = np.array(
                data_info['images'][cam_type]['lidar2img'])
ChaimZhu's avatar
ChaimZhu committed
241
242
        # Depth to image conversion for SUNRGBD dataset
        elif box_mode_3d == Box3DMode.DEPTH:
243
244
            data_['depth2img'] = np.array(
                data_info['images'][cam_type]['depth2img'])
ChaimZhu's avatar
ChaimZhu committed
245

246
        data_ = test_pipeline(data_)
ChaimZhu's avatar
ChaimZhu committed
247
        data.append(data_)
248

249
250
    collate_data = pseudo_collate(data)

251
252
    # forward the model
    with torch.no_grad():
253
        results = model.test_step(collate_data)
254

ChaimZhu's avatar
ChaimZhu committed
255
    if not is_batch:
256
        return results[0], data[0]
ChaimZhu's avatar
ChaimZhu committed
257
    else:
258
        return results, data
259
260


261
262
263
264
def inference_mono_3d_detector(model: nn.Module,
                               imgs: ImagesType,
                               ann_file: Union[str, Sequence[str]],
                               cam_type: str = 'CAM_FRONT'):
265
266
267
268
    """Inference image with the monocular 3D detector.

    Args:
        model (nn.Module): The loaded detector.
ChaimZhu's avatar
ChaimZhu committed
269
270
271
        imgs (str, Sequence[str]):
           Either image files or loaded images.
        ann_files (str, Sequence[str]): Annotation files.
272
273
274
275
        cam_type (str): Image of Camera chose to infer.
            For kitti dataset, it should be 'CAM_2',
            and for nuscenes dataset, it should be
            'CAM_FRONT'. Defaults to 'CAM_FRONT'.
276
277

    Returns:
ChaimZhu's avatar
ChaimZhu committed
278
279
280
        :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
        If pcds is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
281
    """
ChaimZhu's avatar
ChaimZhu committed
282
283
284
285
286
287
    if isinstance(imgs, (list, tuple)):
        is_batch = True
    else:
        imgs = [imgs]
        is_batch = False

288
    cfg = model.cfg
ChaimZhu's avatar
ChaimZhu committed
289

290
    # build the data pipeline
ChaimZhu's avatar
ChaimZhu committed
291
    test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
292
    test_pipeline = Compose(test_pipeline)
ChaimZhu's avatar
ChaimZhu committed
293
294
295
    box_type_3d, box_mode_3d = \
        get_box_type(cfg.test_dataloader.dataset.box_type_3d)

296
    data_list = mmengine.load(ann_file)
297
298
    assert len(imgs) == len(data_list)

ChaimZhu's avatar
ChaimZhu committed
299
300
301
    data = []
    for index, img in enumerate(imgs):
        # get data info containing calib
302
303
304
305
306
307
308
        data_info = data_list[index]
        img_path = data_info['images'][cam_type]['img_path']
        if osp.basename(img_path) != osp.basename(img):
            raise ValueError(f'the info file of {img_path} is not provided.')

        # replace the img_path in data_info with img
        data_info['images'][cam_type]['img_path'] = img
ChaimZhu's avatar
ChaimZhu committed
309
310
311
312
313
314
        data_ = dict(
            images=data_info['images'],
            box_type_3d=box_type_3d,
            box_mode_3d=box_mode_3d)

        data_ = test_pipeline(data_)
315
        data.append(data_)
316

317
318
    collate_data = pseudo_collate(data)

319
320
    # forward the model
    with torch.no_grad():
321
        results = model.test_step(collate_data)
322

ChaimZhu's avatar
ChaimZhu committed
323
324
325
326
    if not is_batch:
        return results[0]
    else:
        return results
327

ChaimZhu's avatar
ChaimZhu committed
328
329

def inference_segmentor(model: nn.Module, pcds: PointsType):
330
    """Inference point cloud with the segmentor.
wuyuefeng's avatar
Demo  
wuyuefeng committed
331
332

    Args:
333
        model (nn.Module): The loaded segmentor.
ChaimZhu's avatar
ChaimZhu committed
334
335
        pcds (str, Sequence[str]):
            Either point cloud files or loaded point cloud.
336
337

    Returns:
ChaimZhu's avatar
ChaimZhu committed
338
339
340
        :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
        If pcds is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
wuyuefeng's avatar
Demo  
wuyuefeng committed
341
    """
ChaimZhu's avatar
ChaimZhu committed
342
343
344
345
346
347
    if isinstance(pcds, (list, tuple)):
        is_batch = True
    else:
        pcds = [pcds]
        is_batch = False

348
    cfg = model.cfg
ChaimZhu's avatar
ChaimZhu committed
349

350
    # build the data pipeline
ChaimZhu's avatar
ChaimZhu committed
351
    test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
352
353
354
355
356
357

    new_test_pipeline = []
    for pipeline in test_pipeline:
        if pipeline['type'] != 'LoadAnnotations3D':
            new_test_pipeline.append(pipeline)
    test_pipeline = Compose(new_test_pipeline)
ChaimZhu's avatar
ChaimZhu committed
358
359

    data = []
360
    # TODO: support load points array
ChaimZhu's avatar
ChaimZhu committed
361
362
363
364
365
    for pcd in pcds:
        data_ = dict(lidar_points=dict(lidar_path=pcd))
        data_ = test_pipeline(data_)
        data.append(data_)

366
367
    collate_data = pseudo_collate(data)

368
369
    # forward the model
    with torch.no_grad():
370
        results = model.test_step(collate_data)
ChaimZhu's avatar
ChaimZhu committed
371
372

    if not is_batch:
373
        return results[0], data[0]
ChaimZhu's avatar
ChaimZhu committed
374
    else:
375
        return results, data