test_backbone_utils.py 877 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import unittest


import torch
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone


class ResnetFPNBackboneTester(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.dtype = torch.float32

    def test_resnet18_fpn_backbone(self):
        device = torch.device('cpu')
        x = torch.rand(1, 3, 300, 300, dtype=self.dtype, device=device)
        resnet18_fpn = resnet_fpn_backbone(backbone_name='resnet18', pretrained=False)
        y = resnet18_fpn(x)
        assert list(y.keys()) == [0, 1, 2, 3, 'pool']

    def test_resnet50_fpn_backbone(self):
        device = torch.device('cpu')
        x = torch.rand(1, 3, 300, 300, dtype=self.dtype, device=device)
        resnet50_fpn = resnet_fpn_backbone(backbone_name='resnet50', pretrained=False)
        y = resnet50_fpn(x)
        assert list(y.keys()) == [0, 1, 2, 3, 'pool']