test_kitti_dataset.py 3.74 KB
Newer Older
jshilong's avatar
jshilong committed
1
2
3
# Copyright (c) OpenMMLab. All rights reserved.

import numpy as np
jshilong's avatar
jshilong committed
4
5
6
import torch
from mmcv.transforms.base import BaseTransform
from mmengine.registry import TRANSFORMS
jshilong's avatar
jshilong committed
7
8
9
10
11
12
13
14
15
16

from mmdet3d.core import LiDARInstance3DBoxes
from mmdet3d.datasets import KittiDataset


def _generate_kitti_dataset_config():
    data_root = 'tests/data/kitti'
    ann_file = 'kitti_infos_train.pkl'
    classes = ['Pedestrian', 'Cyclist', 'Car']
    # wait for pipline refactor
jshilong's avatar
jshilong committed
17
18
19
20
21
22
23
24
25
26
27

    if 'Identity' not in TRANSFORMS:

        @TRANSFORMS.register_module()
        class Identity(BaseTransform):

            def transform(self, info):
                if 'ann_info' in info:
                    info['gt_labels_3d'] = info['ann_info']['gt_labels_3d']
                return info

jshilong's avatar
jshilong committed
28
    pipeline = [
jshilong's avatar
jshilong committed
29
        dict(type='Identity'),
jshilong's avatar
jshilong committed
30
    ]
jshilong's avatar
jshilong committed
31

jshilong's avatar
jshilong committed
32
33
34
35
36
37
38
39
    modality = dict(use_lidar=True, use_camera=False)
    data_prefix = dict(pts='training/velodyne_reduced', img='training/image_2')
    return data_root, ann_file, classes, data_prefix, pipeline, modality


def test_getitem():
    np.random.seed(0)
    data_root, ann_file, classes, data_prefix, \
jshilong's avatar
jshilong committed
40
        pipeline, modality, = _generate_kitti_dataset_config()
jshilong's avatar
jshilong committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    modality['use_camera'] = True

    kitti_dataset = KittiDataset(
        data_root,
        ann_file,
        data_prefix=dict(
            pts='training/velodyne_reduced',
            img='training/image_2',
        ),
        pipeline=pipeline,
        metainfo=dict(CLASSES=classes),
        modality=modality)

    kitti_dataset.prepare_data(0)
    input_dict = kitti_dataset.get_data_info(0)
    kitti_dataset[0]
    # assert the the path should contains data_prefix and data_root
    assert data_prefix['pts'] in input_dict['lidar_points']['lidar_path']
    assert data_root in input_dict['lidar_points']['lidar_path']
    for cam_id, img_info in input_dict['images'].items():
        if 'img_path' in img_info:
            assert data_prefix['img'] in img_info['img_path']
            assert data_root in img_info['img_path']

    ann_info = kitti_dataset.parse_ann_info(input_dict)

    # assert the keys in ann_info and the type
    assert 'gt_labels' in ann_info
    assert ann_info['gt_labels'].dtype == np.int64
    # only one instance
    assert len(ann_info['gt_labels']) == 1
jshilong's avatar
jshilong committed
72
    assert (ann_info['gt_labels'] == 0).all()
jshilong's avatar
jshilong committed
73
74
75
76
77
78
    assert 'gt_labels_3d' in ann_info
    assert ann_info['gt_labels_3d'].dtype == np.int64
    assert 'gt_bboxes' in ann_info
    assert ann_info['gt_bboxes'].dtype == np.float64
    assert 'gt_bboxes_3d' in ann_info
    assert isinstance(ann_info['gt_bboxes_3d'], LiDARInstance3DBoxes)
jshilong's avatar
jshilong committed
79
80
    assert torch.allclose(ann_info['gt_bboxes_3d'].tensor.sum(),
                          torch.tensor(7.2650))
jshilong's avatar
jshilong committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    assert 'group_id' in ann_info
    assert ann_info['group_id'].dtype == np.int64
    assert 'occluded' in ann_info
    assert ann_info['occluded'].dtype == np.int64
    assert 'difficulty' in ann_info
    assert ann_info['difficulty'].dtype == np.int64
    assert 'num_lidar_pts' in ann_info
    assert ann_info['num_lidar_pts'].dtype == np.int64
    assert 'truncated' in ann_info
    assert ann_info['truncated'].dtype == np.int64

    car_kitti_dataset = KittiDataset(
        data_root,
        ann_file,
        data_prefix=dict(
            pts='training/velodyne_reduced',
            img='training/image_2',
        ),
        pipeline=pipeline,
        metainfo=dict(CLASSES=['Car']),
        modality=modality)

    input_dict = car_kitti_dataset.get_data_info(0)
    ann_info = car_kitti_dataset.parse_ann_info(input_dict)

    # assert the keys in ann_info and the type
    assert 'gt_labels' in ann_info
    assert ann_info['gt_labels'].dtype == np.int64
    # all instance have been filtered by classes
    assert len(ann_info['gt_labels']) == 0
    assert len(car_kitti_dataset.metainfo['CLASSES']) == 1