kitti2d_dataset.py 6.92 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
2
3
4
5
6
import mmcv
import numpy as np

from mmdet.datasets import DATASETS, CustomDataset


7
@DATASETS.register_module()
zhangwenwei's avatar
zhangwenwei committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class Kitti2DDataset(CustomDataset):

    CLASSES = ('car', 'pedestrian', 'cyclist')
    """
    Annotation format:
    [
        {
            'image': {
                'image_idx': 0,
                'image_path': 'training/image_2/000000.png',
                'image_shape': array([ 370, 1224], dtype=int32)
            },
            'point_cloud': {
                 'num_features': 4,
                 'velodyne_path': 'training/velodyne/000000.bin'
             },
             'calib': {
                 'P0': <np.ndarray> (4, 4),
                 'P1': <np.ndarray> (4, 4),
                 'P2': <np.ndarray> (4, 4),
                 'P3': <np.ndarray> (4, 4),
                 'R0_rect':4x4 np.array,
                 'Tr_velo_to_cam': 4x4 np.array,
                 'Tr_imu_to_velo': 4x4 np.array
             },
             'annos': {
                 'name': <np.ndarray> (n),
                 'truncated': <np.ndarray> (n),
                 'occluded': <np.ndarray> (n),
                 'alpha': <np.ndarray> (n),
                 'bbox': <np.ndarray> (n, 4),
                 'dimensions': <np.ndarray> (n, 3),
                 'location': <np.ndarray> (n, 3),
                 'rotation_y': <np.ndarray> (n),
                 'score': <np.ndarray> (n),
                 'index': array([0], dtype=int32),
                 'group_ids': array([0], dtype=int32),
                 'difficulty': array([0], dtype=int32),
                 'num_points_in_gt': <np.ndarray> (n),
             }
        }
    ]
    """

    def load_annotations(self, ann_file):
53
54
55
56
57
58
59
60
        """Load annotations from ann_file.

        Args:
            ann_file (str): Path of the annotation file.

        Returns:
            list[dict]: List of annotations.
        """
zhangwenwei's avatar
zhangwenwei committed
61
62
63
        self.data_infos = mmcv.load(ann_file)
        self.cat2label = {
            cat_name: i
64
            for i, cat_name in enumerate(self.CLASSES)
zhangwenwei's avatar
zhangwenwei committed
65
66
67
68
69
70
71
72
73
74
75
76
        }
        return self.data_infos

    def _filter_imgs(self, min_size=32):
        """Filter images without ground truths."""
        valid_inds = []
        for i, img_info in enumerate(self.data_infos):
            if len(img_info['annos']['name']) > 0:
                valid_inds.append(i)
        return valid_inds

    def get_ann_info(self, index):
77
78
79
80
81
82
83
84
85
86
87
88
        """Get annotation info according to the given index.

        Args:
            index (int): Index of the annotation data to get.

        Returns:
            dict: Standard annotation dictionary
                consists of the data information.

                - bboxes (np.ndarray): ground truth bboxes
                - labels (np.ndarray): labels of ground truths
        """
zhangwenwei's avatar
zhangwenwei committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        # Use index to get the annos, thus the evalhook could also use this api
        info = self.data_infos[index]
        annos = info['annos']
        gt_names = annos['name']
        gt_bboxes = annos['bbox']
        difficulty = annos['difficulty']

        # remove classes that is not needed
        selected = self.keep_arrays_by_name(gt_names, self.CLASSES)
        gt_bboxes = gt_bboxes[selected]
        gt_names = gt_names[selected]
        difficulty = difficulty[selected]
        gt_labels = np.array([self.cat2label[n] for n in gt_names])

        anns_results = dict(
            bboxes=gt_bboxes.astype(np.float32),
            labels=gt_labels,
        )
        return anns_results

    def prepare_train_img(self, idx):
110
111
112
113
114
115
116
117
118
        """Training image preparation.

        Args:
            index (int): Index for accessing the target image data.

        Returns:
            dict: Training image data dict after preprocessing
                corresponding to the index.
        """
zhangwenwei's avatar
zhangwenwei committed
119
120
121
122
123
124
125
126
127
128
129
130
        img_raw_info = self.data_infos[idx]['image']
        img_info = dict(filename=img_raw_info['image_path'])
        ann_info = self.get_ann_info(idx)
        if len(ann_info['bboxes']) == 0:
            return None
        results = dict(img_info=img_info, ann_info=ann_info)
        if self.proposals is not None:
            results['proposals'] = self.proposals[idx]
        self.pre_pipeline(results)
        return self.pipeline(results)

    def prepare_test_img(self, idx):
131
132
133
134
135
136
137
138
139
        """Prepare data for testing.

        Args:
            index (int): Index for accessing the target image data.

        Returns:
            dict: Testing image data dict after preprocessing
                corresponding to the index.
        """
zhangwenwei's avatar
zhangwenwei committed
140
141
142
143
144
145
146
147
148
        img_raw_info = self.data_infos[idx]['image']
        img_info = dict(filename=img_raw_info['image_path'])
        results = dict(img_info=img_info)
        if self.proposals is not None:
            results['proposals'] = self.proposals[idx]
        self.pre_pipeline(results)
        return self.pipeline(results)

    def drop_arrays_by_name(self, gt_names, used_classes):
149
150
151
152
153
154
155
156
157
        """Drop irrelevant ground truths by name.

        Args:
            gt_names (list[str]): Names of ground truths.
            used_classes (list[str]): Classes of interest.

        Returns:
            np.ndarray: Indices of ground truths that will be dropped.
        """
zhangwenwei's avatar
zhangwenwei committed
158
159
160
161
162
        inds = [i for i, x in enumerate(gt_names) if x not in used_classes]
        inds = np.array(inds, dtype=np.int64)
        return inds

    def keep_arrays_by_name(self, gt_names, used_classes):
163
164
165
166
167
168
169
170
171
        """Keep useful ground truths by name.

        Args:
            gt_names (list[str]): Names of ground truths.
            used_classes (list[str]): Classes of interest.

        Returns:
            np.ndarray: Indices of ground truths that will be keeped.
        """
zhangwenwei's avatar
zhangwenwei committed
172
173
174
175
176
177
178
179
180
181
182
183
        inds = [i for i, x in enumerate(gt_names) if x in used_classes]
        inds = np.array(inds, dtype=np.int64)
        return inds

    def reformat_bbox(self, outputs, out=None):
        from mmdet3d.core.bbox.transforms import bbox2result_kitti2d
        sample_idx = [info['image']['image_idx'] for info in self.data_infos]
        result_files = bbox2result_kitti2d(outputs, self.CLASSES, sample_idx,
                                           out)
        return result_files

    def evaluate(self, result_files, eval_types=None):
184
185
186
187
188
189
190
191
192
193
194
        """Evaluation in KITTI protocol.

        Args:
            result_files (str): Path of result files.
            eval_types (str): Types of evaluation. Default: None.
                KITTI dataset only support 'bbox' evaluation type.

        Returns:
            tuple (str, dict): Average precision results in str format
                and average precision results in dict format.
        """
zhangwenwei's avatar
zhangwenwei committed
195
196
197
198
199
200
201
202
        from mmdet3d.core.evaluation import kitti_eval
        eval_types = ['bbox'] if not eval_types else eval_types
        assert eval_types in ('bbox', ['bbox'
                                       ]), 'KITTI data set only evaluate bbox'
        gt_annos = [info['annos'] for info in self.data_infos]
        ap_result_str, ap_dict = kitti_eval(
            gt_annos, result_files, self.CLASSES, eval_types=['bbox'])
        return ap_result_str, ap_dict