Unverified Commit 588e9b5e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

simplify model builders (#5001)



* simplify model builders

* cleanup

* refactor kwonly to pos or kw handling

* put weight verification back

* revert num categories checks

* fix default weights

* cleanup

* remove manual parameter map

* refactor decorator interface

* address review comments

* cleanup

* refactor callable default

* fix type annotation

* process ungrouped models

* cleanup

* mroe cleanup

* use decorator for detection models

* add decorator for quantization models

* add decorator for segmentation  models

* add decorator for video  models

* remove old helpers

* fix resnet50

* Adding verification back on InceptionV3

* Add kwargs in DeeplabeV3

* Add kwargs on FCN

* Fix typing on Deeplab

* Fix typing on FCN
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 3ceaff14
...@@ -13,7 +13,7 @@ from ....models.quantization.resnet import ( ...@@ -13,7 +13,7 @@ from ....models.quantization.resnet import (
) )
from .._api import WeightsEnum, Weights from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..resnet import ResNet18_Weights, ResNet50_Weights, ResNeXt101_32X8D_Weights from ..resnet import ResNet18_Weights, ResNet50_Weights, ResNeXt101_32X8D_Weights
...@@ -125,63 +125,62 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): ...@@ -125,63 +125,62 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
default = ImageNet1K_FBGEMM_V2 default = ImageNet1K_FBGEMM_V2
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: ResNet18_QuantizedWeights.ImageNet1K_FBGEMM_V1
if kwargs.get("quantize", False)
else ResNet18_Weights.ImageNet1K_V1,
)
)
def resnet18( def resnet18(
*,
weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None, weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
) -> QuantizableResNet: ) -> QuantizableResNet:
if type(weights) == bool and weights: weights = (ResNet18_QuantizedWeights if quantize else ResNet18_Weights).verify(weights)
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = ResNet18_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else ResNet18_Weights.ImageNet1K_V1
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = ResNet18_QuantizedWeights.verify(weights)
else:
weights = ResNet18_Weights.verify(weights)
return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs) return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs)
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1
if kwargs.get("quantize", False)
else ResNet50_Weights.ImageNet1K_V1,
)
)
def resnet50( def resnet50(
*,
weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None, weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
) -> QuantizableResNet: ) -> QuantizableResNet:
if type(weights) == bool and weights: weights = (ResNet50_QuantizedWeights if quantize else ResNet50_Weights).verify(weights)
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else ResNet50_Weights.ImageNet1K_V1
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = ResNet50_QuantizedWeights.verify(weights)
else:
weights = ResNet50_Weights.verify(weights)
return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs) return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs)
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: ResNeXt101_32X8D_QuantizedWeights.ImageNet1K_FBGEMM_V1
if kwargs.get("quantize", False)
else ResNeXt101_32X8D_Weights.ImageNet1K_V1,
)
)
def resnext101_32x8d( def resnext101_32x8d(
*,
weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None, weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
) -> QuantizableResNet: ) -> QuantizableResNet:
if type(weights) == bool and weights: weights = (ResNeXt101_32X8D_QuantizedWeights if quantize else ResNeXt101_32X8D_Weights).verify(weights)
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = (
ResNeXt101_32X8D_QuantizedWeights.ImageNet1K_FBGEMM_V1
if quantize
else ResNeXt101_32X8D_Weights.ImageNet1K_V1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = ResNeXt101_32X8D_QuantizedWeights.verify(weights)
else:
weights = ResNeXt101_32X8D_Weights.verify(weights)
_ovewrite_named_param(kwargs, "groups", 32) _ovewrite_named_param(kwargs, "groups", 32)
_ovewrite_named_param(kwargs, "width_per_group", 8) _ovewrite_named_param(kwargs, "width_per_group", 8)
......
...@@ -11,7 +11,7 @@ from ....models.quantization.shufflenetv2 import ( ...@@ -11,7 +11,7 @@ from ....models.quantization.shufflenetv2 import (
) )
from .._api import WeightsEnum, Weights from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights
...@@ -27,6 +27,7 @@ __all__ = [ ...@@ -27,6 +27,7 @@ __all__ = [
def _shufflenetv2( def _shufflenetv2(
stages_repeats: List[int], stages_repeats: List[int],
stages_out_channels: List[int], stages_out_channels: List[int],
*,
weights: Optional[WeightsEnum], weights: Optional[WeightsEnum],
progress: bool, progress: bool,
quantize: bool, quantize: bool,
...@@ -87,47 +88,43 @@ class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): ...@@ -87,47 +88,43 @@ class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum):
default = ImageNet1K_FBGEMM_V1 default = ImageNet1K_FBGEMM_V1
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: ShuffleNet_V2_X0_5_QuantizedWeights.ImageNet1K_FBGEMM_V1
if kwargs.get("quantize", False)
else ShuffleNet_V2_X0_5_Weights.ImageNet1K_V1,
)
)
def shufflenet_v2_x0_5( def shufflenet_v2_x0_5(
*,
weights: Optional[Union[ShuffleNet_V2_X0_5_QuantizedWeights, ShuffleNet_V2_X0_5_Weights]] = None, weights: Optional[Union[ShuffleNet_V2_X0_5_QuantizedWeights, ShuffleNet_V2_X0_5_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
) -> QuantizableShuffleNetV2: ) -> QuantizableShuffleNetV2:
if type(weights) == bool and weights: weights = (ShuffleNet_V2_X0_5_QuantizedWeights if quantize else ShuffleNet_V2_X0_5_Weights).verify(weights)
_deprecated_positional(kwargs, "pretrained", "weights", True) return _shufflenetv2(
if "pretrained" in kwargs: [4, 8, 4], [24, 48, 96, 192, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs
default_value = (
ShuffleNet_V2_X0_5_QuantizedWeights.ImageNet1K_FBGEMM_V1
if quantize
else ShuffleNet_V2_X0_5_Weights.ImageNet1K_V1
) )
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = ShuffleNet_V2_X0_5_QuantizedWeights.verify(weights)
else:
weights = ShuffleNet_V2_X0_5_Weights.verify(weights)
return _shufflenetv2([4, 8, 4], [24, 48, 96, 192, 1024], weights, progress, quantize, **kwargs)
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: ShuffleNet_V2_X1_0_QuantizedWeights.ImageNet1K_FBGEMM_V1
if kwargs.get("quantize", False)
else ShuffleNet_V2_X1_0_Weights.ImageNet1K_V1,
)
)
def shufflenet_v2_x1_0( def shufflenet_v2_x1_0(
*,
weights: Optional[Union[ShuffleNet_V2_X1_0_QuantizedWeights, ShuffleNet_V2_X1_0_Weights]] = None, weights: Optional[Union[ShuffleNet_V2_X1_0_QuantizedWeights, ShuffleNet_V2_X1_0_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
) -> QuantizableShuffleNetV2: ) -> QuantizableShuffleNetV2:
if type(weights) == bool and weights: weights = (ShuffleNet_V2_X1_0_QuantizedWeights if quantize else ShuffleNet_V2_X1_0_Weights).verify(weights)
_deprecated_positional(kwargs, "pretrained", "weights", True) return _shufflenetv2(
if "pretrained" in kwargs: [4, 8, 4], [24, 116, 232, 464, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs
default_value = (
ShuffleNet_V2_X1_0_QuantizedWeights.ImageNet1K_FBGEMM_V1
if quantize
else ShuffleNet_V2_X1_0_Weights.ImageNet1K_V1
) )
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = ShuffleNet_V2_X1_0_QuantizedWeights.verify(weights)
else:
weights = ShuffleNet_V2_X1_0_Weights.verify(weights)
return _shufflenetv2([4, 8, 4], [24, 116, 232, 464, 1024], weights, progress, quantize, **kwargs)
...@@ -8,7 +8,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -8,7 +8,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.regnet import RegNet, BlockParams from ...models.regnet import RegNet, BlockParams
from ._api import WeightsEnum, Weights from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [ __all__ = [
...@@ -260,33 +260,24 @@ class RegNet_X_32GF_Weights(WeightsEnum): ...@@ -260,33 +260,24 @@ class RegNet_X_32GF_Weights(WeightsEnum):
default = ImageNet1K_V1 default = ImageNet1K_V1
def regnet_y_400mf(weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_Y_400MF_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_400MF_Weights.ImageNet1K_V1)
weights = RegNet_Y_400MF_Weights.verify(weights) weights = RegNet_Y_400MF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs) params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_y_800mf(weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_Y_800MF_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_800MF_Weights.ImageNet1K_V1)
weights = RegNet_Y_800MF_Weights.verify(weights) weights = RegNet_Y_800MF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs) params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_y_1_6gf(weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_Y_1_6GF_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_1_6GF_Weights.ImageNet1K_V1)
weights = RegNet_Y_1_6GF_Weights.verify(weights) weights = RegNet_Y_1_6GF_Weights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
...@@ -295,11 +286,8 @@ def regnet_y_1_6gf(weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: b ...@@ -295,11 +286,8 @@ def regnet_y_1_6gf(weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: b
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_y_3_2gf(weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_Y_3_2GF_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_3_2GF_Weights.ImageNet1K_V1)
weights = RegNet_Y_3_2GF_Weights.verify(weights) weights = RegNet_Y_3_2GF_Weights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
...@@ -308,11 +296,8 @@ def regnet_y_3_2gf(weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: b ...@@ -308,11 +296,8 @@ def regnet_y_3_2gf(weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: b
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_y_8gf(weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_Y_8GF_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_8GF_Weights.ImageNet1K_V1)
weights = RegNet_Y_8GF_Weights.verify(weights) weights = RegNet_Y_8GF_Weights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
...@@ -321,11 +306,8 @@ def regnet_y_8gf(weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool ...@@ -321,11 +306,8 @@ def regnet_y_8gf(weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_y_16gf(weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_Y_16GF_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_16GF_Weights.ImageNet1K_V1)
weights = RegNet_Y_16GF_Weights.verify(weights) weights = RegNet_Y_16GF_Weights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
...@@ -334,11 +316,8 @@ def regnet_y_16gf(weights: Optional[RegNet_Y_16GF_Weights] = None, progress: boo ...@@ -334,11 +316,8 @@ def regnet_y_16gf(weights: Optional[RegNet_Y_16GF_Weights] = None, progress: boo
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_y_32gf(weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_Y_32GF_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_32GF_Weights.ImageNet1K_V1)
weights = RegNet_Y_32GF_Weights.verify(weights) weights = RegNet_Y_32GF_Weights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
...@@ -347,77 +326,56 @@ def regnet_y_32gf(weights: Optional[RegNet_Y_32GF_Weights] = None, progress: boo ...@@ -347,77 +326,56 @@ def regnet_y_32gf(weights: Optional[RegNet_Y_32GF_Weights] = None, progress: boo
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_400mf(weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_400MF_Weights.ImageNet1K_V1)
weights = RegNet_X_400MF_Weights.verify(weights) weights = RegNet_X_400MF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs) params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_800mf(weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_X_800MF_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_800MF_Weights.ImageNet1K_V1)
weights = RegNet_X_800MF_Weights.verify(weights) weights = RegNet_X_800MF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs) params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_1_6gf(weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_X_1_6GF_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_1_6GF_Weights.ImageNet1K_V1)
weights = RegNet_X_1_6GF_Weights.verify(weights) weights = RegNet_X_1_6GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs) params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_3_2gf(weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_X_3_2GF_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_3_2GF_Weights.ImageNet1K_V1)
weights = RegNet_X_3_2GF_Weights.verify(weights) weights = RegNet_X_3_2GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs) params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_8gf(weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_X_8GF_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_8GF_Weights.ImageNet1K_V1)
weights = RegNet_X_8GF_Weights.verify(weights) weights = RegNet_X_8GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs) params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_16gf(weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_X_16GF_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_16GF_Weights.ImageNet1K_V1)
weights = RegNet_X_16GF_Weights.verify(weights) weights = RegNet_X_16GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs) params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_32gf(weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_X_32GF_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_32GF_Weights.ImageNet1K_V1)
weights = RegNet_X_32GF_Weights.verify(weights) weights = RegNet_X_32GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs) params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs)
......
...@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.resnet import BasicBlock, Bottleneck, ResNet from ...models.resnet import BasicBlock, Bottleneck, ResNet
from ._api import WeightsEnum, Weights from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [ __all__ = [
...@@ -250,61 +250,45 @@ class Wide_ResNet101_2_Weights(WeightsEnum): ...@@ -250,61 +250,45 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
default = ImageNet1K_V2 default = ImageNet1K_V2
def resnet18(weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: @handle_legacy_interface(weights=("pretrained", ResNet18_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet18_Weights.ImageNet1K_V1)
weights = ResNet18_Weights.verify(weights) weights = ResNet18_Weights.verify(weights)
return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
def resnet34(weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: @handle_legacy_interface(weights=("pretrained", ResNet34_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet34_Weights.ImageNet1K_V1)
weights = ResNet34_Weights.verify(weights) weights = ResNet34_Weights.verify(weights)
return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs) return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs)
def resnet50(weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: @handle_legacy_interface(weights=("pretrained", ResNet50_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet50_Weights.ImageNet1K_V1)
weights = ResNet50_Weights.verify(weights) weights = ResNet50_Weights.verify(weights)
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
def resnet101(weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: @handle_legacy_interface(weights=("pretrained", ResNet101_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet101_Weights.ImageNet1K_V1)
weights = ResNet101_Weights.verify(weights) weights = ResNet101_Weights.verify(weights)
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
def resnet152(weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: @handle_legacy_interface(weights=("pretrained", ResNet152_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet152_Weights.ImageNet1K_V1)
weights = ResNet152_Weights.verify(weights) weights = ResNet152_Weights.verify(weights)
return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs) return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs)
def resnext50_32x4d(weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: @handle_legacy_interface(weights=("pretrained", ResNeXt50_32X4D_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def resnext50_32x4d(
_deprecated_positional(kwargs, "pretrained", "weights", True) *, weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any
if "pretrained" in kwargs: ) -> ResNet:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt50_32X4D_Weights.ImageNet1K_V1)
weights = ResNeXt50_32X4D_Weights.verify(weights) weights = ResNeXt50_32X4D_Weights.verify(weights)
_ovewrite_named_param(kwargs, "groups", 32) _ovewrite_named_param(kwargs, "groups", 32)
...@@ -312,13 +296,10 @@ def resnext50_32x4d(weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: ...@@ -312,13 +296,10 @@ def resnext50_32x4d(weights: Optional[ResNeXt50_32X4D_Weights] = None, progress:
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", ResNeXt101_32X8D_Weights.ImageNet1K_V1))
def resnext101_32x8d( def resnext101_32x8d(
weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet: ) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt101_32X8D_Weights.ImageNet1K_V1)
weights = ResNeXt101_32X8D_Weights.verify(weights) weights = ResNeXt101_32X8D_Weights.verify(weights)
_ovewrite_named_param(kwargs, "groups", 32) _ovewrite_named_param(kwargs, "groups", 32)
...@@ -326,24 +307,20 @@ def resnext101_32x8d( ...@@ -326,24 +307,20 @@ def resnext101_32x8d(
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
def wide_resnet50_2(weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: @handle_legacy_interface(weights=("pretrained", Wide_ResNet50_2_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def wide_resnet50_2(
_deprecated_positional(kwargs, "pretrained", "weights", True) *, weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any
if "pretrained" in kwargs: ) -> ResNet:
weights = _deprecated_param(kwargs, "pretrained", "weights", Wide_ResNet50_2_Weights.ImageNet1K_V1)
weights = Wide_ResNet50_2_Weights.verify(weights) weights = Wide_ResNet50_2_Weights.verify(weights)
_ovewrite_named_param(kwargs, "width_per_group", 64 * 2) _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", Wide_ResNet101_2_Weights.ImageNet1K_V1))
def wide_resnet101_2( def wide_resnet101_2(
weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet: ) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", Wide_ResNet101_2_Weights.ImageNet1K_V1)
weights = Wide_ResNet101_2_Weights.verify(weights) weights = Wide_ResNet101_2_Weights.verify(weights)
_ovewrite_named_param(kwargs, "width_per_group", 64 * 2) _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
......
...@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet
from .._api import WeightsEnum, Weights from .._api import WeightsEnum, Weights
from .._meta import _VOC_CATEGORIES from .._meta import _VOC_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large
from ..resnet import resnet50, resnet101 from ..resnet import resnet50, resnet101
from ..resnet import ResNet50_Weights, ResNet101_Weights from ..resnet import ResNet50_Weights, ResNet101_Weights
...@@ -72,7 +72,12 @@ class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): ...@@ -72,7 +72,12 @@ class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum):
default = CocoWithVocLabels_V1 default = CocoWithVocLabels_V1
@handle_legacy_interface(
weights=("pretrained", DeepLabV3_ResNet50_Weights.CocoWithVocLabels_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
)
def deeplabv3_resnet50( def deeplabv3_resnet50(
*,
weights: Optional[DeepLabV3_ResNet50_Weights] = None, weights: Optional[DeepLabV3_ResNet50_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
...@@ -80,17 +85,7 @@ def deeplabv3_resnet50( ...@@ -80,17 +85,7 @@ def deeplabv3_resnet50(
weights_backbone: Optional[ResNet50_Weights] = None, weights_backbone: Optional[ResNet50_Weights] = None,
**kwargs: Any, **kwargs: Any,
) -> DeepLabV3: ) -> DeepLabV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3_ResNet50_Weights.CocoWithVocLabels_V1)
weights = DeepLabV3_ResNet50_Weights.verify(weights) weights = DeepLabV3_ResNet50_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1
)
weights_backbone = ResNet50_Weights.verify(weights_backbone) weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
...@@ -109,7 +104,12 @@ def deeplabv3_resnet50( ...@@ -109,7 +104,12 @@ def deeplabv3_resnet50(
return model return model
@handle_legacy_interface(
weights=("pretrained", DeepLabV3_ResNet101_Weights.CocoWithVocLabels_V1),
weights_backbone=("pretrained_backbone", ResNet101_Weights.ImageNet1K_V1),
)
def deeplabv3_resnet101( def deeplabv3_resnet101(
*,
weights: Optional[DeepLabV3_ResNet101_Weights] = None, weights: Optional[DeepLabV3_ResNet101_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
...@@ -117,17 +117,7 @@ def deeplabv3_resnet101( ...@@ -117,17 +117,7 @@ def deeplabv3_resnet101(
weights_backbone: Optional[ResNet101_Weights] = None, weights_backbone: Optional[ResNet101_Weights] = None,
**kwargs: Any, **kwargs: Any,
) -> DeepLabV3: ) -> DeepLabV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3_ResNet101_Weights.CocoWithVocLabels_V1)
weights = DeepLabV3_ResNet101_Weights.verify(weights) weights = DeepLabV3_ResNet101_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet101_Weights.ImageNet1K_V1
)
weights_backbone = ResNet101_Weights.verify(weights_backbone) weights_backbone = ResNet101_Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
...@@ -146,7 +136,12 @@ def deeplabv3_resnet101( ...@@ -146,7 +136,12 @@ def deeplabv3_resnet101(
return model return model
@handle_legacy_interface(
weights=("pretrained", DeepLabV3_MobileNet_V3_Large_Weights.CocoWithVocLabels_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1),
)
def deeplabv3_mobilenet_v3_large( def deeplabv3_mobilenet_v3_large(
*,
weights: Optional[DeepLabV3_MobileNet_V3_Large_Weights] = None, weights: Optional[DeepLabV3_MobileNet_V3_Large_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
...@@ -154,19 +149,7 @@ def deeplabv3_mobilenet_v3_large( ...@@ -154,19 +149,7 @@ def deeplabv3_mobilenet_v3_large(
weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, weights_backbone: Optional[MobileNet_V3_Large_Weights] = None,
**kwargs: Any, **kwargs: Any,
) -> DeepLabV3: ) -> DeepLabV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(
kwargs, "pretrained", "weights", DeepLabV3_MobileNet_V3_Large_Weights.CocoWithVocLabels_V1
)
weights = DeepLabV3_MobileNet_V3_Large_Weights.verify(weights) weights = DeepLabV3_MobileNet_V3_Large_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1
)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
......
...@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.fcn import FCN, _fcn_resnet from ....models.segmentation.fcn import FCN, _fcn_resnet
from .._api import WeightsEnum, Weights from .._api import WeightsEnum, Weights
from .._meta import _VOC_CATEGORIES from .._meta import _VOC_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..resnet import ResNet50_Weights, ResNet101_Weights, resnet50, resnet101 from ..resnet import ResNet50_Weights, ResNet101_Weights, resnet50, resnet101
...@@ -48,7 +48,12 @@ class FCN_ResNet101_Weights(WeightsEnum): ...@@ -48,7 +48,12 @@ class FCN_ResNet101_Weights(WeightsEnum):
default = CocoWithVocLabels_V1 default = CocoWithVocLabels_V1
@handle_legacy_interface(
weights=("pretrained", FCN_ResNet50_Weights.CocoWithVocLabels_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
)
def fcn_resnet50( def fcn_resnet50(
*,
weights: Optional[FCN_ResNet50_Weights] = None, weights: Optional[FCN_ResNet50_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
...@@ -56,17 +61,7 @@ def fcn_resnet50( ...@@ -56,17 +61,7 @@ def fcn_resnet50(
weights_backbone: Optional[ResNet50_Weights] = None, weights_backbone: Optional[ResNet50_Weights] = None,
**kwargs: Any, **kwargs: Any,
) -> FCN: ) -> FCN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", FCN_ResNet50_Weights.CocoWithVocLabels_V1)
weights = FCN_ResNet50_Weights.verify(weights) weights = FCN_ResNet50_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1
)
weights_backbone = ResNet50_Weights.verify(weights_backbone) weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
...@@ -85,7 +80,12 @@ def fcn_resnet50( ...@@ -85,7 +80,12 @@ def fcn_resnet50(
return model return model
@handle_legacy_interface(
weights=("pretrained", FCN_ResNet101_Weights.CocoWithVocLabels_V1),
weights_backbone=("pretrained_backbone", ResNet101_Weights.ImageNet1K_V1),
)
def fcn_resnet101( def fcn_resnet101(
*,
weights: Optional[FCN_ResNet101_Weights] = None, weights: Optional[FCN_ResNet101_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
...@@ -93,17 +93,7 @@ def fcn_resnet101( ...@@ -93,17 +93,7 @@ def fcn_resnet101(
weights_backbone: Optional[ResNet101_Weights] = None, weights_backbone: Optional[ResNet101_Weights] = None,
**kwargs: Any, **kwargs: Any,
) -> FCN: ) -> FCN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", FCN_ResNet101_Weights.CocoWithVocLabels_V1)
weights = FCN_ResNet101_Weights.verify(weights) weights = FCN_ResNet101_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet101_Weights.ImageNet1K_V1
)
weights_backbone = ResNet101_Weights.verify(weights_backbone) weights_backbone = ResNet101_Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
......
...@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3 from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3
from .._api import WeightsEnum, Weights from .._api import WeightsEnum, Weights
from .._meta import _VOC_CATEGORIES from .._meta import _VOC_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large
...@@ -29,7 +29,12 @@ class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): ...@@ -29,7 +29,12 @@ class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum):
default = CocoWithVocLabels_V1 default = CocoWithVocLabels_V1
@handle_legacy_interface(
weights=("pretrained", LRASPP_MobileNet_V3_Large_Weights.CocoWithVocLabels_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1),
)
def lraspp_mobilenet_v3_large( def lraspp_mobilenet_v3_large(
*,
weights: Optional[LRASPP_MobileNet_V3_Large_Weights] = None, weights: Optional[LRASPP_MobileNet_V3_Large_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
...@@ -39,19 +44,7 @@ def lraspp_mobilenet_v3_large( ...@@ -39,19 +44,7 @@ def lraspp_mobilenet_v3_large(
if kwargs.pop("aux_loss", False): if kwargs.pop("aux_loss", False):
raise NotImplementedError("This model does not use auxiliary loss") raise NotImplementedError("This model does not use auxiliary loss")
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(
kwargs, "pretrained", "weights", LRASPP_MobileNet_V3_Large_Weights.CocoWithVocLabels_V1
)
weights = LRASPP_MobileNet_V3_Large_Weights.verify(weights) weights = LRASPP_MobileNet_V3_Large_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1
)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
......
...@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.shufflenetv2 import ShuffleNetV2 from ...models.shufflenetv2 import ShuffleNetV2
from ._api import WeightsEnum, Weights from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [ __all__ = [
...@@ -82,49 +82,37 @@ class ShuffleNet_V2_X2_0_Weights(WeightsEnum): ...@@ -82,49 +82,37 @@ class ShuffleNet_V2_X2_0_Weights(WeightsEnum):
pass pass
@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X0_5_Weights.ImageNet1K_V1))
def shufflenet_v2_x0_5( def shufflenet_v2_x0_5(
weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2: ) -> ShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNet_V2_X0_5_Weights.ImageNet1K_V1)
weights = ShuffleNet_V2_X0_5_Weights.verify(weights) weights = ShuffleNet_V2_X0_5_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X1_0_Weights.ImageNet1K_V1))
def shufflenet_v2_x1_0( def shufflenet_v2_x1_0(
weights: Optional[ShuffleNet_V2_X1_0_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[ShuffleNet_V2_X1_0_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2: ) -> ShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNet_V2_X1_0_Weights.ImageNet1K_V1)
weights = ShuffleNet_V2_X1_0_Weights.verify(weights) weights = ShuffleNet_V2_X1_0_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
@handle_legacy_interface(weights=("pretrained", None))
def shufflenet_v2_x1_5( def shufflenet_v2_x1_5(
weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2: ) -> ShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = ShuffleNet_V2_X1_5_Weights.verify(weights) weights = ShuffleNet_V2_X1_5_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
@handle_legacy_interface(weights=("pretrained", None))
def shufflenet_v2_x2_0( def shufflenet_v2_x2_0(
weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2: ) -> ShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = ShuffleNet_V2_X2_0_Weights.verify(weights) weights = ShuffleNet_V2_X2_0_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
...@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.squeezenet import SqueezeNet from ...models.squeezenet import SqueezeNet
from ._api import WeightsEnum, Weights from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = ["SqueezeNet", "SqueezeNet1_0_Weights", "SqueezeNet1_1_Weights", "squeezenet1_0", "squeezenet1_1"] __all__ = ["SqueezeNet", "SqueezeNet1_0_Weights", "SqueezeNet1_1_Weights", "squeezenet1_0", "squeezenet1_1"]
...@@ -47,11 +47,10 @@ class SqueezeNet1_1_Weights(WeightsEnum): ...@@ -47,11 +47,10 @@ class SqueezeNet1_1_Weights(WeightsEnum):
default = ImageNet1K_V1 default = ImageNet1K_V1
def squeezenet1_0(weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet: @handle_legacy_interface(weights=("pretrained", SqueezeNet1_0_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def squeezenet1_0(
_deprecated_positional(kwargs, "pretrained", "weights", True) *, weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any
if "pretrained" in kwargs: ) -> SqueezeNet:
weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_0_Weights.ImageNet1K_V1)
weights = SqueezeNet1_0_Weights.verify(weights) weights = SqueezeNet1_0_Weights.verify(weights)
if weights is not None: if weights is not None:
...@@ -65,11 +64,10 @@ def squeezenet1_0(weights: Optional[SqueezeNet1_0_Weights] = None, progress: boo ...@@ -65,11 +64,10 @@ def squeezenet1_0(weights: Optional[SqueezeNet1_0_Weights] = None, progress: boo
return model return model
def squeezenet1_1(weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet: @handle_legacy_interface(weights=("pretrained", SqueezeNet1_1_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def squeezenet1_1(
_deprecated_positional(kwargs, "pretrained", "weights", True) *, weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any
if "pretrained" in kwargs: ) -> SqueezeNet:
weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_1_Weights.ImageNet1K_V1)
weights = SqueezeNet1_1_Weights.verify(weights) weights = SqueezeNet1_1_Weights.verify(weights)
if weights is not None: if weights is not None:
......
...@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.vgg import VGG, make_layers, cfgs from ...models.vgg import VGG, make_layers, cfgs
from ._api import WeightsEnum, Weights from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [ __all__ = [
...@@ -169,81 +169,57 @@ class VGG19_BN_Weights(WeightsEnum): ...@@ -169,81 +169,57 @@ class VGG19_BN_Weights(WeightsEnum):
default = ImageNet1K_V1 default = ImageNet1K_V1
def vgg11(weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: @handle_legacy_interface(weights=("pretrained", VGG11_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11_Weights.ImageNet1K_V1)
weights = VGG11_Weights.verify(weights) weights = VGG11_Weights.verify(weights)
return _vgg("A", False, weights, progress, **kwargs) return _vgg("A", False, weights, progress, **kwargs)
def vgg11_bn(weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: @handle_legacy_interface(weights=("pretrained", VGG11_BN_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11_BN_Weights.ImageNet1K_V1)
weights = VGG11_BN_Weights.verify(weights) weights = VGG11_BN_Weights.verify(weights)
return _vgg("A", True, weights, progress, **kwargs) return _vgg("A", True, weights, progress, **kwargs)
def vgg13(weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: @handle_legacy_interface(weights=("pretrained", VGG13_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13_Weights.ImageNet1K_V1)
weights = VGG13_Weights.verify(weights) weights = VGG13_Weights.verify(weights)
return _vgg("B", False, weights, progress, **kwargs) return _vgg("B", False, weights, progress, **kwargs)
def vgg13_bn(weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: @handle_legacy_interface(weights=("pretrained", VGG13_BN_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13_BN_Weights.ImageNet1K_V1)
weights = VGG13_BN_Weights.verify(weights) weights = VGG13_BN_Weights.verify(weights)
return _vgg("B", True, weights, progress, **kwargs) return _vgg("B", True, weights, progress, **kwargs)
def vgg16(weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: @handle_legacy_interface(weights=("pretrained", VGG16_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16_Weights.ImageNet1K_V1)
weights = VGG16_Weights.verify(weights) weights = VGG16_Weights.verify(weights)
return _vgg("D", False, weights, progress, **kwargs) return _vgg("D", False, weights, progress, **kwargs)
def vgg16_bn(weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: @handle_legacy_interface(weights=("pretrained", VGG16_BN_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16_BN_Weights.ImageNet1K_V1)
weights = VGG16_BN_Weights.verify(weights) weights = VGG16_BN_Weights.verify(weights)
return _vgg("D", True, weights, progress, **kwargs) return _vgg("D", True, weights, progress, **kwargs)
def vgg19(weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: @handle_legacy_interface(weights=("pretrained", VGG19_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19_Weights.ImageNet1K_V1)
weights = VGG19_Weights.verify(weights) weights = VGG19_Weights.verify(weights)
return _vgg("E", False, weights, progress, **kwargs) return _vgg("E", False, weights, progress, **kwargs)
def vgg19_bn(weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: @handle_legacy_interface(weights=("pretrained", VGG19_BN_Weights.ImageNet1K_V1))
if type(weights) == bool and weights: def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19_BN_Weights.ImageNet1K_V1)
weights = VGG19_BN_Weights.verify(weights) weights = VGG19_BN_Weights.verify(weights)
return _vgg("E", True, weights, progress, **kwargs) return _vgg("E", True, weights, progress, **kwargs)
...@@ -17,7 +17,7 @@ from ....models.video.resnet import ( ...@@ -17,7 +17,7 @@ from ....models.video.resnet import (
) )
from .._api import WeightsEnum, Weights from .._api import WeightsEnum, Weights
from .._meta import _KINETICS400_CATEGORIES from .._meta import _KINETICS400_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from .._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [ __all__ = [
...@@ -98,11 +98,8 @@ class R2Plus1D_18_Weights(WeightsEnum): ...@@ -98,11 +98,8 @@ class R2Plus1D_18_Weights(WeightsEnum):
default = Kinetics400_V1 default = Kinetics400_V1
def r3d_18(weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: @handle_legacy_interface(weights=("pretrained", R3D_18_Weights.Kinetics400_V1))
if type(weights) == bool and weights: def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", R3D_18_Weights.Kinetics400_V1)
weights = R3D_18_Weights.verify(weights) weights = R3D_18_Weights.verify(weights)
return _video_resnet( return _video_resnet(
...@@ -116,11 +113,8 @@ def r3d_18(weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kw ...@@ -116,11 +113,8 @@ def r3d_18(weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kw
) )
def mc3_18(weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: @handle_legacy_interface(weights=("pretrained", MC3_18_Weights.Kinetics400_V1))
if type(weights) == bool and weights: def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MC3_18_Weights.Kinetics400_V1)
weights = MC3_18_Weights.verify(weights) weights = MC3_18_Weights.verify(weights)
return _video_resnet( return _video_resnet(
...@@ -134,11 +128,8 @@ def mc3_18(weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kw ...@@ -134,11 +128,8 @@ def mc3_18(weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kw
) )
def r2plus1d_18(weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: @handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.Kinetics400_V1))
if type(weights) == bool and weights: def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", R2Plus1D_18_Weights.Kinetics400_V1)
weights = R2Plus1D_18_Weights.verify(weights) weights = R2Plus1D_18_Weights.verify(weights)
return _video_resnet( return _video_resnet(
......
...@@ -12,7 +12,7 @@ import torch.nn as nn ...@@ -12,7 +12,7 @@ import torch.nn as nn
from torch import Tensor from torch import Tensor
from ._api import WeightsEnum from ._api import WeightsEnum
from ._utils import _deprecated_param, _deprecated_positional from ._utils import handle_legacy_interface
__all__ = [ __all__ = [
...@@ -279,7 +279,8 @@ def _vision_transformer( ...@@ -279,7 +279,8 @@ def _vision_transformer(
return model return model
def vit_b_16(weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: @handle_legacy_interface(weights=("pretrained", None))
def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
""" """
Constructs a vit_b_16 architecture from Constructs a vit_b_16 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
...@@ -289,10 +290,6 @@ def vit_b_16(weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, ...@@ -289,10 +290,6 @@ def vit_b_16(weights: Optional[ViT_B_16_Weights] = None, progress: bool = True,
Default: None. Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
""" """
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = ViT_B_16_Weights.verify(weights) weights = ViT_B_16_Weights.verify(weights)
return _vision_transformer( return _vision_transformer(
...@@ -307,7 +304,8 @@ def vit_b_16(weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, ...@@ -307,7 +304,8 @@ def vit_b_16(weights: Optional[ViT_B_16_Weights] = None, progress: bool = True,
) )
def vit_b_32(weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: @handle_legacy_interface(weights=("pretrained", None))
def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
""" """
Constructs a vit_b_32 architecture from Constructs a vit_b_32 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
...@@ -317,10 +315,6 @@ def vit_b_32(weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, ...@@ -317,10 +315,6 @@ def vit_b_32(weights: Optional[ViT_B_32_Weights] = None, progress: bool = True,
Default: None. Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
""" """
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = ViT_B_32_Weights.verify(weights) weights = ViT_B_32_Weights.verify(weights)
return _vision_transformer( return _vision_transformer(
...@@ -335,7 +329,8 @@ def vit_b_32(weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, ...@@ -335,7 +329,8 @@ def vit_b_32(weights: Optional[ViT_B_32_Weights] = None, progress: bool = True,
) )
def vit_l_16(weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: @handle_legacy_interface(weights=("pretrained", None))
def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
""" """
Constructs a vit_l_16 architecture from Constructs a vit_l_16 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
...@@ -345,10 +340,6 @@ def vit_l_16(weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, ...@@ -345,10 +340,6 @@ def vit_l_16(weights: Optional[ViT_L_16_Weights] = None, progress: bool = True,
Default: None. Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
""" """
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = ViT_L_16_Weights.verify(weights) weights = ViT_L_16_Weights.verify(weights)
return _vision_transformer( return _vision_transformer(
...@@ -363,7 +354,8 @@ def vit_l_16(weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, ...@@ -363,7 +354,8 @@ def vit_l_16(weights: Optional[ViT_L_16_Weights] = None, progress: bool = True,
) )
def vit_l_32(weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: @handle_legacy_interface(weights=("pretrained", None))
def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
""" """
Constructs a vit_l_32 architecture from Constructs a vit_l_32 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
...@@ -373,10 +365,6 @@ def vit_l_32(weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, ...@@ -373,10 +365,6 @@ def vit_l_32(weights: Optional[ViT_L_32_Weights] = None, progress: bool = True,
Default: None. Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
""" """
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = ViT_L_32_Weights.verify(weights) weights = ViT_L_32_Weights.verify(weights)
return _vision_transformer( return _vision_transformer(
......
import collections.abc import collections.abc
import difflib import difflib
import enum import enum
import functools
import inspect
import os import os
import os.path import os.path
import textwrap import textwrap
import warnings
from typing import Collection, Sequence, Callable, Any, Iterator, NoReturn, Mapping, TypeVar, Iterable, Tuple, cast from typing import Collection, Sequence, Callable, Any, Iterator, NoReturn, Mapping, TypeVar, Iterable, Tuple, cast
__all__ = [ __all__ = [
...@@ -13,6 +16,7 @@ __all__ = [ ...@@ -13,6 +16,7 @@ __all__ = [
"FrozenMapping", "FrozenMapping",
"make_repr", "make_repr",
"FrozenBunch", "FrozenBunch",
"kwonly_to_pos_or_kw",
] ]
...@@ -126,3 +130,54 @@ class FrozenBunch(FrozenMapping): ...@@ -126,3 +130,54 @@ class FrozenBunch(FrozenMapping):
def __repr__(self) -> str: def __repr__(self) -> str:
return make_repr(type(self).__name__, self.items()) return make_repr(type(self).__name__, self.items())
def kwonly_to_pos_or_kw(fn: Callable[..., D]) -> Callable[..., D]:
"""Decorates a function that uses keyword only parameters to also allow them being passed as positionals.
For example, consider the use case of changing the signature of ``old_fn`` into the one from ``new_fn``:
.. code::
def old_fn(foo, bar, baz=None):
...
def new_fn(foo, *, bar, baz=None):
...
Calling ``old_fn("foo", "bar, "baz")`` was valid, but the same call is no longer valid with ``new_fn``. To keep BC
and at the same time warn the user of the deprecation, this decorator can be used:
.. code::
@kwonly_to_pos_or_kw
def new_fn(foo, *, bar, baz=None):
...
new_fn("foo", "bar, "baz")
"""
params = inspect.signature(fn).parameters
try:
keyword_only_start_idx = next(
idx for idx, param in enumerate(params.values()) if param.kind == param.KEYWORD_ONLY
)
except StopIteration:
raise TypeError(f"Found no keyword-only parameter on function '{fn.__name__}'") from None
keyword_only_params = tuple(inspect.signature(fn).parameters)[keyword_only_start_idx:]
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> D:
args, keyword_only_args = args[:keyword_only_start_idx], args[keyword_only_start_idx:]
if keyword_only_args:
keyword_only_kwargs = dict(zip(keyword_only_params, keyword_only_args))
warnings.warn(
f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional "
f"parameter(s) is deprecated. Please use keyword parameter(s) instead."
)
kwargs.update(keyword_only_kwargs)
return fn(*args, **kwargs)
return wrapper
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment