browse_dataset.py 7.55 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import argparse
3
import mmcv
4
5
import numpy as np
import warnings
6
from mmcv import Config, DictAction, mkdir_or_exist
7
from os import path as osp
8
from pathlib import Path
9

10
11
from mmdet3d.core.bbox import (Box3DMode, CameraInstance3DBoxes, Coord3DMode,
                               DepthInstance3DBoxes, LiDARInstance3DBoxes)
12
13
14
from mmdet3d.core.visualizer import (show_multi_modality_result, show_result,
                                     show_seg_result)
from mmdet3d.datasets import build_dataset
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30


def parse_args():
    parser = argparse.ArgumentParser(description='Browse a dataset')
    parser.add_argument('config', help='train config file path')
    parser.add_argument(
        '--skip-type',
        type=str,
        nargs='+',
        default=['Normalize'],
        help='skip some useless pipeline')
    parser.add_argument(
        '--output-dir',
        default=None,
        type=str,
        help='If there is no display interface, you can save it')
31
    parser.add_argument(
32
33
34
35
        '--task',
        type=str,
        choices=['det', 'seg', 'multi_modality-det', 'mono-det'],
        help='Determine the visualization method depending on the task.')
36
37
38
39
    parser.add_argument(
        '--aug',
        action='store_true',
        help='Whether to visualize augmented datasets or original dataset.')
40
41
42
43
44
    parser.add_argument(
        '--online',
        action='store_true',
        help='Whether to perform online visualization. Note that you often '
        'need a monitor to do so.')
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    args = parser.parse_args()
    return args


59
def build_data_cfg(config_path, skip_type, aug, cfg_options):
60
    """Build data config for loading visualization data."""
61

62
63
64
    cfg = Config.fromfile(config_path)
    if cfg_options is not None:
        cfg.merge_from_dict(cfg_options)
65
66
    # extract inner dataset of `RepeatDataset` as `cfg.data.train`
    # so we don't need to worry about it later
67
    if cfg.data.train['type'] == 'RepeatDataset':
68
        cfg.data.train = cfg.data.train.dataset
69
70
71
    # use only first dataset for `ConcatDataset`
    if cfg.data.train['type'] == 'ConcatDataset':
        cfg.data.train = cfg.data.train.datasets[0]
72
    train_data_cfg = cfg.data.train
73
74
75
76
77
78
79
80
81

    if aug:
        show_pipeline = cfg.train_pipeline
    else:
        show_pipeline = cfg.eval_pipeline
        for i in range(len(cfg.train_pipeline)):
            if cfg.train_pipeline[i]['type'] == 'LoadAnnotations3D':
                show_pipeline.insert(i, cfg.train_pipeline[i])

82
    train_data_cfg['pipeline'] = [
83
        x for x in show_pipeline if x['type'] not in skip_type
84
85
86
87
88
    ]

    return cfg


89
90
91
92
93
94
95
96
97
98
99
def to_depth_mode(points, bboxes):
    """Convert points and bboxes to Depth Coord and Depth Box mode."""
    if points is not None:
        points = Coord3DMode.convert_point(points.copy(), Coord3DMode.LIDAR,
                                           Coord3DMode.DEPTH)
    if bboxes is not None:
        bboxes = Box3DMode.convert(bboxes.clone(), Box3DMode.LIDAR,
                                   Box3DMode.DEPTH)
    return points, bboxes


100
def show_det_data(input, out_dir, show=False):
101
    """Visualize 3D point cloud and 3D bboxes."""
102
103
104
105
    img_metas = input['img_metas']._data
    points = input['points']._data.numpy()
    gt_bboxes = input['gt_bboxes_3d']._data.tensor
    if img_metas['box_mode_3d'] != Box3DMode.DEPTH:
106
        points, gt_bboxes = to_depth_mode(points, gt_bboxes)
107
    filename = osp.splitext(osp.basename(img_metas['pts_filename']))[0]
108
109
110
111
112
113
114
115
116
117
    show_result(
        points,
        gt_bboxes.clone(),
        None,
        out_dir,
        filename,
        show=show,
        snapshot=True)


118
def show_seg_data(input, out_dir, show=False):
119
    """Visualize 3D point cloud and segmentation mask."""
120
121
122
123
    img_metas = input['img_metas']._data
    points = input['points']._data.numpy()
    gt_seg = input['pts_semantic_mask']._data.numpy()
    filename = osp.splitext(osp.basename(img_metas['pts_filename']))[0]
124
125
126
127
128
129
    show_seg_result(
        points,
        gt_seg.copy(),
        None,
        out_dir,
        filename,
130
131
        np.array(img_metas['PALETTE']),
        img_metas['ignore_index'],
132
133
134
135
        show=show,
        snapshot=True)


136
def show_proj_bbox_img(input, out_dir, show=False, is_nus_mono=False):
137
    """Visualize 3D bboxes on 2D image by projection."""
138
139
140
    gt_bboxes = input['gt_bboxes_3d']._data
    img_metas = input['img_metas']._data
    img = input['img']._data.numpy()
141
142
143
144
145
    # need to transpose channel to first dim
    img = img.transpose(1, 2, 0)
    # no 3D gt bboxes, just show img
    if gt_bboxes.tensor.shape[0] == 0:
        gt_bboxes = None
146
    filename = Path(img_metas['filename']).name
147
148
149
150
151
    if isinstance(gt_bboxes, DepthInstance3DBoxes):
        show_multi_modality_result(
            img,
            gt_bboxes,
            None,
152
            None,
153
154
            out_dir,
            filename,
155
            box_mode='depth',
156
157
158
159
160
161
162
163
164
165
            img_metas=img_metas,
            show=show)
    elif isinstance(gt_bboxes, LiDARInstance3DBoxes):
        show_multi_modality_result(
            img,
            gt_bboxes,
            None,
            img_metas['lidar2img'],
            out_dir,
            filename,
166
167
168
169
170
171
172
173
            box_mode='lidar',
            img_metas=img_metas,
            show=show)
    elif isinstance(gt_bboxes, CameraInstance3DBoxes):
        show_multi_modality_result(
            img,
            gt_bboxes,
            None,
174
            img_metas['cam2img'],
175
176
177
            out_dir,
            filename,
            box_mode='camera',
178
179
180
181
            img_metas=img_metas,
            show=show)
    else:
        # can't project, just show img
182
183
        warnings.warn(
            f'unrecognized gt box type {type(gt_bboxes)}, only show image')
184
185
186
187
        show_multi_modality_result(
            img, None, None, None, out_dir, filename, show=show)


188
189
190
191
192
193
def main():
    args = parse_args()

    if args.output_dir is not None:
        mkdir_or_exist(args.output_dir)

194
195
    cfg = build_data_cfg(args.config, args.skip_type, args.aug,
                         args.cfg_options)
196
197
198
199
200
201
    try:
        dataset = build_dataset(
            cfg.data.train, default_args=dict(filter_empty_gt=False))
    except TypeError:  # seg dataset doesn't have `filter_empty_gt` key
        dataset = build_dataset(cfg.data.train)

202
    dataset_type = cfg.dataset_type
203
    # configure visualization mode
204
    vis_task = args.task  # 'det', 'seg', 'multi_modality-det', 'mono-det'
205
    progress_bar = mmcv.ProgressBar(len(dataset))
206

207
    for input in dataset:
208
        if vis_task in ['det', 'multi_modality-det']:
209
            # show 3D bboxes on 3D point clouds
210
            show_det_data(input, args.output_dir, show=args.online)
211
212
213
        if vis_task in ['multi_modality-det', 'mono-det']:
            # project 3D bboxes to 2D image
            show_proj_bbox_img(
214
                input,
215
216
217
218
                args.output_dir,
                show=args.online,
                is_nus_mono=(dataset_type == 'NuScenesMonoDataset'))
        elif vis_task in ['seg']:
219
            # show 3D segmentation mask on 3D point clouds
220
221
            show_seg_data(input, args.output_dir, show=args.online)
        progress_bar.update()
222
223
224
225


if __name__ == '__main__':
    main()