Unverified Commit 0a919dbb authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add registration mechanism for models (#6333)

* Model registration mechanism.

* Add overwrite options to the dataset prototype registration mechanism.

* Adding example models.

* Fix module filtering

* Fix linter

* Fix docs

* Make name optional if same as model builder

* Apply updates from code-review.

* fix minor bug

* Adding getter for model weight enum

* Support both strings and callables on get_model_weight.

* linter fixes

* Fixing mypy.

* Renaming `get_model_weight` to `get_model_weights`

* Registering all classification models.

* Registering all video models.

* Registering all detection models.

* Registering all optical flow models.

* Fixing mypy.

* Registering all segmentation models.

* Registering all quantization models.

* Fixing linter

* Registering all prototype depth perception models.

* Adding tests and updating existing tests.

* Fix linters

* Fix tests.

* Add beta annotation on docs.

* Fix tests.

* Apply changes from code-review.

* Adding documentation.

* Fix docs.
parent 63870514
...@@ -120,6 +120,46 @@ behavior, such as batch normalization. To switch between these modes, use ...@@ -120,6 +120,46 @@ behavior, such as batch normalization. To switch between these modes, use
# Set model to eval mode # Set model to eval mode
model.eval() model.eval()
Model Registration Mechanism
----------------------------
.. betastatus:: registration mechanism
As of v0.14, TorchVision offers a new model registration mechanism which allows retreaving models
and weights by their names. Here are a few examples on how to use them:
.. code:: python
# List available models
all_models = list_models()
classification_models = list_models(module=torchvision.models)
# Initialize models
m1 = get_model("mobilenet_v3_large", weights=None)
m2 = get_model("quantized_mobilenet_v3_large", weights="DEFAULT")
# Fetch weights
weights = get_weight("MobileNet_V3_Large_QuantizedWeights.DEFAULT")
assert weights == MobileNet_V3_Large_QuantizedWeights.DEFAULT
weights_enum = get_model_weights("quantized_mobilenet_v3_large")
assert weights_enum == MobileNet_V3_Large_QuantizedWeights
weights_enum2 = get_model_weights(torchvision.models.quantization.mobilenet_v3_large)
assert weights_enum == weights_enum2
Here are the available public methods of the model registration mechanism:
.. currentmodule:: torchvision.models
.. autosummary::
:toctree: generated/
:template: function.rst
get_model
get_model_weights
get_weight
list_models
Using models from Hub Using models from Hub
--------------------- ---------------------
......
...@@ -11,15 +11,6 @@ from torchvision.models.detection.backbone_utils import BackboneWithFPN, mobilen ...@@ -11,15 +11,6 @@ from torchvision.models.detection.backbone_utils import BackboneWithFPN, mobilen
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
def get_available_models():
# TODO add a registration mechanism to torchvision.models
return [
k
for k, v in models.__dict__.items()
if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight"
]
@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50")) @pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
def test_resnet_fpn_backbone(backbone_name): def test_resnet_fpn_backbone(backbone_name):
x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu") x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu")
...@@ -135,10 +126,10 @@ class TestFxFeatureExtraction: ...@@ -135,10 +126,10 @@ class TestFxFeatureExtraction:
eval_nodes = [n for n in eval_nodes if not any(x in n for x in exclude_nodes_filter)] eval_nodes = [n for n in eval_nodes if not any(x in n for x in exclude_nodes_filter)]
return random.sample(train_nodes, 10), random.sample(eval_nodes, 10) return random.sample(train_nodes, 10), random.sample(eval_nodes, 10)
@pytest.mark.parametrize("model_name", get_available_models()) @pytest.mark.parametrize("model_name", models.list_models(models))
def test_build_fx_feature_extractor(self, model_name): def test_build_fx_feature_extractor(self, model_name):
set_rng_seed(0) set_rng_seed(0)
model = models.__dict__[model_name](**self.model_defaults).eval() model = models.get_model(model_name, **self.model_defaults).eval()
train_return_nodes, eval_return_nodes = self._get_return_nodes(model) train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
# Check that it works with both a list and dict for return nodes # Check that it works with both a list and dict for return nodes
self._create_feature_extractor( self._create_feature_extractor(
...@@ -172,9 +163,9 @@ class TestFxFeatureExtraction: ...@@ -172,9 +163,9 @@ class TestFxFeatureExtraction:
train_nodes, _ = get_graph_node_names(model) train_nodes, _ = get_graph_node_names(model)
assert all(a == b for a, b in zip(train_nodes, test_module_nodes)) assert all(a == b for a, b in zip(train_nodes, test_module_nodes))
@pytest.mark.parametrize("model_name", get_available_models()) @pytest.mark.parametrize("model_name", models.list_models(models))
def test_forward_backward(self, model_name): def test_forward_backward(self, model_name):
model = models.__dict__[model_name](**self.model_defaults).train() model = models.get_model(model_name, **self.model_defaults).train()
train_return_nodes, eval_return_nodes = self._get_return_nodes(model) train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
model = self._create_feature_extractor( model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
...@@ -211,10 +202,10 @@ class TestFxFeatureExtraction: ...@@ -211,10 +202,10 @@ class TestFxFeatureExtraction:
for k in ilg_out.keys(): for k in ilg_out.keys():
assert ilg_out[k].equal(fgn_out[k]) assert ilg_out[k].equal(fgn_out[k])
@pytest.mark.parametrize("model_name", get_available_models()) @pytest.mark.parametrize("model_name", models.list_models(models))
def test_jit_forward_backward(self, model_name): def test_jit_forward_backward(self, model_name):
set_rng_seed(0) set_rng_seed(0)
model = models.__dict__[model_name](**self.model_defaults).train() model = models.get_model(model_name, **self.model_defaults).train()
train_return_nodes, eval_return_nodes = self._get_return_nodes(model) train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
model = self._create_feature_extractor( model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
......
import importlib
import os import os
import pytest import pytest
import test_models as TM import test_models as TM
import torch import torch
from torchvision import models from torchvision import models
from torchvision.models._api import Weights, WeightsEnum from torchvision.models._api import get_model_weights, Weights, WeightsEnum
from torchvision.models._utils import handle_legacy_interface from torchvision.models._utils import handle_legacy_interface
...@@ -15,23 +14,52 @@ run_if_test_with_extended = pytest.mark.skipif( ...@@ -15,23 +14,52 @@ run_if_test_with_extended = pytest.mark.skipif(
) )
def _get_parent_module(model_fn): @pytest.mark.parametrize(
parent_module_name = ".".join(model_fn.__module__.split(".")[:-1]) "name, model_class",
module = importlib.import_module(parent_module_name) [
return module ("resnet50", models.ResNet),
("retinanet_resnet50_fpn_v2", models.detection.RetinaNet),
("raft_large", models.optical_flow.RAFT),
("quantized_resnet50", models.quantization.QuantizableResNet),
("lraspp_mobilenet_v3_large", models.segmentation.LRASPP),
("mvit_v1_b", models.video.MViT),
],
)
def test_get_model(name, model_class):
assert isinstance(models.get_model(name), model_class)
def _get_model_weights(model_fn): @pytest.mark.parametrize(
module = _get_parent_module(model_fn) "name, weight",
weights_name = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights" [
try: ("resnet50", models.ResNet50_Weights),
return next( ("retinanet_resnet50_fpn_v2", models.detection.RetinaNet_ResNet50_FPN_V2_Weights),
v ("raft_large", models.optical_flow.Raft_Large_Weights),
("quantized_resnet50", models.quantization.ResNet50_QuantizedWeights),
("lraspp_mobilenet_v3_large", models.segmentation.LRASPP_MobileNet_V3_Large_Weights),
("mvit_v1_b", models.video.MViT_V1_B_Weights),
],
)
def test_get_model_weights(name, weight):
assert models.get_model_weights(name) == weight
@pytest.mark.parametrize(
"module", [models, models.detection, models.quantization, models.segmentation, models.video, models.optical_flow]
)
def test_list_models(module):
def get_models_from_module(module):
return [
v.__name__
for k, v in module.__dict__.items() for k, v in module.__dict__.items()
if k.endswith(weights_name) and k.replace(weights_name, "").lower() == model_fn.__name__ if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__
) ]
except StopIteration:
return None a = set(get_models_from_module(module))
b = set(x.replace("quantized_", "") for x in models.list_models(module))
assert len(b) > 0
assert a == b
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -55,27 +83,27 @@ def test_get_weight(name, weight): ...@@ -55,27 +83,27 @@ def test_get_weight(name, weight):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_fn", "model_fn",
TM.get_models_from_module(models) TM.list_model_fns(models)
+ TM.get_models_from_module(models.detection) + TM.list_model_fns(models.detection)
+ TM.get_models_from_module(models.quantization) + TM.list_model_fns(models.quantization)
+ TM.get_models_from_module(models.segmentation) + TM.list_model_fns(models.segmentation)
+ TM.get_models_from_module(models.video) + TM.list_model_fns(models.video)
+ TM.get_models_from_module(models.optical_flow), + TM.list_model_fns(models.optical_flow),
) )
def test_naming_conventions(model_fn): def test_naming_conventions(model_fn):
weights_enum = _get_model_weights(model_fn) weights_enum = get_model_weights(model_fn)
assert weights_enum is not None assert weights_enum is not None
assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT") assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_fn", "model_fn",
TM.get_models_from_module(models) TM.list_model_fns(models)
+ TM.get_models_from_module(models.detection) + TM.list_model_fns(models.detection)
+ TM.get_models_from_module(models.quantization) + TM.list_model_fns(models.quantization)
+ TM.get_models_from_module(models.segmentation) + TM.list_model_fns(models.segmentation)
+ TM.get_models_from_module(models.video) + TM.list_model_fns(models.video)
+ TM.get_models_from_module(models.optical_flow), + TM.list_model_fns(models.optical_flow),
) )
@run_if_test_with_extended @run_if_test_with_extended
def test_schema_meta_validation(model_fn): def test_schema_meta_validation(model_fn):
...@@ -112,7 +140,7 @@ def test_schema_meta_validation(model_fn): ...@@ -112,7 +140,7 @@ def test_schema_meta_validation(model_fn):
module_name = model_fn.__module__.split(".")[-2] module_name = model_fn.__module__.split(".")[-2]
expected_fields = defaults["all"] | defaults[module_name] expected_fields = defaults["all"] | defaults[module_name]
weights_enum = _get_model_weights(model_fn) weights_enum = get_model_weights(model_fn)
if len(weights_enum) == 0: if len(weights_enum) == 0:
pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.") pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")
...@@ -153,17 +181,17 @@ def test_schema_meta_validation(model_fn): ...@@ -153,17 +181,17 @@ def test_schema_meta_validation(model_fn):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_fn", "model_fn",
TM.get_models_from_module(models) TM.list_model_fns(models)
+ TM.get_models_from_module(models.detection) + TM.list_model_fns(models.detection)
+ TM.get_models_from_module(models.quantization) + TM.list_model_fns(models.quantization)
+ TM.get_models_from_module(models.segmentation) + TM.list_model_fns(models.segmentation)
+ TM.get_models_from_module(models.video) + TM.list_model_fns(models.video)
+ TM.get_models_from_module(models.optical_flow), + TM.list_model_fns(models.optical_flow),
) )
@run_if_test_with_extended @run_if_test_with_extended
def test_transforms_jit(model_fn): def test_transforms_jit(model_fn):
model_name = model_fn.__name__ model_name = model_fn.__name__
weights_enum = _get_model_weights(model_fn) weights_enum = get_model_weights(model_fn)
if len(weights_enum) == 0: if len(weights_enum) == 0:
pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.") pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")
......
...@@ -16,18 +16,15 @@ import torch.nn as nn ...@@ -16,18 +16,15 @@ import torch.nn as nn
from _utils_internal import get_relative_path from _utils_internal import get_relative_path
from common_utils import cpu_and_gpu, freeze_rng_state, map_nested_tensor_object, needs_cuda, set_rng_seed from common_utils import cpu_and_gpu, freeze_rng_state, map_nested_tensor_object, needs_cuda, set_rng_seed
from torchvision import models from torchvision import models
from torchvision.models._api import find_model, list_models
ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1" ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1"
SKIP_BIG_MODEL = os.getenv("SKIP_BIG_MODEL", "1") == "1" SKIP_BIG_MODEL = os.getenv("SKIP_BIG_MODEL", "1") == "1"
def get_models_from_module(module): def list_model_fns(module):
# TODO add a registration mechanism to torchvision.models return [find_model(name) for name in list_models(module)]
return [
v
for k, v in module.__dict__.items()
if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight"
]
@pytest.fixture @pytest.fixture
...@@ -597,7 +594,7 @@ def test_vitc_models(model_fn, dev): ...@@ -597,7 +594,7 @@ def test_vitc_models(model_fn, dev):
test_classification_model(model_fn, dev) test_classification_model(model_fn, dev)
@pytest.mark.parametrize("model_fn", get_models_from_module(models)) @pytest.mark.parametrize("model_fn", list_model_fns(models))
@pytest.mark.parametrize("dev", cpu_and_gpu()) @pytest.mark.parametrize("dev", cpu_and_gpu())
def test_classification_model(model_fn, dev): def test_classification_model(model_fn, dev):
set_rng_seed(0) set_rng_seed(0)
...@@ -633,7 +630,7 @@ def test_classification_model(model_fn, dev): ...@@ -633,7 +630,7 @@ def test_classification_model(model_fn, dev):
_check_input_backprop(model, x) _check_input_backprop(model, x)
@pytest.mark.parametrize("model_fn", get_models_from_module(models.segmentation)) @pytest.mark.parametrize("model_fn", list_model_fns(models.segmentation))
@pytest.mark.parametrize("dev", cpu_and_gpu()) @pytest.mark.parametrize("dev", cpu_and_gpu())
def test_segmentation_model(model_fn, dev): def test_segmentation_model(model_fn, dev):
set_rng_seed(0) set_rng_seed(0)
...@@ -695,7 +692,7 @@ def test_segmentation_model(model_fn, dev): ...@@ -695,7 +692,7 @@ def test_segmentation_model(model_fn, dev):
_check_input_backprop(model, x) _check_input_backprop(model, x)
@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection)) @pytest.mark.parametrize("model_fn", list_model_fns(models.detection))
@pytest.mark.parametrize("dev", cpu_and_gpu()) @pytest.mark.parametrize("dev", cpu_and_gpu())
def test_detection_model(model_fn, dev): def test_detection_model(model_fn, dev):
set_rng_seed(0) set_rng_seed(0)
...@@ -793,7 +790,7 @@ def test_detection_model(model_fn, dev): ...@@ -793,7 +790,7 @@ def test_detection_model(model_fn, dev):
_check_input_backprop(model, model_input) _check_input_backprop(model, model_input)
@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection)) @pytest.mark.parametrize("model_fn", list_model_fns(models.detection))
def test_detection_model_validation(model_fn): def test_detection_model_validation(model_fn):
set_rng_seed(0) set_rng_seed(0)
model = model_fn(num_classes=50, weights=None, weights_backbone=None) model = model_fn(num_classes=50, weights=None, weights_backbone=None)
...@@ -822,7 +819,7 @@ def test_detection_model_validation(model_fn): ...@@ -822,7 +819,7 @@ def test_detection_model_validation(model_fn):
model(x, targets=targets) model(x, targets=targets)
@pytest.mark.parametrize("model_fn", get_models_from_module(models.video)) @pytest.mark.parametrize("model_fn", list_model_fns(models.video))
@pytest.mark.parametrize("dev", cpu_and_gpu()) @pytest.mark.parametrize("dev", cpu_and_gpu())
def test_video_model(model_fn, dev): def test_video_model(model_fn, dev):
set_rng_seed(0) set_rng_seed(0)
...@@ -868,7 +865,7 @@ def test_video_model(model_fn, dev): ...@@ -868,7 +865,7 @@ def test_video_model(model_fn, dev):
), ),
reason="This Pytorch Build has not been built with fbgemm and qnnpack", reason="This Pytorch Build has not been built with fbgemm and qnnpack",
) )
@pytest.mark.parametrize("model_fn", get_models_from_module(models.quantization)) @pytest.mark.parametrize("model_fn", list_model_fns(models.quantization))
def test_quantized_classification_model(model_fn): def test_quantized_classification_model(model_fn):
set_rng_seed(0) set_rng_seed(0)
defaults = { defaults = {
...@@ -917,7 +914,7 @@ def test_quantized_classification_model(model_fn): ...@@ -917,7 +914,7 @@ def test_quantized_classification_model(model_fn):
torch.ao.quantization.convert(model, inplace=True) torch.ao.quantization.convert(model, inplace=True)
@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection)) @pytest.mark.parametrize("model_fn", list_model_fns(models.detection))
def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_loading): def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_loading):
model_name = model_fn.__name__ model_name = model_fn.__name__
max_trainable = _model_tests_values[model_name]["max_trainable"] max_trainable = _model_tests_values[model_name]["max_trainable"]
...@@ -930,9 +927,9 @@ def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_load ...@@ -930,9 +927,9 @@ def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_load
@needs_cuda @needs_cuda
@pytest.mark.parametrize("model_builder", (models.optical_flow.raft_large, models.optical_flow.raft_small)) @pytest.mark.parametrize("model_fn", list_model_fns(models.optical_flow))
@pytest.mark.parametrize("scripted", (False, True)) @pytest.mark.parametrize("scripted", (False, True))
def test_raft(model_builder, scripted): def test_raft(model_fn, scripted):
torch.manual_seed(0) torch.manual_seed(0)
...@@ -942,7 +939,7 @@ def test_raft(model_builder, scripted): ...@@ -942,7 +939,7 @@ def test_raft(model_builder, scripted):
# reduced to 1) # reduced to 1)
corr_block = models.optical_flow.raft.CorrBlock(num_levels=2, radius=2) corr_block = models.optical_flow.raft.CorrBlock(num_levels=2, radius=2)
model = model_builder(corr_block=corr_block).eval().to("cuda") model = model_fn(corr_block=corr_block).eval().to("cuda")
if scripted: if scripted:
model = torch.jit.script(model) model = torch.jit.script(model)
...@@ -954,7 +951,7 @@ def test_raft(model_builder, scripted): ...@@ -954,7 +951,7 @@ def test_raft(model_builder, scripted):
flow_pred = preds[-1] flow_pred = preds[-1]
# Tolerance is fairly high, but there are 2 * H * W outputs to check # Tolerance is fairly high, but there are 2 * H * W outputs to check
# The .pkl were generated on the AWS cluter, on the CI it looks like the resuts are slightly different # The .pkl were generated on the AWS cluter, on the CI it looks like the resuts are slightly different
_assert_expected(flow_pred, name=model_builder.__name__, atol=1e-2, rtol=1) _assert_expected(flow_pred, name=model_fn.__name__, atol=1e-2, rtol=1)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -99,8 +99,8 @@ class TestModelsDetectionNegativeSamples: ...@@ -99,8 +99,8 @@ class TestModelsDetectionNegativeSamples:
], ],
) )
def test_forward_negative_sample_frcnn(self, name): def test_forward_negative_sample_frcnn(self, name):
model = torchvision.models.detection.__dict__[name]( model = torchvision.models.get_model(
weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100 name, weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100
) )
images, targets = self._make_empty_sample() images, targets = self._make_empty_sample()
......
import pytest import pytest
import test_models as TM import test_models as TM
import torch import torch
import torchvision.prototype.models.depth.stereo.raft_stereo as raft_stereo
from common_utils import cpu_and_gpu, set_rng_seed from common_utils import cpu_and_gpu, set_rng_seed
from torchvision.prototype import models
@pytest.mark.parametrize("model_builder", (raft_stereo.raft_stereo_base, raft_stereo.raft_stereo_realtime)) @pytest.mark.parametrize("model_fn", TM.list_model_fns(models.depth.stereo))
@pytest.mark.parametrize("model_mode", ("standard", "scripted")) @pytest.mark.parametrize("model_mode", ("standard", "scripted"))
@pytest.mark.parametrize("dev", cpu_and_gpu()) @pytest.mark.parametrize("dev", cpu_and_gpu())
def test_raft_stereo(model_builder, model_mode, dev): def test_raft_stereo(model_fn, model_mode, dev):
# A simple test to make sure the model can do forward pass and jit scriptable # A simple test to make sure the model can do forward pass and jit scriptable
set_rng_seed(0) set_rng_seed(0)
# Use corr_pyramid and corr_block with smaller num_levels and radius to prevent nan output # Use corr_pyramid and corr_block with smaller num_levels and radius to prevent nan output
# get the idea from test_models.test_raft # get the idea from test_models.test_raft
corr_pyramid = raft_stereo.CorrPyramid1d(num_levels=2) corr_pyramid = models.depth.stereo.raft_stereo.CorrPyramid1d(num_levels=2)
corr_block = raft_stereo.CorrBlock1d(num_levels=2, radius=2) corr_block = models.depth.stereo.raft_stereo.CorrBlock1d(num_levels=2, radius=2)
model = model_builder(corr_pyramid=corr_pyramid, corr_block=corr_block).eval().to(dev) model = model_fn(corr_pyramid=corr_pyramid, corr_block=corr_block).eval().to(dev)
if model_mode == "scripted": if model_mode == "scripted":
model = torch.jit.script(model) model = torch.jit.script(model)
...@@ -35,4 +35,4 @@ def test_raft_stereo(model_builder, model_mode, dev): ...@@ -35,4 +35,4 @@ def test_raft_stereo(model_builder, model_mode, dev):
), f"The output shape of depth_pred should be [1, 1, 64, 64] but instead it is {preds[0].shape}" ), f"The output shape of depth_pred should be [1, 1, 64, 64] but instead it is {preds[0].shape}"
# Test against expected file output # Test against expected file output
TM._assert_expected(depth_pred, name=model_builder.__name__, atol=1e-2, rtol=1e-2) TM._assert_expected(depth_pred, name=model_fn.__name__, atol=1e-2, rtol=1e-2)
...@@ -14,4 +14,4 @@ from .vgg import * ...@@ -14,4 +14,4 @@ from .vgg import *
from .vision_transformer import * from .vision_transformer import *
from .swin_transformer import * from .swin_transformer import *
from . import detection, optical_flow, quantization, segmentation, video from . import detection, optical_flow, quantization, segmentation, video
from ._api import get_weight from ._api import get_model, get_model_weights, get_weight, list_models
...@@ -3,14 +3,17 @@ import inspect ...@@ -3,14 +3,17 @@ import inspect
import sys import sys
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from inspect import signature from inspect import signature
from typing import Any, Callable, cast, Dict, Mapping from types import ModuleType
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, TypeVar, Union
from torch import nn
from torchvision._utils import StrEnum from torchvision._utils import StrEnum
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
__all__ = ["WeightsEnum", "Weights", "get_weight"] __all__ = ["WeightsEnum", "Weights", "get_model", "get_model_weights", "get_weight", "list_models"]
@dataclass @dataclass
...@@ -75,7 +78,9 @@ class WeightsEnum(StrEnum): ...@@ -75,7 +78,9 @@ class WeightsEnum(StrEnum):
def get_weight(name: str) -> WeightsEnum: def get_weight(name: str) -> WeightsEnum:
""" """
Gets the weight enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1" Gets the weights enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1"
.. betastatus:: function
Args: Args:
name (str): The name of the weight enum entry. name (str): The name of the weight enum entry.
...@@ -107,10 +112,29 @@ def get_weight(name: str) -> WeightsEnum: ...@@ -107,10 +112,29 @@ def get_weight(name: str) -> WeightsEnum:
return weights_enum.from_str(value_name) return weights_enum.from_str(value_name)
W = TypeVar("W", bound=WeightsEnum)
def get_model_weights(model: Union[Callable, str]) -> W:
"""
Retuns the weights enum class associated to the given model.
.. betastatus:: function
Args:
name (callable or str): The model builder function or the name under which it is registered.
Returns:
weights_enum (W): The weights enum class associated with the model.
"""
if isinstance(model, str):
model = find_model(model)
return cast(W, _get_enum_from_fn(model))
def _get_enum_from_fn(fn: Callable) -> WeightsEnum: def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
""" """
Internal method that gets the weight enum of a specific model builder method. Internal method that gets the weight enum of a specific model builder method.
Might be removed after the handle_legacy_interface is removed.
Args: Args:
fn (Callable): The builder method used to create the model. fn (Callable): The builder method used to create the model.
...@@ -140,3 +164,63 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum: ...@@ -140,3 +164,63 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
) )
return cast(WeightsEnum, weights_enum) return cast(WeightsEnum, weights_enum)
M = TypeVar("M", bound=nn.Module)
BUILTIN_MODELS = {}
def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], Callable[..., M]]:
def wrapper(fn: Callable[..., M]) -> Callable[..., M]:
key = name if name is not None else fn.__name__
if key in BUILTIN_MODELS:
raise ValueError(f"An entry is already registered under the name '{key}'.")
BUILTIN_MODELS[key] = fn
return fn
return wrapper
def list_models(module: Optional[ModuleType] = None) -> List[str]:
"""
Returns a list with the names of registered models.
.. betastatus:: function
Args:
module (ModuleType, optional): The module from which we want to extract the available models.
Returns:
models (list): A list with the names of available models.
"""
models = [
k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__.rsplit(".", 1)[0] == module.__name__
]
return sorted(models)
def find_model(name: str) -> Callable[..., M]:
name = name.lower()
try:
fn = BUILTIN_MODELS[name]
except KeyError:
raise ValueError(f"Unknown model {name}")
return fn
def get_model(name: str, **config: Any) -> M:
"""
Gets the model name and configuration and returns an instantiated model.
.. betastatus:: function
Args:
name (str): The name under which the model is registered.
**config (Any): parameters passed to the model builder method.
Returns:
model (nn.Module): The initialized model.
"""
fn = find_model(name)
return fn(**config)
...@@ -6,7 +6,7 @@ import torch.nn as nn ...@@ -6,7 +6,7 @@ import torch.nn as nn
from ..transforms._presets import ImageClassification from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface from ._utils import _ovewrite_named_param, handle_legacy_interface
...@@ -75,6 +75,7 @@ class AlexNet_Weights(WeightsEnum): ...@@ -75,6 +75,7 @@ class AlexNet_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1))
def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
"""AlexNet model architecture from `One weird trick for parallelizing convolutional neural networks <https://arxiv.org/abs/1404.5997>`__. """AlexNet model architecture from `One weird trick for parallelizing convolutional neural networks <https://arxiv.org/abs/1404.5997>`__.
......
...@@ -9,7 +9,7 @@ from ..ops.misc import Conv2dNormActivation, Permute ...@@ -9,7 +9,7 @@ from ..ops.misc import Conv2dNormActivation, Permute
from ..ops.stochastic_depth import StochasticDepth from ..ops.stochastic_depth import StochasticDepth
from ..transforms._presets import ImageClassification from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface from ._utils import _ovewrite_named_param, handle_legacy_interface
...@@ -278,6 +278,7 @@ class ConvNeXt_Large_Weights(WeightsEnum): ...@@ -278,6 +278,7 @@ class ConvNeXt_Large_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1))
def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
"""ConvNeXt Tiny model architecture from the """ConvNeXt Tiny model architecture from the
...@@ -308,6 +309,7 @@ def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: ...@@ -308,6 +309,7 @@ def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress:
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1))
def convnext_small( def convnext_small(
*, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any
...@@ -340,6 +342,7 @@ def convnext_small( ...@@ -340,6 +342,7 @@ def convnext_small(
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1))
def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
"""ConvNeXt Base model architecture from the """ConvNeXt Base model architecture from the
...@@ -370,6 +373,7 @@ def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: ...@@ -370,6 +373,7 @@ def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress:
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1))
def convnext_large( def convnext_large(
*, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any
......
...@@ -11,7 +11,7 @@ from torch import Tensor ...@@ -11,7 +11,7 @@ from torch import Tensor
from ..transforms._presets import ImageClassification from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface from ._utils import _ovewrite_named_param, handle_legacy_interface
...@@ -337,6 +337,7 @@ class DenseNet201_Weights(WeightsEnum): ...@@ -337,6 +337,7 @@ class DenseNet201_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1))
def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
r"""Densenet-121 model from r"""Densenet-121 model from
...@@ -362,6 +363,7 @@ def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool ...@@ -362,6 +363,7 @@ def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool
return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs) return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.IMAGENET1K_V1))
def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
r"""Densenet-161 model from r"""Densenet-161 model from
...@@ -387,6 +389,7 @@ def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool ...@@ -387,6 +389,7 @@ def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool
return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs) return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.IMAGENET1K_V1))
def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
r"""Densenet-169 model from r"""Densenet-169 model from
...@@ -412,6 +415,7 @@ def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool ...@@ -412,6 +415,7 @@ def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool
return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs) return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.IMAGENET1K_V1))
def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
r"""Densenet-201 model from r"""Densenet-201 model from
......
...@@ -7,7 +7,7 @@ from torchvision.ops import MultiScaleRoIAlign ...@@ -7,7 +7,7 @@ from torchvision.ops import MultiScaleRoIAlign
from ...ops import misc as misc_nn_ops from ...ops import misc as misc_nn_ops
from ...transforms._presets import ObjectDetection from ...transforms._presets import ObjectDetection
from .._api import Weights, WeightsEnum from .._api import register_model, Weights, WeightsEnum
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from .._utils import _ovewrite_value_param, handle_legacy_interface from .._utils import _ovewrite_value_param, handle_legacy_interface
from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights
...@@ -451,6 +451,7 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): ...@@ -451,6 +451,7 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
DEFAULT = COCO_V1 DEFAULT = COCO_V1
@register_model()
@handle_legacy_interface( @handle_legacy_interface(
weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1), weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
...@@ -569,6 +570,7 @@ def fasterrcnn_resnet50_fpn( ...@@ -569,6 +570,7 @@ def fasterrcnn_resnet50_fpn(
return model return model
@register_model()
def fasterrcnn_resnet50_fpn_v2( def fasterrcnn_resnet50_fpn_v2(
*, *,
weights: Optional[FasterRCNN_ResNet50_FPN_V2_Weights] = None, weights: Optional[FasterRCNN_ResNet50_FPN_V2_Weights] = None,
...@@ -685,6 +687,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn( ...@@ -685,6 +687,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
return model return model
@register_model()
@handle_legacy_interface( @handle_legacy_interface(
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1), weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
...@@ -758,6 +761,7 @@ def fasterrcnn_mobilenet_v3_large_320_fpn( ...@@ -758,6 +761,7 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
) )
@register_model()
@handle_legacy_interface( @handle_legacy_interface(
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1), weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
......
...@@ -11,7 +11,7 @@ from ...ops import boxes as box_ops, generalized_box_iou_loss, misc as misc_nn_o ...@@ -11,7 +11,7 @@ from ...ops import boxes as box_ops, generalized_box_iou_loss, misc as misc_nn_o
from ...ops.feature_pyramid_network import LastLevelP6P7 from ...ops.feature_pyramid_network import LastLevelP6P7
from ...transforms._presets import ObjectDetection from ...transforms._presets import ObjectDetection
from ...utils import _log_api_usage_once from ...utils import _log_api_usage_once
from .._api import Weights, WeightsEnum from .._api import register_model, Weights, WeightsEnum
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from .._utils import _ovewrite_value_param, handle_legacy_interface from .._utils import _ovewrite_value_param, handle_legacy_interface
from ..resnet import resnet50, ResNet50_Weights from ..resnet import resnet50, ResNet50_Weights
...@@ -666,6 +666,7 @@ class FCOS_ResNet50_FPN_Weights(WeightsEnum): ...@@ -666,6 +666,7 @@ class FCOS_ResNet50_FPN_Weights(WeightsEnum):
DEFAULT = COCO_V1 DEFAULT = COCO_V1
@register_model()
@handle_legacy_interface( @handle_legacy_interface(
weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1), weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
......
...@@ -6,7 +6,7 @@ from torchvision.ops import MultiScaleRoIAlign ...@@ -6,7 +6,7 @@ from torchvision.ops import MultiScaleRoIAlign
from ...ops import misc as misc_nn_ops from ...ops import misc as misc_nn_ops
from ...transforms._presets import ObjectDetection from ...transforms._presets import ObjectDetection
from .._api import Weights, WeightsEnum from .._api import register_model, Weights, WeightsEnum
from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
from .._utils import _ovewrite_value_param, handle_legacy_interface from .._utils import _ovewrite_value_param, handle_legacy_interface
from ..resnet import resnet50, ResNet50_Weights from ..resnet import resnet50, ResNet50_Weights
...@@ -353,6 +353,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): ...@@ -353,6 +353,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
DEFAULT = COCO_V1 DEFAULT = COCO_V1
@register_model()
@handle_legacy_interface( @handle_legacy_interface(
weights=( weights=(
"pretrained", "pretrained",
......
...@@ -6,7 +6,7 @@ from torchvision.ops import MultiScaleRoIAlign ...@@ -6,7 +6,7 @@ from torchvision.ops import MultiScaleRoIAlign
from ...ops import misc as misc_nn_ops from ...ops import misc as misc_nn_ops
from ...transforms._presets import ObjectDetection from ...transforms._presets import ObjectDetection
from .._api import Weights, WeightsEnum from .._api import register_model, Weights, WeightsEnum
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from .._utils import _ovewrite_value_param, handle_legacy_interface from .._utils import _ovewrite_value_param, handle_legacy_interface
from ..resnet import resnet50, ResNet50_Weights from ..resnet import resnet50, ResNet50_Weights
...@@ -396,6 +396,7 @@ class MaskRCNN_ResNet50_FPN_V2_Weights(WeightsEnum): ...@@ -396,6 +396,7 @@ class MaskRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
DEFAULT = COCO_V1 DEFAULT = COCO_V1
@register_model()
@handle_legacy_interface( @handle_legacy_interface(
weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.COCO_V1), weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
...@@ -503,6 +504,7 @@ def maskrcnn_resnet50_fpn( ...@@ -503,6 +504,7 @@ def maskrcnn_resnet50_fpn(
return model return model
@register_model()
def maskrcnn_resnet50_fpn_v2( def maskrcnn_resnet50_fpn_v2(
*, *,
weights: Optional[MaskRCNN_ResNet50_FPN_V2_Weights] = None, weights: Optional[MaskRCNN_ResNet50_FPN_V2_Weights] = None,
......
...@@ -11,7 +11,7 @@ from ...ops import boxes as box_ops, misc as misc_nn_ops, sigmoid_focal_loss ...@@ -11,7 +11,7 @@ from ...ops import boxes as box_ops, misc as misc_nn_ops, sigmoid_focal_loss
from ...ops.feature_pyramid_network import LastLevelP6P7 from ...ops.feature_pyramid_network import LastLevelP6P7
from ...transforms._presets import ObjectDetection from ...transforms._presets import ObjectDetection
from ...utils import _log_api_usage_once from ...utils import _log_api_usage_once
from .._api import Weights, WeightsEnum from .._api import register_model, Weights, WeightsEnum
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from .._utils import _ovewrite_value_param, handle_legacy_interface from .._utils import _ovewrite_value_param, handle_legacy_interface
from ..resnet import resnet50, ResNet50_Weights from ..resnet import resnet50, ResNet50_Weights
...@@ -715,6 +715,7 @@ class RetinaNet_ResNet50_FPN_V2_Weights(WeightsEnum): ...@@ -715,6 +715,7 @@ class RetinaNet_ResNet50_FPN_V2_Weights(WeightsEnum):
DEFAULT = COCO_V1 DEFAULT = COCO_V1
@register_model()
@handle_legacy_interface( @handle_legacy_interface(
weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1), weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
...@@ -817,6 +818,7 @@ def retinanet_resnet50_fpn( ...@@ -817,6 +818,7 @@ def retinanet_resnet50_fpn(
return model return model
@register_model()
def retinanet_resnet50_fpn_v2( def retinanet_resnet50_fpn_v2(
*, *,
weights: Optional[RetinaNet_ResNet50_FPN_V2_Weights] = None, weights: Optional[RetinaNet_ResNet50_FPN_V2_Weights] = None,
......
...@@ -9,7 +9,7 @@ from torch import nn, Tensor ...@@ -9,7 +9,7 @@ from torch import nn, Tensor
from ...ops import boxes as box_ops from ...ops import boxes as box_ops
from ...transforms._presets import ObjectDetection from ...transforms._presets import ObjectDetection
from ...utils import _log_api_usage_once from ...utils import _log_api_usage_once
from .._api import Weights, WeightsEnum from .._api import register_model, Weights, WeightsEnum
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from .._utils import _ovewrite_value_param, handle_legacy_interface from .._utils import _ovewrite_value_param, handle_legacy_interface
from ..vgg import VGG, vgg16, VGG16_Weights from ..vgg import VGG, vgg16, VGG16_Weights
...@@ -568,6 +568,7 @@ def _vgg_extractor(backbone: VGG, highres: bool, trainable_layers: int): ...@@ -568,6 +568,7 @@ def _vgg_extractor(backbone: VGG, highres: bool, trainable_layers: int):
return SSDFeatureExtractorVGG(backbone, highres) return SSDFeatureExtractorVGG(backbone, highres)
@register_model()
@handle_legacy_interface( @handle_legacy_interface(
weights=("pretrained", SSD300_VGG16_Weights.COCO_V1), weights=("pretrained", SSD300_VGG16_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", VGG16_Weights.IMAGENET1K_FEATURES), weights_backbone=("pretrained_backbone", VGG16_Weights.IMAGENET1K_FEATURES),
......
...@@ -10,7 +10,7 @@ from ...ops.misc import Conv2dNormActivation ...@@ -10,7 +10,7 @@ from ...ops.misc import Conv2dNormActivation
from ...transforms._presets import ObjectDetection from ...transforms._presets import ObjectDetection
from ...utils import _log_api_usage_once from ...utils import _log_api_usage_once
from .. import mobilenet from .. import mobilenet
from .._api import Weights, WeightsEnum from .._api import register_model, Weights, WeightsEnum
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from .._utils import _ovewrite_value_param, handle_legacy_interface from .._utils import _ovewrite_value_param, handle_legacy_interface
from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights
...@@ -204,6 +204,7 @@ class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum): ...@@ -204,6 +204,7 @@ class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum):
DEFAULT = COCO_V1 DEFAULT = COCO_V1
@register_model()
@handle_legacy_interface( @handle_legacy_interface(
weights=("pretrained", SSDLite320_MobileNet_V3_Large_Weights.COCO_V1), weights=("pretrained", SSDLite320_MobileNet_V3_Large_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
......
...@@ -12,7 +12,7 @@ from torchvision.ops import StochasticDepth ...@@ -12,7 +12,7 @@ from torchvision.ops import StochasticDepth
from ..ops.misc import Conv2dNormActivation, SqueezeExcitation from ..ops.misc import Conv2dNormActivation, SqueezeExcitation
from ..transforms._presets import ImageClassification, InterpolationMode from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
...@@ -729,6 +729,7 @@ class EfficientNet_V2_L_Weights(WeightsEnum): ...@@ -729,6 +729,7 @@ class EfficientNet_V2_L_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1))
def efficientnet_b0( def efficientnet_b0(
*, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any
...@@ -757,6 +758,7 @@ def efficientnet_b0( ...@@ -757,6 +758,7 @@ def efficientnet_b0(
return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs) return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1))
def efficientnet_b1( def efficientnet_b1(
*, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any
...@@ -785,6 +787,7 @@ def efficientnet_b1( ...@@ -785,6 +787,7 @@ def efficientnet_b1(
return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs) return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1))
def efficientnet_b2( def efficientnet_b2(
*, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any
...@@ -813,6 +816,7 @@ def efficientnet_b2( ...@@ -813,6 +816,7 @@ def efficientnet_b2(
return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs) return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1))
def efficientnet_b3( def efficientnet_b3(
*, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any
...@@ -841,6 +845,7 @@ def efficientnet_b3( ...@@ -841,6 +845,7 @@ def efficientnet_b3(
return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs) return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1))
def efficientnet_b4( def efficientnet_b4(
*, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any
...@@ -869,6 +874,7 @@ def efficientnet_b4( ...@@ -869,6 +874,7 @@ def efficientnet_b4(
return _efficientnet(inverted_residual_setting, 0.4, last_channel, weights, progress, **kwargs) return _efficientnet(inverted_residual_setting, 0.4, last_channel, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1))
def efficientnet_b5( def efficientnet_b5(
*, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any
...@@ -905,6 +911,7 @@ def efficientnet_b5( ...@@ -905,6 +911,7 @@ def efficientnet_b5(
) )
@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.IMAGENET1K_V1))
def efficientnet_b6( def efficientnet_b6(
*, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any
...@@ -941,6 +948,7 @@ def efficientnet_b6( ...@@ -941,6 +948,7 @@ def efficientnet_b6(
) )
@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.IMAGENET1K_V1))
def efficientnet_b7( def efficientnet_b7(
*, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any
...@@ -977,6 +985,7 @@ def efficientnet_b7( ...@@ -977,6 +985,7 @@ def efficientnet_b7(
) )
@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1))
def efficientnet_v2_s( def efficientnet_v2_s(
*, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any
...@@ -1014,6 +1023,7 @@ def efficientnet_v2_s( ...@@ -1014,6 +1023,7 @@ def efficientnet_v2_s(
) )
@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1))
def efficientnet_v2_m( def efficientnet_v2_m(
*, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any
...@@ -1051,6 +1061,7 @@ def efficientnet_v2_m( ...@@ -1051,6 +1061,7 @@ def efficientnet_v2_m(
) )
@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1))
def efficientnet_v2_l( def efficientnet_v2_l(
*, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any
......
...@@ -10,7 +10,7 @@ from torch import Tensor ...@@ -10,7 +10,7 @@ from torch import Tensor
from ..transforms._presets import ImageClassification from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface from ._utils import _ovewrite_named_param, handle_legacy_interface
...@@ -296,6 +296,7 @@ class GoogLeNet_Weights(WeightsEnum): ...@@ -296,6 +296,7 @@ class GoogLeNet_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", GoogLeNet_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", GoogLeNet_Weights.IMAGENET1K_V1))
def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
"""GoogLeNet (Inception v1) model architecture from """GoogLeNet (Inception v1) model architecture from
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment