inference.py 11.1 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
ChaimZhu's avatar
ChaimZhu committed
2
import warnings
3
4
from copy import deepcopy
from os import path as osp
ChaimZhu's avatar
ChaimZhu committed
5
from typing import Sequence, Union
6

wuyuefeng's avatar
Demo  
wuyuefeng committed
7
import mmcv
8
import numpy as np
wuyuefeng's avatar
Demo  
wuyuefeng committed
9
import torch
ChaimZhu's avatar
ChaimZhu committed
10
import torch.nn as nn
11
from mmengine.config import Config
ChaimZhu's avatar
ChaimZhu committed
12
13
from mmengine.dataset import Compose
from mmengine.runner import load_checkpoint
wuyuefeng's avatar
Demo  
wuyuefeng committed
14

zhangshilong's avatar
zhangshilong committed
15
16
17
from mmdet3d.registry import MODELS
from mmdet3d.structures import Box3DMode, Det3DDataSample, get_box_type
from mmdet3d.structures.det3d_data_sample import SampleList
wuyuefeng's avatar
Demo  
wuyuefeng committed
18
19


20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def convert_SyncBN(config):
    """Convert config's naiveSyncBN to BN.

    Args:
         config (str or :obj:`mmcv.Config`): Config file path or the config
            object.
    """
    if isinstance(config, dict):
        for item in config:
            if item == 'norm_cfg':
                config[item]['type'] = config[item]['type']. \
                                    replace('naiveSyncBN', 'BN')
            else:
                convert_SyncBN(config[item])


36
37
38
def init_model(config, checkpoint=None, device='cuda:0'):
    """Initialize a model from config file, which could be a 3D detector or a
    3D segmentor.
wuyuefeng's avatar
Demo  
wuyuefeng committed
39
40

    Args:
ChaimZhu's avatar
ChaimZhu committed
41
        config (str or :obj:`mmengine.Config`): Config file path or the config
wuyuefeng's avatar
Demo  
wuyuefeng committed
42
43
44
45
46
47
48
49
50
            object.
        checkpoint (str, optional): Checkpoint path. If left as None, the model
            will not load any weights.
        device (str): Device to use.

    Returns:
        nn.Module: The constructed detector.
    """
    if isinstance(config, str):
51
52
        config = Config.fromfile(config)
    elif not isinstance(config, Config):
wuyuefeng's avatar
Demo  
wuyuefeng committed
53
54
        raise TypeError('config must be a filename or Config object, '
                        f'but got {type(config)}')
55
    convert_SyncBN(config.model)
56
    config.model.train_cfg = None
zhangshilong's avatar
zhangshilong committed
57
    model = MODELS.build(config.model)
wuyuefeng's avatar
Demo  
wuyuefeng committed
58
    if checkpoint is not None:
59
        checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
wuyuefeng's avatar
Demo  
wuyuefeng committed
60
61
62
63
        if 'CLASSES' in checkpoint['meta']:
            model.CLASSES = checkpoint['meta']['CLASSES']
        else:
            model.CLASSES = config.class_names
64
65
        if 'PALETTE' in checkpoint['meta']:  # 3D Segmentor
            model.PALETTE = checkpoint['meta']['PALETTE']
wuyuefeng's avatar
Demo  
wuyuefeng committed
66
    model.cfg = config  # save the config in the model for convenience
67
68
69
    if device != 'cpu':
        torch.cuda.set_device(device)
    else:
ChaimZhu's avatar
ChaimZhu committed
70
71
72
        warnings.warn('Don\'t suggest using CPU device. '
                      'Some functions are not supported for now.')

wuyuefeng's avatar
Demo  
wuyuefeng committed
73
74
75
76
77
    model.to(device)
    model.eval()
    return model


ChaimZhu's avatar
ChaimZhu committed
78
79
80
81
82
83
PointsType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
ImagesType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]


def inference_detector(model: nn.Module,
                       pcds: PointsType) -> Union[Det3DDataSample, SampleList]:
