Unverified Commit 090d8237 authored by talregev's avatar talregev Committed by GitHub
Browse files

Improve test of backbone utils (#5552)

parent a8bde781
......@@ -7,7 +7,7 @@ import torch
from common_utils import set_rng_seed
from torchvision import models
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.models.detection.backbone_utils import mobilenet_backbone, resnet_fpn_backbone
from torchvision.models.detection.backbone_utils import BackboneWithFPN, mobilenet_backbone, resnet_fpn_backbone
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
......@@ -19,7 +19,9 @@ def get_available_models():
@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
def test_resnet_fpn_backbone(backbone_name):
x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu")
y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x)
model = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)
assert isinstance(model, BackboneWithFPN)
y = model(x)
assert list(y.keys()) == ["0", "1", "2", "3", "pool"]
with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
......@@ -38,6 +40,10 @@ def test_mobilenet_backbone(backbone_name):
mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[-1, 0, 1, 2])
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[3, 4, 5, 6])
model_fpn = mobilenet_backbone(backbone_name, False, fpn=True)
assert isinstance(model_fpn, BackboneWithFPN)
model = mobilenet_backbone(backbone_name, False, fpn=False)
assert isinstance(model, torch.nn.Sequential)
# Needed by TestFxFeatureExtraction.test_leaf_module_and_function
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment