inference.py 12.2 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
ChaimZhu's avatar
ChaimZhu committed
2
import warnings
3
4
from copy import deepcopy
from os import path as osp
ChaimZhu's avatar
ChaimZhu committed
5
6
from pathlib import Path
from typing import Optional, Sequence, Union
7

8
import mmengine
9
import numpy as np
wuyuefeng's avatar
Demo  
wuyuefeng committed
10
import torch
ChaimZhu's avatar
ChaimZhu committed
11
import torch.nn as nn
12
from mmengine.config import Config
13
from mmengine.dataset import Compose, pseudo_collate
ChaimZhu's avatar
ChaimZhu committed
14
from mmengine.runner import load_checkpoint
wuyuefeng's avatar
Demo  
wuyuefeng committed
15

zhangshilong's avatar
zhangshilong committed
16
17
18
from mmdet3d.registry import MODELS
from mmdet3d.structures import Box3DMode, Det3DDataSample, get_box_type
from mmdet3d.structures.det3d_data_sample import SampleList
wuyuefeng's avatar
Demo  
wuyuefeng committed
19
20


21
22
23
24
def convert_SyncBN(config):
    """Convert config's naiveSyncBN to BN.

    Args:
25
         config (str or :obj:`mmengine.Config`): Config file path or the config
26
27
28
29
30
31
32
33
34
35
36
            object.
    """
    if isinstance(config, dict):
        for item in config:
            if item == 'norm_cfg':
                config[item]['type'] = config[item]['type']. \
                                    replace('naiveSyncBN', 'BN')
            else:
                convert_SyncBN(config[item])


ChaimZhu's avatar
ChaimZhu committed
37
38
39
40
def init_model(config: Union[str, Path, Config],
               checkpoint: Optional[str] = None,
               device: str = 'cuda:0',
               cfg_options: Optional[dict] = None):
41
42
    """Initialize a model from config file, which could be a 3D detector or a
    3D segmentor.
wuyuefeng's avatar
Demo  
wuyuefeng committed
43
44

    Args:
ChaimZhu's avatar
ChaimZhu committed
45
46
        config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
            :obj:`Path`, or the config object.
wuyuefeng's avatar
Demo  
wuyuefeng committed
47
48
49
        checkpoint (str, optional): Checkpoint path. If left as None, the model
            will not load any weights.
        device (str): Device to use.
ChaimZhu's avatar
ChaimZhu committed
50
51
        cfg_options (dict, optional): Options to override some settings in
            the used config.
wuyuefeng's avatar
Demo  
wuyuefeng committed
52
53
54
55

    Returns:
        nn.Module: The constructed detector.
    """
ChaimZhu's avatar
ChaimZhu committed
56
    if isinstance(config, (str, Path)):
57
58
        config = Config.fromfile(config)
    elif not isinstance(config, Config):
wuyuefeng's avatar
Demo  
wuyuefeng committed
59
60
        raise TypeError('config must be a filename or Config object, '
                        f'but got {type(config)}')
ChaimZhu's avatar
ChaimZhu committed
61
62
    if cfg_options is not None:
        config.merge_from_dict(cfg_options)
63

64
    convert_SyncBN(config.model)
65
    config.model.train_cfg = None
zhangshilong's avatar
zhangshilong committed
66
    model = MODELS.build(config.model)
ChaimZhu's avatar
ChaimZhu committed
67

wuyuefeng's avatar
Demo  
wuyuefeng committed
68
    if checkpoint is not None:
69
        checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
ChaimZhu's avatar
ChaimZhu committed
70
71
72
73
74
75
76
77
78
79
80
81
        dataset_meta = checkpoint['meta'].get('dataset_meta', None)
        # save the dataset_meta in the model for convenience
        if 'dataset_meta' in checkpoint.get('meta', {}):
            # mmdet3d 1.x
            model.dataset_meta = dataset_meta
        elif 'CLASSES' in checkpoint.get('meta', {}):
            # < mmdet3d 1.x
            classes = checkpoint['meta']['CLASSES']
            model.dataset_meta = {'CLASSES': classes}

            if 'PALETTE' in checkpoint.get('meta', {}):  # 3D Segmentor
                model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE']
wuyuefeng's avatar
Demo  
wuyuefeng committed
82
        else:
ChaimZhu's avatar
ChaimZhu committed
83
84
85
86
87
88
            # < mmdet3d 1.x
            model.dataset_meta = {'CLASSES': config.class_names}

            if 'PALETTE' in checkpoint.get('meta', {}):  # 3D Segmentor
                model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE']

wuyuefeng's avatar
Demo  
wuyuefeng committed
89
    model.cfg = config  # save the config in the model for convenience
90
91
92
    if device != 'cpu':
        torch.cuda.set_device(device)
    else:
ChaimZhu's avatar
ChaimZhu committed
93
94
95
        warnings.warn('Don\'t suggest using CPU device. '
                      'Some functions are not supported for now.')

wuyuefeng's avatar
Demo  
wuyuefeng committed
96
97
98
99
100
    model.to(device)
    model.eval()
    return model


ChaimZhu's avatar
ChaimZhu committed
101
102
103
104
105
106
PointsType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
ImagesType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]


def inference_detector(model: nn.Module,
                       pcds: PointsType) -> Union[Det3DDataSample, SampleList]:
wuyuefeng's avatar
Demo  
wuyuefeng committed
107
108
109
110
    """Inference point cloud with the detector.

    Args:
        model (nn.Module): The loaded detector.
ChaimZhu's avatar
ChaimZhu committed
111
112
        pcds (str, ndarray, Sequence[str/ndarray]):
            Either point cloud files or loaded point cloud.
wuyuefeng's avatar
Demo  
wuyuefeng committed
113
114

    Returns:
ChaimZhu's avatar
ChaimZhu committed
115
116
117
        :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
        If pcds is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
wuyuefeng's avatar
Demo  
wuyuefeng committed
118
    """
ChaimZhu's avatar
ChaimZhu committed
119
120
121
122
123
124
    if isinstance(pcds, (list, tuple)):
        is_batch = True
    else:
        pcds = [pcds]
        is_batch = False

wuyuefeng's avatar
Demo  
wuyuefeng committed
125
    cfg = model.cfg
126

ChaimZhu's avatar
ChaimZhu committed
127
    if not isinstance(pcds[0], str):
128
129
        cfg = cfg.copy()
        # set loading pipeline type
ChaimZhu's avatar
ChaimZhu committed
130
        cfg.test_dataloader.dataset.pipeline[0].type = 'LoadPointsFromDict'
131

wuyuefeng's avatar
Demo  
wuyuefeng committed
132
    # build the data pipeline
ChaimZhu's avatar
ChaimZhu committed
133
    test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
wuyuefeng's avatar
Demo  
wuyuefeng committed
134
    test_pipeline = Compose(test_pipeline)
135
136
    box_type_3d, box_mode_3d = \
        get_box_type(cfg.test_dataloader.dataset.box_type_3d)
ChaimZhu's avatar
ChaimZhu committed
137
138
139
140
141
142
143
144

    data = []
    for pcd in pcds:
        # prepare data
        if isinstance(pcd, str):
            # load from point cloud file
            data_ = dict(
                lidar_points=dict(lidar_path=pcd),
145
                timestamp=1,
ChaimZhu's avatar
ChaimZhu committed
146
                # for ScanNet demo we need axis_align_matrix
147
148
149
                axis_align_matrix=np.eye(4),
                box_type_3d=box_type_3d,
                box_mode_3d=box_mode_3d)
ChaimZhu's avatar
ChaimZhu committed
150
151
152
153
        else:
            # directly use loaded point cloud
            data_ = dict(
                points=pcd,
154
                timestamp=1,
ChaimZhu's avatar
ChaimZhu committed
155
                # for ScanNet demo we need axis_align_matrix
156
157
158
                axis_align_matrix=np.eye(4),
                box_type_3d=box_type_3d,
                box_mode_3d=box_mode_3d)
ChaimZhu's avatar
ChaimZhu committed
159
160
        data_ = test_pipeline(data_)
        data.append(data_)
161

162
163
    collate_data = pseudo_collate(data)

wuyuefeng's avatar
Demo  
wuyuefeng committed
164
165
    # forward the model
    with torch.no_grad():
166
        results = model.test_step(collate_data)
ChaimZhu's avatar
ChaimZhu committed
167
168

    if not is_batch:
169
        return results[0], data[0]
ChaimZhu's avatar
ChaimZhu committed
170
    else:
171
        return results, data
wuyuefeng's avatar
Demo  
wuyuefeng committed
172
173


ChaimZhu's avatar
ChaimZhu committed
174
175
176
def inference_multi_modality_detector(model: nn.Module,
                                      pcds: Union[str, Sequence[str]],
                                      imgs: Union[str, Sequence[str]],
177
178
                                      ann_file: Union[str, Sequence[str]],
                                      cam_type: str = 'CAM_FRONT'):
179
    """Inference point cloud with the multi-modality detector.
180
181
182

    Args:
        model (nn.Module): The loaded detector.
ChaimZhu's avatar
ChaimZhu committed
183
184
185
186
        pcds (str, Sequence[str]):
            Either point cloud files or loaded point cloud.
        imgs (str, Sequence[str]):
           Either image files or loaded images.
187
188
189
190
191
        ann_file (str, Sequence[str]): Annotation files.
        cam_type (str): Image of Camera chose to infer.
            For kitti dataset, it should be 'CAM_2',
            and for nuscenes dataset, it should be
            'CAM_FRONT'. Defaults to 'CAM_FRONT'.
192
193

    Returns:
ChaimZhu's avatar
ChaimZhu committed
194
195
196
        :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
        If pcds is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
197
    """
ChaimZhu's avatar
ChaimZhu committed
198
199
200
201
202

    # TODO: We will support
    if isinstance(pcds, (list, tuple)):
        is_batch = True
        assert isinstance(imgs, (list, tuple))
203
        assert len(pcds) == len(imgs)
ChaimZhu's avatar
ChaimZhu committed
204
205
206
207
208
    else:
        pcds = [pcds]
        imgs = [imgs]
        is_batch = False

209
    cfg = model.cfg
ChaimZhu's avatar
ChaimZhu committed
210

211
    # build the data pipeline
ChaimZhu's avatar
ChaimZhu committed
212
    test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
213
    test_pipeline = Compose(test_pipeline)
ChaimZhu's avatar
ChaimZhu committed
214
215
216
    box_type_3d, box_mode_3d = \
        get_box_type(cfg.test_dataloader.dataset.box_type_3d)

217
    data_list = mmengine.load(ann_file)['data_list']
218
219
    assert len(imgs) == len(data_list)

ChaimZhu's avatar
ChaimZhu committed
220
221
222
223
    data = []
    for index, pcd in enumerate(pcds):
        # get data info containing calib
        img = imgs[index]
224
225
226
227
228
229
        data_info = data_list[index]
        img_path = data_info['images'][cam_type]['img_path']

        if osp.basename(img_path) != osp.basename(img):
            raise ValueError(f'the info file of {img_path} is not provided.')

ChaimZhu's avatar
ChaimZhu committed
230
231
232
233
        # TODO: check the name consistency of
        # image file and point cloud file
        data_ = dict(
            lidar_points=dict(lidar_path=pcd),
234
            img_path=img,
ChaimZhu's avatar
ChaimZhu committed
235
236
237
238
239
            box_type_3d=box_type_3d,
            box_mode_3d=box_mode_3d)

        # LiDAR to image conversion for KITTI dataset
        if box_mode_3d == Box3DMode.LIDAR:
240
241
            data_['lidar2img'] = np.array(
                data_info['images'][cam_type]['lidar2img'])
ChaimZhu's avatar
ChaimZhu committed
242
243
        # Depth to image conversion for SUNRGBD dataset
        elif box_mode_3d == Box3DMode.DEPTH:
244
245
            data_['depth2img'] = np.array(
                data_info['images'][cam_type]['depth2img'])
ChaimZhu's avatar
ChaimZhu committed
246

247
        data_ = test_pipeline(data_)
ChaimZhu's avatar
ChaimZhu committed
248
        data.append(data_)
249

250
251
    collate_data = pseudo_collate(data)

252
253
    # forward the model
    with torch.no_grad():
254
        results = model.test_step(collate_data)
255

ChaimZhu's avatar
ChaimZhu committed
256
    if not is_batch:
257
        return results[0], data[0]
ChaimZhu's avatar
ChaimZhu committed
258
    else:
259
        return results, data
260
261


262
263
264
265
def inference_mono_3d_detector(model: nn.Module,
                               imgs: ImagesType,
                               ann_file: Union[str, Sequence[str]],
                               cam_type: str = 'CAM_FRONT'):
266
267
268
269
    """Inference image with the monocular 3D detector.

    Args:
        model (nn.Module): The loaded detector.
ChaimZhu's avatar
ChaimZhu committed
270
271
272
        imgs (str, Sequence[str]):
           Either image files or loaded images.
        ann_files (str, Sequence[str]): Annotation files.
273
274
275
276
        cam_type (str): Image of Camera chose to infer.
            For kitti dataset, it should be 'CAM_2',
            and for nuscenes dataset, it should be
            'CAM_FRONT'. Defaults to 'CAM_FRONT'.
277
278

    Returns:
ChaimZhu's avatar
ChaimZhu committed
279
280
281
        :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
        If pcds is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
282
    """
ChaimZhu's avatar
ChaimZhu committed
283
284
285
286
287
288
    if isinstance(imgs, (list, tuple)):
        is_batch = True
    else:
        imgs = [imgs]
        is_batch = False

289
    cfg = model.cfg
ChaimZhu's avatar
ChaimZhu committed
290

291
    # build the data pipeline
ChaimZhu's avatar
ChaimZhu committed
292
    test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
293
    test_pipeline = Compose(test_pipeline)
ChaimZhu's avatar
ChaimZhu committed
294
295
296
    box_type_3d, box_mode_3d = \
        get_box_type(cfg.test_dataloader.dataset.box_type_3d)

297
    data_list = mmengine.load(ann_file)
298
299
    assert len(imgs) == len(data_list)

ChaimZhu's avatar
ChaimZhu committed
300
301
302
    data = []
    for index, img in enumerate(imgs):
        # get data info containing calib
303
304
305
306
307
308
309
        data_info = data_list[index]
        img_path = data_info['images'][cam_type]['img_path']
        if osp.basename(img_path) != osp.basename(img):
            raise ValueError(f'the info file of {img_path} is not provided.')

        # replace the img_path in data_info with img
        data_info['images'][cam_type]['img_path'] = img
ChaimZhu's avatar
ChaimZhu committed
310
311
312
313
314
315
        data_ = dict(
            images=data_info['images'],
            box_type_3d=box_type_3d,
            box_mode_3d=box_mode_3d)

        data_ = test_pipeline(data_)
316
        data.append(data_)
317

318
319
    collate_data = pseudo_collate(data)

320
321
    # forward the model
    with torch.no_grad():
322
        results = model.test_step(collate_data)
323

ChaimZhu's avatar
ChaimZhu committed
324
325
326
327
    if not is_batch:
        return results[0]
    else:
        return results
328

ChaimZhu's avatar
ChaimZhu committed
329
330

def inference_segmentor(model: nn.Module, pcds: PointsType):
331
    """Inference point cloud with the segmentor.
wuyuefeng's avatar
Demo  
wuyuefeng committed
332
333

    Args:
334
        model (nn.Module): The loaded segmentor.
ChaimZhu's avatar
ChaimZhu committed
335
336
        pcds (str, Sequence[str]):
            Either point cloud files or loaded point cloud.
337
338

    Returns:
ChaimZhu's avatar
ChaimZhu committed
339
340
341
        :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
        If pcds is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
wuyuefeng's avatar
Demo  
wuyuefeng committed
342
    """
ChaimZhu's avatar
ChaimZhu committed
343
344
345
346
347
348
    if isinstance(pcds, (list, tuple)):
        is_batch = True
    else:
        pcds = [pcds]
        is_batch = False

349
    cfg = model.cfg
ChaimZhu's avatar
ChaimZhu committed
350

351
    # build the data pipeline
ChaimZhu's avatar
ChaimZhu committed
352
    test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
353
354
355
356
357
358

    new_test_pipeline = []
    for pipeline in test_pipeline:
        if pipeline['type'] != 'LoadAnnotations3D':
            new_test_pipeline.append(pipeline)
    test_pipeline = Compose(new_test_pipeline)
ChaimZhu's avatar
ChaimZhu committed
359
360

    data = []
361
    # TODO: support load points array
ChaimZhu's avatar
ChaimZhu committed
362
363
364
365
366
    for pcd in pcds:
        data_ = dict(lidar_points=dict(lidar_path=pcd))
        data_ = test_pipeline(data_)
        data.append(data_)

367
368
    collate_data = pseudo_collate(data)

369
370
    # forward the model
    with torch.no_grad():
371
        results = model.test_step(collate_data)
ChaimZhu's avatar
ChaimZhu committed
372
373

    if not is_batch:
374
        return results[0], data[0]
ChaimZhu's avatar
ChaimZhu committed
375
    else:
376
        return results, data