Unverified Commit 0a919dbb authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add registration mechanism for models (#6333)

* Model registration mechanism.

* Add overwrite options to the dataset prototype registration mechanism.

* Adding example models.

* Fix module filtering

* Fix linter

* Fix docs

* Make name optional if same as model builder

* Apply updates from code-review.

* fix minor bug

* Adding getter for model weight enum

* Support both strings and callables on get_model_weight.

* linter fixes

* Fixing mypy.

* Renaming `get_model_weight` to `get_model_weights`

* Registering all classification models.

* Registering all video models.

* Registering all detection models.

* Registering all optical flow models.

* Fixing mypy.

* Registering all segmentation models.

* Registering all quantization models.

* Fixing linter

* Registering all prototype depth perception models.

* Adding tests and updating existing tests.

* Fix linters

* Fix tests.

* Add beta annotation on docs.

* Fix tests.

* Apply changes from code-review.

* Adding documentation.

* Fix docs.
parent 63870514
......@@ -9,7 +9,7 @@ from torch import nn, Tensor
from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface
......@@ -428,6 +428,7 @@ class Inception_V3_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", Inception_V3_Weights.IMAGENET1K_V1))
def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
"""
......
......@@ -8,7 +8,7 @@ from torch import Tensor
from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface
......@@ -314,6 +314,7 @@ def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwa
return model
@register_model()
@handle_legacy_interface(weights=("pretrained", MNASNet0_5_Weights.IMAGENET1K_V1))
def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
"""MNASNet with depth multiplier of 0.5 from
......@@ -341,6 +342,7 @@ def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool =
return _mnasnet(0.5, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", MNASNet0_75_Weights.IMAGENET1K_V1))
def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
"""MNASNet with depth multiplier of 0.75 from
......@@ -368,6 +370,7 @@ def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool
return _mnasnet(0.75, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", MNASNet1_0_Weights.IMAGENET1K_V1))
def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
"""MNASNet with depth multiplier of 1.0 from
......@@ -395,6 +398,7 @@ def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool =
return _mnasnet(1.0, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", MNASNet1_3_Weights.IMAGENET1K_V1))
def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
"""MNASNet with depth multiplier of 1.3 from
......
......@@ -8,7 +8,7 @@ from torch import nn, Tensor
from ..ops.misc import Conv2dNormActivation
from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
......@@ -238,6 +238,7 @@ class MobileNet_V2_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V2
@register_model()
@handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1))
def mobilenet_v2(
*, weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any
......
......@@ -8,7 +8,7 @@ from torch import nn, Tensor
from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer
from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
......@@ -371,6 +371,7 @@ class MobileNet_V3_Small_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Large_Weights.IMAGENET1K_V1))
def mobilenet_v3_large(
*, weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any
......@@ -401,6 +402,7 @@ def mobilenet_v3_large(
return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Small_Weights.IMAGENET1K_V1))
def mobilenet_v3_small(
*, weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any
......
......@@ -10,7 +10,7 @@ from torchvision.ops import Conv2dNormActivation
from ...transforms._presets import OpticalFlow
from ...utils import _log_api_usage_once
from .._api import Weights, WeightsEnum
from .._api import register_model, Weights, WeightsEnum
from .._utils import handle_legacy_interface
from ._utils import grid_sample, make_coords_grid, upsample_flow
......@@ -800,6 +800,7 @@ def _raft(
return model
@register_model()
@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_SKHT_V2))
def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs) -> RAFT:
"""RAFT model from
......@@ -855,6 +856,7 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
)
@register_model()
@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V2))
def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs) -> RAFT:
"""RAFT "small" model from
......
......@@ -8,7 +8,7 @@ from torch import Tensor
from torch.nn import functional as F
from ...transforms._presets import ImageClassification
from .._api import Weights, WeightsEnum
from .._api import register_model, Weights, WeightsEnum
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _ovewrite_named_param, handle_legacy_interface
from ..googlenet import BasicConv2d, GoogLeNet, GoogLeNet_Weights, GoogLeNetOutputs, Inception, InceptionAux
......@@ -132,6 +132,7 @@ class GoogLeNet_QuantizedWeights(WeightsEnum):
DEFAULT = IMAGENET1K_FBGEMM_V1
@register_model(name="quantized_googlenet")
@handle_legacy_interface(
weights=(
"pretrained",
......
......@@ -10,7 +10,7 @@ from torchvision.models import inception as inception_module
from torchvision.models.inception import Inception_V3_Weights, InceptionOutputs
from ...transforms._presets import ImageClassification
from .._api import Weights, WeightsEnum
from .._api import register_model, Weights, WeightsEnum
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _ovewrite_named_param, handle_legacy_interface
from .utils import _fuse_modules, _replace_relu, quantize_model
......@@ -198,6 +198,7 @@ class Inception_V3_QuantizedWeights(WeightsEnum):
DEFAULT = IMAGENET1K_FBGEMM_V1
@register_model(name="quantized_inception_v3")
@handle_legacy_interface(
weights=(
"pretrained",
......
......@@ -7,7 +7,7 @@ from torchvision.models.mobilenetv2 import InvertedResidual, MobileNet_V2_Weight
from ...ops.misc import Conv2dNormActivation
from ...transforms._presets import ImageClassification
from .._api import Weights, WeightsEnum
from .._api import register_model, Weights, WeightsEnum
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _ovewrite_named_param, handle_legacy_interface
from .utils import _fuse_modules, _replace_relu, quantize_model
......@@ -89,6 +89,7 @@ class MobileNet_V2_QuantizedWeights(WeightsEnum):
DEFAULT = IMAGENET1K_QNNPACK_V1
@register_model(name="quantized_mobilenet_v2")
@handle_legacy_interface(
weights=(
"pretrained",
......
......@@ -7,7 +7,7 @@ from torch.ao.quantization import DeQuantStub, QuantStub
from ...ops.misc import Conv2dNormActivation, SqueezeExcitation
from ...transforms._presets import ImageClassification
from .._api import Weights, WeightsEnum
from .._api import register_model, Weights, WeightsEnum
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _ovewrite_named_param, handle_legacy_interface
from ..mobilenetv3 import (
......@@ -184,6 +184,7 @@ class MobileNet_V3_Large_QuantizedWeights(WeightsEnum):
DEFAULT = IMAGENET1K_QNNPACK_V1
@register_model(name="quantized_mobilenet_v3_large")
@handle_legacy_interface(
weights=(
"pretrained",
......
......@@ -15,7 +15,7 @@ from torchvision.models.resnet import (
)
from ...transforms._presets import ImageClassification
from .._api import Weights, WeightsEnum
from .._api import register_model, Weights, WeightsEnum
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _ovewrite_named_param, handle_legacy_interface
from .utils import _fuse_modules, _replace_relu, quantize_model
......@@ -268,6 +268,7 @@ class ResNeXt101_64X4D_QuantizedWeights(WeightsEnum):
DEFAULT = IMAGENET1K_FBGEMM_V1
@register_model(name="quantized_resnet18")
@handle_legacy_interface(
weights=(
"pretrained",
......@@ -317,6 +318,7 @@ def resnet18(
return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs)
@register_model(name="quantized_resnet50")
@handle_legacy_interface(
weights=(
"pretrained",
......@@ -366,6 +368,7 @@ def resnet50(
return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs)
@register_model(name="quantized_resnext101_32x8d")
@handle_legacy_interface(
weights=(
"pretrained",
......@@ -417,6 +420,7 @@ def resnext101_32x8d(
return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)
@register_model(name="quantized_resnext101_64x4d")
def resnext101_64x4d(
*,
weights: Optional[Union[ResNeXt101_64X4D_QuantizedWeights, ResNeXt101_64X4D_Weights]] = None,
......
......@@ -7,7 +7,7 @@ from torch import Tensor
from torchvision.models import shufflenetv2
from ...transforms._presets import ImageClassification
from .._api import Weights, WeightsEnum
from .._api import register_model, Weights, WeightsEnum
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _ovewrite_named_param, handle_legacy_interface
from ..shufflenetv2 import (
......@@ -203,6 +203,7 @@ class ShuffleNet_V2_X2_0_QuantizedWeights(WeightsEnum):
DEFAULT = IMAGENET1K_FBGEMM_V1
@register_model(name="quantized_shufflenet_v2_x0_5")
@handle_legacy_interface(
weights=(
"pretrained",
......@@ -256,6 +257,7 @@ def shufflenet_v2_x0_5(
)
@register_model(name="quantized_shufflenet_v2_x1_0")
@handle_legacy_interface(
weights=(
"pretrained",
......@@ -309,6 +311,7 @@ def shufflenet_v2_x1_0(
)
@register_model(name="quantized_shufflenet_v2_x1_5")
def shufflenet_v2_x1_5(
*,
weights: Optional[Union[ShuffleNet_V2_X1_5_QuantizedWeights, ShuffleNet_V2_X1_5_Weights]] = None,
......@@ -354,6 +357,7 @@ def shufflenet_v2_x1_5(
)
@register_model(name="quantized_shufflenet_v2_x2_0")
def shufflenet_v2_x2_0(
*,
weights: Optional[Union[ShuffleNet_V2_X2_0_QuantizedWeights, ShuffleNet_V2_X2_0_Weights]] = None,
......
......@@ -9,7 +9,7 @@ from torch import nn, Tensor
from ..ops.misc import Conv2dNormActivation, SqueezeExcitation
from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
......@@ -1101,6 +1101,7 @@ class RegNet_X_32GF_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V2
@register_model()
@handle_legacy_interface(weights=("pretrained", RegNet_Y_400MF_Weights.IMAGENET1K_V1))
def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
......@@ -1126,6 +1127,7 @@ def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress
return _regnet(params, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", RegNet_Y_800MF_Weights.IMAGENET1K_V1))
def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
......@@ -1151,6 +1153,7 @@ def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress
return _regnet(params, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", RegNet_Y_1_6GF_Weights.IMAGENET1K_V1))
def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
......@@ -1178,6 +1181,7 @@ def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress
return _regnet(params, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", RegNet_Y_3_2GF_Weights.IMAGENET1K_V1))
def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
......@@ -1205,6 +1209,7 @@ def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress
return _regnet(params, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", RegNet_Y_8GF_Weights.IMAGENET1K_V1))
def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
......@@ -1232,6 +1237,7 @@ def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bo
return _regnet(params, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", RegNet_Y_16GF_Weights.IMAGENET1K_V1))
def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
......@@ -1259,6 +1265,7 @@ def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress:
return _regnet(params, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", RegNet_Y_32GF_Weights.IMAGENET1K_V1))
def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
......@@ -1286,6 +1293,7 @@ def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress:
return _regnet(params, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", None))
def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
......@@ -1313,6 +1321,7 @@ def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress
return _regnet(params, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.IMAGENET1K_V1))
def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
......@@ -1338,6 +1347,7 @@ def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress
return _regnet(params, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", RegNet_X_800MF_Weights.IMAGENET1K_V1))
def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
......@@ -1363,6 +1373,7 @@ def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress
return _regnet(params, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", RegNet_X_1_6GF_Weights.IMAGENET1K_V1))
def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
......@@ -1392,6 +1403,7 @@ def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress
return _regnet(params, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", RegNet_X_3_2GF_Weights.IMAGENET1K_V1))
def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
......@@ -1421,6 +1433,7 @@ def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress
return _regnet(params, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", RegNet_X_8GF_Weights.IMAGENET1K_V1))
def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
......@@ -1450,6 +1463,7 @@ def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bo
return _regnet(params, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", RegNet_X_16GF_Weights.IMAGENET1K_V1))
def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
......@@ -1479,6 +1493,7 @@ def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress:
return _regnet(params, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", RegNet_X_32GF_Weights.IMAGENET1K_V1))
def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
......
......@@ -7,7 +7,7 @@ from torch import Tensor
from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface
......@@ -645,6 +645,7 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V2
@register_model()
@handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1))
def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
"""ResNet-18 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
......@@ -670,6 +671,7 @@ def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = Tru
return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1))
def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
"""ResNet-34 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
......@@ -695,6 +697,7 @@ def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = Tru
return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1))
def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
"""ResNet-50 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
......@@ -726,6 +729,7 @@ def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = Tru
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1))
def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
"""ResNet-101 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
......@@ -757,6 +761,7 @@ def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = T
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1))
def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
"""ResNet-152 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
......@@ -788,6 +793,7 @@ def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = T
return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", ResNeXt50_32X4D_Weights.IMAGENET1K_V1))
def resnext50_32x4d(
*, weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any
......@@ -817,6 +823,7 @@ def resnext50_32x4d(
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", ResNeXt101_32X8D_Weights.IMAGENET1K_V1))
def resnext101_32x8d(
*, weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any
......@@ -846,6 +853,7 @@ def resnext101_32x8d(
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
@register_model()
def resnext101_64x4d(
*, weights: Optional[ResNeXt101_64X4D_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
......@@ -874,6 +882,7 @@ def resnext101_64x4d(
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", Wide_ResNet50_2_Weights.IMAGENET1K_V1))
def wide_resnet50_2(
*, weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any
......@@ -907,6 +916,7 @@ def wide_resnet50_2(
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", Wide_ResNet101_2_Weights.IMAGENET1K_V1))
def wide_resnet101_2(
*, weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any
......
......@@ -6,7 +6,7 @@ from torch import nn
from torch.nn import functional as F
from ...transforms._presets import SemanticSegmentation
from .._api import Weights, WeightsEnum
from .._api import register_model, Weights, WeightsEnum
from .._meta import _VOC_CATEGORIES
from .._utils import _ovewrite_value_param, handle_legacy_interface, IntermediateLayerGetter
from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights, MobileNetV3
......@@ -218,6 +218,7 @@ def _deeplabv3_mobilenetv3(
return DeepLabV3(backbone, classifier, aux_classifier)
@register_model()
@handle_legacy_interface(
weights=("pretrained", DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
......@@ -273,6 +274,7 @@ def deeplabv3_resnet50(
return model
@register_model()
@handle_legacy_interface(
weights=("pretrained", DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1),
weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1),
......@@ -328,6 +330,7 @@ def deeplabv3_resnet101(
return model
@register_model()
@handle_legacy_interface(
weights=("pretrained", DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
......
......@@ -4,7 +4,7 @@ from typing import Any, Optional
from torch import nn
from ...transforms._presets import SemanticSegmentation
from .._api import Weights, WeightsEnum
from .._api import register_model, Weights, WeightsEnum
from .._meta import _VOC_CATEGORIES
from .._utils import _ovewrite_value_param, handle_legacy_interface, IntermediateLayerGetter
from ..resnet import ResNet, resnet101, ResNet101_Weights, resnet50, ResNet50_Weights
......@@ -110,6 +110,7 @@ def _fcn_resnet(
return FCN(backbone, classifier, aux_classifier)
@register_model()
@handle_legacy_interface(
weights=("pretrained", FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
......@@ -168,6 +169,7 @@ def fcn_resnet50(
return model
@register_model()
@handle_legacy_interface(
weights=("pretrained", FCN_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1),
weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1),
......
......@@ -7,7 +7,7 @@ from torch.nn import functional as F
from ...transforms._presets import SemanticSegmentation
from ...utils import _log_api_usage_once
from .._api import Weights, WeightsEnum
from .._api import register_model, Weights, WeightsEnum
from .._meta import _VOC_CATEGORIES
from .._utils import _ovewrite_value_param, handle_legacy_interface, IntermediateLayerGetter
from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights, MobileNetV3
......@@ -117,6 +117,7 @@ class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum):
DEFAULT = COCO_WITH_VOC_LABELS_V1
@register_model()
@handle_legacy_interface(
weights=("pretrained", LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
......
......@@ -7,7 +7,7 @@ from torch import Tensor
from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface
......@@ -276,6 +276,7 @@ class ShuffleNet_V2_X2_0_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1))
def shufflenet_v2_x0_5(
*, weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any
......@@ -306,6 +307,7 @@ def shufflenet_v2_x0_5(
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1))
def shufflenet_v2_x1_0(
*, weights: Optional[ShuffleNet_V2_X1_0_Weights] = None, progress: bool = True, **kwargs: Any
......@@ -336,6 +338,7 @@ def shufflenet_v2_x1_0(
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1))
def shufflenet_v2_x1_5(
*, weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any
......@@ -366,6 +369,7 @@ def shufflenet_v2_x1_5(
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1))
def shufflenet_v2_x2_0(
*, weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any
......
......@@ -7,7 +7,7 @@ import torch.nn.init as init
from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface
......@@ -159,6 +159,7 @@ class SqueezeNet1_1_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", SqueezeNet1_0_Weights.IMAGENET1K_V1))
def squeezenet1_0(
*, weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any
......@@ -187,6 +188,7 @@ def squeezenet1_0(
return _squeezenet("1_0", weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", SqueezeNet1_1_Weights.IMAGENET1K_V1))
def squeezenet1_1(
*, weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any
......
......@@ -9,7 +9,7 @@ from ..ops.misc import MLP, Permute
from ..ops.stochastic_depth import StochasticDepth
from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param
......@@ -515,6 +515,7 @@ class Swin_B_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1
@register_model()
def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
"""
Constructs a swin_tiny architecture from
......@@ -551,6 +552,7 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, *
)
@register_model()
def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
"""
Constructs a swin_small architecture from
......@@ -587,6 +589,7 @@ def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, *
)
@register_model()
def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
"""
Constructs a swin_base architecture from
......
......@@ -6,7 +6,7 @@ import torch.nn as nn
from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface
......@@ -285,6 +285,7 @@ class VGG19_BN_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", VGG11_Weights.IMAGENET1K_V1))
def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
"""VGG-11 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
......@@ -310,6 +311,7 @@ def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **k
return _vgg("A", False, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", VGG11_BN_Weights.IMAGENET1K_V1))
def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
"""VGG-11-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
......@@ -335,6 +337,7 @@ def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = Tru
return _vgg("A", True, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", VGG13_Weights.IMAGENET1K_V1))
def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
"""VGG-13 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
......@@ -360,6 +363,7 @@ def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **k
return _vgg("B", False, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", VGG13_BN_Weights.IMAGENET1K_V1))
def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
"""VGG-13-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
......@@ -385,6 +389,7 @@ def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = Tru
return _vgg("B", True, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", VGG16_Weights.IMAGENET1K_V1))
def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
"""VGG-16 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
......@@ -410,6 +415,7 @@ def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **k
return _vgg("D", False, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", VGG16_BN_Weights.IMAGENET1K_V1))
def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
"""VGG-16-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
......@@ -435,6 +441,7 @@ def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = Tru
return _vgg("D", True, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", VGG19_Weights.IMAGENET1K_V1))
def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
"""VGG-19 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
......@@ -460,6 +467,7 @@ def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **k
return _vgg("E", False, weights, progress, **kwargs)
@register_model()
@handle_legacy_interface(weights=("pretrained", VGG19_BN_Weights.IMAGENET1K_V1))
def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
"""VGG-19_BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment