Unverified Commit cac4e228 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Make get_model_builder public (#6560)

parent a67cc87a
......@@ -29,6 +29,21 @@ def test_get_model(name, model_class):
assert isinstance(models.get_model(name), model_class)
@pytest.mark.parametrize(
"name, model_fn",
[
("resnet50", models.resnet50),
("retinanet_resnet50_fpn_v2", models.detection.retinanet_resnet50_fpn_v2),
("raft_large", models.optical_flow.raft_large),
("quantized_resnet50", models.quantization.resnet50),
("lraspp_mobilenet_v3_large", models.segmentation.lraspp_mobilenet_v3_large),
("mvit_v1_b", models.video.mvit_v1_b),
],
)
def test_get_model_builder(name, model_fn):
assert models.get_model_builder(name) == model_fn
@pytest.mark.parametrize(
"name, weight",
[
......
......@@ -17,7 +17,7 @@ import torch.nn as nn
from _utils_internal import get_relative_path
from common_utils import cpu_and_gpu, freeze_rng_state, map_nested_tensor_object, needs_cuda, set_rng_seed
from torchvision import models
from torchvision.models._api import find_model, list_models
from torchvision.models import get_model_builder, list_models
ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1"
......@@ -25,7 +25,7 @@ SKIP_BIG_MODEL = os.getenv("SKIP_BIG_MODEL", "1") == "1"
def list_model_fns(module):
return [find_model(name) for name in list_models(module)]
return [get_model_builder(name) for name in list_models(module)]
@pytest.fixture
......
......@@ -14,4 +14,4 @@ from .vgg import *
from .vision_transformer import *
from .swin_transformer import *
from . import detection, optical_flow, quantization, segmentation, video
from ._api import get_model, get_model_weights, get_weight, list_models
from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models
......@@ -13,7 +13,7 @@ from torchvision._utils import StrEnum
from .._internally_replaced_utils import load_state_dict_from_url
__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_weights", "get_weight", "list_models"]
__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_builder", "get_model_weights", "get_weight", "list_models"]
@dataclass
......@@ -127,7 +127,7 @@ def get_model_weights(name: Union[Callable, str]) -> W:
Returns:
weights_enum (W): The weights enum class associated with the model.
"""
model = find_model(name) if isinstance(name, str) else name
model = get_model_builder(name) if isinstance(name, str) else name
return cast(W, _get_enum_from_fn(model))
......@@ -199,7 +199,18 @@ def list_models(module: Optional[ModuleType] = None) -> List[str]:
return sorted(models)
def find_model(name: str) -> Callable[..., M]:
def get_model_builder(name: str) -> Callable[..., M]:
"""
Gets the model name and returns the model builder method.
.. betastatus:: function
Args:
name (str): The name under which the model is registered.
Returns:
fn (Callable): The model builder method.
"""
name = name.lower()
try:
fn = BUILTIN_MODELS[name]
......@@ -221,5 +232,5 @@ def get_model(name: str, **config: Any) -> M:
Returns:
model (nn.Module): The initialized model.
"""
fn = find_model(name)
fn = get_model_builder(name)
return fn(**config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment