Unverified Commit 1deb2ec2 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Cleanup Models prototype implementation (#4940)

* Disable WeightEntry to pass-through `Weights.verify()`

* Rename `Weights.state_dict()` to `Weights.get_state_dict()`

* Add TODO for missing doc.

* Moving warning messages for googlenet.

* Upper-case global `_COMMON_META` var

* Replace argument with parameter in all warnings.
parent 30f4d108
...@@ -50,7 +50,7 @@ class Weights(Enum): ...@@ -50,7 +50,7 @@ class Weights(Enum):
if obj is not None: if obj is not None:
if type(obj) is str: if type(obj) is str:
obj = cls.from_str(obj) obj = cls.from_str(obj)
elif not isinstance(obj, cls) and not isinstance(obj, WeightEntry): elif not isinstance(obj, cls):
raise TypeError( raise TypeError(
f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}." f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
) )
...@@ -63,7 +63,7 @@ class Weights(Enum): ...@@ -63,7 +63,7 @@ class Weights(Enum):
return v return v
raise ValueError(f"Invalid value {value} for enum {cls.__name__}.") raise ValueError(f"Invalid value {value} for enum {cls.__name__}.")
def state_dict(self, progress: bool) -> OrderedDict: def get_state_dict(self, progress: bool) -> OrderedDict:
return load_state_dict_from_url(self.url, progress=progress) return load_state_dict_from_url(self.url, progress=progress)
def __repr__(self): def __repr__(self):
...@@ -90,7 +90,7 @@ def get_weight(fn: Callable, weight_name: str) -> Weights: ...@@ -90,7 +90,7 @@ def get_weight(fn: Callable, weight_name: str) -> Weights:
""" """
sig = signature(fn) sig = signature(fn)
if "weights" not in sig.parameters: if "weights" not in sig.parameters:
raise ValueError("The method is missing the 'weights' argument.") raise ValueError("The method is missing the 'weights' parameter.")
ann = signature(fn).parameters["weights"].annotation ann = signature(fn).parameters["weights"].annotation
weights_class = None weights_class = None
......
...@@ -30,7 +30,7 @@ class AlexNetWeights(Weights): ...@@ -30,7 +30,7 @@ class AlexNetWeights(Weights):
def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = AlexNetWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = AlexNetWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = AlexNetWeights.verify(weights) weights = AlexNetWeights.verify(weights)
if weights is not None: if weights is not None:
...@@ -39,6 +39,6 @@ def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **k ...@@ -39,6 +39,6 @@ def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **k
model = AlexNet(**kwargs) model = AlexNet(**kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
...@@ -34,7 +34,7 @@ def _load_state_dict(model: nn.Module, weights: Weights, progress: bool) -> None ...@@ -34,7 +34,7 @@ def _load_state_dict(model: nn.Module, weights: Weights, progress: bool) -> None
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
) )
state_dict = weights.state_dict(progress=progress) state_dict = weights.get_state_dict(progress=progress)
for key in list(state_dict.keys()): for key in list(state_dict.keys()):
res = pattern.match(key) res = pattern.match(key)
if res: if res:
...@@ -63,11 +63,11 @@ def _densenet( ...@@ -63,11 +63,11 @@ def _densenet(
return model return model
_common_meta = { _COMMON_META = {
"size": (224, 224), "size": (224, 224),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
"recipe": None, # weights ported from LuaTorch "recipe": None, # TODO: add here a URL to documentation stating that the weights were ported from LuaTorch
} }
...@@ -76,7 +76,7 @@ class DenseNet121Weights(Weights): ...@@ -76,7 +76,7 @@ class DenseNet121Weights(Weights):
url="https://download.pytorch.org/models/densenet121-a639ec97.pth", url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"acc@1": 74.434, "acc@1": 74.434,
"acc@5": 91.972, "acc@5": 91.972,
}, },
...@@ -88,7 +88,7 @@ class DenseNet161Weights(Weights): ...@@ -88,7 +88,7 @@ class DenseNet161Weights(Weights):
url="https://download.pytorch.org/models/densenet161-8d451a50.pth", url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"acc@1": 77.138, "acc@1": 77.138,
"acc@5": 93.560, "acc@5": 93.560,
}, },
...@@ -100,7 +100,7 @@ class DenseNet169Weights(Weights): ...@@ -100,7 +100,7 @@ class DenseNet169Weights(Weights):
url="https://download.pytorch.org/models/densenet169-b2777c0a.pth", url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"acc@1": 75.600, "acc@1": 75.600,
"acc@5": 92.806, "acc@5": 92.806,
}, },
...@@ -112,7 +112,7 @@ class DenseNet201Weights(Weights): ...@@ -112,7 +112,7 @@ class DenseNet201Weights(Weights):
url="https://download.pytorch.org/models/densenet201-c1103571.pth", url="https://download.pytorch.org/models/densenet201-c1103571.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"acc@1": 76.896, "acc@1": 76.896,
"acc@5": 93.370, "acc@5": 93.370,
}, },
...@@ -121,7 +121,7 @@ class DenseNet201Weights(Weights): ...@@ -121,7 +121,7 @@ class DenseNet201Weights(Weights):
def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet121Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None weights = DenseNet121Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = DenseNet121Weights.verify(weights) weights = DenseNet121Weights.verify(weights)
...@@ -130,7 +130,7 @@ def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = T ...@@ -130,7 +130,7 @@ def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = T
def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet161Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None weights = DenseNet161Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = DenseNet161Weights.verify(weights) weights = DenseNet161Weights.verify(weights)
...@@ -139,7 +139,7 @@ def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = T ...@@ -139,7 +139,7 @@ def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = T
def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet169Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None weights = DenseNet169Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = DenseNet169Weights.verify(weights) weights = DenseNet169Weights.verify(weights)
...@@ -148,7 +148,7 @@ def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = T ...@@ -148,7 +148,7 @@ def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = T
def densenet201(weights: Optional[DenseNet201Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: def densenet201(weights: Optional[DenseNet201Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet201Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None weights = DenseNet201Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = DenseNet201Weights.verify(weights) weights = DenseNet201Weights.verify(weights)
......
...@@ -30,7 +30,7 @@ __all__ = [ ...@@ -30,7 +30,7 @@ __all__ = [
] ]
_common_meta = { _COMMON_META = {
"categories": _COCO_CATEGORIES, "categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
} }
...@@ -41,7 +41,7 @@ class FasterRCNNResNet50FPNWeights(Weights): ...@@ -41,7 +41,7 @@ class FasterRCNNResNet50FPNWeights(Weights):
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
"map": 37.0, "map": 37.0,
}, },
...@@ -53,7 +53,7 @@ class FasterRCNNMobileNetV3LargeFPNWeights(Weights): ...@@ -53,7 +53,7 @@ class FasterRCNNMobileNetV3LargeFPNWeights(Weights):
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
"map": 32.8, "map": 32.8,
}, },
...@@ -65,7 +65,7 @@ class FasterRCNNMobileNetV3Large320FPNWeights(Weights): ...@@ -65,7 +65,7 @@ class FasterRCNNMobileNetV3Large320FPNWeights(Weights):
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
"map": 22.8, "map": 22.8,
}, },
...@@ -81,11 +81,11 @@ def fasterrcnn_resnet50_fpn( ...@@ -81,11 +81,11 @@ def fasterrcnn_resnet50_fpn(
**kwargs: Any, **kwargs: Any,
) -> FasterRCNN: ) -> FasterRCNN:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = FasterRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None weights = FasterRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNResNet50FPNWeights.verify(weights) weights = FasterRCNNResNet50FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet50Weights.verify(weights_backbone) weights_backbone = ResNet50Weights.verify(weights_backbone)
...@@ -102,7 +102,7 @@ def fasterrcnn_resnet50_fpn( ...@@ -102,7 +102,7 @@ def fasterrcnn_resnet50_fpn(
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs) model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == FasterRCNNResNet50FPNWeights.Coco_RefV1: if weights == FasterRCNNResNet50FPNWeights.Coco_RefV1:
overwrite_eps(model, 0.0) overwrite_eps(model, 0.0)
...@@ -142,7 +142,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn( ...@@ -142,7 +142,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
) )
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
...@@ -156,11 +156,11 @@ def fasterrcnn_mobilenet_v3_large_fpn( ...@@ -156,11 +156,11 @@ def fasterrcnn_mobilenet_v3_large_fpn(
**kwargs: Any, **kwargs: Any,
) -> FasterRCNN: ) -> FasterRCNN:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = FasterRCNNMobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None weights = FasterRCNNMobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNMobileNetV3LargeFPNWeights.verify(weights) weights = FasterRCNNMobileNetV3LargeFPNWeights.verify(weights)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
...@@ -188,11 +188,11 @@ def fasterrcnn_mobilenet_v3_large_320_fpn( ...@@ -188,11 +188,11 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
**kwargs: Any, **kwargs: Any,
) -> FasterRCNN: ) -> FasterRCNN:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None weights = FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNMobileNetV3Large320FPNWeights.verify(weights) weights = FasterRCNNMobileNetV3Large320FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
......
...@@ -22,7 +22,7 @@ __all__ = [ ...@@ -22,7 +22,7 @@ __all__ = [
] ]
_common_meta = {"categories": _COCO_PERSON_CATEGORIES, "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES} _COMMON_META = {"categories": _COCO_PERSON_CATEGORIES, "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES}
class KeypointRCNNResNet50FPNWeights(Weights): class KeypointRCNNResNet50FPNWeights(Weights):
...@@ -30,7 +30,7 @@ class KeypointRCNNResNet50FPNWeights(Weights): ...@@ -30,7 +30,7 @@ class KeypointRCNNResNet50FPNWeights(Weights):
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth", url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/issues/1606", "recipe": "https://github.com/pytorch/vision/issues/1606",
"box_map": 50.6, "box_map": 50.6,
"kp_map": 61.1, "kp_map": 61.1,
...@@ -40,7 +40,7 @@ class KeypointRCNNResNet50FPNWeights(Weights): ...@@ -40,7 +40,7 @@ class KeypointRCNNResNet50FPNWeights(Weights):
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
"box_map": 54.6, "box_map": 54.6,
"kp_map": 65.0, "kp_map": 65.0,
...@@ -58,7 +58,7 @@ def keypointrcnn_resnet50_fpn( ...@@ -58,7 +58,7 @@ def keypointrcnn_resnet50_fpn(
**kwargs: Any, **kwargs: Any,
) -> KeypointRCNN: ) -> KeypointRCNN:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
pretrained = kwargs.pop("pretrained") pretrained = kwargs.pop("pretrained")
if type(pretrained) == str and pretrained == "legacy": if type(pretrained) == str and pretrained == "legacy":
weights = KeypointRCNNResNet50FPNWeights.Coco_RefV1_Legacy weights = KeypointRCNNResNet50FPNWeights.Coco_RefV1_Legacy
...@@ -68,7 +68,7 @@ def keypointrcnn_resnet50_fpn( ...@@ -68,7 +68,7 @@ def keypointrcnn_resnet50_fpn(
weights = None weights = None
weights = KeypointRCNNResNet50FPNWeights.verify(weights) weights = KeypointRCNNResNet50FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet50Weights.verify(weights_backbone) weights_backbone = ResNet50Weights.verify(weights_backbone)
...@@ -86,7 +86,7 @@ def keypointrcnn_resnet50_fpn( ...@@ -86,7 +86,7 @@ def keypointrcnn_resnet50_fpn(
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == KeypointRCNNResNet50FPNWeights.Coco_RefV1: if weights == KeypointRCNNResNet50FPNWeights.Coco_RefV1:
overwrite_eps(model, 0.0) overwrite_eps(model, 0.0)
......
...@@ -46,11 +46,11 @@ def maskrcnn_resnet50_fpn( ...@@ -46,11 +46,11 @@ def maskrcnn_resnet50_fpn(
**kwargs: Any, **kwargs: Any,
) -> MaskRCNN: ) -> MaskRCNN:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = MaskRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None weights = MaskRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = MaskRCNNResNet50FPNWeights.verify(weights) weights = MaskRCNNResNet50FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet50Weights.verify(weights_backbone) weights_backbone = ResNet50Weights.verify(weights_backbone)
...@@ -67,7 +67,7 @@ def maskrcnn_resnet50_fpn( ...@@ -67,7 +67,7 @@ def maskrcnn_resnet50_fpn(
model = MaskRCNN(backbone, num_classes=num_classes, **kwargs) model = MaskRCNN(backbone, num_classes=num_classes, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == MaskRCNNResNet50FPNWeights.Coco_RefV1: if weights == MaskRCNNResNet50FPNWeights.Coco_RefV1:
overwrite_eps(model, 0.0) overwrite_eps(model, 0.0)
......
...@@ -46,11 +46,11 @@ def retinanet_resnet50_fpn( ...@@ -46,11 +46,11 @@ def retinanet_resnet50_fpn(
**kwargs: Any, **kwargs: Any,
) -> RetinaNet: ) -> RetinaNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RetinaNetResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None weights = RetinaNetResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = RetinaNetResNet50FPNWeights.verify(weights) weights = RetinaNetResNet50FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet50Weights.verify(weights_backbone) weights_backbone = ResNet50Weights.verify(weights_backbone)
...@@ -70,7 +70,7 @@ def retinanet_resnet50_fpn( ...@@ -70,7 +70,7 @@ def retinanet_resnet50_fpn(
model = RetinaNet(backbone, num_classes, **kwargs) model = RetinaNet(backbone, num_classes, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == RetinaNetResNet50FPNWeights.Coco_RefV1: if weights == RetinaNetResNet50FPNWeights.Coco_RefV1:
overwrite_eps(model, 0.0) overwrite_eps(model, 0.0)
......
...@@ -44,16 +44,16 @@ def ssd300_vgg16( ...@@ -44,16 +44,16 @@ def ssd300_vgg16(
**kwargs: Any, **kwargs: Any,
) -> SSD: ) -> SSD:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = SSD300VGG16Weights.Coco_RefV1 if kwargs.pop("pretrained") else None weights = SSD300VGG16Weights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = SSD300VGG16Weights.verify(weights) weights = SSD300VGG16Weights.verify(weights)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = VGG16Weights.ImageNet1K_Features if kwargs.pop("pretrained_backbone") else None weights_backbone = VGG16Weights.ImageNet1K_Features if kwargs.pop("pretrained_backbone") else None
weights_backbone = VGG16Weights.verify(weights_backbone) weights_backbone = VGG16Weights.verify(weights_backbone)
if "size" in kwargs: if "size" in kwargs:
warnings.warn("The size of the model is already fixed; ignoring the argument.") warnings.warn("The size of the model is already fixed; ignoring the parameter.")
if weights is not None: if weights is not None:
weights_backbone = None weights_backbone = None
...@@ -81,6 +81,6 @@ def ssd300_vgg16( ...@@ -81,6 +81,6 @@ def ssd300_vgg16(
model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs) model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
...@@ -50,16 +50,16 @@ def ssdlite320_mobilenet_v3_large( ...@@ -50,16 +50,16 @@ def ssdlite320_mobilenet_v3_large(
**kwargs: Any, **kwargs: Any,
) -> SSD: ) -> SSD:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = SSDlite320MobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None weights = SSDlite320MobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = SSDlite320MobileNetV3LargeFPNWeights.verify(weights) weights = SSDlite320MobileNetV3LargeFPNWeights.verify(weights)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
if "size" in kwargs: if "size" in kwargs:
warnings.warn("The size of the model is already fixed; ignoring the argument.") warnings.warn("The size of the model is already fixed; ignoring the parameter.")
if weights is not None: if weights is not None:
weights_backbone = None weights_backbone = None
...@@ -114,6 +114,6 @@ def ssdlite320_mobilenet_v3_large( ...@@ -114,6 +114,6 @@ def ssdlite320_mobilenet_v3_large(
) )
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
...@@ -57,12 +57,12 @@ def _efficientnet( ...@@ -57,12 +57,12 @@ def _efficientnet(
model = EfficientNet(inverted_residual_setting, dropout, **kwargs) model = EfficientNet(inverted_residual_setting, dropout, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
_common_meta = { _COMMON_META = {
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BICUBIC, "interpolation": InterpolationMode.BICUBIC,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet",
...@@ -74,7 +74,7 @@ class EfficientNetB0Weights(Weights): ...@@ -74,7 +74,7 @@ class EfficientNetB0Weights(Weights):
url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
**_common_meta, **_COMMON_META,
"size": (224, 224), "size": (224, 224),
"acc@1": 77.692, "acc@1": 77.692,
"acc@5": 93.532, "acc@5": 93.532,
...@@ -87,7 +87,7 @@ class EfficientNetB1Weights(Weights): ...@@ -87,7 +87,7 @@ class EfficientNetB1Weights(Weights):
url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
transforms=partial(ImageNetEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
**_common_meta, **_COMMON_META,
"size": (240, 240), "size": (240, 240),
"acc@1": 78.642, "acc@1": 78.642,
"acc@5": 94.186, "acc@5": 94.186,
...@@ -100,7 +100,7 @@ class EfficientNetB2Weights(Weights): ...@@ -100,7 +100,7 @@ class EfficientNetB2Weights(Weights):
url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
transforms=partial(ImageNetEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
**_common_meta, **_COMMON_META,
"size": (288, 288), "size": (288, 288),
"acc@1": 80.608, "acc@1": 80.608,
"acc@5": 95.310, "acc@5": 95.310,
...@@ -113,7 +113,7 @@ class EfficientNetB3Weights(Weights): ...@@ -113,7 +113,7 @@ class EfficientNetB3Weights(Weights):
url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
transforms=partial(ImageNetEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
**_common_meta, **_COMMON_META,
"size": (300, 300), "size": (300, 300),
"acc@1": 82.008, "acc@1": 82.008,
"acc@5": 96.054, "acc@5": 96.054,
...@@ -126,7 +126,7 @@ class EfficientNetB4Weights(Weights): ...@@ -126,7 +126,7 @@ class EfficientNetB4Weights(Weights):
url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
transforms=partial(ImageNetEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
**_common_meta, **_COMMON_META,
"size": (380, 380), "size": (380, 380),
"acc@1": 83.384, "acc@1": 83.384,
"acc@5": 96.594, "acc@5": 96.594,
...@@ -139,7 +139,7 @@ class EfficientNetB5Weights(Weights): ...@@ -139,7 +139,7 @@ class EfficientNetB5Weights(Weights):
url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
transforms=partial(ImageNetEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
**_common_meta, **_COMMON_META,
"size": (456, 456), "size": (456, 456),
"acc@1": 83.444, "acc@1": 83.444,
"acc@5": 96.628, "acc@5": 96.628,
...@@ -152,7 +152,7 @@ class EfficientNetB6Weights(Weights): ...@@ -152,7 +152,7 @@ class EfficientNetB6Weights(Weights):
url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
transforms=partial(ImageNetEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
**_common_meta, **_COMMON_META,
"size": (528, 528), "size": (528, 528),
"acc@1": 84.008, "acc@1": 84.008,
"acc@5": 96.916, "acc@5": 96.916,
...@@ -165,7 +165,7 @@ class EfficientNetB7Weights(Weights): ...@@ -165,7 +165,7 @@ class EfficientNetB7Weights(Weights):
url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
transforms=partial(ImageNetEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
**_common_meta, **_COMMON_META,
"size": (600, 600), "size": (600, 600),
"acc@1": 84.122, "acc@1": 84.122,
"acc@5": 96.908, "acc@5": 96.908,
...@@ -177,7 +177,7 @@ def efficientnet_b0( ...@@ -177,7 +177,7 @@ def efficientnet_b0(
weights: Optional[EfficientNetB0Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[EfficientNetB0Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = EfficientNetB0Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None weights = EfficientNetB0Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None
weights = EfficientNetB0Weights.verify(weights) weights = EfficientNetB0Weights.verify(weights)
return _efficientnet(width_mult=1.0, depth_mult=1.0, dropout=0.2, weights=weights, progress=progress, **kwargs) return _efficientnet(width_mult=1.0, depth_mult=1.0, dropout=0.2, weights=weights, progress=progress, **kwargs)
...@@ -187,7 +187,7 @@ def efficientnet_b1( ...@@ -187,7 +187,7 @@ def efficientnet_b1(
weights: Optional[EfficientNetB1Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[EfficientNetB1Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = EfficientNetB1Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None weights = EfficientNetB1Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None
weights = EfficientNetB1Weights.verify(weights) weights = EfficientNetB1Weights.verify(weights)
return _efficientnet(width_mult=1.0, depth_mult=1.1, dropout=0.2, weights=weights, progress=progress, **kwargs) return _efficientnet(width_mult=1.0, depth_mult=1.1, dropout=0.2, weights=weights, progress=progress, **kwargs)
...@@ -197,7 +197,7 @@ def efficientnet_b2( ...@@ -197,7 +197,7 @@ def efficientnet_b2(
weights: Optional[EfficientNetB2Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[EfficientNetB2Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = EfficientNetB2Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None weights = EfficientNetB2Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None
weights = EfficientNetB2Weights.verify(weights) weights = EfficientNetB2Weights.verify(weights)
return _efficientnet(width_mult=1.1, depth_mult=1.2, dropout=0.3, weights=weights, progress=progress, **kwargs) return _efficientnet(width_mult=1.1, depth_mult=1.2, dropout=0.3, weights=weights, progress=progress, **kwargs)
...@@ -207,7 +207,7 @@ def efficientnet_b3( ...@@ -207,7 +207,7 @@ def efficientnet_b3(
weights: Optional[EfficientNetB3Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[EfficientNetB3Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = EfficientNetB3Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None weights = EfficientNetB3Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None
weights = EfficientNetB3Weights.verify(weights) weights = EfficientNetB3Weights.verify(weights)
return _efficientnet(width_mult=1.2, depth_mult=1.4, dropout=0.3, weights=weights, progress=progress, **kwargs) return _efficientnet(width_mult=1.2, depth_mult=1.4, dropout=0.3, weights=weights, progress=progress, **kwargs)
...@@ -217,7 +217,7 @@ def efficientnet_b4( ...@@ -217,7 +217,7 @@ def efficientnet_b4(
weights: Optional[EfficientNetB4Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[EfficientNetB4Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = EfficientNetB4Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None weights = EfficientNetB4Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None
weights = EfficientNetB4Weights.verify(weights) weights = EfficientNetB4Weights.verify(weights)
return _efficientnet(width_mult=1.4, depth_mult=1.8, dropout=0.4, weights=weights, progress=progress, **kwargs) return _efficientnet(width_mult=1.4, depth_mult=1.8, dropout=0.4, weights=weights, progress=progress, **kwargs)
...@@ -227,7 +227,7 @@ def efficientnet_b5( ...@@ -227,7 +227,7 @@ def efficientnet_b5(
weights: Optional[EfficientNetB5Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[EfficientNetB5Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = EfficientNetB5Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None weights = EfficientNetB5Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = EfficientNetB5Weights.verify(weights) weights = EfficientNetB5Weights.verify(weights)
return _efficientnet( return _efficientnet(
...@@ -245,7 +245,7 @@ def efficientnet_b6( ...@@ -245,7 +245,7 @@ def efficientnet_b6(
weights: Optional[EfficientNetB6Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[EfficientNetB6Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = EfficientNetB6Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None weights = EfficientNetB6Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = EfficientNetB6Weights.verify(weights) weights = EfficientNetB6Weights.verify(weights)
return _efficientnet( return _efficientnet(
...@@ -263,7 +263,7 @@ def efficientnet_b7( ...@@ -263,7 +263,7 @@ def efficientnet_b7(
weights: Optional[EfficientNetB7Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[EfficientNetB7Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = EfficientNetB7Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None weights = EfficientNetB7Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = EfficientNetB7Weights.verify(weights) weights = EfficientNetB7Weights.verify(weights)
return _efficientnet( return _efficientnet(
......
...@@ -30,7 +30,7 @@ class GoogLeNetWeights(Weights): ...@@ -30,7 +30,7 @@ class GoogLeNetWeights(Weights):
def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = GoogLeNetWeights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None weights = GoogLeNetWeights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = GoogLeNetWeights.verify(weights) weights = GoogLeNetWeights.verify(weights)
...@@ -38,10 +38,6 @@ def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, ...@@ -38,10 +38,6 @@ def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True,
if weights is not None: if weights is not None:
if "transform_input" not in kwargs: if "transform_input" not in kwargs:
kwargs["transform_input"] = True kwargs["transform_input"] = True
if original_aux_logits:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
)
kwargs["aux_logits"] = True kwargs["aux_logits"] = True
kwargs["init_weights"] = False kwargs["init_weights"] = False
kwargs["num_classes"] = len(weights.meta["categories"]) kwargs["num_classes"] = len(weights.meta["categories"])
...@@ -49,10 +45,14 @@ def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, ...@@ -49,10 +45,14 @@ def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True,
model = GoogLeNet(**kwargs) model = GoogLeNet(**kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
if not original_aux_logits: if not original_aux_logits:
model.aux_logits = False model.aux_logits = False
model.aux1 = None # type: ignore[assignment] model.aux1 = None # type: ignore[assignment]
model.aux2 = None # type: ignore[assignment] model.aux2 = None # type: ignore[assignment]
else:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
)
return model return model
...@@ -30,7 +30,7 @@ class InceptionV3Weights(Weights): ...@@ -30,7 +30,7 @@ class InceptionV3Weights(Weights):
def inception_v3(weights: Optional[InceptionV3Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: def inception_v3(weights: Optional[InceptionV3Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = InceptionV3Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None weights = InceptionV3Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = InceptionV3Weights.verify(weights) weights = InceptionV3Weights.verify(weights)
...@@ -45,7 +45,7 @@ def inception_v3(weights: Optional[InceptionV3Weights] = None, progress: bool = ...@@ -45,7 +45,7 @@ def inception_v3(weights: Optional[InceptionV3Weights] = None, progress: bool =
model = Inception3(**kwargs) model = Inception3(**kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
if not original_aux_logits: if not original_aux_logits:
model.aux_logits = False model.aux_logits = False
model.AuxLogits = None model.AuxLogits = None
......
...@@ -23,7 +23,7 @@ __all__ = [ ...@@ -23,7 +23,7 @@ __all__ = [
] ]
_common_meta = { _COMMON_META = {
"size": (224, 224), "size": (224, 224),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
...@@ -36,7 +36,7 @@ class MNASNet0_5Weights(Weights): ...@@ -36,7 +36,7 @@ class MNASNet0_5Weights(Weights):
url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"acc@1": 67.734, "acc@1": 67.734,
"acc@5": 87.490, "acc@5": 87.490,
}, },
...@@ -53,7 +53,7 @@ class MNASNet1_0Weights(Weights): ...@@ -53,7 +53,7 @@ class MNASNet1_0Weights(Weights):
url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"acc@1": 73.456, "acc@1": 73.456,
"acc@5": 91.510, "acc@5": 91.510,
}, },
...@@ -72,14 +72,14 @@ def _mnasnet(alpha: float, weights: Optional[Weights], progress: bool, **kwargs: ...@@ -72,14 +72,14 @@ def _mnasnet(alpha: float, weights: Optional[Weights], progress: bool, **kwargs:
model = MNASNet(alpha, **kwargs) model = MNASNet(alpha, **kwargs)
if weights: if weights:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
def mnasnet0_5(weights: Optional[MNASNet0_5Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: def mnasnet0_5(weights: Optional[MNASNet0_5Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = MNASNet0_5Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None weights = MNASNet0_5Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = MNASNet0_5Weights.verify(weights) weights = MNASNet0_5Weights.verify(weights)
...@@ -89,7 +89,7 @@ def mnasnet0_5(weights: Optional[MNASNet0_5Weights] = None, progress: bool = Tru ...@@ -89,7 +89,7 @@ def mnasnet0_5(weights: Optional[MNASNet0_5Weights] = None, progress: bool = Tru
def mnasnet0_75(weights: Optional[MNASNet0_75Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: def mnasnet0_75(weights: Optional[MNASNet0_75Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"): if kwargs.pop("pretrained"):
raise ValueError("No checkpoint is available for model type mnasnet0_75") raise ValueError("No checkpoint is available for model type mnasnet0_75")
...@@ -100,7 +100,7 @@ def mnasnet0_75(weights: Optional[MNASNet0_75Weights] = None, progress: bool = T ...@@ -100,7 +100,7 @@ def mnasnet0_75(weights: Optional[MNASNet0_75Weights] = None, progress: bool = T
def mnasnet1_0(weights: Optional[MNASNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: def mnasnet1_0(weights: Optional[MNASNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = MNASNet1_0Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None weights = MNASNet1_0Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = MNASNet1_0Weights.verify(weights) weights = MNASNet1_0Weights.verify(weights)
...@@ -109,7 +109,7 @@ def mnasnet1_0(weights: Optional[MNASNet1_0Weights] = None, progress: bool = Tru ...@@ -109,7 +109,7 @@ def mnasnet1_0(weights: Optional[MNASNet1_0Weights] = None, progress: bool = Tru
def mnasnet1_3(weights: Optional[MNASNet1_3Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: def mnasnet1_3(weights: Optional[MNASNet1_3Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"): if kwargs.pop("pretrained"):
raise ValueError("No checkpoint is available for model type mnasnet1_3") raise ValueError("No checkpoint is available for model type mnasnet1_3")
......
...@@ -30,7 +30,7 @@ class MobileNetV2Weights(Weights): ...@@ -30,7 +30,7 @@ class MobileNetV2Weights(Weights):
def mobilenet_v2(weights: Optional[MobileNetV2Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2: def mobilenet_v2(weights: Optional[MobileNetV2Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = MobileNetV2Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = MobileNetV2Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = MobileNetV2Weights.verify(weights) weights = MobileNetV2Weights.verify(weights)
...@@ -40,6 +40,6 @@ def mobilenet_v2(weights: Optional[MobileNetV2Weights] = None, progress: bool = ...@@ -40,6 +40,6 @@ def mobilenet_v2(weights: Optional[MobileNetV2Weights] = None, progress: bool =
model = MobileNetV2(**kwargs) model = MobileNetV2(**kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
...@@ -32,12 +32,12 @@ def _mobilenet_v3( ...@@ -32,12 +32,12 @@ def _mobilenet_v3(
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
_common_meta = { _COMMON_META = {
"size": (224, 224), "size": (224, 224),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
...@@ -50,7 +50,7 @@ class MobileNetV3LargeWeights(Weights): ...@@ -50,7 +50,7 @@ class MobileNetV3LargeWeights(Weights):
url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"acc@1": 74.042, "acc@1": 74.042,
"acc@5": 91.340, "acc@5": 91.340,
}, },
...@@ -62,7 +62,7 @@ class MobileNetV3SmallWeights(Weights): ...@@ -62,7 +62,7 @@ class MobileNetV3SmallWeights(Weights):
url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"acc@1": 67.668, "acc@1": 67.668,
"acc@5": 87.402, "acc@5": 87.402,
}, },
...@@ -73,7 +73,7 @@ def mobilenet_v3_large( ...@@ -73,7 +73,7 @@ def mobilenet_v3_large(
weights: Optional[MobileNetV3LargeWeights] = None, progress: bool = True, **kwargs: Any weights: Optional[MobileNetV3LargeWeights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV3: ) -> MobileNetV3:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = MobileNetV3LargeWeights.verify(weights) weights = MobileNetV3LargeWeights.verify(weights)
...@@ -85,7 +85,7 @@ def mobilenet_v3_small( ...@@ -85,7 +85,7 @@ def mobilenet_v3_small(
weights: Optional[MobileNetV3SmallWeights] = None, progress: bool = True, **kwargs: Any weights: Optional[MobileNetV3SmallWeights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV3: ) -> MobileNetV3:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = MobileNetV3SmallWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = MobileNetV3SmallWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = MobileNetV3SmallWeights.verify(weights) weights = MobileNetV3SmallWeights.verify(weights)
......
...@@ -47,7 +47,7 @@ def googlenet( ...@@ -47,7 +47,7 @@ def googlenet(
**kwargs: Any, **kwargs: Any,
) -> QuantizableGoogLeNet: ) -> QuantizableGoogLeNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"): if kwargs.pop("pretrained"):
weights = QuantizedGoogLeNetWeights.ImageNet1K_FBGEMM_TFV1 if quantize else GoogLeNetWeights.ImageNet1K_TFV1 weights = QuantizedGoogLeNetWeights.ImageNet1K_FBGEMM_TFV1 if quantize else GoogLeNetWeights.ImageNet1K_TFV1
else: else:
...@@ -62,10 +62,6 @@ def googlenet( ...@@ -62,10 +62,6 @@ def googlenet(
if weights is not None: if weights is not None:
if "transform_input" not in kwargs: if "transform_input" not in kwargs:
kwargs["transform_input"] = True kwargs["transform_input"] = True
if original_aux_logits:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
)
kwargs["aux_logits"] = True kwargs["aux_logits"] = True
kwargs["init_weights"] = False kwargs["init_weights"] = False
kwargs["num_classes"] = len(weights.meta["categories"]) kwargs["num_classes"] = len(weights.meta["categories"])
...@@ -79,10 +75,14 @@ def googlenet( ...@@ -79,10 +75,14 @@ def googlenet(
quantize_model(model, backend) quantize_model(model, backend)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
if not original_aux_logits: if not original_aux_logits:
model.aux_logits = False model.aux_logits = False
model.aux1 = None # type: ignore[assignment] model.aux1 = None # type: ignore[assignment]
model.aux2 = None # type: ignore[assignment] model.aux2 = None # type: ignore[assignment]
else:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
)
return model return model
...@@ -47,7 +47,7 @@ def inception_v3( ...@@ -47,7 +47,7 @@ def inception_v3(
**kwargs: Any, **kwargs: Any,
) -> QuantizableInception3: ) -> QuantizableInception3:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"): if kwargs.pop("pretrained"):
weights = ( weights = (
QuantizedInceptionV3Weights.ImageNet1K_FBGEMM_TFV1 if quantize else InceptionV3Weights.ImageNet1K_TFV1 QuantizedInceptionV3Weights.ImageNet1K_FBGEMM_TFV1 if quantize else InceptionV3Weights.ImageNet1K_TFV1
...@@ -79,7 +79,7 @@ def inception_v3( ...@@ -79,7 +79,7 @@ def inception_v3(
if quantize and not original_aux_logits: if quantize and not original_aux_logits:
model.aux_logits = False model.aux_logits = False
model.AuxLogits = None model.AuxLogits = None
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
if not quantize and not original_aux_logits: if not quantize and not original_aux_logits:
model.aux_logits = False model.aux_logits = False
model.AuxLogits = None model.AuxLogits = None
......
...@@ -48,7 +48,7 @@ def mobilenet_v2( ...@@ -48,7 +48,7 @@ def mobilenet_v2(
**kwargs: Any, **kwargs: Any,
) -> QuantizableMobileNetV2: ) -> QuantizableMobileNetV2:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"): if kwargs.pop("pretrained"):
weights = ( weights = (
QuantizedMobileNetV2Weights.ImageNet1K_QNNPACK_RefV1 QuantizedMobileNetV2Weights.ImageNet1K_QNNPACK_RefV1
...@@ -75,6 +75,6 @@ def mobilenet_v2( ...@@ -75,6 +75,6 @@ def mobilenet_v2(
quantize_model(model, backend) quantize_model(model, backend)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
...@@ -47,7 +47,7 @@ def _mobilenet_v3_model( ...@@ -47,7 +47,7 @@ def _mobilenet_v3_model(
torch.quantization.prepare_qat(model, inplace=True) torch.quantization.prepare_qat(model, inplace=True)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
if quantize: if quantize:
torch.quantization.convert(model, inplace=True) torch.quantization.convert(model, inplace=True)
...@@ -81,7 +81,7 @@ def mobilenet_v3_large( ...@@ -81,7 +81,7 @@ def mobilenet_v3_large(
**kwargs: Any, **kwargs: Any,
) -> QuantizableMobileNetV3: ) -> QuantizableMobileNetV3:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"): if kwargs.pop("pretrained"):
weights = ( weights = (
QuantizedMobileNetV3LargeWeights.ImageNet1K_QNNPACK_RefV1 QuantizedMobileNetV3LargeWeights.ImageNet1K_QNNPACK_RefV1
......
...@@ -48,12 +48,12 @@ def _resnet( ...@@ -48,12 +48,12 @@ def _resnet(
quantize_model(model, backend) quantize_model(model, backend)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
_common_meta = { _COMMON_META = {
"size": (224, 224), "size": (224, 224),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
...@@ -68,7 +68,7 @@ class QuantizedResNet18Weights(Weights): ...@@ -68,7 +68,7 @@ class QuantizedResNet18Weights(Weights):
url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"unquantized": ResNet18Weights.ImageNet1K_RefV1, "unquantized": ResNet18Weights.ImageNet1K_RefV1,
"acc@1": 69.494, "acc@1": 69.494,
"acc@5": 88.882, "acc@5": 88.882,
...@@ -81,7 +81,7 @@ class QuantizedResNet50Weights(Weights): ...@@ -81,7 +81,7 @@ class QuantizedResNet50Weights(Weights):
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"unquantized": ResNet50Weights.ImageNet1K_RefV1, "unquantized": ResNet50Weights.ImageNet1K_RefV1,
"acc@1": 75.920, "acc@1": 75.920,
"acc@5": 92.814, "acc@5": 92.814,
...@@ -94,7 +94,7 @@ class QuantizedResNeXt101_32x8dWeights(Weights): ...@@ -94,7 +94,7 @@ class QuantizedResNeXt101_32x8dWeights(Weights):
url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"unquantized": ResNeXt101_32x8dWeights.ImageNet1K_RefV1, "unquantized": ResNeXt101_32x8dWeights.ImageNet1K_RefV1,
"acc@1": 78.986, "acc@1": 78.986,
"acc@5": 94.480, "acc@5": 94.480,
...@@ -109,7 +109,7 @@ def resnet18( ...@@ -109,7 +109,7 @@ def resnet18(
**kwargs: Any, **kwargs: Any,
) -> QuantizableResNet: ) -> QuantizableResNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"): if kwargs.pop("pretrained"):
weights = QuantizedResNet18Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet18Weights.ImageNet1K_RefV1 weights = QuantizedResNet18Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet18Weights.ImageNet1K_RefV1
else: else:
...@@ -130,7 +130,7 @@ def resnet50( ...@@ -130,7 +130,7 @@ def resnet50(
**kwargs: Any, **kwargs: Any,
) -> QuantizableResNet: ) -> QuantizableResNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"): if kwargs.pop("pretrained"):
weights = QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet50Weights.ImageNet1K_RefV1 weights = QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet50Weights.ImageNet1K_RefV1
else: else:
...@@ -151,7 +151,7 @@ def resnext101_32x8d( ...@@ -151,7 +151,7 @@ def resnext101_32x8d(
**kwargs: Any, **kwargs: Any,
) -> QuantizableResNet: ) -> QuantizableResNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"): if kwargs.pop("pretrained"):
weights = ( weights = (
QuantizedResNeXt101_32x8dWeights.ImageNet1K_FBGEMM_RefV1 QuantizedResNeXt101_32x8dWeights.ImageNet1K_FBGEMM_RefV1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment