test_s3dis_dataset.py 4.12 KB
Newer Older
ZCMax's avatar
ZCMax committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# Copyright (c) OpenMMLab. All rights reserved.
import unittest

import numpy as np
import torch

from mmdet3d.datasets import S3DISSegDataset
from mmdet3d.utils import register_all_modules


def _generate_s3dis_seg_dataset_config():
    data_root = './tests/data/s3dis/'
    ann_file = 's3dis_infos.pkl'
    classes = ('ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door',
               'table', 'chair', 'sofa', 'bookcase', 'board', 'clutter')
    palette = [[0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 255, 0],
               [255, 0, 255], [100, 100, 255], [200, 200, 100],
               [170, 120, 200], [255, 0, 0], [200, 100, 100], [10, 200, 100],
               [200, 200, 200], [50, 50, 50]]
    scene_idxs = [0 for _ in range(20)]
    modality = dict(use_lidar=True, use_camera=False)
    pipeline = [
        dict(
            type='LoadPointsFromFile',
            coord_type='DEPTH',
            shift_height=False,
            use_color=True,
            load_dim=6,
            use_dim=[0, 1, 2, 3, 4, 5]),
        dict(
            type='LoadAnnotations3D',
            with_bbox_3d=False,
            with_label_3d=False,
            with_mask_3d=False,
            with_seg_3d=True),
        dict(type='PointSegClassMapping'),
        dict(
            type='IndoorPatchPointSample',
            num_points=5,
            block_size=1.0,
            ignore_index=len(classes),
            use_normalized_coord=True,
            enlarge_size=0.2,
            min_unique_num=None),
        dict(type='NormalizePointsColor', color_mean=None),
        dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
    ]

    data_prefix = dict(
        pts='points',
        pts_instance_mask='instance_mask',
        pts_semantic_mask='semantic_mask')

    return (data_root, ann_file, classes, palette, scene_idxs, data_prefix,
            pipeline, modality)


class TestS3DISDataset(unittest.TestCase):

    def test_s3dis_seg(self):
        np.random.seed(0)
        data_root, ann_file, classes, palette, scene_idxs, data_prefix, \
            pipeline, modality, = _generate_s3dis_seg_dataset_config()

        register_all_modules()
        s3dis_seg_dataset = S3DISSegDataset(
            data_root,
            ann_file,
            metainfo=dict(CLASSES=classes, PALETTE=palette),
            data_prefix=data_prefix,
            pipeline=pipeline,
            modality=modality,
            scene_idxs=scene_idxs)

        input_dict = s3dis_seg_dataset.prepare_data(0)

        points = input_dict['inputs']['points']
        data_sample = input_dict['data_sample']
        pts_semantic_mask = data_sample.gt_pts_seg.pts_semantic_mask

        expected_points = torch.tensor([[
            0.0000, 0.0000, 3.1720, 0.4706, 0.4431, 0.3725, 0.4624, 0.7502,
            0.9543
        ],
                                        [
                                            0.2880, -0.5900, 0.0650, 0.3451,
                                            0.3373, 0.3490, 0.5119, 0.5518,
                                            0.0196
                                        ],
                                        [
                                            0.1570, 0.6000, 3.1700, 0.4941,
                                            0.4667, 0.3569, 0.4893, 0.9519,
                                            0.9537
                                        ],
                                        [
                                            -0.1320, 0.3950, 0.2720, 0.3216,
                                            0.2863, 0.2275, 0.4397, 0.8830,
                                            0.0818
                                        ],
                                        [
                                            -0.4860, -0.0640, 3.1710, 0.3843,
                                            0.3725, 0.3059, 0.3789, 0.7286,
                                            0.9540
                                        ]])

        expected_pts_semantic_mask = np.array([0, 1, 0, 8, 0])

        assert torch.allclose(points, expected_points, 1e-2)
        self.assertTrue(
            (pts_semantic_mask.numpy() == expected_pts_semantic_mask).all())