Unverified Commit 0013d931 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Port test_backbone_utils.py to pytest (#3991)

parent 182f80df
import unittest
import torch import torch
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
import pytest
class ResnetFPNBackboneTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.dtype = torch.float32
def test_resnet18_fpn_backbone(self):
device = torch.device('cpu')
x = torch.rand(1, 3, 300, 300, dtype=self.dtype, device=device)
resnet18_fpn = resnet_fpn_backbone(backbone_name='resnet18', pretrained=False)
y = resnet18_fpn(x)
self.assertEqual(list(y.keys()), ['0', '1', '2', '3', 'pool'])
def test_resnet50_fpn_backbone(self): @pytest.mark.parametrize('backbone_name', ('resnet18', 'resnet50'))
device = torch.device('cpu') def test_resnet_fpn_backbone(backbone_name):
x = torch.rand(1, 3, 300, 300, dtype=self.dtype, device=device) x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device='cpu')
resnet50_fpn = resnet_fpn_backbone(backbone_name='resnet50', pretrained=False) y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x)
y = resnet50_fpn(x) assert list(y.keys()) == ['0', '1', '2', '3', 'pool']
self.assertEqual(list(y.keys()), ['0', '1', '2', '3', 'pool'])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment