inference.py 4.5 KB
Newer Older
wuyuefeng's avatar
Demo  
wuyuefeng committed
1
2
import mmcv
import torch
zhangwenwei's avatar
zhangwenwei committed
3
from copy import deepcopy
wuyuefeng's avatar
Demo  
wuyuefeng committed
4
5
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
zhangwenwei's avatar
zhangwenwei committed
6
from os import path as osp
wuyuefeng's avatar
Demo  
wuyuefeng committed
7
8
9
10
11
12
13

from mmdet3d.core import Box3DMode, show_result
from mmdet3d.core.bbox import get_box_type
from mmdet3d.datasets.pipelines import Compose
from mmdet3d.models import build_detector


14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def convert_SyncBN(config):
    """Convert config's naiveSyncBN to BN.

    Args:
         config (str or :obj:`mmcv.Config`): Config file path or the config
            object.
    """
    if isinstance(config, dict):
        for item in config:
            if item == 'norm_cfg':
                config[item]['type'] = config[item]['type']. \
                                    replace('naiveSyncBN', 'BN')
            else:
                convert_SyncBN(config[item])


wuyuefeng's avatar
Demo  
wuyuefeng committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def init_detector(config, checkpoint=None, device='cuda:0'):
    """Initialize a detector from config file.

    Args:
        config (str or :obj:`mmcv.Config`): Config file path or the config
            object.
        checkpoint (str, optional): Checkpoint path. If left as None, the model
            will not load any weights.
        device (str): Device to use.

    Returns:
        nn.Module: The constructed detector.
    """
    if isinstance(config, str):
        config = mmcv.Config.fromfile(config)
    elif not isinstance(config, mmcv.Config):
        raise TypeError('config must be a filename or Config object, '
                        f'but got {type(config)}')
    config.model.pretrained = None
49
    convert_SyncBN(config.model)
50
51
    config.model.train_cfg = None
    model = build_detector(config.model, test_cfg=config.get('test_cfg'))
wuyuefeng's avatar
Demo  
wuyuefeng committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    if checkpoint is not None:
        checkpoint = load_checkpoint(model, checkpoint)
        if 'CLASSES' in checkpoint['meta']:
            model.CLASSES = checkpoint['meta']['CLASSES']
        else:
            model.CLASSES = config.class_names
    model.cfg = config  # save the config in the model for convenience
    model.to(device)
    model.eval()
    return model


def inference_detector(model, pcd):
    """Inference point cloud with the detector.

    Args:
        model (nn.Module): The loaded detector.
        pcd (str): Point cloud files.

    Returns:
        tuple: Predicted results and data from pipeline.
    """
    cfg = model.cfg
    device = next(model.parameters()).device  # model device
    # build the data pipeline
    test_pipeline = deepcopy(cfg.data.test.pipeline)
    test_pipeline = Compose(test_pipeline)
    box_type_3d, box_mode_3d = get_box_type(cfg.data.test.box_type_3d)
    data = dict(
        pts_filename=pcd,
        box_type_3d=box_type_3d,
        box_mode_3d=box_mode_3d,
84
85
86
        sweeps=[],
        # set timestamp = 0
        timestamp=[0],
wuyuefeng's avatar
Demo  
wuyuefeng committed
87
88
89
90
91
92
93
94
95
96
97
98
99
        img_fields=[],
        bbox3d_fields=[],
        pts_mask_fields=[],
        pts_seg_fields=[],
        bbox_fields=[],
        mask_fields=[],
        seg_fields=[])
    data = test_pipeline(data)
    data = collate([data], samples_per_gpu=1)
    if next(model.parameters()).is_cuda:
        # scatter to specified GPU
        data = scatter(data, [device.index])[0]
    else:
yinchimaoliang's avatar
yinchimaoliang committed
100
101
102
        # this is a workaround to avoid the bug of MMDataParallel
        data['img_metas'] = data['img_metas'][0].data
        data['points'] = data['points'][0].data
wuyuefeng's avatar
Demo  
wuyuefeng committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    # forward the model
    with torch.no_grad():
        result = model(return_loss=False, rescale=True, **data)
    return result, data


def show_result_meshlab(data, result, out_dir):
    """Show result by meshlab.

    Args:
        data (dict): Contain data from pipeline.
        result (dict): Predicted result from model.
        out_dir (str): Directory to save visualized result.
    """
    points = data['points'][0][0].cpu().numpy()
    pts_filename = data['img_metas'][0][0]['pts_filename']
    file_name = osp.split(pts_filename)[-1].split('.')[0]

    assert out_dir is not None, 'Expect out_dir, got none.'

123
124
125
126
    if 'pts_bbox' in result[0].keys():
        pred_bboxes = result[0]['pts_bbox']['boxes_3d'].tensor.numpy()
    else:
        pred_bboxes = result[0]['boxes_3d'].tensor.numpy()
wuyuefeng's avatar
Demo  
wuyuefeng committed
127
128
129
130
131
132
133
    # for now we convert points into depth mode
    if data['img_metas'][0][0]['box_mode_3d'] != Box3DMode.DEPTH:
        points = points[..., [1, 0, 2]]
        points[..., 0] *= -1
        pred_bboxes = Box3DMode.convert(pred_bboxes,
                                        data['img_metas'][0][0]['box_mode_3d'],
                                        Box3DMode.DEPTH)
134
135
    show_result(points, None, pred_bboxes, out_dir, file_name, show=False)
    return out_dir, file_name