test_lyft_dataset.py 2.37 KB
Newer Older
VVsssssk's avatar
VVsssssk committed
1
2
3
4
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmcv.transforms.base import BaseTransform
from mmengine.registry import TRANSFORMS
5
from mmengine.structures import InstanceData
VVsssssk's avatar
VVsssssk committed
6
7

from mmdet3d.datasets import LyftDataset
zhangshilong's avatar
zhangshilong committed
8
from mmdet3d.structures import Det3DDataSample, LiDARInstance3DBoxes
VVsssssk's avatar
VVsssssk committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23


def _generate_nus_dataset_config():
    data_root = 'tests/data/lyft'
    ann_file = 'lyft_infos.pkl'
    classes = [
        'car', 'truck', 'bus', 'emergency_vehicle', 'other_vehicle',
        'motorcycle', 'bicycle', 'pedestrian', 'animal'
    ]
    if 'Identity' not in TRANSFORMS:

        @TRANSFORMS.register_module()
        class Identity(BaseTransform):

            def transform(self, info):
VVsssssk's avatar
VVsssssk committed
24
                packed_input = dict(data_samples=Det3DDataSample())
VVsssssk's avatar
VVsssssk committed
25
26
                if 'ann_info' in info:
                    packed_input[
VVsssssk's avatar
VVsssssk committed
27
28
29
                        'data_samples'].gt_instances_3d = InstanceData()
                    packed_input[
                        'data_samples'].gt_instances_3d.labels_3d = info[
VVsssssk's avatar
VVsssssk committed
30
31
32
33
34
35
36
                            'ann_info']['gt_labels_3d']
                return packed_input

    pipeline = [
        dict(type='Identity'),
    ]
    modality = dict(use_lidar=True, use_camera=False)
37
    data_prefix = dict(pts='lidar', img='', sweeps='sweeps/LIDAR_TOP')
VVsssssk's avatar
VVsssssk committed
38
39
40
41
42
43
44
45
46
47
48
49
50
    return data_root, ann_file, classes, data_prefix, pipeline, modality


def test_getitem():
    np.random.seed(0)
    data_root, ann_file, classes, data_prefix, pipeline, modality = \
        _generate_nus_dataset_config()

    lyft_dataset = LyftDataset(
        data_root,
        ann_file,
        data_prefix=data_prefix,
        pipeline=pipeline,
51
        metainfo=dict(classes=classes),
VVsssssk's avatar
VVsssssk committed
52
53
54
55
56
        modality=modality)

    lyft_dataset.prepare_data(0)
    input_dict = lyft_dataset.get_data_info(0)
    # assert the the path should contains data_prefix and data_root
57
58
    assert data_prefix['pts'] in input_dict['lidar_points']['lidar_path']
    assert data_root in input_dict['lidar_points']['lidar_path']
VVsssssk's avatar
VVsssssk committed
59
60
61
62
63
64
65
66
67
68
69

    ann_info = lyft_dataset.parse_ann_info(input_dict)

    # assert the keys in ann_info and the type
    assert 'gt_labels_3d' in ann_info
    assert ann_info['gt_labels_3d'].dtype == np.int64
    assert len(ann_info['gt_labels_3d']) == 3

    assert 'gt_bboxes_3d' in ann_info
    assert isinstance(ann_info['gt_bboxes_3d'], LiDARInstance3DBoxes)

70
    assert len(lyft_dataset.metainfo['classes']) == 9