Unverified Commit 588e9b5e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

simplify model builders (#5001)



* simplify model builders

* cleanup

* refactor kwonly to pos or kw handling

* put weight verification back

* revert num categories checks

* fix default weights

* cleanup

* remove manual parameter map

* refactor decorator interface

* address review comments

* cleanup

* refactor callable default

* fix type annotation

* process ungrouped models

* cleanup

* mroe cleanup

* use decorator for detection models

* add decorator for quantization models

* add decorator for segmentation  models

* add decorator for video  models

* remove old helpers

* fix resnet50

* Adding verification back on InceptionV3

* Add kwargs in DeeplabeV3

* Add kwargs on FCN

* Fix typing on Deeplab

* Fix typing on FCN
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 3ceaff14
...@@ -5,6 +5,8 @@ import test_models as TM ...@@ -5,6 +5,8 @@ import test_models as TM
import torch import torch
from common_utils import cpu_and_gpu, run_on_env_var from common_utils import cpu_and_gpu, run_on_env_var
from torchvision.prototype import models from torchvision.prototype import models
from torchvision.prototype.models._api import WeightsEnum, Weights
from torchvision.prototype.models._utils import handle_legacy_interface
run_if_test_with_prototype = run_on_env_var( run_if_test_with_prototype = run_on_env_var(
"PYTORCH_TEST_WITH_PROTOTYPE", "PYTORCH_TEST_WITH_PROTOTYPE",
...@@ -164,3 +166,87 @@ def test_old_vs_new_factory(model_fn, dev): ...@@ -164,3 +166,87 @@ def test_old_vs_new_factory(model_fn, dev):
def test_smoke(): def test_smoke():
import torchvision.prototype.models # noqa: F401 import torchvision.prototype.models # noqa: F401
# With this filter, every unexpected warning will be turned into an error
@pytest.mark.filterwarnings("error")
class TestHandleLegacyInterface:
class TestWeights(WeightsEnum):
Sentinel = Weights(url="https://pytorch.org", transforms=lambda x: x, meta=dict())
@pytest.mark.parametrize(
"kwargs",
[
pytest.param(dict(), id="empty"),
pytest.param(dict(weights=None), id="None"),
pytest.param(dict(weights=TestWeights.Sentinel), id="Weights"),
],
)
def test_no_warn(self, kwargs):
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
def builder(*, weights=None):
pass
builder(**kwargs)
@pytest.mark.parametrize("pretrained", (True, False))
def test_pretrained_pos(self, pretrained):
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
def builder(*, weights=None):
pass
with pytest.warns(UserWarning, match="positional"):
builder(pretrained)
@pytest.mark.parametrize("pretrained", (True, False))
def test_pretrained_kw(self, pretrained):
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
def builder(*, weights=None):
pass
with pytest.warns(UserWarning, match="deprecated"):
builder(pretrained)
@pytest.mark.parametrize("pretrained", (True, False))
@pytest.mark.parametrize("positional", (True, False))
def test_equivalent_behavior_weights(self, pretrained, positional):
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
def builder(*, weights=None):
pass
args, kwargs = ((pretrained,), dict()) if positional else ((), dict(pretrained=pretrained))
with pytest.warns(UserWarning, match=f"weights={self.TestWeights.Sentinel if pretrained else None}"):
builder(*args, **kwargs)
def test_multi_params(self):
weights_params = ("weights", "weights_other")
pretrained_params = [param.replace("weights", "pretrained") for param in weights_params]
@handle_legacy_interface(
**{
weights_param: (pretrained_param, self.TestWeights.Sentinel)
for weights_param, pretrained_param in zip(weights_params, pretrained_params)
}
)
def builder(*, weights=None, weights_other=None):
pass
for pretrained_param in pretrained_params:
with pytest.warns(UserWarning, match="deprecated"):
builder(**{pretrained_param: True})
def test_default_callable(self):
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: self.TestWeights.Sentinel if kwargs["flag"] else None,
)
)
def builder(*, weights=None, flag):
pass
with pytest.warns(UserWarning, match="deprecated"):
builder(pretrained=True, flag=True)
with pytest.raises(ValueError, match="weights"):
builder(pretrained=True, flag=False)
import functools
import warnings import warnings
from typing import Any, Dict, Optional, TypeVar from typing import Any, Dict, Optional, TypeVar, Callable, Tuple, Union
from ._api import WeightsEnum from torch import nn
from torchvision.prototype.utils._internal import kwonly_to_pos_or_kw
from ._api import WeightsEnum
W = TypeVar("W", bound=WeightsEnum) W = TypeVar("W", bound=WeightsEnum)
M = TypeVar("M", bound=nn.Module)
V = TypeVar("V") V = TypeVar("V")
def _deprecated_param( def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]):
kwargs: Dict[str, Any], deprecated_param: str, new_param: str, default_value: Optional[W] """Decorates a model builder with the new interface to make it compatible with the old.
) -> Optional[W]:
warnings.warn(f"The parameter '{deprecated_param}' is deprecated, please use '{new_param}' instead.") In particular this handles two things:
if kwargs.pop(deprecated_param):
if default_value is not None: 1. Allows positional parameters again, but emits a deprecation warning in case they are used. See
return default_value :func:`torchvision.prototype.utils._internal.kwonly_to_pos_or_kw` for details.
else: 2. Handles the default value change from ``pretrained=False`` to ``weights=None`` and ``pretrained=True`` to
raise ValueError("No checkpoint is available for model.") ``weights=Weights`` and emits a deprecation warning with instructions for the new interface.
Args:
**weights (Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): Deprecated parameter
name and default value for the legacy ``pretrained=True``. The default value can be a callable in which
case it will be called with a dictionary of the keyword arguments. The only key that is guaranteed to be in
the dictionary is the deprecated parameter name passed as first element in the tuple. All other parameters
should be accessed with :meth:`~dict.get`.
"""
def outer_wrapper(builder: Callable[..., M]) -> Callable[..., M]:
@kwonly_to_pos_or_kw
@functools.wraps(builder)
def inner_wrapper(*args: Any, **kwargs: Any) -> M:
for weights_param, (pretrained_param, default) in weights.items(): # type: ignore[union-attr]
# If neither the weights nor the pretrained parameter as passed, or the weights argument already use
# the new style arguments, there is nothing to do. Note that we cannot use `None` as sentinel for the
# weight argument, since it is a valid value.
sentinel = object()
weights_arg = kwargs.get(weights_param, sentinel)
if (
(weights_param not in kwargs and pretrained_param not in kwargs)
or isinstance(weights_arg, WeightsEnum)
or weights_arg is None
):
continue
# If the pretrained parameter was passed as positional argument, it is now mapped to
# `kwargs[weights_param]`. This happens because the @kwonly_to_pos_or_kw decorator uses the current
# signature to infer the names of positionally passed arguments and thus has no knowledge that there
# used to be a pretrained parameter.
pretrained_positional = weights_arg is not sentinel
if pretrained_positional:
# We put the pretrained argument under its legacy name in the keyword argument dictionary to have a
# unified access to the value if the default value is a callable.
kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param)
else: else:
return None pretrained_arg = kwargs[pretrained_param]
if pretrained_arg:
default_weights_arg = default(kwargs) if callable(default) else default
if not isinstance(default_weights_arg, WeightsEnum):
raise ValueError(f"No weights available for model {builder.__name__}")
else:
default_weights_arg = None
def _deprecated_positional(kwargs: Dict[str, Any], deprecated_param: str, new_param: str, default_value: V) -> None: if not pretrained_positional:
warnings.warn( warnings.warn(
f"The positional parameter '{deprecated_param}' is deprecated, please use keyword parameter '{new_param}'" f"The parameter '{pretrained_param}' is deprecated, please use '{weights_param}' instead."
+ " instead." )
msg = (
f"Arguments other than a weight enum or `None` for '{weights_param}' are deprecated. "
f"The current behavior is equivalent to passing `{weights_param}={default_weights_arg}`."
)
if pretrained_arg:
msg = (
f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.default` "
f"to get the most up-to-date weights."
) )
kwargs[deprecated_param] = default_value warnings.warn(msg)
del kwargs[pretrained_param]
kwargs[weights_param] = default_weights_arg
return builder(*args, **kwargs)
return inner_wrapper
return outer_wrapper
def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None: def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None:
......
...@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.alexnet import AlexNet from ...models.alexnet import AlexNet
from ._api import WeightsEnum, Weights from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"] __all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]
...@@ -29,11 +29,8 @@ class AlexNet_Weights(WeightsEnum): ...@@ -29,11 +29,8 @@ class AlexNet_Weights(WeightsEnum):
default = ImageNet1K_V1 default = ImageNet1K_V1
def alexnet(weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: @handle_legacy_interface(weights=("pretrained", AlexNet_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", AlexNet_Weights.ImageNet1K_V1)
weights = AlexNet_Weights.verify(weights) weights = AlexNet_Weights.verify(weights)
if weights is not None: if weights is not None:
......
...@@ -9,7 +9,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -9,7 +9,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.densenet import DenseNet from ...models.densenet import DenseNet
from ._api import WeightsEnum, Weights from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [ __all__ = [
...@@ -123,41 +123,29 @@ class DenseNet201_Weights(WeightsEnum): ...@@ -123,41 +123,29 @@ class DenseNet201_Weights(WeightsEnum):
default = ImageNet1K_V1 default = ImageNet1K_V1
def densenet121(weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: @handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet121_Weights.ImageNet1K_V1)
weights = DenseNet121_Weights.verify(weights) weights = DenseNet121_Weights.verify(weights)
return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs) return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
def densenet161(weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: @handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet161_Weights.ImageNet1K_V1)
weights = DenseNet161_Weights.verify(weights) weights = DenseNet161_Weights.verify(weights)
return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs) return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)
def densenet169(weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: @handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet169_Weights.ImageNet1K_V1)
weights = DenseNet169_Weights.verify(weights) weights = DenseNet169_Weights.verify(weights)
return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs) return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)
def densenet201(weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: @handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet201_Weights.ImageNet1K_V1)
weights = DenseNet201_Weights.verify(weights) weights = DenseNet201_Weights.verify(weights)
return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs) return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)
...@@ -14,7 +14,7 @@ from ....models.detection.faster_rcnn import ( ...@@ -14,7 +14,7 @@ from ....models.detection.faster_rcnn import (
) )
from .._api import WeightsEnum, Weights from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large
from ..resnet import ResNet50_Weights, resnet50 from ..resnet import ResNet50_Weights, resnet50
...@@ -75,7 +75,12 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): ...@@ -75,7 +75,12 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
default = Coco_V1 default = Coco_V1
@handle_legacy_interface(
weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
)
def fasterrcnn_resnet50_fpn( def fasterrcnn_resnet50_fpn(
*,
weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None, weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
...@@ -83,17 +88,7 @@ def fasterrcnn_resnet50_fpn( ...@@ -83,17 +88,7 @@ def fasterrcnn_resnet50_fpn(
trainable_backbone_layers: Optional[int] = None, trainable_backbone_layers: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> FasterRCNN: ) -> FasterRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNN_ResNet50_FPN_Weights.Coco_V1)
weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights) weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1
)
weights_backbone = ResNet50_Weights.verify(weights_backbone) weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
...@@ -119,6 +114,7 @@ def fasterrcnn_resnet50_fpn( ...@@ -119,6 +114,7 @@ def fasterrcnn_resnet50_fpn(
def _fasterrcnn_mobilenet_v3_large_fpn( def _fasterrcnn_mobilenet_v3_large_fpn(
*,
weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]], weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]],
progress: bool, progress: bool,
num_classes: Optional[int], num_classes: Optional[int],
...@@ -158,7 +154,12 @@ def _fasterrcnn_mobilenet_v3_large_fpn( ...@@ -158,7 +154,12 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
return model return model
@handle_legacy_interface(
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1),
)
def fasterrcnn_mobilenet_v3_large_fpn( def fasterrcnn_mobilenet_v3_large_fpn(
*,
weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None, weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
...@@ -166,17 +167,7 @@ def fasterrcnn_mobilenet_v3_large_fpn( ...@@ -166,17 +167,7 @@ def fasterrcnn_mobilenet_v3_large_fpn(
trainable_backbone_layers: Optional[int] = None, trainable_backbone_layers: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> FasterRCNN: ) -> FasterRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNN_MobileNet_V3_Large_FPN_Weights.Coco_V1)
weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights) weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1
)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
defaults = { defaults = {
...@@ -185,16 +176,21 @@ def fasterrcnn_mobilenet_v3_large_fpn( ...@@ -185,16 +176,21 @@ def fasterrcnn_mobilenet_v3_large_fpn(
kwargs = {**defaults, **kwargs} kwargs = {**defaults, **kwargs}
return _fasterrcnn_mobilenet_v3_large_fpn( return _fasterrcnn_mobilenet_v3_large_fpn(
weights, weights=weights,
progress, progress=progress,
num_classes, num_classes=num_classes,
weights_backbone, weights_backbone=weights_backbone,
trainable_backbone_layers, trainable_backbone_layers=trainable_backbone_layers,
**kwargs, **kwargs,
) )
@handle_legacy_interface(
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1),
)
def fasterrcnn_mobilenet_v3_large_320_fpn( def fasterrcnn_mobilenet_v3_large_320_fpn(
*,
weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None, weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
...@@ -202,19 +198,8 @@ def fasterrcnn_mobilenet_v3_large_320_fpn( ...@@ -202,19 +198,8 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
trainable_backbone_layers: Optional[int] = None, trainable_backbone_layers: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> FasterRCNN: ) -> FasterRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(
kwargs, "pretrained", "weights", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.Coco_V1
)
weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights) weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1
)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
defaults = { defaults = {
...@@ -227,10 +212,10 @@ def fasterrcnn_mobilenet_v3_large_320_fpn( ...@@ -227,10 +212,10 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
kwargs = {**defaults, **kwargs} kwargs = {**defaults, **kwargs}
return _fasterrcnn_mobilenet_v3_large_fpn( return _fasterrcnn_mobilenet_v3_large_fpn(
weights, weights=weights,
progress, progress=progress,
num_classes, num_classes=num_classes,
weights_backbone, weights_backbone=weights_backbone,
trainable_backbone_layers, trainable_backbone_layers=trainable_backbone_layers,
**kwargs, **kwargs,
) )
...@@ -11,7 +11,7 @@ from ....models.detection.keypoint_rcnn import ( ...@@ -11,7 +11,7 @@ from ....models.detection.keypoint_rcnn import (
) )
from .._api import WeightsEnum, Weights from .._api import WeightsEnum, Weights
from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..resnet import ResNet50_Weights, resnet50 from ..resnet import ResNet50_Weights, resnet50
...@@ -49,7 +49,17 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): ...@@ -49,7 +49,17 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
default = Coco_V1 default = Coco_V1
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.Coco_Legacy
if kwargs["pretrained"] == "legacy"
else KeypointRCNN_ResNet50_FPN_Weights.Coco_V1,
),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
)
def keypointrcnn_resnet50_fpn( def keypointrcnn_resnet50_fpn(
*,
weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None, weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
...@@ -58,21 +68,7 @@ def keypointrcnn_resnet50_fpn( ...@@ -58,21 +68,7 @@ def keypointrcnn_resnet50_fpn(
trainable_backbone_layers: Optional[int] = None, trainable_backbone_layers: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> KeypointRCNN: ) -> KeypointRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = KeypointRCNN_ResNet50_FPN_Weights.Coco_V1
if kwargs["pretrained"] == "legacy":
default_value = KeypointRCNN_ResNet50_FPN_Weights.Coco_Legacy
kwargs["pretrained"] = True
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value)
weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights) weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1
)
weights_backbone = ResNet50_Weights.verify(weights_backbone) weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
......
...@@ -12,7 +12,7 @@ from ....models.detection.mask_rcnn import ( ...@@ -12,7 +12,7 @@ from ....models.detection.mask_rcnn import (
) )
from .._api import WeightsEnum, Weights from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..resnet import ResNet50_Weights, resnet50 from ..resnet import ResNet50_Weights, resnet50
...@@ -38,7 +38,12 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): ...@@ -38,7 +38,12 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
default = Coco_V1 default = Coco_V1
@handle_legacy_interface(
weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
)
def maskrcnn_resnet50_fpn( def maskrcnn_resnet50_fpn(
*,
weights: Optional[MaskRCNN_ResNet50_FPN_Weights] = None, weights: Optional[MaskRCNN_ResNet50_FPN_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
...@@ -46,17 +51,7 @@ def maskrcnn_resnet50_fpn( ...@@ -46,17 +51,7 @@ def maskrcnn_resnet50_fpn(
trainable_backbone_layers: Optional[int] = None, trainable_backbone_layers: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> MaskRCNN: ) -> MaskRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MaskRCNN_ResNet50_FPN_Weights.Coco_V1)
weights = MaskRCNN_ResNet50_FPN_Weights.verify(weights) weights = MaskRCNN_ResNet50_FPN_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1
)
weights_backbone = ResNet50_Weights.verify(weights_backbone) weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
......
...@@ -13,7 +13,7 @@ from ....models.detection.retinanet import ( ...@@ -13,7 +13,7 @@ from ....models.detection.retinanet import (
) )
from .._api import WeightsEnum, Weights from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..resnet import ResNet50_Weights, resnet50 from ..resnet import ResNet50_Weights, resnet50
...@@ -38,7 +38,12 @@ class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): ...@@ -38,7 +38,12 @@ class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
default = Coco_V1 default = Coco_V1
@handle_legacy_interface(
weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
)
def retinanet_resnet50_fpn( def retinanet_resnet50_fpn(
*,
weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None, weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
...@@ -46,17 +51,7 @@ def retinanet_resnet50_fpn( ...@@ -46,17 +51,7 @@ def retinanet_resnet50_fpn(
trainable_backbone_layers: Optional[int] = None, trainable_backbone_layers: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> RetinaNet: ) -> RetinaNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RetinaNet_ResNet50_FPN_Weights.Coco_V1)
weights = RetinaNet_ResNet50_FPN_Weights.verify(weights) weights = RetinaNet_ResNet50_FPN_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1
)
weights_backbone = ResNet50_Weights.verify(weights_backbone) weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
......
...@@ -12,7 +12,7 @@ from ....models.detection.ssd import ( ...@@ -12,7 +12,7 @@ from ....models.detection.ssd import (
) )
from .._api import WeightsEnum, Weights from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..vgg import VGG16_Weights, vgg16 from ..vgg import VGG16_Weights, vgg16
...@@ -37,7 +37,12 @@ class SSD300_VGG16_Weights(WeightsEnum): ...@@ -37,7 +37,12 @@ class SSD300_VGG16_Weights(WeightsEnum):
default = Coco_V1 default = Coco_V1
@handle_legacy_interface(
weights=("pretrained", SSD300_VGG16_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", VGG16_Weights.ImageNet1K_Features),
)
def ssd300_vgg16( def ssd300_vgg16(
*,
weights: Optional[SSD300_VGG16_Weights] = None, weights: Optional[SSD300_VGG16_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
...@@ -45,17 +50,7 @@ def ssd300_vgg16( ...@@ -45,17 +50,7 @@ def ssd300_vgg16(
trainable_backbone_layers: Optional[int] = None, trainable_backbone_layers: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> SSD: ) -> SSD:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", SSD300_VGG16_Weights.Coco_V1)
weights = SSD300_VGG16_Weights.verify(weights) weights = SSD300_VGG16_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", VGG16_Weights.ImageNet1K_Features
)
weights_backbone = VGG16_Weights.verify(weights_backbone) weights_backbone = VGG16_Weights.verify(weights_backbone)
if "size" in kwargs: if "size" in kwargs:
......
...@@ -17,7 +17,7 @@ from ....models.detection.ssdlite import ( ...@@ -17,7 +17,7 @@ from ....models.detection.ssdlite import (
) )
from .._api import WeightsEnum, Weights from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large
...@@ -42,7 +42,12 @@ class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum): ...@@ -42,7 +42,12 @@ class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum):
default = Coco_V1 default = Coco_V1
@handle_legacy_interface(
weights=("pretrained", SSDLite320_MobileNet_V3_Large_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1),
)
def ssdlite320_mobilenet_v3_large( def ssdlite320_mobilenet_v3_large(
*,
weights: Optional[SSDLite320_MobileNet_V3_Large_Weights] = None, weights: Optional[SSDLite320_MobileNet_V3_Large_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
...@@ -51,17 +56,7 @@ def ssdlite320_mobilenet_v3_large( ...@@ -51,17 +56,7 @@ def ssdlite320_mobilenet_v3_large(
norm_layer: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any, **kwargs: Any,
) -> SSD: ) -> SSD:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", SSDLite320_MobileNet_V3_Large_Weights.Coco_V1)
weights = SSDLite320_MobileNet_V3_Large_Weights.verify(weights) weights = SSDLite320_MobileNet_V3_Large_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1
)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
if "size" in kwargs: if "size" in kwargs:
......
...@@ -8,7 +8,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -8,7 +8,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.efficientnet import EfficientNet, MBConvConfig from ...models.efficientnet import EfficientNet, MBConvConfig
from ._api import WeightsEnum, Weights from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [ __all__ = [
...@@ -181,73 +181,55 @@ class EfficientNet_B7_Weights(WeightsEnum): ...@@ -181,73 +181,55 @@ class EfficientNet_B7_Weights(WeightsEnum):
default = ImageNet1K_V1 default = ImageNet1K_V1
@handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.ImageNet1K_V1))
def efficientnet_b0( def efficientnet_b0(
weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B0_Weights.ImageNet1K_V1)
weights = EfficientNet_B0_Weights.verify(weights) weights = EfficientNet_B0_Weights.verify(weights)
return _efficientnet(width_mult=1.0, depth_mult=1.0, dropout=0.2, weights=weights, progress=progress, **kwargs) return _efficientnet(width_mult=1.0, depth_mult=1.0, dropout=0.2, weights=weights, progress=progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.ImageNet1K_V1))
def efficientnet_b1( def efficientnet_b1(
weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B1_Weights.ImageNet1K_V1)
weights = EfficientNet_B1_Weights.verify(weights) weights = EfficientNet_B1_Weights.verify(weights)
return _efficientnet(width_mult=1.0, depth_mult=1.1, dropout=0.2, weights=weights, progress=progress, **kwargs) return _efficientnet(width_mult=1.0, depth_mult=1.1, dropout=0.2, weights=weights, progress=progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.ImageNet1K_V1))
def efficientnet_b2( def efficientnet_b2(
weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B2_Weights.ImageNet1K_V1)
weights = EfficientNet_B2_Weights.verify(weights) weights = EfficientNet_B2_Weights.verify(weights)
return _efficientnet(width_mult=1.1, depth_mult=1.2, dropout=0.3, weights=weights, progress=progress, **kwargs) return _efficientnet(width_mult=1.1, depth_mult=1.2, dropout=0.3, weights=weights, progress=progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.ImageNet1K_V1))
def efficientnet_b3( def efficientnet_b3(
weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B3_Weights.ImageNet1K_V1)
weights = EfficientNet_B3_Weights.verify(weights) weights = EfficientNet_B3_Weights.verify(weights)
return _efficientnet(width_mult=1.2, depth_mult=1.4, dropout=0.3, weights=weights, progress=progress, **kwargs) return _efficientnet(width_mult=1.2, depth_mult=1.4, dropout=0.3, weights=weights, progress=progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.ImageNet1K_V1))
def efficientnet_b4( def efficientnet_b4(
weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B4_Weights.ImageNet1K_V1)
weights = EfficientNet_B4_Weights.verify(weights) weights = EfficientNet_B4_Weights.verify(weights)
return _efficientnet(width_mult=1.4, depth_mult=1.8, dropout=0.4, weights=weights, progress=progress, **kwargs) return _efficientnet(width_mult=1.4, depth_mult=1.8, dropout=0.4, weights=weights, progress=progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.ImageNet1K_V1))
def efficientnet_b5( def efficientnet_b5(
weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B5_Weights.ImageNet1K_V1)
weights = EfficientNet_B5_Weights.verify(weights) weights = EfficientNet_B5_Weights.verify(weights)
return _efficientnet( return _efficientnet(
...@@ -261,13 +243,10 @@ def efficientnet_b5( ...@@ -261,13 +243,10 @@ def efficientnet_b5(
) )
@handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.ImageNet1K_V1))
def efficientnet_b6( def efficientnet_b6(
weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B6_Weights.ImageNet1K_V1)
weights = EfficientNet_B6_Weights.verify(weights) weights = EfficientNet_B6_Weights.verify(weights)
return _efficientnet( return _efficientnet(
...@@ -281,13 +260,10 @@ def efficientnet_b6( ...@@ -281,13 +260,10 @@ def efficientnet_b6(
) )
@handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.ImageNet1K_V1))
def efficientnet_b7( def efficientnet_b7(
weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B7_Weights.ImageNet1K_V1)
weights = EfficientNet_B7_Weights.verify(weights) weights = EfficientNet_B7_Weights.verify(weights)
return _efficientnet( return _efficientnet(
......
...@@ -8,7 +8,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -8,7 +8,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs
from ._api import WeightsEnum, Weights from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weights", "googlenet"] __all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weights", "googlenet"]
...@@ -30,11 +30,8 @@ class GoogLeNet_Weights(WeightsEnum): ...@@ -30,11 +30,8 @@ class GoogLeNet_Weights(WeightsEnum):
default = ImageNet1K_V1 default = ImageNet1K_V1
def googlenet(weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: @handle_legacy_interface(weights=("pretrained", GoogLeNet_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", GoogLeNet_Weights.ImageNet1K_V1)
weights = GoogLeNet_Weights.verify(weights) weights = GoogLeNet_Weights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", False) original_aux_logits = kwargs.get("aux_logits", False)
......
...@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs
from ._api import WeightsEnum, Weights from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_Weights", "inception_v3"] __all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_Weights", "inception_v3"]
...@@ -29,11 +29,8 @@ class Inception_V3_Weights(WeightsEnum): ...@@ -29,11 +29,8 @@ class Inception_V3_Weights(WeightsEnum):
default = ImageNet1K_V1 default = ImageNet1K_V1
def inception_v3(weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: @handle_legacy_interface(weights=("pretrained", Inception_V3_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", Inception_V3_Weights.ImageNet1K_V1)
weights = Inception_V3_Weights.verify(weights) weights = Inception_V3_Weights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", True) original_aux_logits = kwargs.get("aux_logits", True)
......
...@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.mnasnet import MNASNet from ...models.mnasnet import MNASNet
from ._api import WeightsEnum, Weights from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [ __all__ = [
...@@ -79,41 +79,29 @@ def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwa ...@@ -79,41 +79,29 @@ def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwa
return model return model
def mnasnet0_5(weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: @handle_legacy_interface(weights=("pretrained", MNASNet0_5_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet0_5_Weights.ImageNet1K_V1)
weights = MNASNet0_5_Weights.verify(weights) weights = MNASNet0_5_Weights.verify(weights)
return _mnasnet(0.5, weights, progress, **kwargs) return _mnasnet(0.5, weights, progress, **kwargs)
def mnasnet0_75(weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: @handle_legacy_interface(weights=("pretrained", None))
if type(weights) == bool and weights: def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = MNASNet0_75_Weights.verify(weights) weights = MNASNet0_75_Weights.verify(weights)
return _mnasnet(0.75, weights, progress, **kwargs) return _mnasnet(0.75, weights, progress, **kwargs)
def mnasnet1_0(weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: @handle_legacy_interface(weights=("pretrained", MNASNet1_0_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet1_0_Weights.ImageNet1K_V1)
weights = MNASNet1_0_Weights.verify(weights) weights = MNASNet1_0_Weights.verify(weights)
return _mnasnet(1.0, weights, progress, **kwargs) return _mnasnet(1.0, weights, progress, **kwargs)
def mnasnet1_3(weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: @handle_legacy_interface(weights=("pretrained", None))
if type(weights) == bool and weights: def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = MNASNet1_3_Weights.verify(weights) weights = MNASNet1_3_Weights.verify(weights)
return _mnasnet(1.3, weights, progress, **kwargs) return _mnasnet(1.3, weights, progress, **kwargs)
...@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.mobilenetv2 import MobileNetV2 from ...models.mobilenetv2 import MobileNetV2
from ._api import WeightsEnum, Weights from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"] __all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"]
...@@ -29,11 +29,10 @@ class MobileNet_V2_Weights(WeightsEnum): ...@@ -29,11 +29,10 @@ class MobileNet_V2_Weights(WeightsEnum):
default = ImageNet1K_V1 default = ImageNet1K_V1
def mobilenet_v2(weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2: @handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def mobilenet_v2(
_deprecated_positional(kwargs, "pretrained", "weights", True) *, weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any
if "pretrained" in kwargs: ) -> MobileNetV2:
weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNet_V2_Weights.ImageNet1K_V1)
weights = MobileNet_V2_Weights.verify(weights) weights = MobileNet_V2_Weights.verify(weights)
if weights is not None: if weights is not None:
......
...@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig
from ._api import WeightsEnum, Weights from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [ __all__ = [
...@@ -82,26 +82,20 @@ class MobileNet_V3_Small_Weights(WeightsEnum): ...@@ -82,26 +82,20 @@ class MobileNet_V3_Small_Weights(WeightsEnum):
default = ImageNet1K_V1 default = ImageNet1K_V1
@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Large_Weights.ImageNet1K_V1))
def mobilenet_v3_large( def mobilenet_v3_large(
weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV3: ) -> MobileNetV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNet_V3_Large_Weights.ImageNet1K_V1)
weights = MobileNet_V3_Large_Weights.verify(weights) weights = MobileNet_V3_Large_Weights.verify(weights)
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Small_Weights.ImageNet1K_V1))
def mobilenet_v3_small( def mobilenet_v3_small(
weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV3: ) -> MobileNetV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNet_V3_Small_Weights.ImageNet1K_V1)
weights = MobileNet_V3_Small_Weights.verify(weights) weights = MobileNet_V3_Small_Weights.verify(weights)
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs) inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs)
......
...@@ -12,7 +12,7 @@ from ....models.quantization.googlenet import ( ...@@ -12,7 +12,7 @@ from ....models.quantization.googlenet import (
) )
from .._api import WeightsEnum, Weights from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..googlenet import GoogLeNet_Weights from ..googlenet import GoogLeNet_Weights
...@@ -42,21 +42,22 @@ class GoogLeNet_QuantizedWeights(WeightsEnum): ...@@ -42,21 +42,22 @@ class GoogLeNet_QuantizedWeights(WeightsEnum):
default = ImageNet1K_FBGEMM_V1 default = ImageNet1K_FBGEMM_V1
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: GoogLeNet_QuantizedWeights.ImageNet1K_FBGEMM_V1
if kwargs.get("quantize", False)
else GoogLeNet_Weights.ImageNet1K_V1,
)
)
def googlenet( def googlenet(
*,
weights: Optional[Union[GoogLeNet_QuantizedWeights, GoogLeNet_Weights]] = None, weights: Optional[Union[GoogLeNet_QuantizedWeights, GoogLeNet_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
) -> QuantizableGoogLeNet: ) -> QuantizableGoogLeNet:
if type(weights) == bool and weights: weights = (GoogLeNet_QuantizedWeights if quantize else GoogLeNet_Weights).verify(weights)
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = GoogLeNet_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else GoogLeNet_Weights.ImageNet1K_V1
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = GoogLeNet_QuantizedWeights.verify(weights)
else:
weights = GoogLeNet_Weights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", False) original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None: if weights is not None:
......
...@@ -11,7 +11,7 @@ from ....models.quantization.inception import ( ...@@ -11,7 +11,7 @@ from ....models.quantization.inception import (
) )
from .._api import WeightsEnum, Weights from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..inception import Inception_V3_Weights from ..inception import Inception_V3_Weights
...@@ -41,23 +41,22 @@ class Inception_V3_QuantizedWeights(WeightsEnum): ...@@ -41,23 +41,22 @@ class Inception_V3_QuantizedWeights(WeightsEnum):
default = ImageNet1K_FBGEMM_V1 default = ImageNet1K_FBGEMM_V1
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: Inception_V3_QuantizedWeights.ImageNet1K_FBGEMM_V1
if kwargs.get("quantize", False)
else Inception_V3_Weights.ImageNet1K_V1,
)
)
def inception_v3( def inception_v3(
*,
weights: Optional[Union[Inception_V3_QuantizedWeights, Inception_V3_Weights]] = None, weights: Optional[Union[Inception_V3_QuantizedWeights, Inception_V3_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
) -> QuantizableInception3: ) -> QuantizableInception3:
if type(weights) == bool and weights: weights = (Inception_V3_QuantizedWeights if quantize else Inception_V3_Weights).verify(weights)
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = (
Inception_V3_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else Inception_V3_Weights.ImageNet1K_V1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = Inception_V3_QuantizedWeights.verify(weights)
else:
weights = Inception_V3_Weights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", False) original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None: if weights is not None:
......
...@@ -12,7 +12,7 @@ from ....models.quantization.mobilenetv2 import ( ...@@ -12,7 +12,7 @@ from ....models.quantization.mobilenetv2 import (
) )
from .._api import WeightsEnum, Weights from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..mobilenetv2 import MobileNet_V2_Weights from ..mobilenetv2 import MobileNet_V2_Weights
...@@ -42,23 +42,22 @@ class MobileNet_V2_QuantizedWeights(WeightsEnum): ...@@ -42,23 +42,22 @@ class MobileNet_V2_QuantizedWeights(WeightsEnum):
default = ImageNet1K_QNNPACK_V1 default = ImageNet1K_QNNPACK_V1
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: MobileNet_V2_QuantizedWeights.ImageNet1K_QNNPACK_V1
if kwargs.get("quantize", False)
else MobileNet_V2_Weights.ImageNet1K_V1,
)
)
def mobilenet_v2( def mobilenet_v2(
*,
weights: Optional[Union[MobileNet_V2_QuantizedWeights, MobileNet_V2_Weights]] = None, weights: Optional[Union[MobileNet_V2_QuantizedWeights, MobileNet_V2_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
) -> QuantizableMobileNetV2: ) -> QuantizableMobileNetV2:
if type(weights) == bool and weights: weights = (MobileNet_V2_QuantizedWeights if quantize else MobileNet_V2_Weights).verify(weights)
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = (
MobileNet_V2_QuantizedWeights.ImageNet1K_QNNPACK_V1 if quantize else MobileNet_V2_Weights.ImageNet1K_V1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = MobileNet_V2_QuantizedWeights.verify(weights)
else:
weights = MobileNet_V2_Weights.verify(weights)
if weights is not None: if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
......
...@@ -13,7 +13,7 @@ from ....models.quantization.mobilenetv3 import ( ...@@ -13,7 +13,7 @@ from ....models.quantization.mobilenetv3 import (
) )
from .._api import WeightsEnum, Weights from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..mobilenetv3 import MobileNet_V3_Large_Weights, _mobilenet_v3_conf from ..mobilenetv3 import MobileNet_V3_Large_Weights, _mobilenet_v3_conf
...@@ -75,25 +75,22 @@ class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): ...@@ -75,25 +75,22 @@ class MobileNet_V3_Large_QuantizedWeights(WeightsEnum):
default = ImageNet1K_QNNPACK_V1 default = ImageNet1K_QNNPACK_V1
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: MobileNet_V3_Large_QuantizedWeights.ImageNet1K_QNNPACK_V1
if kwargs.get("quantize", False)
else MobileNet_V3_Large_Weights.ImageNet1K_V1,
)
)
def mobilenet_v3_large( def mobilenet_v3_large(
*,
weights: Optional[Union[MobileNet_V3_Large_QuantizedWeights, MobileNet_V3_Large_Weights]] = None, weights: Optional[Union[MobileNet_V3_Large_QuantizedWeights, MobileNet_V3_Large_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
) -> QuantizableMobileNetV3: ) -> QuantizableMobileNetV3:
if type(weights) == bool and weights: weights = (MobileNet_V3_Large_QuantizedWeights if quantize else MobileNet_V3_Large_Weights).verify(weights)
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = (
MobileNet_V3_Large_QuantizedWeights.ImageNet1K_QNNPACK_V1
if quantize
else MobileNet_V3_Large_Weights.ImageNet1K_V1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = MobileNet_V3_Large_QuantizedWeights.verify(weights)
else:
weights = MobileNet_V3_Large_Weights.verify(weights)
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs) return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment