test_kitti_dataset.py 4.47 KB
Newer Older
jshilong's avatar
jshilong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Copyright (c) OpenMMLab. All rights reserved.

import numpy as np

from mmdet3d.core import LiDARInstance3DBoxes
from mmdet3d.datasets import KittiDataset


def _generate_kitti_dataset_config():
    data_root = 'tests/data/kitti'
    ann_file = 'kitti_infos_train.pkl'
    classes = ['Pedestrian', 'Cyclist', 'Car']
    # wait for pipline refactor
    pipeline = [
        dict(
            type='LoadPointsFromFile',
            coord_type='LIDAR',
            load_dim=4,
            use_dim=4,
            file_client_args=dict(backend='disk')),
        dict(
            type='MultiScaleFlipAug3D',
            img_scale=(1333, 800),
            pts_scale_ratio=1,
            flip=False,
            transforms=[
                dict(
                    type='GlobalRotScaleTrans',
                    rot_range=[0, 0],
                    scale_ratio_range=[1.0, 1.0],
                    translation_std=[0, 0, 0]),
                dict(type='RandomFlip3D'),
                dict(
                    type='PointsRangeFilter',
                    point_cloud_range=[0, -40, -3, 70.4, 40, 1]),
                dict(
                    type='DefaultFormatBundle3D',
                    class_names=classes,
                    with_label=False),
                dict(type='Collect3D', keys=['points'])
            ])
    ]
    modality = dict(use_lidar=True, use_camera=False)
    data_prefix = dict(pts='training/velodyne_reduced', img='training/image_2')
    return data_root, ann_file, classes, data_prefix, pipeline, modality


def test_getitem():
    np.random.seed(0)
    data_root, ann_file, classes, data_prefix, \
        _, modality, = _generate_kitti_dataset_config()
    modality['use_camera'] = True

    from mmcv.transforms.base import BaseTransform
    from mmengine.registry import TRANSFORMS

    @TRANSFORMS.register_module()
    class Identity(BaseTransform):

        def transform(self, info):
            if 'ann_info' in info:
                info['gt_labels_3d'] = info['ann_info']['gt_labels_3d']
            return info

    pipeline = [
        dict(type='Identity'),
    ]
    kitti_dataset = KittiDataset(
        data_root,
        ann_file,
        data_prefix=dict(
            pts='training/velodyne_reduced',
            img='training/image_2',
        ),
        pipeline=pipeline,
        metainfo=dict(CLASSES=classes),
        modality=modality)

    kitti_dataset.prepare_data(0)
    input_dict = kitti_dataset.get_data_info(0)
    kitti_dataset[0]
    # assert the the path should contains data_prefix and data_root
    assert data_prefix['pts'] in input_dict['lidar_points']['lidar_path']
    assert data_root in input_dict['lidar_points']['lidar_path']
    for cam_id, img_info in input_dict['images'].items():
        if 'img_path' in img_info:
            assert data_prefix['img'] in img_info['img_path']
            assert data_root in img_info['img_path']

    ann_info = kitti_dataset.parse_ann_info(input_dict)

    # assert the keys in ann_info and the type
    assert 'gt_labels' in ann_info
    assert ann_info['gt_labels'].dtype == np.int64
    # only one instance
    assert len(ann_info['gt_labels']) == 1
    assert 'gt_labels_3d' in ann_info
    assert ann_info['gt_labels_3d'].dtype == np.int64
    assert 'gt_bboxes' in ann_info
    assert ann_info['gt_bboxes'].dtype == np.float64
    assert 'gt_bboxes_3d' in ann_info
    assert isinstance(ann_info['gt_bboxes_3d'], LiDARInstance3DBoxes)
    assert 'group_id' in ann_info
    assert ann_info['group_id'].dtype == np.int64
    assert 'occluded' in ann_info
    assert ann_info['occluded'].dtype == np.int64
    assert 'difficulty' in ann_info
    assert ann_info['difficulty'].dtype == np.int64
    assert 'num_lidar_pts' in ann_info
    assert ann_info['num_lidar_pts'].dtype == np.int64
    assert 'truncated' in ann_info
    assert ann_info['truncated'].dtype == np.int64

    car_kitti_dataset = KittiDataset(
        data_root,
        ann_file,
        data_prefix=dict(
            pts='training/velodyne_reduced',
            img='training/image_2',
        ),
        pipeline=pipeline,
        metainfo=dict(CLASSES=['Car']),
        modality=modality)

    input_dict = car_kitti_dataset.get_data_info(0)
    ann_info = car_kitti_dataset.parse_ann_info(input_dict)

    # assert the keys in ann_info and the type
    assert 'gt_labels' in ann_info
    assert ann_info['gt_labels'].dtype == np.int64
    # all instance have been filtered by classes
    assert len(ann_info['gt_labels']) == 0
    assert len(car_kitti_dataset.metainfo['CLASSES']) == 1