Unverified Commit 18cf5ab1 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

More Multiweight support cleanups (#4948)

* Updated the link for densenet recipe.

* Set default value of `num_classes` and `num_keypoints` to `None`

* Provide helper methods for parameter checks to reduce duplicate code.

* Throw errors on silent config overwrites from weight meta-data and legacy builders.

* Changing order of arguments + fixing mypy.

* Make the builders fully BC.

* Add "default" weights support that returns always the best weights.
parent 09e759ea
...@@ -31,14 +31,19 @@ def get_models_with_module_names(module): ...@@ -31,14 +31,19 @@ def get_models_with_module_names(module):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_fn, weight", "model_fn, name, weight",
[ [
(models.resnet50, models.ResNet50Weights.ImageNet1K_RefV2), (models.resnet50, "ImageNet1K_RefV1", models.ResNet50Weights.ImageNet1K_RefV1),
(models.quantization.resnet50, models.quantization.QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1), (models.resnet50, "default", models.ResNet50Weights.ImageNet1K_RefV2),
(
models.quantization.resnet50,
"ImageNet1K_FBGEMM_RefV1",
models.quantization.QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1,
),
], ],
) )
def test_get_weight(model_fn, weight): def test_get_weight(model_fn, name, weight):
assert models._api.get_weight(model_fn, weight.name) == weight assert models._api.get_weight(model_fn, name) == weight
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models)) @pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))
......
...@@ -30,6 +30,7 @@ class WeightEntry: ...@@ -30,6 +30,7 @@ class WeightEntry:
url: str url: str
transforms: Callable transforms: Callable
meta: Dict[str, Any] meta: Dict[str, Any]
default: bool
class Weights(Enum): class Weights(Enum):
...@@ -59,7 +60,7 @@ class Weights(Enum): ...@@ -59,7 +60,7 @@ class Weights(Enum):
@classmethod @classmethod
def from_str(cls, value: str) -> "Weights": def from_str(cls, value: str) -> "Weights":
for v in cls: for v in cls:
if v._name_ == value: if v._name_ == value or (value == "default" and v.default):
return v return v
raise ValueError(f"Invalid value {value} for enum {cls.__name__}.") raise ValueError(f"Invalid value {value} for enum {cls.__name__}.")
......
import warnings
from typing import Any, Dict, Optional, TypeVar
from ._api import Weights
W = TypeVar("W", bound=Weights)
V = TypeVar("V")
def _deprecated_param(
kwargs: Dict[str, Any], deprecated_param: str, new_param: str, default_value: Optional[W]
) -> Optional[W]:
warnings.warn(f"The parameter '{deprecated_param}' is deprecated, please use '{new_param}' instead.")
if kwargs.pop(deprecated_param):
if default_value is not None:
return default_value
else:
raise ValueError("No checkpoint is available for model.")
else:
return None
def _deprecated_positional(kwargs: Dict[str, Any], deprecated_param: str, new_param: str, default_value: V) -> None:
warnings.warn(
f"The positional parameter '{deprecated_param}' is deprecated, please use keyword parameter '{new_param}'"
+ " instead."
)
kwargs[deprecated_param] = default_value
def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None:
if param in kwargs:
if kwargs[param] != new_value:
raise ValueError(f"The parameter '{param}' expected value {new_value} but got {kwargs[param]} instead.")
else:
kwargs[param] = new_value
def _ovewrite_value_param(param: Optional[V], new_value: V) -> V:
if param is not None:
if param != new_value:
raise ValueError(f"The parameter '{param}' expected value {new_value} but got {param} instead.")
return new_value
import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
...@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.alexnet import AlexNet from ...models.alexnet import AlexNet
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = ["AlexNet", "AlexNetWeights", "alexnet"] __all__ = ["AlexNet", "AlexNetWeights", "alexnet"]
...@@ -25,16 +25,19 @@ class AlexNetWeights(Weights): ...@@ -25,16 +25,19 @@ class AlexNetWeights(Weights):
"acc@1": 56.522, "acc@1": 56.522,
"acc@5": 79.066, "acc@5": 79.066,
}, },
default=True,
) )
def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", AlexNetWeights.ImageNet1K_RefV1)
weights = AlexNetWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = AlexNetWeights.verify(weights) weights = AlexNetWeights.verify(weights)
if weights is not None: if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"]) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = AlexNet(**kwargs) model = AlexNet(**kwargs)
......
import re import re
import warnings
from functools import partial from functools import partial
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple
...@@ -10,6 +9,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -10,6 +9,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.densenet import DenseNet from ...models.densenet import DenseNet
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [ __all__ = [
...@@ -53,7 +53,7 @@ def _densenet( ...@@ -53,7 +53,7 @@ def _densenet(
**kwargs: Any, **kwargs: Any,
) -> DenseNet: ) -> DenseNet:
if weights is not None: if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"]) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
...@@ -67,7 +67,7 @@ _COMMON_META = { ...@@ -67,7 +67,7 @@ _COMMON_META = {
"size": (224, 224), "size": (224, 224),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
"recipe": None, # TODO: add here a URL to documentation stating that the weights were ported from LuaTorch "recipe": "https://github.com/pytorch/vision/pull/116",
} }
...@@ -80,6 +80,7 @@ class DenseNet121Weights(Weights): ...@@ -80,6 +80,7 @@ class DenseNet121Weights(Weights):
"acc@1": 74.434, "acc@1": 74.434,
"acc@5": 91.972, "acc@5": 91.972,
}, },
default=True,
) )
...@@ -92,6 +93,7 @@ class DenseNet161Weights(Weights): ...@@ -92,6 +93,7 @@ class DenseNet161Weights(Weights):
"acc@1": 77.138, "acc@1": 77.138,
"acc@5": 93.560, "acc@5": 93.560,
}, },
default=True,
) )
...@@ -104,6 +106,7 @@ class DenseNet169Weights(Weights): ...@@ -104,6 +106,7 @@ class DenseNet169Weights(Weights):
"acc@1": 75.600, "acc@1": 75.600,
"acc@5": 92.806, "acc@5": 92.806,
}, },
default=True,
) )
...@@ -116,40 +119,45 @@ class DenseNet201Weights(Weights): ...@@ -116,40 +119,45 @@ class DenseNet201Weights(Weights):
"acc@1": 76.896, "acc@1": 76.896,
"acc@5": 93.370, "acc@5": 93.370,
}, },
default=True,
) )
def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet121Weights.ImageNet1K_Community)
weights = DenseNet121Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = DenseNet121Weights.verify(weights) weights = DenseNet121Weights.verify(weights)
return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs) return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet161Weights.ImageNet1K_Community)
weights = DenseNet161Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = DenseNet161Weights.verify(weights) weights = DenseNet161Weights.verify(weights)
return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs) return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)
def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet169Weights.ImageNet1K_Community)
weights = DenseNet169Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = DenseNet169Weights.verify(weights) weights = DenseNet169Weights.verify(weights)
return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs) return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)
def densenet201(weights: Optional[DenseNet201Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: def densenet201(weights: Optional[DenseNet201Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet201Weights.ImageNet1K_Community)
weights = DenseNet201Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = DenseNet201Weights.verify(weights) weights = DenseNet201Weights.verify(weights)
return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs) return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)
import warnings
from typing import Any, Optional, Union from typing import Any, Optional, Union
from torchvision.prototype.transforms import CocoEval from torchvision.prototype.transforms import CocoEval
...@@ -15,6 +14,7 @@ from ....models.detection.faster_rcnn import ( ...@@ -15,6 +14,7 @@ from ....models.detection.faster_rcnn import (
) )
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
from ..resnet import ResNet50Weights, resnet50 from ..resnet import ResNet50Weights, resnet50
...@@ -45,6 +45,7 @@ class FasterRCNNResNet50FPNWeights(Weights): ...@@ -45,6 +45,7 @@ class FasterRCNNResNet50FPNWeights(Weights):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
"map": 37.0, "map": 37.0,
}, },
default=True,
) )
...@@ -57,6 +58,7 @@ class FasterRCNNMobileNetV3LargeFPNWeights(Weights): ...@@ -57,6 +58,7 @@ class FasterRCNNMobileNetV3LargeFPNWeights(Weights):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
"map": 32.8, "map": 32.8,
}, },
default=True,
) )
...@@ -69,29 +71,36 @@ class FasterRCNNMobileNetV3Large320FPNWeights(Weights): ...@@ -69,29 +71,36 @@ class FasterRCNNMobileNetV3Large320FPNWeights(Weights):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
"map": 22.8, "map": 22.8,
}, },
default=True,
) )
def fasterrcnn_resnet50_fpn( def fasterrcnn_resnet50_fpn(
weights: Optional[FasterRCNNResNet50FPNWeights] = None, weights: Optional[FasterRCNNResNet50FPNWeights] = None,
weights_backbone: Optional[ResNet50Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: int = 91, num_classes: Optional[int] = None,
weights_backbone: Optional[ResNet50Weights] = None,
trainable_backbone_layers: Optional[int] = None, trainable_backbone_layers: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> FasterRCNN: ) -> FasterRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNNResNet50FPNWeights.Coco_RefV1)
weights = FasterRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNResNet50FPNWeights.verify(weights) weights = FasterRCNNResNet50FPNWeights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.") weights_backbone = _deprecated_param(
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1
)
weights_backbone = ResNet50Weights.verify(weights_backbone) weights_backbone = ResNet50Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
weights_backbone = None weights_backbone = None
num_classes = len(weights.meta["categories"]) num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91
trainable_backbone_layers = _validate_trainable_layers( trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3 weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
...@@ -110,16 +119,18 @@ def fasterrcnn_resnet50_fpn( ...@@ -110,16 +119,18 @@ def fasterrcnn_resnet50_fpn(
def _fasterrcnn_mobilenet_v3_large_fpn( def _fasterrcnn_mobilenet_v3_large_fpn(
weights: Optional[Union[FasterRCNNMobileNetV3LargeFPNWeights, FasterRCNNMobileNetV3Large320FPNWeights]] = None, weights: Optional[Union[FasterRCNNMobileNetV3LargeFPNWeights, FasterRCNNMobileNetV3Large320FPNWeights]],
weights_backbone: Optional[MobileNetV3LargeWeights] = None, progress: bool,
progress: bool = True, num_classes: Optional[int],
num_classes: int = 91, weights_backbone: Optional[MobileNetV3LargeWeights],
trainable_backbone_layers: Optional[int] = None, trainable_backbone_layers: Optional[int],
**kwargs: Any, **kwargs: Any,
) -> FasterRCNN: ) -> FasterRCNN:
if weights is not None: if weights is not None:
weights_backbone = None weights_backbone = None
num_classes = len(weights.meta["categories"]) num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91
trainable_backbone_layers = _validate_trainable_layers( trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 3 weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 3
...@@ -149,19 +160,23 @@ def _fasterrcnn_mobilenet_v3_large_fpn( ...@@ -149,19 +160,23 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
def fasterrcnn_mobilenet_v3_large_fpn( def fasterrcnn_mobilenet_v3_large_fpn(
weights: Optional[FasterRCNNMobileNetV3LargeFPNWeights] = None, weights: Optional[FasterRCNNMobileNetV3LargeFPNWeights] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
progress: bool = True, progress: bool = True,
num_classes: int = 91, num_classes: Optional[int] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
trainable_backbone_layers: Optional[int] = None, trainable_backbone_layers: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> FasterRCNN: ) -> FasterRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNNMobileNetV3LargeFPNWeights.Coco_RefV1)
weights = FasterRCNNMobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNMobileNetV3LargeFPNWeights.verify(weights) weights = FasterRCNNMobileNetV3LargeFPNWeights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.") weights_backbone = _deprecated_param(
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1
)
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
defaults = { defaults = {
...@@ -171,9 +186,9 @@ def fasterrcnn_mobilenet_v3_large_fpn( ...@@ -171,9 +186,9 @@ def fasterrcnn_mobilenet_v3_large_fpn(
kwargs = {**defaults, **kwargs} kwargs = {**defaults, **kwargs}
return _fasterrcnn_mobilenet_v3_large_fpn( return _fasterrcnn_mobilenet_v3_large_fpn(
weights, weights,
weights_backbone,
progress, progress,
num_classes, num_classes,
weights_backbone,
trainable_backbone_layers, trainable_backbone_layers,
**kwargs, **kwargs,
) )
...@@ -181,19 +196,23 @@ def fasterrcnn_mobilenet_v3_large_fpn( ...@@ -181,19 +196,23 @@ def fasterrcnn_mobilenet_v3_large_fpn(
def fasterrcnn_mobilenet_v3_large_320_fpn( def fasterrcnn_mobilenet_v3_large_320_fpn(
weights: Optional[FasterRCNNMobileNetV3Large320FPNWeights] = None, weights: Optional[FasterRCNNMobileNetV3Large320FPNWeights] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
progress: bool = True, progress: bool = True,
num_classes: int = 91, num_classes: Optional[int] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
trainable_backbone_layers: Optional[int] = None, trainable_backbone_layers: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> FasterRCNN: ) -> FasterRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1)
weights = FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNMobileNetV3Large320FPNWeights.verify(weights) weights = FasterRCNNMobileNetV3Large320FPNWeights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.") weights_backbone = _deprecated_param(
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1
)
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
defaults = { defaults = {
...@@ -207,9 +226,9 @@ def fasterrcnn_mobilenet_v3_large_320_fpn( ...@@ -207,9 +226,9 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
kwargs = {**defaults, **kwargs} kwargs = {**defaults, **kwargs}
return _fasterrcnn_mobilenet_v3_large_fpn( return _fasterrcnn_mobilenet_v3_large_fpn(
weights, weights,
weights_backbone,
progress, progress,
num_classes, num_classes,
weights_backbone,
trainable_backbone_layers, trainable_backbone_layers,
**kwargs, **kwargs,
) )
import warnings
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import CocoEval from torchvision.prototype.transforms import CocoEval
...@@ -12,6 +11,7 @@ from ....models.detection.keypoint_rcnn import ( ...@@ -12,6 +11,7 @@ from ....models.detection.keypoint_rcnn import (
) )
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..resnet import ResNet50Weights, resnet50 from ..resnet import ResNet50Weights, resnet50
...@@ -35,6 +35,7 @@ class KeypointRCNNResNet50FPNWeights(Weights): ...@@ -35,6 +35,7 @@ class KeypointRCNNResNet50FPNWeights(Weights):
"box_map": 50.6, "box_map": 50.6,
"kp_map": 61.1, "kp_map": 61.1,
}, },
default=False,
) )
Coco_RefV1 = WeightEntry( Coco_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
...@@ -45,37 +46,45 @@ class KeypointRCNNResNet50FPNWeights(Weights): ...@@ -45,37 +46,45 @@ class KeypointRCNNResNet50FPNWeights(Weights):
"box_map": 54.6, "box_map": 54.6,
"kp_map": 65.0, "kp_map": 65.0,
}, },
default=True,
) )
def keypointrcnn_resnet50_fpn( def keypointrcnn_resnet50_fpn(
weights: Optional[KeypointRCNNResNet50FPNWeights] = None, weights: Optional[KeypointRCNNResNet50FPNWeights] = None,
weights_backbone: Optional[ResNet50Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: int = 2, num_classes: Optional[int] = None,
num_keypoints: int = 17, num_keypoints: Optional[int] = None,
weights_backbone: Optional[ResNet50Weights] = None,
trainable_backbone_layers: Optional[int] = None, trainable_backbone_layers: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> KeypointRCNN: ) -> KeypointRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") default_value = KeypointRCNNResNet50FPNWeights.Coco_RefV1
pretrained = kwargs.pop("pretrained") if kwargs["pretrained"] == "legacy":
if type(pretrained) == str and pretrained == "legacy": default_value = KeypointRCNNResNet50FPNWeights.Coco_RefV1_Legacy
weights = KeypointRCNNResNet50FPNWeights.Coco_RefV1_Legacy kwargs["pretrained"] = True
elif type(pretrained) == bool and pretrained: weights = _deprecated_param(kwargs, "pretrained", "weights", default_value)
weights = KeypointRCNNResNet50FPNWeights.Coco_RefV1
else:
weights = None
weights = KeypointRCNNResNet50FPNWeights.verify(weights) weights = KeypointRCNNResNet50FPNWeights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.") weights_backbone = _deprecated_param(
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1
)
weights_backbone = ResNet50Weights.verify(weights_backbone) weights_backbone = ResNet50Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
weights_backbone = None weights_backbone = None
num_classes = len(weights.meta["categories"]) num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
num_keypoints = len(weights.meta["keypoint_names"]) num_keypoints = _ovewrite_value_param(num_keypoints, len(weights.meta["keypoint_names"]))
else:
if num_classes is None:
num_classes = 2
if num_keypoints is None:
num_keypoints = 17
trainable_backbone_layers = _validate_trainable_layers( trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3 weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
......
import warnings
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import CocoEval from torchvision.prototype.transforms import CocoEval
...@@ -13,6 +12,7 @@ from ....models.detection.mask_rcnn import ( ...@@ -13,6 +12,7 @@ from ....models.detection.mask_rcnn import (
) )
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..resnet import ResNet50Weights, resnet50 from ..resnet import ResNet50Weights, resnet50
...@@ -34,29 +34,36 @@ class MaskRCNNResNet50FPNWeights(Weights): ...@@ -34,29 +34,36 @@ class MaskRCNNResNet50FPNWeights(Weights):
"box_map": 37.9, "box_map": 37.9,
"mask_map": 34.6, "mask_map": 34.6,
}, },
default=True,
) )
def maskrcnn_resnet50_fpn( def maskrcnn_resnet50_fpn(
weights: Optional[MaskRCNNResNet50FPNWeights] = None, weights: Optional[MaskRCNNResNet50FPNWeights] = None,
weights_backbone: Optional[ResNet50Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: int = 91, num_classes: Optional[int] = None,
weights_backbone: Optional[ResNet50Weights] = None,
trainable_backbone_layers: Optional[int] = None, trainable_backbone_layers: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> MaskRCNN: ) -> MaskRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", MaskRCNNResNet50FPNWeights.Coco_RefV1)
weights = MaskRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = MaskRCNNResNet50FPNWeights.verify(weights) weights = MaskRCNNResNet50FPNWeights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.") weights_backbone = _deprecated_param(
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1
)
weights_backbone = ResNet50Weights.verify(weights_backbone) weights_backbone = ResNet50Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
weights_backbone = None weights_backbone = None
num_classes = len(weights.meta["categories"]) num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91
trainable_backbone_layers = _validate_trainable_layers( trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3 weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
......
import warnings
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import CocoEval from torchvision.prototype.transforms import CocoEval
...@@ -14,6 +13,7 @@ from ....models.detection.retinanet import ( ...@@ -14,6 +13,7 @@ from ....models.detection.retinanet import (
) )
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..resnet import ResNet50Weights, resnet50 from ..resnet import ResNet50Weights, resnet50
...@@ -34,29 +34,36 @@ class RetinaNetResNet50FPNWeights(Weights): ...@@ -34,29 +34,36 @@ class RetinaNetResNet50FPNWeights(Weights):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
"map": 36.4, "map": 36.4,
}, },
default=True,
) )
def retinanet_resnet50_fpn( def retinanet_resnet50_fpn(
weights: Optional[RetinaNetResNet50FPNWeights] = None, weights: Optional[RetinaNetResNet50FPNWeights] = None,
weights_backbone: Optional[ResNet50Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: int = 91, num_classes: Optional[int] = None,
weights_backbone: Optional[ResNet50Weights] = None,
trainable_backbone_layers: Optional[int] = None, trainable_backbone_layers: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> RetinaNet: ) -> RetinaNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", RetinaNetResNet50FPNWeights.Coco_RefV1)
weights = RetinaNetResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = RetinaNetResNet50FPNWeights.verify(weights) weights = RetinaNetResNet50FPNWeights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.") weights_backbone = _deprecated_param(
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1
)
weights_backbone = ResNet50Weights.verify(weights_backbone) weights_backbone = ResNet50Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
weights_backbone = None weights_backbone = None
num_classes = len(weights.meta["categories"]) num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91
trainable_backbone_layers = _validate_trainable_layers( trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3 weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
......
...@@ -12,6 +12,7 @@ from ....models.detection.ssd import ( ...@@ -12,6 +12,7 @@ from ....models.detection.ssd import (
) )
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..vgg import VGG16Weights, vgg16 from ..vgg import VGG16Weights, vgg16
...@@ -32,24 +33,29 @@ class SSD300VGG16Weights(Weights): ...@@ -32,24 +33,29 @@ class SSD300VGG16Weights(Weights):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16",
"map": 25.1, "map": 25.1,
}, },
default=True,
) )
def ssd300_vgg16( def ssd300_vgg16(
weights: Optional[SSD300VGG16Weights] = None, weights: Optional[SSD300VGG16Weights] = None,
weights_backbone: Optional[VGG16Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: int = 91, num_classes: Optional[int] = None,
weights_backbone: Optional[VGG16Weights] = None,
trainable_backbone_layers: Optional[int] = None, trainable_backbone_layers: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> SSD: ) -> SSD:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", SSD300VGG16Weights.Coco_RefV1)
weights = SSD300VGG16Weights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = SSD300VGG16Weights.verify(weights) weights = SSD300VGG16Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.") weights_backbone = _deprecated_param(
weights_backbone = VGG16Weights.ImageNet1K_Features if kwargs.pop("pretrained_backbone") else None kwargs, "pretrained_backbone", "weights_backbone", VGG16Weights.ImageNet1K_Features
)
weights_backbone = VGG16Weights.verify(weights_backbone) weights_backbone = VGG16Weights.verify(weights_backbone)
if "size" in kwargs: if "size" in kwargs:
...@@ -57,7 +63,9 @@ def ssd300_vgg16( ...@@ -57,7 +63,9 @@ def ssd300_vgg16(
if weights is not None: if weights is not None:
weights_backbone = None weights_backbone = None
num_classes = len(weights.meta["categories"]) num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91
trainable_backbone_layers = _validate_trainable_layers( trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 4 weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 4
......
...@@ -17,6 +17,7 @@ from ....models.detection.ssdlite import ( ...@@ -17,6 +17,7 @@ from ....models.detection.ssdlite import (
) )
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
...@@ -37,25 +38,30 @@ class SSDlite320MobileNetV3LargeFPNWeights(Weights): ...@@ -37,25 +38,30 @@ class SSDlite320MobileNetV3LargeFPNWeights(Weights):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssdlite320-mobilenetv3-large", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssdlite320-mobilenetv3-large",
"map": 21.3, "map": 21.3,
}, },
default=True,
) )
def ssdlite320_mobilenet_v3_large( def ssdlite320_mobilenet_v3_large(
weights: Optional[SSDlite320MobileNetV3LargeFPNWeights] = None, weights: Optional[SSDlite320MobileNetV3LargeFPNWeights] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
progress: bool = True, progress: bool = True,
num_classes: int = 91, num_classes: Optional[int] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
trainable_backbone_layers: Optional[int] = None, trainable_backbone_layers: Optional[int] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any, **kwargs: Any,
) -> SSD: ) -> SSD:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", SSDlite320MobileNetV3LargeFPNWeights.Coco_RefV1)
weights = SSDlite320MobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = SSDlite320MobileNetV3LargeFPNWeights.verify(weights) weights = SSDlite320MobileNetV3LargeFPNWeights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.") weights_backbone = _deprecated_param(
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1
)
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
if "size" in kwargs: if "size" in kwargs:
...@@ -63,7 +69,9 @@ def ssdlite320_mobilenet_v3_large( ...@@ -63,7 +69,9 @@ def ssdlite320_mobilenet_v3_large(
if weights is not None: if weights is not None:
weights_backbone = None weights_backbone = None
num_classes = len(weights.meta["categories"]) num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91
trainable_backbone_layers = _validate_trainable_layers( trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 6 weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 6
......
import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
...@@ -9,6 +8,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -9,6 +8,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.efficientnet import EfficientNet, MBConvConfig from ...models.efficientnet import EfficientNet, MBConvConfig
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [ __all__ = [
...@@ -41,7 +41,7 @@ def _efficientnet( ...@@ -41,7 +41,7 @@ def _efficientnet(
**kwargs: Any, **kwargs: Any,
) -> EfficientNet: ) -> EfficientNet:
if weights is not None: if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"]) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
bneck_conf = partial(MBConvConfig, width_mult=width_mult, depth_mult=depth_mult) bneck_conf = partial(MBConvConfig, width_mult=width_mult, depth_mult=depth_mult)
inverted_residual_setting = [ inverted_residual_setting = [
...@@ -79,6 +79,7 @@ class EfficientNetB0Weights(Weights): ...@@ -79,6 +79,7 @@ class EfficientNetB0Weights(Weights):
"acc@1": 77.692, "acc@1": 77.692,
"acc@5": 93.532, "acc@5": 93.532,
}, },
default=True,
) )
...@@ -92,6 +93,7 @@ class EfficientNetB1Weights(Weights): ...@@ -92,6 +93,7 @@ class EfficientNetB1Weights(Weights):
"acc@1": 78.642, "acc@1": 78.642,
"acc@5": 94.186, "acc@5": 94.186,
}, },
default=True,
) )
...@@ -105,6 +107,7 @@ class EfficientNetB2Weights(Weights): ...@@ -105,6 +107,7 @@ class EfficientNetB2Weights(Weights):
"acc@1": 80.608, "acc@1": 80.608,
"acc@5": 95.310, "acc@5": 95.310,
}, },
default=True,
) )
...@@ -118,6 +121,7 @@ class EfficientNetB3Weights(Weights): ...@@ -118,6 +121,7 @@ class EfficientNetB3Weights(Weights):
"acc@1": 82.008, "acc@1": 82.008,
"acc@5": 96.054, "acc@5": 96.054,
}, },
default=True,
) )
...@@ -131,6 +135,7 @@ class EfficientNetB4Weights(Weights): ...@@ -131,6 +135,7 @@ class EfficientNetB4Weights(Weights):
"acc@1": 83.384, "acc@1": 83.384,
"acc@5": 96.594, "acc@5": 96.594,
}, },
default=True,
) )
...@@ -144,6 +149,7 @@ class EfficientNetB5Weights(Weights): ...@@ -144,6 +149,7 @@ class EfficientNetB5Weights(Weights):
"acc@1": 83.444, "acc@1": 83.444,
"acc@5": 96.628, "acc@5": 96.628,
}, },
default=True,
) )
...@@ -157,6 +163,7 @@ class EfficientNetB6Weights(Weights): ...@@ -157,6 +163,7 @@ class EfficientNetB6Weights(Weights):
"acc@1": 84.008, "acc@1": 84.008,
"acc@5": 96.916, "acc@5": 96.916,
}, },
default=True,
) )
...@@ -170,66 +177,79 @@ class EfficientNetB7Weights(Weights): ...@@ -170,66 +177,79 @@ class EfficientNetB7Weights(Weights):
"acc@1": 84.122, "acc@1": 84.122,
"acc@5": 96.908, "acc@5": 96.908,
}, },
default=True,
) )
def efficientnet_b0( def efficientnet_b0(
weights: Optional[EfficientNetB0Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[EfficientNetB0Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB0Weights.ImageNet1K_TimmV1)
weights = EfficientNetB0Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None
weights = EfficientNetB0Weights.verify(weights) weights = EfficientNetB0Weights.verify(weights)
return _efficientnet(width_mult=1.0, depth_mult=1.0, dropout=0.2, weights=weights, progress=progress, **kwargs) return _efficientnet(width_mult=1.0, depth_mult=1.0, dropout=0.2, weights=weights, progress=progress, **kwargs)
def efficientnet_b1( def efficientnet_b1(
weights: Optional[EfficientNetB1Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[EfficientNetB1Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB1Weights.ImageNet1K_TimmV1)
weights = EfficientNetB1Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None
weights = EfficientNetB1Weights.verify(weights) weights = EfficientNetB1Weights.verify(weights)
return _efficientnet(width_mult=1.0, depth_mult=1.1, dropout=0.2, weights=weights, progress=progress, **kwargs) return _efficientnet(width_mult=1.0, depth_mult=1.1, dropout=0.2, weights=weights, progress=progress, **kwargs)
def efficientnet_b2( def efficientnet_b2(
weights: Optional[EfficientNetB2Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[EfficientNetB2Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB2Weights.ImageNet1K_TimmV1)
weights = EfficientNetB2Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None
weights = EfficientNetB2Weights.verify(weights) weights = EfficientNetB2Weights.verify(weights)
return _efficientnet(width_mult=1.1, depth_mult=1.2, dropout=0.3, weights=weights, progress=progress, **kwargs) return _efficientnet(width_mult=1.1, depth_mult=1.2, dropout=0.3, weights=weights, progress=progress, **kwargs)
def efficientnet_b3( def efficientnet_b3(
weights: Optional[EfficientNetB3Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[EfficientNetB3Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB3Weights.ImageNet1K_TimmV1)
weights = EfficientNetB3Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None
weights = EfficientNetB3Weights.verify(weights) weights = EfficientNetB3Weights.verify(weights)
return _efficientnet(width_mult=1.2, depth_mult=1.4, dropout=0.3, weights=weights, progress=progress, **kwargs) return _efficientnet(width_mult=1.2, depth_mult=1.4, dropout=0.3, weights=weights, progress=progress, **kwargs)
def efficientnet_b4( def efficientnet_b4(
weights: Optional[EfficientNetB4Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[EfficientNetB4Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB4Weights.ImageNet1K_TimmV1)
weights = EfficientNetB4Weights.ImageNet1K_TimmV1 if kwargs.pop("pretrained") else None
weights = EfficientNetB4Weights.verify(weights) weights = EfficientNetB4Weights.verify(weights)
return _efficientnet(width_mult=1.4, depth_mult=1.8, dropout=0.4, weights=weights, progress=progress, **kwargs) return _efficientnet(width_mult=1.4, depth_mult=1.8, dropout=0.4, weights=weights, progress=progress, **kwargs)
def efficientnet_b5( def efficientnet_b5(
weights: Optional[EfficientNetB5Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[EfficientNetB5Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB5Weights.ImageNet1K_TFV1)
weights = EfficientNetB5Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = EfficientNetB5Weights.verify(weights) weights = EfficientNetB5Weights.verify(weights)
return _efficientnet( return _efficientnet(
width_mult=1.6, width_mult=1.6,
depth_mult=2.2, depth_mult=2.2,
...@@ -244,10 +264,12 @@ def efficientnet_b5( ...@@ -244,10 +264,12 @@ def efficientnet_b5(
def efficientnet_b6( def efficientnet_b6(
weights: Optional[EfficientNetB6Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[EfficientNetB6Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB6Weights.ImageNet1K_TFV1)
weights = EfficientNetB6Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = EfficientNetB6Weights.verify(weights) weights = EfficientNetB6Weights.verify(weights)
return _efficientnet( return _efficientnet(
width_mult=1.8, width_mult=1.8,
depth_mult=2.6, depth_mult=2.6,
...@@ -262,10 +284,12 @@ def efficientnet_b6( ...@@ -262,10 +284,12 @@ def efficientnet_b6(
def efficientnet_b7( def efficientnet_b7(
weights: Optional[EfficientNetB7Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[EfficientNetB7Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB7Weights.ImageNet1K_TFV1)
weights = EfficientNetB7Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = EfficientNetB7Weights.verify(weights) weights = EfficientNetB7Weights.verify(weights)
return _efficientnet( return _efficientnet(
width_mult=2.0, width_mult=2.0,
depth_mult=3.1, depth_mult=3.1,
......
...@@ -8,6 +8,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -8,6 +8,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNetWeights", "googlenet"] __all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNetWeights", "googlenet"]
...@@ -25,22 +26,24 @@ class GoogLeNetWeights(Weights): ...@@ -25,22 +26,24 @@ class GoogLeNetWeights(Weights):
"acc@1": 69.778, "acc@1": 69.778,
"acc@5": 89.530, "acc@5": 89.530,
}, },
default=True,
) )
def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", GoogLeNetWeights.ImageNet1K_TFV1)
weights = GoogLeNetWeights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = GoogLeNetWeights.verify(weights) weights = GoogLeNetWeights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", False) original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None: if weights is not None:
if "transform_input" not in kwargs: if "transform_input" not in kwargs:
kwargs["transform_input"] = True _ovewrite_named_param(kwargs, "transform_input", True)
kwargs["aux_logits"] = True _ovewrite_named_param(kwargs, "aux_logits", True)
kwargs["init_weights"] = False _ovewrite_named_param(kwargs, "init_weights", False)
kwargs["num_classes"] = len(weights.meta["categories"]) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = GoogLeNet(**kwargs) model = GoogLeNet(**kwargs)
......
import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
...@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "InceptionV3Weights", "inception_v3"] __all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "InceptionV3Weights", "inception_v3"]
...@@ -25,22 +25,24 @@ class InceptionV3Weights(Weights): ...@@ -25,22 +25,24 @@ class InceptionV3Weights(Weights):
"acc@1": 77.294, "acc@1": 77.294,
"acc@5": 93.450, "acc@5": 93.450,
}, },
default=True,
) )
def inception_v3(weights: Optional[InceptionV3Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: def inception_v3(weights: Optional[InceptionV3Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", InceptionV3Weights.ImageNet1K_TFV1)
weights = InceptionV3Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = InceptionV3Weights.verify(weights) weights = InceptionV3Weights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", True) original_aux_logits = kwargs.get("aux_logits", True)
if weights is not None: if weights is not None:
if "transform_input" not in kwargs: if "transform_input" not in kwargs:
kwargs["transform_input"] = True _ovewrite_named_param(kwargs, "transform_input", True)
kwargs["aux_logits"] = True _ovewrite_named_param(kwargs, "aux_logits", True)
kwargs["init_weights"] = False _ovewrite_named_param(kwargs, "init_weights", False)
kwargs["num_classes"] = len(weights.meta["categories"]) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = Inception3(**kwargs) model = Inception3(**kwargs)
......
import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
...@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.mnasnet import MNASNet from ...models.mnasnet import MNASNet
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [ __all__ = [
...@@ -40,6 +40,7 @@ class MNASNet0_5Weights(Weights): ...@@ -40,6 +40,7 @@ class MNASNet0_5Weights(Weights):
"acc@1": 67.734, "acc@1": 67.734,
"acc@5": 87.490, "acc@5": 87.490,
}, },
default=True,
) )
...@@ -57,6 +58,7 @@ class MNASNet1_0Weights(Weights): ...@@ -57,6 +58,7 @@ class MNASNet1_0Weights(Weights):
"acc@1": 73.456, "acc@1": 73.456,
"acc@5": 91.510, "acc@5": 91.510,
}, },
default=True,
) )
...@@ -67,7 +69,7 @@ class MNASNet1_3Weights(Weights): ...@@ -67,7 +69,7 @@ class MNASNet1_3Weights(Weights):
def _mnasnet(alpha: float, weights: Optional[Weights], progress: bool, **kwargs: Any) -> MNASNet: def _mnasnet(alpha: float, weights: Optional[Weights], progress: bool, **kwargs: Any) -> MNASNet:
if weights is not None: if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"]) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = MNASNet(alpha, **kwargs) model = MNASNet(alpha, **kwargs)
...@@ -78,41 +80,40 @@ def _mnasnet(alpha: float, weights: Optional[Weights], progress: bool, **kwargs: ...@@ -78,41 +80,40 @@ def _mnasnet(alpha: float, weights: Optional[Weights], progress: bool, **kwargs:
def mnasnet0_5(weights: Optional[MNASNet0_5Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: def mnasnet0_5(weights: Optional[MNASNet0_5Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet0_5Weights.ImageNet1K_Community)
weights = MNASNet0_5Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = MNASNet0_5Weights.verify(weights) weights = MNASNet0_5Weights.verify(weights)
return _mnasnet(0.5, weights, progress, **kwargs) return _mnasnet(0.5, weights, progress, **kwargs)
def mnasnet0_75(weights: Optional[MNASNet0_75Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: def mnasnet0_75(weights: Optional[MNASNet0_75Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", None)
if kwargs.pop("pretrained"):
raise ValueError("No checkpoint is available for model type mnasnet0_75")
weights = MNASNet0_75Weights.verify(weights) weights = MNASNet0_75Weights.verify(weights)
return _mnasnet(0.75, weights, progress, **kwargs) return _mnasnet(0.75, weights, progress, **kwargs)
def mnasnet1_0(weights: Optional[MNASNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: def mnasnet1_0(weights: Optional[MNASNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet1_0Weights.ImageNet1K_Community)
weights = MNASNet1_0Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = MNASNet1_0Weights.verify(weights) weights = MNASNet1_0Weights.verify(weights)
return _mnasnet(1.0, weights, progress, **kwargs) return _mnasnet(1.0, weights, progress, **kwargs)
def mnasnet1_3(weights: Optional[MNASNet1_3Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: def mnasnet1_3(weights: Optional[MNASNet1_3Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", None)
if kwargs.pop("pretrained"):
raise ValueError("No checkpoint is available for model type mnasnet1_3")
weights = MNASNet1_3Weights.verify(weights) weights = MNASNet1_3Weights.verify(weights)
return _mnasnet(1.3, weights, progress, **kwargs) return _mnasnet(1.3, weights, progress, **kwargs)
import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
...@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.mobilenetv2 import MobileNetV2 from ...models.mobilenetv2 import MobileNetV2
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = ["MobileNetV2", "MobileNetV2Weights", "mobilenet_v2"] __all__ = ["MobileNetV2", "MobileNetV2Weights", "mobilenet_v2"]
...@@ -25,17 +25,19 @@ class MobileNetV2Weights(Weights): ...@@ -25,17 +25,19 @@ class MobileNetV2Weights(Weights):
"acc@1": 71.878, "acc@1": 71.878,
"acc@5": 90.286, "acc@5": 90.286,
}, },
default=True,
) )
def mobilenet_v2(weights: Optional[MobileNetV2Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2: def mobilenet_v2(weights: Optional[MobileNetV2Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNetV2Weights.ImageNet1K_RefV1)
weights = MobileNetV2Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = MobileNetV2Weights.verify(weights) weights = MobileNetV2Weights.verify(weights)
if weights is not None: if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"]) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = MobileNetV2(**kwargs) model = MobileNetV2(**kwargs)
......
import warnings
from functools import partial from functools import partial
from typing import Any, Optional, List from typing import Any, Optional, List
...@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [ __all__ = [
...@@ -27,7 +27,7 @@ def _mobilenet_v3( ...@@ -27,7 +27,7 @@ def _mobilenet_v3(
**kwargs: Any, **kwargs: Any,
) -> MobileNetV3: ) -> MobileNetV3:
if weights is not None: if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"]) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
...@@ -54,6 +54,7 @@ class MobileNetV3LargeWeights(Weights): ...@@ -54,6 +54,7 @@ class MobileNetV3LargeWeights(Weights):
"acc@1": 74.042, "acc@1": 74.042,
"acc@5": 91.340, "acc@5": 91.340,
}, },
default=False,
) )
ImageNet1K_RefV2 = WeightEntry( ImageNet1K_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth", url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth",
...@@ -64,6 +65,7 @@ class MobileNetV3LargeWeights(Weights): ...@@ -64,6 +65,7 @@ class MobileNetV3LargeWeights(Weights):
"acc@1": 75.274, "acc@1": 75.274,
"acc@5": 92.566, "acc@5": 92.566,
}, },
default=True,
) )
...@@ -77,15 +79,17 @@ class MobileNetV3SmallWeights(Weights): ...@@ -77,15 +79,17 @@ class MobileNetV3SmallWeights(Weights):
"acc@1": 67.668, "acc@1": 67.668,
"acc@5": 87.402, "acc@5": 87.402,
}, },
default=True,
) )
def mobilenet_v3_large( def mobilenet_v3_large(
weights: Optional[MobileNetV3LargeWeights] = None, progress: bool = True, **kwargs: Any weights: Optional[MobileNetV3LargeWeights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV3: ) -> MobileNetV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNetV3LargeWeights.ImageNet1K_RefV1)
weights = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = MobileNetV3LargeWeights.verify(weights) weights = MobileNetV3LargeWeights.verify(weights)
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
...@@ -95,9 +99,10 @@ def mobilenet_v3_large( ...@@ -95,9 +99,10 @@ def mobilenet_v3_large(
def mobilenet_v3_small( def mobilenet_v3_small(
weights: Optional[MobileNetV3SmallWeights] = None, progress: bool = True, **kwargs: Any weights: Optional[MobileNetV3SmallWeights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV3: ) -> MobileNetV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNetV3SmallWeights.ImageNet1K_RefV1)
weights = MobileNetV3SmallWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = MobileNetV3SmallWeights.verify(weights) weights = MobileNetV3SmallWeights.verify(weights)
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs) inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs)
......
...@@ -12,6 +12,7 @@ from ....models.quantization.googlenet import ( ...@@ -12,6 +12,7 @@ from ....models.quantization.googlenet import (
) )
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ..googlenet import GoogLeNetWeights from ..googlenet import GoogLeNetWeights
...@@ -37,6 +38,7 @@ class QuantizedGoogLeNetWeights(Weights): ...@@ -37,6 +38,7 @@ class QuantizedGoogLeNetWeights(Weights):
"acc@1": 69.826, "acc@1": 69.826,
"acc@5": 89.404, "acc@5": 89.404,
}, },
default=True,
) )
...@@ -46,13 +48,13 @@ def googlenet( ...@@ -46,13 +48,13 @@ def googlenet(
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
) -> QuantizableGoogLeNet: ) -> QuantizableGoogLeNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") default_value = (
if kwargs.pop("pretrained"): QuantizedGoogLeNetWeights.ImageNet1K_FBGEMM_TFV1 if quantize else GoogLeNetWeights.ImageNet1K_TFV1
weights = QuantizedGoogLeNetWeights.ImageNet1K_FBGEMM_TFV1 if quantize else GoogLeNetWeights.ImageNet1K_TFV1 )
else: weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
weights = None
if quantize: if quantize:
weights = QuantizedGoogLeNetWeights.verify(weights) weights = QuantizedGoogLeNetWeights.verify(weights)
else: else:
...@@ -61,12 +63,12 @@ def googlenet( ...@@ -61,12 +63,12 @@ def googlenet(
original_aux_logits = kwargs.get("aux_logits", False) original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None: if weights is not None:
if "transform_input" not in kwargs: if "transform_input" not in kwargs:
kwargs["transform_input"] = True _ovewrite_named_param(kwargs, "transform_input", True)
kwargs["aux_logits"] = True _ovewrite_named_param(kwargs, "aux_logits", True)
kwargs["init_weights"] = False _ovewrite_named_param(kwargs, "init_weights", False)
kwargs["num_classes"] = len(weights.meta["categories"]) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta: if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"] _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "fbgemm") backend = kwargs.pop("backend", "fbgemm")
model = QuantizableGoogLeNet(**kwargs) model = QuantizableGoogLeNet(**kwargs)
......
import warnings
from functools import partial from functools import partial
from typing import Any, Optional, Union from typing import Any, Optional, Union
...@@ -12,6 +11,7 @@ from ....models.quantization.inception import ( ...@@ -12,6 +11,7 @@ from ....models.quantization.inception import (
) )
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ..inception import InceptionV3Weights from ..inception import InceptionV3Weights
...@@ -37,6 +37,7 @@ class QuantizedInceptionV3Weights(Weights): ...@@ -37,6 +37,7 @@ class QuantizedInceptionV3Weights(Weights):
"acc@1": 77.176, "acc@1": 77.176,
"acc@5": 93.354, "acc@5": 93.354,
}, },
default=True,
) )
...@@ -46,15 +47,13 @@ def inception_v3( ...@@ -46,15 +47,13 @@ def inception_v3(
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
) -> QuantizableInception3: ) -> QuantizableInception3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") default_value = (
if kwargs.pop("pretrained"):
weights = (
QuantizedInceptionV3Weights.ImageNet1K_FBGEMM_TFV1 if quantize else InceptionV3Weights.ImageNet1K_TFV1 QuantizedInceptionV3Weights.ImageNet1K_FBGEMM_TFV1 if quantize else InceptionV3Weights.ImageNet1K_TFV1
) )
else: weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
weights = None
if quantize: if quantize:
weights = QuantizedInceptionV3Weights.verify(weights) weights = QuantizedInceptionV3Weights.verify(weights)
else: else:
...@@ -63,11 +62,11 @@ def inception_v3( ...@@ -63,11 +62,11 @@ def inception_v3(
original_aux_logits = kwargs.get("aux_logits", False) original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None: if weights is not None:
if "transform_input" not in kwargs: if "transform_input" not in kwargs:
kwargs["transform_input"] = True _ovewrite_named_param(kwargs, "transform_input", True)
kwargs["aux_logits"] = True _ovewrite_named_param(kwargs, "aux_logits", True)
kwargs["num_classes"] = len(weights.meta["categories"]) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta: if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"] _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "fbgemm") backend = kwargs.pop("backend", "fbgemm")
model = QuantizableInception3(**kwargs) model = QuantizableInception3(**kwargs)
......
import warnings
from functools import partial from functools import partial
from typing import Any, Optional, Union from typing import Any, Optional, Union
...@@ -13,6 +12,7 @@ from ....models.quantization.mobilenetv2 import ( ...@@ -13,6 +12,7 @@ from ....models.quantization.mobilenetv2 import (
) )
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ..mobilenetv2 import MobileNetV2Weights from ..mobilenetv2 import MobileNetV2Weights
...@@ -38,6 +38,7 @@ class QuantizedMobileNetV2Weights(Weights): ...@@ -38,6 +38,7 @@ class QuantizedMobileNetV2Weights(Weights):
"acc@1": 71.658, "acc@1": 71.658,
"acc@5": 90.150, "acc@5": 90.150,
}, },
default=True,
) )
...@@ -47,26 +48,22 @@ def mobilenet_v2( ...@@ -47,26 +48,22 @@ def mobilenet_v2(
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
) -> QuantizableMobileNetV2: ) -> QuantizableMobileNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") default_value = (
if kwargs.pop("pretrained"): QuantizedMobileNetV2Weights.ImageNet1K_QNNPACK_RefV1 if quantize else MobileNetV2Weights.ImageNet1K_RefV1
weights = (
QuantizedMobileNetV2Weights.ImageNet1K_QNNPACK_RefV1
if quantize
else MobileNetV2Weights.ImageNet1K_RefV1
) )
else: weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
weights = None
if quantize: if quantize:
weights = QuantizedMobileNetV2Weights.verify(weights) weights = QuantizedMobileNetV2Weights.verify(weights)
else: else:
weights = MobileNetV2Weights.verify(weights) weights = MobileNetV2Weights.verify(weights)
if weights is not None: if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"]) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta: if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"] _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "qnnpack") backend = kwargs.pop("backend", "qnnpack")
model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs) model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment