test_point_data.py 3.12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Copyright (c) OpenMMLab. All rights reserved.
import random
from unittest import TestCase

import numpy as np
import pytest
import torch

from mmdet3d.structures import PointData


class TestPointData(TestCase):

    def setup_data(self):
        metainfo = dict(sample_idx=random.randint(0, 100))
        points = torch.rand((5, 3))
        point_data = PointData(metainfo=metainfo, points=points)
        return point_data

    def test_set_data(self):
        point_data = self.setup_data()

        # test set '_metainfo_fields' or '_data_fields'
        with self.assertRaises(AttributeError):
            point_data._metainfo_fields = 1
        with self.assertRaises(AttributeError):
            point_data._data_fields = 1

        point_data.keypoints = torch.rand((5, 2))
        assert 'keypoints' in point_data

    def test_getitem(self):
        point_data = PointData()
        # length must be greater than 0
        with self.assertRaises(IndexError):
            point_data[1]

        point_data = self.setup_data()
        assert len(point_data) == 5
        slice_point_data = point_data[:2]
        assert len(slice_point_data) == 2
        slice_point_data = point_data[1]
        assert len(slice_point_data) == 1
        # assert the index should in 0 ~ len(point_data) - 1
        with pytest.raises(IndexError):
            point_data[5]

        # isinstance(str, slice, int, torch.LongTensor, torch.BoolTensor)
        item = torch.Tensor([1, 2, 3, 4])  # float
        with pytest.raises(AssertionError):
            point_data[item]

        # when input is a bool tensor, The shape of
        # the input at index 0 should equal to
        # the value length in instance_data_field
        with pytest.raises(AssertionError):
            point_data[item.bool()]

        # test LongTensor
        long_tensor = torch.randint(5, (2, ))
        long_index_point_data = point_data[long_tensor]
        assert len(long_index_point_data) == len(long_tensor)

        # test BoolTensor
        bool_tensor = torch.rand(5) > 0.5
        bool_index_point_data = point_data[bool_tensor]
        assert len(bool_index_point_data) == bool_tensor.sum()
        bool_tensor = torch.rand(5) > 1
        empty_point_data = point_data[bool_tensor]
        assert len(empty_point_data) == bool_tensor.sum()

        # test list index
        list_index = [1, 2]
        list_index_point_data = point_data[list_index]
        assert len(list_index_point_data) == len(list_index)

        # test list bool
        list_bool = [True, False, True, False, False]
        list_bool_point_data = point_data[list_bool]
        assert len(list_bool_point_data) == 2

        # test numpy
        long_numpy = np.random.randint(5, size=2)
        long_numpy_point_data = point_data[long_numpy]
        assert len(long_numpy_point_data) == len(long_numpy)

        bool_numpy = np.random.rand(5) > 0.5
        bool_numpy_point_data = point_data[bool_numpy]
        assert len(bool_numpy_point_data) == bool_numpy.sum()

    def test_len(self):
        point_data = self.setup_data()
        assert len(point_data) == 5
        point_data = PointData()
        assert len(point_data) == 0