test_dataset_wrappers.py 4.36 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp

import numpy as np
import pytest
from mmcv.transforms.base import BaseTransform
from mmengine.structures import InstanceData

from mmdet3d.datasets import CBGSDataset, NuScenesDataset
from mmdet3d.registry import DATASETS, TRANSFORMS
from mmdet3d.structures import Det3DDataSample


def is_equal(dict_a, dict_b):
    for key in dict_a:
        if key not in dict_b:
            return False
        if isinstance(dict_a[key], dict):
            return is_equal(dict_a[key], dict_b[key])
        elif isinstance(dict_a[key], np.ndarray):
            if not (dict_a[key] == dict_b[key]).any():
                return False
        else:
            if not (dict_a[key] == dict_b[key]):
                return False
    return True


@TRANSFORMS.register_module()
class Identity(BaseTransform):

    def transform(self, info):
        packed_input = dict(data_samples=Det3DDataSample())
        if 'ann_info' in info:
            packed_input['data_samples'].gt_instances_3d = InstanceData()
            packed_input['data_samples'].gt_instances_3d.labels_3d = info[
                'ann_info']['gt_labels_3d']
        return packed_input


@DATASETS.register_module()
class CustomDataset(NuScenesDataset):
    pass


class TestCBGSDataset:

    def setup(self):
        dataset = NuScenesDataset
        self.dataset = dataset(
            data_root=osp.join(osp.dirname(__file__), '../data/nuscenes'),
            ann_file='nus_info.pkl',
            data_prefix=dict(
                pts='samples/LIDAR_TOP', img='', sweeps='sweeps/LIDAR_TOP'),
            pipeline=[dict(type=Identity)])

        self.sample_indices = [0, 0, 1, 1, 1]
        # test init
        self.cbgs_datasets = CBGSDataset(dataset=self.dataset)
        self.cbgs_datasets.sample_indices = self.sample_indices

    def test_init(self):
        # Test build dataset from cfg
        dataset_cfg = dict(
            type=CustomDataset,
            data_root=osp.join(osp.dirname(__file__), '../data/nuscenes'),
            ann_file='nus_info.pkl',
            data_prefix=dict(
                pts='samples/LIDAR_TOP', img='', sweeps='sweeps/LIDAR_TOP'),
            pipeline=[dict(type=Identity)])
        cbgs_datasets = CBGSDataset(dataset=dataset_cfg)
        cbgs_datasets.sample_indices = self.sample_indices
        cbgs_datasets.dataset.pipeline = self.dataset.pipeline
        assert len(cbgs_datasets) == len(self.cbgs_datasets)
        for i in range(len(cbgs_datasets)):
            assert is_equal(
                cbgs_datasets.get_data_info(i),
                self.cbgs_datasets.get_data_info(i))
            assert (cbgs_datasets[i]['data_samples'].gt_instances_3d.labels_3d
                    == self.cbgs_datasets[i]
                    ['data_samples'].gt_instances_3d.labels_3d).any()

        with pytest.raises(TypeError):
            CBGSDataset(dataset=[0])

    def test_full_init(self):
        self.cbgs_datasets.full_init()
        self.cbgs_datasets.sample_indices = self.sample_indices
        assert len(self.cbgs_datasets) == len(self.sample_indices)
        # Reinit `sample_indices`
        self.cbgs_datasets._fully_initialized = False
        self.cbgs_datasets.sample_indices = self.sample_indices
        assert len(self.cbgs_datasets) != len(self.sample_indices)

        with pytest.raises(NotImplementedError):
            self.cbgs_datasets.get_subset_(1)

        with pytest.raises(NotImplementedError):
            self.cbgs_datasets.get_subset(1)

    def test_metainfo(self):
        assert self.cbgs_datasets.metainfo == self.dataset.metainfo

    def test_length(self):
        assert len(self.cbgs_datasets) == len(self.sample_indices)

    def test_getitem(self):
        for i in range(len(self.sample_indices)):
            assert (self.cbgs_datasets[i]['data_samples'].gt_instances_3d.
                    labels_3d == self.dataset[self.sample_indices[i]]
                    ['data_samples'].gt_instances_3d.labels_3d).any()

    def test_get_data_info(self):
        for i in range(len(self.sample_indices)):
            assert is_equal(
                self.cbgs_datasets.get_data_info(i),
                self.dataset.get_data_info(self.sample_indices[i]))

    def test_get_cat_ids(self):
        for i in range(len(self.sample_indices)):
            assert self.cbgs_datasets.get_cat_ids(
                i) == self.dataset.get_cat_ids(self.sample_indices[i])