test_kitti_dataset.py 3.61 KB
Newer Older
jshilong's avatar
jshilong committed
1
2
3
# Copyright (c) OpenMMLab. All rights reserved.

import numpy as np
jshilong's avatar
jshilong committed
4
5
6
import torch
from mmcv.transforms.base import BaseTransform
from mmengine.registry import TRANSFORMS
7
from mmengine.structures import InstanceData
jshilong's avatar
jshilong committed
8
9

from mmdet3d.datasets import KittiDataset
zhangshilong's avatar
zhangshilong committed
10
from mmdet3d.structures import Det3DDataSample, LiDARInstance3DBoxes
jshilong's avatar
jshilong committed
11
12
13
14
15
16
17


def _generate_kitti_dataset_config():
    data_root = 'tests/data/kitti'
    ann_file = 'kitti_infos_train.pkl'
    classes = ['Pedestrian', 'Cyclist', 'Car']
    # wait for pipline refactor
jshilong's avatar
jshilong committed
18
19
20
21
22
23
24
25
26

    if 'Identity' not in TRANSFORMS:

        @TRANSFORMS.register_module()
        class Identity(BaseTransform):

            def transform(self, info):
                if 'ann_info' in info:
                    info['gt_labels_3d'] = info['ann_info']['gt_labels_3d']
27
28
29
30
                data_sample = Det3DDataSample()
                gt_instances_3d = InstanceData()
                gt_instances_3d.labels_3d = info['gt_labels_3d']
                data_sample.gt_instances_3d = gt_instances_3d
VVsssssk's avatar
VVsssssk committed
31
                info['data_samples'] = data_sample
jshilong's avatar
jshilong committed
32
33
                return info

jshilong's avatar
jshilong committed
34
    pipeline = [
jshilong's avatar
jshilong committed
35
        dict(type='Identity'),
jshilong's avatar
jshilong committed
36
    ]
jshilong's avatar
jshilong committed
37

jshilong's avatar
jshilong committed
38
39
40
41
42
43
44
45
    modality = dict(use_lidar=True, use_camera=False)
    data_prefix = dict(pts='training/velodyne_reduced', img='training/image_2')
    return data_root, ann_file, classes, data_prefix, pipeline, modality


def test_getitem():
    np.random.seed(0)
    data_root, ann_file, classes, data_prefix, \
jshilong's avatar
jshilong committed
46
        pipeline, modality, = _generate_kitti_dataset_config()
jshilong's avatar
jshilong committed
47
48
49
50
51
52
53
54
55
56
    modality['use_camera'] = True

    kitti_dataset = KittiDataset(
        data_root,
        ann_file,
        data_prefix=dict(
            pts='training/velodyne_reduced',
            img='training/image_2',
        ),
        pipeline=pipeline,
57
        metainfo=dict(classes=classes),
jshilong's avatar
jshilong committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        modality=modality)

    kitti_dataset.prepare_data(0)
    input_dict = kitti_dataset.get_data_info(0)
    kitti_dataset[0]
    # assert the the path should contains data_prefix and data_root
    assert data_prefix['pts'] in input_dict['lidar_points']['lidar_path']
    assert data_root in input_dict['lidar_points']['lidar_path']
    for cam_id, img_info in input_dict['images'].items():
        if 'img_path' in img_info:
            assert data_prefix['img'] in img_info['img_path']
            assert data_root in img_info['img_path']

    ann_info = kitti_dataset.parse_ann_info(input_dict)

    # assert the keys in ann_info and the type
zhangshilong's avatar
zhangshilong committed
74
75
    assert 'instances' in ann_info

jshilong's avatar
jshilong committed
76
77
78
    # only one instance
    assert 'gt_labels_3d' in ann_info
    assert ann_info['gt_labels_3d'].dtype == np.int64
zhangshilong's avatar
zhangshilong committed
79

jshilong's avatar
jshilong committed
80
81
    assert 'gt_bboxes_3d' in ann_info
    assert isinstance(ann_info['gt_bboxes_3d'], LiDARInstance3DBoxes)
jshilong's avatar
jshilong committed
82
83
    assert torch.allclose(ann_info['gt_bboxes_3d'].tensor.sum(),
                          torch.tensor(7.2650))
ZCMax's avatar
ZCMax committed
84
    assert 'centers_2d' in ann_info
VVsssssk's avatar
VVsssssk committed
85
    assert ann_info['centers_2d'].dtype == np.float32
ZCMax's avatar
ZCMax committed
86
    assert 'depths' in ann_info
VVsssssk's avatar
VVsssssk committed
87
    assert ann_info['depths'].dtype == np.float32
jshilong's avatar
jshilong committed
88
89
90
91
92
93
94
95
96

    car_kitti_dataset = KittiDataset(
        data_root,
        ann_file,
        data_prefix=dict(
            pts='training/velodyne_reduced',
            img='training/image_2',
        ),
        pipeline=pipeline,
97
        metainfo=dict(classes=['Car']),
jshilong's avatar
jshilong committed
98
99
100
101
102
103
        modality=modality)

    input_dict = car_kitti_dataset.get_data_info(0)
    ann_info = car_kitti_dataset.parse_ann_info(input_dict)

    # assert the keys in ann_info and the type
zhangshilong's avatar
zhangshilong committed
104
105
    assert 'instances' in ann_info
    assert ann_info['gt_labels_3d'].dtype == np.int64
jshilong's avatar
jshilong committed
106
    # all instance have been filtered by classes
zhangshilong's avatar
zhangshilong committed
107
    assert len(ann_info['gt_labels_3d']) == 0
108
    assert len(car_kitti_dataset.metainfo['classes']) == 1