Unverified Commit 588e9b5e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

simplify model builders (#5001)



* simplify model builders

* cleanup

* refactor kwonly to pos or kw handling

* put weight verification back

* revert num categories checks

* fix default weights

* cleanup

* remove manual parameter map

* refactor decorator interface

* address review comments

* cleanup

* refactor callable default

* fix type annotation

* process ungrouped models

* cleanup

* mroe cleanup

* use decorator for detection models

* add decorator for quantization models

* add decorator for segmentation  models

* add decorator for video  models

* remove old helpers

* fix resnet50

* Adding verification back on InceptionV3

* Add kwargs in DeeplabeV3

* Add kwargs on FCN

* Fix typing on Deeplab

* Fix typing on FCN
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 3ceaff14
......@@ -5,6 +5,8 @@ import test_models as TM
import torch
from common_utils import cpu_and_gpu, run_on_env_var
from torchvision.prototype import models
from torchvision.prototype.models._api import WeightsEnum, Weights
from torchvision.prototype.models._utils import handle_legacy_interface
run_if_test_with_prototype = run_on_env_var(
"PYTORCH_TEST_WITH_PROTOTYPE",
......@@ -164,3 +166,87 @@ def test_old_vs_new_factory(model_fn, dev):
def test_smoke():
import torchvision.prototype.models # noqa: F401
# With this filter, every unexpected warning will be turned into an error
@pytest.mark.filterwarnings("error")
class TestHandleLegacyInterface:
class TestWeights(WeightsEnum):
Sentinel = Weights(url="https://pytorch.org", transforms=lambda x: x, meta=dict())
@pytest.mark.parametrize(
"kwargs",
[
pytest.param(dict(), id="empty"),
pytest.param(dict(weights=None), id="None"),
pytest.param(dict(weights=TestWeights.Sentinel), id="Weights"),
],
)
def test_no_warn(self, kwargs):
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
def builder(*, weights=None):
pass
builder(**kwargs)
@pytest.mark.parametrize("pretrained", (True, False))
def test_pretrained_pos(self, pretrained):
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
def builder(*, weights=None):
pass
with pytest.warns(UserWarning, match="positional"):
builder(pretrained)
@pytest.mark.parametrize("pretrained", (True, False))
def test_pretrained_kw(self, pretrained):
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
def builder(*, weights=None):
pass
with pytest.warns(UserWarning, match="deprecated"):
builder(pretrained)
@pytest.mark.parametrize("pretrained", (True, False))
@pytest.mark.parametrize("positional", (True, False))
def test_equivalent_behavior_weights(self, pretrained, positional):
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
def builder(*, weights=None):
pass
args, kwargs = ((pretrained,), dict()) if positional else ((), dict(pretrained=pretrained))
with pytest.warns(UserWarning, match=f"weights={self.TestWeights.Sentinel if pretrained else None}"):
builder(*args, **kwargs)
def test_multi_params(self):
weights_params = ("weights", "weights_other")
pretrained_params = [param.replace("weights", "pretrained") for param in weights_params]
@handle_legacy_interface(
**{
weights_param: (pretrained_param, self.TestWeights.Sentinel)
for weights_param, pretrained_param in zip(weights_params, pretrained_params)
}
)
def builder(*, weights=None, weights_other=None):
pass
for pretrained_param in pretrained_params:
with pytest.warns(UserWarning, match="deprecated"):
builder(**{pretrained_param: True})
def test_default_callable(self):
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: self.TestWeights.Sentinel if kwargs["flag"] else None,
)
)
def builder(*, weights=None, flag):
pass
with pytest.warns(UserWarning, match="deprecated"):
builder(pretrained=True, flag=True)
with pytest.raises(ValueError, match="weights"):
builder(pretrained=True, flag=False)
import functools
import warnings
from typing import Any, Dict, Optional, TypeVar
from typing import Any, Dict, Optional, TypeVar, Callable, Tuple, Union
from ._api import WeightsEnum
from torch import nn
from torchvision.prototype.utils._internal import kwonly_to_pos_or_kw
from ._api import WeightsEnum
W = TypeVar("W", bound=WeightsEnum)
M = TypeVar("M", bound=nn.Module)
V = TypeVar("V")
def _deprecated_param(
kwargs: Dict[str, Any], deprecated_param: str, new_param: str, default_value: Optional[W]
) -> Optional[W]:
warnings.warn(f"The parameter '{deprecated_param}' is deprecated, please use '{new_param}' instead.")
if kwargs.pop(deprecated_param):
if default_value is not None:
return default_value
else:
raise ValueError("No checkpoint is available for model.")
else:
return None
def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]):
"""Decorates a model builder with the new interface to make it compatible with the old.
In particular this handles two things:
1. Allows positional parameters again, but emits a deprecation warning in case they are used. See
:func:`torchvision.prototype.utils._internal.kwonly_to_pos_or_kw` for details.
2. Handles the default value change from ``pretrained=False`` to ``weights=None`` and ``pretrained=True`` to
``weights=Weights`` and emits a deprecation warning with instructions for the new interface.
Args:
**weights (Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): Deprecated parameter
name and default value for the legacy ``pretrained=True``. The default value can be a callable in which
case it will be called with a dictionary of the keyword arguments. The only key that is guaranteed to be in
the dictionary is the deprecated parameter name passed as first element in the tuple. All other parameters
should be accessed with :meth:`~dict.get`.
"""
def outer_wrapper(builder: Callable[..., M]) -> Callable[..., M]:
@kwonly_to_pos_or_kw
@functools.wraps(builder)
def inner_wrapper(*args: Any, **kwargs: Any) -> M:
for weights_param, (pretrained_param, default) in weights.items(): # type: ignore[union-attr]
# If neither the weights nor the pretrained parameter as passed, or the weights argument already use
# the new style arguments, there is nothing to do. Note that we cannot use `None` as sentinel for the
# weight argument, since it is a valid value.
sentinel = object()
weights_arg = kwargs.get(weights_param, sentinel)
if (
(weights_param not in kwargs and pretrained_param not in kwargs)
or isinstance(weights_arg, WeightsEnum)
or weights_arg is None
):
continue
# If the pretrained parameter was passed as positional argument, it is now mapped to
# `kwargs[weights_param]`. This happens because the @kwonly_to_pos_or_kw decorator uses the current
# signature to infer the names of positionally passed arguments and thus has no knowledge that there
# used to be a pretrained parameter.
pretrained_positional = weights_arg is not sentinel
if pretrained_positional:
# We put the pretrained argument under its legacy name in the keyword argument dictionary to have a
# unified access to the value if the default value is a callable.
kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param)
else:
pretrained_arg = kwargs[pretrained_param]
if pretrained_arg:
default_weights_arg = default(kwargs) if callable(default) else default
if not isinstance(default_weights_arg, WeightsEnum):
raise ValueError(f"No weights available for model {builder.__name__}")
else:
default_weights_arg = None
if not pretrained_positional:
warnings.warn(
f"The parameter '{pretrained_param}' is deprecated, please use '{weights_param}' instead."
)
msg = (
f"Arguments other than a weight enum or `None` for '{weights_param}' are deprecated. "
f"The current behavior is equivalent to passing `{weights_param}={default_weights_arg}`."
)
if pretrained_arg:
msg = (
f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.default` "
f"to get the most up-to-date weights."
)
warnings.warn(msg)
del kwargs[pretrained_param]
kwargs[weights_param] = default_weights_arg
return builder(*args, **kwargs)
return inner_wrapper
def _deprecated_positional(kwargs: Dict[str, Any], deprecated_param: str, new_param: str, default_value: V) -> None:
warnings.warn(
f"The positional parameter '{deprecated_param}' is deprecated, please use keyword parameter '{new_param}'"
+ " instead."
)
kwargs[deprecated_param] = default_value
return outer_wrapper
def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None:
......
......@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.alexnet import AlexNet
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]
......@@ -29,11 +29,8 @@ class AlexNet_Weights(WeightsEnum):
default = ImageNet1K_V1
def alexnet(weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", AlexNet_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.ImageNet1K_V1))
def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
weights = AlexNet_Weights.verify(weights)
if weights is not None:
......
......@@ -9,7 +9,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.densenet import DenseNet
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [
......@@ -123,41 +123,29 @@ class DenseNet201_Weights(WeightsEnum):
default = ImageNet1K_V1
def densenet121(weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet121_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.ImageNet1K_V1))
def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet121_Weights.verify(weights)
return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
def densenet161(weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet161_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.ImageNet1K_V1))
def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet161_Weights.verify(weights)
return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)
def densenet169(weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet169_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.ImageNet1K_V1))
def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet169_Weights.verify(weights)
return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)
def densenet201(weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet201_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.ImageNet1K_V1))
def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet201_Weights.verify(weights)
return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)
......@@ -14,7 +14,7 @@ from ....models.detection.faster_rcnn import (
)
from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large
from ..resnet import ResNet50_Weights, resnet50
......@@ -75,7 +75,12 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
default = Coco_V1
@handle_legacy_interface(
weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
)
def fasterrcnn_resnet50_fpn(
*,
weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
......@@ -83,17 +88,7 @@ def fasterrcnn_resnet50_fpn(
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNN_ResNet50_FPN_Weights.Coco_V1)
weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1
)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
......@@ -119,6 +114,7 @@ def fasterrcnn_resnet50_fpn(
def _fasterrcnn_mobilenet_v3_large_fpn(
*,
weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]],
progress: bool,
num_classes: Optional[int],
......@@ -158,7 +154,12 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
return model
@handle_legacy_interface(
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1),
)
def fasterrcnn_mobilenet_v3_large_fpn(
*,
weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
......@@ -166,17 +167,7 @@ def fasterrcnn_mobilenet_v3_large_fpn(
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNN_MobileNet_V3_Large_FPN_Weights.Coco_V1)
weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1
)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
defaults = {
......@@ -185,16 +176,21 @@ def fasterrcnn_mobilenet_v3_large_fpn(
kwargs = {**defaults, **kwargs}
return _fasterrcnn_mobilenet_v3_large_fpn(
weights,
progress,
num_classes,
weights_backbone,
trainable_backbone_layers,
weights=weights,
progress=progress,
num_classes=num_classes,
weights_backbone=weights_backbone,
trainable_backbone_layers=trainable_backbone_layers,
**kwargs,
)
@handle_legacy_interface(
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1),
)
def fasterrcnn_mobilenet_v3_large_320_fpn(
*,
weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
......@@ -202,19 +198,8 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(
kwargs, "pretrained", "weights", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.Coco_V1
)
weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1
)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
defaults = {
......@@ -227,10 +212,10 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
kwargs = {**defaults, **kwargs}
return _fasterrcnn_mobilenet_v3_large_fpn(
weights,
progress,
num_classes,
weights_backbone,
trainable_backbone_layers,
weights=weights,
progress=progress,
num_classes=num_classes,
weights_backbone=weights_backbone,
trainable_backbone_layers=trainable_backbone_layers,
**kwargs,
)
......@@ -11,7 +11,7 @@ from ....models.detection.keypoint_rcnn import (
)
from .._api import WeightsEnum, Weights
from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..resnet import ResNet50_Weights, resnet50
......@@ -49,7 +49,17 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
default = Coco_V1
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.Coco_Legacy
if kwargs["pretrained"] == "legacy"
else KeypointRCNN_ResNet50_FPN_Weights.Coco_V1,
),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
)
def keypointrcnn_resnet50_fpn(
*,
weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
......@@ -58,21 +68,7 @@ def keypointrcnn_resnet50_fpn(
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> KeypointRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = KeypointRCNN_ResNet50_FPN_Weights.Coco_V1
if kwargs["pretrained"] == "legacy":
default_value = KeypointRCNN_ResNet50_FPN_Weights.Coco_Legacy
kwargs["pretrained"] = True
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value)
weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1
)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
......
......@@ -12,7 +12,7 @@ from ....models.detection.mask_rcnn import (
)
from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..resnet import ResNet50_Weights, resnet50
......@@ -38,7 +38,12 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
default = Coco_V1
@handle_legacy_interface(
weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
)
def maskrcnn_resnet50_fpn(
*,
weights: Optional[MaskRCNN_ResNet50_FPN_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
......@@ -46,17 +51,7 @@ def maskrcnn_resnet50_fpn(
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> MaskRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MaskRCNN_ResNet50_FPN_Weights.Coco_V1)
weights = MaskRCNN_ResNet50_FPN_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1
)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
......
......@@ -13,7 +13,7 @@ from ....models.detection.retinanet import (
)
from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..resnet import ResNet50_Weights, resnet50
......@@ -38,7 +38,12 @@ class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
default = Coco_V1
@handle_legacy_interface(
weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
)
def retinanet_resnet50_fpn(
*,
weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
......@@ -46,17 +51,7 @@ def retinanet_resnet50_fpn(
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> RetinaNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RetinaNet_ResNet50_FPN_Weights.Coco_V1)
weights = RetinaNet_ResNet50_FPN_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1
)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
......
......@@ -12,7 +12,7 @@ from ....models.detection.ssd import (
)
from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..vgg import VGG16_Weights, vgg16
......@@ -37,7 +37,12 @@ class SSD300_VGG16_Weights(WeightsEnum):
default = Coco_V1
@handle_legacy_interface(
weights=("pretrained", SSD300_VGG16_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", VGG16_Weights.ImageNet1K_Features),
)
def ssd300_vgg16(
*,
weights: Optional[SSD300_VGG16_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
......@@ -45,17 +50,7 @@ def ssd300_vgg16(
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> SSD:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", SSD300_VGG16_Weights.Coco_V1)
weights = SSD300_VGG16_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", VGG16_Weights.ImageNet1K_Features
)
weights_backbone = VGG16_Weights.verify(weights_backbone)
if "size" in kwargs:
......
......@@ -17,7 +17,7 @@ from ....models.detection.ssdlite import (
)
from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large
......@@ -42,7 +42,12 @@ class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum):
default = Coco_V1
@handle_legacy_interface(
weights=("pretrained", SSDLite320_MobileNet_V3_Large_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1),
)
def ssdlite320_mobilenet_v3_large(
*,
weights: Optional[SSDLite320_MobileNet_V3_Large_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
......@@ -51,17 +56,7 @@ def ssdlite320_mobilenet_v3_large(
norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any,
) -> SSD:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", SSDLite320_MobileNet_V3_Large_Weights.Coco_V1)
weights = SSDLite320_MobileNet_V3_Large_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1
)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
if "size" in kwargs:
......
......@@ -8,7 +8,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.efficientnet import EfficientNet, MBConvConfig
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [
......@@ -181,73 +181,55 @@ class EfficientNet_B7_Weights(WeightsEnum):
default = ImageNet1K_V1
@handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.ImageNet1K_V1))
def efficientnet_b0(
weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any
*, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B0_Weights.ImageNet1K_V1)
weights = EfficientNet_B0_Weights.verify(weights)
return _efficientnet(width_mult=1.0, depth_mult=1.0, dropout=0.2, weights=weights, progress=progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.ImageNet1K_V1))
def efficientnet_b1(
weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any
*, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B1_Weights.ImageNet1K_V1)
weights = EfficientNet_B1_Weights.verify(weights)
return _efficientnet(width_mult=1.0, depth_mult=1.1, dropout=0.2, weights=weights, progress=progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.ImageNet1K_V1))
def efficientnet_b2(
weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any
*, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B2_Weights.ImageNet1K_V1)
weights = EfficientNet_B2_Weights.verify(weights)
return _efficientnet(width_mult=1.1, depth_mult=1.2, dropout=0.3, weights=weights, progress=progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.ImageNet1K_V1))
def efficientnet_b3(
weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any
*, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B3_Weights.ImageNet1K_V1)
weights = EfficientNet_B3_Weights.verify(weights)
return _efficientnet(width_mult=1.2, depth_mult=1.4, dropout=0.3, weights=weights, progress=progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.ImageNet1K_V1))
def efficientnet_b4(
weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any
*, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B4_Weights.ImageNet1K_V1)
weights = EfficientNet_B4_Weights.verify(weights)
return _efficientnet(width_mult=1.4, depth_mult=1.8, dropout=0.4, weights=weights, progress=progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.ImageNet1K_V1))
def efficientnet_b5(
weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any
*, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B5_Weights.ImageNet1K_V1)
weights = EfficientNet_B5_Weights.verify(weights)
return _efficientnet(
......@@ -261,13 +243,10 @@ def efficientnet_b5(
)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.ImageNet1K_V1))
def efficientnet_b6(
weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any
*, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B6_Weights.ImageNet1K_V1)
weights = EfficientNet_B6_Weights.verify(weights)
return _efficientnet(
......@@ -281,13 +260,10 @@ def efficientnet_b6(
)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.ImageNet1K_V1))
def efficientnet_b7(
weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any
*, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B7_Weights.ImageNet1K_V1)
weights = EfficientNet_B7_Weights.verify(weights)
return _efficientnet(
......
......@@ -8,7 +8,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weights", "googlenet"]
......@@ -30,11 +30,8 @@ class GoogLeNet_Weights(WeightsEnum):
default = ImageNet1K_V1
def googlenet(weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", GoogLeNet_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", GoogLeNet_Weights.ImageNet1K_V1))
def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
weights = GoogLeNet_Weights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", False)
......
......@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_Weights", "inception_v3"]
......@@ -29,11 +29,8 @@ class Inception_V3_Weights(WeightsEnum):
default = ImageNet1K_V1
def inception_v3(weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", Inception_V3_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", Inception_V3_Weights.ImageNet1K_V1))
def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
weights = Inception_V3_Weights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", True)
......
......@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.mnasnet import MNASNet
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [
......@@ -79,41 +79,29 @@ def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwa
return model
def mnasnet0_5(weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet0_5_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", MNASNet0_5_Weights.ImageNet1K_V1))
def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
weights = MNASNet0_5_Weights.verify(weights)
return _mnasnet(0.5, weights, progress, **kwargs)
def mnasnet0_75(weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
@handle_legacy_interface(weights=("pretrained", None))
def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
weights = MNASNet0_75_Weights.verify(weights)
return _mnasnet(0.75, weights, progress, **kwargs)
def mnasnet1_0(weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet1_0_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", MNASNet1_0_Weights.ImageNet1K_V1))
def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
weights = MNASNet1_0_Weights.verify(weights)
return _mnasnet(1.0, weights, progress, **kwargs)
def mnasnet1_3(weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
@handle_legacy_interface(weights=("pretrained", None))
def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
weights = MNASNet1_3_Weights.verify(weights)
return _mnasnet(1.3, weights, progress, **kwargs)
......@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.mobilenetv2 import MobileNetV2
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"]
......@@ -29,11 +29,10 @@ class MobileNet_V2_Weights(WeightsEnum):
default = ImageNet1K_V1
def mobilenet_v2(weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNet_V2_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.ImageNet1K_V1))
def mobilenet_v2(
*, weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV2:
weights = MobileNet_V2_Weights.verify(weights)
if weights is not None:
......
......@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [
......@@ -82,26 +82,20 @@ class MobileNet_V3_Small_Weights(WeightsEnum):
default = ImageNet1K_V1
@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Large_Weights.ImageNet1K_V1))
def mobilenet_v3_large(
weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any
*, weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNet_V3_Large_Weights.ImageNet1K_V1)
weights = MobileNet_V3_Large_Weights.verify(weights)
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Small_Weights.ImageNet1K_V1))
def mobilenet_v3_small(
weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any
*, weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNet_V3_Small_Weights.ImageNet1K_V1)
weights = MobileNet_V3_Small_Weights.verify(weights)
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs)
......
......@@ -12,7 +12,7 @@ from ....models.quantization.googlenet import (
)
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..googlenet import GoogLeNet_Weights
......@@ -42,21 +42,22 @@ class GoogLeNet_QuantizedWeights(WeightsEnum):
default = ImageNet1K_FBGEMM_V1
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: GoogLeNet_QuantizedWeights.ImageNet1K_FBGEMM_V1
if kwargs.get("quantize", False)
else GoogLeNet_Weights.ImageNet1K_V1,
)
)
def googlenet(
*,
weights: Optional[Union[GoogLeNet_QuantizedWeights, GoogLeNet_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableGoogLeNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = GoogLeNet_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else GoogLeNet_Weights.ImageNet1K_V1
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = GoogLeNet_QuantizedWeights.verify(weights)
else:
weights = GoogLeNet_Weights.verify(weights)
weights = (GoogLeNet_QuantizedWeights if quantize else GoogLeNet_Weights).verify(weights)
original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None:
......
......@@ -11,7 +11,7 @@ from ....models.quantization.inception import (
)
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..inception import Inception_V3_Weights
......@@ -41,23 +41,22 @@ class Inception_V3_QuantizedWeights(WeightsEnum):
default = ImageNet1K_FBGEMM_V1
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: Inception_V3_QuantizedWeights.ImageNet1K_FBGEMM_V1
if kwargs.get("quantize", False)
else Inception_V3_Weights.ImageNet1K_V1,
)
)
def inception_v3(
*,
weights: Optional[Union[Inception_V3_QuantizedWeights, Inception_V3_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableInception3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = (
Inception_V3_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else Inception_V3_Weights.ImageNet1K_V1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = Inception_V3_QuantizedWeights.verify(weights)
else:
weights = Inception_V3_Weights.verify(weights)
weights = (Inception_V3_QuantizedWeights if quantize else Inception_V3_Weights).verify(weights)
original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None:
......
......@@ -12,7 +12,7 @@ from ....models.quantization.mobilenetv2 import (
)
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..mobilenetv2 import MobileNet_V2_Weights
......@@ -42,23 +42,22 @@ class MobileNet_V2_QuantizedWeights(WeightsEnum):
default = ImageNet1K_QNNPACK_V1
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: MobileNet_V2_QuantizedWeights.ImageNet1K_QNNPACK_V1
if kwargs.get("quantize", False)
else MobileNet_V2_Weights.ImageNet1K_V1,
)
)
def mobilenet_v2(
*,
weights: Optional[Union[MobileNet_V2_QuantizedWeights, MobileNet_V2_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableMobileNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = (
MobileNet_V2_QuantizedWeights.ImageNet1K_QNNPACK_V1 if quantize else MobileNet_V2_Weights.ImageNet1K_V1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = MobileNet_V2_QuantizedWeights.verify(weights)
else:
weights = MobileNet_V2_Weights.verify(weights)
weights = (MobileNet_V2_QuantizedWeights if quantize else MobileNet_V2_Weights).verify(weights)
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
......
......@@ -13,7 +13,7 @@ from ....models.quantization.mobilenetv3 import (
)
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..mobilenetv3 import MobileNet_V3_Large_Weights, _mobilenet_v3_conf
......@@ -75,25 +75,22 @@ class MobileNet_V3_Large_QuantizedWeights(WeightsEnum):
default = ImageNet1K_QNNPACK_V1
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: MobileNet_V3_Large_QuantizedWeights.ImageNet1K_QNNPACK_V1
if kwargs.get("quantize", False)
else MobileNet_V3_Large_Weights.ImageNet1K_V1,
)
)
def mobilenet_v3_large(
*,
weights: Optional[Union[MobileNet_V3_Large_QuantizedWeights, MobileNet_V3_Large_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableMobileNetV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = (
MobileNet_V3_Large_QuantizedWeights.ImageNet1K_QNNPACK_V1
if quantize
else MobileNet_V3_Large_Weights.ImageNet1K_V1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = MobileNet_V3_Large_QuantizedWeights.verify(weights)
else:
weights = MobileNet_V3_Large_Weights.verify(weights)
weights = (MobileNet_V3_Large_QuantizedWeights if quantize else MobileNet_V3_Large_Weights).verify(weights)
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment