test_det3d_data_sample.py 6.79 KB
Newer Older
ZCMax's avatar
ZCMax committed
1
# Copyright (c) OpenMMLab. All rights reserved.
VVsssssk's avatar
VVsssssk committed
2
3
4
5
6
from unittest import TestCase

import numpy as np
import pytest
import torch
7
from mmengine.structures import InstanceData
VVsssssk's avatar
VVsssssk committed
8

zhangshilong's avatar
zhangshilong committed
9
from mmdet3d.structures import Det3DDataSample, PointData
VVsssssk's avatar
VVsssssk committed
10
11
12
13
14
15
16
17
18


def _equal(a, b):
    if isinstance(a, (torch.Tensor, np.ndarray)):
        return (a == b).all()
    else:
        return a == b


19
class TestDet3DDataSample(TestCase):
VVsssssk's avatar
VVsssssk committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

    def test_init(self):
        meta_info = dict(
            img_size=[256, 256],
            scale_factor=np.array([1.5, 1.5]),
            img_shape=torch.rand(4))

        det3d_data_sample = Det3DDataSample(metainfo=meta_info)
        assert 'img_size' in det3d_data_sample
        assert det3d_data_sample.img_size == [256, 256]
        assert det3d_data_sample.get('img_size') == [256, 256]

    def test_setter(self):
        det3d_data_sample = Det3DDataSample()
        # test gt_instances_3d
        gt_instances_3d_data = dict(
36
            bboxes_3d=torch.rand(4, 7), labels_3d=torch.rand(4))
VVsssssk's avatar
VVsssssk committed
37
38
39
        gt_instances_3d = InstanceData(**gt_instances_3d_data)
        det3d_data_sample.gt_instances_3d = gt_instances_3d
        assert 'gt_instances_3d' in det3d_data_sample
VVsssssk's avatar
VVsssssk committed
40
41
42
43
44
45
        assert _equal(det3d_data_sample.gt_instances_3d.bboxes_3d,
                      gt_instances_3d_data['bboxes_3d'])
        assert _equal(det3d_data_sample.gt_instances_3d.labels_3d,
                      gt_instances_3d_data['labels_3d'])

        # test pred_instances_3d
VVsssssk's avatar
VVsssssk committed
46
        pred_instances_3d_data = dict(
47
            bboxes_3d=torch.rand(2, 7),
VVsssssk's avatar
VVsssssk committed
48
49
            labels_3d=torch.rand(2),
            scores_3d=torch.rand(2))
VVsssssk's avatar
VVsssssk committed
50
51
52
        pred_instances_3d = InstanceData(**pred_instances_3d_data)
        det3d_data_sample.pred_instances_3d = pred_instances_3d
        assert 'pred_instances_3d' in det3d_data_sample
VVsssssk's avatar
VVsssssk committed
53
54
55
56
57
58
59
60
61
        assert _equal(det3d_data_sample.pred_instances_3d.bboxes_3d,
                      pred_instances_3d_data['bboxes_3d'])
        assert _equal(det3d_data_sample.pred_instances_3d.labels_3d,
                      pred_instances_3d_data['labels_3d'])
        assert _equal(det3d_data_sample.pred_instances_3d.scores_3d,
                      pred_instances_3d_data['scores_3d'])

        # test pts_pred_instances_3d
        pts_pred_instances_3d_data = dict(
62
            bboxes_3d=torch.rand(2, 7),
VVsssssk's avatar
VVsssssk committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
            labels_3d=torch.rand(2),
            scores_3d=torch.rand(2))
        pts_pred_instances_3d = InstanceData(**pts_pred_instances_3d_data)
        det3d_data_sample.pts_pred_instances_3d = pts_pred_instances_3d
        assert 'pts_pred_instances_3d' in det3d_data_sample
        assert _equal(det3d_data_sample.pts_pred_instances_3d.bboxes_3d,
                      pts_pred_instances_3d_data['bboxes_3d'])
        assert _equal(det3d_data_sample.pts_pred_instances_3d.labels_3d,
                      pts_pred_instances_3d_data['labels_3d'])
        assert _equal(det3d_data_sample.pts_pred_instances_3d.scores_3d,
                      pts_pred_instances_3d_data['scores_3d'])

        # test img_pred_instances_3d
        img_pred_instances_3d_data = dict(
77
            bboxes_3d=torch.rand(2, 7),
VVsssssk's avatar
VVsssssk committed
78
79
80
81
82
83
84
85
86
87
88
            labels_3d=torch.rand(2),
            scores_3d=torch.rand(2))
        img_pred_instances_3d = InstanceData(**img_pred_instances_3d_data)
        det3d_data_sample.img_pred_instances_3d = img_pred_instances_3d
        assert 'img_pred_instances_3d' in det3d_data_sample
        assert _equal(det3d_data_sample.img_pred_instances_3d.bboxes_3d,
                      img_pred_instances_3d_data['bboxes_3d'])
        assert _equal(det3d_data_sample.img_pred_instances_3d.labels_3d,
                      img_pred_instances_3d_data['labels_3d'])
        assert _equal(det3d_data_sample.img_pred_instances_3d.scores_3d,
                      img_pred_instances_3d_data['scores_3d'])
VVsssssk's avatar
VVsssssk committed
89

90
        # test gt_pts_seg
ZCMax's avatar
ZCMax committed
91
92
93
94
95
96
97
98
99
100
        gt_pts_seg_data = dict(
            pts_instance_mask=torch.rand(20), pts_semantic_mask=torch.rand(20))
        gt_pts_seg = PointData(**gt_pts_seg_data)
        det3d_data_sample.gt_pts_seg = gt_pts_seg
        assert 'gt_pts_seg' in det3d_data_sample
        assert _equal(det3d_data_sample.gt_pts_seg.pts_instance_mask,
                      gt_pts_seg_data['pts_instance_mask'])
        assert _equal(det3d_data_sample.gt_pts_seg.pts_semantic_mask,
                      gt_pts_seg_data['pts_semantic_mask'])

101
        # test pred_pts_seg
ZCMax's avatar
ZCMax committed
102
103
104
105
106
107
108
109
110
        pred_pts_seg_data = dict(
            pts_instance_mask=torch.rand(20), pts_semantic_mask=torch.rand(20))
        pred_pts_seg = PointData(**pred_pts_seg_data)
        det3d_data_sample.pred_pts_seg = pred_pts_seg
        assert 'pred_pts_seg' in det3d_data_sample
        assert _equal(det3d_data_sample.pred_pts_seg.pts_instance_mask,
                      pred_pts_seg_data['pts_instance_mask'])
        assert _equal(det3d_data_sample.pred_pts_seg.pts_semantic_mask,
                      pred_pts_seg_data['pts_semantic_mask'])
VVsssssk's avatar
VVsssssk committed
111
112
113
114
115
116

        # test type error
        with pytest.raises(AssertionError):
            det3d_data_sample.pred_instances_3d = torch.rand(2, 4)

        with pytest.raises(AssertionError):
ZCMax's avatar
ZCMax committed
117
            det3d_data_sample.pred_pts_seg = torch.rand(20)
VVsssssk's avatar
VVsssssk committed
118
119

    def test_deleter(self):
VVsssssk's avatar
VVsssssk committed
120
121
        tmp_instances_3d_data = dict(
            bboxes_3d=torch.rand(4, 4), labels_3d=torch.rand(4))
VVsssssk's avatar
VVsssssk committed
122
123

        det3d_data_sample = Det3DDataSample()
VVsssssk's avatar
VVsssssk committed
124
        gt_instances_3d = InstanceData(data=tmp_instances_3d_data)
VVsssssk's avatar
VVsssssk committed
125
126
127
128
129
        det3d_data_sample.gt_instances_3d = gt_instances_3d
        assert 'gt_instances_3d' in det3d_data_sample
        del det3d_data_sample.gt_instances_3d
        assert 'gt_instances_3d' not in det3d_data_sample

VVsssssk's avatar
VVsssssk committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        pred_instances_3d = InstanceData(data=tmp_instances_3d_data)
        det3d_data_sample.pred_instances_3d = pred_instances_3d
        assert 'pred_instances_3d' in det3d_data_sample
        del det3d_data_sample.pred_instances_3d
        assert 'pred_instances_3d' not in det3d_data_sample

        pts_pred_instances_3d = InstanceData(data=tmp_instances_3d_data)
        det3d_data_sample.pts_pred_instances_3d = pts_pred_instances_3d
        assert 'pts_pred_instances_3d' in det3d_data_sample
        del det3d_data_sample.pts_pred_instances_3d
        assert 'pts_pred_instances_3d' not in det3d_data_sample

        img_pred_instances_3d = InstanceData(data=tmp_instances_3d_data)
        det3d_data_sample.img_pred_instances_3d = img_pred_instances_3d
        assert 'img_pred_instances_3d' in det3d_data_sample
        del det3d_data_sample.img_pred_instances_3d
        assert 'img_pred_instances_3d' not in det3d_data_sample

ZCMax's avatar
ZCMax committed
148
149
150
151
152
153
154
        pred_pts_seg_data = dict(
            pts_instance_mask=torch.rand(20), pts_semantic_mask=torch.rand(20))
        pred_pts_seg = PointData(**pred_pts_seg_data)
        det3d_data_sample.pred_pts_seg = pred_pts_seg
        assert 'pred_pts_seg' in det3d_data_sample
        del det3d_data_sample.pred_pts_seg
        assert 'pred_pts_seg' not in det3d_data_sample