test_loading.py 3.86 KB
Newer Older
liyinhao's avatar
liyinhao committed
1
2
import os.path as osp

liyinhao's avatar
liyinhao committed
3
import mmcv
liyinhao's avatar
liyinhao committed
4
import numpy as np
zhangwenwei's avatar
zhangwenwei committed
5
import pytest
liyinhao's avatar
liyinhao committed
6

wuyuefeng's avatar
wuyuefeng committed
7
from mmdet3d.core.bbox import DepthInstance3DBoxes
zhangwenwei's avatar
zhangwenwei committed
8
from mmdet3d.datasets.pipelines import LoadAnnotations3D, LoadPointsFromFile
liyinhao's avatar
liyinhao committed
9
10


zhangwenwei's avatar
zhangwenwei committed
11
def test_load_points_from_indoor_file():
12
    sunrgbd_info = mmcv.load('./tests/data/sunrgbd/sunrgbd_infos.pkl')
zhangwenwei's avatar
zhangwenwei committed
13
    sunrgbd_load_points_from_file = LoadPointsFromFile(6, shift_height=True)
liyinhao's avatar
liyinhao committed
14
    sunrgbd_results = dict()
liyinhao's avatar
liyinhao committed
15
    data_path = './tests/data/sunrgbd'
liyinhao's avatar
liyinhao committed
16
    sunrgbd_info = sunrgbd_info[0]
liyinhao's avatar
liyinhao committed
17
18
    sunrgbd_results['pts_filename'] = osp.join(data_path,
                                               sunrgbd_info['pts_path'])
liyinhao's avatar
liyinhao committed
19
    sunrgbd_results = sunrgbd_load_points_from_file(sunrgbd_results)
20
    sunrgbd_point_cloud = sunrgbd_results['points']
liyinhao's avatar
liyinhao committed
21
    assert sunrgbd_point_cloud.shape == (100, 4)
liyinhao's avatar
liyinhao committed
22
23

    scannet_info = mmcv.load('./tests/data/scannet/scannet_infos.pkl')
zhangwenwei's avatar
zhangwenwei committed
24
    scannet_load_data = LoadPointsFromFile(shift_height=True)
liyinhao's avatar
liyinhao committed
25
    scannet_results = dict()
liyinhao's avatar
liyinhao committed
26
    data_path = './tests/data/scannet'
liyinhao's avatar
liyinhao committed
27
    scannet_info = scannet_info[0]
liyinhao's avatar
liyinhao committed
28

liyinhao's avatar
liyinhao committed
29
    scannet_results['pts_filename'] = osp.join(data_path,
liyinhao's avatar
liyinhao committed
30
                                               scannet_info['pts_path'])
liyinhao's avatar
liyinhao committed
31
    scannet_results = scannet_load_data(scannet_results)
32
    scannet_point_cloud = scannet_results['points']
liyinhao's avatar
liyinhao committed
33
    assert scannet_point_cloud.shape == (100, 4)
liyinhao's avatar
liyinhao committed
34
35


zhangwenwei's avatar
zhangwenwei committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def test_load_points_from_outdoor_file():
    data_path = 'tests/data/kitti/a.bin'
    load_points_from_file = LoadPointsFromFile(4, 4)
    results = dict()
    results['pts_filename'] = data_path
    results = load_points_from_file(results)
    points = results['points']
    assert points.shape == (50, 4)
    assert np.allclose(points.sum(), 2637.479)

    load_points_from_file = LoadPointsFromFile(4, [0, 1, 2, 3])
    results = dict()
    results['pts_filename'] = data_path
    results = load_points_from_file(results)
    new_points = results['points']
    assert new_points.shape == (50, 4)
    assert np.allclose(points.sum(), 2637.479)
    np.equal(points, new_points)
liyinhao's avatar
liyinhao committed
54

zhangwenwei's avatar
zhangwenwei committed
55
56
57
58
59
60
    with pytest.raises(AssertionError):
        LoadPointsFromFile(4, 5)


def test_load_annotations3D():
    # Test scannet LoadAnnotations3D
liyinhao's avatar
liyinhao committed
61
    scannet_info = mmcv.load('./tests/data/scannet/scannet_infos.pkl')[0]
zhangwenwei's avatar
zhangwenwei committed
62
63
64
65
66
    scannet_load_annotations3D = LoadAnnotations3D(
        with_bbox_3d=True,
        with_label_3d=True,
        with_mask_3d=True,
        with_seg_3d=True)
liyinhao's avatar
liyinhao committed
67
    scannet_results = dict()
liyinhao's avatar
liyinhao committed
68
    data_path = './tests/data/scannet'
zhangwenwei's avatar
zhangwenwei committed
69

liyinhao's avatar
liyinhao committed
70
71
    if scannet_info['annos']['gt_num'] != 0:
        scannet_gt_bboxes_3d = scannet_info['annos']['gt_boxes_upright_depth']
liyinhao's avatar
liyinhao committed
72
        scannet_gt_labels_3d = scannet_info['annos']['class']
liyinhao's avatar
liyinhao committed
73
74
    else:
        scannet_gt_bboxes_3d = np.zeros((1, 6), dtype=np.float32)
liyinhao's avatar
liyinhao committed
75
        scannet_gt_labels_3d = np.zeros((1, ))
zhangwenwei's avatar
zhangwenwei committed
76
77
78
79

    # prepare input of loading pipeline
    scannet_results['ann_info'] = dict()
    scannet_results['ann_info']['pts_instance_mask_path'] = osp.join(
liyinhao's avatar
liyinhao committed
80
        data_path, scannet_info['pts_instance_mask_path'])
zhangwenwei's avatar
zhangwenwei committed
81
    scannet_results['ann_info']['pts_semantic_mask_path'] = osp.join(
liyinhao's avatar
liyinhao committed
82
        data_path, scannet_info['pts_semantic_mask_path'])
wuyuefeng's avatar
wuyuefeng committed
83
84
    scannet_results['ann_info']['gt_bboxes_3d'] = DepthInstance3DBoxes(
        scannet_gt_bboxes_3d, box_dim=6, with_yaw=False)
zhangwenwei's avatar
zhangwenwei committed
85
86
87
88
89
90
    scannet_results['ann_info']['gt_labels_3d'] = scannet_gt_labels_3d

    scannet_results['bbox3d_fields'] = []
    scannet_results['pts_mask_fields'] = []
    scannet_results['pts_seg_fields'] = []

liyinhao's avatar
liyinhao committed
91
    scannet_results = scannet_load_annotations3D(scannet_results)
92
    scannet_gt_boxes = scannet_results['gt_bboxes_3d']
liyinhao's avatar
liyinhao committed
93
    scannet_gt_lbaels = scannet_results['gt_labels_3d']
zhangwenwei's avatar
zhangwenwei committed
94

95
96
    scannet_pts_instance_mask = scannet_results['pts_instance_mask']
    scannet_pts_semantic_mask = scannet_results['pts_semantic_mask']
wuyuefeng's avatar
wuyuefeng committed
97
    assert scannet_gt_boxes.tensor.shape == (27, 7)
98
    assert scannet_gt_lbaels.shape == (27, )
liyinhao's avatar
liyinhao committed
99
100
    assert scannet_pts_instance_mask.shape == (100, )
    assert scannet_pts_semantic_mask.shape == (100, )