"docs/img/thumbnails/pruning-speed-up.svg" did not exist on "eb77376e026657a1c5b1317104a46868629c3439"
argo_dataset.py 3.1 KB
Newer Older
zhe chen's avatar
zhe chen committed
1
import os
yeshenglong1's avatar
yeshenglong1 committed
2
from time import time
zhe chen's avatar
zhe chen committed
3

yeshenglong1's avatar
yeshenglong1 committed
4
import mmcv
zhe chen's avatar
zhe chen committed
5
6
import numpy as np
from mmdet.datasets import DATASETS
yeshenglong1's avatar
yeshenglong1 committed
7
8
from shapely.geometry import LineString

zhe chen's avatar
zhe chen committed
9
10
11
from .base_dataset import BaseMapDataset


yeshenglong1's avatar
yeshenglong1 committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
@DATASETS.register_module()
class AV2Dataset(BaseMapDataset):
    """Argoverse2 map dataset class.

    Args:
        ann_file (str): annotation file path
        cat2id (dict): category to class id
        roi_size (tuple): bev range
        eval_config (Config): evaluation config
        meta (dict): meta information
        pipeline (Config): data processing pipeline config,
        interval (int): annotation load interval
        work_dir (str): path to work dir
        test_mode (bool): whether in test mode
    """

zhe chen's avatar
zhe chen committed
28
    def __init__(self, **kwargs, ):
yeshenglong1's avatar
yeshenglong1 committed
29
        super().__init__(**kwargs)
zhe chen's avatar
zhe chen committed
30

yeshenglong1's avatar
yeshenglong1 committed
31
32
33
34
35
36
37
38
39
    def load_annotations(self, ann_file):
        """Load annotations from ann_file.

        Args:
            ann_file (str): Path of the annotation file.

        Returns:
            list[dict]: List of annotations.
        """
zhe chen's avatar
zhe chen committed
40

yeshenglong1's avatar
yeshenglong1 committed
41
42
43
44
45
46
        start_time = time()
        ann = mmcv.load(ann_file)
        samples = []
        for seg_id, sequence in ann.items():
            samples.extend(sequence)
        samples = samples[::self.interval]
zhe chen's avatar
zhe chen committed
47

yeshenglong1's avatar
yeshenglong1 committed
48
49
50
51
        print(f'collected {len(samples)} samples in {(time() - start_time):.2f}s')
        self.samples = samples

    def get_sample(self, idx):
zhe chen's avatar
zhe chen committed
52
53
        """Get data sample. For each sample, map extractor will be applied to extract
        map elements.
yeshenglong1's avatar
yeshenglong1 committed
54
55
56
57
58
59
60
61
62

        Args:
            idx (int): data index

        Returns:
            result (dict): dict of input
        """

        sample = self.samples[idx]
zhe chen's avatar
zhe chen committed
63

yeshenglong1's avatar
yeshenglong1 committed
64
65
66
67
68
69
70
71
        if not self.test_mode:
            ann = sample['annotation']

            # collected required keys
            map_label2geom = {}
            for k, v in ann.items():
                if k in self.cat2id.keys():
                    map_label2geom[self.cat2id[k]] = [LineString(np.array(l)[:, :3]) for l in v]
zhe chen's avatar
zhe chen committed
72

yeshenglong1's avatar
yeshenglong1 committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        ego2img_rts = []
        cams = sample['sensor']
        for c in cams.values():
            extrinsic, intrinsic = np.array(
                c['extrinsic']), np.array(c['intrinsic'])
            ego2cam_rt = extrinsic
            viewpad = np.eye(4)
            viewpad[:intrinsic.shape[0], :intrinsic.shape[1]] = intrinsic
            ego2cam_rt = (viewpad @ ego2cam_rt)
            ego2img_rts.append(ego2cam_rt)

        pose = sample['pose']
        input_dict = {
            'token': sample['timestamp'],
            'img_filenames': [os.path.join(self.root_path, c['image_path']) for c in cams.values()],
            # intrinsics are 3x3 Ks
            'cam_intrinsics': [c['intrinsic'] for c in cams.values()],
            # extrinsics are 4x4 tranform matrix, NOTE: **ego2cam**
            'cam_extrinsics': [c['extrinsic'] for c in cams.values()],
            'ego2img': ego2img_rts,
zhe chen's avatar
zhe chen committed
93
            'ego2global_translation': pose['ego2global_translation'],
yeshenglong1's avatar
yeshenglong1 committed
94
95
96
            'ego2global_rotation': pose['ego2global_rotation'],
        }
        if not self.test_mode:
zhe chen's avatar
zhe chen committed
97
            input_dict.update({'map_geoms': map_label2geom})  # {0: List[ped_crossing(LineString)], 1: ...}})
yeshenglong1's avatar
yeshenglong1 committed
98

zhe chen's avatar
zhe chen committed
99
        return input_dict