inference.py 9.87 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
ChaimZhu's avatar
ChaimZhu committed
2
import warnings
3
4
from copy import deepcopy
from os import path as osp
ChaimZhu's avatar
ChaimZhu committed
5
from typing import Sequence, Union
6

wuyuefeng's avatar
Demo  
wuyuefeng committed
7
import mmcv
ChaimZhu's avatar
ChaimZhu committed
8
import mmengine
9
import numpy as np
wuyuefeng's avatar
Demo  
wuyuefeng committed
10
import torch
ChaimZhu's avatar
ChaimZhu committed
11
12
13
import torch.nn as nn
from mmengine.dataset import Compose
from mmengine.runner import load_checkpoint
wuyuefeng's avatar
Demo  
wuyuefeng committed
14

zhangshilong's avatar
zhangshilong committed
15
16
17
from mmdet3d.registry import MODELS
from mmdet3d.structures import Box3DMode, Det3DDataSample, get_box_type
from mmdet3d.structures.det3d_data_sample import SampleList
wuyuefeng's avatar
Demo  
wuyuefeng committed
18
19


20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def convert_SyncBN(config):
    """Convert config's naiveSyncBN to BN.

    Args:
         config (str or :obj:`mmcv.Config`): Config file path or the config
            object.
    """
    if isinstance(config, dict):
        for item in config:
            if item == 'norm_cfg':
                config[item]['type'] = config[item]['type']. \
                                    replace('naiveSyncBN', 'BN')
            else:
                convert_SyncBN(config[item])


36
37
38
def init_model(config, checkpoint=None, device='cuda:0'):
    """Initialize a model from config file, which could be a 3D detector or a
    3D segmentor.
wuyuefeng's avatar
Demo  
wuyuefeng committed
39
40

    Args:
ChaimZhu's avatar
ChaimZhu committed
41
        config (str or :obj:`mmengine.Config`): Config file path or the config
wuyuefeng's avatar
Demo  
wuyuefeng committed
42
43
44
45
46
47
48
49
50
            object.
        checkpoint (str, optional): Checkpoint path. If left as None, the model
            will not load any weights.
        device (str): Device to use.

    Returns:
        nn.Module: The constructed detector.
    """
    if isinstance(config, str):
ChaimZhu's avatar
ChaimZhu committed
51
52
        config = mmengine.Config.fromfile(config)
    elif not isinstance(config, mmengine.Config):
wuyuefeng's avatar
Demo  
wuyuefeng committed
53
54
55
        raise TypeError('config must be a filename or Config object, '
                        f'but got {type(config)}')
    config.model.pretrained = None
56
    convert_SyncBN(config.model)
57
    config.model.train_cfg = None
zhangshilong's avatar
zhangshilong committed
58
    model = MODELS.build(config.model)
wuyuefeng's avatar
Demo  
wuyuefeng committed
59
    if checkpoint is not None:
60
        checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
wuyuefeng's avatar
Demo  
wuyuefeng committed
61
62
63
64
        if 'CLASSES' in checkpoint['meta']:
            model.CLASSES = checkpoint['meta']['CLASSES']
        else:
            model.CLASSES = config.class_names
65
66
        if 'PALETTE' in checkpoint['meta']:  # 3D Segmentor
            model.PALETTE = checkpoint['meta']['PALETTE']
wuyuefeng's avatar
Demo  
wuyuefeng committed
67
    model.cfg = config  # save the config in the model for convenience
68
69
70
    if device != 'cpu':
        torch.cuda.set_device(device)
    else:
ChaimZhu's avatar
ChaimZhu committed
71
72
73
        warnings.warn('Don\'t suggest using CPU device. '
                      'Some functions are not supported for now.')

wuyuefeng's avatar
Demo  
wuyuefeng committed
74
75
76
77
78
    model.to(device)
    model.eval()
    return model


ChaimZhu's avatar
ChaimZhu committed
79
80
81
82
83
84
PointsType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
ImagesType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]


def inference_detector(model: nn.Module,
                       pcds: PointsType) -> Union[Det3DDataSample, SampleList]:
wuyuefeng's avatar
Demo  
wuyuefeng committed
85
86
87
88
    """Inference point cloud with the detector.

    Args:
        model (nn.Module): The loaded detector.
ChaimZhu's avatar
ChaimZhu committed
89
90
        pcds (str, ndarray, Sequence[str/ndarray]):
            Either point cloud files or loaded point cloud.
wuyuefeng's avatar
Demo  
wuyuefeng committed
91
92

    Returns:
ChaimZhu's avatar
ChaimZhu committed
93
94
95
        :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
        If pcds is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
wuyuefeng's avatar
Demo  
wuyuefeng committed
96
    """
ChaimZhu's avatar
ChaimZhu committed
97
98
99
100
101
102
    if isinstance(pcds, (list, tuple)):
        is_batch = True
    else:
        pcds = [pcds]
        is_batch = False

wuyuefeng's avatar
Demo  
wuyuefeng committed
103
    cfg = model.cfg
104

ChaimZhu's avatar
ChaimZhu committed
105
    if not isinstance(pcds[0], str):
106
107
        cfg = cfg.copy()
        # set loading pipeline type
ChaimZhu's avatar
ChaimZhu committed
108
        cfg.test_dataloader.dataset.pipeline[0].type = 'LoadPointsFromDict'
109

wuyuefeng's avatar
Demo  
wuyuefeng committed
110
    # build the data pipeline
ChaimZhu's avatar
ChaimZhu committed
111
    test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
wuyuefeng's avatar
Demo  
wuyuefeng committed
112
    test_pipeline = Compose(test_pipeline)
ChaimZhu's avatar
ChaimZhu committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    # box_type_3d, box_mode_3d = get_box_type(
    # cfg.test_dataloader.dataset.box_type_3d)

    data = []
    for pcd in pcds:
        # prepare data
        if isinstance(pcd, str):
            # load from point cloud file
            data_ = dict(
                lidar_points=dict(lidar_path=pcd),
                # for ScanNet demo we need axis_align_matrix
                ann_info=dict(axis_align_matrix=np.eye(4)),
                sweeps=[],
                # set timestamp = 0
                timestamp=[0])
        else:
            # directly use loaded point cloud
            data_ = dict(
                points=pcd,
                # for ScanNet demo we need axis_align_matrix
                ann_info=dict(axis_align_matrix=np.eye(4)),
                sweeps=[],
                # set timestamp = 0
                timestamp=[0])
        data_ = test_pipeline(data_)
        data.append(data_)
139

wuyuefeng's avatar
Demo  
wuyuefeng committed
140
141
    # forward the model
    with torch.no_grad():
ChaimZhu's avatar
ChaimZhu committed
142
143
144
145
146
147
        results = model.test_step(data)

    if not is_batch:
        return results[0]
    else:
        return results
wuyuefeng's avatar
Demo  
wuyuefeng committed
148
149


ChaimZhu's avatar
ChaimZhu committed
150
151
152
153
def inference_multi_modality_detector(model: nn.Module,
                                      pcds: Union[str, Sequence[str]],
                                      imgs: Union[str, Sequence[str]],
                                      ann_files: Union[str, Sequence[str]]):
154
    """Inference point cloud with the multi-modality detector.
155
156
157

    Args:
        model (nn.Module): The loaded detector.
ChaimZhu's avatar
ChaimZhu committed
158
159
160
161
162
        pcds (str, Sequence[str]):
            Either point cloud files or loaded point cloud.
        imgs (str, Sequence[str]):
           Either image files or loaded images.
        ann_files (str, Sequence[str]): Annotation files.
163
164

    Returns:
ChaimZhu's avatar
ChaimZhu committed
165
166
167
        :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
        If pcds is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
168
    """
ChaimZhu's avatar
ChaimZhu committed
169
170
171
172
173
174
175
176
177
178
179
180
181

    # TODO: We will support
    if isinstance(pcds, (list, tuple)):
        is_batch = True
        assert isinstance(imgs, (list, tuple))
        assert isinstance(ann_files, (list, tuple))
        assert len(pcds) == len(imgs) == len(ann_files)
    else:
        pcds = [pcds]
        imgs = [imgs]
        ann_files = [ann_files]
        is_batch = False

182
    cfg = model.cfg
ChaimZhu's avatar
ChaimZhu committed
183

184
    # build the data pipeline
ChaimZhu's avatar
ChaimZhu committed
185
    test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
186
    test_pipeline = Compose(test_pipeline)
ChaimZhu's avatar
ChaimZhu committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    box_type_3d, box_mode_3d = \
        get_box_type(cfg.test_dataloader.dataset.box_type_3d)

    data = []
    for index, pcd in enumerate(pcds):
        # get data info containing calib
        img = imgs[index]
        ann_file = ann_files[index]
        data_info = mmcv.load(ann_file)[0]
        # TODO: check the name consistency of
        # image file and point cloud file
        data_ = dict(
            lidar_points=dict(lidar_path=pcd),
            img_path=imgs[index],
            img_prefix=osp.dirname(img),
            img_info=dict(filename=osp.basename(img)),
            box_type_3d=box_type_3d,
            box_mode_3d=box_mode_3d)
        data_ = test_pipeline(data_)

        # LiDAR to image conversion for KITTI dataset
        if box_mode_3d == Box3DMode.LIDAR:
            data_['lidar2img'] = data_info['images']['CAM2']['lidar2img']
        # Depth to image conversion for SUNRGBD dataset
        elif box_mode_3d == Box3DMode.DEPTH:
            data_['depth2img'] = data_info['images']['CAM0']['depth2img']

        data.append(data_)
215
216
217

    # forward the model
    with torch.no_grad():
ChaimZhu's avatar
ChaimZhu committed
218
219
220
221
222
223
        results = model.test_step(data)

    if not is_batch:
        return results[0]
    else:
        return results
224
225


ChaimZhu's avatar
ChaimZhu committed
226
227
def inference_mono_3d_detector(model: nn.Module, imgs: ImagesType,
                               ann_files: Union[str, Sequence[str]]):
228
229
230
231
    """Inference image with the monocular 3D detector.

    Args:
        model (nn.Module): The loaded detector.
ChaimZhu's avatar
ChaimZhu committed
232
233
234
        imgs (str, Sequence[str]):
           Either image files or loaded images.
        ann_files (str, Sequence[str]): Annotation files.
235
236

    Returns:
ChaimZhu's avatar
ChaimZhu committed
237
238
239
        :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
        If pcds is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
240
    """
ChaimZhu's avatar
ChaimZhu committed
241
242
243
244
245
246
    if isinstance(imgs, (list, tuple)):
        is_batch = True
    else:
        imgs = [imgs]
        is_batch = False

247
    cfg = model.cfg
ChaimZhu's avatar
ChaimZhu committed
248

249
    # build the data pipeline
ChaimZhu's avatar
ChaimZhu committed
250
    test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
251
    test_pipeline = Compose(test_pipeline)
ChaimZhu's avatar
ChaimZhu committed
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
    box_type_3d, box_mode_3d = \
        get_box_type(cfg.test_dataloader.dataset.box_type_3d)

    data = []
    for index, img in enumerate(imgs):
        ann_file = ann_files[index]
        # get data info containing calib
        data_info = mmcv.load(ann_file)[0]
        data_ = dict(
            img_path=img,
            images=data_info['images'],
            box_type_3d=box_type_3d,
            box_mode_3d=box_mode_3d)

        data_ = test_pipeline(data_)
267
268
269

    # forward the model
    with torch.no_grad():
ChaimZhu's avatar
ChaimZhu committed
270
        results = model.test_step(data)
271

ChaimZhu's avatar
ChaimZhu committed
272
273
274
275
    if not is_batch:
        return results[0]
    else:
        return results
276

ChaimZhu's avatar
ChaimZhu committed
277
278

def inference_segmentor(model: nn.Module, pcds: PointsType):
279
    """Inference point cloud with the segmentor.
wuyuefeng's avatar
Demo  
wuyuefeng committed
280
281

    Args:
282
        model (nn.Module): The loaded segmentor.
ChaimZhu's avatar
ChaimZhu committed
283
284
        pcds (str, Sequence[str]):
            Either point cloud files or loaded point cloud.
285
286

    Returns:
ChaimZhu's avatar
ChaimZhu committed
287
288
289
        :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]:
        If pcds is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
wuyuefeng's avatar
Demo  
wuyuefeng committed
290
    """
ChaimZhu's avatar
ChaimZhu committed
291
292
293
294
295
296
    if isinstance(pcds, (list, tuple)):
        is_batch = True
    else:
        pcds = [pcds]
        is_batch = False

297
    cfg = model.cfg
ChaimZhu's avatar
ChaimZhu committed
298

299
    # build the data pipeline
ChaimZhu's avatar
ChaimZhu committed
300
    test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
301
    test_pipeline = Compose(test_pipeline)
ChaimZhu's avatar
ChaimZhu committed
302
303
304
305
306
307
308

    data = []
    for pcd in pcds:
        data_ = dict(lidar_points=dict(lidar_path=pcd))
        data_ = test_pipeline(data_)
        data.append(data_)

309
310
    # forward the model
    with torch.no_grad():
ChaimZhu's avatar
ChaimZhu committed
311
312
313
314
315
316
        results = model.test_step(data)

    if not is_batch:
        return results[0]
    else:
        return results