Unverified Commit 9e788719 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Expose `get_weight` to Torch Hub (#6026)

* Prefixing `_get_enum_from_fn` with underscore

* Exposing `get_weight` to Torch Hub
parent 8e5844fc
# Optional list of dependencies required by the package # Optional list of dependencies required by the package
dependencies = ["torch"] dependencies = ["torch"]
from torchvision.models import get_weight
from torchvision.models.alexnet import alexnet from torchvision.models.alexnet import alexnet
from torchvision.models.convnext import convnext_tiny, convnext_small, convnext_base, convnext_large from torchvision.models.convnext import convnext_tiny, convnext_small, convnext_base, convnext_large
from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161 from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161
......
...@@ -107,7 +107,7 @@ def get_weight(name: str) -> WeightsEnum: ...@@ -107,7 +107,7 @@ def get_weight(name: str) -> WeightsEnum:
return weights_enum.from_str(value_name) return weights_enum.from_str(value_name)
def get_enum_from_fn(fn: Callable) -> WeightsEnum: def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
""" """
Internal method that gets the weight enum of a specific model builder method. Internal method that gets the weight enum of a specific model builder method.
Might be removed after the handle_legacy_interface is removed. Might be removed after the handle_legacy_interface is removed.
......
...@@ -6,7 +6,7 @@ from torchvision.ops import misc as misc_nn_ops ...@@ -6,7 +6,7 @@ from torchvision.ops import misc as misc_nn_ops
from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool
from .. import mobilenet, resnet from .. import mobilenet, resnet
from .._api import WeightsEnum, get_enum_from_fn from .._api import WeightsEnum, _get_enum_from_fn
from .._utils import IntermediateLayerGetter, handle_legacy_interface from .._utils import IntermediateLayerGetter, handle_legacy_interface
...@@ -62,7 +62,7 @@ class BackboneWithFPN(nn.Module): ...@@ -62,7 +62,7 @@ class BackboneWithFPN(nn.Module):
@handle_legacy_interface( @handle_legacy_interface(
weights=( weights=(
"pretrained", "pretrained",
lambda kwargs: get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"), lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"),
), ),
) )
def resnet_fpn_backbone( def resnet_fpn_backbone(
...@@ -177,7 +177,7 @@ def _validate_trainable_layers( ...@@ -177,7 +177,7 @@ def _validate_trainable_layers(
@handle_legacy_interface( @handle_legacy_interface(
weights=( weights=(
"pretrained", "pretrained",
lambda kwargs: get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"), lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"),
), ),
) )
def mobilenet_backbone( def mobilenet_backbone(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment