"official/projects/movinet/train.py" did not exist on "db4acd913f99659e270c44aab0e417f9cf0b8aaa"
test_sunrgbd_dataset.py 4.4 KB
Newer Older
liyinhao's avatar
liyinhao committed
1
import numpy as np
liyinhao's avatar
liyinhao committed
2
3
import pytest
import torch
liyinhao's avatar
liyinhao committed
4

5
from mmdet3d.datasets import SUNRGBDDataset
liyinhao's avatar
liyinhao committed
6
7
8
9


def test_getitem():
    np.random.seed(0)
10
    root_path = './tests/data/sunrgbd/sunrgbd_trainval'
liyinhao's avatar
liyinhao committed
11
12
13
14
15
16
17
18
19
20
21
22
23
    ann_file = './tests/data/sunrgbd/sunrgbd_infos.pkl'
    class_names = ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk',
                   'dresser', 'night_stand', 'bookshelf', 'bathtub')
    pipelines = [
        dict(
            type='IndoorLoadPointsFromFile',
            use_height=True,
            load_dim=6,
            use_dim=[0, 1, 2]),
        dict(type='IndoorFlipData', flip_ratio_yz=1.0),
        dict(
            type='IndoorGlobalRotScale',
            use_height=True,
liyinhao's avatar
liyinhao committed
24
            rot_range=[-1 / 6, 1 / 6],
liyinhao's avatar
liyinhao committed
25
26
27
28
29
30
            scale_range=[0.85, 1.15]),
        dict(type='IndoorPointSample', num_points=5),
        dict(type='DefaultFormatBundle3D', class_names=class_names),
        dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels']),
    ]

31
    sunrgbd_dataset = SUNRGBDDataset(root_path, ann_file, pipelines)
liyinhao's avatar
liyinhao committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    data = sunrgbd_dataset[0]
    points = data['points']._data
    gt_bboxes_3d = data['gt_bboxes_3d']._data
    gt_labels = data['gt_labels']._data

    expected_points = np.array(
        [[0.6570105, 1.5538014, 0.24514851, 1.0165423],
         [0.656101, 1.558591, 0.21755838, 0.98895216],
         [0.6293659, 1.5679953, -0.10004003, 0.67135376],
         [0.6068739, 1.5974995, -0.41063973, 0.36075398],
         [0.6464709, 1.5573514, 0.15114647, 0.9225402]])
    expected_gt_bboxes_3d = np.array([[
        -2.012483, 3.9473376, -0.25446942, 2.3730404, 1.9457763, 2.0303352,
        1.2205974
    ],
                                      [
                                          -3.7036808, 4.2396426, -0.81091917,
                                          0.6032123, 0.91040343, 1.003341,
                                          1.2662518
                                      ],
                                      [
                                          0.6528646, 2.1638472, -0.15228128,
                                          0.7347852, 1.6113238, 2.1694272,
                                          2.81404
                                      ]])
    expected_gt_labels = np.array([0, 7, 6])
58
    original_classes = sunrgbd_dataset.CLASSES
liyinhao's avatar
liyinhao committed
59
60
61
62

    assert np.allclose(points, expected_points)
    assert np.allclose(gt_bboxes_3d, expected_gt_bboxes_3d)
    assert np.all(gt_labels.numpy() == expected_gt_labels)
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    assert original_classes == class_names

    SUNRGBD_dataset = SUNRGBDDataset(
        root_path, ann_file, pipeline=None, classes=['bed', 'table'])
    assert SUNRGBD_dataset.CLASSES != original_classes
    assert SUNRGBD_dataset.CLASSES == ['bed', 'table']

    SUNRGBD_dataset = SUNRGBDDataset(
        root_path, ann_file, pipeline=None, classes=('bed', 'table'))
    assert SUNRGBD_dataset.CLASSES != original_classes
    assert SUNRGBD_dataset.CLASSES == ('bed', 'table')

    import tempfile
    tmp_file = tempfile.NamedTemporaryFile()
    with open(tmp_file.name, 'w') as f:
        f.write('bed\ntable\n')

    SUNRGBD_dataset = SUNRGBDDataset(
        root_path, ann_file, pipeline=None, classes=tmp_file.name)
    assert SUNRGBD_dataset.CLASSES != original_classes
    assert SUNRGBD_dataset.CLASSES == ['bed', 'table']
liyinhao's avatar
liyinhao committed
84
85
86
87
88
89
90
91


def test_evaluate():

    if not torch.cuda.is_available():
        pytest.skip()
    root_path = './tests/data/sunrgbd'
    ann_file = './tests/data/sunrgbd/sunrgbd_infos.pkl'
92
    sunrgbd_dataset = SUNRGBDDataset(root_path, ann_file)
liyinhao's avatar
liyinhao committed
93
94
    results = []
    pred_boxes = dict()
95
96
97
98
99
100
101
102
103
104
    pred_boxes['box3d_lidar'] = np.array(
        [[
            4.168696, -1.047307, -1.231666, 1.887584, 2.30207, 1.969614,
            1.69564944
        ],
         [
             4.811675, -2.583086, -1.273334, 0.883176, 0.585172, 0.973334,
             1.64999513
         ], [1.904545, 1.086364, -1.2, 1.563134, 0.71281, 2.104546,
             0.1022069]])
liyinhao's avatar
liyinhao committed
105
106
107
    pred_boxes['label_preds'] = torch.Tensor([0, 7, 6]).cuda()
    pred_boxes['scores'] = torch.Tensor([0.5, 1.0, 1.0]).cuda()
    results.append([pred_boxes])
liyinhao's avatar
liyinhao committed
108
    metric = [0.25, 0.5]
liyinhao's avatar
liyinhao committed
109
    ap_dict = sunrgbd_dataset.evaluate(results, metric)
liyinhao's avatar
liyinhao committed
110
111
112
    bed_precision_25 = ap_dict['bed_AP_0.25']
    dresser_precision_25 = ap_dict['dresser_AP_0.25']
    night_stand_precision_25 = ap_dict['night_stand_AP_0.25']
liyinhao's avatar
liyinhao committed
113
114
115
    assert abs(bed_precision_25 - 1) < 0.01
    assert abs(dresser_precision_25 - 1) < 0.01
    assert abs(night_stand_precision_25 - 1) < 0.01