Unverified Commit b9da6db4 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Cleanup namings of Multi-weights classes and enums (#5003)

* Rename classes Weights => WeightsEnum and WeightEntry => Weights.

* Make enum values follow the naming convention `_V1`, `_V2` etc

* Cleanup the Enum class naming conventions.

* Add a test to check naming conventions.
parent b3cdec1f
......@@ -11,15 +11,15 @@ from ....models.quantization.mobilenetv3 import (
QuantizableMobileNetV3,
_replace_relu,
)
from .._api import Weights, WeightEntry
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ..mobilenetv3 import MobileNetV3LargeWeights, _mobilenet_v3_conf
from ..mobilenetv3 import MobileNet_V3_Large_Weights, _mobilenet_v3_conf
__all__ = [
"QuantizableMobileNetV3",
"QuantizedMobileNetV3LargeWeights",
"MobileNet_V3_Large_QuantizedWeights",
"mobilenet_v3_large",
]
......@@ -27,7 +27,7 @@ __all__ = [
def _mobilenet_v3_model(
inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
weights: Optional[Weights],
weights: Optional[WeightsEnum],
progress: bool,
quantize: bool,
**kwargs: Any,
......@@ -56,8 +56,8 @@ def _mobilenet_v3_model(
return model
class QuantizedMobileNetV3LargeWeights(Weights):
ImageNet1K_QNNPACK_RefV1 = WeightEntry(
class MobileNet_V3_Large_QuantizedWeights(WeightsEnum):
ImageNet1K_QNNPACK_V1 = Weights(
url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -67,7 +67,7 @@ class QuantizedMobileNetV3LargeWeights(Weights):
"backend": "qnnpack",
"quantization": "qat",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3",
"unquantized": MobileNetV3LargeWeights.ImageNet1K_RefV1,
"unquantized": MobileNet_V3_Large_Weights.ImageNet1K_V1,
"acc@1": 73.004,
"acc@5": 90.858,
},
......@@ -76,7 +76,7 @@ class QuantizedMobileNetV3LargeWeights(Weights):
def mobilenet_v3_large(
weights: Optional[Union[QuantizedMobileNetV3LargeWeights, MobileNetV3LargeWeights]] = None,
weights: Optional[Union[MobileNet_V3_Large_QuantizedWeights, MobileNet_V3_Large_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
......@@ -85,15 +85,15 @@ def mobilenet_v3_large(
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = (
QuantizedMobileNetV3LargeWeights.ImageNet1K_QNNPACK_RefV1
MobileNet_V3_Large_QuantizedWeights.ImageNet1K_QNNPACK_V1
if quantize
else MobileNetV3LargeWeights.ImageNet1K_RefV1
else MobileNet_V3_Large_Weights.ImageNet1K_V1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = QuantizedMobileNetV3LargeWeights.verify(weights)
weights = MobileNet_V3_Large_QuantizedWeights.verify(weights)
else:
weights = MobileNetV3LargeWeights.verify(weights)
weights = MobileNet_V3_Large_Weights.verify(weights)
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs)
......@@ -11,17 +11,17 @@ from ....models.quantization.resnet import (
_replace_relu,
quantize_model,
)
from .._api import Weights, WeightEntry
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ..resnet import ResNet18Weights, ResNet50Weights, ResNeXt101_32x8dWeights
from ..resnet import ResNet18_Weights, ResNet50_Weights, ResNeXt101_32X8D_Weights
__all__ = [
"QuantizableResNet",
"QuantizedResNet18Weights",
"QuantizedResNet50Weights",
"QuantizedResNeXt101_32x8dWeights",
"ResNet18_QuantizedWeights",
"ResNet50_QuantizedWeights",
"ResNeXt101_32X8D_QuantizedWeights",
"resnet18",
"resnet50",
"resnext101_32x8d",
......@@ -31,7 +31,7 @@ __all__ = [
def _resnet(
block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]],
layers: List[int],
weights: Optional[Weights],
weights: Optional[WeightsEnum],
progress: bool,
quantize: bool,
**kwargs: Any,
......@@ -63,13 +63,13 @@ _COMMON_META = {
}
class QuantizedResNet18Weights(Weights):
ImageNet1K_FBGEMM_RefV1 = WeightEntry(
class ResNet18_QuantizedWeights(WeightsEnum):
ImageNet1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_COMMON_META,
"unquantized": ResNet18Weights.ImageNet1K_RefV1,
"unquantized": ResNet18_Weights.ImageNet1K_V1,
"acc@1": 69.494,
"acc@5": 88.882,
},
......@@ -77,24 +77,24 @@ class QuantizedResNet18Weights(Weights):
)
class QuantizedResNet50Weights(Weights):
ImageNet1K_FBGEMM_RefV1 = WeightEntry(
class ResNet50_QuantizedWeights(WeightsEnum):
ImageNet1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_COMMON_META,
"unquantized": ResNet50Weights.ImageNet1K_RefV1,
"unquantized": ResNet50_Weights.ImageNet1K_V1,
"acc@1": 75.920,
"acc@5": 92.814,
},
default=False,
)
ImageNet1K_FBGEMM_RefV2 = WeightEntry(
ImageNet1K_FBGEMM_V2 = Weights(
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"unquantized": ResNet50Weights.ImageNet1K_RefV2,
"unquantized": ResNet50_Weights.ImageNet1K_V2,
"acc@1": 80.282,
"acc@5": 94.976,
},
......@@ -102,24 +102,24 @@ class QuantizedResNet50Weights(Weights):
)
class QuantizedResNeXt101_32x8dWeights(Weights):
ImageNet1K_FBGEMM_RefV1 = WeightEntry(
class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
ImageNet1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_COMMON_META,
"unquantized": ResNeXt101_32x8dWeights.ImageNet1K_RefV1,
"unquantized": ResNeXt101_32X8D_Weights.ImageNet1K_V1,
"acc@1": 78.986,
"acc@5": 94.480,
},
default=False,
)
ImageNet1K_FBGEMM_RefV2 = WeightEntry(
ImageNet1K_FBGEMM_V2 = Weights(
url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"unquantized": ResNeXt101_32x8dWeights.ImageNet1K_RefV2,
"unquantized": ResNeXt101_32X8D_Weights.ImageNet1K_V2,
"acc@1": 82.574,
"acc@5": 96.132,
},
......@@ -128,7 +128,7 @@ class QuantizedResNeXt101_32x8dWeights(Weights):
def resnet18(
weights: Optional[Union[QuantizedResNet18Weights, ResNet18Weights]] = None,
weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
......@@ -136,20 +136,18 @@ def resnet18(
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = (
QuantizedResNet18Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet18Weights.ImageNet1K_RefV1
)
default_value = ResNet18_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else ResNet18_Weights.ImageNet1K_V1
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = QuantizedResNet18Weights.verify(weights)
weights = ResNet18_QuantizedWeights.verify(weights)
else:
weights = ResNet18Weights.verify(weights)
weights = ResNet18_Weights.verify(weights)
return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs)
def resnet50(
weights: Optional[Union[QuantizedResNet50Weights, ResNet50Weights]] = None,
weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
......@@ -157,20 +155,18 @@ def resnet50(
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = (
QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet50Weights.ImageNet1K_RefV1
)
default_value = ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else ResNet50_Weights.ImageNet1K_V1
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = QuantizedResNet50Weights.verify(weights)
weights = ResNet50_QuantizedWeights.verify(weights)
else:
weights = ResNet50Weights.verify(weights)
weights = ResNet50_Weights.verify(weights)
return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs)
def resnext101_32x8d(
weights: Optional[Union[QuantizedResNeXt101_32x8dWeights, ResNeXt101_32x8dWeights]] = None,
weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
......@@ -179,15 +175,15 @@ def resnext101_32x8d(
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = (
QuantizedResNeXt101_32x8dWeights.ImageNet1K_FBGEMM_RefV1
ResNeXt101_32X8D_QuantizedWeights.ImageNet1K_FBGEMM_V1
if quantize
else ResNeXt101_32x8dWeights.ImageNet1K_RefV1
else ResNeXt101_32X8D_Weights.ImageNet1K_V1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = QuantizedResNeXt101_32x8dWeights.verify(weights)
weights = ResNeXt101_32X8D_QuantizedWeights.verify(weights)
else:
weights = ResNeXt101_32x8dWeights.verify(weights)
weights = ResNeXt101_32X8D_Weights.verify(weights)
_ovewrite_named_param(kwargs, "groups", 32)
_ovewrite_named_param(kwargs, "width_per_group", 8)
......
......@@ -9,16 +9,16 @@ from ....models.quantization.shufflenetv2 import (
_replace_relu,
quantize_model,
)
from .._api import Weights, WeightEntry
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ..shufflenetv2 import ShuffleNetV2_x0_5Weights, ShuffleNetV2_x1_0Weights
from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights
__all__ = [
"QuantizableShuffleNetV2",
"QuantizedShuffleNetV2_x0_5Weights",
"QuantizedShuffleNetV2_x1_0Weights",
"ShuffleNet_V2_X0_5_QuantizedWeights",
"ShuffleNet_V2_X1_0_QuantizedWeights",
"shufflenet_v2_x0_5",
"shufflenet_v2_x1_0",
]
......@@ -27,7 +27,7 @@ __all__ = [
def _shufflenetv2(
stages_repeats: List[int],
stages_out_channels: List[int],
weights: Optional[Weights],
weights: Optional[WeightsEnum],
progress: bool,
quantize: bool,
**kwargs: Any,
......@@ -59,13 +59,13 @@ _COMMON_META = {
}
class QuantizedShuffleNetV2_x0_5Weights(Weights):
ImageNet1K_FBGEMM_Community = WeightEntry(
class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum):
ImageNet1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_COMMON_META,
"unquantized": ShuffleNetV2_x0_5Weights.ImageNet1K_Community,
"unquantized": ShuffleNet_V2_X0_5_Weights.ImageNet1K_V1,
"acc@1": 57.972,
"acc@5": 79.780,
},
......@@ -73,13 +73,13 @@ class QuantizedShuffleNetV2_x0_5Weights(Weights):
)
class QuantizedShuffleNetV2_x1_0Weights(Weights):
ImageNet1K_FBGEMM_Community = WeightEntry(
class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum):
ImageNet1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_COMMON_META,
"unquantized": ShuffleNetV2_x1_0Weights.ImageNet1K_Community,
"unquantized": ShuffleNet_V2_X1_0_Weights.ImageNet1K_V1,
"acc@1": 68.360,
"acc@5": 87.582,
},
......@@ -88,7 +88,7 @@ class QuantizedShuffleNetV2_x1_0Weights(Weights):
def shufflenet_v2_x0_5(
weights: Optional[Union[QuantizedShuffleNetV2_x0_5Weights, ShuffleNetV2_x0_5Weights]] = None,
weights: Optional[Union[ShuffleNet_V2_X0_5_QuantizedWeights, ShuffleNet_V2_X0_5_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
......@@ -97,21 +97,21 @@ def shufflenet_v2_x0_5(
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = (
QuantizedShuffleNetV2_x0_5Weights.ImageNet1K_FBGEMM_Community
ShuffleNet_V2_X0_5_QuantizedWeights.ImageNet1K_FBGEMM_V1
if quantize
else ShuffleNetV2_x0_5Weights.ImageNet1K_Community
else ShuffleNet_V2_X0_5_Weights.ImageNet1K_V1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = QuantizedShuffleNetV2_x0_5Weights.verify(weights)
weights = ShuffleNet_V2_X0_5_QuantizedWeights.verify(weights)
else:
weights = ShuffleNetV2_x0_5Weights.verify(weights)
weights = ShuffleNet_V2_X0_5_Weights.verify(weights)
return _shufflenetv2([4, 8, 4], [24, 48, 96, 192, 1024], weights, progress, quantize, **kwargs)
def shufflenet_v2_x1_0(
weights: Optional[Union[QuantizedShuffleNetV2_x1_0Weights, ShuffleNetV2_x1_0Weights]] = None,
weights: Optional[Union[ShuffleNet_V2_X1_0_QuantizedWeights, ShuffleNet_V2_X1_0_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
......@@ -120,14 +120,14 @@ def shufflenet_v2_x1_0(
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = (
QuantizedShuffleNetV2_x1_0Weights.ImageNet1K_FBGEMM_Community
ShuffleNet_V2_X1_0_QuantizedWeights.ImageNet1K_FBGEMM_V1
if quantize
else ShuffleNetV2_x1_0Weights.ImageNet1K_Community
else ShuffleNet_V2_X1_0_Weights.ImageNet1K_V1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = QuantizedShuffleNetV2_x1_0Weights.verify(weights)
weights = ShuffleNet_V2_X1_0_QuantizedWeights.verify(weights)
else:
weights = ShuffleNetV2_x1_0Weights.verify(weights)
weights = ShuffleNet_V2_X1_0_Weights.verify(weights)
return _shufflenetv2([4, 8, 4], [24, 116, 232, 464, 1024], weights, progress, quantize, **kwargs)
......@@ -6,27 +6,27 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode
from ...models.regnet import RegNet, BlockParams
from ._api import Weights, WeightEntry
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [
"RegNet",
"RegNet_y_400mfWeights",
"RegNet_y_800mfWeights",
"RegNet_y_1_6gfWeights",
"RegNet_y_3_2gfWeights",
"RegNet_y_8gfWeights",
"RegNet_y_16gfWeights",
"RegNet_y_32gfWeights",
"RegNet_x_400mfWeights",
"RegNet_x_800mfWeights",
"RegNet_x_1_6gfWeights",
"RegNet_x_3_2gfWeights",
"RegNet_x_8gfWeights",
"RegNet_x_16gfWeights",
"RegNet_x_32gfWeights",
"RegNet_Y_400MF_Weights",
"RegNet_Y_800MF_Weights",
"RegNet_Y_1_6GF_Weights",
"RegNet_Y_3_2GF_Weights",
"RegNet_Y_8GF_Weights",
"RegNet_Y_16GF_Weights",
"RegNet_Y_32GF_Weights",
"RegNet_X_400MF_Weights",
"RegNet_X_800MF_Weights",
"RegNet_X_1_6GF_Weights",
"RegNet_X_3_2GF_Weights",
"RegNet_X_8GF_Weights",
"RegNet_X_16GF_Weights",
"RegNet_X_32GF_Weights",
"regnet_y_400mf",
"regnet_y_800mf",
"regnet_y_1_6gf",
......@@ -48,7 +48,7 @@ _COMMON_META = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpo
def _regnet(
block_params: BlockParams,
weights: Optional[Weights],
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> RegNet:
......@@ -64,8 +64,8 @@ def _regnet(
return model
class RegNet_y_400mfWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class RegNet_Y_400MF_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -78,8 +78,8 @@ class RegNet_y_400mfWeights(Weights):
)
class RegNet_y_800mfWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class RegNet_Y_800MF_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -92,8 +92,8 @@ class RegNet_y_800mfWeights(Weights):
)
class RegNet_y_1_6gfWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class RegNet_Y_1_6GF_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -106,8 +106,8 @@ class RegNet_y_1_6gfWeights(Weights):
)
class RegNet_y_3_2gfWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class RegNet_Y_3_2GF_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -120,8 +120,8 @@ class RegNet_y_3_2gfWeights(Weights):
)
class RegNet_y_8gfWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class RegNet_Y_8GF_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -134,8 +134,8 @@ class RegNet_y_8gfWeights(Weights):
)
class RegNet_y_16gfWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class RegNet_Y_16GF_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -148,8 +148,8 @@ class RegNet_y_16gfWeights(Weights):
)
class RegNet_y_32gfWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class RegNet_Y_32GF_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -162,8 +162,8 @@ class RegNet_y_32gfWeights(Weights):
)
class RegNet_x_400mfWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class RegNet_X_400MF_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -176,8 +176,8 @@ class RegNet_x_400mfWeights(Weights):
)
class RegNet_x_800mfWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class RegNet_X_800MF_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -190,8 +190,8 @@ class RegNet_x_800mfWeights(Weights):
)
class RegNet_x_1_6gfWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class RegNet_X_1_6GF_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -204,8 +204,8 @@ class RegNet_x_1_6gfWeights(Weights):
)
class RegNet_x_3_2gfWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class RegNet_X_3_2GF_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -218,8 +218,8 @@ class RegNet_x_3_2gfWeights(Weights):
)
class RegNet_x_8gfWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class RegNet_X_8GF_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -232,8 +232,8 @@ class RegNet_x_8gfWeights(Weights):
)
class RegNet_x_16gfWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class RegNet_X_16GF_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -246,8 +246,8 @@ class RegNet_x_16gfWeights(Weights):
)
class RegNet_x_32gfWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class RegNet_X_32GF_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -260,34 +260,34 @@ class RegNet_x_32gfWeights(Weights):
)
def regnet_y_400mf(weights: Optional[RegNet_y_400mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
def regnet_y_400mf(weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_400mfWeights.ImageNet1K_RefV1)
weights = RegNet_y_400mfWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_400MF_Weights.ImageNet1K_V1)
weights = RegNet_Y_400MF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs)
return _regnet(params, weights, progress, **kwargs)
def regnet_y_800mf(weights: Optional[RegNet_y_800mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
def regnet_y_800mf(weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_800mfWeights.ImageNet1K_RefV1)
weights = RegNet_y_800mfWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_800MF_Weights.ImageNet1K_V1)
weights = RegNet_Y_800MF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs)
return _regnet(params, weights, progress, **kwargs)
def regnet_y_1_6gf(weights: Optional[RegNet_y_1_6gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
def regnet_y_1_6gf(weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_1_6gfWeights.ImageNet1K_RefV1)
weights = RegNet_y_1_6gfWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_1_6GF_Weights.ImageNet1K_V1)
weights = RegNet_Y_1_6GF_Weights.verify(weights)
params = BlockParams.from_init_params(
depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs
......@@ -295,12 +295,12 @@ def regnet_y_1_6gf(weights: Optional[RegNet_y_1_6gfWeights] = None, progress: bo
return _regnet(params, weights, progress, **kwargs)
def regnet_y_3_2gf(weights: Optional[RegNet_y_3_2gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
def regnet_y_3_2gf(weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_3_2gfWeights.ImageNet1K_RefV1)
weights = RegNet_y_3_2gfWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_3_2GF_Weights.ImageNet1K_V1)
weights = RegNet_Y_3_2GF_Weights.verify(weights)
params = BlockParams.from_init_params(
depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs
......@@ -308,12 +308,12 @@ def regnet_y_3_2gf(weights: Optional[RegNet_y_3_2gfWeights] = None, progress: bo
return _regnet(params, weights, progress, **kwargs)
def regnet_y_8gf(weights: Optional[RegNet_y_8gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
def regnet_y_8gf(weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_8gfWeights.ImageNet1K_RefV1)
weights = RegNet_y_8gfWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_8GF_Weights.ImageNet1K_V1)
weights = RegNet_Y_8GF_Weights.verify(weights)
params = BlockParams.from_init_params(
depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs
......@@ -321,12 +321,12 @@ def regnet_y_8gf(weights: Optional[RegNet_y_8gfWeights] = None, progress: bool =
return _regnet(params, weights, progress, **kwargs)
def regnet_y_16gf(weights: Optional[RegNet_y_16gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
def regnet_y_16gf(weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_16gfWeights.ImageNet1K_RefV1)
weights = RegNet_y_16gfWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_16GF_Weights.ImageNet1K_V1)
weights = RegNet_Y_16GF_Weights.verify(weights)
params = BlockParams.from_init_params(
depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs
......@@ -334,12 +334,12 @@ def regnet_y_16gf(weights: Optional[RegNet_y_16gfWeights] = None, progress: bool
return _regnet(params, weights, progress, **kwargs)
def regnet_y_32gf(weights: Optional[RegNet_y_32gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
def regnet_y_32gf(weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_32gfWeights.ImageNet1K_RefV1)
weights = RegNet_y_32gfWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_32GF_Weights.ImageNet1K_V1)
weights = RegNet_Y_32GF_Weights.verify(weights)
params = BlockParams.from_init_params(
depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs
......@@ -347,78 +347,78 @@ def regnet_y_32gf(weights: Optional[RegNet_y_32gfWeights] = None, progress: bool
return _regnet(params, weights, progress, **kwargs)
def regnet_x_400mf(weights: Optional[RegNet_x_400mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
def regnet_x_400mf(weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_400mfWeights.ImageNet1K_RefV1)
weights = RegNet_x_400mfWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_400MF_Weights.ImageNet1K_V1)
weights = RegNet_X_400MF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs)
return _regnet(params, weights, progress, **kwargs)
def regnet_x_800mf(weights: Optional[RegNet_x_800mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
def regnet_x_800mf(weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_800mfWeights.ImageNet1K_RefV1)
weights = RegNet_x_800mfWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_800MF_Weights.ImageNet1K_V1)
weights = RegNet_X_800MF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs)
return _regnet(params, weights, progress, **kwargs)
def regnet_x_1_6gf(weights: Optional[RegNet_x_1_6gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
def regnet_x_1_6gf(weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_1_6gfWeights.ImageNet1K_RefV1)
weights = RegNet_x_1_6gfWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_1_6GF_Weights.ImageNet1K_V1)
weights = RegNet_X_1_6GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs)
return _regnet(params, weights, progress, **kwargs)
def regnet_x_3_2gf(weights: Optional[RegNet_x_3_2gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
def regnet_x_3_2gf(weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_3_2gfWeights.ImageNet1K_RefV1)
weights = RegNet_x_3_2gfWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_3_2GF_Weights.ImageNet1K_V1)
weights = RegNet_X_3_2GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs)
return _regnet(params, weights, progress, **kwargs)
def regnet_x_8gf(weights: Optional[RegNet_x_8gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
def regnet_x_8gf(weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_8gfWeights.ImageNet1K_RefV1)
weights = RegNet_x_8gfWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_8GF_Weights.ImageNet1K_V1)
weights = RegNet_X_8GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs)
return _regnet(params, weights, progress, **kwargs)
def regnet_x_16gf(weights: Optional[RegNet_x_16gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
def regnet_x_16gf(weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_16gfWeights.ImageNet1K_RefV1)
weights = RegNet_x_16gfWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_16GF_Weights.ImageNet1K_V1)
weights = RegNet_X_16GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs)
return _regnet(params, weights, progress, **kwargs)
def regnet_x_32gf(weights: Optional[RegNet_x_32gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
def regnet_x_32gf(weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_32gfWeights.ImageNet1K_RefV1)
weights = RegNet_x_32gfWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_32GF_Weights.ImageNet1K_V1)
weights = RegNet_X_32GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs)
return _regnet(params, weights, progress, **kwargs)
......@@ -5,22 +5,22 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode
from ...models.resnet import BasicBlock, Bottleneck, ResNet
from ._api import Weights, WeightEntry
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [
"ResNet",
"ResNet18Weights",
"ResNet34Weights",
"ResNet50Weights",
"ResNet101Weights",
"ResNet152Weights",
"ResNeXt50_32x4dWeights",
"ResNeXt101_32x8dWeights",
"WideResNet50_2Weights",
"WideResNet101_2Weights",
"ResNet18_Weights",
"ResNet34_Weights",
"ResNet50_Weights",
"ResNet101_Weights",
"ResNet152_Weights",
"ResNeXt50_32X4D_Weights",
"ResNeXt101_32X8D_Weights",
"Wide_ResNet50_2_Weights",
"Wide_ResNet101_2_Weights",
"resnet18",
"resnet34",
"resnet50",
......@@ -36,7 +36,7 @@ __all__ = [
def _resnet(
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
weights: Optional[Weights],
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> ResNet:
......@@ -54,8 +54,8 @@ def _resnet(
_COMMON_META = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
class ResNet18Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class ResNet18_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet18-f37072fd.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -68,8 +68,8 @@ class ResNet18Weights(Weights):
)
class ResNet34Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class ResNet34_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet34-b627a593.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -82,8 +82,8 @@ class ResNet34Weights(Weights):
)
class ResNet50Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class ResNet50_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet50-0676ba61.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -94,7 +94,7 @@ class ResNet50Weights(Weights):
},
default=False,
)
ImageNet1K_RefV2 = WeightEntry(
ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/resnet50-f46c3f97.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
......@@ -107,8 +107,8 @@ class ResNet50Weights(Weights):
)
class ResNet101Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class ResNet101_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet101-63fe2227.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -119,7 +119,7 @@ class ResNet101Weights(Weights):
},
default=False,
)
ImageNet1K_RefV2 = WeightEntry(
ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/resnet101-cd907fc2.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
......@@ -132,8 +132,8 @@ class ResNet101Weights(Weights):
)
class ResNet152Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class ResNet152_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet152-394f9c45.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -144,7 +144,7 @@ class ResNet152Weights(Weights):
},
default=False,
)
ImageNet1K_RefV2 = WeightEntry(
ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/resnet152-f82ba261.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
......@@ -157,8 +157,8 @@ class ResNet152Weights(Weights):
)
class ResNeXt50_32x4dWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class ResNeXt50_32X4D_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -169,7 +169,7 @@ class ResNeXt50_32x4dWeights(Weights):
},
default=False,
)
ImageNet1K_RefV2 = WeightEntry(
ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
......@@ -182,8 +182,8 @@ class ResNeXt50_32x4dWeights(Weights):
)
class ResNeXt101_32x8dWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class ResNeXt101_32X8D_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -194,7 +194,7 @@ class ResNeXt101_32x8dWeights(Weights):
},
default=False,
)
ImageNet1K_RefV2 = WeightEntry(
ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
......@@ -207,8 +207,8 @@ class ResNeXt101_32x8dWeights(Weights):
)
class WideResNet50_2Weights(Weights):
ImageNet1K_Community = WeightEntry(
class Wide_ResNet50_2_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -219,7 +219,7 @@ class WideResNet50_2Weights(Weights):
},
default=False,
)
ImageNet1K_RefV2 = WeightEntry(
ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
......@@ -232,8 +232,8 @@ class WideResNet50_2Weights(Weights):
)
class WideResNet101_2Weights(Weights):
ImageNet1K_Community = WeightEntry(
class Wide_ResNet101_2_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -244,7 +244,7 @@ class WideResNet101_2Weights(Weights):
},
default=False,
)
ImageNet1K_RefV2 = WeightEntry(
ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
......@@ -257,97 +257,101 @@ class WideResNet101_2Weights(Weights):
)
def resnet18(weights: Optional[ResNet18Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
def resnet18(weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet18Weights.ImageNet1K_RefV1)
weights = ResNet18Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet18_Weights.ImageNet1K_V1)
weights = ResNet18_Weights.verify(weights)
return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
def resnet34(weights: Optional[ResNet34Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
def resnet34(weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet34Weights.ImageNet1K_RefV1)
weights = ResNet34Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet34_Weights.ImageNet1K_V1)
weights = ResNet34_Weights.verify(weights)
return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs)
def resnet50(weights: Optional[ResNet50Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
def resnet50(weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet50Weights.ImageNet1K_RefV1)
weights = ResNet50Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet50_Weights.ImageNet1K_V1)
weights = ResNet50_Weights.verify(weights)
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
def resnet101(weights: Optional[ResNet101Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
def resnet101(weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet101Weights.ImageNet1K_RefV1)
weights = ResNet101Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet101_Weights.ImageNet1K_V1)
weights = ResNet101_Weights.verify(weights)
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
def resnet152(weights: Optional[ResNet152Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
def resnet152(weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet152Weights.ImageNet1K_RefV1)
weights = ResNet152Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet152_Weights.ImageNet1K_V1)
weights = ResNet152_Weights.verify(weights)
return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs)
def resnext50_32x4d(weights: Optional[ResNeXt50_32x4dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
def resnext50_32x4d(weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt50_32x4dWeights.ImageNet1K_RefV1)
weights = ResNeXt50_32x4dWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt50_32X4D_Weights.ImageNet1K_V1)
weights = ResNeXt50_32X4D_Weights.verify(weights)
_ovewrite_named_param(kwargs, "groups", 32)
_ovewrite_named_param(kwargs, "width_per_group", 4)
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
def resnext101_32x8d(weights: Optional[ResNeXt101_32x8dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
def resnext101_32x8d(
weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt101_32x8dWeights.ImageNet1K_RefV1)
weights = ResNeXt101_32x8dWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt101_32X8D_Weights.ImageNet1K_V1)
weights = ResNeXt101_32X8D_Weights.verify(weights)
_ovewrite_named_param(kwargs, "groups", 32)
_ovewrite_named_param(kwargs, "width_per_group", 8)
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
def wide_resnet50_2(weights: Optional[WideResNet50_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
def wide_resnet50_2(weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", WideResNet50_2Weights.ImageNet1K_Community)
weights = WideResNet50_2Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", Wide_ResNet50_2_Weights.ImageNet1K_V1)
weights = Wide_ResNet50_2_Weights.verify(weights)
_ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
def wide_resnet101_2(weights: Optional[WideResNet101_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
def wide_resnet101_2(
weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", WideResNet101_2Weights.ImageNet1K_Community)
weights = WideResNet101_2Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", Wide_ResNet101_2_Weights.ImageNet1K_V1)
weights = Wide_ResNet101_2_Weights.verify(weights)
_ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
......@@ -5,19 +5,19 @@ from torchvision.prototype.transforms import VocEval
from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet
from .._api import Weights, WeightEntry
from .._api import WeightsEnum, Weights
from .._meta import _VOC_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large
from ..resnet import resnet50, resnet101
from ..resnet import ResNet50Weights, ResNet101Weights
from ..resnet import ResNet50_Weights, ResNet101_Weights
__all__ = [
"DeepLabV3",
"DeepLabV3ResNet50Weights",
"DeepLabV3ResNet101Weights",
"DeepLabV3MobileNetV3LargeWeights",
"DeepLabV3_ResNet50_Weights",
"DeepLabV3_ResNet101_Weights",
"DeepLabV3_MobileNet_V3_Large_Weights",
"deeplabv3_mobilenet_v3_large",
"deeplabv3_resnet50",
"deeplabv3_resnet101",
......@@ -30,8 +30,8 @@ _COMMON_META = {
}
class DeepLabV3ResNet50Weights(Weights):
CocoWithVocLabels_RefV1 = WeightEntry(
class DeepLabV3_ResNet50_Weights(WeightsEnum):
CocoWithVocLabels_V1 = Weights(
url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
transforms=partial(VocEval, resize_size=520),
meta={
......@@ -44,8 +44,8 @@ class DeepLabV3ResNet50Weights(Weights):
)
class DeepLabV3ResNet101Weights(Weights):
CocoWithVocLabels_RefV1 = WeightEntry(
class DeepLabV3_ResNet101_Weights(WeightsEnum):
CocoWithVocLabels_V1 = Weights(
url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
transforms=partial(VocEval, resize_size=520),
meta={
......@@ -58,8 +58,8 @@ class DeepLabV3ResNet101Weights(Weights):
)
class DeepLabV3MobileNetV3LargeWeights(Weights):
CocoWithVocLabels_RefV1 = WeightEntry(
class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum):
CocoWithVocLabels_V1 = Weights(
url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
transforms=partial(VocEval, resize_size=520),
meta={
......@@ -73,25 +73,25 @@ class DeepLabV3MobileNetV3LargeWeights(Weights):
def deeplabv3_resnet50(
weights: Optional[DeepLabV3ResNet50Weights] = None,
weights: Optional[DeepLabV3_ResNet50_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None,
weights_backbone: Optional[ResNet50Weights] = None,
weights_backbone: Optional[ResNet50_Weights] = None,
**kwargs: Any,
) -> DeepLabV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3ResNet50Weights.CocoWithVocLabels_RefV1)
weights = DeepLabV3ResNet50Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3_ResNet50_Weights.CocoWithVocLabels_V1)
weights = DeepLabV3_ResNet50_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1
kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1
)
weights_backbone = ResNet50Weights.verify(weights_backbone)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
......@@ -110,25 +110,25 @@ def deeplabv3_resnet50(
def deeplabv3_resnet101(
weights: Optional[DeepLabV3ResNet101Weights] = None,
weights: Optional[DeepLabV3_ResNet101_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None,
weights_backbone: Optional[ResNet101Weights] = None,
weights_backbone: Optional[ResNet101_Weights] = None,
**kwargs: Any,
) -> DeepLabV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3ResNet101Weights.CocoWithVocLabels_RefV1)
weights = DeepLabV3ResNet101Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3_ResNet101_Weights.CocoWithVocLabels_V1)
weights = DeepLabV3_ResNet101_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet101Weights.ImageNet1K_RefV1
kwargs, "pretrained_backbone", "weights_backbone", ResNet101_Weights.ImageNet1K_V1
)
weights_backbone = ResNet101Weights.verify(weights_backbone)
weights_backbone = ResNet101_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
......@@ -147,27 +147,27 @@ def deeplabv3_resnet101(
def deeplabv3_mobilenet_v3_large(
weights: Optional[DeepLabV3MobileNetV3LargeWeights] = None,
weights: Optional[DeepLabV3_MobileNet_V3_Large_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
weights_backbone: Optional[MobileNet_V3_Large_Weights] = None,
**kwargs: Any,
) -> DeepLabV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(
kwargs, "pretrained", "weights", DeepLabV3MobileNetV3LargeWeights.CocoWithVocLabels_RefV1
kwargs, "pretrained", "weights", DeepLabV3_MobileNet_V3_Large_Weights.CocoWithVocLabels_V1
)
weights = DeepLabV3MobileNetV3LargeWeights.verify(weights)
weights = DeepLabV3_MobileNet_V3_Large_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1
kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1
)
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
......
......@@ -5,13 +5,13 @@ from torchvision.prototype.transforms import VocEval
from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.fcn import FCN, _fcn_resnet
from .._api import Weights, WeightEntry
from .._api import WeightsEnum, Weights
from .._meta import _VOC_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..resnet import ResNet50Weights, ResNet101Weights, resnet50, resnet101
from ..resnet import ResNet50_Weights, ResNet101_Weights, resnet50, resnet101
__all__ = ["FCN", "FCNResNet50Weights", "FCNResNet101Weights", "fcn_resnet50", "fcn_resnet101"]
__all__ = ["FCN", "FCN_ResNet50_Weights", "FCN_ResNet101_Weights", "fcn_resnet50", "fcn_resnet101"]
_COMMON_META = {
......@@ -20,8 +20,8 @@ _COMMON_META = {
}
class FCNResNet50Weights(Weights):
CocoWithVocLabels_RefV1 = WeightEntry(
class FCN_ResNet50_Weights(WeightsEnum):
CocoWithVocLabels_V1 = Weights(
url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth",
transforms=partial(VocEval, resize_size=520),
meta={
......@@ -34,8 +34,8 @@ class FCNResNet50Weights(Weights):
)
class FCNResNet101Weights(Weights):
CocoWithVocLabels_RefV1 = WeightEntry(
class FCN_ResNet101_Weights(WeightsEnum):
CocoWithVocLabels_V1 = Weights(
url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth",
transforms=partial(VocEval, resize_size=520),
meta={
......@@ -49,25 +49,25 @@ class FCNResNet101Weights(Weights):
def fcn_resnet50(
weights: Optional[FCNResNet50Weights] = None,
weights: Optional[FCN_ResNet50_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None,
weights_backbone: Optional[ResNet50Weights] = None,
weights_backbone: Optional[ResNet50_Weights] = None,
**kwargs: Any,
) -> FCN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", FCNResNet50Weights.CocoWithVocLabels_RefV1)
weights = FCNResNet50Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", FCN_ResNet50_Weights.CocoWithVocLabels_V1)
weights = FCN_ResNet50_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1
kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1
)
weights_backbone = ResNet50Weights.verify(weights_backbone)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
......@@ -86,25 +86,25 @@ def fcn_resnet50(
def fcn_resnet101(
weights: Optional[FCNResNet101Weights] = None,
weights: Optional[FCN_ResNet101_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None,
weights_backbone: Optional[ResNet101Weights] = None,
weights_backbone: Optional[ResNet101_Weights] = None,
**kwargs: Any,
) -> FCN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", FCNResNet101Weights.CocoWithVocLabels_RefV1)
weights = FCNResNet101Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", FCN_ResNet101_Weights.CocoWithVocLabels_V1)
weights = FCN_ResNet101_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet101Weights.ImageNet1K_RefV1
kwargs, "pretrained_backbone", "weights_backbone", ResNet101_Weights.ImageNet1K_V1
)
weights_backbone = ResNet101Weights.verify(weights_backbone)
weights_backbone = ResNet101_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
......
......@@ -5,17 +5,17 @@ from torchvision.prototype.transforms import VocEval
from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3
from .._api import Weights, WeightEntry
from .._api import WeightsEnum, Weights
from .._meta import _VOC_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large
__all__ = ["LRASPP", "LRASPPMobileNetV3LargeWeights", "lraspp_mobilenet_v3_large"]
__all__ = ["LRASPP", "LRASPP_MobileNet_V3_Large_Weights", "lraspp_mobilenet_v3_large"]
class LRASPPMobileNetV3LargeWeights(Weights):
CocoWithVocLabels_RefV1 = WeightEntry(
class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum):
CocoWithVocLabels_V1 = Weights(
url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth",
transforms=partial(VocEval, resize_size=520),
meta={
......@@ -30,10 +30,10 @@ class LRASPPMobileNetV3LargeWeights(Weights):
def lraspp_mobilenet_v3_large(
weights: Optional[LRASPPMobileNetV3LargeWeights] = None,
weights: Optional[LRASPP_MobileNet_V3_Large_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
weights_backbone: Optional[MobileNet_V3_Large_Weights] = None,
**kwargs: Any,
) -> LRASPP:
if kwargs.pop("aux_loss", False):
......@@ -43,16 +43,16 @@ def lraspp_mobilenet_v3_large(
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(
kwargs, "pretrained", "weights", LRASPPMobileNetV3LargeWeights.CocoWithVocLabels_RefV1
kwargs, "pretrained", "weights", LRASPP_MobileNet_V3_Large_Weights.CocoWithVocLabels_V1
)
weights = LRASPPMobileNetV3LargeWeights.verify(weights)
weights = LRASPP_MobileNet_V3_Large_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1
kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1
)
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
......
......@@ -5,17 +5,17 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode
from ...models.shufflenetv2 import ShuffleNetV2
from ._api import Weights, WeightEntry
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [
"ShuffleNetV2",
"ShuffleNetV2_x0_5Weights",
"ShuffleNetV2_x1_0Weights",
"ShuffleNetV2_x1_5Weights",
"ShuffleNetV2_x2_0Weights",
"ShuffleNet_V2_X0_5_Weights",
"ShuffleNet_V2_X1_0_Weights",
"ShuffleNet_V2_X1_5_Weights",
"ShuffleNet_V2_X2_0_Weights",
"shufflenet_v2_x0_5",
"shufflenet_v2_x1_0",
"shufflenet_v2_x1_5",
......@@ -24,7 +24,7 @@ __all__ = [
def _shufflenetv2(
weights: Optional[Weights],
weights: Optional[WeightsEnum],
progress: bool,
*args: Any,
**kwargs: Any,
......@@ -48,8 +48,8 @@ _COMMON_META = {
}
class ShuffleNetV2_x0_5Weights(Weights):
ImageNet1K_Community = WeightEntry(
class ShuffleNet_V2_X0_5_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -61,8 +61,8 @@ class ShuffleNetV2_x0_5Weights(Weights):
)
class ShuffleNetV2_x1_0Weights(Weights):
ImageNet1K_Community = WeightEntry(
class ShuffleNet_V2_X1_0_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -74,57 +74,57 @@ class ShuffleNetV2_x1_0Weights(Weights):
)
class ShuffleNetV2_x1_5Weights(Weights):
class ShuffleNet_V2_X1_5_Weights(WeightsEnum):
pass
class ShuffleNetV2_x2_0Weights(Weights):
class ShuffleNet_V2_X2_0_Weights(WeightsEnum):
pass
def shufflenet_v2_x0_5(
weights: Optional[ShuffleNetV2_x0_5Weights] = None, progress: bool = True, **kwargs: Any
weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNetV2_x0_5Weights.ImageNet1K_Community)
weights = ShuffleNetV2_x0_5Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNet_V2_X0_5_Weights.ImageNet1K_V1)
weights = ShuffleNet_V2_X0_5_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
def shufflenet_v2_x1_0(
weights: Optional[ShuffleNetV2_x1_0Weights] = None, progress: bool = True, **kwargs: Any
weights: Optional[ShuffleNet_V2_X1_0_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNetV2_x1_0Weights.ImageNet1K_Community)
weights = ShuffleNetV2_x1_0Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNet_V2_X1_0_Weights.ImageNet1K_V1)
weights = ShuffleNet_V2_X1_0_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
def shufflenet_v2_x1_5(
weights: Optional[ShuffleNetV2_x1_5Weights] = None, progress: bool = True, **kwargs: Any
weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = ShuffleNetV2_x1_5Weights.verify(weights)
weights = ShuffleNet_V2_X1_5_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
def shufflenet_v2_x2_0(
weights: Optional[ShuffleNetV2_x2_0Weights] = None, progress: bool = True, **kwargs: Any
weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = ShuffleNetV2_x2_0Weights.verify(weights)
weights = ShuffleNet_V2_X2_0_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
......@@ -5,12 +5,12 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode
from ...models.squeezenet import SqueezeNet
from ._api import Weights, WeightEntry
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = ["SqueezeNet", "SqueezeNet1_0Weights", "SqueezeNet1_1Weights", "squeezenet1_0", "squeezenet1_1"]
__all__ = ["SqueezeNet", "SqueezeNet1_0_Weights", "SqueezeNet1_1_Weights", "squeezenet1_0", "squeezenet1_1"]
_COMMON_META = {
......@@ -21,8 +21,8 @@ _COMMON_META = {
}
class SqueezeNet1_0Weights(Weights):
ImageNet1K_Community = WeightEntry(
class SqueezeNet1_0_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -34,8 +34,8 @@ class SqueezeNet1_0Weights(Weights):
)
class SqueezeNet1_1Weights(Weights):
ImageNet1K_Community = WeightEntry(
class SqueezeNet1_1_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -47,12 +47,12 @@ class SqueezeNet1_1Weights(Weights):
)
def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet:
def squeezenet1_0(weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_0Weights.ImageNet1K_Community)
weights = SqueezeNet1_0Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_0_Weights.ImageNet1K_V1)
weights = SqueezeNet1_0_Weights.verify(weights)
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
......@@ -65,12 +65,12 @@ def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool
return model
def squeezenet1_1(weights: Optional[SqueezeNet1_1Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet:
def squeezenet1_1(weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_1Weights.ImageNet1K_Community)
weights = SqueezeNet1_1Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_1_Weights.ImageNet1K_V1)
weights = SqueezeNet1_1_Weights.verify(weights)
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
......
......@@ -5,21 +5,21 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode
from ...models.vgg import VGG, make_layers, cfgs
from ._api import Weights, WeightEntry
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [
"VGG",
"VGG11Weights",
"VGG11BNWeights",
"VGG13Weights",
"VGG13BNWeights",
"VGG16Weights",
"VGG16BNWeights",
"VGG19Weights",
"VGG19BNWeights",
"VGG11_Weights",
"VGG11_BN_Weights",
"VGG13_Weights",
"VGG13_BN_Weights",
"VGG16_Weights",
"VGG16_BN_Weights",
"VGG19_Weights",
"VGG19_BN_Weights",
"vgg11",
"vgg11_bn",
"vgg13",
......@@ -31,7 +31,7 @@ __all__ = [
]
def _vgg(cfg: str, batch_norm: bool, weights: Optional[Weights], progress: bool, **kwargs: Any) -> VGG:
def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
......@@ -48,8 +48,8 @@ _COMMON_META = {
}
class VGG11Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class VGG11_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg11-8a719046.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -61,8 +61,8 @@ class VGG11Weights(Weights):
)
class VGG11BNWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class VGG11_BN_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -74,8 +74,8 @@ class VGG11BNWeights(Weights):
)
class VGG13Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class VGG13_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg13-19584684.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -87,8 +87,8 @@ class VGG13Weights(Weights):
)
class VGG13BNWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class VGG13_BN_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -100,8 +100,8 @@ class VGG13BNWeights(Weights):
)
class VGG16Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class VGG16_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg16-397923af.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -114,7 +114,7 @@ class VGG16Weights(Weights):
# We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the
# same input standardization method as the paper. Only the `features` weights have proper values, those on the
# `classifier` module are filled with nans.
ImageNet1K_Features = WeightEntry(
ImageNet1K_Features = Weights(
url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth",
transforms=partial(
ImageNetEval, crop_size=224, mean=(0.48235, 0.45882, 0.40784), std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0)
......@@ -131,8 +131,8 @@ class VGG16Weights(Weights):
)
class VGG16BNWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class VGG16_BN_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -144,8 +144,8 @@ class VGG16BNWeights(Weights):
)
class VGG19Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class VGG19_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -157,8 +157,8 @@ class VGG19Weights(Weights):
)
class VGG19BNWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class VGG19_BN_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -170,81 +170,81 @@ class VGG19BNWeights(Weights):
)
def vgg11(weights: Optional[VGG11Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
def vgg11(weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11Weights.ImageNet1K_RefV1)
weights = VGG11Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11_Weights.ImageNet1K_V1)
weights = VGG11_Weights.verify(weights)
return _vgg("A", False, weights, progress, **kwargs)
def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
def vgg11_bn(weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11BNWeights.ImageNet1K_RefV1)
weights = VGG11BNWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11_BN_Weights.ImageNet1K_V1)
weights = VGG11_BN_Weights.verify(weights)
return _vgg("A", True, weights, progress, **kwargs)
def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
def vgg13(weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13Weights.ImageNet1K_RefV1)
weights = VGG13Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13_Weights.ImageNet1K_V1)
weights = VGG13_Weights.verify(weights)
return _vgg("B", False, weights, progress, **kwargs)
def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
def vgg13_bn(weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13BNWeights.ImageNet1K_RefV1)
weights = VGG13BNWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13_BN_Weights.ImageNet1K_V1)
weights = VGG13_BN_Weights.verify(weights)
return _vgg("B", True, weights, progress, **kwargs)
def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
def vgg16(weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16Weights.ImageNet1K_RefV1)
weights = VGG16Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16_Weights.ImageNet1K_V1)
weights = VGG16_Weights.verify(weights)
return _vgg("D", False, weights, progress, **kwargs)
def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
def vgg16_bn(weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16BNWeights.ImageNet1K_RefV1)
weights = VGG16BNWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16_BN_Weights.ImageNet1K_V1)
weights = VGG16_BN_Weights.verify(weights)
return _vgg("D", True, weights, progress, **kwargs)
def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
def vgg19(weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19Weights.ImageNet1K_RefV1)
weights = VGG19Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19_Weights.ImageNet1K_V1)
weights = VGG19_Weights.verify(weights)
return _vgg("E", False, weights, progress, **kwargs)
def vgg19_bn(weights: Optional[VGG19BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
def vgg19_bn(weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19BNWeights.ImageNet1K_RefV1)
weights = VGG19BNWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19_BN_Weights.ImageNet1K_V1)
weights = VGG19_BN_Weights.verify(weights)
return _vgg("E", True, weights, progress, **kwargs)
......@@ -15,16 +15,16 @@ from ....models.video.resnet import (
R2Plus1dStem,
VideoResNet,
)
from .._api import Weights, WeightEntry
from .._api import WeightsEnum, Weights
from .._meta import _KINETICS400_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [
"VideoResNet",
"R3D_18Weights",
"MC3_18Weights",
"R2Plus1D_18Weights",
"R3D_18_Weights",
"MC3_18_Weights",
"R2Plus1D_18_Weights",
"r3d_18",
"mc3_18",
"r2plus1d_18",
......@@ -36,7 +36,7 @@ def _video_resnet(
conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]],
layers: List[int],
stem: Callable[..., nn.Module],
weights: Optional[Weights],
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> VideoResNet:
......@@ -59,8 +59,8 @@ _COMMON_META = {
}
class R3D_18Weights(Weights):
Kinetics400_RefV1 = WeightEntry(
class R3D_18_Weights(WeightsEnum):
Kinetics400_V1 = Weights(
url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth",
transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)),
meta={
......@@ -72,8 +72,8 @@ class R3D_18Weights(Weights):
)
class MC3_18Weights(Weights):
Kinetics400_RefV1 = WeightEntry(
class MC3_18_Weights(WeightsEnum):
Kinetics400_V1 = Weights(
url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth",
transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)),
meta={
......@@ -85,8 +85,8 @@ class MC3_18Weights(Weights):
)
class R2Plus1D_18Weights(Weights):
Kinetics400_RefV1 = WeightEntry(
class R2Plus1D_18_Weights(WeightsEnum):
Kinetics400_V1 = Weights(
url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth",
transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)),
meta={
......@@ -98,12 +98,12 @@ class R2Plus1D_18Weights(Weights):
)
def r3d_18(weights: Optional[R3D_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
def r3d_18(weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", R3D_18Weights.Kinetics400_RefV1)
weights = R3D_18Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", R3D_18_Weights.Kinetics400_V1)
weights = R3D_18_Weights.verify(weights)
return _video_resnet(
BasicBlock,
......@@ -116,12 +116,12 @@ def r3d_18(weights: Optional[R3D_18Weights] = None, progress: bool = True, **kwa
)
def mc3_18(weights: Optional[MC3_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
def mc3_18(weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MC3_18Weights.Kinetics400_RefV1)
weights = MC3_18Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", MC3_18_Weights.Kinetics400_V1)
weights = MC3_18_Weights.verify(weights)
return _video_resnet(
BasicBlock,
......@@ -134,12 +134,12 @@ def mc3_18(weights: Optional[MC3_18Weights] = None, progress: bool = True, **kwa
)
def r2plus1d_18(weights: Optional[R2Plus1D_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
def r2plus1d_18(weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", R2Plus1D_18Weights.Kinetics400_RefV1)
weights = R2Plus1D_18Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", R2Plus1D_18_Weights.Kinetics400_V1)
weights = R2Plus1D_18_Weights.verify(weights)
return _video_resnet(
BasicBlock,
......
......@@ -11,16 +11,16 @@ import torch
import torch.nn as nn
from torch import Tensor
from ._api import Weights
from ._api import WeightsEnum
from ._utils import _deprecated_param, _deprecated_positional
__all__ = [
"VisionTransformer",
"VisionTransformer_B_16Weights",
"VisionTransformer_B_32Weights",
"VisionTransformer_L_16Weights",
"VisionTransformer_L_32Weights",
"ViT_B_16_Weights",
"ViT_B_32_Weights",
"ViT_L_16_Weights",
"ViT_L_32_Weights",
"vit_b_16",
"vit_b_32",
"vit_l_16",
......@@ -231,22 +231,22 @@ class VisionTransformer(nn.Module):
return x
class VisionTransformer_B_16Weights(Weights):
class ViT_B_16_Weights(WeightsEnum):
# If a default model is added here the corresponding changes need to be done in vit_b_16
pass
class VisionTransformer_B_32Weights(Weights):
class ViT_B_32_Weights(WeightsEnum):
# If a default model is added here the corresponding changes need to be done in vit_b_32
pass
class VisionTransformer_L_16Weights(Weights):
class ViT_L_16_Weights(WeightsEnum):
# If a default model is added here the corresponding changes need to be done in vit_l_16
pass
class VisionTransformer_L_32Weights(Weights):
class ViT_L_32_Weights(WeightsEnum):
# If a default model is added here the corresponding changes need to be done in vit_l_32
pass
......@@ -257,7 +257,7 @@ def _vision_transformer(
num_heads: int,
hidden_dim: int,
mlp_dim: int,
weights: Optional[Weights],
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> VisionTransformer:
......@@ -279,15 +279,13 @@ def _vision_transformer(
return model
def vit_b_16(
weights: Optional[VisionTransformer_B_16Weights] = None, progress: bool = True, **kwargs: Any
) -> VisionTransformer:
def vit_b_16(weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_b_16 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args:
weights (VisionTransformer_B_16Weights, optional): If not None, returns a model pre-trained on ImageNet.
weights (ViT_B_16Weights, optional): If not None, returns a model pre-trained on ImageNet.
Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
......@@ -295,7 +293,7 @@ def vit_b_16(
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = VisionTransformer_B_16Weights.verify(weights)
weights = ViT_B_16_Weights.verify(weights)
return _vision_transformer(
patch_size=16,
......@@ -309,15 +307,13 @@ def vit_b_16(
)
def vit_b_32(
weights: Optional[VisionTransformer_B_32Weights] = None, progress: bool = True, **kwargs: Any
) -> VisionTransformer:
def vit_b_32(weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_b_32 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args:
weights (VisionTransformer_B_32Weights, optional): If not None, returns a model pre-trained on ImageNet.
weights (ViT_B_32Weights, optional): If not None, returns a model pre-trained on ImageNet.
Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
......@@ -325,7 +321,7 @@ def vit_b_32(
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = VisionTransformer_B_32Weights.verify(weights)
weights = ViT_B_32_Weights.verify(weights)
return _vision_transformer(
patch_size=32,
......@@ -339,15 +335,13 @@ def vit_b_32(
)
def vit_l_16(
weights: Optional[VisionTransformer_L_16Weights] = None, progress: bool = True, **kwargs: Any
) -> VisionTransformer:
def vit_l_16(weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_l_16 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args:
weights (VisionTransformer_L_16Weights, optional): If not None, returns a model pre-trained on ImageNet.
weights (ViT_L_16Weights, optional): If not None, returns a model pre-trained on ImageNet.
Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
......@@ -355,7 +349,7 @@ def vit_l_16(
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = VisionTransformer_L_16Weights.verify(weights)
weights = ViT_L_16_Weights.verify(weights)
return _vision_transformer(
patch_size=16,
......@@ -369,15 +363,13 @@ def vit_l_16(
)
def vit_l_32(
weights: Optional[VisionTransformer_B_32Weights] = None, progress: bool = True, **kwargs: Any
) -> VisionTransformer:
def vit_l_32(weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_l_32 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args:
weights (VisionTransformer_L_16Weights, optional): If not None, returns a model pre-trained on ImageNet.
weights (ViT_L_32Weights, optional): If not None, returns a model pre-trained on ImageNet.
Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
......@@ -385,7 +377,7 @@ def vit_l_32(
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = VisionTransformer_L_32Weights.verify(weights)
weights = ViT_L_32_Weights.verify(weights)
return _vision_transformer(
patch_size=32,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment