Unverified Commit 18cf5ab1 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

More Multiweight support cleanups (#4948)

* Updated the link for densenet recipe.

* Set default value of `num_classes` and `num_keypoints` to `None`

* Provide helper methods for parameter checks to reduce duplicate code.

* Throw errors on silent config overwrites from weight meta-data and legacy builders.

* Changing order of arguments + fixing mypy.

* Make the builders fully BC.

* Add "default" weights support that returns always the best weights.
parent 09e759ea
......@@ -31,14 +31,19 @@ def get_models_with_module_names(module):
@pytest.mark.parametrize(
"model_fn, weight",
"model_fn, name, weight",
[
(models.resnet50, models.ResNet50Weights.ImageNet1K_RefV2),
(models.quantization.resnet50, models.quantization.QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1),
(models.resnet50, "ImageNet1K_RefV1", models.ResNet50Weights.ImageNet1K_RefV1),
(models.resnet50, "default", models.ResNet50Weights.ImageNet1K_RefV2),
(
models.quantization.resnet50,
"ImageNet1K_FBGEMM_RefV1",
models.quantization.QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1,
),
],
)
def test_get_weight(model_fn, weight):
assert models._api.get_weight(model_fn, weight.name) == weight
def test_get_weight(model_fn, name, weight):
assert models._api.get_weight(model_fn, name) == weight
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))
......
......@@ -30,6 +30,7 @@ class WeightEntry:
url: str
transforms: Callable
meta: Dict[str, Any]
default: bool
class Weights(Enum):
......@@ -59,7 +60,7 @@ class Weights(Enum):
@classmethod
def from_str(cls, value: str) -> "Weights":
for v in cls:
if v._name_ == value:
if v._name_ == value or (value == "default" and v.default):
return v
raise ValueError(f"Invalid value {value} for enum {cls.__name__}.")
......
import warnings
from typing import Any, Dict, Optional, TypeVar
from ._api import Weights
W = TypeVar("W", bound=Weights)
V = TypeVar("V")
def _deprecated_param(
kwargs: Dict[str, Any], deprecated_param: str, new_param: str, default_value: Optional[W]
) -> Optional[W]:
warnings.warn(f"The parameter '{deprecated_param}' is deprecated, please use '{new_param}' instead.")
if kwargs.pop(deprecated_param):
if default_value is not None:
return default_value
else:
raise ValueError("No checkpoint is available for model.")
else:
return None
def _deprecated_positional(kwargs: Dict[str, Any], deprecated_param: str, new_param: str, default_value: V) -> None:
warnings.warn(
f"The positional parameter '{deprecated_param}' is deprecated, please use keyword parameter '{new_param}'"
+ " instead."
)
kwargs[deprecated_param] = default_value
def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None:
if param in kwargs:
if kwargs[param] != new_value:
raise ValueError(f"The parameter '{param}' expected value {new_value} but got {kwargs[param]} instead.")
else:
kwargs[param] = new_value
def _ovewrite_value_param(param: Optional[V], new_value: V) -> V:
if param is not None:
if param != new_value:
raise ValueError(f"The parameter '{param}' expected value {new_value} but got {param} instead.")
return new_value
import warnings
from functools import partial
from typing import Any, Optional
......@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.alexnet import AlexNet
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = ["AlexNet", "AlexNetWeights", "alexnet"]
......@@ -25,16 +25,19 @@ class AlexNetWeights(Weights):
"acc@1": 56.522,
"acc@5": 79.066,
},
default=True,
)
def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = AlexNetWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", AlexNetWeights.ImageNet1K_RefV1)
weights = AlexNetWeights.verify(weights)
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = AlexNet(**kwargs)
......
import re
import warnings
from functools import partial
from typing import Any, Optional, Tuple
......@@ -10,6 +9,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.densenet import DenseNet
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [
......@@ -53,7 +53,7 @@ def _densenet(
**kwargs: Any,
) -> DenseNet:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
......@@ -67,7 +67,7 @@ _COMMON_META = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": None, # TODO: add here a URL to documentation stating that the weights were ported from LuaTorch
"recipe": "https://github.com/pytorch/vision/pull/116",
}
......@@ -80,6 +80,7 @@ class DenseNet121Weights(Weights):
"acc@1": 74.434,
"acc@5": 91.972,
},
default=True,
)
......@@ -92,6 +93,7 @@ class DenseNet161Weights(Weights):
"acc@1": 77.138,
"acc@5": 93.560,
},
default=True,
)
......@@ -104,6 +106,7 @@ class DenseNet169Weights(Weights):
"acc@1": 75.600,
"acc@5": 92.806,
},
default=True,
)
......@@ -116,40 +119,45 @@ class DenseNet201Weights(Weights):
"acc@1": 76.896,
"acc@5": 93.370,
},
default=True,
)
def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet121Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet121Weights.ImageNet1K_Community)
weights = DenseNet121Weights.verify(weights)
return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet161Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet161Weights.ImageNet1K_Community)
weights = DenseNet161Weights.verify(weights)
return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)
def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet169Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet169Weights.ImageNet1K_Community)
weights = DenseNet169Weights.verify(weights)
return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)
def densenet201(weights: Optional[DenseNet201Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet201Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet201Weights.ImageNet1K_Community)
weights = DenseNet201Weights.verify(weights)
return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)
import warnings
from typing import Any, Optional, Union
from torchvision.prototype.transforms import CocoEval
......@@ -15,6 +14,7 @@ from ....models.detection.faster_rcnn import (
)
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
from ..resnet import ResNet50Weights, resnet50
......@@ -45,6 +45,7 @@ class FasterRCNNResNet50FPNWeights(Weights):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
"map": 37.0,
},
default=True,
)
......@@ -57,6 +58,7 @@ class FasterRCNNMobileNetV3LargeFPNWeights(Weights):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
"map": 32.8,
},
default=True,
)
......@@ -69,29 +71,36 @@ class FasterRCNNMobileNetV3Large320FPNWeights(Weights):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
"map": 22.8,
},
default=True,
)
def fasterrcnn_resnet50_fpn(
weights: Optional[FasterRCNNResNet50FPNWeights] = None,
weights_backbone: Optional[ResNet50Weights] = None,
progress: bool = True,
num_classes: int = 91,
num_classes: Optional[int] = None,
weights_backbone: Optional[ResNet50Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = FasterRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNNResNet50FPNWeights.Coco_RefV1)
weights = FasterRCNNResNet50FPNWeights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1
)
weights_backbone = ResNet50Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = len(weights.meta["categories"])
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91
trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
......@@ -110,16 +119,18 @@ def fasterrcnn_resnet50_fpn(
def _fasterrcnn_mobilenet_v3_large_fpn(
weights: Optional[Union[FasterRCNNMobileNetV3LargeFPNWeights, FasterRCNNMobileNetV3Large320FPNWeights]] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
progress: bool = True,
num_classes: int = 91,
trainable_backbone_layers: Optional[int] = None,
weights: Optional[Union[FasterRCNNMobileNetV3LargeFPNWeights, FasterRCNNMobileNetV3Large320FPNWeights]],
progress: bool,
num_classes: Optional[int],
weights_backbone: Optional[MobileNetV3LargeWeights],
trainable_backbone_layers: Optional[int],
**kwargs: Any,
) -> FasterRCNN:
if weights is not None:
weights_backbone = None
num_classes = len(weights.meta["categories"])
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91
trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 3
......@@ -149,19 +160,23 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
def fasterrcnn_mobilenet_v3_large_fpn(
weights: Optional[FasterRCNNMobileNetV3LargeFPNWeights] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
progress: bool = True,
num_classes: int = 91,
num_classes: Optional[int] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = FasterRCNNMobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNNMobileNetV3LargeFPNWeights.Coco_RefV1)
weights = FasterRCNNMobileNetV3LargeFPNWeights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1
)
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
defaults = {
......@@ -171,9 +186,9 @@ def fasterrcnn_mobilenet_v3_large_fpn(
kwargs = {**defaults, **kwargs}
return _fasterrcnn_mobilenet_v3_large_fpn(
weights,
weights_backbone,
progress,
num_classes,
weights_backbone,
trainable_backbone_layers,
**kwargs,
)
......@@ -181,19 +196,23 @@ def fasterrcnn_mobilenet_v3_large_fpn(
def fasterrcnn_mobilenet_v3_large_320_fpn(
weights: Optional[FasterRCNNMobileNetV3Large320FPNWeights] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
progress: bool = True,
num_classes: int = 91,
num_classes: Optional[int] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1)
weights = FasterRCNNMobileNetV3Large320FPNWeights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1
)
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
defaults = {
......@@ -207,9 +226,9 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
kwargs = {**defaults, **kwargs}
return _fasterrcnn_mobilenet_v3_large_fpn(
weights,
weights_backbone,
progress,
num_classes,
weights_backbone,
trainable_backbone_layers,
**kwargs,
)
import warnings
from typing import Any, Optional
from torchvision.prototype.transforms import CocoEval
......@@ -12,6 +11,7 @@ from ....models.detection.keypoint_rcnn import (
)
from .._api import Weights, WeightEntry
from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..resnet import ResNet50Weights, resnet50
......@@ -35,6 +35,7 @@ class KeypointRCNNResNet50FPNWeights(Weights):
"box_map": 50.6,
"kp_map": 61.1,
},
default=False,
)
Coco_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
......@@ -45,37 +46,45 @@ class KeypointRCNNResNet50FPNWeights(Weights):
"box_map": 54.6,
"kp_map": 65.0,
},
default=True,
)
def keypointrcnn_resnet50_fpn(
weights: Optional[KeypointRCNNResNet50FPNWeights] = None,
weights_backbone: Optional[ResNet50Weights] = None,
progress: bool = True,
num_classes: int = 2,
num_keypoints: int = 17,
num_classes: Optional[int] = None,
num_keypoints: Optional[int] = None,
weights_backbone: Optional[ResNet50Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> KeypointRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
pretrained = kwargs.pop("pretrained")
if type(pretrained) == str and pretrained == "legacy":
weights = KeypointRCNNResNet50FPNWeights.Coco_RefV1_Legacy
elif type(pretrained) == bool and pretrained:
weights = KeypointRCNNResNet50FPNWeights.Coco_RefV1
else:
weights = None
default_value = KeypointRCNNResNet50FPNWeights.Coco_RefV1
if kwargs["pretrained"] == "legacy":
default_value = KeypointRCNNResNet50FPNWeights.Coco_RefV1_Legacy
kwargs["pretrained"] = True
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value)
weights = KeypointRCNNResNet50FPNWeights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1
)
weights_backbone = ResNet50Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = len(weights.meta["categories"])
num_keypoints = len(weights.meta["keypoint_names"])
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
num_keypoints = _ovewrite_value_param(num_keypoints, len(weights.meta["keypoint_names"]))
else:
if num_classes is None:
num_classes = 2
if num_keypoints is None:
num_keypoints = 17
trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
......
import warnings
from typing import Any, Optional
from torchvision.prototype.transforms import CocoEval
......@@ -13,6 +12,7 @@ from ....models.detection.mask_rcnn import (
)
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..resnet import ResNet50Weights, resnet50
......@@ -34,29 +34,36 @@ class MaskRCNNResNet50FPNWeights(Weights):
"box_map": 37.9,
"mask_map": 34.6,
},
default=True,
)
def maskrcnn_resnet50_fpn(
weights: Optional[MaskRCNNResNet50FPNWeights] = None,
weights_backbone: Optional[ResNet50Weights] = None,
progress: bool = True,
num_classes: int = 91,
num_classes: Optional[int] = None,
weights_backbone: Optional[ResNet50Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> MaskRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = MaskRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", MaskRCNNResNet50FPNWeights.Coco_RefV1)
weights = MaskRCNNResNet50FPNWeights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1
)
weights_backbone = ResNet50Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = len(weights.meta["categories"])
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91
trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
......
import warnings
from typing import Any, Optional
from torchvision.prototype.transforms import CocoEval
......@@ -14,6 +13,7 @@ from ....models.detection.retinanet import (
)
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..resnet import ResNet50Weights, resnet50
......@@ -34,29 +34,36 @@ class RetinaNetResNet50FPNWeights(Weights):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
"map": 36.4,
},
default=True,
)
def retinanet_resnet50_fpn(
weights: Optional[RetinaNetResNet50FPNWeights] = None,
weights_backbone: Optional[ResNet50Weights] = None,
progress: bool = True,
num_classes: int = 91,
num_classes: Optional[int] = None,
weights_backbone: Optional[ResNet50Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> RetinaNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RetinaNetResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", RetinaNetResNet50FPNWeights.Coco_RefV1)
weights = RetinaNetResNet50FPNWeights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1
)
weights_backbone = ResNet50Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = len(weights.meta["categories"])
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91
trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
......
......@@ -12,6 +12,7 @@ from ....models.detection.ssd import (
)
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..vgg import VGG16Weights, vgg16
......@@ -32,24 +33,29 @@ class SSD300VGG16Weights(Weights):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16",
"map": 25.1,
},
default=True,
)
def ssd300_vgg16(
weights: Optional[SSD300VGG16Weights] = None,
weights_backbone: Optional[VGG16Weights] = None,
progress: bool = True,
num_classes: int = 91,
num_classes: Optional[int] = None,
weights_backbone: Optional[VGG16Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> SSD:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = SSD300VGG16Weights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", SSD300VGG16Weights.Coco_RefV1)
weights = SSD300VGG16Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = VGG16Weights.ImageNet1K_Features if kwargs.pop("pretrained_backbone") else None
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", VGG16Weights.ImageNet1K_Features
)
weights_backbone = VGG16Weights.verify(weights_backbone)
if "size" in kwargs:
......@@ -57,7 +63,9 @@ def ssd300_vgg16(
if weights is not None:
weights_backbone = None
num_classes = len(weights.meta["categories"])
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91
trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 4
......
......@@ -17,6 +17,7 @@ from ....models.detection.ssdlite import (
)
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
......@@ -37,25 +38,30 @@ class SSDlite320MobileNetV3LargeFPNWeights(Weights):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssdlite320-mobilenetv3-large",
"map": 21.3,
},
default=True,
)
def ssdlite320_mobilenet_v3_large(
weights: Optional[SSDlite320MobileNetV3LargeFPNWeights] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
progress: bool = True,
num_classes: int = 91,
num_classes: Optional[int] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
trainable_backbone_layers: Optional[int] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any,
) -> SSD:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = SSDlite320MobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", SSDlite320MobileNetV3LargeFPNWeights.Coco_RefV1)
weights = SSDlite320MobileNetV3LargeFPNWeights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1
)
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
if "size" in kwargs:
......@@ -63,7 +69,9 @@ def ssdlite320_mobilenet_v3_large(
if weights is not None:
weights_backbone = None
num_classes = len(weights.meta["categories"])
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91
trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 6
......
import warnings
from functools import partial
from typing import Any, Optional
......@@ -9,6 +8,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.efficientnet import EfficientNet, MBConvConfig
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [
......@@ -41,7 +41,7 @@ def _efficientnet(
**kwargs: Any,
) -> EfficientNet:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
bneck_conf = partial(MBConvConfig, width_mult=width_mult, depth_mult=depth_mult)
inverted_residual_setting = [
......@@ -79,6 +79,7 @@ class EfficientNetB0Weights(Weights):
"acc@1": 77.692,
"acc@5": 93.532,
},
default=True,
)
......@@ -92,6 +93,7 @@ class EfficientNetB1Weights(Weights):
"acc@1": 78.642,
"acc@5": 94.186,
},
default=True,
)
......@@ -105,6 +107,7 @@ class EfficientNetB2Weights(Weights):
"acc@1": 80.608,
"acc@5": 95.310,
},
default=True,
)
......@@ -118,6 +121,7 @@ class EfficientNetB3Weights(Weights):
"acc@1": 82.008,
"acc@5": 96.054,
},
default=True,
)
......@@ -131,6 +135,7 @@ class EfficientNetB4Weights(Weights):
"acc@1": 83.384,
"acc@5": 96.594,
},
default=True,
)
......@@ -144,6 +149,7 @@ class EfficientNetB5Weights(Weights):
"acc@1": 83.444,
"acc@5": 96.628,
},
default=True,
)
......@@ -157,6 +163,7 @@ class EfficientNetB6Weights(Weights):
"acc@1": 84.008,
"acc@5": 96.916,
},
default=True,
)
......@@ -170,66 +177,79 @@ class EfficientNetB7Weights(Weights):
"acc@1": 84.122,
"acc@5": 96.908,
},
default=True,
)
def efficientnet_b0(
weights: Optional[EfficientNetB0Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = EfficientNetB0Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB0Weights.ImageNet1K_TimmV1)
weights = EfficientNetB0Weights.verify(weights)
return _efficientnet(width_mult=1.0, depth_mult=1.0, dropout=0.2, weights=weights, progress=progress, **kwargs)
def efficientnet_b1(
weights: Optional[EfficientNetB1Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = EfficientNetB1Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB1Weights.ImageNet1K_TimmV1)
weights = EfficientNetB1Weights.verify(weights)
return _efficientnet(width_mult=1.0, depth_mult=1.1, dropout=0.2, weights=weights, progress=progress, **kwargs)
def efficientnet_b2(
weights: Optional[EfficientNetB2Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = EfficientNetB2Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB2Weights.ImageNet1K_TimmV1)
weights = EfficientNetB2Weights.verify(weights)
return _efficientnet(width_mult=1.1, depth_mult=1.2, dropout=0.3, weights=weights, progress=progress, **kwargs)
def efficientnet_b3(
weights: Optional[EfficientNetB3Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = EfficientNetB3Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB3Weights.ImageNet1K_TimmV1)
weights = EfficientNetB3Weights.verify(weights)
return _efficientnet(width_mult=1.2, depth_mult=1.4, dropout=0.3, weights=weights, progress=progress, **kwargs)
def efficientnet_b4(
weights: Optional[EfficientNetB4Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = EfficientNetB4Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB4Weights.ImageNet1K_TimmV1)
weights = EfficientNetB4Weights.verify(weights)
return _efficientnet(width_mult=1.4, depth_mult=1.8, dropout=0.4, weights=weights, progress=progress, **kwargs)
def efficientnet_b5(
weights: Optional[EfficientNetB5Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = EfficientNetB5Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB5Weights.ImageNet1K_TFV1)
weights = EfficientNetB5Weights.verify(weights)
return _efficientnet(
width_mult=1.6,
depth_mult=2.2,
......@@ -244,10 +264,12 @@ def efficientnet_b5(
def efficientnet_b6(
weights: Optional[EfficientNetB6Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = EfficientNetB6Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB6Weights.ImageNet1K_TFV1)
weights = EfficientNetB6Weights.verify(weights)
return _efficientnet(
width_mult=1.8,
depth_mult=2.6,
......@@ -262,10 +284,12 @@ def efficientnet_b6(
def efficientnet_b7(
weights: Optional[EfficientNetB7Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = EfficientNetB7Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB7Weights.ImageNet1K_TFV1)
weights = EfficientNetB7Weights.verify(weights)
return _efficientnet(
width_mult=2.0,
depth_mult=3.1,
......
......@@ -8,6 +8,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNetWeights", "googlenet"]
......@@ -25,22 +26,24 @@ class GoogLeNetWeights(Weights):
"acc@1": 69.778,
"acc@5": 89.530,
},
default=True,
)
def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = GoogLeNetWeights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", GoogLeNetWeights.ImageNet1K_TFV1)
weights = GoogLeNetWeights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
kwargs["aux_logits"] = True
kwargs["init_weights"] = False
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "transform_input", True)
_ovewrite_named_param(kwargs, "aux_logits", True)
_ovewrite_named_param(kwargs, "init_weights", False)
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = GoogLeNet(**kwargs)
......
import warnings
from functools import partial
from typing import Any, Optional
......@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "InceptionV3Weights", "inception_v3"]
......@@ -25,22 +25,24 @@ class InceptionV3Weights(Weights):
"acc@1": 77.294,
"acc@5": 93.450,
},
default=True,
)
def inception_v3(weights: Optional[InceptionV3Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = InceptionV3Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", InceptionV3Weights.ImageNet1K_TFV1)
weights = InceptionV3Weights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", True)
if weights is not None:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
kwargs["aux_logits"] = True
kwargs["init_weights"] = False
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "transform_input", True)
_ovewrite_named_param(kwargs, "aux_logits", True)
_ovewrite_named_param(kwargs, "init_weights", False)
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = Inception3(**kwargs)
......
import warnings
from functools import partial
from typing import Any, Optional
......@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.mnasnet import MNASNet
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [
......@@ -40,6 +40,7 @@ class MNASNet0_5Weights(Weights):
"acc@1": 67.734,
"acc@5": 87.490,
},
default=True,
)
......@@ -57,6 +58,7 @@ class MNASNet1_0Weights(Weights):
"acc@1": 73.456,
"acc@5": 91.510,
},
default=True,
)
......@@ -67,7 +69,7 @@ class MNASNet1_3Weights(Weights):
def _mnasnet(alpha: float, weights: Optional[Weights], progress: bool, **kwargs: Any) -> MNASNet:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = MNASNet(alpha, **kwargs)
......@@ -78,41 +80,40 @@ def _mnasnet(alpha: float, weights: Optional[Weights], progress: bool, **kwargs:
def mnasnet0_5(weights: Optional[MNASNet0_5Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = MNASNet0_5Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet0_5Weights.ImageNet1K_Community)
weights = MNASNet0_5Weights.verify(weights)
return _mnasnet(0.5, weights, progress, **kwargs)
def mnasnet0_75(weights: Optional[MNASNet0_75Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
raise ValueError("No checkpoint is available for model type mnasnet0_75")
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = MNASNet0_75Weights.verify(weights)
return _mnasnet(0.75, weights, progress, **kwargs)
def mnasnet1_0(weights: Optional[MNASNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = MNASNet1_0Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet1_0Weights.ImageNet1K_Community)
weights = MNASNet1_0Weights.verify(weights)
return _mnasnet(1.0, weights, progress, **kwargs)
def mnasnet1_3(weights: Optional[MNASNet1_3Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
raise ValueError("No checkpoint is available for model type mnasnet1_3")
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = MNASNet1_3Weights.verify(weights)
return _mnasnet(1.3, weights, progress, **kwargs)
import warnings
from functools import partial
from typing import Any, Optional
......@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.mobilenetv2 import MobileNetV2
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = ["MobileNetV2", "MobileNetV2Weights", "mobilenet_v2"]
......@@ -25,17 +25,19 @@ class MobileNetV2Weights(Weights):
"acc@1": 71.878,
"acc@5": 90.286,
},
default=True,
)
def mobilenet_v2(weights: Optional[MobileNetV2Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = MobileNetV2Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNetV2Weights.ImageNet1K_RefV1)
weights = MobileNetV2Weights.verify(weights)
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = MobileNetV2(**kwargs)
......
import warnings
from functools import partial
from typing import Any, Optional, List
......@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [
......@@ -27,7 +27,7 @@ def _mobilenet_v3(
**kwargs: Any,
) -> MobileNetV3:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
......@@ -54,6 +54,7 @@ class MobileNetV3LargeWeights(Weights):
"acc@1": 74.042,
"acc@5": 91.340,
},
default=False,
)
ImageNet1K_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth",
......@@ -64,6 +65,7 @@ class MobileNetV3LargeWeights(Weights):
"acc@1": 75.274,
"acc@5": 92.566,
},
default=True,
)
......@@ -77,15 +79,17 @@ class MobileNetV3SmallWeights(Weights):
"acc@1": 67.668,
"acc@5": 87.402,
},
default=True,
)
def mobilenet_v3_large(
weights: Optional[MobileNetV3LargeWeights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNetV3LargeWeights.ImageNet1K_RefV1)
weights = MobileNetV3LargeWeights.verify(weights)
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
......@@ -95,9 +99,10 @@ def mobilenet_v3_large(
def mobilenet_v3_small(
weights: Optional[MobileNetV3SmallWeights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = MobileNetV3SmallWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNetV3SmallWeights.ImageNet1K_RefV1)
weights = MobileNetV3SmallWeights.verify(weights)
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs)
......
......@@ -12,6 +12,7 @@ from ....models.quantization.googlenet import (
)
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ..googlenet import GoogLeNetWeights
......@@ -37,6 +38,7 @@ class QuantizedGoogLeNetWeights(Weights):
"acc@1": 69.826,
"acc@5": 89.404,
},
default=True,
)
......@@ -46,13 +48,13 @@ def googlenet(
quantize: bool = False,
**kwargs: Any,
) -> QuantizableGoogLeNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = QuantizedGoogLeNetWeights.ImageNet1K_FBGEMM_TFV1 if quantize else GoogLeNetWeights.ImageNet1K_TFV1
else:
weights = None
default_value = (
QuantizedGoogLeNetWeights.ImageNet1K_FBGEMM_TFV1 if quantize else GoogLeNetWeights.ImageNet1K_TFV1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = QuantizedGoogLeNetWeights.verify(weights)
else:
......@@ -61,12 +63,12 @@ def googlenet(
original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
kwargs["aux_logits"] = True
kwargs["init_weights"] = False
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "transform_input", True)
_ovewrite_named_param(kwargs, "aux_logits", True)
_ovewrite_named_param(kwargs, "init_weights", False)
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"]
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "fbgemm")
model = QuantizableGoogLeNet(**kwargs)
......
import warnings
from functools import partial
from typing import Any, Optional, Union
......@@ -12,6 +11,7 @@ from ....models.quantization.inception import (
)
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ..inception import InceptionV3Weights
......@@ -37,6 +37,7 @@ class QuantizedInceptionV3Weights(Weights):
"acc@1": 77.176,
"acc@5": 93.354,
},
default=True,
)
......@@ -46,15 +47,13 @@ def inception_v3(
quantize: bool = False,
**kwargs: Any,
) -> QuantizableInception3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = (
QuantizedInceptionV3Weights.ImageNet1K_FBGEMM_TFV1 if quantize else InceptionV3Weights.ImageNet1K_TFV1
)
else:
weights = None
default_value = (
QuantizedInceptionV3Weights.ImageNet1K_FBGEMM_TFV1 if quantize else InceptionV3Weights.ImageNet1K_TFV1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = QuantizedInceptionV3Weights.verify(weights)
else:
......@@ -63,11 +62,11 @@ def inception_v3(
original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
kwargs["aux_logits"] = True
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "transform_input", True)
_ovewrite_named_param(kwargs, "aux_logits", True)
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"]
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "fbgemm")
model = QuantizableInception3(**kwargs)
......
import warnings
from functools import partial
from typing import Any, Optional, Union
......@@ -13,6 +12,7 @@ from ....models.quantization.mobilenetv2 import (
)
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ..mobilenetv2 import MobileNetV2Weights
......@@ -38,6 +38,7 @@ class QuantizedMobileNetV2Weights(Weights):
"acc@1": 71.658,
"acc@5": 90.150,
},
default=True,
)
......@@ -47,26 +48,22 @@ def mobilenet_v2(
quantize: bool = False,
**kwargs: Any,
) -> QuantizableMobileNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = (
QuantizedMobileNetV2Weights.ImageNet1K_QNNPACK_RefV1
if quantize
else MobileNetV2Weights.ImageNet1K_RefV1
)
else:
weights = None
default_value = (
QuantizedMobileNetV2Weights.ImageNet1K_QNNPACK_RefV1 if quantize else MobileNetV2Weights.ImageNet1K_RefV1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = QuantizedMobileNetV2Weights.verify(weights)
else:
weights = MobileNetV2Weights.verify(weights)
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"]
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "qnnpack")
model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment