Unverified Commit 0a919dbb authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add registration mechanism for models (#6333)

* Model registration mechanism.

* Add overwrite options to the dataset prototype registration mechanism.

* Adding example models.

* Fix module filtering

* Fix linter

* Fix docs

* Make name optional if same as model builder

* Apply updates from code-review.

* fix minor bug

* Adding getter for model weight enum

* Support both strings and callables on get_model_weight.

* linter fixes

* Fixing mypy.

* Renaming `get_model_weight` to `get_model_weights`

* Registering all classification models.

* Registering all video models.

* Registering all detection models.

* Registering all optical flow models.

* Fixing mypy.

* Registering all segmentation models.

* Registering all quantization models.

* Fixing linter

* Registering all prototype depth perception models.

* Adding tests and updating existing tests.

* Fix linters

* Fix tests.

* Add beta annotation on docs.

* Fix tests.

* Apply changes from code-review.

* Adding documentation.

* Fix docs.
parent 63870514
......@@ -120,6 +120,46 @@ behavior, such as batch normalization. To switch between these modes, use
# Set model to eval mode
model.eval()
Model Registration Mechanism
----------------------------
.. betastatus:: registration mechanism
As of v0.14, TorchVision offers a new model registration mechanism which allows retreaving models
and weights by their names. Here are a few examples on how to use them:
.. code:: python
# List available models
all_models = list_models()
classification_models = list_models(module=torchvision.models)
# Initialize models
m1 = get_model("mobilenet_v3_large", weights=None)
m2 = get_model("quantized_mobilenet_v3_large", weights="DEFAULT")
# Fetch weights
weights = get_weight("MobileNet_V3_Large_QuantizedWeights.DEFAULT")
assert weights == MobileNet_V3_Large_QuantizedWeights.DEFAULT
weights_enum = get_model_weights("quantized_mobilenet_v3_large")
assert weights_enum == MobileNet_V3_Large_QuantizedWeights
weights_enum2 = get_model_weights(torchvision.models.quantization.mobilenet_v3_large)
assert weights_enum == weights_enum2
Here are the available public methods of the model registration mechanism:
.. currentmodule:: torchvision.models
.. autosummary::
:toctree: generated/
:template: function.rst
get_model
get_model_weights
get_weight
list_models
Using models from Hub
---------------------
......
......@@ -11,15 +11,6 @@ from torchvision.models.detection.backbone_utils import BackboneWithFPN, mobilen
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
def get_available_models():
# TODO add a registration mechanism to torchvision.models
return [
k
for k, v in models.__dict__.items()
if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight"
]
@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
def test_resnet_fpn_backbone(backbone_name):
x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu")
......@@ -135,10 +126,10 @@ class TestFxFeatureExtraction:
eval_nodes = [n for n in eval_nodes if not any(x in n for x in exclude_nodes_filter)]
return random.sample(train_nodes, 10), random.sample(eval_nodes, 10)
@pytest.mark.parametrize("model_name", get_available_models())
@pytest.mark.parametrize("model_name", models.list_models(models))
def test_build_fx_feature_extractor(self, model_name):
set_rng_seed(0)
model = models.__dict__[model_name](**self.model_defaults).eval()
model = models.get_model(model_name, **self.model_defaults).eval()
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
# Check that it works with both a list and dict for return nodes
self._create_feature_extractor(
......@@ -172,9 +163,9 @@ class TestFxFeatureExtraction:
train_nodes, _ = get_graph_node_names(model)
assert all(a == b for a, b in zip(train_nodes, test_module_nodes))
@pytest.mark.parametrize("model_name", get_available_models())
@pytest.mark.parametrize("model_name", models.list_models(models))
def test_forward_backward(self, model_name):
model = models.__dict__[model_name](**self.model_defaults).train()
model = models.get_model(model_name, **self.model_defaults).train()
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
......@@ -211,10 +202,10 @@ class TestFxFeatureExtraction:
for k in ilg_out.keys():
assert ilg_out[k].equal(fgn_out[k])
@pytest.mark.parametrize("model_name", get_available_models())
@pytest.mark.parametrize("model_name", models.list_models(models))
def test_jit_forward_backward(self, model_name):
set_rng_seed(0)
model = models.__dict__[model_name](**self.model_defaults).train()
model = models.get_model(model_name, **self.model_defaults).train()
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
......
import importlib
import os
import pytest
import test_models as TM
import torch
from torchvision import models
from torchvision.models._api import Weights, WeightsEnum
from torchvision.models._api import get_model_weights, Weights, WeightsEnum
from torchvision.models._utils import handle_legacy_interface
......@@ -15,23 +14,52 @@ run_if_test_with_extended = pytest.mark.skipif(
)
def _get_parent_module(model_fn):
parent_module_name = ".".join(model_fn.__module__.split(".")[:-1])
module = importlib.import_module(parent_module_name)
return module
@pytest.mark.parametrize(
"name, model_class",
[
("resnet50", models.ResNet),
("retinanet_resnet50_fpn_v2", models.detection.RetinaNet),
("raft_large", models.optical_flow.RAFT),
("quantized_resnet50", models.quantization.QuantizableResNet),
("lraspp_mobilenet_v3_large", models.segmentation.LRASPP),
("mvit_v1_b", models.video.MViT),
],
)
def test_get_model(name, model_class):
assert isinstance(models.get_model(name), model_class)
def _get_model_weights(model_fn):
module = _get_parent_module(model_fn)
weights_name = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights"
try:
return next(
v
@pytest.mark.parametrize(
"name, weight",
[
("resnet50", models.ResNet50_Weights),
("retinanet_resnet50_fpn_v2", models.detection.RetinaNet_ResNet50_FPN_V2_Weights),
("raft_large", models.optical_flow.Raft_Large_Weights),
("quantized_resnet50", models.quantization.ResNet50_QuantizedWeights),
("lraspp_mobilenet_v3_large", models.segmentation.LRASPP_MobileNet_V3_Large_Weights),
("mvit_v1_b", models.video.MViT_V1_B_Weights),
],
)
def test_get_model_weights(name, weight):
assert models.get_model_weights(name) == weight
@pytest.mark.parametrize(
"module", [models, models.detection, models.quantization, models.segmentation, models.video, models.optical_flow]
)
def test_list_models(module):
def get_models_from_module(module):
return [
v.__name__
for k, v in module.__dict__.items()
if k.endswith(weights_name) and k.replace(weights_name, "").lower() == model_fn.__name__
)
except StopIteration:
return None
if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__
]
a = set(get_models_from_module(module))
b = set(x.replace("quantized_", "") for x in models.list_models(module))
assert len(b) > 0
assert a == b
@pytest.mark.parametrize(
......@@ -55,27 +83,27 @@ def test_get_weight(name, weight):
@pytest.mark.parametrize(
"model_fn",
TM.get_models_from_module(models)
+ TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(models.quantization)
+ TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video)
+ TM.get_models_from_module(models.optical_flow),
TM.list_model_fns(models)
+ TM.list_model_fns(models.detection)
+ TM.list_model_fns(models.quantization)
+ TM.list_model_fns(models.segmentation)
+ TM.list_model_fns(models.video)
+ TM.list_model_fns(models.optical_flow),
)
def test_naming_conventions(model_fn):
weights_enum = _get_model_weights(model_fn)
weights_enum = get_model_weights(model_fn)
assert weights_enum is not None
assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT")
@pytest.mark.parametrize(
"model_fn",
TM.get_models_from_module(models)
+ TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(models.quantization)
+ TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video)
+ TM.get_models_from_module(models.optical_flow),
TM.list_model_fns(models)
+ TM.list_model_fns(models.detection)
+ TM.list_model_fns(models.quantization)
+ TM.list_model_fns(models.segmentation)
+ TM.list_model_fns(models.video)
+ TM.list_model_fns(models.optical_flow),
)
@run_if_test_with_extended
def test_schema_meta_validation(model_fn):
......@@ -112,7 +140,7 @@ def test_schema_meta_validation(model_fn):
module_name = model_fn.__module__.split(".")[-2]
expected_fields = defaults["all"] | defaults[module_name]
weights_enum = _get_model_weights(model_fn)
weights_enum = get_model_weights(model_fn)
if len(weights_enum) == 0:
pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")
......@@ -153,17 +181,17 @@ def test_schema_meta_validation(model_fn):
@pytest.mark.parametrize(
"model_fn",
TM.get_models_from_module(models)
+ TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(models.quantization)
+ TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video)
+ TM.get_models_from_module(models.optical_flow),
TM.list_model_fns(models)
+ TM.list_model_fns(models.detection)
+ TM.list_model_fns(models.quantization)
+ TM.list_model_fns(models.segmentation)
+ TM.list_model_fns(models.video)
+ TM.list_model_fns(models.optical_flow),
)
@run_if_test_with_extended
def test_transforms_jit(model_fn):
model_name = model_fn.__name__
weights_enum = _get_model_weights(model_fn)
weights_enum = get_model_weights(model_fn)
if len(weights_enum) == 0:
pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")
......
......@@ -16,18 +16,15 @@ import torch.nn as nn
from _utils_internal import get_relative_path
from common_utils import cpu_and_gpu, freeze_rng_state, map_nested_tensor_object, needs_cuda, set_rng_seed
from torchvision import models
from torchvision.models._api import find_model, list_models
ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1"
SKIP_BIG_MODEL = os.getenv("SKIP_BIG_MODEL", "1") == "1"
def get_models_from_module(module):
# TODO add a registration mechanism to torchvision.models
return [
v
for k, v in module.__dict__.items()
if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight"
]
def list_model_fns(module):
return [find_model(name) for name in list_models(module)]
@pytest.fixture
......@@ -597,7 +594,7 @@ def test_vitc_models(model_fn, dev):
test_classification_model(model_fn, dev)
@pytest.mark.parametrize("model_fn", get_models_from_module(models))
@pytest.mark.parametrize("model_fn", list_model_fns(models))
@pytest.mark.parametrize("dev", cpu_and_gpu())
def test_classification_model(model_fn, dev):
set_rng_seed(0)
......@@ -633,7 +630,7 @@ def test_classification_model(model_fn, dev):
_check_input_backprop(model, x)
@pytest.mark.parametrize("model_fn", get_models_from_module(models.segmentation))
@pytest.mark.parametrize("model_fn", list_model_fns(models.segmentation))
@pytest.mark.parametrize("dev", cpu_and_gpu())
def test_segmentation_model(model_fn, dev):
set_rng_seed(0)
......@@ -695,7 +692,7 @@ def test_segmentation_model(model_fn, dev):
_check_input_backprop(model, x)
@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection))
@pytest.mark.parametrize("model_fn", list_model_fns(models.detection))
@pytest.mark.parametrize("dev", cpu_and_gpu())
def test_detection_model(model_fn, dev):
set_rng_seed(0)
......@@ -793,7 +790,7 @@ def test_detection_model(model_fn, dev):
_check_input_backprop(model, model_input)
@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection))
@pytest.mark.parametrize("model_fn", list_model_fns(models.detection))
def test_detection_model_validation(model_fn):
set_rng_seed(0)
model = model_fn(num_classes=50, weights=None, weights_backbone=None)
......@@ -822,7 +819,7 @@ def test_detection_model_validation(model_fn):
model(x, targets=targets)
@pytest.mark.parametrize("model_fn", get_models_from_module(models.video))
@pytest.mark.parametrize("model_fn", list_model_fns(models.video))
@pytest.mark.parametrize("dev", cpu_and_gpu())
def test_video_model(model_fn, dev):
set_rng_seed(0)
......@@ -868,7 +865,7 @@ def test_video_model(model_fn, dev):
),
reason="This Pytorch Build has not been built with fbgemm and qnnpack",
)
@pytest.mark.parametrize("model_fn", get_models_from_module(models.quantization))
@pytest.mark.parametrize("model_fn", list_model_fns(models.quantization))
def test_quantized_classification_model(model_fn):
set_rng_seed(0)
defaults = {
......@@ -917,7 +914,7 @@ def test_quantized_classification_model(model_fn):
torch.ao.quantization.convert(model, inplace=True)
@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection))
@pytest.mark.parametrize("model_fn", list_model_fns(models.detection))
def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_loading):
model_name = model_fn.__name__
max_trainable = _model_tests_values[model_name]["max_trainable"]
......@@ -930,9 +927,9 @@ def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_load
@needs_cuda
@pytest.mark.parametrize("model_builder", (models.optical_flow.raft_large, models.optical_flow.raft_small))
@pytest.mark.parametrize("model_fn", list_model_fns(models.optical_flow))
@pytest.mark.parametrize("scripted", (False, True))
def test_raft(model_builder, scripted):
def test_raft(model_fn, scripted):
torch.manual_seed(0)
......@@ -942,7 +939,7 @@ def test_raft(model_builder, scripted):
# reduced to 1)
corr_block = models.optical_flow.raft.CorrBlock(num_levels=2, radius=2)
model = model_builder(corr_block=corr_block).eval().to("cuda")
model = model_fn(corr_block=corr_block).eval().to("cuda")
if scripted:
model = torch.jit.script(model)
......@@ -954,7 +951,7 @@ def test_raft(model_builder, scripted):
flow_pred = preds[-1]
# Tolerance is fairly high, but there are 2 * H * W outputs to check
# The .pkl were generated on the AWS cluter, on the CI it looks like the resuts are slightly different
_assert_expected(flow_pred, name=model_builder.__name__, atol=1e-2, rtol=1)
_assert_expected(flow_pred, name=model_fn.__name__, atol=1e-2, rtol=1)
if __name__ == "__main__":
......
......@@ -99,8 +99,8 @@ class TestModelsDetectionNegativeSamples:
],
)
def test_forward_negative_sample_frcnn(self, name):
model = torchvision.models.detection.__dict__[name](
weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100
model = torchvision.models.get_model(
name, weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100
)
images, targets = self._make_empty_sample()
......
import pytest
import test_models as TM
import torch
import torchvision.prototype.models.depth.stereo.raft_stereo as raft_stereo
from common_utils import cpu_and_gpu, set_rng_seed
from torchvision.prototype import models
@pytest.mark.parametrize("model_builder", (raft_stereo.raft_stereo_base, raft_stereo.raft_stereo_realtime))
@pytest.mark.parametrize("model_fn", TM.list_model_fns(models.depth.stereo))
@pytest.mark.parametrize("model_mode", ("standard", "scripted"))
@pytest.mark.parametrize("dev", cpu_and_gpu())
def test_raft_stereo(model_builder, model_mode, dev):
def test_raft_stereo(model_fn, model_mode, dev):
# A simple test to make sure the model can do forward pass and jit scriptable
set_rng_seed(0)
# Use corr_pyramid and corr_block with smaller num_levels and radius to prevent nan output
# get the idea from test_models.test_raft
corr_pyramid = raft_stereo.CorrPyramid1d(num_levels=2)
corr_block = raft_stereo.CorrBlock1d(num_levels=2, radius=2)
model = model_builder(corr_pyramid=corr_pyramid, corr_block=corr_block).eval().to(dev)
corr_pyramid = models.depth.stereo.raft_stereo.CorrPyramid1d(num_levels=2)
corr_block = models.depth.stereo.raft_stereo.CorrBlock1d(num_levels=2, radius=2)
model = model_fn(corr_pyramid=corr_pyramid, corr_block=corr_block).eval().to(dev)
if model_mode == "scripted":
model = torch.jit.script(model)
......@@ -35,4 +35,4 @@ def test_raft_stereo(model_builder, model_mode, dev):
), f"The output shape of depth_pred should be [1, 1, 64, 64] but instead it is {preds[0].shape}"
# Test against expected file output
TM._assert_expected(depth_pred, name=model_builder.__name__, atol=1e-2, rtol=1e-2)
TM._assert_expected(depth_pred, name=model_fn.__name__, atol=1e-2, rtol=1e-2)
......@@ -14,4 +14,4 @@ from .vgg import *
from .vision_transformer import *
from .swin_transformer import *
from . import detection, optical_flow, quantization, segmentation, video
from ._api import get_weight
from ._api import get_model, get_model_weights, get_weight, list_models
......@@ -3,14 +3,17 @@ import inspect
import sys
from dataclasses import dataclass, fields
from inspect import signature
from typing import Any, Callable, cast, Dict, Mapping
from types import ModuleType
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, TypeVar, Union
from torch import nn
from torchvision._utils import StrEnum
from .._internally_replaced_utils import load_state_dict_from_url
__all__ = ["WeightsEnum", "Weights", "get_weight"]
__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_weights", "get_weight", "list_models"]
@dataclass
......@@ -75,7 +78,9 @@ class WeightsEnum(StrEnum):
def get_weight(name: str) -> WeightsEnum:
"""
Gets the weight enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1"
Gets the weights enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1"
.. betastatus:: function
Args:
name (str): The name of the weight enum entry.
......@@ -107,10 +112,29 @@ def get_weight(name: str) -> WeightsEnum:
return weights_enum.from_str(value_name)
W = TypeVar("W", bound=WeightsEnum)
def get_model_weights(model: Union[Callable, str]) -> W:
"""
Retuns the weights enum class associated to the given model.
.. betastatus:: function
Args:
name (callable or str): The model builder function or the name under which it is registered.
Returns:
weights_enum (W): The weights enum class associated with the model.
"""
if isinstance(model, str):
model = find_model(model)
return cast(W, _get_enum_from_fn(model))
def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
"""
Internal method that gets the weight enum of a specific model builder method.
Might be removed after the handle_legacy_interface is removed.
Args:
fn (Callable): The builder method used to create the model.
......@@ -140,3 +164,63 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
)
return cast(WeightsEnum, weights_enum)
M = TypeVar("M", bound=nn.Module)
BUILTIN_MODELS = {}
def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], Callable[..., M]]:
def wrapper(fn: Callable[..., M]) -> Callable[..., M]:
key = name if name is not None else fn.__name__
if key in BUILTIN_MODELS:
raise ValueError(f"An entry is already registered under the name '{key}'.")
BUILTIN_MODELS[key] = fn
return fn
return wrapper
def list_models(module: Optional[ModuleType] = None) -> List[str]:
"""
Returns a list with the names of registered models.
.. betastatus:: function
Args:
module (ModuleType, optional): The module from which we want to extract the available models.
Returns:
models (list): A list with the names of available models.
"""
models = [
k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__.rsplit(".", 1)[0] == module.__name__
]
return sorted(models)
def find_model(name: str) -> Callable[..., M]:
name = name.lower()
try:
fn = BUILTIN_MODELS[name]
except KeyError:
raise ValueError(f"Unknown model {name}")
return fn
def get_model(name: str, **config: Any) -> M:
"""
Gets the model name and configuration and returns an instantiated model.
.. betastatus:: function
Args:
name (str): The name under which the model is registered.
**config (Any): parameters passed to the model builder method.
Returns:
model (nn.Module): The initialized model.
"""
fn = find_model(name)
return fn(**config)
......@@ -6,7 +6,7 @@ import torch.nn as nn
from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface
......@@ -75,6 +75,7 @@ class AlexNet_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1))
def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
"""AlexNet model architecture from `One weird trick for parallelizing convolutional neural networks <https://arxiv.org/abs/1404.5997>`__.
......
......@@ -9,7 +9,7 @@ from ..ops.misc import Conv2dNormActivation, Permute
from ..ops.stochastic_depth import StochasticDepth
from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface
......@@ -278,6 +278,7 @@ class ConvNeXt_Large_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1))
def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
"""ConvNeXt Tiny model architecture from the
......@@ -308,6 +309,7 @@ def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress:
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1))
def convnext_small(
*, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any
......@@ -340,6 +342,7 @@ def convnext_small(
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1))
def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
"""ConvNeXt Base model architecture from the
......@@ -370,6 +373,7 @@ def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress:
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1))
def convnext_large(
*, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any
......
......@@ -11,7 +11,7 @@ from torch import Tensor
from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface
......@@ -337,6 +337,7 @@ class DenseNet201_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1))
def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
r"""Densenet-121 model from
......@@ -362,6 +363,7 @@ def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool
return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.IMAGENET1K_V1))
def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
r"""Densenet-161 model from
......@@ -387,6 +389,7 @@ def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool
return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.IMAGENET1K_V1))
def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
r"""Densenet-169 model from
......@@ -412,6 +415,7 @@ def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool
return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.IMAGENET1K_V1))
def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
r"""Densenet-201 model from
......
......@@ -7,7 +7,7 @@ from torchvision.ops import MultiScaleRoIAlign
from ...ops import misc as misc_nn_ops
from ...transforms._presets import ObjectDetection
from .._api import Weights, WeightsEnum
from .._api import register_model, Weights, WeightsEnum
from .._meta import _COCO_CATEGORIES
from .._utils import _ovewrite_value_param, handle_legacy_interface
from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights
......@@ -451,6 +451,7 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
DEFAULT = COCO_V1
@register_model()
@handle_legacy_interface(
weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
......@@ -569,6 +570,7 @@ def fasterrcnn_resnet50_fpn(
return model
@register_model()
def fasterrcnn_resnet50_fpn_v2(
*,
weights: Optional[FasterRCNN_ResNet50_FPN_V2_Weights] = None,
......@@ -685,6 +687,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
return model
@register_model()
@handle_legacy_interface(
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
......@@ -758,6 +761,7 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
)
@register_model()
@handle_legacy_interface(
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
......
......@@ -11,7 +11,7 @@ from ...ops import boxes as box_ops, generalized_box_iou_loss, misc as misc_nn_o
from ...ops.feature_pyramid_network import LastLevelP6P7
from ...transforms._presets import ObjectDetection
from ...utils import _log_api_usage_once
from .._api import Weights, WeightsEnum
from .._api import register_model, Weights, WeightsEnum
from .._meta import _COCO_CATEGORIES
from .._utils import _ovewrite_value_param, handle_legacy_interface
from ..resnet import resnet50, ResNet50_Weights
......@@ -666,6 +666,7 @@ class FCOS_ResNet50_FPN_Weights(WeightsEnum):
DEFAULT = COCO_V1
@register_model()
@handle_legacy_interface(
weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
......
......@@ -6,7 +6,7 @@ from torchvision.ops import MultiScaleRoIAlign
from ...ops import misc as misc_nn_ops
from ...transforms._presets import ObjectDetection
from .._api import Weights, WeightsEnum
from .._api import register_model, Weights, WeightsEnum
from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
from .._utils import _ovewrite_value_param, handle_legacy_interface
from ..resnet import resnet50, ResNet50_Weights
......@@ -353,6 +353,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
DEFAULT = COCO_V1
@register_model()
@handle_legacy_interface(
weights=(
"pretrained",
......
......@@ -6,7 +6,7 @@ from torchvision.ops import MultiScaleRoIAlign
from ...ops import misc as misc_nn_ops
from ...transforms._presets import ObjectDetection
from .._api import Weights, WeightsEnum
from .._api import register_model, Weights, WeightsEnum
from .._meta import _COCO_CATEGORIES
from .._utils import _ovewrite_value_param, handle_legacy_interface
from ..resnet import resnet50, ResNet50_Weights
......@@ -396,6 +396,7 @@ class MaskRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
DEFAULT = COCO_V1
@register_model()
@handle_legacy_interface(
weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
......@@ -503,6 +504,7 @@ def maskrcnn_resnet50_fpn(
return model
@register_model()
def maskrcnn_resnet50_fpn_v2(
*,
weights: Optional[MaskRCNN_ResNet50_FPN_V2_Weights] = None,
......
......@@ -11,7 +11,7 @@ from ...ops import boxes as box_ops, misc as misc_nn_ops, sigmoid_focal_loss
from ...ops.feature_pyramid_network import LastLevelP6P7
from ...transforms._presets import ObjectDetection
from ...utils import _log_api_usage_once
from .._api import Weights, WeightsEnum
from .._api import register_model, Weights, WeightsEnum
from .._meta import _COCO_CATEGORIES
from .._utils import _ovewrite_value_param, handle_legacy_interface
from ..resnet import resnet50, ResNet50_Weights
......@@ -715,6 +715,7 @@ class RetinaNet_ResNet50_FPN_V2_Weights(WeightsEnum):
DEFAULT = COCO_V1
@register_model()
@handle_legacy_interface(
weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
......@@ -817,6 +818,7 @@ def retinanet_resnet50_fpn(
return model
@register_model()
def retinanet_resnet50_fpn_v2(
*,
weights: Optional[RetinaNet_ResNet50_FPN_V2_Weights] = None,
......
......@@ -9,7 +9,7 @@ from torch import nn, Tensor
from ...ops import boxes as box_ops
from ...transforms._presets import ObjectDetection
from ...utils import _log_api_usage_once
from .._api import Weights, WeightsEnum
from .._api import register_model, Weights, WeightsEnum
from .._meta import _COCO_CATEGORIES
from .._utils import _ovewrite_value_param, handle_legacy_interface
from ..vgg import VGG, vgg16, VGG16_Weights
......@@ -568,6 +568,7 @@ def _vgg_extractor(backbone: VGG, highres: bool, trainable_layers: int):
return SSDFeatureExtractorVGG(backbone, highres)
@register_model()
@handle_legacy_interface(
weights=("pretrained", SSD300_VGG16_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", VGG16_Weights.IMAGENET1K_FEATURES),
......
......@@ -10,7 +10,7 @@ from ...ops.misc import Conv2dNormActivation
from ...transforms._presets import ObjectDetection
from ...utils import _log_api_usage_once
from .. import mobilenet
from .._api import Weights, WeightsEnum
from .._api import register_model, Weights, WeightsEnum
from .._meta import _COCO_CATEGORIES
from .._utils import _ovewrite_value_param, handle_legacy_interface
from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights
......@@ -204,6 +204,7 @@ class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum):
DEFAULT = COCO_V1
@register_model()
@handle_legacy_interface(
weights=("pretrained", SSDLite320_MobileNet_V3_Large_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
......
......@@ -12,7 +12,7 @@ from torchvision.ops import StochasticDepth
from ..ops.misc import Conv2dNormActivation, SqueezeExcitation
from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
......@@ -729,6 +729,7 @@ class EfficientNet_V2_L_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1))
def efficientnet_b0(
*, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any
......@@ -757,6 +758,7 @@ def efficientnet_b0(
return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1))
def efficientnet_b1(
*, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any
......@@ -785,6 +787,7 @@ def efficientnet_b1(
return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1))
def efficientnet_b2(
*, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any
......@@ -813,6 +816,7 @@ def efficientnet_b2(
return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1))
def efficientnet_b3(
*, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any
......@@ -841,6 +845,7 @@ def efficientnet_b3(
return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1))
def efficientnet_b4(
*, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any
......@@ -869,6 +874,7 @@ def efficientnet_b4(
return _efficientnet(inverted_residual_setting, 0.4, last_channel, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1))
def efficientnet_b5(
*, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any
......@@ -905,6 +911,7 @@ def efficientnet_b5(
)
@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.IMAGENET1K_V1))
def efficientnet_b6(
*, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any
......@@ -941,6 +948,7 @@ def efficientnet_b6(
)
@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.IMAGENET1K_V1))
def efficientnet_b7(
*, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any
......@@ -977,6 +985,7 @@ def efficientnet_b7(
)
@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1))
def efficientnet_v2_s(
*, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any
......@@ -1014,6 +1023,7 @@ def efficientnet_v2_s(
)
@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1))
def efficientnet_v2_m(
*, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any
......@@ -1051,6 +1061,7 @@ def efficientnet_v2_m(
)
@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1))
def efficientnet_v2_l(
*, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any
......
......@@ -10,7 +10,7 @@ from torch import Tensor
from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface
......@@ -296,6 +296,7 @@ class GoogLeNet_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", GoogLeNet_Weights.IMAGENET1K_V1))
def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
"""GoogLeNet (Inception v1) model architecture from
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment