test_det3d_data_sample.py 6.77 KB
Newer Older
ZCMax's avatar
ZCMax committed
1
# Copyright (c) OpenMMLab. All rights reserved.
VVsssssk's avatar
VVsssssk committed
2
3
4
5
6
from unittest import TestCase

import numpy as np
import pytest
import torch
ZCMax's avatar
ZCMax committed
7
from mmengine.data import InstanceData
VVsssssk's avatar
VVsssssk committed
8

zhangshilong's avatar
zhangshilong committed
9
from mmdet3d.structures import Det3DDataSample, PointData
VVsssssk's avatar
VVsssssk committed
10
11
12
13
14
15
16
17
18


def _equal(a, b):
    if isinstance(a, (torch.Tensor, np.ndarray)):
        return (a == b).all()
    else:
        return a == b


VVsssssk's avatar
VVsssssk committed
19
class TestDet3DataSample(TestCase):
VVsssssk's avatar
VVsssssk committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

    def test_init(self):
        meta_info = dict(
            img_size=[256, 256],
            scale_factor=np.array([1.5, 1.5]),
            img_shape=torch.rand(4))

        det3d_data_sample = Det3DDataSample(metainfo=meta_info)
        assert 'img_size' in det3d_data_sample
        assert det3d_data_sample.img_size == [256, 256]
        assert det3d_data_sample.get('img_size') == [256, 256]

    def test_setter(self):
        det3d_data_sample = Det3DDataSample()
        # test gt_instances_3d
        gt_instances_3d_data = dict(
VVsssssk's avatar
VVsssssk committed
36
            bboxes_3d=torch.rand(4, 4), labels_3d=torch.rand(4))
VVsssssk's avatar
VVsssssk committed
37
38
39
        gt_instances_3d = InstanceData(**gt_instances_3d_data)
        det3d_data_sample.gt_instances_3d = gt_instances_3d
        assert 'gt_instances_3d' in det3d_data_sample
VVsssssk's avatar
VVsssssk committed
40
41
42
43
44
45
        assert _equal(det3d_data_sample.gt_instances_3d.bboxes_3d,
                      gt_instances_3d_data['bboxes_3d'])
        assert _equal(det3d_data_sample.gt_instances_3d.labels_3d,
                      gt_instances_3d_data['labels_3d'])

        # test pred_instances_3d
VVsssssk's avatar
VVsssssk committed
46
        pred_instances_3d_data = dict(
VVsssssk's avatar
VVsssssk committed
47
48
49
            bboxes_3d=torch.rand(2, 4),
            labels_3d=torch.rand(2),
            scores_3d=torch.rand(2))
VVsssssk's avatar
VVsssssk committed
50
51
52
        pred_instances_3d = InstanceData(**pred_instances_3d_data)
        det3d_data_sample.pred_instances_3d = pred_instances_3d
        assert 'pred_instances_3d' in det3d_data_sample
VVsssssk's avatar
VVsssssk committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        assert _equal(det3d_data_sample.pred_instances_3d.bboxes_3d,
                      pred_instances_3d_data['bboxes_3d'])
        assert _equal(det3d_data_sample.pred_instances_3d.labels_3d,
                      pred_instances_3d_data['labels_3d'])
        assert _equal(det3d_data_sample.pred_instances_3d.scores_3d,
                      pred_instances_3d_data['scores_3d'])

        # test pts_pred_instances_3d
        pts_pred_instances_3d_data = dict(
            bboxes_3d=torch.rand(2, 4),
            labels_3d=torch.rand(2),
            scores_3d=torch.rand(2))
        pts_pred_instances_3d = InstanceData(**pts_pred_instances_3d_data)
        det3d_data_sample.pts_pred_instances_3d = pts_pred_instances_3d
        assert 'pts_pred_instances_3d' in det3d_data_sample
        assert _equal(det3d_data_sample.pts_pred_instances_3d.bboxes_3d,
                      pts_pred_instances_3d_data['bboxes_3d'])
        assert _equal(det3d_data_sample.pts_pred_instances_3d.labels_3d,
                      pts_pred_instances_3d_data['labels_3d'])
        assert _equal(det3d_data_sample.pts_pred_instances_3d.scores_3d,
                      pts_pred_instances_3d_data['scores_3d'])

        # test img_pred_instances_3d
        img_pred_instances_3d_data = dict(
            bboxes_3d=torch.rand(2, 4),
            labels_3d=torch.rand(2),
            scores_3d=torch.rand(2))
        img_pred_instances_3d = InstanceData(**img_pred_instances_3d_data)
        det3d_data_sample.img_pred_instances_3d = img_pred_instances_3d
        assert 'img_pred_instances_3d' in det3d_data_sample
        assert _equal(det3d_data_sample.img_pred_instances_3d.bboxes_3d,
                      img_pred_instances_3d_data['bboxes_3d'])
        assert _equal(det3d_data_sample.img_pred_instances_3d.labels_3d,
                      img_pred_instances_3d_data['labels_3d'])
        assert _equal(det3d_data_sample.img_pred_instances_3d.scores_3d,
                      img_pred_instances_3d_data['scores_3d'])
VVsssssk's avatar
VVsssssk committed
89

ZCMax's avatar
ZCMax committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        # test gt_seg
        gt_pts_seg_data = dict(
            pts_instance_mask=torch.rand(20), pts_semantic_mask=torch.rand(20))
        gt_pts_seg = PointData(**gt_pts_seg_data)
        det3d_data_sample.gt_pts_seg = gt_pts_seg
        assert 'gt_pts_seg' in det3d_data_sample
        assert _equal(det3d_data_sample.gt_pts_seg.pts_instance_mask,
                      gt_pts_seg_data['pts_instance_mask'])
        assert _equal(det3d_data_sample.gt_pts_seg.pts_semantic_mask,
                      gt_pts_seg_data['pts_semantic_mask'])

        # test pred_seg
        pred_pts_seg_data = dict(
            pts_instance_mask=torch.rand(20), pts_semantic_mask=torch.rand(20))
        pred_pts_seg = PointData(**pred_pts_seg_data)
        det3d_data_sample.pred_pts_seg = pred_pts_seg
        assert 'pred_pts_seg' in det3d_data_sample
        assert _equal(det3d_data_sample.pred_pts_seg.pts_instance_mask,
                      pred_pts_seg_data['pts_instance_mask'])
        assert _equal(det3d_data_sample.pred_pts_seg.pts_semantic_mask,
                      pred_pts_seg_data['pts_semantic_mask'])
VVsssssk's avatar
VVsssssk committed
111
112
113
114
115
116

        # test type error
        with pytest.raises(AssertionError):
            det3d_data_sample.pred_instances_3d = torch.rand(2, 4)

        with pytest.raises(AssertionError):
ZCMax's avatar
ZCMax committed
117
            det3d_data_sample.pred_pts_seg = torch.rand(20)
VVsssssk's avatar
VVsssssk committed
118
119

    def test_deleter(self):
VVsssssk's avatar
VVsssssk committed
120
121
        tmp_instances_3d_data = dict(
            bboxes_3d=torch.rand(4, 4), labels_3d=torch.rand(4))
VVsssssk's avatar
VVsssssk committed
122
123

        det3d_data_sample = Det3DDataSample()
VVsssssk's avatar
VVsssssk committed
124
        gt_instances_3d = InstanceData(data=tmp_instances_3d_data)
VVsssssk's avatar
VVsssssk committed
125
126
127
128
129
        det3d_data_sample.gt_instances_3d = gt_instances_3d
        assert 'gt_instances_3d' in det3d_data_sample
        del det3d_data_sample.gt_instances_3d
        assert 'gt_instances_3d' not in det3d_data_sample

VVsssssk's avatar
VVsssssk committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        pred_instances_3d = InstanceData(data=tmp_instances_3d_data)
        det3d_data_sample.pred_instances_3d = pred_instances_3d
        assert 'pred_instances_3d' in det3d_data_sample
        del det3d_data_sample.pred_instances_3d
        assert 'pred_instances_3d' not in det3d_data_sample

        pts_pred_instances_3d = InstanceData(data=tmp_instances_3d_data)
        det3d_data_sample.pts_pred_instances_3d = pts_pred_instances_3d
        assert 'pts_pred_instances_3d' in det3d_data_sample
        del det3d_data_sample.pts_pred_instances_3d
        assert 'pts_pred_instances_3d' not in det3d_data_sample

        img_pred_instances_3d = InstanceData(data=tmp_instances_3d_data)
        det3d_data_sample.img_pred_instances_3d = img_pred_instances_3d
        assert 'img_pred_instances_3d' in det3d_data_sample
        del det3d_data_sample.img_pred_instances_3d
        assert 'img_pred_instances_3d' not in det3d_data_sample

ZCMax's avatar
ZCMax committed
148
149
150
151
152
153
154
        pred_pts_seg_data = dict(
            pts_instance_mask=torch.rand(20), pts_semantic_mask=torch.rand(20))
        pred_pts_seg = PointData(**pred_pts_seg_data)
        det3d_data_sample.pred_pts_seg = pred_pts_seg
        assert 'pred_pts_seg' in det3d_data_sample
        del det3d_data_sample.pred_pts_seg
        assert 'pred_pts_seg' not in det3d_data_sample