test_scannet_dataset.py 5.39 KB
Newer Older
yinchimaoliang's avatar
yinchimaoliang committed
1
import numpy as np
yinchimaoliang's avatar
yinchimaoliang committed
2
import pytest
yinchimaoliang's avatar
yinchimaoliang committed
3
import torch
yinchimaoliang's avatar
yinchimaoliang committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

from mmdet3d.datasets.scannet_dataset import ScannetDataset


def test_getitem():
    np.random.seed(0)
    root_path = './tests/data/scannet'
    ann_file = './tests/data/scannet/scannet_infos.pkl'
    class_names = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door',
                   'window', 'bookshelf', 'picture', 'counter', 'desk',
                   'curtain', 'refrigerator', 'showercurtrain', 'toilet',
                   'sink', 'bathtub', 'garbagebin')
    pipelines = [
        dict(
            type='IndoorLoadPointsFromFile',
            use_height=True,
            load_dim=6,
            use_dim=[0, 1, 2]),
        dict(type='IndoorLoadAnnotations3D'),
        dict(type='IndoorPointSample', num_points=5),
        dict(type='IndoorFlipData', flip_ratio_yz=1.0, flip_ratio_xz=1.0),
        dict(
            type='IndoorGlobalRotScale',
            use_height=True,
            rot_range=[-np.pi * 1 / 36, np.pi * 1 / 36],
            scale_range=None),
        dict(type='DefaultFormatBundle3D', class_names=class_names),
        dict(
            type='Collect3D',
            keys=[
                'points', 'gt_bboxes_3d', 'gt_labels', 'pts_semantic_mask',
                'pts_instance_mask'
            ]),
    ]

    scannet_dataset = ScannetDataset(root_path, ann_file, pipelines, True)
    data = scannet_dataset[0]
    points = data['points']._data
    gt_bboxes_3d = data['gt_bboxes_3d']._data
    gt_labels = data['gt_labels']._data
    pts_semantic_mask = data['pts_semantic_mask']
    pts_instance_mask = data['pts_instance_mask']

    expected_points = np.array(
        [[-2.9078157, -1.9569951, 2.3543026, 2.389488],
         [-0.71360034, -3.4359822, 2.1330001, 2.1681855],
         [-1.332374, 1.474838, -0.04405887, -0.00887359],
         [2.1336637, -1.3265059, -0.02880373, 0.00638155],
         [0.43895668, -3.0259454, 1.5560012, 1.5911865]])
    expected_gt_bboxes_3d = np.array([
        [-1.5005362, -3.512584, 1.8565295, 1.7457027, 0.24149807, 0.57235193],
        [-2.8848705, 3.4961755, 1.5268247, 0.66170084, 0.17433672, 0.67153597],
        [-1.1585636, -2.192365, 0.61649567, 0.5557011, 2.5375574, 1.2144762],
        [-2.930457, -2.4856408, 0.9722377, 0.6270478, 1.8461524, 0.28697443],
        [3.3114715, -0.00476722, 1.0712197, 0.46191898, 3.8605113, 2.1603441]
    ])
    expected_gt_labels = np.array([
        6, 6, 4, 9, 11, 11, 10, 0, 15, 17, 17, 17, 3, 12, 4, 4, 14, 1, 0, 0, 0,
        0, 0, 0, 5, 5, 5
    ])
    expected_pts_semantic_mask = np.array([3, 1, 2, 2, 15])
    expected_pts_instance_mask = np.array([44, 22, 10, 10, 57])

    assert np.allclose(points, expected_points)
    assert gt_bboxes_3d[:5].shape == (5, 6)
    assert np.allclose(gt_bboxes_3d[:5], expected_gt_bboxes_3d)
    assert np.all(gt_labels.numpy() == expected_gt_labels)
    assert np.all(pts_semantic_mask == expected_pts_semantic_mask)
    assert np.all(pts_instance_mask == expected_pts_instance_mask)
yinchimaoliang's avatar
yinchimaoliang committed
73
74
75


def test_evaluate():
yinchimaoliang's avatar
yinchimaoliang committed
76
77
    if not torch.cuda.is_available():
        pytest.skip()
yinchimaoliang's avatar
yinchimaoliang committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    root_path = './tests/data/scannet'
    ann_file = './tests/data/scannet/scannet_infos.pkl'
    scannet_dataset = ScannetDataset(root_path, ann_file)
    results = []
    pred_boxes = dict()
    pred_boxes['box3d_lidar'] = np.array([[
        3.52074146e+00, -1.48129511e+00, 1.57035351e+00, 2.31956959e-01,
        1.74445975e+00, 5.72351933e-01, 0
    ],
                                          [
                                              -3.48033905e+00, -2.90395617e+00,
                                              1.19105673e+00, 1.70723915e-01,
                                              6.60776615e-01, 6.71535969e-01, 0
                                          ],
                                          [
                                              2.19867110e+00, -1.14655101e+00,
                                              9.25755501e-03, 2.53463078e+00,
                                              5.41841269e-01, 1.21447623e+00, 0
                                          ],
                                          [
                                              2.50163722, -2.91681337,
                                              0.82875049, 1.84280431,
                                              0.61697435, 0.28697443, 0
                                          ],
                                          [
                                              -0.01335114, 3.3114481,
                                              -0.00895238, 3.85815716,
                                              0.44081616, 2.16034412, 0
                                          ]])
    pred_boxes['label_preds'] = torch.Tensor([6, 6, 4, 9, 11]).cuda()
    pred_boxes['scores'] = torch.Tensor([0.5, 1.0, 1.0, 1.0, 1.0]).cuda()
    results.append([pred_boxes])
    metric = dict()
    metric['AP_IOU_THRESHHOLDS'] = [0.25, 0.5]
    ap_dict = scannet_dataset.evaluate(results, metric)
    table_average_precision_25 = ap_dict['table Average Precision 25']
    window_average_precision_25 = ap_dict['window Average Precision 25']
    counter_average_precision_25 = ap_dict['counter Average Precision 25']
    curtain_average_precision_25 = ap_dict['curtain Average Precision 25']
    assert abs(table_average_precision_25 - 0.3333) < 0.01
    assert abs(window_average_precision_25 - 1) < 0.01
    assert abs(counter_average_precision_25 - 1) < 0.01
    assert abs(curtain_average_precision_25 - 0.5) < 0.01