Unverified Commit 588e9b5e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

simplify model builders (#5001)



* simplify model builders

* cleanup

* refactor kwonly to pos or kw handling

* put weight verification back

* revert num categories checks

* fix default weights

* cleanup

* remove manual parameter map

* refactor decorator interface

* address review comments

* cleanup

* refactor callable default

* fix type annotation

* process ungrouped models

* cleanup

* mroe cleanup

* use decorator for detection models

* add decorator for quantization models

* add decorator for segmentation  models

* add decorator for video  models

* remove old helpers

* fix resnet50

* Adding verification back on InceptionV3

* Add kwargs in DeeplabeV3

* Add kwargs on FCN

* Fix typing on Deeplab

* Fix typing on FCN
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 3ceaff14
......@@ -13,7 +13,7 @@ from ....models.quantization.resnet import (
)
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..resnet import ResNet18_Weights, ResNet50_Weights, ResNeXt101_32X8D_Weights
......@@ -125,63 +125,62 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
default = ImageNet1K_FBGEMM_V2
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: ResNet18_QuantizedWeights.ImageNet1K_FBGEMM_V1
if kwargs.get("quantize", False)
else ResNet18_Weights.ImageNet1K_V1,
)
)
def resnet18(
*,
weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = ResNet18_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else ResNet18_Weights.ImageNet1K_V1
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = ResNet18_QuantizedWeights.verify(weights)
else:
weights = ResNet18_Weights.verify(weights)
weights = (ResNet18_QuantizedWeights if quantize else ResNet18_Weights).verify(weights)
return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs)
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1
if kwargs.get("quantize", False)
else ResNet50_Weights.ImageNet1K_V1,
)
)
def resnet50(
*,
weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else ResNet50_Weights.ImageNet1K_V1
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = ResNet50_QuantizedWeights.verify(weights)
else:
weights = ResNet50_Weights.verify(weights)
weights = (ResNet50_QuantizedWeights if quantize else ResNet50_Weights).verify(weights)
return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs)
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: ResNeXt101_32X8D_QuantizedWeights.ImageNet1K_FBGEMM_V1
if kwargs.get("quantize", False)
else ResNeXt101_32X8D_Weights.ImageNet1K_V1,
)
)
def resnext101_32x8d(
*,
weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = (
ResNeXt101_32X8D_QuantizedWeights.ImageNet1K_FBGEMM_V1
if quantize
else ResNeXt101_32X8D_Weights.ImageNet1K_V1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = ResNeXt101_32X8D_QuantizedWeights.verify(weights)
else:
weights = ResNeXt101_32X8D_Weights.verify(weights)
weights = (ResNeXt101_32X8D_QuantizedWeights if quantize else ResNeXt101_32X8D_Weights).verify(weights)
_ovewrite_named_param(kwargs, "groups", 32)
_ovewrite_named_param(kwargs, "width_per_group", 8)
......
......@@ -11,7 +11,7 @@ from ....models.quantization.shufflenetv2 import (
)
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights
......@@ -27,6 +27,7 @@ __all__ = [
def _shufflenetv2(
stages_repeats: List[int],
stages_out_channels: List[int],
*,
weights: Optional[WeightsEnum],
progress: bool,
quantize: bool,
......@@ -87,47 +88,43 @@ class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum):
default = ImageNet1K_FBGEMM_V1
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: ShuffleNet_V2_X0_5_QuantizedWeights.ImageNet1K_FBGEMM_V1
if kwargs.get("quantize", False)
else ShuffleNet_V2_X0_5_Weights.ImageNet1K_V1,
)
)
def shufflenet_v2_x0_5(
*,
weights: Optional[Union[ShuffleNet_V2_X0_5_QuantizedWeights, ShuffleNet_V2_X0_5_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = (
ShuffleNet_V2_X0_5_QuantizedWeights.ImageNet1K_FBGEMM_V1
if quantize
else ShuffleNet_V2_X0_5_Weights.ImageNet1K_V1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = ShuffleNet_V2_X0_5_QuantizedWeights.verify(weights)
else:
weights = ShuffleNet_V2_X0_5_Weights.verify(weights)
return _shufflenetv2([4, 8, 4], [24, 48, 96, 192, 1024], weights, progress, quantize, **kwargs)
weights = (ShuffleNet_V2_X0_5_QuantizedWeights if quantize else ShuffleNet_V2_X0_5_Weights).verify(weights)
return _shufflenetv2(
[4, 8, 4], [24, 48, 96, 192, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs
)
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: ShuffleNet_V2_X1_0_QuantizedWeights.ImageNet1K_FBGEMM_V1
if kwargs.get("quantize", False)
else ShuffleNet_V2_X1_0_Weights.ImageNet1K_V1,
)
)
def shufflenet_v2_x1_0(
*,
weights: Optional[Union[ShuffleNet_V2_X1_0_QuantizedWeights, ShuffleNet_V2_X1_0_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = (
ShuffleNet_V2_X1_0_QuantizedWeights.ImageNet1K_FBGEMM_V1
if quantize
else ShuffleNet_V2_X1_0_Weights.ImageNet1K_V1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = ShuffleNet_V2_X1_0_QuantizedWeights.verify(weights)
else:
weights = ShuffleNet_V2_X1_0_Weights.verify(weights)
return _shufflenetv2([4, 8, 4], [24, 116, 232, 464, 1024], weights, progress, quantize, **kwargs)
weights = (ShuffleNet_V2_X1_0_QuantizedWeights if quantize else ShuffleNet_V2_X1_0_Weights).verify(weights)
return _shufflenetv2(
[4, 8, 4], [24, 116, 232, 464, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs
)
......@@ -8,7 +8,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.regnet import RegNet, BlockParams
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [
......@@ -260,33 +260,24 @@ class RegNet_X_32GF_Weights(WeightsEnum):
default = ImageNet1K_V1
def regnet_y_400mf(weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_400MF_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", RegNet_Y_400MF_Weights.ImageNet1K_V1))
def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_Y_400MF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs)
return _regnet(params, weights, progress, **kwargs)
def regnet_y_800mf(weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_800MF_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", RegNet_Y_800MF_Weights.ImageNet1K_V1))
def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_Y_800MF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs)
return _regnet(params, weights, progress, **kwargs)
def regnet_y_1_6gf(weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_1_6GF_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", RegNet_Y_1_6GF_Weights.ImageNet1K_V1))
def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_Y_1_6GF_Weights.verify(weights)
params = BlockParams.from_init_params(
......@@ -295,11 +286,8 @@ def regnet_y_1_6gf(weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: b
return _regnet(params, weights, progress, **kwargs)
def regnet_y_3_2gf(weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_3_2GF_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", RegNet_Y_3_2GF_Weights.ImageNet1K_V1))
def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_Y_3_2GF_Weights.verify(weights)
params = BlockParams.from_init_params(
......@@ -308,11 +296,8 @@ def regnet_y_3_2gf(weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: b
return _regnet(params, weights, progress, **kwargs)
def regnet_y_8gf(weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_8GF_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", RegNet_Y_8GF_Weights.ImageNet1K_V1))
def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_Y_8GF_Weights.verify(weights)
params = BlockParams.from_init_params(
......@@ -321,11 +306,8 @@ def regnet_y_8gf(weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool
return _regnet(params, weights, progress, **kwargs)
def regnet_y_16gf(weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_16GF_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", RegNet_Y_16GF_Weights.ImageNet1K_V1))
def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_Y_16GF_Weights.verify(weights)
params = BlockParams.from_init_params(
......@@ -334,11 +316,8 @@ def regnet_y_16gf(weights: Optional[RegNet_Y_16GF_Weights] = None, progress: boo
return _regnet(params, weights, progress, **kwargs)
def regnet_y_32gf(weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_32GF_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", RegNet_Y_32GF_Weights.ImageNet1K_V1))
def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_Y_32GF_Weights.verify(weights)
params = BlockParams.from_init_params(
......@@ -347,77 +326,56 @@ def regnet_y_32gf(weights: Optional[RegNet_Y_32GF_Weights] = None, progress: boo
return _regnet(params, weights, progress, **kwargs)
def regnet_x_400mf(weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_400MF_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.ImageNet1K_V1))
def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_X_400MF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs)
return _regnet(params, weights, progress, **kwargs)
def regnet_x_800mf(weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_800MF_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", RegNet_X_800MF_Weights.ImageNet1K_V1))
def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_X_800MF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs)
return _regnet(params, weights, progress, **kwargs)
def regnet_x_1_6gf(weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_1_6GF_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", RegNet_X_1_6GF_Weights.ImageNet1K_V1))
def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_X_1_6GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs)
return _regnet(params, weights, progress, **kwargs)
def regnet_x_3_2gf(weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_3_2GF_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", RegNet_X_3_2GF_Weights.ImageNet1K_V1))
def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_X_3_2GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs)
return _regnet(params, weights, progress, **kwargs)
def regnet_x_8gf(weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_8GF_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", RegNet_X_8GF_Weights.ImageNet1K_V1))
def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_X_8GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs)
return _regnet(params, weights, progress, **kwargs)
def regnet_x_16gf(weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_16GF_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", RegNet_X_16GF_Weights.ImageNet1K_V1))
def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_X_16GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs)
return _regnet(params, weights, progress, **kwargs)
def regnet_x_32gf(weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_32GF_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", RegNet_X_32GF_Weights.ImageNet1K_V1))
def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_X_32GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs)
......
......@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.resnet import BasicBlock, Bottleneck, ResNet
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [
......@@ -250,61 +250,45 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
default = ImageNet1K_V2
def resnet18(weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet18_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", ResNet18_Weights.ImageNet1K_V1))
def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
weights = ResNet18_Weights.verify(weights)
return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
def resnet34(weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet34_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", ResNet34_Weights.ImageNet1K_V1))
def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
weights = ResNet34_Weights.verify(weights)
return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs)
def resnet50(weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet50_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", ResNet50_Weights.ImageNet1K_V1))
def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
weights = ResNet50_Weights.verify(weights)
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
def resnet101(weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet101_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", ResNet101_Weights.ImageNet1K_V1))
def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
weights = ResNet101_Weights.verify(weights)
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
def resnet152(weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet152_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", ResNet152_Weights.ImageNet1K_V1))
def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
weights = ResNet152_Weights.verify(weights)
return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs)
def resnext50_32x4d(weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt50_32X4D_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", ResNeXt50_32X4D_Weights.ImageNet1K_V1))
def resnext50_32x4d(
*, weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
weights = ResNeXt50_32X4D_Weights.verify(weights)
_ovewrite_named_param(kwargs, "groups", 32)
......@@ -312,13 +296,10 @@ def resnext50_32x4d(weights: Optional[ResNeXt50_32X4D_Weights] = None, progress:
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", ResNeXt101_32X8D_Weights.ImageNet1K_V1))
def resnext101_32x8d(
weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any
*, weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt101_32X8D_Weights.ImageNet1K_V1)
weights = ResNeXt101_32X8D_Weights.verify(weights)
_ovewrite_named_param(kwargs, "groups", 32)
......@@ -326,24 +307,20 @@ def resnext101_32x8d(
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
def wide_resnet50_2(weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", Wide_ResNet50_2_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", Wide_ResNet50_2_Weights.ImageNet1K_V1))
def wide_resnet50_2(
*, weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
weights = Wide_ResNet50_2_Weights.verify(weights)
_ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", Wide_ResNet101_2_Weights.ImageNet1K_V1))
def wide_resnet101_2(
weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any
*, weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", Wide_ResNet101_2_Weights.ImageNet1K_V1)
weights = Wide_ResNet101_2_Weights.verify(weights)
_ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
......
......@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet
from .._api import WeightsEnum, Weights
from .._meta import _VOC_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large
from ..resnet import resnet50, resnet101
from ..resnet import ResNet50_Weights, ResNet101_Weights
......@@ -72,7 +72,12 @@ class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum):
default = CocoWithVocLabels_V1
@handle_legacy_interface(
weights=("pretrained", DeepLabV3_ResNet50_Weights.CocoWithVocLabels_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
)
def deeplabv3_resnet50(
*,
weights: Optional[DeepLabV3_ResNet50_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
......@@ -80,17 +85,7 @@ def deeplabv3_resnet50(
weights_backbone: Optional[ResNet50_Weights] = None,
**kwargs: Any,
) -> DeepLabV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3_ResNet50_Weights.CocoWithVocLabels_V1)
weights = DeepLabV3_ResNet50_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1
)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
......@@ -109,7 +104,12 @@ def deeplabv3_resnet50(
return model
@handle_legacy_interface(
weights=("pretrained", DeepLabV3_ResNet101_Weights.CocoWithVocLabels_V1),
weights_backbone=("pretrained_backbone", ResNet101_Weights.ImageNet1K_V1),
)
def deeplabv3_resnet101(
*,
weights: Optional[DeepLabV3_ResNet101_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
......@@ -117,17 +117,7 @@ def deeplabv3_resnet101(
weights_backbone: Optional[ResNet101_Weights] = None,
**kwargs: Any,
) -> DeepLabV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3_ResNet101_Weights.CocoWithVocLabels_V1)
weights = DeepLabV3_ResNet101_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet101_Weights.ImageNet1K_V1
)
weights_backbone = ResNet101_Weights.verify(weights_backbone)
if weights is not None:
......@@ -146,7 +136,12 @@ def deeplabv3_resnet101(
return model
@handle_legacy_interface(
weights=("pretrained", DeepLabV3_MobileNet_V3_Large_Weights.CocoWithVocLabels_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1),
)
def deeplabv3_mobilenet_v3_large(
*,
weights: Optional[DeepLabV3_MobileNet_V3_Large_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
......@@ -154,19 +149,7 @@ def deeplabv3_mobilenet_v3_large(
weights_backbone: Optional[MobileNet_V3_Large_Weights] = None,
**kwargs: Any,
) -> DeepLabV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(
kwargs, "pretrained", "weights", DeepLabV3_MobileNet_V3_Large_Weights.CocoWithVocLabels_V1
)
weights = DeepLabV3_MobileNet_V3_Large_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1
)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
if weights is not None:
......
......@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.fcn import FCN, _fcn_resnet
from .._api import WeightsEnum, Weights
from .._meta import _VOC_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..resnet import ResNet50_Weights, ResNet101_Weights, resnet50, resnet101
......@@ -48,7 +48,12 @@ class FCN_ResNet101_Weights(WeightsEnum):
default = CocoWithVocLabels_V1
@handle_legacy_interface(
weights=("pretrained", FCN_ResNet50_Weights.CocoWithVocLabels_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
)
def fcn_resnet50(
*,
weights: Optional[FCN_ResNet50_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
......@@ -56,17 +61,7 @@ def fcn_resnet50(
weights_backbone: Optional[ResNet50_Weights] = None,
**kwargs: Any,
) -> FCN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", FCN_ResNet50_Weights.CocoWithVocLabels_V1)
weights = FCN_ResNet50_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1
)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
......@@ -85,7 +80,12 @@ def fcn_resnet50(
return model
@handle_legacy_interface(
weights=("pretrained", FCN_ResNet101_Weights.CocoWithVocLabels_V1),
weights_backbone=("pretrained_backbone", ResNet101_Weights.ImageNet1K_V1),
)
def fcn_resnet101(
*,
weights: Optional[FCN_ResNet101_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
......@@ -93,17 +93,7 @@ def fcn_resnet101(
weights_backbone: Optional[ResNet101_Weights] = None,
**kwargs: Any,
) -> FCN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", FCN_ResNet101_Weights.CocoWithVocLabels_V1)
weights = FCN_ResNet101_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet101_Weights.ImageNet1K_V1
)
weights_backbone = ResNet101_Weights.verify(weights_backbone)
if weights is not None:
......
......@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3
from .._api import WeightsEnum, Weights
from .._meta import _VOC_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large
......@@ -29,7 +29,12 @@ class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum):
default = CocoWithVocLabels_V1
@handle_legacy_interface(
weights=("pretrained", LRASPP_MobileNet_V3_Large_Weights.CocoWithVocLabels_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1),
)
def lraspp_mobilenet_v3_large(
*,
weights: Optional[LRASPP_MobileNet_V3_Large_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
......@@ -39,19 +44,7 @@ def lraspp_mobilenet_v3_large(
if kwargs.pop("aux_loss", False):
raise NotImplementedError("This model does not use auxiliary loss")
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(
kwargs, "pretrained", "weights", LRASPP_MobileNet_V3_Large_Weights.CocoWithVocLabels_V1
)
weights = LRASPP_MobileNet_V3_Large_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1
)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
if weights is not None:
......
......@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.shufflenetv2 import ShuffleNetV2
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [
......@@ -82,49 +82,37 @@ class ShuffleNet_V2_X2_0_Weights(WeightsEnum):
pass
@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X0_5_Weights.ImageNet1K_V1))
def shufflenet_v2_x0_5(
weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any
*, weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNet_V2_X0_5_Weights.ImageNet1K_V1)
weights = ShuffleNet_V2_X0_5_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X1_0_Weights.ImageNet1K_V1))
def shufflenet_v2_x1_0(
weights: Optional[ShuffleNet_V2_X1_0_Weights] = None, progress: bool = True, **kwargs: Any
*, weights: Optional[ShuffleNet_V2_X1_0_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNet_V2_X1_0_Weights.ImageNet1K_V1)
weights = ShuffleNet_V2_X1_0_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
@handle_legacy_interface(weights=("pretrained", None))
def shufflenet_v2_x1_5(
weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any
*, weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = ShuffleNet_V2_X1_5_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
@handle_legacy_interface(weights=("pretrained", None))
def shufflenet_v2_x2_0(
weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any
*, weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = ShuffleNet_V2_X2_0_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
......@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.squeezenet import SqueezeNet
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = ["SqueezeNet", "SqueezeNet1_0_Weights", "SqueezeNet1_1_Weights", "squeezenet1_0", "squeezenet1_1"]
......@@ -47,11 +47,10 @@ class SqueezeNet1_1_Weights(WeightsEnum):
default = ImageNet1K_V1
def squeezenet1_0(weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_0_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", SqueezeNet1_0_Weights.ImageNet1K_V1))
def squeezenet1_0(
*, weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any
) -> SqueezeNet:
weights = SqueezeNet1_0_Weights.verify(weights)
if weights is not None:
......@@ -65,11 +64,10 @@ def squeezenet1_0(weights: Optional[SqueezeNet1_0_Weights] = None, progress: boo
return model
def squeezenet1_1(weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_1_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", SqueezeNet1_1_Weights.ImageNet1K_V1))
def squeezenet1_1(
*, weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any
) -> SqueezeNet:
weights = SqueezeNet1_1_Weights.verify(weights)
if weights is not None:
......
......@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.vgg import VGG, make_layers, cfgs
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [
......@@ -169,81 +169,57 @@ class VGG19_BN_Weights(WeightsEnum):
default = ImageNet1K_V1
def vgg11(weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", VGG11_Weights.ImageNet1K_V1))
def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
weights = VGG11_Weights.verify(weights)
return _vgg("A", False, weights, progress, **kwargs)
def vgg11_bn(weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11_BN_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", VGG11_BN_Weights.ImageNet1K_V1))
def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
weights = VGG11_BN_Weights.verify(weights)
return _vgg("A", True, weights, progress, **kwargs)
def vgg13(weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", VGG13_Weights.ImageNet1K_V1))
def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
weights = VGG13_Weights.verify(weights)
return _vgg("B", False, weights, progress, **kwargs)
def vgg13_bn(weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13_BN_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", VGG13_BN_Weights.ImageNet1K_V1))
def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
weights = VGG13_BN_Weights.verify(weights)
return _vgg("B", True, weights, progress, **kwargs)
def vgg16(weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", VGG16_Weights.ImageNet1K_V1))
def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
weights = VGG16_Weights.verify(weights)
return _vgg("D", False, weights, progress, **kwargs)
def vgg16_bn(weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16_BN_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", VGG16_BN_Weights.ImageNet1K_V1))
def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
weights = VGG16_BN_Weights.verify(weights)
return _vgg("D", True, weights, progress, **kwargs)
def vgg19(weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", VGG19_Weights.ImageNet1K_V1))
def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
weights = VGG19_Weights.verify(weights)
return _vgg("E", False, weights, progress, **kwargs)
def vgg19_bn(weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19_BN_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", VGG19_BN_Weights.ImageNet1K_V1))
def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
weights = VGG19_BN_Weights.verify(weights)
return _vgg("E", True, weights, progress, **kwargs)
......@@ -17,7 +17,7 @@ from ....models.video.resnet import (
)
from .._api import WeightsEnum, Weights
from .._meta import _KINETICS400_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from .._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [
......@@ -98,11 +98,8 @@ class R2Plus1D_18_Weights(WeightsEnum):
default = Kinetics400_V1
def r3d_18(weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", R3D_18_Weights.Kinetics400_V1)
@handle_legacy_interface(weights=("pretrained", R3D_18_Weights.Kinetics400_V1))
def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
weights = R3D_18_Weights.verify(weights)
return _video_resnet(
......@@ -116,11 +113,8 @@ def r3d_18(weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kw
)
def mc3_18(weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MC3_18_Weights.Kinetics400_V1)
@handle_legacy_interface(weights=("pretrained", MC3_18_Weights.Kinetics400_V1))
def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
weights = MC3_18_Weights.verify(weights)
return _video_resnet(
......@@ -134,11 +128,8 @@ def mc3_18(weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kw
)
def r2plus1d_18(weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", R2Plus1D_18_Weights.Kinetics400_V1)
@handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.Kinetics400_V1))
def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
weights = R2Plus1D_18_Weights.verify(weights)
return _video_resnet(
......
......@@ -12,7 +12,7 @@ import torch.nn as nn
from torch import Tensor
from ._api import WeightsEnum
from ._utils import _deprecated_param, _deprecated_positional
from ._utils import handle_legacy_interface
__all__ = [
......@@ -279,7 +279,8 @@ def _vision_transformer(
return model
def vit_b_16(weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
@handle_legacy_interface(weights=("pretrained", None))
def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_b_16 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
......@@ -289,10 +290,6 @@ def vit_b_16(weights: Optional[ViT_B_16_Weights] = None, progress: bool = True,
Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = ViT_B_16_Weights.verify(weights)
return _vision_transformer(
......@@ -307,7 +304,8 @@ def vit_b_16(weights: Optional[ViT_B_16_Weights] = None, progress: bool = True,
)
def vit_b_32(weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
@handle_legacy_interface(weights=("pretrained", None))
def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_b_32 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
......@@ -317,10 +315,6 @@ def vit_b_32(weights: Optional[ViT_B_32_Weights] = None, progress: bool = True,
Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = ViT_B_32_Weights.verify(weights)
return _vision_transformer(
......@@ -335,7 +329,8 @@ def vit_b_32(weights: Optional[ViT_B_32_Weights] = None, progress: bool = True,
)
def vit_l_16(weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
@handle_legacy_interface(weights=("pretrained", None))
def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_l_16 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
......@@ -345,10 +340,6 @@ def vit_l_16(weights: Optional[ViT_L_16_Weights] = None, progress: bool = True,
Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = ViT_L_16_Weights.verify(weights)
return _vision_transformer(
......@@ -363,7 +354,8 @@ def vit_l_16(weights: Optional[ViT_L_16_Weights] = None, progress: bool = True,
)
def vit_l_32(weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
@handle_legacy_interface(weights=("pretrained", None))
def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_l_32 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
......@@ -373,10 +365,6 @@ def vit_l_32(weights: Optional[ViT_L_32_Weights] = None, progress: bool = True,
Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = ViT_L_32_Weights.verify(weights)
return _vision_transformer(
......
import collections.abc
import difflib
import enum
import functools
import inspect
import os
import os.path
import textwrap
import warnings
from typing import Collection, Sequence, Callable, Any, Iterator, NoReturn, Mapping, TypeVar, Iterable, Tuple, cast
__all__ = [
......@@ -13,6 +16,7 @@ __all__ = [
"FrozenMapping",
"make_repr",
"FrozenBunch",
"kwonly_to_pos_or_kw",
]
......@@ -126,3 +130,54 @@ class FrozenBunch(FrozenMapping):
def __repr__(self) -> str:
return make_repr(type(self).__name__, self.items())
def kwonly_to_pos_or_kw(fn: Callable[..., D]) -> Callable[..., D]:
"""Decorates a function that uses keyword only parameters to also allow them being passed as positionals.
For example, consider the use case of changing the signature of ``old_fn`` into the one from ``new_fn``:
.. code::
def old_fn(foo, bar, baz=None):
...
def new_fn(foo, *, bar, baz=None):
...
Calling ``old_fn("foo", "bar, "baz")`` was valid, but the same call is no longer valid with ``new_fn``. To keep BC
and at the same time warn the user of the deprecation, this decorator can be used:
.. code::
@kwonly_to_pos_or_kw
def new_fn(foo, *, bar, baz=None):
...
new_fn("foo", "bar, "baz")
"""
params = inspect.signature(fn).parameters
try:
keyword_only_start_idx = next(
idx for idx, param in enumerate(params.values()) if param.kind == param.KEYWORD_ONLY
)
except StopIteration:
raise TypeError(f"Found no keyword-only parameter on function '{fn.__name__}'") from None
keyword_only_params = tuple(inspect.signature(fn).parameters)[keyword_only_start_idx:]
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> D:
args, keyword_only_args = args[:keyword_only_start_idx], args[keyword_only_start_idx:]
if keyword_only_args:
keyword_only_kwargs = dict(zip(keyword_only_params, keyword_only_args))
warnings.warn(
f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional "
f"parameter(s) is deprecated. Please use keyword parameter(s) instead."
)
kwargs.update(keyword_only_kwargs)
return fn(*args, **kwargs)
return wrapper
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment