test_kitti_dataset.py 4.1 KB
Newer Older
jshilong's avatar
jshilong committed
1
2
3
# Copyright (c) OpenMMLab. All rights reserved.

import numpy as np
jshilong's avatar
jshilong committed
4
5
import torch
from mmcv.transforms.base import BaseTransform
6
from mmengine.data import InstanceData
jshilong's avatar
jshilong committed
7
from mmengine.registry import TRANSFORMS
jshilong's avatar
jshilong committed
8
9

from mmdet3d.core import LiDARInstance3DBoxes
10
from mmdet3d.core.data_structures import Det3DDataSample
jshilong's avatar
jshilong committed
11
12
13
14
15
16
17
18
from mmdet3d.datasets import KittiDataset


def _generate_kitti_dataset_config():
    data_root = 'tests/data/kitti'
    ann_file = 'kitti_infos_train.pkl'
    classes = ['Pedestrian', 'Cyclist', 'Car']
    # wait for pipline refactor
jshilong's avatar
jshilong committed
19
20
21
22
23
24
25
26
27

    if 'Identity' not in TRANSFORMS:

        @TRANSFORMS.register_module()
        class Identity(BaseTransform):

            def transform(self, info):
                if 'ann_info' in info:
                    info['gt_labels_3d'] = info['ann_info']['gt_labels_3d']
28
29
30
31
32
                data_sample = Det3DDataSample()
                gt_instances_3d = InstanceData()
                gt_instances_3d.labels_3d = info['gt_labels_3d']
                data_sample.gt_instances_3d = gt_instances_3d
                info['data_sample'] = data_sample
jshilong's avatar
jshilong committed
33
34
                return info

jshilong's avatar
jshilong committed
35
    pipeline = [
jshilong's avatar
jshilong committed
36
        dict(type='Identity'),
jshilong's avatar
jshilong committed
37
    ]
jshilong's avatar
jshilong committed
38

jshilong's avatar
jshilong committed
39
40
41
42
43
44
45
46
    modality = dict(use_lidar=True, use_camera=False)
    data_prefix = dict(pts='training/velodyne_reduced', img='training/image_2')
    return data_root, ann_file, classes, data_prefix, pipeline, modality


def test_getitem():
    np.random.seed(0)
    data_root, ann_file, classes, data_prefix, \
jshilong's avatar
jshilong committed
47
        pipeline, modality, = _generate_kitti_dataset_config()
jshilong's avatar
jshilong committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    modality['use_camera'] = True

    kitti_dataset = KittiDataset(
        data_root,
        ann_file,
        data_prefix=dict(
            pts='training/velodyne_reduced',
            img='training/image_2',
        ),
        pipeline=pipeline,
        metainfo=dict(CLASSES=classes),
        modality=modality)

    kitti_dataset.prepare_data(0)
    input_dict = kitti_dataset.get_data_info(0)
    kitti_dataset[0]
    # assert the the path should contains data_prefix and data_root
    assert data_prefix['pts'] in input_dict['lidar_points']['lidar_path']
    assert data_root in input_dict['lidar_points']['lidar_path']
    for cam_id, img_info in input_dict['images'].items():
        if 'img_path' in img_info:
            assert data_prefix['img'] in img_info['img_path']
            assert data_root in img_info['img_path']

    ann_info = kitti_dataset.parse_ann_info(input_dict)

    # assert the keys in ann_info and the type
    assert 'gt_labels' in ann_info
    assert ann_info['gt_labels'].dtype == np.int64
    # only one instance
    assert len(ann_info['gt_labels']) == 1
jshilong's avatar
jshilong committed
79
    assert (ann_info['gt_labels'] == 0).all()
jshilong's avatar
jshilong committed
80
81
82
83
84
85
    assert 'gt_labels_3d' in ann_info
    assert ann_info['gt_labels_3d'].dtype == np.int64
    assert 'gt_bboxes' in ann_info
    assert ann_info['gt_bboxes'].dtype == np.float64
    assert 'gt_bboxes_3d' in ann_info
    assert isinstance(ann_info['gt_bboxes_3d'], LiDARInstance3DBoxes)
jshilong's avatar
jshilong committed
86
87
    assert torch.allclose(ann_info['gt_bboxes_3d'].tensor.sum(),
                          torch.tensor(7.2650))
jshilong's avatar
jshilong committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    assert 'group_id' in ann_info
    assert ann_info['group_id'].dtype == np.int64
    assert 'occluded' in ann_info
    assert ann_info['occluded'].dtype == np.int64
    assert 'difficulty' in ann_info
    assert ann_info['difficulty'].dtype == np.int64
    assert 'num_lidar_pts' in ann_info
    assert ann_info['num_lidar_pts'].dtype == np.int64
    assert 'truncated' in ann_info
    assert ann_info['truncated'].dtype == np.int64

    car_kitti_dataset = KittiDataset(
        data_root,
        ann_file,
        data_prefix=dict(
            pts='training/velodyne_reduced',
            img='training/image_2',
        ),
        pipeline=pipeline,
        metainfo=dict(CLASSES=['Car']),
        modality=modality)

    input_dict = car_kitti_dataset.get_data_info(0)
    ann_info = car_kitti_dataset.parse_ann_info(input_dict)

    # assert the keys in ann_info and the type
    assert 'gt_labels' in ann_info
    assert ann_info['gt_labels'].dtype == np.int64
    # all instance have been filtered by classes
    assert len(ann_info['gt_labels']) == 0
    assert len(car_kitti_dataset.metainfo['CLASSES']) == 1