Unverified Commit 1deb2ec2 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Cleanup Models prototype implementation (#4940)

* Disable WeightEntry to pass-through `Weights.verify()`

* Rename `Weights.state_dict()` to `Weights.get_state_dict()`

* Add TODO for missing doc.

* Moving warning messages for googlenet.

* Upper-case global `_COMMON_META` var

* Replace argument with parameter in all warnings.
parent 30f4d108
......@@ -50,7 +50,7 @@ class Weights(Enum):
if obj is not None:
if type(obj) is str:
obj = cls.from_str(obj)
elif not isinstance(obj, cls) and not isinstance(obj, WeightEntry):
elif not isinstance(obj, cls):
raise TypeError(
f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
)
......@@ -63,7 +63,7 @@ class Weights(Enum):
return v
raise ValueError(f"Invalid value {value} for enum {cls.__name__}.")
def state_dict(self, progress: bool) -> OrderedDict:
def get_state_dict(self, progress: bool) -> OrderedDict:
return load_state_dict_from_url(self.url, progress=progress)
def __repr__(self):
......@@ -90,7 +90,7 @@ def get_weight(fn: Callable, weight_name: str) -> Weights:
"""
sig = signature(fn)
if "weights" not in sig.parameters:
raise ValueError("The method is missing the 'weights' argument.")
raise ValueError("The method is missing the 'weights' parameter.")
ann = signature(fn).parameters["weights"].annotation
weights_class = None
......
......@@ -30,7 +30,7 @@ class AlexNetWeights(Weights):
def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = AlexNetWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = AlexNetWeights.verify(weights)
if weights is not None:
......@@ -39,6 +39,6 @@ def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **k
model = AlexNet(**kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
......@@ -34,7 +34,7 @@ def _load_state_dict(model: nn.Module, weights: Weights, progress: bool) -> None
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
)
state_dict = weights.state_dict(progress=progress)
state_dict = weights.get_state_dict(progress=progress)
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
......@@ -63,11 +63,11 @@ def _densenet(
return model
_common_meta = {
_COMMON_META = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": None, # weights ported from LuaTorch
"recipe": None, # TODO: add here a URL to documentation stating that the weights were ported from LuaTorch
}
......@@ -76,7 +76,7 @@ class DenseNet121Weights(Weights):
url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 74.434,
"acc@5": 91.972,
},
......@@ -88,7 +88,7 @@ class DenseNet161Weights(Weights):
url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 77.138,
"acc@5": 93.560,
},
......@@ -100,7 +100,7 @@ class DenseNet169Weights(Weights):
url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 75.600,
"acc@5": 92.806,
},
......@@ -112,7 +112,7 @@ class DenseNet201Weights(Weights):
url="https://download.pytorch.org/models/densenet201-c1103571.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 76.896,
"acc@5": 93.370,
},
......@@ -121,7 +121,7 @@ class DenseNet201Weights(Weights):
def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet121Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = DenseNet121Weights.verify(weights)
......@@ -130,7 +130,7 @@ def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = T
def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet161Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = DenseNet161Weights.verify(weights)
......@@ -139,7 +139,7 @@ def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = T
def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet169Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = DenseNet169Weights.verify(weights)
......@@ -148,7 +148,7 @@ def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = T
def densenet201(weights: Optional[DenseNet201Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet201Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = DenseNet201Weights.verify(weights)
......
......@@ -30,7 +30,7 @@ __all__ = [
]
_common_meta = {
_COMMON_META = {
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
......@@ -41,7 +41,7 @@ class FasterRCNNResNet50FPNWeights(Weights):
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
transforms=CocoEval,
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
"map": 37.0,
},
......@@ -53,7 +53,7 @@ class FasterRCNNMobileNetV3LargeFPNWeights(Weights):
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
transforms=CocoEval,
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
"map": 32.8,
},
......@@ -65,7 +65,7 @@ class FasterRCNNMobileNetV3Large320FPNWeights(Weights):
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
transforms=CocoEval,
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
"map": 22.8,
},
......@@ -81,11 +81,11 @@ def fasterrcnn_resnet50_fpn(
**kwargs: Any,
) -> FasterRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = FasterRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNResNet50FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet50Weights.verify(weights_backbone)
......@@ -102,7 +102,7 @@ def fasterrcnn_resnet50_fpn(
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == FasterRCNNResNet50FPNWeights.Coco_RefV1:
overwrite_eps(model, 0.0)
......@@ -142,7 +142,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
......@@ -156,11 +156,11 @@ def fasterrcnn_mobilenet_v3_large_fpn(
**kwargs: Any,
) -> FasterRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = FasterRCNNMobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNMobileNetV3LargeFPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
......@@ -188,11 +188,11 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
**kwargs: Any,
) -> FasterRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNMobileNetV3Large320FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
......
......@@ -22,7 +22,7 @@ __all__ = [
]
_common_meta = {"categories": _COCO_PERSON_CATEGORIES, "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES}
_COMMON_META = {"categories": _COCO_PERSON_CATEGORIES, "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES}
class KeypointRCNNResNet50FPNWeights(Weights):
......@@ -30,7 +30,7 @@ class KeypointRCNNResNet50FPNWeights(Weights):
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
transforms=CocoEval,
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/issues/1606",
"box_map": 50.6,
"kp_map": 61.1,
......@@ -40,7 +40,7 @@ class KeypointRCNNResNet50FPNWeights(Weights):
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
transforms=CocoEval,
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
"box_map": 54.6,
"kp_map": 65.0,
......@@ -58,7 +58,7 @@ def keypointrcnn_resnet50_fpn(
**kwargs: Any,
) -> KeypointRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
pretrained = kwargs.pop("pretrained")
if type(pretrained) == str and pretrained == "legacy":
weights = KeypointRCNNResNet50FPNWeights.Coco_RefV1_Legacy
......@@ -68,7 +68,7 @@ def keypointrcnn_resnet50_fpn(
weights = None
weights = KeypointRCNNResNet50FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet50Weights.verify(weights_backbone)
......@@ -86,7 +86,7 @@ def keypointrcnn_resnet50_fpn(
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == KeypointRCNNResNet50FPNWeights.Coco_RefV1:
overwrite_eps(model, 0.0)
......
......@@ -46,11 +46,11 @@ def maskrcnn_resnet50_fpn(
**kwargs: Any,
) -> MaskRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = MaskRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = MaskRCNNResNet50FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet50Weights.verify(weights_backbone)
......@@ -67,7 +67,7 @@ def maskrcnn_resnet50_fpn(
model = MaskRCNN(backbone, num_classes=num_classes, **kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == MaskRCNNResNet50FPNWeights.Coco_RefV1:
overwrite_eps(model, 0.0)
......
......@@ -46,11 +46,11 @@ def retinanet_resnet50_fpn(
**kwargs: Any,
) -> RetinaNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RetinaNetResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = RetinaNetResNet50FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet50Weights.verify(weights_backbone)
......@@ -70,7 +70,7 @@ def retinanet_resnet50_fpn(
model = RetinaNet(backbone, num_classes, **kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == RetinaNetResNet50FPNWeights.Coco_RefV1:
overwrite_eps(model, 0.0)
......
......@@ -44,16 +44,16 @@ def ssd300_vgg16(
**kwargs: Any,
) -> SSD:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = SSD300VGG16Weights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = SSD300VGG16Weights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = VGG16Weights.ImageNet1K_Features if kwargs.pop("pretrained_backbone") else None
weights_backbone = VGG16Weights.verify(weights_backbone)
if "size" in kwargs:
warnings.warn("The size of the model is already fixed; ignoring the argument.")
warnings.warn("The size of the model is already fixed; ignoring the parameter.")
if weights is not None:
weights_backbone = None
......@@ -81,6 +81,6 @@ def ssd300_vgg16(
model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
......@@ -50,16 +50,16 @@ def ssdlite320_mobilenet_v3_large(
**kwargs: Any,
) -> SSD:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = SSDlite320MobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = SSDlite320MobileNetV3LargeFPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
if "size" in kwargs:
warnings.warn("The size of the model is already fixed; ignoring the argument.")
warnings.warn("The size of the model is already fixed; ignoring the parameter.")
if weights is not None:
weights_backbone = None
......@@ -114,6 +114,6 @@ def ssdlite320_mobilenet_v3_large(
)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
......@@ -57,12 +57,12 @@ def _efficientnet(
model = EfficientNet(inverted_residual_setting, dropout, **kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
_common_meta = {
_COMMON_META = {
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BICUBIC,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet",
......@@ -74,7 +74,7 @@ class EfficientNetB0Weights(Weights):
url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC),
meta={
**_common_meta,
**_COMMON_META,
"size": (224, 224),
"acc@1": 77.692,
"acc@5": 93.532,
......@@ -87,7 +87,7 @@ class EfficientNetB1Weights(Weights):
url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
transforms=partial(ImageNetEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC),
meta={
**_common_meta,
**_COMMON_META,
"size": (240, 240),
"acc@1": 78.642,
"acc@5": 94.186,
......@@ -100,7 +100,7 @@ class EfficientNetB2Weights(Weights):
url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
transforms=partial(ImageNetEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC),
meta={
**_common_meta,
**_COMMON_META,
"size": (288, 288),
"acc@1": 80.608,
"acc@5": 95.310,
......@@ -113,7 +113,7 @@ class EfficientNetB3Weights(Weights):
url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
transforms=partial(ImageNetEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC),
meta={
**_common_meta,
**_COMMON_META,
"size": (300, 300),
"acc@1": 82.008,
"acc@5": 96.054,
......@@ -126,7 +126,7 @@ class EfficientNetB4Weights(Weights):
url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
transforms=partial(ImageNetEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC),
meta={
**_common_meta,
**_COMMON_META,
"size": (380, 380),
"acc@1": 83.384,
"acc@5": 96.594,
......@@ -139,7 +139,7 @@ class EfficientNetB5Weights(Weights):
url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
transforms=partial(ImageNetEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC),
meta={
**_common_meta,
**_COMMON_META,
"size": (456, 456),
"acc@1": 83.444,
"acc@5": 96.628,
......@@ -152,7 +152,7 @@ class EfficientNetB6Weights(Weights):
url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
transforms=partial(ImageNetEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC),
meta={
**_common_meta,
**_COMMON_META,
"size": (528, 528),
"acc@1": 84.008,
"acc@5": 96.916,
......@@ -165,7 +165,7 @@ class EfficientNetB7Weights(Weights):
url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
transforms=partial(ImageNetEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC),
meta={
**_common_meta,
**_COMMON_META,
"size": (600, 600),
"acc@1": 84.122,
"acc@5": 96.908,
......@@ -177,7 +177,7 @@ def efficientnet_b0(
weights: Optional[EfficientNetB0Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = EfficientNetB0Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None
weights = EfficientNetB0Weights.verify(weights)
return _efficientnet(width_mult=1.0, depth_mult=1.0, dropout=0.2, weights=weights, progress=progress, **kwargs)
......@@ -187,7 +187,7 @@ def efficientnet_b1(
weights: Optional[EfficientNetB1Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = EfficientNetB1Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None
weights = EfficientNetB1Weights.verify(weights)
return _efficientnet(width_mult=1.0, depth_mult=1.1, dropout=0.2, weights=weights, progress=progress, **kwargs)
......@@ -197,7 +197,7 @@ def efficientnet_b2(
weights: Optional[EfficientNetB2Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = EfficientNetB2Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None
weights = EfficientNetB2Weights.verify(weights)
return _efficientnet(width_mult=1.1, depth_mult=1.2, dropout=0.3, weights=weights, progress=progress, **kwargs)
......@@ -207,7 +207,7 @@ def efficientnet_b3(
weights: Optional[EfficientNetB3Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = EfficientNetB3Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None
weights = EfficientNetB3Weights.verify(weights)
return _efficientnet(width_mult=1.2, depth_mult=1.4, dropout=0.3, weights=weights, progress=progress, **kwargs)
......@@ -217,7 +217,7 @@ def efficientnet_b4(
weights: Optional[EfficientNetB4Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = EfficientNetB4Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None
weights = EfficientNetB4Weights.verify(weights)
return _efficientnet(width_mult=1.4, depth_mult=1.8, dropout=0.4, weights=weights, progress=progress, **kwargs)
......@@ -227,7 +227,7 @@ def efficientnet_b5(
weights: Optional[EfficientNetB5Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = EfficientNetB5Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = EfficientNetB5Weights.verify(weights)
return _efficientnet(
......@@ -245,7 +245,7 @@ def efficientnet_b6(
weights: Optional[EfficientNetB6Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = EfficientNetB6Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = EfficientNetB6Weights.verify(weights)
return _efficientnet(
......@@ -263,7 +263,7 @@ def efficientnet_b7(
weights: Optional[EfficientNetB7Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = EfficientNetB7Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = EfficientNetB7Weights.verify(weights)
return _efficientnet(
......
......@@ -30,7 +30,7 @@ class GoogLeNetWeights(Weights):
def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = GoogLeNetWeights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = GoogLeNetWeights.verify(weights)
......@@ -38,10 +38,6 @@ def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True,
if weights is not None:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if original_aux_logits:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
)
kwargs["aux_logits"] = True
kwargs["init_weights"] = False
kwargs["num_classes"] = len(weights.meta["categories"])
......@@ -49,10 +45,14 @@ def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True,
model = GoogLeNet(**kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
if not original_aux_logits:
model.aux_logits = False
model.aux1 = None # type: ignore[assignment]
model.aux2 = None # type: ignore[assignment]
else:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
)
return model
......@@ -30,7 +30,7 @@ class InceptionV3Weights(Weights):
def inception_v3(weights: Optional[InceptionV3Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = InceptionV3Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = InceptionV3Weights.verify(weights)
......@@ -45,7 +45,7 @@ def inception_v3(weights: Optional[InceptionV3Weights] = None, progress: bool =
model = Inception3(**kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
if not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
......
......@@ -23,7 +23,7 @@ __all__ = [
]
_common_meta = {
_COMMON_META = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
......@@ -36,7 +36,7 @@ class MNASNet0_5Weights(Weights):
url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 67.734,
"acc@5": 87.490,
},
......@@ -53,7 +53,7 @@ class MNASNet1_0Weights(Weights):
url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 73.456,
"acc@5": 91.510,
},
......@@ -72,14 +72,14 @@ def _mnasnet(alpha: float, weights: Optional[Weights], progress: bool, **kwargs:
model = MNASNet(alpha, **kwargs)
if weights:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
def mnasnet0_5(weights: Optional[MNASNet0_5Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = MNASNet0_5Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = MNASNet0_5Weights.verify(weights)
......@@ -89,7 +89,7 @@ def mnasnet0_5(weights: Optional[MNASNet0_5Weights] = None, progress: bool = Tru
def mnasnet0_75(weights: Optional[MNASNet0_75Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
raise ValueError("No checkpoint is available for model type mnasnet0_75")
......@@ -100,7 +100,7 @@ def mnasnet0_75(weights: Optional[MNASNet0_75Weights] = None, progress: bool = T
def mnasnet1_0(weights: Optional[MNASNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = MNASNet1_0Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = MNASNet1_0Weights.verify(weights)
......@@ -109,7 +109,7 @@ def mnasnet1_0(weights: Optional[MNASNet1_0Weights] = None, progress: bool = Tru
def mnasnet1_3(weights: Optional[MNASNet1_3Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
raise ValueError("No checkpoint is available for model type mnasnet1_3")
......
......@@ -30,7 +30,7 @@ class MobileNetV2Weights(Weights):
def mobilenet_v2(weights: Optional[MobileNetV2Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = MobileNetV2Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = MobileNetV2Weights.verify(weights)
......@@ -40,6 +40,6 @@ def mobilenet_v2(weights: Optional[MobileNetV2Weights] = None, progress: bool =
model = MobileNetV2(**kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
......@@ -32,12 +32,12 @@ def _mobilenet_v3(
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
_common_meta = {
_COMMON_META = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
......@@ -50,7 +50,7 @@ class MobileNetV3LargeWeights(Weights):
url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 74.042,
"acc@5": 91.340,
},
......@@ -62,7 +62,7 @@ class MobileNetV3SmallWeights(Weights):
url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 67.668,
"acc@5": 87.402,
},
......@@ -73,7 +73,7 @@ def mobilenet_v3_large(
weights: Optional[MobileNetV3LargeWeights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV3:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = MobileNetV3LargeWeights.verify(weights)
......@@ -85,7 +85,7 @@ def mobilenet_v3_small(
weights: Optional[MobileNetV3SmallWeights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV3:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = MobileNetV3SmallWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = MobileNetV3SmallWeights.verify(weights)
......
......@@ -47,7 +47,7 @@ def googlenet(
**kwargs: Any,
) -> QuantizableGoogLeNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = QuantizedGoogLeNetWeights.ImageNet1K_FBGEMM_TFV1 if quantize else GoogLeNetWeights.ImageNet1K_TFV1
else:
......@@ -62,10 +62,6 @@ def googlenet(
if weights is not None:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if original_aux_logits:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
)
kwargs["aux_logits"] = True
kwargs["init_weights"] = False
kwargs["num_classes"] = len(weights.meta["categories"])
......@@ -79,10 +75,14 @@ def googlenet(
quantize_model(model, backend)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
if not original_aux_logits:
model.aux_logits = False
model.aux1 = None # type: ignore[assignment]
model.aux2 = None # type: ignore[assignment]
else:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
)
return model
......@@ -47,7 +47,7 @@ def inception_v3(
**kwargs: Any,
) -> QuantizableInception3:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = (
QuantizedInceptionV3Weights.ImageNet1K_FBGEMM_TFV1 if quantize else InceptionV3Weights.ImageNet1K_TFV1
......@@ -79,7 +79,7 @@ def inception_v3(
if quantize and not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
if not quantize and not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
......
......@@ -48,7 +48,7 @@ def mobilenet_v2(
**kwargs: Any,
) -> QuantizableMobileNetV2:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = (
QuantizedMobileNetV2Weights.ImageNet1K_QNNPACK_RefV1
......@@ -75,6 +75,6 @@ def mobilenet_v2(
quantize_model(model, backend)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
......@@ -47,7 +47,7 @@ def _mobilenet_v3_model(
torch.quantization.prepare_qat(model, inplace=True)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
if quantize:
torch.quantization.convert(model, inplace=True)
......@@ -81,7 +81,7 @@ def mobilenet_v3_large(
**kwargs: Any,
) -> QuantizableMobileNetV3:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = (
QuantizedMobileNetV3LargeWeights.ImageNet1K_QNNPACK_RefV1
......
......@@ -48,12 +48,12 @@ def _resnet(
quantize_model(model, backend)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
_common_meta = {
_COMMON_META = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
......@@ -68,7 +68,7 @@ class QuantizedResNet18Weights(Weights):
url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"unquantized": ResNet18Weights.ImageNet1K_RefV1,
"acc@1": 69.494,
"acc@5": 88.882,
......@@ -81,7 +81,7 @@ class QuantizedResNet50Weights(Weights):
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"unquantized": ResNet50Weights.ImageNet1K_RefV1,
"acc@1": 75.920,
"acc@5": 92.814,
......@@ -94,7 +94,7 @@ class QuantizedResNeXt101_32x8dWeights(Weights):
url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"unquantized": ResNeXt101_32x8dWeights.ImageNet1K_RefV1,
"acc@1": 78.986,
"acc@5": 94.480,
......@@ -109,7 +109,7 @@ def resnet18(
**kwargs: Any,
) -> QuantizableResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = QuantizedResNet18Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet18Weights.ImageNet1K_RefV1
else:
......@@ -130,7 +130,7 @@ def resnet50(
**kwargs: Any,
) -> QuantizableResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet50Weights.ImageNet1K_RefV1
else:
......@@ -151,7 +151,7 @@ def resnext101_32x8d(
**kwargs: Any,
) -> QuantizableResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = (
QuantizedResNeXt101_32x8dWeights.ImageNet1K_FBGEMM_RefV1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment