inference.py 12.2 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
ChaimZhu's avatar
ChaimZhu committed
2
import warnings
3
4
from copy import deepcopy
from os import path as osp
ChaimZhu's avatar
ChaimZhu committed
5
6
from pathlib import Path
from typing import Optional, Sequence, Union
7

8
import mmengine
9
import numpy as np
wuyuefeng's avatar
Demo  
wuyuefeng committed
10
import torch
ChaimZhu's avatar
ChaimZhu committed
11
import torch.nn as nn
12
from mmengine.config import Config
ChaimZhu's avatar
ChaimZhu committed
13
14
from mmengine.dataset import Compose
from mmengine.runner import load_checkpoint
wuyuefeng's avatar
Demo  
wuyuefeng committed
15

zhangshilong's avatar
zhangshilong committed
16
17
18
from mmdet3d.registry import MODELS
from mmdet3d.structures import Box3DMode, Det3DDataSample, get_box_type
from mmdet3d.structures.det3d_data_sample import SampleList
wuyuefeng's avatar
Demo  
wuyuefeng committed
19
20


21
22
23
24
def convert_SyncBN(config):
    """Convert config's naiveSyncBN to BN.

    Args:
25
         config (str or :obj:`mmengine.Config`): Config file path or the config
26
27
28
29
30
31
32
33
34
35
36
            object.
    """
    if isinstance(config, dict):
        for item in config:
            if item == 'norm_cfg':
                config[item]['type'] = config[item]['type']. \
                                    replace('naiveSyncBN', 'BN')
            else:
                convert_SyncBN(config[item])


ChaimZhu's avatar
ChaimZhu committed
37
38
39
40
def init_model(config: Union[str, Path, Config],
               checkpoint: Optional[str] = None,
               device: str = 'cuda:0',
               cfg_options: Optional[dict] = None):
41
42
    """Initialize a model from config file, which could be a 3D detector or a
    3D segmentor.
wuyuefeng's avatar
Demo  
wuyuefeng committed
43
44

    Args:
ChaimZhu's avatar
ChaimZhu committed
45
46
        config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
            :obj:`Path`, or the config object.
wuyuefeng's avatar
Demo  
wuyuefeng committed
47
48
49
        checkpoint (str, optional): Checkpoint path. If left as None, the model
            will not load any weights.
        device (str): Device to use.
ChaimZhu's avatar
ChaimZhu committed
50
51
        cfg_options (dict, optional): Options to override some settings in
            the used config.
wuyuefeng's avatar
Demo  
wuyuefeng committed
52
53
54
55

    Returns:
        nn.Module: The constructed detector.
    """
ChaimZhu's avatar
ChaimZhu committed
56
    if isinstance(config, (str, Path)):
57
58
        config = Config.fromfile(config)
    elif not isinstance(config, Config):
wuyuefeng's avatar
Demo  
wuyuefeng committed
59
60
        raise TypeError('config must be a filename or Config object, '
                        f'but got {type(config)}')
ChaimZhu's avatar
ChaimZhu committed
61
62
63
64
    if cfg_options is not None:
        config.merge_from_dict(cfg_options)
    elif 'init_cfg' in config.model.backbone:
        config.model.backbone.init_cfg = None
65
    convert_SyncBN(config.model)
66
    config.model.train_cfg = None
zhangshilong's avatar
zhangshilong committed
67
    model = MODELS.build(config.model)
ChaimZhu's avatar
ChaimZhu committed
68

wuyuefeng's avatar
Demo  
wuyuefeng committed
69
    if checkpoint is not None:
70
        checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
ChaimZhu's avatar
ChaimZhu committed
71
72
73
74
75
76
77
78
79
80
81
82
83

        dataset_meta = checkpoint['meta'].get('dataset_meta', None)
        # save the dataset_meta in the model for convenience
        if 'dataset_meta' in checkpoint.get('meta', {}):
            # mmdet3d 1.x
            model.dataset_meta = dataset_meta
        elif 'CLASSES' in checkpoint.get('meta', {}):
            # < mmdet3d 1.x
            classes = checkpoint['meta']['CLASSES']
            model.dataset_meta = {'CLASSES': classes}

            if 'PALETTE' in checkpoint.get('meta', {}):  # 3D Segmentor
                model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE']
wuyuefeng's avatar
Demo  
wuyuefeng committed
84
        else:
ChaimZhu's avatar
ChaimZhu committed
85
86
87
88
89
90
            # < mmdet3d 1.x
            model.dataset_meta = {'CLASSES': config.class_names}

            if 'PALETTE' in checkpoint.get('meta', {}):  # 3D Segmentor
                model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE']

wuyuefeng's avatar
Demo  
wuyuefeng committed
91
    model.cfg = config  # save the config in the model for convenience
92
93
94
    if device != 'cpu':
        torch.cuda.set_device(device)
    else:
ChaimZhu's avatar
ChaimZhu committed
95
96
97
        warnings.warn('Don\'t suggest using CPU device. '
                      'Some functions are not supported for now.')

wuyuefeng's avatar
Demo  
wuyuefeng committed
98
99
100
101
102
    model.to(device)
    model.eval()
    return model


ChaimZhu's avatar
ChaimZhu committed
103
104
105
106
107
108
PointsType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
ImagesType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]


def inference_detector(model: nn.Module,
                       pcds: PointsType) -> Union[Det3DDataSample, SampleList]:
wuyuefeng's avatar
Demo  
wuyuefeng committed
109
110
111
112
    """Inference point cloud with the detector.

    Args:
        model (nn.Module): The loaded detector.
ChaimZhu's avatar
ChaimZhu committed
113
114
        pcds (str, ndarray, Sequence[str/ndarray]):
            Either point cloud files or loaded point cloud.
wuyuefeng's avatar
Demo  
wuyuefeng committed
115
116

    Returns:
ChaimZhu's avatar
ChaimZhu committed
117
118
119
        :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
        If pcds is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
wuyuefeng's avatar
Demo  
wuyuefeng committed
120
    """
ChaimZhu's avatar
ChaimZhu committed
121
122
123
124
125
126
    if isinstance(pcds, (list, tuple)):
        is_batch = True
    else:
        pcds = [pcds]
        is_batch = False

wuyuefeng's avatar
Demo  
wuyuefeng committed
127
    cfg = model.cfg
128

ChaimZhu's avatar
ChaimZhu committed
129
    if not isinstance(pcds[0], str):
130
131
        cfg = cfg.copy()
        # set loading pipeline type
ChaimZhu's avatar
ChaimZhu committed
132
        cfg.test_dataloader.dataset.pipeline[0].type = 'LoadPointsFromDict'
133

wuyuefeng's avatar
Demo  
wuyuefeng committed
134
    # build the data pipeline
ChaimZhu's avatar
ChaimZhu committed
135
    test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
wuyuefeng's avatar
Demo  
wuyuefeng committed
136
    test_pipeline = Compose(test_pipeline)
137
138
    box_type_3d, box_mode_3d = \
        get_box_type(cfg.test_dataloader.dataset.box_type_3d)
ChaimZhu's avatar
ChaimZhu committed
139
140
141
142
143
144
145
146
147

    data = []
    for pcd in pcds:
        # prepare data
        if isinstance(pcd, str):
            # load from point cloud file
            data_ = dict(
                lidar_points=dict(lidar_path=pcd),
                # for ScanNet demo we need axis_align_matrix
148
149
150
                axis_align_matrix=np.eye(4),
                box_type_3d=box_type_3d,
                box_mode_3d=box_mode_3d)
ChaimZhu's avatar
ChaimZhu committed
151
152
153
154
155
        else:
            # directly use loaded point cloud
            data_ = dict(
                points=pcd,
                # for ScanNet demo we need axis_align_matrix
156
157
158
                axis_align_matrix=np.eye(4),
                box_type_3d=box_type_3d,
                box_mode_3d=box_mode_3d)
ChaimZhu's avatar
ChaimZhu committed
159
160
        data_ = test_pipeline(data_)
        data.append(data_)
161

wuyuefeng's avatar
Demo  
wuyuefeng committed
162
163
    # forward the model
    with torch.no_grad():
ChaimZhu's avatar
ChaimZhu committed
164
165
166
        results = model.test_step(data)

    if not is_batch:
167
        return results[0], data[0]
ChaimZhu's avatar
ChaimZhu committed
168
    else:
169
        return results, data
wuyuefeng's avatar
Demo  
wuyuefeng committed
170
171


ChaimZhu's avatar
ChaimZhu committed
172
173
174
def inference_multi_modality_detector(model: nn.Module,
                                      pcds: Union[str, Sequence[str]],
                                      imgs: Union[str, Sequence[str]],
175
176
                                      ann_file: Union[str, Sequence[str]],
                                      cam_type: str = 'CAM_FRONT'):
177
    """Inference point cloud with the multi-modality detector.
178
179
180

    Args:
        model (nn.Module): The loaded detector.
ChaimZhu's avatar
ChaimZhu committed
181
182
183
184
        pcds (str, Sequence[str]):
            Either point cloud files or loaded point cloud.
        imgs (str, Sequence[str]):
           Either image files or loaded images.
185
186
187
188
189
        ann_file (str, Sequence[str]): Annotation files.
        cam_type (str): Image of Camera chose to infer.
            For kitti dataset, it should be 'CAM_2',
            and for nuscenes dataset, it should be
            'CAM_FRONT'. Defaults to 'CAM_FRONT'.
190
191

    Returns:
ChaimZhu's avatar
ChaimZhu committed
192
193
194
        :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
        If pcds is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
195
    """
ChaimZhu's avatar
ChaimZhu committed
196
197
198
199
200

    # TODO: We will support
    if isinstance(pcds, (list, tuple)):
        is_batch = True
        assert isinstance(imgs, (list, tuple))
201
        assert len(pcds) == len(imgs)
ChaimZhu's avatar
ChaimZhu committed
202
203
204
205
206
    else:
        pcds = [pcds]
        imgs = [imgs]
        is_batch = False

207
    cfg = model.cfg
ChaimZhu's avatar
ChaimZhu committed
208

209
    # build the data pipeline
ChaimZhu's avatar
ChaimZhu committed
210
    test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
211
    test_pipeline = Compose(test_pipeline)
ChaimZhu's avatar
ChaimZhu committed
212
213
214
    box_type_3d, box_mode_3d = \
        get_box_type(cfg.test_dataloader.dataset.box_type_3d)

215
    data_list = mmengine.load(ann_file)['data_list']
216
217
    assert len(imgs) == len(data_list)

ChaimZhu's avatar
ChaimZhu committed
218
219
220
221
    data = []
    for index, pcd in enumerate(pcds):
        # get data info containing calib
        img = imgs[index]
222
223
224
225
226
227
        data_info = data_list[index]
        img_path = data_info['images'][cam_type]['img_path']

        if osp.basename(img_path) != osp.basename(img):
            raise ValueError(f'the info file of {img_path} is not provided.')

ChaimZhu's avatar
ChaimZhu committed
228
229
230
231
        # TODO: check the name consistency of
        # image file and point cloud file
        data_ = dict(
            lidar_points=dict(lidar_path=pcd),
232
            img_path=img,
ChaimZhu's avatar
ChaimZhu committed
233
234
235
236
237
            box_type_3d=box_type_3d,
            box_mode_3d=box_mode_3d)

        # LiDAR to image conversion for KITTI dataset
        if box_mode_3d == Box3DMode.LIDAR:
238
239
            data_['lidar2img'] = np.array(
                data_info['images'][cam_type]['lidar2img'])
ChaimZhu's avatar
ChaimZhu committed
240
241
        # Depth to image conversion for SUNRGBD dataset
        elif box_mode_3d == Box3DMode.DEPTH:
242
243
            data_['depth2img'] = np.array(
                data_info['images'][cam_type]['depth2img'])
ChaimZhu's avatar
ChaimZhu committed
244

245
        data_ = test_pipeline(data_)
ChaimZhu's avatar
ChaimZhu committed
246
        data.append(data_)
247
248
249

    # forward the model
    with torch.no_grad():
ChaimZhu's avatar
ChaimZhu committed
250
251
        results = model.test_step(data)

252
    for index in range(len(data)):
253
        meta_info = data[index]['data_samples'].metainfo
254
255
        results[index].set_metainfo(meta_info)

ChaimZhu's avatar
ChaimZhu committed
256
    if not is_batch:
257
        return results[0], data[0]
ChaimZhu's avatar
ChaimZhu committed
258
    else:
259
        return results, data
260
261


262
263
264
265
def inference_mono_3d_detector(model: nn.Module,
                               imgs: ImagesType,
                               ann_file: Union[str, Sequence[str]],
                               cam_type: str = 'CAM_FRONT'):
266
267
268
269
    """Inference image with the monocular 3D detector.

    Args:
        model (nn.Module): The loaded detector.
ChaimZhu's avatar
ChaimZhu committed
270
271
272
        imgs (str, Sequence[str]):
           Either image files or loaded images.
        ann_files (str, Sequence[str]): Annotation files.
273
274
275
276
        cam_type (str): Image of Camera chose to infer.
            For kitti dataset, it should be 'CAM_2',
            and for nuscenes dataset, it should be
            'CAM_FRONT'. Defaults to 'CAM_FRONT'.
277
278

    Returns:
ChaimZhu's avatar
ChaimZhu committed
279
280
281
        :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
        If pcds is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
282
    """
ChaimZhu's avatar
ChaimZhu committed
283
284
285
286
287
288
    if isinstance(imgs, (list, tuple)):
        is_batch = True
    else:
        imgs = [imgs]
        is_batch = False

289
    cfg = model.cfg
ChaimZhu's avatar
ChaimZhu committed
290

291
    # build the data pipeline
ChaimZhu's avatar
ChaimZhu committed
292
    test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
293
    test_pipeline = Compose(test_pipeline)
ChaimZhu's avatar
ChaimZhu committed
294
295
296
    box_type_3d, box_mode_3d = \
        get_box_type(cfg.test_dataloader.dataset.box_type_3d)

297
    data_list = mmengine.load(ann_file)
298
299
    assert len(imgs) == len(data_list)

ChaimZhu's avatar
ChaimZhu committed
300
301
302
    data = []
    for index, img in enumerate(imgs):
        # get data info containing calib
303
304
305
306
307
308
309
        data_info = data_list[index]
        img_path = data_info['images'][cam_type]['img_path']
        if osp.basename(img_path) != osp.basename(img):
            raise ValueError(f'the info file of {img_path} is not provided.')

        # replace the img_path in data_info with img
        data_info['images'][cam_type]['img_path'] = img
ChaimZhu's avatar
ChaimZhu committed
310
311
312
313
314
315
        data_ = dict(
            images=data_info['images'],
            box_type_3d=box_type_3d,
            box_mode_3d=box_mode_3d)

        data_ = test_pipeline(data_)
316
        data.append(data_)
317
318
319

    # forward the model
    with torch.no_grad():
ChaimZhu's avatar
ChaimZhu committed
320
        results = model.test_step(data)
321

322
    for index in range(len(data)):
323
        meta_info = data[index]['data_samples'].metainfo
324
325
        results[index].set_metainfo(meta_info)

ChaimZhu's avatar
ChaimZhu committed
326
327
328
329
    if not is_batch:
        return results[0]
    else:
        return results
330

ChaimZhu's avatar
ChaimZhu committed
331
332

def inference_segmentor(model: nn.Module, pcds: PointsType):
333
    """Inference point cloud with the segmentor.
wuyuefeng's avatar
Demo  
wuyuefeng committed
334
335

    Args:
336
        model (nn.Module): The loaded segmentor.
ChaimZhu's avatar
ChaimZhu committed
337
338
        pcds (str, Sequence[str]):
            Either point cloud files or loaded point cloud.
339
340

    Returns:
ChaimZhu's avatar
ChaimZhu committed
341
342
343
        :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
        If pcds is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
wuyuefeng's avatar
Demo  
wuyuefeng committed
344
    """
ChaimZhu's avatar
ChaimZhu committed
345
346
347
348
349
350
    if isinstance(pcds, (list, tuple)):
        is_batch = True
    else:
        pcds = [pcds]
        is_batch = False

351
    cfg = model.cfg
ChaimZhu's avatar
ChaimZhu committed
352

353
    # build the data pipeline
ChaimZhu's avatar
ChaimZhu committed
354
    test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
355
    test_pipeline = Compose(test_pipeline)
ChaimZhu's avatar
ChaimZhu committed
356
357

    data = []
358
    # TODO: support load points array
ChaimZhu's avatar
ChaimZhu committed
359
360
361
362
363
    for pcd in pcds:
        data_ = dict(lidar_points=dict(lidar_path=pcd))
        data_ = test_pipeline(data_)
        data.append(data_)

364
365
    # forward the model
    with torch.no_grad():
ChaimZhu's avatar
ChaimZhu committed
366
367
368
        results = model.test_step(data)

    if not is_batch:
369
        return results[0], data[0]
ChaimZhu's avatar
ChaimZhu committed
370
    else:
371
        return results, data