Unverified Commit b9da6db4 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Cleanup namings of Multi-weights classes and enums (#5003)

* Rename classes Weights => WeightsEnum and WeightEntry => Weights.

* Make enum values follow the naming convention `_V1`, `_V2` etc

* Cleanup the Enum class naming conventions.

* Add a test to check naming conventions.
parent b3cdec1f
...@@ -11,15 +11,15 @@ from ....models.quantization.mobilenetv3 import ( ...@@ -11,15 +11,15 @@ from ....models.quantization.mobilenetv3 import (
QuantizableMobileNetV3, QuantizableMobileNetV3,
_replace_relu, _replace_relu,
) )
from .._api import Weights, WeightEntry from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ..mobilenetv3 import MobileNetV3LargeWeights, _mobilenet_v3_conf from ..mobilenetv3 import MobileNet_V3_Large_Weights, _mobilenet_v3_conf
__all__ = [ __all__ = [
"QuantizableMobileNetV3", "QuantizableMobileNetV3",
"QuantizedMobileNetV3LargeWeights", "MobileNet_V3_Large_QuantizedWeights",
"mobilenet_v3_large", "mobilenet_v3_large",
] ]
...@@ -27,7 +27,7 @@ __all__ = [ ...@@ -27,7 +27,7 @@ __all__ = [
def _mobilenet_v3_model( def _mobilenet_v3_model(
inverted_residual_setting: List[InvertedResidualConfig], inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int, last_channel: int,
weights: Optional[Weights], weights: Optional[WeightsEnum],
progress: bool, progress: bool,
quantize: bool, quantize: bool,
**kwargs: Any, **kwargs: Any,
...@@ -56,8 +56,8 @@ def _mobilenet_v3_model( ...@@ -56,8 +56,8 @@ def _mobilenet_v3_model(
return model return model
class QuantizedMobileNetV3LargeWeights(Weights): class MobileNet_V3_Large_QuantizedWeights(WeightsEnum):
ImageNet1K_QNNPACK_RefV1 = WeightEntry( ImageNet1K_QNNPACK_V1 = Weights(
url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -67,7 +67,7 @@ class QuantizedMobileNetV3LargeWeights(Weights): ...@@ -67,7 +67,7 @@ class QuantizedMobileNetV3LargeWeights(Weights):
"backend": "qnnpack", "backend": "qnnpack",
"quantization": "qat", "quantization": "qat",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3",
"unquantized": MobileNetV3LargeWeights.ImageNet1K_RefV1, "unquantized": MobileNet_V3_Large_Weights.ImageNet1K_V1,
"acc@1": 73.004, "acc@1": 73.004,
"acc@5": 90.858, "acc@5": 90.858,
}, },
...@@ -76,7 +76,7 @@ class QuantizedMobileNetV3LargeWeights(Weights): ...@@ -76,7 +76,7 @@ class QuantizedMobileNetV3LargeWeights(Weights):
def mobilenet_v3_large( def mobilenet_v3_large(
weights: Optional[Union[QuantizedMobileNetV3LargeWeights, MobileNetV3LargeWeights]] = None, weights: Optional[Union[MobileNet_V3_Large_QuantizedWeights, MobileNet_V3_Large_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
...@@ -85,15 +85,15 @@ def mobilenet_v3_large( ...@@ -85,15 +85,15 @@ def mobilenet_v3_large(
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
default_value = ( default_value = (
QuantizedMobileNetV3LargeWeights.ImageNet1K_QNNPACK_RefV1 MobileNet_V3_Large_QuantizedWeights.ImageNet1K_QNNPACK_V1
if quantize if quantize
else MobileNetV3LargeWeights.ImageNet1K_RefV1 else MobileNet_V3_Large_Weights.ImageNet1K_V1
) )
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize: if quantize:
weights = QuantizedMobileNetV3LargeWeights.verify(weights) weights = MobileNet_V3_Large_QuantizedWeights.verify(weights)
else: else:
weights = MobileNetV3LargeWeights.verify(weights) weights = MobileNet_V3_Large_Weights.verify(weights)
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs) return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs)
...@@ -11,17 +11,17 @@ from ....models.quantization.resnet import ( ...@@ -11,17 +11,17 @@ from ....models.quantization.resnet import (
_replace_relu, _replace_relu,
quantize_model, quantize_model,
) )
from .._api import Weights, WeightEntry from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ..resnet import ResNet18Weights, ResNet50Weights, ResNeXt101_32x8dWeights from ..resnet import ResNet18_Weights, ResNet50_Weights, ResNeXt101_32X8D_Weights
__all__ = [ __all__ = [
"QuantizableResNet", "QuantizableResNet",
"QuantizedResNet18Weights", "ResNet18_QuantizedWeights",
"QuantizedResNet50Weights", "ResNet50_QuantizedWeights",
"QuantizedResNeXt101_32x8dWeights", "ResNeXt101_32X8D_QuantizedWeights",
"resnet18", "resnet18",
"resnet50", "resnet50",
"resnext101_32x8d", "resnext101_32x8d",
...@@ -31,7 +31,7 @@ __all__ = [ ...@@ -31,7 +31,7 @@ __all__ = [
def _resnet( def _resnet(
block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]], block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]],
layers: List[int], layers: List[int],
weights: Optional[Weights], weights: Optional[WeightsEnum],
progress: bool, progress: bool,
quantize: bool, quantize: bool,
**kwargs: Any, **kwargs: Any,
...@@ -63,13 +63,13 @@ _COMMON_META = { ...@@ -63,13 +63,13 @@ _COMMON_META = {
} }
class QuantizedResNet18Weights(Weights): class ResNet18_QuantizedWeights(WeightsEnum):
ImageNet1K_FBGEMM_RefV1 = WeightEntry( ImageNet1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"unquantized": ResNet18Weights.ImageNet1K_RefV1, "unquantized": ResNet18_Weights.ImageNet1K_V1,
"acc@1": 69.494, "acc@1": 69.494,
"acc@5": 88.882, "acc@5": 88.882,
}, },
...@@ -77,24 +77,24 @@ class QuantizedResNet18Weights(Weights): ...@@ -77,24 +77,24 @@ class QuantizedResNet18Weights(Weights):
) )
class QuantizedResNet50Weights(Weights): class ResNet50_QuantizedWeights(WeightsEnum):
ImageNet1K_FBGEMM_RefV1 = WeightEntry( ImageNet1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"unquantized": ResNet50Weights.ImageNet1K_RefV1, "unquantized": ResNet50_Weights.ImageNet1K_V1,
"acc@1": 75.920, "acc@1": 75.920,
"acc@5": 92.814, "acc@5": 92.814,
}, },
default=False, default=False,
) )
ImageNet1K_FBGEMM_RefV2 = WeightEntry( ImageNet1K_FBGEMM_V2 = Weights(
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth", url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"unquantized": ResNet50Weights.ImageNet1K_RefV2, "unquantized": ResNet50_Weights.ImageNet1K_V2,
"acc@1": 80.282, "acc@1": 80.282,
"acc@5": 94.976, "acc@5": 94.976,
}, },
...@@ -102,24 +102,24 @@ class QuantizedResNet50Weights(Weights): ...@@ -102,24 +102,24 @@ class QuantizedResNet50Weights(Weights):
) )
class QuantizedResNeXt101_32x8dWeights(Weights): class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
ImageNet1K_FBGEMM_RefV1 = WeightEntry( ImageNet1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"unquantized": ResNeXt101_32x8dWeights.ImageNet1K_RefV1, "unquantized": ResNeXt101_32X8D_Weights.ImageNet1K_V1,
"acc@1": 78.986, "acc@1": 78.986,
"acc@5": 94.480, "acc@5": 94.480,
}, },
default=False, default=False,
) )
ImageNet1K_FBGEMM_RefV2 = WeightEntry( ImageNet1K_FBGEMM_V2 = Weights(
url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth", url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"unquantized": ResNeXt101_32x8dWeights.ImageNet1K_RefV2, "unquantized": ResNeXt101_32X8D_Weights.ImageNet1K_V2,
"acc@1": 82.574, "acc@1": 82.574,
"acc@5": 96.132, "acc@5": 96.132,
}, },
...@@ -128,7 +128,7 @@ class QuantizedResNeXt101_32x8dWeights(Weights): ...@@ -128,7 +128,7 @@ class QuantizedResNeXt101_32x8dWeights(Weights):
def resnet18( def resnet18(
weights: Optional[Union[QuantizedResNet18Weights, ResNet18Weights]] = None, weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
...@@ -136,20 +136,18 @@ def resnet18( ...@@ -136,20 +136,18 @@ def resnet18(
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
default_value = ( default_value = ResNet18_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else ResNet18_Weights.ImageNet1K_V1
QuantizedResNet18Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet18Weights.ImageNet1K_RefV1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize: if quantize:
weights = QuantizedResNet18Weights.verify(weights) weights = ResNet18_QuantizedWeights.verify(weights)
else: else:
weights = ResNet18Weights.verify(weights) weights = ResNet18_Weights.verify(weights)
return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs) return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs)
def resnet50( def resnet50(
weights: Optional[Union[QuantizedResNet50Weights, ResNet50Weights]] = None, weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
...@@ -157,20 +155,18 @@ def resnet50( ...@@ -157,20 +155,18 @@ def resnet50(
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
default_value = ( default_value = ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else ResNet50_Weights.ImageNet1K_V1
QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet50Weights.ImageNet1K_RefV1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize: if quantize:
weights = QuantizedResNet50Weights.verify(weights) weights = ResNet50_QuantizedWeights.verify(weights)
else: else:
weights = ResNet50Weights.verify(weights) weights = ResNet50_Weights.verify(weights)
return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs) return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs)
def resnext101_32x8d( def resnext101_32x8d(
weights: Optional[Union[QuantizedResNeXt101_32x8dWeights, ResNeXt101_32x8dWeights]] = None, weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
...@@ -179,15 +175,15 @@ def resnext101_32x8d( ...@@ -179,15 +175,15 @@ def resnext101_32x8d(
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
default_value = ( default_value = (
QuantizedResNeXt101_32x8dWeights.ImageNet1K_FBGEMM_RefV1 ResNeXt101_32X8D_QuantizedWeights.ImageNet1K_FBGEMM_V1
if quantize if quantize
else ResNeXt101_32x8dWeights.ImageNet1K_RefV1 else ResNeXt101_32X8D_Weights.ImageNet1K_V1
) )
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize: if quantize:
weights = QuantizedResNeXt101_32x8dWeights.verify(weights) weights = ResNeXt101_32X8D_QuantizedWeights.verify(weights)
else: else:
weights = ResNeXt101_32x8dWeights.verify(weights) weights = ResNeXt101_32X8D_Weights.verify(weights)
_ovewrite_named_param(kwargs, "groups", 32) _ovewrite_named_param(kwargs, "groups", 32)
_ovewrite_named_param(kwargs, "width_per_group", 8) _ovewrite_named_param(kwargs, "width_per_group", 8)
......
...@@ -9,16 +9,16 @@ from ....models.quantization.shufflenetv2 import ( ...@@ -9,16 +9,16 @@ from ....models.quantization.shufflenetv2 import (
_replace_relu, _replace_relu,
quantize_model, quantize_model,
) )
from .._api import Weights, WeightEntry from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ..shufflenetv2 import ShuffleNetV2_x0_5Weights, ShuffleNetV2_x1_0Weights from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights
__all__ = [ __all__ = [
"QuantizableShuffleNetV2", "QuantizableShuffleNetV2",
"QuantizedShuffleNetV2_x0_5Weights", "ShuffleNet_V2_X0_5_QuantizedWeights",
"QuantizedShuffleNetV2_x1_0Weights", "ShuffleNet_V2_X1_0_QuantizedWeights",
"shufflenet_v2_x0_5", "shufflenet_v2_x0_5",
"shufflenet_v2_x1_0", "shufflenet_v2_x1_0",
] ]
...@@ -27,7 +27,7 @@ __all__ = [ ...@@ -27,7 +27,7 @@ __all__ = [
def _shufflenetv2( def _shufflenetv2(
stages_repeats: List[int], stages_repeats: List[int],
stages_out_channels: List[int], stages_out_channels: List[int],
weights: Optional[Weights], weights: Optional[WeightsEnum],
progress: bool, progress: bool,
quantize: bool, quantize: bool,
**kwargs: Any, **kwargs: Any,
...@@ -59,13 +59,13 @@ _COMMON_META = { ...@@ -59,13 +59,13 @@ _COMMON_META = {
} }
class QuantizedShuffleNetV2_x0_5Weights(Weights): class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum):
ImageNet1K_FBGEMM_Community = WeightEntry( ImageNet1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth", url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"unquantized": ShuffleNetV2_x0_5Weights.ImageNet1K_Community, "unquantized": ShuffleNet_V2_X0_5_Weights.ImageNet1K_V1,
"acc@1": 57.972, "acc@1": 57.972,
"acc@5": 79.780, "acc@5": 79.780,
}, },
...@@ -73,13 +73,13 @@ class QuantizedShuffleNetV2_x0_5Weights(Weights): ...@@ -73,13 +73,13 @@ class QuantizedShuffleNetV2_x0_5Weights(Weights):
) )
class QuantizedShuffleNetV2_x1_0Weights(Weights): class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum):
ImageNet1K_FBGEMM_Community = WeightEntry( ImageNet1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"unquantized": ShuffleNetV2_x1_0Weights.ImageNet1K_Community, "unquantized": ShuffleNet_V2_X1_0_Weights.ImageNet1K_V1,
"acc@1": 68.360, "acc@1": 68.360,
"acc@5": 87.582, "acc@5": 87.582,
}, },
...@@ -88,7 +88,7 @@ class QuantizedShuffleNetV2_x1_0Weights(Weights): ...@@ -88,7 +88,7 @@ class QuantizedShuffleNetV2_x1_0Weights(Weights):
def shufflenet_v2_x0_5( def shufflenet_v2_x0_5(
weights: Optional[Union[QuantizedShuffleNetV2_x0_5Weights, ShuffleNetV2_x0_5Weights]] = None, weights: Optional[Union[ShuffleNet_V2_X0_5_QuantizedWeights, ShuffleNet_V2_X0_5_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
...@@ -97,21 +97,21 @@ def shufflenet_v2_x0_5( ...@@ -97,21 +97,21 @@ def shufflenet_v2_x0_5(
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
default_value = ( default_value = (
QuantizedShuffleNetV2_x0_5Weights.ImageNet1K_FBGEMM_Community ShuffleNet_V2_X0_5_QuantizedWeights.ImageNet1K_FBGEMM_V1
if quantize if quantize
else ShuffleNetV2_x0_5Weights.ImageNet1K_Community else ShuffleNet_V2_X0_5_Weights.ImageNet1K_V1
) )
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize: if quantize:
weights = QuantizedShuffleNetV2_x0_5Weights.verify(weights) weights = ShuffleNet_V2_X0_5_QuantizedWeights.verify(weights)
else: else:
weights = ShuffleNetV2_x0_5Weights.verify(weights) weights = ShuffleNet_V2_X0_5_Weights.verify(weights)
return _shufflenetv2([4, 8, 4], [24, 48, 96, 192, 1024], weights, progress, quantize, **kwargs) return _shufflenetv2([4, 8, 4], [24, 48, 96, 192, 1024], weights, progress, quantize, **kwargs)
def shufflenet_v2_x1_0( def shufflenet_v2_x1_0(
weights: Optional[Union[QuantizedShuffleNetV2_x1_0Weights, ShuffleNetV2_x1_0Weights]] = None, weights: Optional[Union[ShuffleNet_V2_X1_0_QuantizedWeights, ShuffleNet_V2_X1_0_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
...@@ -120,14 +120,14 @@ def shufflenet_v2_x1_0( ...@@ -120,14 +120,14 @@ def shufflenet_v2_x1_0(
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
default_value = ( default_value = (
QuantizedShuffleNetV2_x1_0Weights.ImageNet1K_FBGEMM_Community ShuffleNet_V2_X1_0_QuantizedWeights.ImageNet1K_FBGEMM_V1
if quantize if quantize
else ShuffleNetV2_x1_0Weights.ImageNet1K_Community else ShuffleNet_V2_X1_0_Weights.ImageNet1K_V1
) )
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize: if quantize:
weights = QuantizedShuffleNetV2_x1_0Weights.verify(weights) weights = ShuffleNet_V2_X1_0_QuantizedWeights.verify(weights)
else: else:
weights = ShuffleNetV2_x1_0Weights.verify(weights) weights = ShuffleNet_V2_X1_0_Weights.verify(weights)
return _shufflenetv2([4, 8, 4], [24, 116, 232, 464, 1024], weights, progress, quantize, **kwargs) return _shufflenetv2([4, 8, 4], [24, 116, 232, 464, 1024], weights, progress, quantize, **kwargs)
...@@ -6,27 +6,27 @@ from torchvision.prototype.transforms import ImageNetEval ...@@ -6,27 +6,27 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.regnet import RegNet, BlockParams from ...models.regnet import RegNet, BlockParams
from ._api import Weights, WeightEntry from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [ __all__ = [
"RegNet", "RegNet",
"RegNet_y_400mfWeights", "RegNet_Y_400MF_Weights",
"RegNet_y_800mfWeights", "RegNet_Y_800MF_Weights",
"RegNet_y_1_6gfWeights", "RegNet_Y_1_6GF_Weights",
"RegNet_y_3_2gfWeights", "RegNet_Y_3_2GF_Weights",
"RegNet_y_8gfWeights", "RegNet_Y_8GF_Weights",
"RegNet_y_16gfWeights", "RegNet_Y_16GF_Weights",
"RegNet_y_32gfWeights", "RegNet_Y_32GF_Weights",
"RegNet_x_400mfWeights", "RegNet_X_400MF_Weights",
"RegNet_x_800mfWeights", "RegNet_X_800MF_Weights",
"RegNet_x_1_6gfWeights", "RegNet_X_1_6GF_Weights",
"RegNet_x_3_2gfWeights", "RegNet_X_3_2GF_Weights",
"RegNet_x_8gfWeights", "RegNet_X_8GF_Weights",
"RegNet_x_16gfWeights", "RegNet_X_16GF_Weights",
"RegNet_x_32gfWeights", "RegNet_X_32GF_Weights",
"regnet_y_400mf", "regnet_y_400mf",
"regnet_y_800mf", "regnet_y_800mf",
"regnet_y_1_6gf", "regnet_y_1_6gf",
...@@ -48,7 +48,7 @@ _COMMON_META = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpo ...@@ -48,7 +48,7 @@ _COMMON_META = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpo
def _regnet( def _regnet(
block_params: BlockParams, block_params: BlockParams,
weights: Optional[Weights], weights: Optional[WeightsEnum],
progress: bool, progress: bool,
**kwargs: Any, **kwargs: Any,
) -> RegNet: ) -> RegNet:
...@@ -64,8 +64,8 @@ def _regnet( ...@@ -64,8 +64,8 @@ def _regnet(
return model return model
class RegNet_y_400mfWeights(Weights): class RegNet_Y_400MF_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -78,8 +78,8 @@ class RegNet_y_400mfWeights(Weights): ...@@ -78,8 +78,8 @@ class RegNet_y_400mfWeights(Weights):
) )
class RegNet_y_800mfWeights(Weights): class RegNet_Y_800MF_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -92,8 +92,8 @@ class RegNet_y_800mfWeights(Weights): ...@@ -92,8 +92,8 @@ class RegNet_y_800mfWeights(Weights):
) )
class RegNet_y_1_6gfWeights(Weights): class RegNet_Y_1_6GF_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -106,8 +106,8 @@ class RegNet_y_1_6gfWeights(Weights): ...@@ -106,8 +106,8 @@ class RegNet_y_1_6gfWeights(Weights):
) )
class RegNet_y_3_2gfWeights(Weights): class RegNet_Y_3_2GF_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -120,8 +120,8 @@ class RegNet_y_3_2gfWeights(Weights): ...@@ -120,8 +120,8 @@ class RegNet_y_3_2gfWeights(Weights):
) )
class RegNet_y_8gfWeights(Weights): class RegNet_Y_8GF_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -134,8 +134,8 @@ class RegNet_y_8gfWeights(Weights): ...@@ -134,8 +134,8 @@ class RegNet_y_8gfWeights(Weights):
) )
class RegNet_y_16gfWeights(Weights): class RegNet_Y_16GF_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -148,8 +148,8 @@ class RegNet_y_16gfWeights(Weights): ...@@ -148,8 +148,8 @@ class RegNet_y_16gfWeights(Weights):
) )
class RegNet_y_32gfWeights(Weights): class RegNet_Y_32GF_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -162,8 +162,8 @@ class RegNet_y_32gfWeights(Weights): ...@@ -162,8 +162,8 @@ class RegNet_y_32gfWeights(Weights):
) )
class RegNet_x_400mfWeights(Weights): class RegNet_X_400MF_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -176,8 +176,8 @@ class RegNet_x_400mfWeights(Weights): ...@@ -176,8 +176,8 @@ class RegNet_x_400mfWeights(Weights):
) )
class RegNet_x_800mfWeights(Weights): class RegNet_X_800MF_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -190,8 +190,8 @@ class RegNet_x_800mfWeights(Weights): ...@@ -190,8 +190,8 @@ class RegNet_x_800mfWeights(Weights):
) )
class RegNet_x_1_6gfWeights(Weights): class RegNet_X_1_6GF_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -204,8 +204,8 @@ class RegNet_x_1_6gfWeights(Weights): ...@@ -204,8 +204,8 @@ class RegNet_x_1_6gfWeights(Weights):
) )
class RegNet_x_3_2gfWeights(Weights): class RegNet_X_3_2GF_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -218,8 +218,8 @@ class RegNet_x_3_2gfWeights(Weights): ...@@ -218,8 +218,8 @@ class RegNet_x_3_2gfWeights(Weights):
) )
class RegNet_x_8gfWeights(Weights): class RegNet_X_8GF_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -232,8 +232,8 @@ class RegNet_x_8gfWeights(Weights): ...@@ -232,8 +232,8 @@ class RegNet_x_8gfWeights(Weights):
) )
class RegNet_x_16gfWeights(Weights): class RegNet_X_16GF_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -246,8 +246,8 @@ class RegNet_x_16gfWeights(Weights): ...@@ -246,8 +246,8 @@ class RegNet_x_16gfWeights(Weights):
) )
class RegNet_x_32gfWeights(Weights): class RegNet_X_32GF_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -260,34 +260,34 @@ class RegNet_x_32gfWeights(Weights): ...@@ -260,34 +260,34 @@ class RegNet_x_32gfWeights(Weights):
) )
def regnet_y_400mf(weights: Optional[RegNet_y_400mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_y_400mf(weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_400mfWeights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_400MF_Weights.ImageNet1K_V1)
weights = RegNet_y_400mfWeights.verify(weights) weights = RegNet_Y_400MF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs) params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_y_800mf(weights: Optional[RegNet_y_800mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_y_800mf(weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_800mfWeights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_800MF_Weights.ImageNet1K_V1)
weights = RegNet_y_800mfWeights.verify(weights) weights = RegNet_Y_800MF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs) params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_y_1_6gf(weights: Optional[RegNet_y_1_6gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_y_1_6gf(weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_1_6gfWeights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_1_6GF_Weights.ImageNet1K_V1)
weights = RegNet_y_1_6gfWeights.verify(weights) weights = RegNet_Y_1_6GF_Weights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs
...@@ -295,12 +295,12 @@ def regnet_y_1_6gf(weights: Optional[RegNet_y_1_6gfWeights] = None, progress: bo ...@@ -295,12 +295,12 @@ def regnet_y_1_6gf(weights: Optional[RegNet_y_1_6gfWeights] = None, progress: bo
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_y_3_2gf(weights: Optional[RegNet_y_3_2gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_y_3_2gf(weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_3_2gfWeights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_3_2GF_Weights.ImageNet1K_V1)
weights = RegNet_y_3_2gfWeights.verify(weights) weights = RegNet_Y_3_2GF_Weights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs
...@@ -308,12 +308,12 @@ def regnet_y_3_2gf(weights: Optional[RegNet_y_3_2gfWeights] = None, progress: bo ...@@ -308,12 +308,12 @@ def regnet_y_3_2gf(weights: Optional[RegNet_y_3_2gfWeights] = None, progress: bo
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_y_8gf(weights: Optional[RegNet_y_8gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_y_8gf(weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_8gfWeights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_8GF_Weights.ImageNet1K_V1)
weights = RegNet_y_8gfWeights.verify(weights) weights = RegNet_Y_8GF_Weights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs
...@@ -321,12 +321,12 @@ def regnet_y_8gf(weights: Optional[RegNet_y_8gfWeights] = None, progress: bool = ...@@ -321,12 +321,12 @@ def regnet_y_8gf(weights: Optional[RegNet_y_8gfWeights] = None, progress: bool =
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_y_16gf(weights: Optional[RegNet_y_16gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_y_16gf(weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_16gfWeights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_16GF_Weights.ImageNet1K_V1)
weights = RegNet_y_16gfWeights.verify(weights) weights = RegNet_Y_16GF_Weights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs
...@@ -334,12 +334,12 @@ def regnet_y_16gf(weights: Optional[RegNet_y_16gfWeights] = None, progress: bool ...@@ -334,12 +334,12 @@ def regnet_y_16gf(weights: Optional[RegNet_y_16gfWeights] = None, progress: bool
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_y_32gf(weights: Optional[RegNet_y_32gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_y_32gf(weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_32gfWeights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_32GF_Weights.ImageNet1K_V1)
weights = RegNet_y_32gfWeights.verify(weights) weights = RegNet_Y_32GF_Weights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs
...@@ -347,78 +347,78 @@ def regnet_y_32gf(weights: Optional[RegNet_y_32gfWeights] = None, progress: bool ...@@ -347,78 +347,78 @@ def regnet_y_32gf(weights: Optional[RegNet_y_32gfWeights] = None, progress: bool
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_400mf(weights: Optional[RegNet_x_400mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_x_400mf(weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_400mfWeights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_400MF_Weights.ImageNet1K_V1)
weights = RegNet_x_400mfWeights.verify(weights) weights = RegNet_X_400MF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs) params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_800mf(weights: Optional[RegNet_x_800mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_x_800mf(weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_800mfWeights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_800MF_Weights.ImageNet1K_V1)
weights = RegNet_x_800mfWeights.verify(weights) weights = RegNet_X_800MF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs) params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_1_6gf(weights: Optional[RegNet_x_1_6gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_x_1_6gf(weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_1_6gfWeights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_1_6GF_Weights.ImageNet1K_V1)
weights = RegNet_x_1_6gfWeights.verify(weights) weights = RegNet_X_1_6GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs) params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_3_2gf(weights: Optional[RegNet_x_3_2gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_x_3_2gf(weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_3_2gfWeights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_3_2GF_Weights.ImageNet1K_V1)
weights = RegNet_x_3_2gfWeights.verify(weights) weights = RegNet_X_3_2GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs) params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_8gf(weights: Optional[RegNet_x_8gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_x_8gf(weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_8gfWeights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_8GF_Weights.ImageNet1K_V1)
weights = RegNet_x_8gfWeights.verify(weights) weights = RegNet_X_8GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs) params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_16gf(weights: Optional[RegNet_x_16gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_x_16gf(weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_16gfWeights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_16GF_Weights.ImageNet1K_V1)
weights = RegNet_x_16gfWeights.verify(weights) weights = RegNet_X_16GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs) params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_32gf(weights: Optional[RegNet_x_32gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_x_32gf(weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_32gfWeights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_32GF_Weights.ImageNet1K_V1)
weights = RegNet_x_32gfWeights.verify(weights) weights = RegNet_X_32GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs) params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
...@@ -5,22 +5,22 @@ from torchvision.prototype.transforms import ImageNetEval ...@@ -5,22 +5,22 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.resnet import BasicBlock, Bottleneck, ResNet from ...models.resnet import BasicBlock, Bottleneck, ResNet
from ._api import Weights, WeightEntry from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [ __all__ = [
"ResNet", "ResNet",
"ResNet18Weights", "ResNet18_Weights",
"ResNet34Weights", "ResNet34_Weights",
"ResNet50Weights", "ResNet50_Weights",
"ResNet101Weights", "ResNet101_Weights",
"ResNet152Weights", "ResNet152_Weights",
"ResNeXt50_32x4dWeights", "ResNeXt50_32X4D_Weights",
"ResNeXt101_32x8dWeights", "ResNeXt101_32X8D_Weights",
"WideResNet50_2Weights", "Wide_ResNet50_2_Weights",
"WideResNet101_2Weights", "Wide_ResNet101_2_Weights",
"resnet18", "resnet18",
"resnet34", "resnet34",
"resnet50", "resnet50",
...@@ -36,7 +36,7 @@ __all__ = [ ...@@ -36,7 +36,7 @@ __all__ = [
def _resnet( def _resnet(
block: Type[Union[BasicBlock, Bottleneck]], block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int], layers: List[int],
weights: Optional[Weights], weights: Optional[WeightsEnum],
progress: bool, progress: bool,
**kwargs: Any, **kwargs: Any,
) -> ResNet: ) -> ResNet:
...@@ -54,8 +54,8 @@ def _resnet( ...@@ -54,8 +54,8 @@ def _resnet(
_COMMON_META = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR} _COMMON_META = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
class ResNet18Weights(Weights): class ResNet18_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet18-f37072fd.pth", url="https://download.pytorch.org/models/resnet18-f37072fd.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -68,8 +68,8 @@ class ResNet18Weights(Weights): ...@@ -68,8 +68,8 @@ class ResNet18Weights(Weights):
) )
class ResNet34Weights(Weights): class ResNet34_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet34-b627a593.pth", url="https://download.pytorch.org/models/resnet34-b627a593.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -82,8 +82,8 @@ class ResNet34Weights(Weights): ...@@ -82,8 +82,8 @@ class ResNet34Weights(Weights):
) )
class ResNet50Weights(Weights): class ResNet50_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet50-0676ba61.pth", url="https://download.pytorch.org/models/resnet50-0676ba61.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -94,7 +94,7 @@ class ResNet50Weights(Weights): ...@@ -94,7 +94,7 @@ class ResNet50Weights(Weights):
}, },
default=False, default=False,
) )
ImageNet1K_RefV2 = WeightEntry( ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/resnet50-f46c3f97.pth", url="https://download.pytorch.org/models/resnet50-f46c3f97.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={ meta={
...@@ -107,8 +107,8 @@ class ResNet50Weights(Weights): ...@@ -107,8 +107,8 @@ class ResNet50Weights(Weights):
) )
class ResNet101Weights(Weights): class ResNet101_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet101-63fe2227.pth", url="https://download.pytorch.org/models/resnet101-63fe2227.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -119,7 +119,7 @@ class ResNet101Weights(Weights): ...@@ -119,7 +119,7 @@ class ResNet101Weights(Weights):
}, },
default=False, default=False,
) )
ImageNet1K_RefV2 = WeightEntry( ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/resnet101-cd907fc2.pth", url="https://download.pytorch.org/models/resnet101-cd907fc2.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={ meta={
...@@ -132,8 +132,8 @@ class ResNet101Weights(Weights): ...@@ -132,8 +132,8 @@ class ResNet101Weights(Weights):
) )
class ResNet152Weights(Weights): class ResNet152_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet152-394f9c45.pth", url="https://download.pytorch.org/models/resnet152-394f9c45.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -144,7 +144,7 @@ class ResNet152Weights(Weights): ...@@ -144,7 +144,7 @@ class ResNet152Weights(Weights):
}, },
default=False, default=False,
) )
ImageNet1K_RefV2 = WeightEntry( ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/resnet152-f82ba261.pth", url="https://download.pytorch.org/models/resnet152-f82ba261.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={ meta={
...@@ -157,8 +157,8 @@ class ResNet152Weights(Weights): ...@@ -157,8 +157,8 @@ class ResNet152Weights(Weights):
) )
class ResNeXt50_32x4dWeights(Weights): class ResNeXt50_32X4D_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -169,7 +169,7 @@ class ResNeXt50_32x4dWeights(Weights): ...@@ -169,7 +169,7 @@ class ResNeXt50_32x4dWeights(Weights):
}, },
default=False, default=False,
) )
ImageNet1K_RefV2 = WeightEntry( ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth", url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={ meta={
...@@ -182,8 +182,8 @@ class ResNeXt50_32x4dWeights(Weights): ...@@ -182,8 +182,8 @@ class ResNeXt50_32x4dWeights(Weights):
) )
class ResNeXt101_32x8dWeights(Weights): class ResNeXt101_32X8D_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -194,7 +194,7 @@ class ResNeXt101_32x8dWeights(Weights): ...@@ -194,7 +194,7 @@ class ResNeXt101_32x8dWeights(Weights):
}, },
default=False, default=False,
) )
ImageNet1K_RefV2 = WeightEntry( ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth", url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={ meta={
...@@ -207,8 +207,8 @@ class ResNeXt101_32x8dWeights(Weights): ...@@ -207,8 +207,8 @@ class ResNeXt101_32x8dWeights(Weights):
) )
class WideResNet50_2Weights(Weights): class Wide_ResNet50_2_Weights(WeightsEnum):
ImageNet1K_Community = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -219,7 +219,7 @@ class WideResNet50_2Weights(Weights): ...@@ -219,7 +219,7 @@ class WideResNet50_2Weights(Weights):
}, },
default=False, default=False,
) )
ImageNet1K_RefV2 = WeightEntry( ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth", url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={ meta={
...@@ -232,8 +232,8 @@ class WideResNet50_2Weights(Weights): ...@@ -232,8 +232,8 @@ class WideResNet50_2Weights(Weights):
) )
class WideResNet101_2Weights(Weights): class Wide_ResNet101_2_Weights(WeightsEnum):
ImageNet1K_Community = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -244,7 +244,7 @@ class WideResNet101_2Weights(Weights): ...@@ -244,7 +244,7 @@ class WideResNet101_2Weights(Weights):
}, },
default=False, default=False,
) )
ImageNet1K_RefV2 = WeightEntry( ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth", url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={ meta={
...@@ -257,97 +257,101 @@ class WideResNet101_2Weights(Weights): ...@@ -257,97 +257,101 @@ class WideResNet101_2Weights(Weights):
) )
def resnet18(weights: Optional[ResNet18Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnet18(weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet18Weights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet18_Weights.ImageNet1K_V1)
weights = ResNet18Weights.verify(weights) weights = ResNet18_Weights.verify(weights)
return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
def resnet34(weights: Optional[ResNet34Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnet34(weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet34Weights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet34_Weights.ImageNet1K_V1)
weights = ResNet34Weights.verify(weights) weights = ResNet34_Weights.verify(weights)
return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs) return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs)
def resnet50(weights: Optional[ResNet50Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnet50(weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet50Weights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet50_Weights.ImageNet1K_V1)
weights = ResNet50Weights.verify(weights) weights = ResNet50_Weights.verify(weights)
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
def resnet101(weights: Optional[ResNet101Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnet101(weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet101Weights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet101_Weights.ImageNet1K_V1)
weights = ResNet101Weights.verify(weights) weights = ResNet101_Weights.verify(weights)
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
def resnet152(weights: Optional[ResNet152Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnet152(weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet152Weights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet152_Weights.ImageNet1K_V1)
weights = ResNet152Weights.verify(weights) weights = ResNet152_Weights.verify(weights)
return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs) return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs)
def resnext50_32x4d(weights: Optional[ResNeXt50_32x4dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnext50_32x4d(weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt50_32x4dWeights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt50_32X4D_Weights.ImageNet1K_V1)
weights = ResNeXt50_32x4dWeights.verify(weights) weights = ResNeXt50_32X4D_Weights.verify(weights)
_ovewrite_named_param(kwargs, "groups", 32) _ovewrite_named_param(kwargs, "groups", 32)
_ovewrite_named_param(kwargs, "width_per_group", 4) _ovewrite_named_param(kwargs, "width_per_group", 4)
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
def resnext101_32x8d(weights: Optional[ResNeXt101_32x8dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnext101_32x8d(
weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt101_32x8dWeights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt101_32X8D_Weights.ImageNet1K_V1)
weights = ResNeXt101_32x8dWeights.verify(weights) weights = ResNeXt101_32X8D_Weights.verify(weights)
_ovewrite_named_param(kwargs, "groups", 32) _ovewrite_named_param(kwargs, "groups", 32)
_ovewrite_named_param(kwargs, "width_per_group", 8) _ovewrite_named_param(kwargs, "width_per_group", 8)
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
def wide_resnet50_2(weights: Optional[WideResNet50_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def wide_resnet50_2(weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", WideResNet50_2Weights.ImageNet1K_Community) weights = _deprecated_param(kwargs, "pretrained", "weights", Wide_ResNet50_2_Weights.ImageNet1K_V1)
weights = WideResNet50_2Weights.verify(weights) weights = Wide_ResNet50_2_Weights.verify(weights)
_ovewrite_named_param(kwargs, "width_per_group", 64 * 2) _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
def wide_resnet101_2(weights: Optional[WideResNet101_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def wide_resnet101_2(
weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", WideResNet101_2Weights.ImageNet1K_Community) weights = _deprecated_param(kwargs, "pretrained", "weights", Wide_ResNet101_2_Weights.ImageNet1K_V1)
weights = WideResNet101_2Weights.verify(weights) weights = Wide_ResNet101_2_Weights.verify(weights)
_ovewrite_named_param(kwargs, "width_per_group", 64 * 2) _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
...@@ -5,19 +5,19 @@ from torchvision.prototype.transforms import VocEval ...@@ -5,19 +5,19 @@ from torchvision.prototype.transforms import VocEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet
from .._api import Weights, WeightEntry from .._api import WeightsEnum, Weights
from .._meta import _VOC_CATEGORIES from .._meta import _VOC_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large
from ..resnet import resnet50, resnet101 from ..resnet import resnet50, resnet101
from ..resnet import ResNet50Weights, ResNet101Weights from ..resnet import ResNet50_Weights, ResNet101_Weights
__all__ = [ __all__ = [
"DeepLabV3", "DeepLabV3",
"DeepLabV3ResNet50Weights", "DeepLabV3_ResNet50_Weights",
"DeepLabV3ResNet101Weights", "DeepLabV3_ResNet101_Weights",
"DeepLabV3MobileNetV3LargeWeights", "DeepLabV3_MobileNet_V3_Large_Weights",
"deeplabv3_mobilenet_v3_large", "deeplabv3_mobilenet_v3_large",
"deeplabv3_resnet50", "deeplabv3_resnet50",
"deeplabv3_resnet101", "deeplabv3_resnet101",
...@@ -30,8 +30,8 @@ _COMMON_META = { ...@@ -30,8 +30,8 @@ _COMMON_META = {
} }
class DeepLabV3ResNet50Weights(Weights): class DeepLabV3_ResNet50_Weights(WeightsEnum):
CocoWithVocLabels_RefV1 = WeightEntry( CocoWithVocLabels_V1 = Weights(
url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
transforms=partial(VocEval, resize_size=520), transforms=partial(VocEval, resize_size=520),
meta={ meta={
...@@ -44,8 +44,8 @@ class DeepLabV3ResNet50Weights(Weights): ...@@ -44,8 +44,8 @@ class DeepLabV3ResNet50Weights(Weights):
) )
class DeepLabV3ResNet101Weights(Weights): class DeepLabV3_ResNet101_Weights(WeightsEnum):
CocoWithVocLabels_RefV1 = WeightEntry( CocoWithVocLabels_V1 = Weights(
url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
transforms=partial(VocEval, resize_size=520), transforms=partial(VocEval, resize_size=520),
meta={ meta={
...@@ -58,8 +58,8 @@ class DeepLabV3ResNet101Weights(Weights): ...@@ -58,8 +58,8 @@ class DeepLabV3ResNet101Weights(Weights):
) )
class DeepLabV3MobileNetV3LargeWeights(Weights): class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum):
CocoWithVocLabels_RefV1 = WeightEntry( CocoWithVocLabels_V1 = Weights(
url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
transforms=partial(VocEval, resize_size=520), transforms=partial(VocEval, resize_size=520),
meta={ meta={
...@@ -73,25 +73,25 @@ class DeepLabV3MobileNetV3LargeWeights(Weights): ...@@ -73,25 +73,25 @@ class DeepLabV3MobileNetV3LargeWeights(Weights):
def deeplabv3_resnet50( def deeplabv3_resnet50(
weights: Optional[DeepLabV3ResNet50Weights] = None, weights: Optional[DeepLabV3_ResNet50_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None, aux_loss: Optional[bool] = None,
weights_backbone: Optional[ResNet50Weights] = None, weights_backbone: Optional[ResNet50_Weights] = None,
**kwargs: Any, **kwargs: Any,
) -> DeepLabV3: ) -> DeepLabV3:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3ResNet50Weights.CocoWithVocLabels_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3_ResNet50_Weights.CocoWithVocLabels_V1)
weights = DeepLabV3ResNet50Weights.verify(weights) weights = DeepLabV3_ResNet50_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone: if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param( weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1 kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1
) )
weights_backbone = ResNet50Weights.verify(weights_backbone) weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
weights_backbone = None weights_backbone = None
...@@ -110,25 +110,25 @@ def deeplabv3_resnet50( ...@@ -110,25 +110,25 @@ def deeplabv3_resnet50(
def deeplabv3_resnet101( def deeplabv3_resnet101(
weights: Optional[DeepLabV3ResNet101Weights] = None, weights: Optional[DeepLabV3_ResNet101_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None, aux_loss: Optional[bool] = None,
weights_backbone: Optional[ResNet101Weights] = None, weights_backbone: Optional[ResNet101_Weights] = None,
**kwargs: Any, **kwargs: Any,
) -> DeepLabV3: ) -> DeepLabV3:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3ResNet101Weights.CocoWithVocLabels_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3_ResNet101_Weights.CocoWithVocLabels_V1)
weights = DeepLabV3ResNet101Weights.verify(weights) weights = DeepLabV3_ResNet101_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone: if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param( weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet101Weights.ImageNet1K_RefV1 kwargs, "pretrained_backbone", "weights_backbone", ResNet101_Weights.ImageNet1K_V1
) )
weights_backbone = ResNet101Weights.verify(weights_backbone) weights_backbone = ResNet101_Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
weights_backbone = None weights_backbone = None
...@@ -147,27 +147,27 @@ def deeplabv3_resnet101( ...@@ -147,27 +147,27 @@ def deeplabv3_resnet101(
def deeplabv3_mobilenet_v3_large( def deeplabv3_mobilenet_v3_large(
weights: Optional[DeepLabV3MobileNetV3LargeWeights] = None, weights: Optional[DeepLabV3_MobileNet_V3_Large_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None, aux_loss: Optional[bool] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None, weights_backbone: Optional[MobileNet_V3_Large_Weights] = None,
**kwargs: Any, **kwargs: Any,
) -> DeepLabV3: ) -> DeepLabV3:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param( weights = _deprecated_param(
kwargs, "pretrained", "weights", DeepLabV3MobileNetV3LargeWeights.CocoWithVocLabels_RefV1 kwargs, "pretrained", "weights", DeepLabV3_MobileNet_V3_Large_Weights.CocoWithVocLabels_V1
) )
weights = DeepLabV3MobileNetV3LargeWeights.verify(weights) weights = DeepLabV3_MobileNet_V3_Large_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone: if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param( weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1 kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1
) )
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
weights_backbone = None weights_backbone = None
......
...@@ -5,13 +5,13 @@ from torchvision.prototype.transforms import VocEval ...@@ -5,13 +5,13 @@ from torchvision.prototype.transforms import VocEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.fcn import FCN, _fcn_resnet from ....models.segmentation.fcn import FCN, _fcn_resnet
from .._api import Weights, WeightEntry from .._api import WeightsEnum, Weights
from .._meta import _VOC_CATEGORIES from .._meta import _VOC_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..resnet import ResNet50Weights, ResNet101Weights, resnet50, resnet101 from ..resnet import ResNet50_Weights, ResNet101_Weights, resnet50, resnet101
__all__ = ["FCN", "FCNResNet50Weights", "FCNResNet101Weights", "fcn_resnet50", "fcn_resnet101"] __all__ = ["FCN", "FCN_ResNet50_Weights", "FCN_ResNet101_Weights", "fcn_resnet50", "fcn_resnet101"]
_COMMON_META = { _COMMON_META = {
...@@ -20,8 +20,8 @@ _COMMON_META = { ...@@ -20,8 +20,8 @@ _COMMON_META = {
} }
class FCNResNet50Weights(Weights): class FCN_ResNet50_Weights(WeightsEnum):
CocoWithVocLabels_RefV1 = WeightEntry( CocoWithVocLabels_V1 = Weights(
url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth",
transforms=partial(VocEval, resize_size=520), transforms=partial(VocEval, resize_size=520),
meta={ meta={
...@@ -34,8 +34,8 @@ class FCNResNet50Weights(Weights): ...@@ -34,8 +34,8 @@ class FCNResNet50Weights(Weights):
) )
class FCNResNet101Weights(Weights): class FCN_ResNet101_Weights(WeightsEnum):
CocoWithVocLabels_RefV1 = WeightEntry( CocoWithVocLabels_V1 = Weights(
url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth",
transforms=partial(VocEval, resize_size=520), transforms=partial(VocEval, resize_size=520),
meta={ meta={
...@@ -49,25 +49,25 @@ class FCNResNet101Weights(Weights): ...@@ -49,25 +49,25 @@ class FCNResNet101Weights(Weights):
def fcn_resnet50( def fcn_resnet50(
weights: Optional[FCNResNet50Weights] = None, weights: Optional[FCN_ResNet50_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None, aux_loss: Optional[bool] = None,
weights_backbone: Optional[ResNet50Weights] = None, weights_backbone: Optional[ResNet50_Weights] = None,
**kwargs: Any, **kwargs: Any,
) -> FCN: ) -> FCN:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", FCNResNet50Weights.CocoWithVocLabels_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", FCN_ResNet50_Weights.CocoWithVocLabels_V1)
weights = FCNResNet50Weights.verify(weights) weights = FCN_ResNet50_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone: if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param( weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1 kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1
) )
weights_backbone = ResNet50Weights.verify(weights_backbone) weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
weights_backbone = None weights_backbone = None
...@@ -86,25 +86,25 @@ def fcn_resnet50( ...@@ -86,25 +86,25 @@ def fcn_resnet50(
def fcn_resnet101( def fcn_resnet101(
weights: Optional[FCNResNet101Weights] = None, weights: Optional[FCN_ResNet101_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None, aux_loss: Optional[bool] = None,
weights_backbone: Optional[ResNet101Weights] = None, weights_backbone: Optional[ResNet101_Weights] = None,
**kwargs: Any, **kwargs: Any,
) -> FCN: ) -> FCN:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", FCNResNet101Weights.CocoWithVocLabels_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", FCN_ResNet101_Weights.CocoWithVocLabels_V1)
weights = FCNResNet101Weights.verify(weights) weights = FCN_ResNet101_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone: if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param( weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet101Weights.ImageNet1K_RefV1 kwargs, "pretrained_backbone", "weights_backbone", ResNet101_Weights.ImageNet1K_V1
) )
weights_backbone = ResNet101Weights.verify(weights_backbone) weights_backbone = ResNet101_Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
weights_backbone = None weights_backbone = None
......
...@@ -5,17 +5,17 @@ from torchvision.prototype.transforms import VocEval ...@@ -5,17 +5,17 @@ from torchvision.prototype.transforms import VocEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3 from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3
from .._api import Weights, WeightEntry from .._api import WeightsEnum, Weights
from .._meta import _VOC_CATEGORIES from .._meta import _VOC_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large
__all__ = ["LRASPP", "LRASPPMobileNetV3LargeWeights", "lraspp_mobilenet_v3_large"] __all__ = ["LRASPP", "LRASPP_MobileNet_V3_Large_Weights", "lraspp_mobilenet_v3_large"]
class LRASPPMobileNetV3LargeWeights(Weights): class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum):
CocoWithVocLabels_RefV1 = WeightEntry( CocoWithVocLabels_V1 = Weights(
url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth",
transforms=partial(VocEval, resize_size=520), transforms=partial(VocEval, resize_size=520),
meta={ meta={
...@@ -30,10 +30,10 @@ class LRASPPMobileNetV3LargeWeights(Weights): ...@@ -30,10 +30,10 @@ class LRASPPMobileNetV3LargeWeights(Weights):
def lraspp_mobilenet_v3_large( def lraspp_mobilenet_v3_large(
weights: Optional[LRASPPMobileNetV3LargeWeights] = None, weights: Optional[LRASPP_MobileNet_V3_Large_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None, weights_backbone: Optional[MobileNet_V3_Large_Weights] = None,
**kwargs: Any, **kwargs: Any,
) -> LRASPP: ) -> LRASPP:
if kwargs.pop("aux_loss", False): if kwargs.pop("aux_loss", False):
...@@ -43,16 +43,16 @@ def lraspp_mobilenet_v3_large( ...@@ -43,16 +43,16 @@ def lraspp_mobilenet_v3_large(
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param( weights = _deprecated_param(
kwargs, "pretrained", "weights", LRASPPMobileNetV3LargeWeights.CocoWithVocLabels_RefV1 kwargs, "pretrained", "weights", LRASPP_MobileNet_V3_Large_Weights.CocoWithVocLabels_V1
) )
weights = LRASPPMobileNetV3LargeWeights.verify(weights) weights = LRASPP_MobileNet_V3_Large_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone: if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param( weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1 kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1
) )
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
weights_backbone = None weights_backbone = None
......
...@@ -5,17 +5,17 @@ from torchvision.prototype.transforms import ImageNetEval ...@@ -5,17 +5,17 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.shufflenetv2 import ShuffleNetV2 from ...models.shufflenetv2 import ShuffleNetV2
from ._api import Weights, WeightEntry from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [ __all__ = [
"ShuffleNetV2", "ShuffleNetV2",
"ShuffleNetV2_x0_5Weights", "ShuffleNet_V2_X0_5_Weights",
"ShuffleNetV2_x1_0Weights", "ShuffleNet_V2_X1_0_Weights",
"ShuffleNetV2_x1_5Weights", "ShuffleNet_V2_X1_5_Weights",
"ShuffleNetV2_x2_0Weights", "ShuffleNet_V2_X2_0_Weights",
"shufflenet_v2_x0_5", "shufflenet_v2_x0_5",
"shufflenet_v2_x1_0", "shufflenet_v2_x1_0",
"shufflenet_v2_x1_5", "shufflenet_v2_x1_5",
...@@ -24,7 +24,7 @@ __all__ = [ ...@@ -24,7 +24,7 @@ __all__ = [
def _shufflenetv2( def _shufflenetv2(
weights: Optional[Weights], weights: Optional[WeightsEnum],
progress: bool, progress: bool,
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
...@@ -48,8 +48,8 @@ _COMMON_META = { ...@@ -48,8 +48,8 @@ _COMMON_META = {
} }
class ShuffleNetV2_x0_5Weights(Weights): class ShuffleNet_V2_X0_5_Weights(WeightsEnum):
ImageNet1K_Community = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -61,8 +61,8 @@ class ShuffleNetV2_x0_5Weights(Weights): ...@@ -61,8 +61,8 @@ class ShuffleNetV2_x0_5Weights(Weights):
) )
class ShuffleNetV2_x1_0Weights(Weights): class ShuffleNet_V2_X1_0_Weights(WeightsEnum):
ImageNet1K_Community = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -74,57 +74,57 @@ class ShuffleNetV2_x1_0Weights(Weights): ...@@ -74,57 +74,57 @@ class ShuffleNetV2_x1_0Weights(Weights):
) )
class ShuffleNetV2_x1_5Weights(Weights): class ShuffleNet_V2_X1_5_Weights(WeightsEnum):
pass pass
class ShuffleNetV2_x2_0Weights(Weights): class ShuffleNet_V2_X2_0_Weights(WeightsEnum):
pass pass
def shufflenet_v2_x0_5( def shufflenet_v2_x0_5(
weights: Optional[ShuffleNetV2_x0_5Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2: ) -> ShuffleNetV2:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNetV2_x0_5Weights.ImageNet1K_Community) weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNet_V2_X0_5_Weights.ImageNet1K_V1)
weights = ShuffleNetV2_x0_5Weights.verify(weights) weights = ShuffleNet_V2_X0_5_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
def shufflenet_v2_x1_0( def shufflenet_v2_x1_0(
weights: Optional[ShuffleNetV2_x1_0Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[ShuffleNet_V2_X1_0_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2: ) -> ShuffleNetV2:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNetV2_x1_0Weights.ImageNet1K_Community) weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNet_V2_X1_0_Weights.ImageNet1K_V1)
weights = ShuffleNetV2_x1_0Weights.verify(weights) weights = ShuffleNet_V2_X1_0_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
def shufflenet_v2_x1_5( def shufflenet_v2_x1_5(
weights: Optional[ShuffleNetV2_x1_5Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2: ) -> ShuffleNetV2:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None) weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = ShuffleNetV2_x1_5Weights.verify(weights) weights = ShuffleNet_V2_X1_5_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
def shufflenet_v2_x2_0( def shufflenet_v2_x2_0(
weights: Optional[ShuffleNetV2_x2_0Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2: ) -> ShuffleNetV2:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None) weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = ShuffleNetV2_x2_0Weights.verify(weights) weights = ShuffleNet_V2_X2_0_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
...@@ -5,12 +5,12 @@ from torchvision.prototype.transforms import ImageNetEval ...@@ -5,12 +5,12 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.squeezenet import SqueezeNet from ...models.squeezenet import SqueezeNet
from ._api import Weights, WeightEntry from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = ["SqueezeNet", "SqueezeNet1_0Weights", "SqueezeNet1_1Weights", "squeezenet1_0", "squeezenet1_1"] __all__ = ["SqueezeNet", "SqueezeNet1_0_Weights", "SqueezeNet1_1_Weights", "squeezenet1_0", "squeezenet1_1"]
_COMMON_META = { _COMMON_META = {
...@@ -21,8 +21,8 @@ _COMMON_META = { ...@@ -21,8 +21,8 @@ _COMMON_META = {
} }
class SqueezeNet1_0Weights(Weights): class SqueezeNet1_0_Weights(WeightsEnum):
ImageNet1K_Community = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -34,8 +34,8 @@ class SqueezeNet1_0Weights(Weights): ...@@ -34,8 +34,8 @@ class SqueezeNet1_0Weights(Weights):
) )
class SqueezeNet1_1Weights(Weights): class SqueezeNet1_1_Weights(WeightsEnum):
ImageNet1K_Community = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -47,12 +47,12 @@ class SqueezeNet1_1Weights(Weights): ...@@ -47,12 +47,12 @@ class SqueezeNet1_1Weights(Weights):
) )
def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet: def squeezenet1_0(weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_0Weights.ImageNet1K_Community) weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_0_Weights.ImageNet1K_V1)
weights = SqueezeNet1_0Weights.verify(weights) weights = SqueezeNet1_0_Weights.verify(weights)
if weights is not None: if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
...@@ -65,12 +65,12 @@ def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool ...@@ -65,12 +65,12 @@ def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool
return model return model
def squeezenet1_1(weights: Optional[SqueezeNet1_1Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet: def squeezenet1_1(weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_1Weights.ImageNet1K_Community) weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_1_Weights.ImageNet1K_V1)
weights = SqueezeNet1_1Weights.verify(weights) weights = SqueezeNet1_1_Weights.verify(weights)
if weights is not None: if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
......
...@@ -5,21 +5,21 @@ from torchvision.prototype.transforms import ImageNetEval ...@@ -5,21 +5,21 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.vgg import VGG, make_layers, cfgs from ...models.vgg import VGG, make_layers, cfgs
from ._api import Weights, WeightEntry from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [ __all__ = [
"VGG", "VGG",
"VGG11Weights", "VGG11_Weights",
"VGG11BNWeights", "VGG11_BN_Weights",
"VGG13Weights", "VGG13_Weights",
"VGG13BNWeights", "VGG13_BN_Weights",
"VGG16Weights", "VGG16_Weights",
"VGG16BNWeights", "VGG16_BN_Weights",
"VGG19Weights", "VGG19_Weights",
"VGG19BNWeights", "VGG19_BN_Weights",
"vgg11", "vgg11",
"vgg11_bn", "vgg11_bn",
"vgg13", "vgg13",
...@@ -31,7 +31,7 @@ __all__ = [ ...@@ -31,7 +31,7 @@ __all__ = [
] ]
def _vgg(cfg: str, batch_norm: bool, weights: Optional[Weights], progress: bool, **kwargs: Any) -> VGG: def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG:
if weights is not None: if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
...@@ -48,8 +48,8 @@ _COMMON_META = { ...@@ -48,8 +48,8 @@ _COMMON_META = {
} }
class VGG11Weights(Weights): class VGG11_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg11-8a719046.pth", url="https://download.pytorch.org/models/vgg11-8a719046.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -61,8 +61,8 @@ class VGG11Weights(Weights): ...@@ -61,8 +61,8 @@ class VGG11Weights(Weights):
) )
class VGG11BNWeights(Weights): class VGG11_BN_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth", url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -74,8 +74,8 @@ class VGG11BNWeights(Weights): ...@@ -74,8 +74,8 @@ class VGG11BNWeights(Weights):
) )
class VGG13Weights(Weights): class VGG13_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg13-19584684.pth", url="https://download.pytorch.org/models/vgg13-19584684.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -87,8 +87,8 @@ class VGG13Weights(Weights): ...@@ -87,8 +87,8 @@ class VGG13Weights(Weights):
) )
class VGG13BNWeights(Weights): class VGG13_BN_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -100,8 +100,8 @@ class VGG13BNWeights(Weights): ...@@ -100,8 +100,8 @@ class VGG13BNWeights(Weights):
) )
class VGG16Weights(Weights): class VGG16_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg16-397923af.pth", url="https://download.pytorch.org/models/vgg16-397923af.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -114,7 +114,7 @@ class VGG16Weights(Weights): ...@@ -114,7 +114,7 @@ class VGG16Weights(Weights):
# We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the
# same input standardization method as the paper. Only the `features` weights have proper values, those on the # same input standardization method as the paper. Only the `features` weights have proper values, those on the
# `classifier` module are filled with nans. # `classifier` module are filled with nans.
ImageNet1K_Features = WeightEntry( ImageNet1K_Features = Weights(
url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth", url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth",
transforms=partial( transforms=partial(
ImageNetEval, crop_size=224, mean=(0.48235, 0.45882, 0.40784), std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0) ImageNetEval, crop_size=224, mean=(0.48235, 0.45882, 0.40784), std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0)
...@@ -131,8 +131,8 @@ class VGG16Weights(Weights): ...@@ -131,8 +131,8 @@ class VGG16Weights(Weights):
) )
class VGG16BNWeights(Weights): class VGG16_BN_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -144,8 +144,8 @@ class VGG16BNWeights(Weights): ...@@ -144,8 +144,8 @@ class VGG16BNWeights(Weights):
) )
class VGG19Weights(Weights): class VGG19_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -157,8 +157,8 @@ class VGG19Weights(Weights): ...@@ -157,8 +157,8 @@ class VGG19Weights(Weights):
) )
class VGG19BNWeights(Weights): class VGG19_BN_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -170,81 +170,81 @@ class VGG19BNWeights(Weights): ...@@ -170,81 +170,81 @@ class VGG19BNWeights(Weights):
) )
def vgg11(weights: Optional[VGG11Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg11(weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11Weights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11_Weights.ImageNet1K_V1)
weights = VGG11Weights.verify(weights) weights = VGG11_Weights.verify(weights)
return _vgg("A", False, weights, progress, **kwargs) return _vgg("A", False, weights, progress, **kwargs)
def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg11_bn(weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11BNWeights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11_BN_Weights.ImageNet1K_V1)
weights = VGG11BNWeights.verify(weights) weights = VGG11_BN_Weights.verify(weights)
return _vgg("A", True, weights, progress, **kwargs) return _vgg("A", True, weights, progress, **kwargs)
def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg13(weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13Weights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13_Weights.ImageNet1K_V1)
weights = VGG13Weights.verify(weights) weights = VGG13_Weights.verify(weights)
return _vgg("B", False, weights, progress, **kwargs) return _vgg("B", False, weights, progress, **kwargs)
def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg13_bn(weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13BNWeights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13_BN_Weights.ImageNet1K_V1)
weights = VGG13BNWeights.verify(weights) weights = VGG13_BN_Weights.verify(weights)
return _vgg("B", True, weights, progress, **kwargs) return _vgg("B", True, weights, progress, **kwargs)
def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg16(weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16Weights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16_Weights.ImageNet1K_V1)
weights = VGG16Weights.verify(weights) weights = VGG16_Weights.verify(weights)
return _vgg("D", False, weights, progress, **kwargs) return _vgg("D", False, weights, progress, **kwargs)
def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg16_bn(weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16BNWeights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16_BN_Weights.ImageNet1K_V1)
weights = VGG16BNWeights.verify(weights) weights = VGG16_BN_Weights.verify(weights)
return _vgg("D", True, weights, progress, **kwargs) return _vgg("D", True, weights, progress, **kwargs)
def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg19(weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19Weights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19_Weights.ImageNet1K_V1)
weights = VGG19Weights.verify(weights) weights = VGG19_Weights.verify(weights)
return _vgg("E", False, weights, progress, **kwargs) return _vgg("E", False, weights, progress, **kwargs)
def vgg19_bn(weights: Optional[VGG19BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg19_bn(weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19BNWeights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19_BN_Weights.ImageNet1K_V1)
weights = VGG19BNWeights.verify(weights) weights = VGG19_BN_Weights.verify(weights)
return _vgg("E", True, weights, progress, **kwargs) return _vgg("E", True, weights, progress, **kwargs)
...@@ -15,16 +15,16 @@ from ....models.video.resnet import ( ...@@ -15,16 +15,16 @@ from ....models.video.resnet import (
R2Plus1dStem, R2Plus1dStem,
VideoResNet, VideoResNet,
) )
from .._api import Weights, WeightEntry from .._api import WeightsEnum, Weights
from .._meta import _KINETICS400_CATEGORIES from .._meta import _KINETICS400_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [ __all__ = [
"VideoResNet", "VideoResNet",
"R3D_18Weights", "R3D_18_Weights",
"MC3_18Weights", "MC3_18_Weights",
"R2Plus1D_18Weights", "R2Plus1D_18_Weights",
"r3d_18", "r3d_18",
"mc3_18", "mc3_18",
"r2plus1d_18", "r2plus1d_18",
...@@ -36,7 +36,7 @@ def _video_resnet( ...@@ -36,7 +36,7 @@ def _video_resnet(
conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]], conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]],
layers: List[int], layers: List[int],
stem: Callable[..., nn.Module], stem: Callable[..., nn.Module],
weights: Optional[Weights], weights: Optional[WeightsEnum],
progress: bool, progress: bool,
**kwargs: Any, **kwargs: Any,
) -> VideoResNet: ) -> VideoResNet:
...@@ -59,8 +59,8 @@ _COMMON_META = { ...@@ -59,8 +59,8 @@ _COMMON_META = {
} }
class R3D_18Weights(Weights): class R3D_18_Weights(WeightsEnum):
Kinetics400_RefV1 = WeightEntry( Kinetics400_V1 = Weights(
url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth", url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth",
transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)), transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)),
meta={ meta={
...@@ -72,8 +72,8 @@ class R3D_18Weights(Weights): ...@@ -72,8 +72,8 @@ class R3D_18Weights(Weights):
) )
class MC3_18Weights(Weights): class MC3_18_Weights(WeightsEnum):
Kinetics400_RefV1 = WeightEntry( Kinetics400_V1 = Weights(
url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth",
transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)), transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)),
meta={ meta={
...@@ -85,8 +85,8 @@ class MC3_18Weights(Weights): ...@@ -85,8 +85,8 @@ class MC3_18Weights(Weights):
) )
class R2Plus1D_18Weights(Weights): class R2Plus1D_18_Weights(WeightsEnum):
Kinetics400_RefV1 = WeightEntry( Kinetics400_V1 = Weights(
url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth",
transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)), transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)),
meta={ meta={
...@@ -98,12 +98,12 @@ class R2Plus1D_18Weights(Weights): ...@@ -98,12 +98,12 @@ class R2Plus1D_18Weights(Weights):
) )
def r3d_18(weights: Optional[R3D_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: def r3d_18(weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", R3D_18Weights.Kinetics400_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", R3D_18_Weights.Kinetics400_V1)
weights = R3D_18Weights.verify(weights) weights = R3D_18_Weights.verify(weights)
return _video_resnet( return _video_resnet(
BasicBlock, BasicBlock,
...@@ -116,12 +116,12 @@ def r3d_18(weights: Optional[R3D_18Weights] = None, progress: bool = True, **kwa ...@@ -116,12 +116,12 @@ def r3d_18(weights: Optional[R3D_18Weights] = None, progress: bool = True, **kwa
) )
def mc3_18(weights: Optional[MC3_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: def mc3_18(weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MC3_18Weights.Kinetics400_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", MC3_18_Weights.Kinetics400_V1)
weights = MC3_18Weights.verify(weights) weights = MC3_18_Weights.verify(weights)
return _video_resnet( return _video_resnet(
BasicBlock, BasicBlock,
...@@ -134,12 +134,12 @@ def mc3_18(weights: Optional[MC3_18Weights] = None, progress: bool = True, **kwa ...@@ -134,12 +134,12 @@ def mc3_18(weights: Optional[MC3_18Weights] = None, progress: bool = True, **kwa
) )
def r2plus1d_18(weights: Optional[R2Plus1D_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: def r2plus1d_18(weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", R2Plus1D_18Weights.Kinetics400_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", R2Plus1D_18_Weights.Kinetics400_V1)
weights = R2Plus1D_18Weights.verify(weights) weights = R2Plus1D_18_Weights.verify(weights)
return _video_resnet( return _video_resnet(
BasicBlock, BasicBlock,
......
...@@ -11,16 +11,16 @@ import torch ...@@ -11,16 +11,16 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from ._api import Weights from ._api import WeightsEnum
from ._utils import _deprecated_param, _deprecated_positional from ._utils import _deprecated_param, _deprecated_positional
__all__ = [ __all__ = [
"VisionTransformer", "VisionTransformer",
"VisionTransformer_B_16Weights", "ViT_B_16_Weights",
"VisionTransformer_B_32Weights", "ViT_B_32_Weights",
"VisionTransformer_L_16Weights", "ViT_L_16_Weights",
"VisionTransformer_L_32Weights", "ViT_L_32_Weights",
"vit_b_16", "vit_b_16",
"vit_b_32", "vit_b_32",
"vit_l_16", "vit_l_16",
...@@ -231,22 +231,22 @@ class VisionTransformer(nn.Module): ...@@ -231,22 +231,22 @@ class VisionTransformer(nn.Module):
return x return x
class VisionTransformer_B_16Weights(Weights): class ViT_B_16_Weights(WeightsEnum):
# If a default model is added here the corresponding changes need to be done in vit_b_16 # If a default model is added here the corresponding changes need to be done in vit_b_16
pass pass
class VisionTransformer_B_32Weights(Weights): class ViT_B_32_Weights(WeightsEnum):
# If a default model is added here the corresponding changes need to be done in vit_b_32 # If a default model is added here the corresponding changes need to be done in vit_b_32
pass pass
class VisionTransformer_L_16Weights(Weights): class ViT_L_16_Weights(WeightsEnum):
# If a default model is added here the corresponding changes need to be done in vit_l_16 # If a default model is added here the corresponding changes need to be done in vit_l_16
pass pass
class VisionTransformer_L_32Weights(Weights): class ViT_L_32_Weights(WeightsEnum):
# If a default model is added here the corresponding changes need to be done in vit_l_32 # If a default model is added here the corresponding changes need to be done in vit_l_32
pass pass
...@@ -257,7 +257,7 @@ def _vision_transformer( ...@@ -257,7 +257,7 @@ def _vision_transformer(
num_heads: int, num_heads: int,
hidden_dim: int, hidden_dim: int,
mlp_dim: int, mlp_dim: int,
weights: Optional[Weights], weights: Optional[WeightsEnum],
progress: bool, progress: bool,
**kwargs: Any, **kwargs: Any,
) -> VisionTransformer: ) -> VisionTransformer:
...@@ -279,15 +279,13 @@ def _vision_transformer( ...@@ -279,15 +279,13 @@ def _vision_transformer(
return model return model
def vit_b_16( def vit_b_16(weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
weights: Optional[VisionTransformer_B_16Weights] = None, progress: bool = True, **kwargs: Any
) -> VisionTransformer:
""" """
Constructs a vit_b_16 architecture from Constructs a vit_b_16 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args: Args:
weights (VisionTransformer_B_16Weights, optional): If not None, returns a model pre-trained on ImageNet. weights (ViT_B_16Weights, optional): If not None, returns a model pre-trained on ImageNet.
Default: None. Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
""" """
...@@ -295,7 +293,7 @@ def vit_b_16( ...@@ -295,7 +293,7 @@ def vit_b_16(
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None) weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = VisionTransformer_B_16Weights.verify(weights) weights = ViT_B_16_Weights.verify(weights)
return _vision_transformer( return _vision_transformer(
patch_size=16, patch_size=16,
...@@ -309,15 +307,13 @@ def vit_b_16( ...@@ -309,15 +307,13 @@ def vit_b_16(
) )
def vit_b_32( def vit_b_32(weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
weights: Optional[VisionTransformer_B_32Weights] = None, progress: bool = True, **kwargs: Any
) -> VisionTransformer:
""" """
Constructs a vit_b_32 architecture from Constructs a vit_b_32 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args: Args:
weights (VisionTransformer_B_32Weights, optional): If not None, returns a model pre-trained on ImageNet. weights (ViT_B_32Weights, optional): If not None, returns a model pre-trained on ImageNet.
Default: None. Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
""" """
...@@ -325,7 +321,7 @@ def vit_b_32( ...@@ -325,7 +321,7 @@ def vit_b_32(
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None) weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = VisionTransformer_B_32Weights.verify(weights) weights = ViT_B_32_Weights.verify(weights)
return _vision_transformer( return _vision_transformer(
patch_size=32, patch_size=32,
...@@ -339,15 +335,13 @@ def vit_b_32( ...@@ -339,15 +335,13 @@ def vit_b_32(
) )
def vit_l_16( def vit_l_16(weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
weights: Optional[VisionTransformer_L_16Weights] = None, progress: bool = True, **kwargs: Any
) -> VisionTransformer:
""" """
Constructs a vit_l_16 architecture from Constructs a vit_l_16 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args: Args:
weights (VisionTransformer_L_16Weights, optional): If not None, returns a model pre-trained on ImageNet. weights (ViT_L_16Weights, optional): If not None, returns a model pre-trained on ImageNet.
Default: None. Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
""" """
...@@ -355,7 +349,7 @@ def vit_l_16( ...@@ -355,7 +349,7 @@ def vit_l_16(
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None) weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = VisionTransformer_L_16Weights.verify(weights) weights = ViT_L_16_Weights.verify(weights)
return _vision_transformer( return _vision_transformer(
patch_size=16, patch_size=16,
...@@ -369,15 +363,13 @@ def vit_l_16( ...@@ -369,15 +363,13 @@ def vit_l_16(
) )
def vit_l_32( def vit_l_32(weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
weights: Optional[VisionTransformer_B_32Weights] = None, progress: bool = True, **kwargs: Any
) -> VisionTransformer:
""" """
Constructs a vit_l_32 architecture from Constructs a vit_l_32 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args: Args:
weights (VisionTransformer_L_16Weights, optional): If not None, returns a model pre-trained on ImageNet. weights (ViT_L_32Weights, optional): If not None, returns a model pre-trained on ImageNet.
Default: None. Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
""" """
...@@ -385,7 +377,7 @@ def vit_l_32( ...@@ -385,7 +377,7 @@ def vit_l_32(
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None) weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = VisionTransformer_L_32Weights.verify(weights) weights = ViT_L_32_Weights.verify(weights)
return _vision_transformer( return _vision_transformer(
patch_size=32, patch_size=32,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment