test_sunrgbd_dataset.py 3.47 KB
Newer Older
liyinhao's avatar
liyinhao committed
1
import numpy as np
liyinhao's avatar
liyinhao committed
2
3
import pytest
import torch
liyinhao's avatar
liyinhao committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61

from mmdet3d.datasets import SunrgbdDataset


def test_getitem():
    np.random.seed(0)
    root_path = './tests/data/sunrgbd'
    ann_file = './tests/data/sunrgbd/sunrgbd_infos.pkl'
    class_names = ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk',
                   'dresser', 'night_stand', 'bookshelf', 'bathtub')
    pipelines = [
        dict(
            type='IndoorLoadPointsFromFile',
            use_height=True,
            load_dim=6,
            use_dim=[0, 1, 2]),
        dict(type='IndoorFlipData', flip_ratio_yz=1.0),
        dict(
            type='IndoorGlobalRotScale',
            use_height=True,
            rot_range=[-np.pi / 6, np.pi / 6],
            scale_range=[0.85, 1.15]),
        dict(type='IndoorPointSample', num_points=5),
        dict(type='DefaultFormatBundle3D', class_names=class_names),
        dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels']),
    ]

    sunrgbd_dataset = SunrgbdDataset(root_path, ann_file, pipelines, True)
    data = sunrgbd_dataset[0]
    points = data['points']._data
    gt_bboxes_3d = data['gt_bboxes_3d']._data
    gt_labels = data['gt_labels']._data

    expected_points = np.array(
        [[0.6570105, 1.5538014, 0.24514851, 1.0165423],
         [0.656101, 1.558591, 0.21755838, 0.98895216],
         [0.6293659, 1.5679953, -0.10004003, 0.67135376],
         [0.6068739, 1.5974995, -0.41063973, 0.36075398],
         [0.6464709, 1.5573514, 0.15114647, 0.9225402]])
    expected_gt_bboxes_3d = np.array([[
        -2.012483, 3.9473376, -0.25446942, 2.3730404, 1.9457763, 2.0303352,
        1.2205974
    ],
                                      [
                                          -3.7036808, 4.2396426, -0.81091917,
                                          0.6032123, 0.91040343, 1.003341,
                                          1.2662518
                                      ],
                                      [
                                          0.6528646, 2.1638472, -0.15228128,
                                          0.7347852, 1.6113238, 2.1694272,
                                          2.81404
                                      ]])
    expected_gt_labels = np.array([0, 7, 6])

    assert np.allclose(points, expected_points)
    assert np.allclose(gt_bboxes_3d, expected_gt_bboxes_3d)
    assert np.all(gt_labels.numpy() == expected_gt_labels)
liyinhao's avatar
liyinhao committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82


def test_evaluate():

    if not torch.cuda.is_available():
        pytest.skip()
    root_path = './tests/data/sunrgbd'
    ann_file = './tests/data/sunrgbd/sunrgbd_infos.pkl'
    sunrgbd_dataset = SunrgbdDataset(root_path, ann_file)
    results = []
    pred_boxes = dict()
    pred_boxes['box3d_lidar'] = np.array([[
        1.047307, 4.168696, -0.246859, 2.30207, 1.887584, 1.969614, 1.69564944
    ], [
        2.583086, 4.811675, -0.786667, 0.585172, 0.883176, 0.973334, 1.64999513
    ], [
        -1.086364, 1.904545, -0.147727, 0.71281, 1.563134, 2.104546, 0.1022069
    ]])
    pred_boxes['label_preds'] = torch.Tensor([0, 7, 6]).cuda()
    pred_boxes['scores'] = torch.Tensor([0.5, 1.0, 1.0]).cuda()
    results.append([pred_boxes])
liyinhao's avatar
liyinhao committed
83
    metric = [0.25, 0.5]
liyinhao's avatar
liyinhao committed
84
    ap_dict = sunrgbd_dataset.evaluate(results, metric)
liyinhao's avatar
liyinhao committed
85
86
87
    bed_precision_25 = ap_dict['bed_AP_25']
    dresser_precision_25 = ap_dict['dresser_AP_25']
    night_stand_precision_25 = ap_dict['night_stand_AP_25']
liyinhao's avatar
liyinhao committed
88
89
90
    assert abs(bed_precision_25 - 1) < 0.01
    assert abs(dresser_precision_25 - 1) < 0.01
    assert abs(night_stand_precision_25 - 1) < 0.01