browse_dataset.py 7.86 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import argparse
3
import warnings
4
from os import path as osp
5
from pathlib import Path
6

7
8
9
10
import mmcv
import numpy as np
from mmcv import Config, DictAction, mkdir_or_exist

11
12
from mmdet3d.core.bbox import (Box3DMode, CameraInstance3DBoxes, Coord3DMode,
                               DepthInstance3DBoxes, LiDARInstance3DBoxes)
13
14
15
from mmdet3d.core.visualizer import (show_multi_modality_result, show_result,
                                     show_seg_result)
from mmdet3d.datasets import build_dataset
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


def parse_args():
    parser = argparse.ArgumentParser(description='Browse a dataset')
    parser.add_argument('config', help='train config file path')
    parser.add_argument(
        '--skip-type',
        type=str,
        nargs='+',
        default=['Normalize'],
        help='skip some useless pipeline')
    parser.add_argument(
        '--output-dir',
        default=None,
        type=str,
        help='If there is no display interface, you can save it')
32
    parser.add_argument(
33
34
35
36
        '--task',
        type=str,
        choices=['det', 'seg', 'multi_modality-det', 'mono-det'],
        help='Determine the visualization method depending on the task.')
37
38
39
40
    parser.add_argument(
        '--aug',
        action='store_true',
        help='Whether to visualize augmented datasets or original dataset.')
41
42
43
44
45
    parser.add_argument(
        '--online',
        action='store_true',
        help='Whether to perform online visualization. Note that you often '
        'need a monitor to do so.')
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    args = parser.parse_args()
    return args


60
def build_data_cfg(config_path, skip_type, aug, cfg_options):
61
    """Build data config for loading visualization data."""
62

63
64
65
    cfg = Config.fromfile(config_path)
    if cfg_options is not None:
        cfg.merge_from_dict(cfg_options)
66
67
    # extract inner dataset of `RepeatDataset` as `cfg.data.train`
    # so we don't need to worry about it later
68
    if cfg.data.train['type'] == 'RepeatDataset':
69
        cfg.data.train = cfg.data.train.dataset
70
71
72
    # use only first dataset for `ConcatDataset`
    if cfg.data.train['type'] == 'ConcatDataset':
        cfg.data.train = cfg.data.train.datasets[0]
73
    train_data_cfg = cfg.data.train
74
75
76
77
78
79
80
81

    if aug:
        show_pipeline = cfg.train_pipeline
    else:
        show_pipeline = cfg.eval_pipeline
        for i in range(len(cfg.train_pipeline)):
            if cfg.train_pipeline[i]['type'] == 'LoadAnnotations3D':
                show_pipeline.insert(i, cfg.train_pipeline[i])
82
83
84
85
86
87
            # Collect points as well as labels
            if cfg.train_pipeline[i]['type'] == 'Collect3D':
                if show_pipeline[-1]['type'] == 'Collect3D':
                    show_pipeline[-1] = cfg.train_pipeline[i]
                else:
                    show_pipeline.append(cfg.train_pipeline[i])
88

89
    train_data_cfg['pipeline'] = [
90
        x for x in show_pipeline if x['type'] not in skip_type
91
92
93
94
95
    ]

    return cfg


96
97
98
99
100
101
102
103
104
105
106
def to_depth_mode(points, bboxes):
    """Convert points and bboxes to Depth Coord and Depth Box mode."""
    if points is not None:
        points = Coord3DMode.convert_point(points.copy(), Coord3DMode.LIDAR,
                                           Coord3DMode.DEPTH)
    if bboxes is not None:
        bboxes = Box3DMode.convert(bboxes.clone(), Box3DMode.LIDAR,
                                   Box3DMode.DEPTH)
    return points, bboxes


107
def show_det_data(input, out_dir, show=False):
108
    """Visualize 3D point cloud and 3D bboxes."""
109
110
111
112
    img_metas = input['img_metas']._data
    points = input['points']._data.numpy()
    gt_bboxes = input['gt_bboxes_3d']._data.tensor
    if img_metas['box_mode_3d'] != Box3DMode.DEPTH:
113
        points, gt_bboxes = to_depth_mode(points, gt_bboxes)
114
    filename = osp.splitext(osp.basename(img_metas['pts_filename']))[0]
115
116
117
118
119
120
121
122
123
124
    show_result(
        points,
        gt_bboxes.clone(),
        None,
        out_dir,
        filename,
        show=show,
        snapshot=True)


125
def show_seg_data(input, out_dir, show=False):
126
    """Visualize 3D point cloud and segmentation mask."""
127
128
129
130
    img_metas = input['img_metas']._data
    points = input['points']._data.numpy()
    gt_seg = input['pts_semantic_mask']._data.numpy()
    filename = osp.splitext(osp.basename(img_metas['pts_filename']))[0]
131
132
133
134
135
136
    show_seg_result(
        points,
        gt_seg.copy(),
        None,
        out_dir,
        filename,
137
138
        np.array(img_metas['PALETTE']),
        img_metas['ignore_index'],
139
140
141
142
        show=show,
        snapshot=True)


143
def show_proj_bbox_img(input, out_dir, show=False, is_nus_mono=False):
144
    """Visualize 3D bboxes on 2D image by projection."""
145
146
147
    gt_bboxes = input['gt_bboxes_3d']._data
    img_metas = input['img_metas']._data
    img = input['img']._data.numpy()
148
149
150
151
152
    # need to transpose channel to first dim
    img = img.transpose(1, 2, 0)
    # no 3D gt bboxes, just show img
    if gt_bboxes.tensor.shape[0] == 0:
        gt_bboxes = None
153
    filename = Path(img_metas['filename']).name
154
155
156
157
158
    if isinstance(gt_bboxes, DepthInstance3DBoxes):
        show_multi_modality_result(
            img,
            gt_bboxes,
            None,
159
            None,
160
161
            out_dir,
            filename,
162
            box_mode='depth',
163
164
165
166
167
168
169
170
171
172
            img_metas=img_metas,
            show=show)
    elif isinstance(gt_bboxes, LiDARInstance3DBoxes):
        show_multi_modality_result(
            img,
            gt_bboxes,
            None,
            img_metas['lidar2img'],
            out_dir,
            filename,
173
174
175
176
177
178
179
180
            box_mode='lidar',
            img_metas=img_metas,
            show=show)
    elif isinstance(gt_bboxes, CameraInstance3DBoxes):
        show_multi_modality_result(
            img,
            gt_bboxes,
            None,
181
            img_metas['cam2img'],
182
183
184
            out_dir,
            filename,
            box_mode='camera',
185
186
187
188
            img_metas=img_metas,
            show=show)
    else:
        # can't project, just show img
189
190
        warnings.warn(
            f'unrecognized gt box type {type(gt_bboxes)}, only show image')
191
192
193
194
        show_multi_modality_result(
            img, None, None, None, out_dir, filename, show=show)


195
196
197
198
199
200
def main():
    args = parse_args()

    if args.output_dir is not None:
        mkdir_or_exist(args.output_dir)

201
202
    cfg = build_data_cfg(args.config, args.skip_type, args.aug,
                         args.cfg_options)
203
204
205
206
207
208
    try:
        dataset = build_dataset(
            cfg.data.train, default_args=dict(filter_empty_gt=False))
    except TypeError:  # seg dataset doesn't have `filter_empty_gt` key
        dataset = build_dataset(cfg.data.train)

209
    dataset_type = cfg.dataset_type
210
    # configure visualization mode
211
    vis_task = args.task  # 'det', 'seg', 'multi_modality-det', 'mono-det'
212
    progress_bar = mmcv.ProgressBar(len(dataset))
213

214
    for input in dataset:
215
        if vis_task in ['det', 'multi_modality-det']:
216
            # show 3D bboxes on 3D point clouds
217
            show_det_data(input, args.output_dir, show=args.online)
218
219
220
        if vis_task in ['multi_modality-det', 'mono-det']:
            # project 3D bboxes to 2D image
            show_proj_bbox_img(
221
                input,
222
223
224
225
                args.output_dir,
                show=args.online,
                is_nus_mono=(dataset_type == 'NuScenesMonoDataset'))
        elif vis_task in ['seg']:
226
            # show 3D segmentation mask on 3D point clouds
227
228
            show_seg_data(input, args.output_dir, show=args.online)
        progress_bar.update()
229
230
231
232


if __name__ == '__main__':
    main()