visualize.py 4.18 KB
Newer Older
yeshenglong1's avatar
yeshenglong1 committed
1
2
import argparse
import os
zhe chen's avatar
zhe chen committed
3
4

import mmcv
yeshenglong1's avatar
yeshenglong1 committed
5
6
7
8
9
10
11
12
13
14
15
16
from renderer import Renderer

CAT2ID = {
    'ped_crossing': 0,
    'divider': 1,
    'boundary': 2,
}

ID2CAT = {v: k for k, v in CAT2ID.items()}

ROI_SIZE = (60, 30)

zhe chen's avatar
zhe chen committed
17

yeshenglong1's avatar
yeshenglong1 committed
18
19
20
def parse_args():
    parser = argparse.ArgumentParser(
        description='Visualize groundtruth and results')
zhe chen's avatar
zhe chen committed
21

yeshenglong1's avatar
yeshenglong1 committed
22
    parser.add_argument('log_id', type=str,
zhe chen's avatar
zhe chen committed
23
24
25
26
27
28
29
30
31
32
                        help='log_id of data to visualize')
    parser.add_argument('ann_file',
                        help='gt file to visualize')
    parser.add_argument('--result',
                        type=str,
                        help='prediction result to visualize')
    parser.add_argument('--thr',
                        type=float,
                        default=0,
                        help='score threshold to filter predictions')
yeshenglong1's avatar
yeshenglong1 committed
33
    parser.add_argument(
zhe chen's avatar
zhe chen committed
34
        '--out-dir',
yeshenglong1's avatar
yeshenglong1 committed
35
36
37
38
39
40
        default='demo',
        help='directory where visualize results will be saved')
    args = parser.parse_args()

    return args

zhe chen's avatar
zhe chen committed
41

yeshenglong1's avatar
yeshenglong1 committed
42
43
44
45
46
47
def import_plugin(cfg):
    '''
        import modules, registry will be update
    '''

    import sys
zhe chen's avatar
zhe chen committed
48
    sys.path.append(os.path.abspath('.'))
yeshenglong1's avatar
yeshenglong1 committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    if hasattr(cfg, 'plugin'):
        if cfg.plugin:
            import importlib
            if hasattr(cfg, 'plugin_dir'):
                def import_path(plugin_dir):
                    _module_dir = os.path.dirname(plugin_dir)
                    _module_dir = _module_dir.split('/')
                    _module_path = _module_dir[0]

                    for m in _module_dir[1:]:
                        _module_path = _module_path + '.' + m
                    print(f'importing {_module_path}/')
                    plg_lib = importlib.import_module(_module_path)

                plugin_dirs = cfg.plugin_dir
zhe chen's avatar
zhe chen committed
64
65
                if not isinstance(plugin_dirs, list):
                    plugin_dirs = [plugin_dirs, ]
yeshenglong1's avatar
yeshenglong1 committed
66
67
                for plugin_dir in plugin_dirs:
                    import_path(plugin_dir)
zhe chen's avatar
zhe chen committed
68

yeshenglong1's avatar
yeshenglong1 committed
69
70
71
72
73
74
75
76
77
78
            else:
                # import dir is the dirpath for the config file
                _module_dir = os.path.dirname(args.config)
                _module_dir = _module_dir.split('/')
                _module_path = _module_dir[0]
                for m in _module_dir[1:]:
                    _module_path = _module_path + '.' + m
                print(f'importing {_module_path}/')
                plg_lib = importlib.import_module(_module_path)

zhe chen's avatar
zhe chen committed
79

yeshenglong1's avatar
yeshenglong1 committed
80
81
82
83
84
def main(args):
    log_id = args.log_id
    ann = mmcv.load(args.ann_file)
    root_path = os.path.dirname(args.ann_file)
    out_dir = os.path.join(args.out_dir, str(log_id))
zhe chen's avatar
zhe chen committed
85

yeshenglong1's avatar
yeshenglong1 committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    log_ann = ann[log_id]
    renderer = Renderer(roi_size=ROI_SIZE)

    if args.result:
        result = mmcv.load(args.result)['results']

    for frame in mmcv.track_iter_progress(log_ann):
        timestamp = frame['timestamp']
        sensor = frame['sensor']
        annotation = frame['annotation']
        imgs = [mmcv.imread(os.path.join(root_path, 'argoverse2', i['image_path'])) for i in sensor.values()]
        extrinsics = [i['extrinsic'] for i in sensor.values()]
        intrinsics = [i['intrinsic'] for i in sensor.values()]

        frame_dir = os.path.join(out_dir, timestamp, 'gt')
        os.makedirs(frame_dir, exist_ok=True)
zhe chen's avatar
zhe chen committed
102

yeshenglong1's avatar
yeshenglong1 committed
103
        renderer.render_bev_from_vectors(annotation, out_dir=frame_dir)
zhe chen's avatar
zhe chen committed
104
105
        renderer.render_camera_views_from_vectors(annotation, imgs, extrinsics,
                                                  intrinsics, 4, frame_dir)
yeshenglong1's avatar
yeshenglong1 committed
106
107
108
109
110
111
112
113
114
115
116

        if args.result:
            pred = result[timestamp]
            vectors = {cat: [] for cat in CAT2ID.keys()}
            for i in range(len(pred['labels'])):
                score = pred['scores'][i]
                label = pred['labels'][i]
                v = pred['vectors'][i]

                if score > args.thr:
                    vectors[ID2CAT[label]].append(v)
zhe chen's avatar
zhe chen committed
117

yeshenglong1's avatar
yeshenglong1 committed
118
119
120
            frame_dir = os.path.join(out_dir, timestamp, 'pred')
            os.makedirs(frame_dir, exist_ok=True)
            renderer.render_bev_from_vectors(vectors, out_dir=frame_dir)
zhe chen's avatar
zhe chen committed
121
122
            renderer.render_camera_views_from_vectors(vectors, imgs,
                                                      extrinsics, intrinsics, 4, frame_dir)
yeshenglong1's avatar
yeshenglong1 committed
123
124
125
126


if __name__ == '__main__':
    args = parse_args()
zhe chen's avatar
zhe chen committed
127
    main(args)