Unverified Commit 2cd25c1a authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Fix resnet_fpn_backbone(pretrained=True) (#7172)

parent 135a0f9e
......@@ -9,6 +9,7 @@ from common_extended_utils import get_file_size_mb, get_ops
from torchvision import models
from torchvision.models import get_model_weights, Weights, WeightsEnum
from torchvision.models._utils import handle_legacy_interface
from torchvision.models.detection.backbone_utils import mobilenet_backbone, resnet_fpn_backbone
run_if_test_with_extended = pytest.mark.skipif(
os.getenv("PYTORCH_TEST_WITH_EXTENDED", "0") != "1",
......@@ -425,7 +426,11 @@ class TestHandleLegacyInterface:
+ TM.list_model_fns(models.quantization)
+ TM.list_model_fns(models.segmentation)
+ TM.list_model_fns(models.video)
+ TM.list_model_fns(models.optical_flow),
+ TM.list_model_fns(models.optical_flow)
+ [
lambda pretrained: resnet_fpn_backbone(backbone_name="resnet50", pretrained=pretrained),
lambda pretrained: mobilenet_backbone(backbone_name="mobilenet_v2", fpn=False, pretrained=pretrained),
],
)
@run_if_test_with_extended
def test_pretrained_deprecation(self, model_fn):
......
......@@ -6,7 +6,7 @@ from enum import Enum
from functools import partial
from inspect import signature
from types import ModuleType
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, TypeVar, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union
from torch import nn
......@@ -138,7 +138,7 @@ def get_weight(name: str) -> WeightsEnum:
return weights_enum[value_name]
def get_model_weights(name: Union[Callable, str]) -> WeightsEnum:
def get_model_weights(name: Union[Callable, str]) -> Type[WeightsEnum]:
"""
Returns the weights enum class associated to the given model.
......@@ -152,7 +152,7 @@ def get_model_weights(name: Union[Callable, str]) -> WeightsEnum:
return _get_enum_from_fn(model)
def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
def _get_enum_from_fn(fn: Callable) -> Type[WeightsEnum]:
"""
Internal method that gets the weight enum of a specific model builder method.
......@@ -182,7 +182,7 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
"The WeightsEnum class for the specific method couldn't be retrieved. Make sure the typing info is correct."
)
return cast(WeightsEnum, weights_enum)
return weights_enum
M = TypeVar("M", bound=nn.Module)
......
......@@ -62,7 +62,7 @@ class BackboneWithFPN(nn.Module):
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"),
lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
),
)
def resnet_fpn_backbone(
......@@ -177,7 +177,7 @@ def _validate_trainable_layers(
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"),
lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
),
)
def mobilenet_backbone(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment