Unverified Commit cac4e228 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Make get_model_builder public (#6560)

parent a67cc87a
...@@ -29,6 +29,21 @@ def test_get_model(name, model_class): ...@@ -29,6 +29,21 @@ def test_get_model(name, model_class):
assert isinstance(models.get_model(name), model_class) assert isinstance(models.get_model(name), model_class)
@pytest.mark.parametrize(
"name, model_fn",
[
("resnet50", models.resnet50),
("retinanet_resnet50_fpn_v2", models.detection.retinanet_resnet50_fpn_v2),
("raft_large", models.optical_flow.raft_large),
("quantized_resnet50", models.quantization.resnet50),
("lraspp_mobilenet_v3_large", models.segmentation.lraspp_mobilenet_v3_large),
("mvit_v1_b", models.video.mvit_v1_b),
],
)
def test_get_model_builder(name, model_fn):
assert models.get_model_builder(name) == model_fn
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name, weight", "name, weight",
[ [
......
...@@ -17,7 +17,7 @@ import torch.nn as nn ...@@ -17,7 +17,7 @@ import torch.nn as nn
from _utils_internal import get_relative_path from _utils_internal import get_relative_path
from common_utils import cpu_and_gpu, freeze_rng_state, map_nested_tensor_object, needs_cuda, set_rng_seed from common_utils import cpu_and_gpu, freeze_rng_state, map_nested_tensor_object, needs_cuda, set_rng_seed
from torchvision import models from torchvision import models
from torchvision.models._api import find_model, list_models from torchvision.models import get_model_builder, list_models
ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1" ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1"
...@@ -25,7 +25,7 @@ SKIP_BIG_MODEL = os.getenv("SKIP_BIG_MODEL", "1") == "1" ...@@ -25,7 +25,7 @@ SKIP_BIG_MODEL = os.getenv("SKIP_BIG_MODEL", "1") == "1"
def list_model_fns(module): def list_model_fns(module):
return [find_model(name) for name in list_models(module)] return [get_model_builder(name) for name in list_models(module)]
@pytest.fixture @pytest.fixture
......
...@@ -14,4 +14,4 @@ from .vgg import * ...@@ -14,4 +14,4 @@ from .vgg import *
from .vision_transformer import * from .vision_transformer import *
from .swin_transformer import * from .swin_transformer import *
from . import detection, optical_flow, quantization, segmentation, video from . import detection, optical_flow, quantization, segmentation, video
from ._api import get_model, get_model_weights, get_weight, list_models from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models
...@@ -13,7 +13,7 @@ from torchvision._utils import StrEnum ...@@ -13,7 +13,7 @@ from torchvision._utils import StrEnum
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_weights", "get_weight", "list_models"] __all__ = ["WeightsEnum", "Weights", "get_model", "get_model_builder", "get_model_weights", "get_weight", "list_models"]
@dataclass @dataclass
...@@ -127,7 +127,7 @@ def get_model_weights(name: Union[Callable, str]) -> W: ...@@ -127,7 +127,7 @@ def get_model_weights(name: Union[Callable, str]) -> W:
Returns: Returns:
weights_enum (W): The weights enum class associated with the model. weights_enum (W): The weights enum class associated with the model.
""" """
model = find_model(name) if isinstance(name, str) else name model = get_model_builder(name) if isinstance(name, str) else name
return cast(W, _get_enum_from_fn(model)) return cast(W, _get_enum_from_fn(model))
...@@ -199,7 +199,18 @@ def list_models(module: Optional[ModuleType] = None) -> List[str]: ...@@ -199,7 +199,18 @@ def list_models(module: Optional[ModuleType] = None) -> List[str]:
return sorted(models) return sorted(models)
def find_model(name: str) -> Callable[..., M]: def get_model_builder(name: str) -> Callable[..., M]:
"""
Gets the model name and returns the model builder method.
.. betastatus:: function
Args:
name (str): The name under which the model is registered.
Returns:
fn (Callable): The model builder method.
"""
name = name.lower() name = name.lower()
try: try:
fn = BUILTIN_MODELS[name] fn = BUILTIN_MODELS[name]
...@@ -221,5 +232,5 @@ def get_model(name: str, **config: Any) -> M: ...@@ -221,5 +232,5 @@ def get_model(name: str, **config: Any) -> M:
Returns: Returns:
model (nn.Module): The initialized model. model (nn.Module): The initialized model.
""" """
fn = find_model(name) fn = get_model_builder(name)
return fn(**config) return fn(**config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment