test_kitti_dataset.py 3.6 KB
Newer Older
jshilong's avatar
jshilong committed
1
2
3
# Copyright (c) OpenMMLab. All rights reserved.

import numpy as np
jshilong's avatar
jshilong committed
4
5
import torch
from mmcv.transforms.base import BaseTransform
6
from mmengine.data import InstanceData
jshilong's avatar
jshilong committed
7
from mmengine.registry import TRANSFORMS
jshilong's avatar
jshilong committed
8
9

from mmdet3d.datasets import KittiDataset
zhangshilong's avatar
zhangshilong committed
10
from mmdet3d.structures import Det3DDataSample, LiDARInstance3DBoxes
jshilong's avatar
jshilong committed
11
12
13
14
15
16
17


def _generate_kitti_dataset_config():
    data_root = 'tests/data/kitti'
    ann_file = 'kitti_infos_train.pkl'
    classes = ['Pedestrian', 'Cyclist', 'Car']
    # wait for pipline refactor
jshilong's avatar
jshilong committed
18
19
20
21
22
23
24
25
26

    if 'Identity' not in TRANSFORMS:

        @TRANSFORMS.register_module()
        class Identity(BaseTransform):

            def transform(self, info):
                if 'ann_info' in info:
                    info['gt_labels_3d'] = info['ann_info']['gt_labels_3d']
27
28
29
30
31
                data_sample = Det3DDataSample()
                gt_instances_3d = InstanceData()
                gt_instances_3d.labels_3d = info['gt_labels_3d']
                data_sample.gt_instances_3d = gt_instances_3d
                info['data_sample'] = data_sample
jshilong's avatar
jshilong committed
32
33
                return info

jshilong's avatar
jshilong committed
34
    pipeline = [
jshilong's avatar
jshilong committed
35
        dict(type='Identity'),
jshilong's avatar
jshilong committed
36
    ]
jshilong's avatar
jshilong committed
37

jshilong's avatar
jshilong committed
38
39
40
41
42
43
44
45
    modality = dict(use_lidar=True, use_camera=False)
    data_prefix = dict(pts='training/velodyne_reduced', img='training/image_2')
    return data_root, ann_file, classes, data_prefix, pipeline, modality


def test_getitem():
    np.random.seed(0)
    data_root, ann_file, classes, data_prefix, \
jshilong's avatar
jshilong committed
46
        pipeline, modality, = _generate_kitti_dataset_config()
jshilong's avatar
jshilong committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    modality['use_camera'] = True

    kitti_dataset = KittiDataset(
        data_root,
        ann_file,
        data_prefix=dict(
            pts='training/velodyne_reduced',
            img='training/image_2',
        ),
        pipeline=pipeline,
        metainfo=dict(CLASSES=classes),
        modality=modality)

    kitti_dataset.prepare_data(0)
    input_dict = kitti_dataset.get_data_info(0)
    kitti_dataset[0]
    # assert the the path should contains data_prefix and data_root
    assert data_prefix['pts'] in input_dict['lidar_points']['lidar_path']
    assert data_root in input_dict['lidar_points']['lidar_path']
    for cam_id, img_info in input_dict['images'].items():
        if 'img_path' in img_info:
            assert data_prefix['img'] in img_info['img_path']
            assert data_root in img_info['img_path']

    ann_info = kitti_dataset.parse_ann_info(input_dict)

    # assert the keys in ann_info and the type
zhangshilong's avatar
zhangshilong committed
74
75
    assert 'instances' in ann_info

jshilong's avatar
jshilong committed
76
77
78
    # only one instance
    assert 'gt_labels_3d' in ann_info
    assert ann_info['gt_labels_3d'].dtype == np.int64
zhangshilong's avatar
zhangshilong committed
79

jshilong's avatar
jshilong committed
80
81
    assert 'gt_bboxes_3d' in ann_info
    assert isinstance(ann_info['gt_bboxes_3d'], LiDARInstance3DBoxes)
jshilong's avatar
jshilong committed
82
83
    assert torch.allclose(ann_info['gt_bboxes_3d'].tensor.sum(),
                          torch.tensor(7.2650))
ZCMax's avatar
ZCMax committed
84
85
86
87
    assert 'centers_2d' in ann_info
    assert ann_info['centers_2d'].dtype == np.float64
    assert 'depths' in ann_info
    assert ann_info['depths'].dtype == np.float64
jshilong's avatar
jshilong committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

    car_kitti_dataset = KittiDataset(
        data_root,
        ann_file,
        data_prefix=dict(
            pts='training/velodyne_reduced',
            img='training/image_2',
        ),
        pipeline=pipeline,
        metainfo=dict(CLASSES=['Car']),
        modality=modality)

    input_dict = car_kitti_dataset.get_data_info(0)
    ann_info = car_kitti_dataset.parse_ann_info(input_dict)

    # assert the keys in ann_info and the type
zhangshilong's avatar
zhangshilong committed
104
105
    assert 'instances' in ann_info
    assert ann_info['gt_labels_3d'].dtype == np.int64
jshilong's avatar
jshilong committed
106
    # all instance have been filtered by classes
zhangshilong's avatar
zhangshilong committed
107
    assert len(ann_info['gt_labels_3d']) == 0
jshilong's avatar
jshilong committed
108
    assert len(car_kitti_dataset.metainfo['CLASSES']) == 1