Unverified Commit b9da6db4 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Cleanup namings of Multi-weights classes and enums (#5003)

* Rename classes Weights => WeightsEnum and WeightEntry => Weights.

* Make enum values follow the naming convention `_V1`, `_V2` etc

* Cleanup the Enum class naming conventions.

* Add a test to check naming conventions.
parent b3cdec1f
......@@ -18,6 +18,12 @@ def _get_original_model(model_fn):
return module.__dict__[model_fn.__name__]
def _get_parent_module(model_fn):
parent_module_name = ".".join(model_fn.__module__.split(".")[:-1])
module = importlib.import_module(parent_module_name)
return module
def _build_model(fn, **kwargs):
try:
model = fn(**kwargs)
......@@ -29,20 +35,20 @@ def _build_model(fn, **kwargs):
return model.eval()
def get_models_with_module_names(module):
module_name = module.__name__.split(".")[-1]
return [(fn, module_name) for fn in TM.get_models_from_module(module)]
@pytest.mark.parametrize(
"model_fn, name, weight",
[
(models.resnet50, "ImageNet1K_RefV1", models.ResNet50Weights.ImageNet1K_RefV1),
(models.resnet50, "default", models.ResNet50Weights.ImageNet1K_RefV2),
(models.resnet50, "ImageNet1K_V1", models.ResNet50_Weights.ImageNet1K_V1),
(models.resnet50, "default", models.ResNet50_Weights.ImageNet1K_V2),
(
models.quantization.resnet50,
"default",
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V2,
),
(
models.quantization.resnet50,
"ImageNet1K_FBGEMM_RefV1",
models.quantization.QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1,
"ImageNet1K_FBGEMM_V1",
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1,
),
],
)
......@@ -50,6 +56,21 @@ def test_get_weight(model_fn, name, weight):
assert models._api.get_weight(model_fn, name) == weight
@pytest.mark.parametrize(
"model_fn",
TM.get_models_from_module(models)
+ TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(models.quantization)
+ TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video),
)
def test_naming_conventions(model_fn):
model_name = model_fn.__name__
module = _get_parent_module(model_fn)
weights_name = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights"
assert model_name in set(x.replace(weights_name, "").lower() for x in module.__dict__ if x.endswith(weights_name))
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))
@pytest.mark.parametrize("dev", cpu_and_gpu())
@run_if_test_with_prototype
......@@ -85,16 +106,16 @@ def test_video_model(model_fn, dev):
@pytest.mark.parametrize(
"model_fn, module_name",
get_models_with_module_names(models)
+ get_models_with_module_names(models.detection)
+ get_models_with_module_names(models.quantization)
+ get_models_with_module_names(models.segmentation)
+ get_models_with_module_names(models.video),
"model_fn",
TM.get_models_from_module(models)
+ TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(models.quantization)
+ TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video),
)
@pytest.mark.parametrize("dev", cpu_and_gpu())
@run_if_test_with_prototype
def test_old_vs_new_factory(model_fn, module_name, dev):
def test_old_vs_new_factory(model_fn, dev):
defaults = {
"models": {
"input_shape": (1, 3, 224, 224),
......@@ -114,6 +135,7 @@ def test_old_vs_new_factory(model_fn, module_name, dev):
},
}
model_name = model_fn.__name__
module_name = model_fn.__module__.split(".")[-2]
kwargs = {"pretrained": True, **defaults[module_name], **TM._model_params.get(model_name, {})}
input_shape = kwargs.pop("input_shape")
kwargs.pop("num_classes", None) # ignore this as it's an incompatible speed optimization for pre-trained models
......
......@@ -7,11 +7,11 @@ from typing import Any, Callable, Dict
from ..._internally_replaced_utils import load_state_dict_from_url
__all__ = ["Weights", "WeightEntry", "get_weight"]
__all__ = ["WeightsEnum", "Weights", "get_weight"]
@dataclass
class WeightEntry:
class Weights:
"""
This class is used to group important attributes associated with the pre-trained weights.
......@@ -33,17 +33,17 @@ class WeightEntry:
default: bool
class Weights(Enum):
class WeightsEnum(Enum):
"""
This class is the parent class of all model weights. Each model building method receives an optional `weights`
parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
`WeightEntry`.
`Weights`.
Args:
value (WeightEntry): The data class entry with the weight information.
value (Weights): The data class entry with the weight information.
"""
def __init__(self, value: WeightEntry):
def __init__(self, value: Weights):
self._value_ = value
@classmethod
......@@ -58,7 +58,7 @@ class Weights(Enum):
return obj
@classmethod
def from_str(cls, value: str) -> "Weights":
def from_str(cls, value: str) -> "WeightsEnum":
for v in cls:
if v._name_ == value or (value == "default" and v.default):
return v
......@@ -71,14 +71,14 @@ class Weights(Enum):
return f"{self.__class__.__name__}.{self._name_}"
def __getattr__(self, name):
# Be able to fetch WeightEntry attributes directly
for f in fields(WeightEntry):
# Be able to fetch Weights attributes directly
for f in fields(Weights):
if f.name == name:
return object.__getattribute__(self.value, name)
return super().__getattr__(name)
def get_weight(fn: Callable, weight_name: str) -> Weights:
def get_weight(fn: Callable, weight_name: str) -> WeightsEnum:
"""
Gets the weight enum of a specific model builder method and weight name combination.
......@@ -87,32 +87,32 @@ def get_weight(fn: Callable, weight_name: str) -> Weights:
weight_name (str): The name of the weight enum entry of the specific model.
Returns:
Weights: The requested weight enum.
WeightsEnum: The requested weight enum.
"""
sig = signature(fn)
if "weights" not in sig.parameters:
raise ValueError("The method is missing the 'weights' parameter.")
ann = signature(fn).parameters["weights"].annotation
weights_class = None
if isinstance(ann, type) and issubclass(ann, Weights):
weights_class = ann
weights_enum = None
if isinstance(ann, type) and issubclass(ann, WeightsEnum):
weights_enum = ann
else:
# handle cases like Union[Optional, T]
# TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
for t in ann.__args__: # type: ignore[union-attr]
if isinstance(t, type) and issubclass(t, Weights):
if isinstance(t, type) and issubclass(t, WeightsEnum):
# ensure the name exists. handles builders with multiple types of weights like in quantization
try:
t.from_str(weight_name)
except ValueError:
continue
weights_class = t
weights_enum = t
break
if weights_class is None:
if weights_enum is None:
raise ValueError(
"The weight class for the specific method couldn't be retrieved. Make sure the typing info is correct."
)
return weights_class.from_str(weight_name)
return weights_enum.from_str(weight_name)
import warnings
from typing import Any, Dict, Optional, TypeVar
from ._api import Weights
from ._api import WeightsEnum
W = TypeVar("W", bound=Weights)
W = TypeVar("W", bound=WeightsEnum)
V = TypeVar("V")
......
......@@ -5,16 +5,16 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode
from ...models.alexnet import AlexNet
from ._api import Weights, WeightEntry
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = ["AlexNet", "AlexNetWeights", "alexnet"]
__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]
class AlexNetWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class AlexNet_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -29,12 +29,12 @@ class AlexNetWeights(Weights):
)
def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
def alexnet(weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", AlexNetWeights.ImageNet1K_RefV1)
weights = AlexNetWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", AlexNet_Weights.ImageNet1K_V1)
weights = AlexNet_Weights.verify(weights)
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
......
......@@ -7,17 +7,17 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode
from ...models.densenet import DenseNet
from ._api import Weights, WeightEntry
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [
"DenseNet",
"DenseNet121Weights",
"DenseNet161Weights",
"DenseNet169Weights",
"DenseNet201Weights",
"DenseNet121_Weights",
"DenseNet161_Weights",
"DenseNet169_Weights",
"DenseNet201_Weights",
"densenet121",
"densenet161",
"densenet169",
......@@ -25,7 +25,7 @@ __all__ = [
]
def _load_state_dict(model: nn.Module, weights: Weights, progress: bool) -> None:
def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> None:
# '.'s are no longer allowed in module names, but previous _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
......@@ -48,7 +48,7 @@ def _densenet(
growth_rate: int,
block_config: Tuple[int, int, int, int],
num_init_features: int,
weights: Optional[Weights],
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> DenseNet:
......@@ -71,8 +71,8 @@ _COMMON_META = {
}
class DenseNet121Weights(Weights):
ImageNet1K_Community = WeightEntry(
class DenseNet121_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -84,8 +84,8 @@ class DenseNet121Weights(Weights):
)
class DenseNet161Weights(Weights):
ImageNet1K_Community = WeightEntry(
class DenseNet161_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -97,8 +97,8 @@ class DenseNet161Weights(Weights):
)
class DenseNet169Weights(Weights):
ImageNet1K_Community = WeightEntry(
class DenseNet169_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -110,8 +110,8 @@ class DenseNet169Weights(Weights):
)
class DenseNet201Weights(Weights):
ImageNet1K_Community = WeightEntry(
class DenseNet201_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet201-c1103571.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -123,41 +123,41 @@ class DenseNet201Weights(Weights):
)
def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
def densenet121(weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet121Weights.ImageNet1K_Community)
weights = DenseNet121Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet121_Weights.ImageNet1K_V1)
weights = DenseNet121_Weights.verify(weights)
return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
def densenet161(weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet161Weights.ImageNet1K_Community)
weights = DenseNet161Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet161_Weights.ImageNet1K_V1)
weights = DenseNet161_Weights.verify(weights)
return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)
def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
def densenet169(weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet169Weights.ImageNet1K_Community)
weights = DenseNet169Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet169_Weights.ImageNet1K_V1)
weights = DenseNet169_Weights.verify(weights)
return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)
def densenet201(weights: Optional[DenseNet201Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
def densenet201(weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet201Weights.ImageNet1K_Community)
weights = DenseNet201Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet201_Weights.ImageNet1K_V1)
weights = DenseNet201_Weights.verify(weights)
return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)
......@@ -12,18 +12,18 @@ from ....models.detection.faster_rcnn import (
misc_nn_ops,
overwrite_eps,
)
from .._api import Weights, WeightEntry
from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
from ..resnet import ResNet50Weights, resnet50
from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large
from ..resnet import ResNet50_Weights, resnet50
__all__ = [
"FasterRCNN",
"FasterRCNNResNet50FPNWeights",
"FasterRCNNMobileNetV3LargeFPNWeights",
"FasterRCNNMobileNetV3Large320FPNWeights",
"FasterRCNN_ResNet50_FPN_Weights",
"FasterRCNN_MobileNet_V3_Large_FPN_Weights",
"FasterRCNN_MobileNet_V3_Large_320_FPN_Weights",
"fasterrcnn_resnet50_fpn",
"fasterrcnn_mobilenet_v3_large_fpn",
"fasterrcnn_mobilenet_v3_large_320_fpn",
......@@ -36,8 +36,8 @@ _COMMON_META = {
}
class FasterRCNNResNet50FPNWeights(Weights):
Coco_RefV1 = WeightEntry(
class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
Coco_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
transforms=CocoEval,
meta={
......@@ -49,8 +49,8 @@ class FasterRCNNResNet50FPNWeights(Weights):
)
class FasterRCNNMobileNetV3LargeFPNWeights(Weights):
Coco_RefV1 = WeightEntry(
class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
Coco_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
transforms=CocoEval,
meta={
......@@ -62,8 +62,8 @@ class FasterRCNNMobileNetV3LargeFPNWeights(Weights):
)
class FasterRCNNMobileNetV3Large320FPNWeights(Weights):
Coco_RefV1 = WeightEntry(
class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
Coco_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
transforms=CocoEval,
meta={
......@@ -76,25 +76,25 @@ class FasterRCNNMobileNetV3Large320FPNWeights(Weights):
def fasterrcnn_resnet50_fpn(
weights: Optional[FasterRCNNResNet50FPNWeights] = None,
weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[ResNet50Weights] = None,
weights_backbone: Optional[ResNet50_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNNResNet50FPNWeights.Coco_RefV1)
weights = FasterRCNNResNet50FPNWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNN_ResNet50_FPN_Weights.Coco_V1)
weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1
kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1
)
weights_backbone = ResNet50Weights.verify(weights_backbone)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
......@@ -112,17 +112,17 @@ def fasterrcnn_resnet50_fpn(
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == FasterRCNNResNet50FPNWeights.Coco_RefV1:
if weights == FasterRCNN_ResNet50_FPN_Weights.Coco_V1:
overwrite_eps(model, 0.0)
return model
def _fasterrcnn_mobilenet_v3_large_fpn(
weights: Optional[Union[FasterRCNNMobileNetV3LargeFPNWeights, FasterRCNNMobileNetV3Large320FPNWeights]],
weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]],
progress: bool,
num_classes: Optional[int],
weights_backbone: Optional[MobileNetV3LargeWeights],
weights_backbone: Optional[MobileNet_V3_Large_Weights],
trainable_backbone_layers: Optional[int],
**kwargs: Any,
) -> FasterRCNN:
......@@ -159,25 +159,25 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
def fasterrcnn_mobilenet_v3_large_fpn(
weights: Optional[FasterRCNNMobileNetV3LargeFPNWeights] = None,
weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
weights_backbone: Optional[MobileNet_V3_Large_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNNMobileNetV3LargeFPNWeights.Coco_RefV1)
weights = FasterRCNNMobileNetV3LargeFPNWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNN_MobileNet_V3_Large_FPN_Weights.Coco_V1)
weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1
kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1
)
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
defaults = {
"rpn_score_thresh": 0.05,
......@@ -195,25 +195,27 @@ def fasterrcnn_mobilenet_v3_large_fpn(
def fasterrcnn_mobilenet_v3_large_320_fpn(
weights: Optional[FasterRCNNMobileNetV3Large320FPNWeights] = None,
weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
weights_backbone: Optional[MobileNet_V3_Large_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1)
weights = FasterRCNNMobileNetV3Large320FPNWeights.verify(weights)
weights = _deprecated_param(
kwargs, "pretrained", "weights", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.Coco_V1
)
weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1
kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1
)
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
defaults = {
"min_size": 320,
......
......@@ -9,15 +9,15 @@ from ....models.detection.keypoint_rcnn import (
misc_nn_ops,
overwrite_eps,
)
from .._api import Weights, WeightEntry
from .._api import WeightsEnum, Weights
from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..resnet import ResNet50Weights, resnet50
from ..resnet import ResNet50_Weights, resnet50
__all__ = [
"KeypointRCNN",
"KeypointRCNNResNet50FPNWeights",
"KeypointRCNN_ResNet50_FPN_Weights",
"keypointrcnn_resnet50_fpn",
]
......@@ -25,8 +25,8 @@ __all__ = [
_COMMON_META = {"categories": _COCO_PERSON_CATEGORIES, "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES}
class KeypointRCNNResNet50FPNWeights(Weights):
Coco_RefV1_Legacy = WeightEntry(
class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
Coco_Legacy = Weights(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
transforms=CocoEval,
meta={
......@@ -37,7 +37,7 @@ class KeypointRCNNResNet50FPNWeights(Weights):
},
default=False,
)
Coco_RefV1 = WeightEntry(
Coco_V1 = Weights(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
transforms=CocoEval,
meta={
......@@ -51,30 +51,30 @@ class KeypointRCNNResNet50FPNWeights(Weights):
def keypointrcnn_resnet50_fpn(
weights: Optional[KeypointRCNNResNet50FPNWeights] = None,
weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
num_keypoints: Optional[int] = None,
weights_backbone: Optional[ResNet50Weights] = None,
weights_backbone: Optional[ResNet50_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> KeypointRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = KeypointRCNNResNet50FPNWeights.Coco_RefV1
default_value = KeypointRCNN_ResNet50_FPN_Weights.Coco_V1
if kwargs["pretrained"] == "legacy":
default_value = KeypointRCNNResNet50FPNWeights.Coco_RefV1_Legacy
default_value = KeypointRCNN_ResNet50_FPN_Weights.Coco_Legacy
kwargs["pretrained"] = True
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value)
weights = KeypointRCNNResNet50FPNWeights.verify(weights)
weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1
kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1
)
weights_backbone = ResNet50Weights.verify(weights_backbone)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
......@@ -96,7 +96,7 @@ def keypointrcnn_resnet50_fpn(
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == KeypointRCNNResNet50FPNWeights.Coco_RefV1:
if weights == KeypointRCNN_ResNet50_FPN_Weights.Coco_V1:
overwrite_eps(model, 0.0)
return model
......@@ -10,21 +10,21 @@ from ....models.detection.mask_rcnn import (
misc_nn_ops,
overwrite_eps,
)
from .._api import Weights, WeightEntry
from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..resnet import ResNet50Weights, resnet50
from ..resnet import ResNet50_Weights, resnet50
__all__ = [
"MaskRCNN",
"MaskRCNNResNet50FPNWeights",
"MaskRCNN_ResNet50_FPN_Weights",
"maskrcnn_resnet50_fpn",
]
class MaskRCNNResNet50FPNWeights(Weights):
Coco_RefV1 = WeightEntry(
class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
Coco_V1 = Weights(
url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth",
transforms=CocoEval,
meta={
......@@ -39,25 +39,25 @@ class MaskRCNNResNet50FPNWeights(Weights):
def maskrcnn_resnet50_fpn(
weights: Optional[MaskRCNNResNet50FPNWeights] = None,
weights: Optional[MaskRCNN_ResNet50_FPN_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[ResNet50Weights] = None,
weights_backbone: Optional[ResNet50_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> MaskRCNN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MaskRCNNResNet50FPNWeights.Coco_RefV1)
weights = MaskRCNNResNet50FPNWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", MaskRCNN_ResNet50_FPN_Weights.Coco_V1)
weights = MaskRCNN_ResNet50_FPN_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1
kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1
)
weights_backbone = ResNet50Weights.verify(weights_backbone)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
......@@ -75,7 +75,7 @@ def maskrcnn_resnet50_fpn(
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == MaskRCNNResNet50FPNWeights.Coco_RefV1:
if weights == MaskRCNN_ResNet50_FPN_Weights.Coco_V1:
overwrite_eps(model, 0.0)
return model
......@@ -11,21 +11,21 @@ from ....models.detection.retinanet import (
misc_nn_ops,
overwrite_eps,
)
from .._api import Weights, WeightEntry
from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..resnet import ResNet50Weights, resnet50
from ..resnet import ResNet50_Weights, resnet50
__all__ = [
"RetinaNet",
"RetinaNetResNet50FPNWeights",
"RetinaNet_ResNet50_FPN_Weights",
"retinanet_resnet50_fpn",
]
class RetinaNetResNet50FPNWeights(Weights):
Coco_RefV1 = WeightEntry(
class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
Coco_V1 = Weights(
url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
transforms=CocoEval,
meta={
......@@ -39,25 +39,25 @@ class RetinaNetResNet50FPNWeights(Weights):
def retinanet_resnet50_fpn(
weights: Optional[RetinaNetResNet50FPNWeights] = None,
weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[ResNet50Weights] = None,
weights_backbone: Optional[ResNet50_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> RetinaNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RetinaNetResNet50FPNWeights.Coco_RefV1)
weights = RetinaNetResNet50FPNWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", RetinaNet_ResNet50_FPN_Weights.Coco_V1)
weights = RetinaNet_ResNet50_FPN_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1
kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1
)
weights_backbone = ResNet50Weights.verify(weights_backbone)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
......@@ -78,7 +78,7 @@ def retinanet_resnet50_fpn(
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == RetinaNetResNet50FPNWeights.Coco_RefV1:
if weights == RetinaNet_ResNet50_FPN_Weights.Coco_V1:
overwrite_eps(model, 0.0)
return model
......@@ -10,20 +10,20 @@ from ....models.detection.ssd import (
DefaultBoxGenerator,
SSD,
)
from .._api import Weights, WeightEntry
from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..vgg import VGG16Weights, vgg16
from ..vgg import VGG16_Weights, vgg16
__all__ = [
"SSD300VGG16Weights",
"SSD300_VGG16_Weights",
"ssd300_vgg16",
]
class SSD300VGG16Weights(Weights):
Coco_RefV1 = WeightEntry(
class SSD300_VGG16_Weights(WeightsEnum):
Coco_V1 = Weights(
url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth",
transforms=CocoEval,
meta={
......@@ -38,25 +38,25 @@ class SSD300VGG16Weights(Weights):
def ssd300_vgg16(
weights: Optional[SSD300VGG16Weights] = None,
weights: Optional[SSD300_VGG16_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[VGG16Weights] = None,
weights_backbone: Optional[VGG16_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> SSD:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", SSD300VGG16Weights.Coco_RefV1)
weights = SSD300VGG16Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", SSD300_VGG16_Weights.Coco_V1)
weights = SSD300_VGG16_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", VGG16Weights.ImageNet1K_Features
kwargs, "pretrained_backbone", "weights_backbone", VGG16_Weights.ImageNet1K_Features
)
weights_backbone = VGG16Weights.verify(weights_backbone)
weights_backbone = VGG16_Weights.verify(weights_backbone)
if "size" in kwargs:
warnings.warn("The size of the model is already fixed; ignoring the parameter.")
......
......@@ -15,20 +15,20 @@ from ....models.detection.ssdlite import (
SSD,
SSDLiteHead,
)
from .._api import Weights, WeightEntry
from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large
__all__ = [
"SSDlite320MobileNetV3LargeFPNWeights",
"SSDLite320_MobileNet_V3_Large_Weights",
"ssdlite320_mobilenet_v3_large",
]
class SSDlite320MobileNetV3LargeFPNWeights(Weights):
Coco_RefV1 = WeightEntry(
class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum):
Coco_V1 = Weights(
url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth",
transforms=CocoEval,
meta={
......@@ -43,10 +43,10 @@ class SSDlite320MobileNetV3LargeFPNWeights(Weights):
def ssdlite320_mobilenet_v3_large(
weights: Optional[SSDlite320MobileNetV3LargeFPNWeights] = None,
weights: Optional[SSDLite320_MobileNet_V3_Large_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
weights_backbone: Optional[MobileNet_V3_Large_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any,
......@@ -54,15 +54,15 @@ def ssdlite320_mobilenet_v3_large(
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", SSDlite320MobileNetV3LargeFPNWeights.Coco_RefV1)
weights = SSDlite320MobileNetV3LargeFPNWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", SSDLite320_MobileNet_V3_Large_Weights.Coco_V1)
weights = SSDLite320_MobileNet_V3_Large_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1
kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1
)
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
if "size" in kwargs:
warnings.warn("The size of the model is already fixed; ignoring the parameter.")
......
......@@ -6,21 +6,21 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode
from ...models.efficientnet import EfficientNet, MBConvConfig
from ._api import Weights, WeightEntry
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [
"EfficientNet",
"EfficientNetB0Weights",
"EfficientNetB1Weights",
"EfficientNetB2Weights",
"EfficientNetB3Weights",
"EfficientNetB4Weights",
"EfficientNetB5Weights",
"EfficientNetB6Weights",
"EfficientNetB7Weights",
"EfficientNet_B0_Weights",
"EfficientNet_B1_Weights",
"EfficientNet_B2_Weights",
"EfficientNet_B3_Weights",
"EfficientNet_B4_Weights",
"EfficientNet_B5_Weights",
"EfficientNet_B6_Weights",
"EfficientNet_B7_Weights",
"efficientnet_b0",
"efficientnet_b1",
"efficientnet_b2",
......@@ -36,7 +36,7 @@ def _efficientnet(
width_mult: float,
depth_mult: float,
dropout: float,
weights: Optional[Weights],
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> EfficientNet:
......@@ -69,8 +69,8 @@ _COMMON_META = {
}
class EfficientNetB0Weights(Weights):
ImageNet1K_TimmV1 = WeightEntry(
class EfficientNet_B0_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC),
meta={
......@@ -83,8 +83,8 @@ class EfficientNetB0Weights(Weights):
)
class EfficientNetB1Weights(Weights):
ImageNet1K_TimmV1 = WeightEntry(
class EfficientNet_B1_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
transforms=partial(ImageNetEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC),
meta={
......@@ -97,8 +97,8 @@ class EfficientNetB1Weights(Weights):
)
class EfficientNetB2Weights(Weights):
ImageNet1K_TimmV1 = WeightEntry(
class EfficientNet_B2_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
transforms=partial(ImageNetEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC),
meta={
......@@ -111,8 +111,8 @@ class EfficientNetB2Weights(Weights):
)
class EfficientNetB3Weights(Weights):
ImageNet1K_TimmV1 = WeightEntry(
class EfficientNet_B3_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
transforms=partial(ImageNetEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC),
meta={
......@@ -125,8 +125,8 @@ class EfficientNetB3Weights(Weights):
)
class EfficientNetB4Weights(Weights):
ImageNet1K_TimmV1 = WeightEntry(
class EfficientNet_B4_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
transforms=partial(ImageNetEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC),
meta={
......@@ -139,8 +139,8 @@ class EfficientNetB4Weights(Weights):
)
class EfficientNetB5Weights(Weights):
ImageNet1K_TFV1 = WeightEntry(
class EfficientNet_B5_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
transforms=partial(ImageNetEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC),
meta={
......@@ -153,8 +153,8 @@ class EfficientNetB5Weights(Weights):
)
class EfficientNetB6Weights(Weights):
ImageNet1K_TFV1 = WeightEntry(
class EfficientNet_B6_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
transforms=partial(ImageNetEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC),
meta={
......@@ -167,8 +167,8 @@ class EfficientNetB6Weights(Weights):
)
class EfficientNetB7Weights(Weights):
ImageNet1K_TFV1 = WeightEntry(
class EfficientNet_B7_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
transforms=partial(ImageNetEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC),
meta={
......@@ -182,73 +182,73 @@ class EfficientNetB7Weights(Weights):
def efficientnet_b0(
weights: Optional[EfficientNetB0Weights] = None, progress: bool = True, **kwargs: Any
weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB0Weights.ImageNet1K_TimmV1)
weights = EfficientNetB0Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B0_Weights.ImageNet1K_V1)
weights = EfficientNet_B0_Weights.verify(weights)
return _efficientnet(width_mult=1.0, depth_mult=1.0, dropout=0.2, weights=weights, progress=progress, **kwargs)
def efficientnet_b1(
weights: Optional[EfficientNetB1Weights] = None, progress: bool = True, **kwargs: Any
weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB1Weights.ImageNet1K_TimmV1)
weights = EfficientNetB1Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B1_Weights.ImageNet1K_V1)
weights = EfficientNet_B1_Weights.verify(weights)
return _efficientnet(width_mult=1.0, depth_mult=1.1, dropout=0.2, weights=weights, progress=progress, **kwargs)
def efficientnet_b2(
weights: Optional[EfficientNetB2Weights] = None, progress: bool = True, **kwargs: Any
weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB2Weights.ImageNet1K_TimmV1)
weights = EfficientNetB2Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B2_Weights.ImageNet1K_V1)
weights = EfficientNet_B2_Weights.verify(weights)
return _efficientnet(width_mult=1.1, depth_mult=1.2, dropout=0.3, weights=weights, progress=progress, **kwargs)
def efficientnet_b3(
weights: Optional[EfficientNetB3Weights] = None, progress: bool = True, **kwargs: Any
weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB3Weights.ImageNet1K_TimmV1)
weights = EfficientNetB3Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B3_Weights.ImageNet1K_V1)
weights = EfficientNet_B3_Weights.verify(weights)
return _efficientnet(width_mult=1.2, depth_mult=1.4, dropout=0.3, weights=weights, progress=progress, **kwargs)
def efficientnet_b4(
weights: Optional[EfficientNetB4Weights] = None, progress: bool = True, **kwargs: Any
weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB4Weights.ImageNet1K_TimmV1)
weights = EfficientNetB4Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B4_Weights.ImageNet1K_V1)
weights = EfficientNet_B4_Weights.verify(weights)
return _efficientnet(width_mult=1.4, depth_mult=1.8, dropout=0.4, weights=weights, progress=progress, **kwargs)
def efficientnet_b5(
weights: Optional[EfficientNetB5Weights] = None, progress: bool = True, **kwargs: Any
weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB5Weights.ImageNet1K_TFV1)
weights = EfficientNetB5Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B5_Weights.ImageNet1K_V1)
weights = EfficientNet_B5_Weights.verify(weights)
return _efficientnet(
width_mult=1.6,
......@@ -262,13 +262,13 @@ def efficientnet_b5(
def efficientnet_b6(
weights: Optional[EfficientNetB6Weights] = None, progress: bool = True, **kwargs: Any
weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB6Weights.ImageNet1K_TFV1)
weights = EfficientNetB6Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B6_Weights.ImageNet1K_V1)
weights = EfficientNet_B6_Weights.verify(weights)
return _efficientnet(
width_mult=1.8,
......@@ -282,13 +282,13 @@ def efficientnet_b6(
def efficientnet_b7(
weights: Optional[EfficientNetB7Weights] = None, progress: bool = True, **kwargs: Any
weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB7Weights.ImageNet1K_TFV1)
weights = EfficientNetB7Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B7_Weights.ImageNet1K_V1)
weights = EfficientNet_B7_Weights.verify(weights)
return _efficientnet(
width_mult=2.0,
......
......@@ -6,16 +6,16 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode
from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs
from ._api import Weights, WeightEntry
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNetWeights", "googlenet"]
__all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weights", "googlenet"]
class GoogLeNetWeights(Weights):
ImageNet1K_TFV1 = WeightEntry(
class GoogLeNet_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/googlenet-1378be20.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -30,12 +30,12 @@ class GoogLeNetWeights(Weights):
)
def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
def googlenet(weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", GoogLeNetWeights.ImageNet1K_TFV1)
weights = GoogLeNetWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", GoogLeNet_Weights.ImageNet1K_V1)
weights = GoogLeNet_Weights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None:
......
......@@ -5,16 +5,16 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode
from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs
from ._api import Weights, WeightEntry
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "InceptionV3Weights", "inception_v3"]
__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_Weights", "inception_v3"]
class InceptionV3Weights(Weights):
ImageNet1K_TFV1 = WeightEntry(
class Inception_V3_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
transforms=partial(ImageNetEval, crop_size=299, resize_size=342),
meta={
......@@ -29,12 +29,12 @@ class InceptionV3Weights(Weights):
)
def inception_v3(weights: Optional[InceptionV3Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
def inception_v3(weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", InceptionV3Weights.ImageNet1K_TFV1)
weights = InceptionV3Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", Inception_V3_Weights.ImageNet1K_V1)
weights = Inception_V3_Weights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", True)
if weights is not None:
......
......@@ -5,17 +5,17 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode
from ...models.mnasnet import MNASNet
from ._api import Weights, WeightEntry
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [
"MNASNet",
"MNASNet0_5Weights",
"MNASNet0_75Weights",
"MNASNet1_0Weights",
"MNASNet1_3Weights",
"MNASNet0_5_Weights",
"MNASNet0_75_Weights",
"MNASNet1_0_Weights",
"MNASNet1_3_Weights",
"mnasnet0_5",
"mnasnet0_75",
"mnasnet1_0",
......@@ -31,8 +31,8 @@ _COMMON_META = {
}
class MNASNet0_5Weights(Weights):
ImageNet1K_Community = WeightEntry(
class MNASNet0_5_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -44,13 +44,13 @@ class MNASNet0_5Weights(Weights):
)
class MNASNet0_75Weights(Weights):
class MNASNet0_75_Weights(WeightsEnum):
# If a default model is added here the corresponding changes need to be done in mnasnet0_75
pass
class MNASNet1_0Weights(Weights):
ImageNet1K_Community = WeightEntry(
class MNASNet1_0_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -62,12 +62,12 @@ class MNASNet1_0Weights(Weights):
)
class MNASNet1_3Weights(Weights):
class MNASNet1_3_Weights(WeightsEnum):
# If a default model is added here the corresponding changes need to be done in mnasnet1_3
pass
def _mnasnet(alpha: float, weights: Optional[Weights], progress: bool, **kwargs: Any) -> MNASNet:
def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
......@@ -79,41 +79,41 @@ def _mnasnet(alpha: float, weights: Optional[Weights], progress: bool, **kwargs:
return model
def mnasnet0_5(weights: Optional[MNASNet0_5Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
def mnasnet0_5(weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet0_5Weights.ImageNet1K_Community)
weights = MNASNet0_5Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet0_5_Weights.ImageNet1K_V1)
weights = MNASNet0_5_Weights.verify(weights)
return _mnasnet(0.5, weights, progress, **kwargs)
def mnasnet0_75(weights: Optional[MNASNet0_75Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
def mnasnet0_75(weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = MNASNet0_75Weights.verify(weights)
weights = MNASNet0_75_Weights.verify(weights)
return _mnasnet(0.75, weights, progress, **kwargs)
def mnasnet1_0(weights: Optional[MNASNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
def mnasnet1_0(weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet1_0Weights.ImageNet1K_Community)
weights = MNASNet1_0Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet1_0_Weights.ImageNet1K_V1)
weights = MNASNet1_0_Weights.verify(weights)
return _mnasnet(1.0, weights, progress, **kwargs)
def mnasnet1_3(weights: Optional[MNASNet1_3Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
def mnasnet1_3(weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = MNASNet1_3Weights.verify(weights)
weights = MNASNet1_3_Weights.verify(weights)
return _mnasnet(1.3, weights, progress, **kwargs)
......@@ -5,16 +5,16 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode
from ...models.mobilenetv2 import MobileNetV2
from ._api import Weights, WeightEntry
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = ["MobileNetV2", "MobileNetV2Weights", "mobilenet_v2"]
__all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"]
class MobileNetV2Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class MobileNet_V2_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -29,12 +29,12 @@ class MobileNetV2Weights(Weights):
)
def mobilenet_v2(weights: Optional[MobileNetV2Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2:
def mobilenet_v2(weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNetV2Weights.ImageNet1K_RefV1)
weights = MobileNetV2Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNet_V2_Weights.ImageNet1K_V1)
weights = MobileNet_V2_Weights.verify(weights)
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
......
......@@ -5,15 +5,15 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode
from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig
from ._api import Weights, WeightEntry
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [
"MobileNetV3",
"MobileNetV3LargeWeights",
"MobileNetV3SmallWeights",
"MobileNet_V3_Large_Weights",
"MobileNet_V3_Small_Weights",
"mobilenet_v3_large",
"mobilenet_v3_small",
]
......@@ -22,7 +22,7 @@ __all__ = [
def _mobilenet_v3(
inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
weights: Optional[Weights],
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> MobileNetV3:
......@@ -44,8 +44,8 @@ _COMMON_META = {
}
class MobileNetV3LargeWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class MobileNet_V3_Large_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -56,7 +56,7 @@ class MobileNetV3LargeWeights(Weights):
},
default=False,
)
ImageNet1K_RefV2 = WeightEntry(
ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
......@@ -69,8 +69,8 @@ class MobileNetV3LargeWeights(Weights):
)
class MobileNetV3SmallWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class MobileNet_V3_Small_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -84,26 +84,26 @@ class MobileNetV3SmallWeights(Weights):
def mobilenet_v3_large(
weights: Optional[MobileNetV3LargeWeights] = None, progress: bool = True, **kwargs: Any
weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNetV3LargeWeights.ImageNet1K_RefV1)
weights = MobileNetV3LargeWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNet_V3_Large_Weights.ImageNet1K_V1)
weights = MobileNet_V3_Large_Weights.verify(weights)
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs)
def mobilenet_v3_small(
weights: Optional[MobileNetV3SmallWeights] = None, progress: bool = True, **kwargs: Any
weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNetV3SmallWeights.ImageNet1K_RefV1)
weights = MobileNetV3SmallWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNet_V3_Small_Weights.ImageNet1K_V1)
weights = MobileNet_V3_Small_Weights.verify(weights)
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs)
return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs)
......@@ -10,21 +10,21 @@ from ....models.quantization.googlenet import (
_replace_relu,
quantize_model,
)
from .._api import Weights, WeightEntry
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ..googlenet import GoogLeNetWeights
from ..googlenet import GoogLeNet_Weights
__all__ = [
"QuantizableGoogLeNet",
"QuantizedGoogLeNetWeights",
"GoogLeNet_QuantizedWeights",
"googlenet",
]
class QuantizedGoogLeNetWeights(Weights):
ImageNet1K_FBGEMM_TFV1 = WeightEntry(
class GoogLeNet_QuantizedWeights(WeightsEnum):
ImageNet1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -34,7 +34,7 @@ class QuantizedGoogLeNetWeights(Weights):
"backend": "fbgemm",
"quantization": "ptq",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
"unquantized": GoogLeNetWeights.ImageNet1K_TFV1,
"unquantized": GoogLeNet_Weights.ImageNet1K_V1,
"acc@1": 69.826,
"acc@5": 89.404,
},
......@@ -43,7 +43,7 @@ class QuantizedGoogLeNetWeights(Weights):
def googlenet(
weights: Optional[Union[QuantizedGoogLeNetWeights, GoogLeNetWeights]] = None,
weights: Optional[Union[GoogLeNet_QuantizedWeights, GoogLeNet_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
......@@ -51,14 +51,12 @@ def googlenet(
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = (
QuantizedGoogLeNetWeights.ImageNet1K_FBGEMM_TFV1 if quantize else GoogLeNetWeights.ImageNet1K_TFV1
)
default_value = GoogLeNet_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else GoogLeNet_Weights.ImageNet1K_V1
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = QuantizedGoogLeNetWeights.verify(weights)
weights = GoogLeNet_QuantizedWeights.verify(weights)
else:
weights = GoogLeNetWeights.verify(weights)
weights = GoogLeNet_Weights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None:
......
......@@ -9,21 +9,21 @@ from ....models.quantization.inception import (
_replace_relu,
quantize_model,
)
from .._api import Weights, WeightEntry
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ..inception import InceptionV3Weights
from ..inception import Inception_V3_Weights
__all__ = [
"QuantizableInception3",
"QuantizedInceptionV3Weights",
"Inception_V3_QuantizedWeights",
"inception_v3",
]
class QuantizedInceptionV3Weights(Weights):
ImageNet1K_FBGEMM_TFV1 = WeightEntry(
class Inception_V3_QuantizedWeights(WeightsEnum):
ImageNet1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth",
transforms=partial(ImageNetEval, crop_size=299, resize_size=342),
meta={
......@@ -33,7 +33,7 @@ class QuantizedInceptionV3Weights(Weights):
"backend": "fbgemm",
"quantization": "ptq",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
"unquantized": InceptionV3Weights.ImageNet1K_TFV1,
"unquantized": Inception_V3_Weights.ImageNet1K_V1,
"acc@1": 77.176,
"acc@5": 93.354,
},
......@@ -42,7 +42,7 @@ class QuantizedInceptionV3Weights(Weights):
def inception_v3(
weights: Optional[Union[QuantizedInceptionV3Weights, InceptionV3Weights]] = None,
weights: Optional[Union[Inception_V3_QuantizedWeights, Inception_V3_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
......@@ -51,13 +51,13 @@ def inception_v3(
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = (
QuantizedInceptionV3Weights.ImageNet1K_FBGEMM_TFV1 if quantize else InceptionV3Weights.ImageNet1K_TFV1
Inception_V3_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else Inception_V3_Weights.ImageNet1K_V1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = QuantizedInceptionV3Weights.verify(weights)
weights = Inception_V3_QuantizedWeights.verify(weights)
else:
weights = InceptionV3Weights.verify(weights)
weights = Inception_V3_Weights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None:
......
......@@ -10,21 +10,21 @@ from ....models.quantization.mobilenetv2 import (
_replace_relu,
quantize_model,
)
from .._api import Weights, WeightEntry
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ..mobilenetv2 import MobileNetV2Weights
from ..mobilenetv2 import MobileNet_V2_Weights
__all__ = [
"QuantizableMobileNetV2",
"QuantizedMobileNetV2Weights",
"MobileNet_V2_QuantizedWeights",
"mobilenet_v2",
]
class QuantizedMobileNetV2Weights(Weights):
ImageNet1K_QNNPACK_RefV1 = WeightEntry(
class MobileNet_V2_QuantizedWeights(WeightsEnum):
ImageNet1K_QNNPACK_V1 = Weights(
url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
......@@ -34,7 +34,7 @@ class QuantizedMobileNetV2Weights(Weights):
"backend": "qnnpack",
"quantization": "qat",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2",
"unquantized": MobileNetV2Weights.ImageNet1K_RefV1,
"unquantized": MobileNet_V2_Weights.ImageNet1K_V1,
"acc@1": 71.658,
"acc@5": 90.150,
},
......@@ -43,7 +43,7 @@ class QuantizedMobileNetV2Weights(Weights):
def mobilenet_v2(
weights: Optional[Union[QuantizedMobileNetV2Weights, MobileNetV2Weights]] = None,
weights: Optional[Union[MobileNet_V2_QuantizedWeights, MobileNet_V2_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
......@@ -52,13 +52,13 @@ def mobilenet_v2(
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
default_value = (
QuantizedMobileNetV2Weights.ImageNet1K_QNNPACK_RefV1 if quantize else MobileNetV2Weights.ImageNet1K_RefV1
MobileNet_V2_QuantizedWeights.ImageNet1K_QNNPACK_V1 if quantize else MobileNet_V2_Weights.ImageNet1K_V1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = QuantizedMobileNetV2Weights.verify(weights)
weights = MobileNet_V2_QuantizedWeights.verify(weights)
else:
weights = MobileNetV2Weights.verify(weights)
weights = MobileNet_V2_Weights.verify(weights)
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment