"vscode:/vscode.git/clone" did not exist on "4474eaf5528aa073ce5ea6dc8c4136dc2b8f7449"
test_models_detection_anchor_utils.py 3.47 KB
Newer Older
Aditya Oke's avatar
Aditya Oke committed
1
import torch
2
from common_utils import TestCase
3
from _assert_utils import assert_equal
4
from torchvision.models.detection.anchor_utils import AnchorGenerator, DefaultBoxGenerator
Aditya Oke's avatar
Aditya Oke committed
5
from torchvision.models.detection.image_list import ImageList
6
import pytest
Aditya Oke's avatar
Aditya Oke committed
7
8


9
class Tester(TestCase):
Aditya Oke's avatar
Aditya Oke committed
10
11
12
13
14
15
16
    def test_incorrect_anchors(self):
        incorrect_sizes = ((2, 4, 8), (32, 8), )
        incorrect_aspects = (0.5, 1.0)
        anc = AnchorGenerator(incorrect_sizes, incorrect_aspects)
        image1 = torch.randn(3, 800, 800)
        image_list = ImageList(image1, [(800, 800)])
        feature_maps = [torch.randn(1, 50)]
17
        pytest.raises(ValueError, anc, image_list, feature_maps)
18
19

    def _init_test_anchor_generator(self):
20
21
        anchor_sizes = ((10,),)
        aspect_ratios = ((1,),)
22
23
24
25
        anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)

        return anchor_generator

26
27
28
29
30
31
    def _init_test_defaultbox_generator(self):
        aspect_ratios = [[2]]
        dbox_generator = DefaultBoxGenerator(aspect_ratios)

        return dbox_generator

32
33
    def get_features(self, images):
        s0, s1 = images.shape[-2:]
34
        features = [torch.rand(2, 8, s0 // 5, s1 // 5)]
35
36
37
        return features

    def test_anchor_generator(self):
38
        images = torch.randn(2, 3, 15, 15)
39
40
41
42
43
44
45
46
        features = self.get_features(images)
        image_shapes = [i.shape[-2:] for i in images]
        images = ImageList(images, image_shapes)

        model = self._init_test_anchor_generator()
        model.eval()
        anchors = model(images, features)

47
        # Estimate the number of target anchors
48
49
50
51
52
        grid_sizes = [f.shape[-2:] for f in features]
        num_anchors_estimated = 0
        for sizes, num_anchors_per_loc in zip(grid_sizes, model.num_anchors_per_location()):
            num_anchors_estimated += sizes[0] * sizes[1] * num_anchors_per_loc

53
54
55
56
57
58
59
60
61
62
        anchors_output = torch.tensor([[-5., -5., 5., 5.],
                                       [0., -5., 10., 5.],
                                       [5., -5., 15., 5.],
                                       [-5., 0., 5., 10.],
                                       [0., 0., 10., 10.],
                                       [5., 0., 15., 10.],
                                       [-5., 5., 5., 15.],
                                       [0., 5., 10., 15.],
                                       [5., 5., 15., 15.]])

63
64
65
66
        assert num_anchors_estimated == 9
        assert len(anchors) == 2
        assert tuple(anchors[0].shape) == (9, 4)
        assert tuple(anchors[1].shape) == (9, 4)
67
68
        assert_equal(anchors[0], anchors_output)
        assert_equal(anchors[1], anchors_output)
69
70
71
72
73
74
75
76
77
78
79
80

    def test_defaultbox_generator(self):
        images = torch.zeros(2, 3, 15, 15)
        features = [torch.zeros(2, 8, 1, 1)]
        image_shapes = [i.shape[-2:] for i in images]
        images = ImageList(images, image_shapes)

        model = self._init_test_defaultbox_generator()
        model.eval()
        dboxes = model(images, features)

        dboxes_output = torch.tensor([
81
82
83
84
            [6.3750, 6.3750, 8.6250, 8.6250],
            [4.7443, 4.7443, 10.2557, 10.2557],
            [5.9090, 6.7045, 9.0910, 8.2955],
            [6.7045, 5.9090, 8.2955, 9.0910]
85
86
        ])

87
88
89
        assert len(dboxes) == 2
        assert tuple(dboxes[0].shape) == (4, 4)
        assert tuple(dboxes[1].shape) == (4, 4)
90
91
        torch.testing.assert_close(dboxes[0], dboxes_output, rtol=1e-5, atol=1e-8)
        torch.testing.assert_close(dboxes[1], dboxes_output, rtol=1e-5, atol=1e-8)