Unverified Commit c72b2843 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Expose on Hub the public methods of the registration API (#6364)

* Expose on Hub the public methods of the registration API

* Limit methods and update docs.
parent 84469834
...@@ -176,6 +176,15 @@ Most pre-trained models can be accessed directly via PyTorch Hub without having ...@@ -176,6 +176,15 @@ Most pre-trained models can be accessed directly via PyTorch Hub without having
weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2") weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2")
model = torch.hub.load("pytorch/vision", "resnet50", weights=weights) model = torch.hub.load("pytorch/vision", "resnet50", weights=weights)
You can also retrieve all the available weights of a specific model via PyTorch Hub by doing:
.. code:: python
import torch
weight_enum = torch.hub.load("pytorch/vision", "get_model_weights", name="resnet50")
print([weight for weight in weight_enum])
The only exception to the above are the detection models included on The only exception to the above are the detection models included on
:mod:`torchvision.models.detection`. These models require TorchVision :mod:`torchvision.models.detection`. These models require TorchVision
to be installed because they depend on custom C++ operators. to be installed because they depend on custom C++ operators.
......
# Optional list of dependencies required by the package # Optional list of dependencies required by the package
dependencies = ["torch"] dependencies = ["torch"]
from torchvision.models import get_weight from torchvision.models import get_model_weights, get_weight
from torchvision.models.alexnet import alexnet from torchvision.models.alexnet import alexnet
from torchvision.models.convnext import convnext_base, convnext_large, convnext_small, convnext_tiny from torchvision.models.convnext import convnext_base, convnext_large, convnext_small, convnext_tiny
from torchvision.models.densenet import densenet121, densenet161, densenet169, densenet201 from torchvision.models.densenet import densenet121, densenet161, densenet169, densenet201
......
...@@ -115,7 +115,7 @@ def get_weight(name: str) -> WeightsEnum: ...@@ -115,7 +115,7 @@ def get_weight(name: str) -> WeightsEnum:
W = TypeVar("W", bound=WeightsEnum) W = TypeVar("W", bound=WeightsEnum)
def get_model_weights(model: Union[Callable, str]) -> W: def get_model_weights(name: Union[Callable, str]) -> W:
""" """
Retuns the weights enum class associated to the given model. Retuns the weights enum class associated to the given model.
...@@ -127,8 +127,7 @@ def get_model_weights(model: Union[Callable, str]) -> W: ...@@ -127,8 +127,7 @@ def get_model_weights(model: Union[Callable, str]) -> W:
Returns: Returns:
weights_enum (W): The weights enum class associated with the model. weights_enum (W): The weights enum class associated with the model.
""" """
if isinstance(model, str): model = find_model(name) if isinstance(name, str) else name
model = find_model(model)
return cast(W, _get_enum_from_fn(model)) return cast(W, _get_enum_from_fn(model))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment