Unverified Commit 9e788719 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Expose `get_weight` to Torch Hub (#6026)

* Prefixing `_get_enum_from_fn` with underscore

* Exposing `get_weight` to Torch Hub
parent 8e5844fc
# Optional list of dependencies required by the package
dependencies = ["torch"]
from torchvision.models import get_weight
from torchvision.models.alexnet import alexnet
from torchvision.models.convnext import convnext_tiny, convnext_small, convnext_base, convnext_large
from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161
......
......@@ -107,7 +107,7 @@ def get_weight(name: str) -> WeightsEnum:
return weights_enum.from_str(value_name)
def get_enum_from_fn(fn: Callable) -> WeightsEnum:
def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
"""
Internal method that gets the weight enum of a specific model builder method.
Might be removed after the handle_legacy_interface is removed.
......
......@@ -6,7 +6,7 @@ from torchvision.ops import misc as misc_nn_ops
from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool
from .. import mobilenet, resnet
from .._api import WeightsEnum, get_enum_from_fn
from .._api import WeightsEnum, _get_enum_from_fn
from .._utils import IntermediateLayerGetter, handle_legacy_interface
......@@ -62,7 +62,7 @@ class BackboneWithFPN(nn.Module):
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"),
lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"),
),
)
def resnet_fpn_backbone(
......@@ -177,7 +177,7 @@ def _validate_trainable_layers(
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"),
lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"),
),
)
def mobilenet_backbone(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment