Unverified Commit 8886a3cf authored by Yoshitomo Matsubara's avatar Yoshitomo Matsubara Committed by GitHub
Browse files

Rename prototype weight names to comply with PEP8 (#5257)



* renamed ImageNet weights

* renamed COCO weights

* renamed COCO with VOC labels weights

* renamed Kinetics 400 weights

* rename default with DEFAULT

* update test

* fix typos

* update test

* update test

* update test

* indent as w was weight_enum

* revert

* Adding back the capitalization test
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent c6722307
......@@ -54,15 +54,15 @@ def _build_model(fn, **kwargs):
@pytest.mark.parametrize(
"name, weight",
[
("ResNet50_Weights.ImageNet1K_V1", models.ResNet50_Weights.ImageNet1K_V1),
("ResNet50_Weights.default", models.ResNet50_Weights.ImageNet1K_V2),
("ResNet50_Weights.IMAGENET1K_V1", models.ResNet50_Weights.IMAGENET1K_V1),
("ResNet50_Weights.DEFAULT", models.ResNet50_Weights.IMAGENET1K_V2),
(
"ResNet50_QuantizedWeights.default",
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V2,
"ResNet50_QuantizedWeights.DEFAULT",
models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2,
),
(
"ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1",
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1,
"ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1",
models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1,
),
],
)
......@@ -83,7 +83,7 @@ def test_naming_conventions(model_fn):
weights_enum = _get_model_weights(model_fn)
print(weights_enum)
assert weights_enum is not None
assert len(weights_enum) == 0 or hasattr(weights_enum, "default")
assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT")
@pytest.mark.parametrize(
......@@ -117,13 +117,14 @@ def test_schema_meta_validation(model_fn):
problematic_weights = {}
incorrect_params = []
bad_names = []
for w in weights_enum:
missing_fields = fields - set(w.meta.keys())
if missing_fields:
problematic_weights[w] = missing_fields
if w == weights_enum.default:
if w == weights_enum.DEFAULT:
if module_name == "quantization":
# parametes() cound doesn't work well with quantization, so we check against the non-quantized
# parameters() count doesn't work well with quantization, so we check against the non-quantized
unquantized_w = w.meta.get("unquantized")
if unquantized_w is not None and w.meta.get("num_params") != unquantized_w.meta.get("num_params"):
incorrect_params.append(w)
......@@ -131,11 +132,14 @@ def test_schema_meta_validation(model_fn):
if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()):
incorrect_params.append(w)
else:
if w.meta.get("num_params") != weights_enum.default.meta.get("num_params"):
if w.meta.get("num_params") != weights_enum.DEFAULT.meta.get("num_params"):
incorrect_params.append(w)
if not w.name.isupper():
bad_names.append(w)
assert not problematic_weights
assert not incorrect_params
assert not bad_names
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))
......
......@@ -80,7 +80,7 @@ class WeightsEnum(Enum):
def get_weight(name: str) -> WeightsEnum:
"""
Gets the weight enum value by its full name. Example: "ResNet50_Weights.ImageNet1K_V1"
Gets the weight enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1"
Args:
name (str): The name of the weight enum entry.
......
......@@ -78,7 +78,7 @@ def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[D
)
if pretrained_arg:
msg = (
f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.default` "
f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.DEFAULT` "
f"to get the most up-to-date weights."
)
warnings.warn(msg)
......
......@@ -14,7 +14,7 @@ __all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]
class AlexNet_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -31,10 +31,10 @@ class AlexNet_Weights(WeightsEnum):
"acc@5": 79.066,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1))
def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
weights = AlexNet_Weights.verify(weights)
......
......@@ -178,7 +178,7 @@ class ConvNeXt(nn.Module):
class ConvNeXt_Tiny_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=236),
meta={
......@@ -195,10 +195,10 @@ class ConvNeXt_Tiny_Weights(WeightsEnum):
"acc@5": 96.146,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1))
def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
r"""ConvNeXt model architecture from the
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper.
......
......@@ -76,7 +76,7 @@ _COMMON_META = {
class DenseNet121_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -86,11 +86,11 @@ class DenseNet121_Weights(WeightsEnum):
"acc@5": 91.972,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1
class DenseNet161_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -100,11 +100,11 @@ class DenseNet161_Weights(WeightsEnum):
"acc@5": 93.560,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1
class DenseNet169_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -114,11 +114,11 @@ class DenseNet169_Weights(WeightsEnum):
"acc@5": 92.806,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1
class DenseNet201_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet201-c1103571.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -128,31 +128,31 @@ class DenseNet201_Weights(WeightsEnum):
"acc@5": 93.370,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1))
def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet121_Weights.verify(weights)
return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.IMAGENET1K_V1))
def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet161_Weights.verify(weights)
return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.IMAGENET1K_V1))
def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet169_Weights.verify(weights)
return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.IMAGENET1K_V1))
def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet201_Weights.verify(weights)
......
......@@ -40,7 +40,7 @@ _COMMON_META = {
class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
Coco_V1 = Weights(
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
transforms=CocoEval,
meta={
......@@ -50,11 +50,11 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
"map": 37.0,
},
)
default = Coco_V1
DEFAULT = COCO_V1
class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
Coco_V1 = Weights(
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
transforms=CocoEval,
meta={
......@@ -64,11 +64,11 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
"map": 32.8,
},
)
default = Coco_V1
DEFAULT = COCO_V1
class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
Coco_V1 = Weights(
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
transforms=CocoEval,
meta={
......@@ -78,12 +78,12 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
"map": 22.8,
},
)
default = Coco_V1
DEFAULT = COCO_V1
@handle_legacy_interface(
weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
def fasterrcnn_resnet50_fpn(
*,
......@@ -113,7 +113,7 @@ def fasterrcnn_resnet50_fpn(
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == FasterRCNN_ResNet50_FPN_Weights.Coco_V1:
if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0)
return model
......@@ -161,8 +161,8 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
@handle_legacy_interface(
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1),
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
)
def fasterrcnn_mobilenet_v3_large_fpn(
*,
......@@ -192,8 +192,8 @@ def fasterrcnn_mobilenet_v3_large_fpn(
@handle_legacy_interface(
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1),
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
)
def fasterrcnn_mobilenet_v3_large_320_fpn(
*,
......
......@@ -38,12 +38,12 @@ class FCOS_ResNet50_FPN_Weights(WeightsEnum):
"map": 39.2,
},
)
default = COCO_V1
DEFAULT = COCO_V1
@handle_legacy_interface(
weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
def fcos_resnet50_fpn(
*,
......
......@@ -34,7 +34,7 @@ _COMMON_META = {
class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
Coco_Legacy = Weights(
COCO_LEGACY = Weights(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
transforms=CocoEval,
meta={
......@@ -45,7 +45,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
"map_kp": 61.1,
},
)
Coco_V1 = Weights(
COCO_V1 = Weights(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
transforms=CocoEval,
meta={
......@@ -56,17 +56,17 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
"map_kp": 65.0,
},
)
default = Coco_V1
DEFAULT = COCO_V1
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.Coco_Legacy
lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.COCO_LEGACY
if kwargs["pretrained"] == "legacy"
else KeypointRCNN_ResNet50_FPN_Weights.Coco_V1,
else KeypointRCNN_ResNet50_FPN_Weights.COCO_V1,
),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
def keypointrcnn_resnet50_fpn(
*,
......@@ -101,7 +101,7 @@ def keypointrcnn_resnet50_fpn(
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == KeypointRCNN_ResNet50_FPN_Weights.Coco_V1:
if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0)
return model
......@@ -24,7 +24,7 @@ __all__ = [
class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
Coco_V1 = Weights(
COCO_V1 = Weights(
url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth",
transforms=CocoEval,
meta={
......@@ -39,12 +39,12 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
"map_mask": 34.6,
},
)
default = Coco_V1
DEFAULT = COCO_V1
@handle_legacy_interface(
weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
def maskrcnn_resnet50_fpn(
*,
......@@ -74,7 +74,7 @@ def maskrcnn_resnet50_fpn(
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == MaskRCNN_ResNet50_FPN_Weights.Coco_V1:
if weights == MaskRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0)
return model
......@@ -25,7 +25,7 @@ __all__ = [
class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
Coco_V1 = Weights(
COCO_V1 = Weights(
url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
transforms=CocoEval,
meta={
......@@ -39,12 +39,12 @@ class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
"map": 36.4,
},
)
default = Coco_V1
DEFAULT = COCO_V1
@handle_legacy_interface(
weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
def retinanet_resnet50_fpn(
*,
......@@ -77,7 +77,7 @@ def retinanet_resnet50_fpn(
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == RetinaNet_ResNet50_FPN_Weights.Coco_V1:
if weights == RetinaNet_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0)
return model
......@@ -23,7 +23,7 @@ __all__ = [
class SSD300_VGG16_Weights(WeightsEnum):
Coco_V1 = Weights(
COCO_V1 = Weights(
url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth",
transforms=CocoEval,
meta={
......@@ -38,12 +38,12 @@ class SSD300_VGG16_Weights(WeightsEnum):
"map": 25.1,
},
)
default = Coco_V1
DEFAULT = COCO_V1
@handle_legacy_interface(
weights=("pretrained", SSD300_VGG16_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", VGG16_Weights.ImageNet1K_Features),
weights=("pretrained", SSD300_VGG16_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", VGG16_Weights.IMAGENET1K_FEATURES),
)
def ssd300_vgg16(
*,
......
......@@ -28,7 +28,7 @@ __all__ = [
class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum):
Coco_V1 = Weights(
COCO_V1 = Weights(
url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth",
transforms=CocoEval,
meta={
......@@ -43,12 +43,12 @@ class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum):
"map": 21.3,
},
)
default = Coco_V1
DEFAULT = COCO_V1
@handle_legacy_interface(
weights=("pretrained", SSDLite320_MobileNet_V3_Large_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1),
weights=("pretrained", SSDLite320_MobileNet_V3_Large_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
)
def ssdlite320_mobilenet_v3_large(
*,
......
......@@ -74,7 +74,7 @@ _COMMON_META = {
class EfficientNet_B0_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC),
meta={
......@@ -85,11 +85,11 @@ class EfficientNet_B0_Weights(WeightsEnum):
"acc@5": 93.532,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1
class EfficientNet_B1_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
transforms=partial(ImageNetEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC),
meta={
......@@ -100,7 +100,7 @@ class EfficientNet_B1_Weights(WeightsEnum):
"acc@5": 94.186,
},
)
ImageNet1K_V2 = Weights(
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth",
transforms=partial(ImageNetEval, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR),
meta={
......@@ -113,11 +113,11 @@ class EfficientNet_B1_Weights(WeightsEnum):
"acc@5": 94.934,
},
)
default = ImageNet1K_V2
DEFAULT = IMAGENET1K_V2
class EfficientNet_B2_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
transforms=partial(ImageNetEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC),
meta={
......@@ -128,11 +128,11 @@ class EfficientNet_B2_Weights(WeightsEnum):
"acc@5": 95.310,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1
class EfficientNet_B3_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
transforms=partial(ImageNetEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC),
meta={
......@@ -143,11 +143,11 @@ class EfficientNet_B3_Weights(WeightsEnum):
"acc@5": 96.054,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1
class EfficientNet_B4_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
transforms=partial(ImageNetEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC),
meta={
......@@ -158,11 +158,11 @@ class EfficientNet_B4_Weights(WeightsEnum):
"acc@5": 96.594,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1
class EfficientNet_B5_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
transforms=partial(ImageNetEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC),
meta={
......@@ -173,11 +173,11 @@ class EfficientNet_B5_Weights(WeightsEnum):
"acc@5": 96.628,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1
class EfficientNet_B6_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
transforms=partial(ImageNetEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC),
meta={
......@@ -188,11 +188,11 @@ class EfficientNet_B6_Weights(WeightsEnum):
"acc@5": 96.916,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1
class EfficientNet_B7_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
transforms=partial(ImageNetEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC),
meta={
......@@ -203,10 +203,10 @@ class EfficientNet_B7_Weights(WeightsEnum):
"acc@5": 96.908,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1))
def efficientnet_b0(
*, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
......@@ -215,7 +215,7 @@ def efficientnet_b0(
return _efficientnet(width_mult=1.0, depth_mult=1.0, dropout=0.2, weights=weights, progress=progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1))
def efficientnet_b1(
*, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
......@@ -224,7 +224,7 @@ def efficientnet_b1(
return _efficientnet(width_mult=1.0, depth_mult=1.1, dropout=0.2, weights=weights, progress=progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1))
def efficientnet_b2(
*, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
......@@ -233,7 +233,7 @@ def efficientnet_b2(
return _efficientnet(width_mult=1.1, depth_mult=1.2, dropout=0.3, weights=weights, progress=progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1))
def efficientnet_b3(
*, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
......@@ -242,7 +242,7 @@ def efficientnet_b3(
return _efficientnet(width_mult=1.2, depth_mult=1.4, dropout=0.3, weights=weights, progress=progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1))
def efficientnet_b4(
*, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
......@@ -251,7 +251,7 @@ def efficientnet_b4(
return _efficientnet(width_mult=1.4, depth_mult=1.8, dropout=0.4, weights=weights, progress=progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1))
def efficientnet_b5(
*, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
......@@ -268,7 +268,7 @@ def efficientnet_b5(
)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.IMAGENET1K_V1))
def efficientnet_b6(
*, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
......@@ -285,7 +285,7 @@ def efficientnet_b6(
)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.IMAGENET1K_V1))
def efficientnet_b7(
*, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
......
......@@ -15,7 +15,7 @@ __all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weig
class GoogLeNet_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/googlenet-1378be20.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -32,10 +32,10 @@ class GoogLeNet_Weights(WeightsEnum):
"acc@5": 89.530,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", GoogLeNet_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", GoogLeNet_Weights.IMAGENET1K_V1))
def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
weights = GoogLeNet_Weights.verify(weights)
......
......@@ -14,7 +14,7 @@ __all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_
class Inception_V3_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
transforms=partial(ImageNetEval, crop_size=299, resize_size=342),
meta={
......@@ -31,10 +31,10 @@ class Inception_V3_Weights(WeightsEnum):
"acc@5": 93.450,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", Inception_V3_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", Inception_V3_Weights.IMAGENET1K_V1))
def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
weights = Inception_V3_Weights.verify(weights)
......
......@@ -36,7 +36,7 @@ _COMMON_META = {
class MNASNet0_5_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -46,7 +46,7 @@ class MNASNet0_5_Weights(WeightsEnum):
"acc@5": 87.490,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1
class MNASNet0_75_Weights(WeightsEnum):
......@@ -55,7 +55,7 @@ class MNASNet0_75_Weights(WeightsEnum):
class MNASNet1_0_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -65,7 +65,7 @@ class MNASNet1_0_Weights(WeightsEnum):
"acc@5": 91.510,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1
class MNASNet1_3_Weights(WeightsEnum):
......@@ -85,7 +85,7 @@ def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwa
return model
@handle_legacy_interface(weights=("pretrained", MNASNet0_5_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", MNASNet0_5_Weights.IMAGENET1K_V1))
def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
weights = MNASNet0_5_Weights.verify(weights)
......@@ -99,7 +99,7 @@ def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool
return _mnasnet(0.75, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", MNASNet1_0_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", MNASNet1_0_Weights.IMAGENET1K_V1))
def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
weights = MNASNet1_0_Weights.verify(weights)
......
......@@ -14,7 +14,7 @@ __all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"]
class MobileNet_V2_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -31,10 +31,10 @@ class MobileNet_V2_Weights(WeightsEnum):
"acc@5": 90.286,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1))
def mobilenet_v2(
*, weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV2:
......
......@@ -49,7 +49,7 @@ _COMMON_META = {
class MobileNet_V3_Large_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -60,7 +60,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum):
"acc@5": 91.340,
},
)
ImageNet1K_V2 = Weights(
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
......@@ -71,11 +71,11 @@ class MobileNet_V3_Large_Weights(WeightsEnum):
"acc@5": 92.566,
},
)
default = ImageNet1K_V2
DEFAULT = IMAGENET1K_V2
class MobileNet_V3_Small_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -86,10 +86,10 @@ class MobileNet_V3_Small_Weights(WeightsEnum):
"acc@5": 87.402,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Large_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Large_Weights.IMAGENET1K_V1))
def mobilenet_v3_large(
*, weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV3:
......@@ -99,7 +99,7 @@ def mobilenet_v3_large(
return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Small_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Small_Weights.IMAGENET1K_V1))
def mobilenet_v3_small(
*, weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV3:
......
......@@ -115,7 +115,7 @@ class Raft_Large_Weights(WeightsEnum):
},
)
default = C_T_V2
DEFAULT = C_T_V2
class Raft_Small_Weights(WeightsEnum):
......@@ -148,7 +148,7 @@ class Raft_Small_Weights(WeightsEnum):
},
)
default = C_T_V2
DEFAULT = C_T_V2
@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_V2))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment