test_backbone_utils.py 425 Bytes
Newer Older
1
2
3
import torch
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

4
import pytest
5
6


7
8
9
10
11
@pytest.mark.parametrize('backbone_name', ('resnet18', 'resnet50'))
def test_resnet_fpn_backbone(backbone_name):
    x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device='cpu')
    y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x)
    assert list(y.keys()) == ['0', '1', '2', '3', 'pool']