wuyuefeng's avatar
Demo  
wuyuefeng committed
84
85
86
87
    """Inference point cloud with the detector.

    Args:
        model (nn.Module): The loaded detector.
ChaimZhu's avatar
ChaimZhu committed
88
89
        pcds (str, ndarray, Sequence[str/ndarray]):
            Either point cloud files or loaded point cloud.
wuyuefeng's avatar
Demo  
wuyuefeng committed
90
91

    Returns:
ChaimZhu's avatar
ChaimZhu committed
92
93
94
        :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
        If pcds is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
wuyuefeng's avatar
Demo  
wuyuefeng committed
95
    """
ChaimZhu's avatar
ChaimZhu committed
96
97
98
99
100
101
    if isinstance(pcds, (list, tuple)):
        is_batch = True
    else:
        pcds = [pcds]
        is_batch = False

wuyuefeng's avatar
Demo  
wuyuefeng committed
102
    cfg = model.cfg
103

ChaimZhu's avatar
ChaimZhu committed
104
    if not isinstance(pcds[0], str):
105
106
        cfg = cfg.copy()
        # set loading pipeline type
ChaimZhu's avatar
ChaimZhu committed
107
        cfg.test_dataloader.dataset.pipeline[0].type = 'LoadPointsFromDict'
108

wuyuefeng's avatar
Demo  
wuyuefeng committed
109
    # build the data pipeline
ChaimZhu's avatar
ChaimZhu committed
110
    test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
wuyuefeng's avatar
Demo  
wuyuefeng committed
111
    test_pipeline = Compose(test_pipeline)
112
113
    box_type_3d, box_mode_3d = \
        get_box_type(cfg.test_dataloader.dataset.box_type_3d)
ChaimZhu's avatar
ChaimZhu committed
114
115
116
117
118
119
120
121
122

    data = []
    for pcd in pcds:
        # prepare data
        if isinstance(pcd, str):
            # load from point cloud file
            data_ = dict(
                lidar_points=dict(lidar_path=pcd),
                # for ScanNet demo we need axis_align_matrix
123
124
125
                axis_align_matrix=np.eye(4),
                box_type_3d=box_type_3d,
                box_mode_3d=box_mode_3d)
ChaimZhu's avatar
ChaimZhu committed
126
127
128
129
130
        else:
            # directly use loaded point cloud
            data_ = dict(
                points=pcd,
                # for ScanNet demo we need axis_align_matrix
131
132
133
                axis_align_matrix=np.eye(4),
                box_type_3d=box_type_3d,
                box_mode_3d=box_mode_3d)
ChaimZhu's avatar
ChaimZhu committed
134
135
        data_ = test_pipeline(data_)
        data.append(data_)
136

wuyuefeng's avatar
Demo  
wuyuefeng committed
137
138
    # forward the model
    with torch.no_grad():
ChaimZhu's avatar
ChaimZhu committed
139
140
141
        results = model.test_step(data)

    if not is_batch:
142
        return results[0], data[0]
ChaimZhu's avatar
ChaimZhu committed
143
    else:
144
        return results, data
wuyuefeng's avatar
Demo  
wuyuefeng committed
145
146


ChaimZhu's avatar
ChaimZhu committed
147
148
149
def inference_multi_modality_detector(model: nn.Module,
                                      pcds: Union[str, Sequence[str]],
                                      imgs: Union[str, Sequence[str]],
150
151
                                      ann_file: Union[str, Sequence[str]],
                                      cam_type: str = 'CAM_FRONT'):
152
    """Inference point cloud with the multi-modality detector.
153
154
155

    Args:
        model (nn.Module): The loaded detector.
ChaimZhu's avatar
ChaimZhu committed
156
157
158
159
        pcds (str, Sequence[str]):
            Either point cloud files or loaded point cloud.
        imgs (str, Sequence[str]):
           Either image files or loaded images.
160
161
162
163
164
        ann_file (str, Sequence[str]): Annotation files.
        cam_type (str): Image of Camera chose to infer.
            For kitti dataset, it should be 'CAM_2',
            and for nuscenes dataset, it should be
            'CAM_FRONT'. Defaults to 'CAM_FRONT'.
165
166

    Returns:
ChaimZhu's avatar
ChaimZhu committed
167
168
169
        :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
        If pcds is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
170
    """
ChaimZhu's avatar
ChaimZhu committed
171
172
173
174
175

    # TODO: We will support
    if isinstance(pcds, (list, tuple)):
        is_batch = True
        assert isinstance(imgs, (list, tuple))
176
        assert len(pcds) == len(imgs)
ChaimZhu's avatar
ChaimZhu committed
177
178
179
180
181
    else:
        pcds = [pcds]
        imgs = [imgs]
        is_batch = False

182
    cfg = model.cfg
ChaimZhu's avatar
ChaimZhu committed
183

184
    # build the data pipeline
ChaimZhu's avatar
ChaimZhu committed
185
    test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
186
    test_pipeline = Compose(test_pipeline)
ChaimZhu's avatar
ChaimZhu committed
187
188
189
    box_type_3d, box_mode_3d = \
        get_box_type(cfg.test_dataloader.dataset.box_type_3d)

190
191
192
    data_list = mmcv.load(ann_file)['data_list']
    assert len(imgs) == len(data_list)

ChaimZhu's avatar
ChaimZhu committed
193
194
195
196
    data = []
    for index, pcd in enumerate(pcds):
        # get data info containing calib
        img = imgs[index]
197
198
199
200
201
202
        data_info = data_list[index]
        img_path = data_info['images'][cam_type]['img_path']

        if osp.basename(img_path) != osp.basename(img):
            raise ValueError(f'the info file of {img_path} is not provided.')

ChaimZhu's avatar
ChaimZhu committed
203
204
205
206
        # TODO: check the name consistency of
        # image file and point cloud file
        data_ = dict(
            lidar_points=dict(lidar_path=pcd),
207
            img_path=img,
ChaimZhu's avatar
ChaimZhu committed
208
209
210
211
212
            box_type_3d=box_type_3d,
            box_mode_3d=box_mode_3d)

        # LiDAR to image conversion for KITTI dataset
        if box_mode_3d == Box3DMode.LIDAR:
213
214
            data_['lidar2img'] = np.array(
                data_info['images'][cam_type]['lidar2img'])
ChaimZhu's avatar
ChaimZhu committed
215
216
        # Depth to image conversion for SUNRGBD dataset
        elif box_mode_3d == Box3DMode.DEPTH:
217
218
            data_['depth2img'] = np.array(
                data_info['images'][cam_type]['depth2img'])
ChaimZhu's avatar
ChaimZhu committed
219

220
        data_ = test_pipeline(data_)
ChaimZhu's avatar
ChaimZhu committed
221
        data.append(data_)
222
223
224

    # forward the model
    with torch.no_grad():
ChaimZhu's avatar
ChaimZhu committed
225
226
        results = model.test_step(data)

227
228
229
230
    for index in range(len(data)):
        meta_info = data[index]['data_sample'].metainfo
        results[index].set_metainfo(meta_info)

ChaimZhu's avatar
ChaimZhu committed
231
    if not is_batch:
232
        return results[0], data[0]
ChaimZhu's avatar
ChaimZhu committed
233
    else:
234
        return results, data
235
236


237
238
239
240
def inference_mono_3d_detector(model: nn.Module,
                               imgs: ImagesType,
                               ann_file: Union[str, Sequence[str]],
                               cam_type: str = 'CAM_FRONT'):
241
242
243
244
    """Inference image with the monocular 3D detector.

    Args:
        model (nn.Module): The loaded detector.
ChaimZhu's avatar
ChaimZhu committed
245
246
247
        imgs (str, Sequence[str]):
           Either image files or loaded images.
        ann_files (str, Sequence[str]): Annotation files.
248
249
250
251
        cam_type (str): Image of Camera chose to infer.
            For kitti dataset, it should be 'CAM_2',
            and for nuscenes dataset, it should be
            'CAM_FRONT'. Defaults to 'CAM_FRONT'.
252
253

    Returns:
ChaimZhu's avatar
ChaimZhu committed
254
255
256
        :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
        If pcds is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
257
    """
ChaimZhu's avatar
ChaimZhu committed
258
259
260
261
262
263
    if isinstance(imgs, (list, tuple)):
        is_batch = True
    else:
        imgs = [imgs]
        is_batch = False

264
    cfg = model.cfg
ChaimZhu's avatar
ChaimZhu committed
265

266
    # build the data pipeline
ChaimZhu's avatar
ChaimZhu committed
267
    test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
268
    test_pipeline = Compose(test_pipeline)
ChaimZhu's avatar
ChaimZhu committed
269
270
271
    box_type_3d, box_mode_3d = \
        get_box_type(cfg.test_dataloader.dataset.box_type_3d)

272
273
274
    data_list = mmcv.load(ann_file)
    assert len(imgs) == len(data_list)

ChaimZhu's avatar
ChaimZhu committed
275
276
277
    data = []
    for index, img in enumerate(imgs):
        # get data info containing calib
278
279
280
281
282
283
284
        data_info = data_list[index]
        img_path = data_info['images'][cam_type]['img_path']
        if osp.basename(img_path) != osp.basename(img):
            raise ValueError(f'the info file of {img_path} is not provided.')

        # replace the img_path in data_info with img
        data_info['images'][cam_type]['img_path'] = img
ChaimZhu's avatar
ChaimZhu committed
285
286
287
288
289
290
        data_ = dict(
            images=data_info['images'],
            box_type_3d=box_type_3d,
            box_mode_3d=box_mode_3d)

        data_ = test_pipeline(data_)
291
        data.append(data_)
292
293
294

    # forward the model
    with torch.no_grad():
ChaimZhu's avatar
ChaimZhu committed
295
        results = model.test_step(data)
296

297
298
299
300
    for index in range(len(data)):
        meta_info = data[index]['data_sample'].metainfo
        results[index].set_metainfo(meta_info)

ChaimZhu's avatar
ChaimZhu committed
301
302
303
304
    if not is_batch:
        return results[0]
    else:
        return results
305

ChaimZhu's avatar
ChaimZhu committed
306
307

def inference_segmentor(model: nn.Module, pcds: PointsType):
308
    """Inference point cloud with the segmentor.
wuyuefeng's avatar
Demo  
wuyuefeng committed
309
310

    Args:
311
        model (nn.Module): The loaded segmentor.
ChaimZhu's avatar
ChaimZhu committed
312
313
        pcds (str, Sequence[str]):
            Either point cloud files or loaded point cloud.
314
315

    Returns:
ChaimZhu's avatar
ChaimZhu committed
316
317
318
        :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
        If pcds is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
wuyuefeng's avatar
Demo  
wuyuefeng committed
319
    """
ChaimZhu's avatar
ChaimZhu committed
320
321
322
323
324
325
    if isinstance(pcds, (list, tuple)):
        is_batch = True
    else:
        pcds = [pcds]
        is_batch = False

326
    cfg = model.cfg
ChaimZhu's avatar
ChaimZhu committed
327

328
    # build the data pipeline
ChaimZhu's avatar
ChaimZhu committed
329
    test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
330
    test_pipeline = Compose(test_pipeline)
ChaimZhu's avatar
ChaimZhu committed
331
332

    data = []
333
    # TODO: support load points array
ChaimZhu's avatar
ChaimZhu committed
334
335
336
337
338
    for pcd in pcds:
        data_ = dict(lidar_points=dict(lidar_path=pcd))
        data_ = test_pipeline(data_)
        data.append(data_)

339
340
    # forward the model
    with torch.no_grad():
ChaimZhu's avatar
ChaimZhu committed
341
342
343
        results = model.test_step(data)

    if not is_batch:
344
        return results[0], data[0]
ChaimZhu's avatar
ChaimZhu committed
345
    else:
346
        return results, data