"docs/vscode:/vscode.git/clone" did not exist on "eadf0e2555cfa19b033e02de53553f71ac33536f"
Unverified Commit b9da6db4 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Cleanup namings of Multi-weights classes and enums (#5003)

* Rename classes Weights => WeightsEnum and WeightEntry => Weights.

* Make enum values follow the naming convention `_V1`, `_V2` etc

* Cleanup the Enum class naming conventions.

* Add a test to check naming conventions.
parent b3cdec1f
...@@ -18,6 +18,12 @@ def _get_original_model(model_fn): ...@@ -18,6 +18,12 @@ def _get_original_model(model_fn):
return module.__dict__[model_fn.__name__] return module.__dict__[model_fn.__name__]
def _get_parent_module(model_fn):
parent_module_name = ".".join(model_fn.__module__.split(".")[:-1])
module = importlib.import_module(parent_module_name)
return module
def _build_model(fn, **kwargs): def _build_model(fn, **kwargs):
try: try:
model = fn(**kwargs) model = fn(**kwargs)
...@@ -29,20 +35,20 @@ def _build_model(fn, **kwargs): ...@@ -29,20 +35,20 @@ def _build_model(fn, **kwargs):
return model.eval() return model.eval()
def get_models_with_module_names(module):
module_name = module.__name__.split(".")[-1]
return [(fn, module_name) for fn in TM.get_models_from_module(module)]
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_fn, name, weight", "model_fn, name, weight",
[ [
(models.resnet50, "ImageNet1K_RefV1", models.ResNet50Weights.ImageNet1K_RefV1), (models.resnet50, "ImageNet1K_V1", models.ResNet50_Weights.ImageNet1K_V1),
(models.resnet50, "default", models.ResNet50Weights.ImageNet1K_RefV2), (models.resnet50, "default", models.ResNet50_Weights.ImageNet1K_V2),
(
models.quantization.resnet50,
"default",
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V2,
),
( (
models.quantization.resnet50, models.quantization.resnet50,
"ImageNet1K_FBGEMM_RefV1", "ImageNet1K_FBGEMM_V1",
models.quantization.QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1, models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1,
), ),
], ],
) )
...@@ -50,6 +56,21 @@ def test_get_weight(model_fn, name, weight): ...@@ -50,6 +56,21 @@ def test_get_weight(model_fn, name, weight):
assert models._api.get_weight(model_fn, name) == weight assert models._api.get_weight(model_fn, name) == weight
@pytest.mark.parametrize(
"model_fn",
TM.get_models_from_module(models)
+ TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(models.quantization)
+ TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video),
)
def test_naming_conventions(model_fn):
model_name = model_fn.__name__
module = _get_parent_module(model_fn)
weights_name = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights"
assert model_name in set(x.replace(weights_name, "").lower() for x in module.__dict__ if x.endswith(weights_name))
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models)) @pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))
@pytest.mark.parametrize("dev", cpu_and_gpu()) @pytest.mark.parametrize("dev", cpu_and_gpu())
@run_if_test_with_prototype @run_if_test_with_prototype
...@@ -85,16 +106,16 @@ def test_video_model(model_fn, dev): ...@@ -85,16 +106,16 @@ def test_video_model(model_fn, dev):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_fn, module_name", "model_fn",
get_models_with_module_names(models) TM.get_models_from_module(models)
+ get_models_with_module_names(models.detection) + TM.get_models_from_module(models.detection)
+ get_models_with_module_names(models.quantization) + TM.get_models_from_module(models.quantization)
+ get_models_with_module_names(models.segmentation) + TM.get_models_from_module(models.segmentation)
+ get_models_with_module_names(models.video), + TM.get_models_from_module(models.video),
) )
@pytest.mark.parametrize("dev", cpu_and_gpu()) @pytest.mark.parametrize("dev", cpu_and_gpu())
@run_if_test_with_prototype @run_if_test_with_prototype
def test_old_vs_new_factory(model_fn, module_name, dev): def test_old_vs_new_factory(model_fn, dev):
defaults = { defaults = {
"models": { "models": {
"input_shape": (1, 3, 224, 224), "input_shape": (1, 3, 224, 224),
...@@ -114,6 +135,7 @@ def test_old_vs_new_factory(model_fn, module_name, dev): ...@@ -114,6 +135,7 @@ def test_old_vs_new_factory(model_fn, module_name, dev):
}, },
} }
model_name = model_fn.__name__ model_name = model_fn.__name__
module_name = model_fn.__module__.split(".")[-2]
kwargs = {"pretrained": True, **defaults[module_name], **TM._model_params.get(model_name, {})} kwargs = {"pretrained": True, **defaults[module_name], **TM._model_params.get(model_name, {})}
input_shape = kwargs.pop("input_shape") input_shape = kwargs.pop("input_shape")
kwargs.pop("num_classes", None) # ignore this as it's an incompatible speed optimization for pre-trained models kwargs.pop("num_classes", None) # ignore this as it's an incompatible speed optimization for pre-trained models
......
...@@ -7,11 +7,11 @@ from typing import Any, Callable, Dict ...@@ -7,11 +7,11 @@ from typing import Any, Callable, Dict
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
__all__ = ["Weights", "WeightEntry", "get_weight"] __all__ = ["WeightsEnum", "Weights", "get_weight"]
@dataclass @dataclass
class WeightEntry: class Weights:
""" """
This class is used to group important attributes associated with the pre-trained weights. This class is used to group important attributes associated with the pre-trained weights.
...@@ -33,17 +33,17 @@ class WeightEntry: ...@@ -33,17 +33,17 @@ class WeightEntry:
default: bool default: bool
class Weights(Enum): class WeightsEnum(Enum):
""" """
This class is the parent class of all model weights. Each model building method receives an optional `weights` This class is the parent class of all model weights. Each model building method receives an optional `weights`
parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
`WeightEntry`. `Weights`.
Args: Args:
value (WeightEntry): The data class entry with the weight information. value (Weights): The data class entry with the weight information.
""" """
def __init__(self, value: WeightEntry): def __init__(self, value: Weights):
self._value_ = value self._value_ = value
@classmethod @classmethod
...@@ -58,7 +58,7 @@ class Weights(Enum): ...@@ -58,7 +58,7 @@ class Weights(Enum):
return obj return obj
@classmethod @classmethod
def from_str(cls, value: str) -> "Weights": def from_str(cls, value: str) -> "WeightsEnum":
for v in cls: for v in cls:
if v._name_ == value or (value == "default" and v.default): if v._name_ == value or (value == "default" and v.default):
return v return v
...@@ -71,14 +71,14 @@ class Weights(Enum): ...@@ -71,14 +71,14 @@ class Weights(Enum):
return f"{self.__class__.__name__}.{self._name_}" return f"{self.__class__.__name__}.{self._name_}"
def __getattr__(self, name): def __getattr__(self, name):
# Be able to fetch WeightEntry attributes directly # Be able to fetch Weights attributes directly
for f in fields(WeightEntry): for f in fields(Weights):
if f.name == name: if f.name == name:
return object.__getattribute__(self.value, name) return object.__getattribute__(self.value, name)
return super().__getattr__(name) return super().__getattr__(name)
def get_weight(fn: Callable, weight_name: str) -> Weights: def get_weight(fn: Callable, weight_name: str) -> WeightsEnum:
""" """
Gets the weight enum of a specific model builder method and weight name combination. Gets the weight enum of a specific model builder method and weight name combination.
...@@ -87,32 +87,32 @@ def get_weight(fn: Callable, weight_name: str) -> Weights: ...@@ -87,32 +87,32 @@ def get_weight(fn: Callable, weight_name: str) -> Weights:
weight_name (str): The name of the weight enum entry of the specific model. weight_name (str): The name of the weight enum entry of the specific model.
Returns: Returns:
Weights: The requested weight enum. WeightsEnum: The requested weight enum.
""" """
sig = signature(fn) sig = signature(fn)
if "weights" not in sig.parameters: if "weights" not in sig.parameters:
raise ValueError("The method is missing the 'weights' parameter.") raise ValueError("The method is missing the 'weights' parameter.")
ann = signature(fn).parameters["weights"].annotation ann = signature(fn).parameters["weights"].annotation
weights_class = None weights_enum = None
if isinstance(ann, type) and issubclass(ann, Weights): if isinstance(ann, type) and issubclass(ann, WeightsEnum):
weights_class = ann weights_enum = ann
else: else:
# handle cases like Union[Optional, T] # handle cases like Union[Optional, T]
# TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8 # TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
for t in ann.__args__: # type: ignore[union-attr] for t in ann.__args__: # type: ignore[union-attr]
if isinstance(t, type) and issubclass(t, Weights): if isinstance(t, type) and issubclass(t, WeightsEnum):
# ensure the name exists. handles builders with multiple types of weights like in quantization # ensure the name exists. handles builders with multiple types of weights like in quantization
try: try:
t.from_str(weight_name) t.from_str(weight_name)
except ValueError: except ValueError:
continue continue
weights_class = t weights_enum = t
break break
if weights_class is None: if weights_enum is None:
raise ValueError( raise ValueError(
"The weight class for the specific method couldn't be retrieved. Make sure the typing info is correct." "The weight class for the specific method couldn't be retrieved. Make sure the typing info is correct."
) )
return weights_class.from_str(weight_name) return weights_enum.from_str(weight_name)
import warnings import warnings
from typing import Any, Dict, Optional, TypeVar from typing import Any, Dict, Optional, TypeVar
from ._api import Weights from ._api import WeightsEnum
W = TypeVar("W", bound=Weights) W = TypeVar("W", bound=WeightsEnum)
V = TypeVar("V") V = TypeVar("V")
......
...@@ -5,16 +5,16 @@ from torchvision.prototype.transforms import ImageNetEval ...@@ -5,16 +5,16 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.alexnet import AlexNet from ...models.alexnet import AlexNet
from ._api import Weights, WeightEntry from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = ["AlexNet", "AlexNetWeights", "alexnet"] __all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]
class AlexNetWeights(Weights): class AlexNet_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -29,12 +29,12 @@ class AlexNetWeights(Weights): ...@@ -29,12 +29,12 @@ class AlexNetWeights(Weights):
) )
def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: def alexnet(weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", AlexNetWeights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", AlexNet_Weights.ImageNet1K_V1)
weights = AlexNetWeights.verify(weights) weights = AlexNet_Weights.verify(weights)
if weights is not None: if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
......
...@@ -7,17 +7,17 @@ from torchvision.prototype.transforms import ImageNetEval ...@@ -7,17 +7,17 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.densenet import DenseNet from ...models.densenet import DenseNet
from ._api import Weights, WeightEntry from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [ __all__ = [
"DenseNet", "DenseNet",
"DenseNet121Weights", "DenseNet121_Weights",
"DenseNet161Weights", "DenseNet161_Weights",
"DenseNet169Weights", "DenseNet169_Weights",
"DenseNet201Weights", "DenseNet201_Weights",
"densenet121", "densenet121",
"densenet161", "densenet161",
"densenet169", "densenet169",
...@@ -25,7 +25,7 @@ __all__ = [ ...@@ -25,7 +25,7 @@ __all__ = [
] ]
def _load_state_dict(model: nn.Module, weights: Weights, progress: bool) -> None: def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> None:
# '.'s are no longer allowed in module names, but previous _DenseLayer # '.'s are no longer allowed in module names, but previous _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used # They are also in the checkpoints in model_urls. This pattern is used
...@@ -48,7 +48,7 @@ def _densenet( ...@@ -48,7 +48,7 @@ def _densenet(
growth_rate: int, growth_rate: int,
block_config: Tuple[int, int, int, int], block_config: Tuple[int, int, int, int],
num_init_features: int, num_init_features: int,
weights: Optional[Weights], weights: Optional[WeightsEnum],
progress: bool, progress: bool,
**kwargs: Any, **kwargs: Any,
) -> DenseNet: ) -> DenseNet:
...@@ -71,8 +71,8 @@ _COMMON_META = { ...@@ -71,8 +71,8 @@ _COMMON_META = {
} }
class DenseNet121Weights(Weights): class DenseNet121_Weights(WeightsEnum):
ImageNet1K_Community = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet121-a639ec97.pth", url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -84,8 +84,8 @@ class DenseNet121Weights(Weights): ...@@ -84,8 +84,8 @@ class DenseNet121Weights(Weights):
) )
class DenseNet161Weights(Weights): class DenseNet161_Weights(WeightsEnum):
ImageNet1K_Community = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet161-8d451a50.pth", url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -97,8 +97,8 @@ class DenseNet161Weights(Weights): ...@@ -97,8 +97,8 @@ class DenseNet161Weights(Weights):
) )
class DenseNet169Weights(Weights): class DenseNet169_Weights(WeightsEnum):
ImageNet1K_Community = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet169-b2777c0a.pth", url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -110,8 +110,8 @@ class DenseNet169Weights(Weights): ...@@ -110,8 +110,8 @@ class DenseNet169Weights(Weights):
) )
class DenseNet201Weights(Weights): class DenseNet201_Weights(WeightsEnum):
ImageNet1K_Community = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet201-c1103571.pth", url="https://download.pytorch.org/models/densenet201-c1103571.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -123,41 +123,41 @@ class DenseNet201Weights(Weights): ...@@ -123,41 +123,41 @@ class DenseNet201Weights(Weights):
) )
def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: def densenet121(weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet121Weights.ImageNet1K_Community) weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet121_Weights.ImageNet1K_V1)
weights = DenseNet121Weights.verify(weights) weights = DenseNet121_Weights.verify(weights)
return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs) return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: def densenet161(weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet161Weights.ImageNet1K_Community) weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet161_Weights.ImageNet1K_V1)
weights = DenseNet161Weights.verify(weights) weights = DenseNet161_Weights.verify(weights)
return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs) return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)
def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: def densenet169(weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet169Weights.ImageNet1K_Community) weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet169_Weights.ImageNet1K_V1)
weights = DenseNet169Weights.verify(weights) weights = DenseNet169_Weights.verify(weights)
return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs) return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)
def densenet201(weights: Optional[DenseNet201Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: def densenet201(weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet201Weights.ImageNet1K_Community) weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet201_Weights.ImageNet1K_V1)
weights = DenseNet201Weights.verify(weights) weights = DenseNet201_Weights.verify(weights)
return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs) return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)
...@@ -12,18 +12,18 @@ from ....models.detection.faster_rcnn import ( ...@@ -12,18 +12,18 @@ from ....models.detection.faster_rcnn import (
misc_nn_ops, misc_nn_ops,
overwrite_eps, overwrite_eps,
) )
from .._api import Weights, WeightEntry from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large
from ..resnet import ResNet50Weights, resnet50 from ..resnet import ResNet50_Weights, resnet50
__all__ = [ __all__ = [
"FasterRCNN", "FasterRCNN",
"FasterRCNNResNet50FPNWeights", "FasterRCNN_ResNet50_FPN_Weights",
"FasterRCNNMobileNetV3LargeFPNWeights", "FasterRCNN_MobileNet_V3_Large_FPN_Weights",
"FasterRCNNMobileNetV3Large320FPNWeights", "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights",
"fasterrcnn_resnet50_fpn", "fasterrcnn_resnet50_fpn",
"fasterrcnn_mobilenet_v3_large_fpn", "fasterrcnn_mobilenet_v3_large_fpn",
"fasterrcnn_mobilenet_v3_large_320_fpn", "fasterrcnn_mobilenet_v3_large_320_fpn",
...@@ -36,8 +36,8 @@ _COMMON_META = { ...@@ -36,8 +36,8 @@ _COMMON_META = {
} }
class FasterRCNNResNet50FPNWeights(Weights): class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
Coco_RefV1 = WeightEntry( Coco_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
...@@ -49,8 +49,8 @@ class FasterRCNNResNet50FPNWeights(Weights): ...@@ -49,8 +49,8 @@ class FasterRCNNResNet50FPNWeights(Weights):
) )
class FasterRCNNMobileNetV3LargeFPNWeights(Weights): class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
Coco_RefV1 = WeightEntry( Coco_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
...@@ -62,8 +62,8 @@ class FasterRCNNMobileNetV3LargeFPNWeights(Weights): ...@@ -62,8 +62,8 @@ class FasterRCNNMobileNetV3LargeFPNWeights(Weights):
) )
class FasterRCNNMobileNetV3Large320FPNWeights(Weights): class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
Coco_RefV1 = WeightEntry( Coco_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
...@@ -76,25 +76,25 @@ class FasterRCNNMobileNetV3Large320FPNWeights(Weights): ...@@ -76,25 +76,25 @@ class FasterRCNNMobileNetV3Large320FPNWeights(Weights):
def fasterrcnn_resnet50_fpn( def fasterrcnn_resnet50_fpn(
weights: Optional[FasterRCNNResNet50FPNWeights] = None, weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
weights_backbone: Optional[ResNet50Weights] = None, weights_backbone: Optional[ResNet50_Weights] = None,
trainable_backbone_layers: Optional[int] = None, trainable_backbone_layers: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> FasterRCNN: ) -> FasterRCNN:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNNResNet50FPNWeights.Coco_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNN_ResNet50_FPN_Weights.Coco_V1)
weights = FasterRCNNResNet50FPNWeights.verify(weights) weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone: if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param( weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1 kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1
) )
weights_backbone = ResNet50Weights.verify(weights_backbone) weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
weights_backbone = None weights_backbone = None
...@@ -112,17 +112,17 @@ def fasterrcnn_resnet50_fpn( ...@@ -112,17 +112,17 @@ def fasterrcnn_resnet50_fpn(
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == FasterRCNNResNet50FPNWeights.Coco_RefV1: if weights == FasterRCNN_ResNet50_FPN_Weights.Coco_V1:
overwrite_eps(model, 0.0) overwrite_eps(model, 0.0)
return model return model
def _fasterrcnn_mobilenet_v3_large_fpn( def _fasterrcnn_mobilenet_v3_large_fpn(
weights: Optional[Union[FasterRCNNMobileNetV3LargeFPNWeights, FasterRCNNMobileNetV3Large320FPNWeights]], weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]],
progress: bool, progress: bool,
num_classes: Optional[int], num_classes: Optional[int],
weights_backbone: Optional[MobileNetV3LargeWeights], weights_backbone: Optional[MobileNet_V3_Large_Weights],
trainable_backbone_layers: Optional[int], trainable_backbone_layers: Optional[int],
**kwargs: Any, **kwargs: Any,
) -> FasterRCNN: ) -> FasterRCNN:
...@@ -159,25 +159,25 @@ def _fasterrcnn_mobilenet_v3_large_fpn( ...@@ -159,25 +159,25 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
def fasterrcnn_mobilenet_v3_large_fpn( def fasterrcnn_mobilenet_v3_large_fpn(
weights: Optional[FasterRCNNMobileNetV3LargeFPNWeights] = None, weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None, weights_backbone: Optional[MobileNet_V3_Large_Weights] = None,
trainable_backbone_layers: Optional[int] = None, trainable_backbone_layers: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> FasterRCNN: ) -> FasterRCNN:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNNMobileNetV3LargeFPNWeights.Coco_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNN_MobileNet_V3_Large_FPN_Weights.Coco_V1)
weights = FasterRCNNMobileNetV3LargeFPNWeights.verify(weights) weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone: if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param( weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1 kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1
) )
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
defaults = { defaults = {
"rpn_score_thresh": 0.05, "rpn_score_thresh": 0.05,
...@@ -195,25 +195,27 @@ def fasterrcnn_mobilenet_v3_large_fpn( ...@@ -195,25 +195,27 @@ def fasterrcnn_mobilenet_v3_large_fpn(
def fasterrcnn_mobilenet_v3_large_320_fpn( def fasterrcnn_mobilenet_v3_large_320_fpn(
weights: Optional[FasterRCNNMobileNetV3Large320FPNWeights] = None, weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None, weights_backbone: Optional[MobileNet_V3_Large_Weights] = None,
trainable_backbone_layers: Optional[int] = None, trainable_backbone_layers: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> FasterRCNN: ) -> FasterRCNN:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1) weights = _deprecated_param(
weights = FasterRCNNMobileNetV3Large320FPNWeights.verify(weights) kwargs, "pretrained", "weights", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.Coco_V1
)
weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone: if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param( weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1 kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1
) )
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
defaults = { defaults = {
"min_size": 320, "min_size": 320,
......
...@@ -9,15 +9,15 @@ from ....models.detection.keypoint_rcnn import ( ...@@ -9,15 +9,15 @@ from ....models.detection.keypoint_rcnn import (
misc_nn_ops, misc_nn_ops,
overwrite_eps, overwrite_eps,
) )
from .._api import Weights, WeightEntry from .._api import WeightsEnum, Weights
from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..resnet import ResNet50Weights, resnet50 from ..resnet import ResNet50_Weights, resnet50
__all__ = [ __all__ = [
"KeypointRCNN", "KeypointRCNN",
"KeypointRCNNResNet50FPNWeights", "KeypointRCNN_ResNet50_FPN_Weights",
"keypointrcnn_resnet50_fpn", "keypointrcnn_resnet50_fpn",
] ]
...@@ -25,8 +25,8 @@ __all__ = [ ...@@ -25,8 +25,8 @@ __all__ = [
_COMMON_META = {"categories": _COCO_PERSON_CATEGORIES, "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES} _COMMON_META = {"categories": _COCO_PERSON_CATEGORIES, "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES}
class KeypointRCNNResNet50FPNWeights(Weights): class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
Coco_RefV1_Legacy = WeightEntry( Coco_Legacy = Weights(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth", url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
...@@ -37,7 +37,7 @@ class KeypointRCNNResNet50FPNWeights(Weights): ...@@ -37,7 +37,7 @@ class KeypointRCNNResNet50FPNWeights(Weights):
}, },
default=False, default=False,
) )
Coco_RefV1 = WeightEntry( Coco_V1 = Weights(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
...@@ -51,30 +51,30 @@ class KeypointRCNNResNet50FPNWeights(Weights): ...@@ -51,30 +51,30 @@ class KeypointRCNNResNet50FPNWeights(Weights):
def keypointrcnn_resnet50_fpn( def keypointrcnn_resnet50_fpn(
weights: Optional[KeypointRCNNResNet50FPNWeights] = None, weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
num_keypoints: Optional[int] = None, num_keypoints: Optional[int] = None,
weights_backbone: Optional[ResNet50Weights] = None, weights_backbone: Optional[ResNet50_Weights] = None,
trainable_backbone_layers: Optional[int] = None, trainable_backbone_layers: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> KeypointRCNN: ) -> KeypointRCNN:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
default_value = KeypointRCNNResNet50FPNWeights.Coco_RefV1 default_value = KeypointRCNN_ResNet50_FPN_Weights.Coco_V1
if kwargs["pretrained"] == "legacy": if kwargs["pretrained"] == "legacy":
default_value = KeypointRCNNResNet50FPNWeights.Coco_RefV1_Legacy default_value = KeypointRCNN_ResNet50_FPN_Weights.Coco_Legacy
kwargs["pretrained"] = True kwargs["pretrained"] = True
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) weights = _deprecated_param(kwargs, "pretrained", "weights", default_value)
weights = KeypointRCNNResNet50FPNWeights.verify(weights) weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone: if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param( weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1 kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1
) )
weights_backbone = ResNet50Weights.verify(weights_backbone) weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
weights_backbone = None weights_backbone = None
...@@ -96,7 +96,7 @@ def keypointrcnn_resnet50_fpn( ...@@ -96,7 +96,7 @@ def keypointrcnn_resnet50_fpn(
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == KeypointRCNNResNet50FPNWeights.Coco_RefV1: if weights == KeypointRCNN_ResNet50_FPN_Weights.Coco_V1:
overwrite_eps(model, 0.0) overwrite_eps(model, 0.0)
return model return model
...@@ -10,21 +10,21 @@ from ....models.detection.mask_rcnn import ( ...@@ -10,21 +10,21 @@ from ....models.detection.mask_rcnn import (
misc_nn_ops, misc_nn_ops,
overwrite_eps, overwrite_eps,
) )
from .._api import Weights, WeightEntry from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..resnet import ResNet50Weights, resnet50 from ..resnet import ResNet50_Weights, resnet50
__all__ = [ __all__ = [
"MaskRCNN", "MaskRCNN",
"MaskRCNNResNet50FPNWeights", "MaskRCNN_ResNet50_FPN_Weights",
"maskrcnn_resnet50_fpn", "maskrcnn_resnet50_fpn",
] ]
class MaskRCNNResNet50FPNWeights(Weights): class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
Coco_RefV1 = WeightEntry( Coco_V1 = Weights(
url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth",
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
...@@ -39,25 +39,25 @@ class MaskRCNNResNet50FPNWeights(Weights): ...@@ -39,25 +39,25 @@ class MaskRCNNResNet50FPNWeights(Weights):
def maskrcnn_resnet50_fpn( def maskrcnn_resnet50_fpn(
weights: Optional[MaskRCNNResNet50FPNWeights] = None, weights: Optional[MaskRCNN_ResNet50_FPN_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
weights_backbone: Optional[ResNet50Weights] = None, weights_backbone: Optional[ResNet50_Weights] = None,
trainable_backbone_layers: Optional[int] = None, trainable_backbone_layers: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> MaskRCNN: ) -> MaskRCNN:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MaskRCNNResNet50FPNWeights.Coco_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", MaskRCNN_ResNet50_FPN_Weights.Coco_V1)
weights = MaskRCNNResNet50FPNWeights.verify(weights) weights = MaskRCNN_ResNet50_FPN_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone: if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param( weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1 kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1
) )
weights_backbone = ResNet50Weights.verify(weights_backbone) weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
weights_backbone = None weights_backbone = None
...@@ -75,7 +75,7 @@ def maskrcnn_resnet50_fpn( ...@@ -75,7 +75,7 @@ def maskrcnn_resnet50_fpn(
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == MaskRCNNResNet50FPNWeights.Coco_RefV1: if weights == MaskRCNN_ResNet50_FPN_Weights.Coco_V1:
overwrite_eps(model, 0.0) overwrite_eps(model, 0.0)
return model return model
...@@ -11,21 +11,21 @@ from ....models.detection.retinanet import ( ...@@ -11,21 +11,21 @@ from ....models.detection.retinanet import (
misc_nn_ops, misc_nn_ops,
overwrite_eps, overwrite_eps,
) )
from .._api import Weights, WeightEntry from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..resnet import ResNet50Weights, resnet50 from ..resnet import ResNet50_Weights, resnet50
__all__ = [ __all__ = [
"RetinaNet", "RetinaNet",
"RetinaNetResNet50FPNWeights", "RetinaNet_ResNet50_FPN_Weights",
"retinanet_resnet50_fpn", "retinanet_resnet50_fpn",
] ]
class RetinaNetResNet50FPNWeights(Weights): class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
Coco_RefV1 = WeightEntry( Coco_V1 = Weights(
url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
...@@ -39,25 +39,25 @@ class RetinaNetResNet50FPNWeights(Weights): ...@@ -39,25 +39,25 @@ class RetinaNetResNet50FPNWeights(Weights):
def retinanet_resnet50_fpn( def retinanet_resnet50_fpn(
weights: Optional[RetinaNetResNet50FPNWeights] = None, weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
weights_backbone: Optional[ResNet50Weights] = None, weights_backbone: Optional[ResNet50_Weights] = None,
trainable_backbone_layers: Optional[int] = None, trainable_backbone_layers: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> RetinaNet: ) -> RetinaNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", RetinaNetResNet50FPNWeights.Coco_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", RetinaNet_ResNet50_FPN_Weights.Coco_V1)
weights = RetinaNetResNet50FPNWeights.verify(weights) weights = RetinaNet_ResNet50_FPN_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone: if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param( weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1 kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1
) )
weights_backbone = ResNet50Weights.verify(weights_backbone) weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
weights_backbone = None weights_backbone = None
...@@ -78,7 +78,7 @@ def retinanet_resnet50_fpn( ...@@ -78,7 +78,7 @@ def retinanet_resnet50_fpn(
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == RetinaNetResNet50FPNWeights.Coco_RefV1: if weights == RetinaNet_ResNet50_FPN_Weights.Coco_V1:
overwrite_eps(model, 0.0) overwrite_eps(model, 0.0)
return model return model
...@@ -10,20 +10,20 @@ from ....models.detection.ssd import ( ...@@ -10,20 +10,20 @@ from ....models.detection.ssd import (
DefaultBoxGenerator, DefaultBoxGenerator,
SSD, SSD,
) )
from .._api import Weights, WeightEntry from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..vgg import VGG16Weights, vgg16 from ..vgg import VGG16_Weights, vgg16
__all__ = [ __all__ = [
"SSD300VGG16Weights", "SSD300_VGG16_Weights",
"ssd300_vgg16", "ssd300_vgg16",
] ]
class SSD300VGG16Weights(Weights): class SSD300_VGG16_Weights(WeightsEnum):
Coco_RefV1 = WeightEntry( Coco_V1 = Weights(
url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth",
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
...@@ -38,25 +38,25 @@ class SSD300VGG16Weights(Weights): ...@@ -38,25 +38,25 @@ class SSD300VGG16Weights(Weights):
def ssd300_vgg16( def ssd300_vgg16(
weights: Optional[SSD300VGG16Weights] = None, weights: Optional[SSD300_VGG16_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
weights_backbone: Optional[VGG16Weights] = None, weights_backbone: Optional[VGG16_Weights] = None,
trainable_backbone_layers: Optional[int] = None, trainable_backbone_layers: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> SSD: ) -> SSD:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", SSD300VGG16Weights.Coco_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", SSD300_VGG16_Weights.Coco_V1)
weights = SSD300VGG16Weights.verify(weights) weights = SSD300_VGG16_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone: if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param( weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", VGG16Weights.ImageNet1K_Features kwargs, "pretrained_backbone", "weights_backbone", VGG16_Weights.ImageNet1K_Features
) )
weights_backbone = VGG16Weights.verify(weights_backbone) weights_backbone = VGG16_Weights.verify(weights_backbone)
if "size" in kwargs: if "size" in kwargs:
warnings.warn("The size of the model is already fixed; ignoring the parameter.") warnings.warn("The size of the model is already fixed; ignoring the parameter.")
......
...@@ -15,20 +15,20 @@ from ....models.detection.ssdlite import ( ...@@ -15,20 +15,20 @@ from ....models.detection.ssdlite import (
SSD, SSD,
SSDLiteHead, SSDLiteHead,
) )
from .._api import Weights, WeightEntry from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large
__all__ = [ __all__ = [
"SSDlite320MobileNetV3LargeFPNWeights", "SSDLite320_MobileNet_V3_Large_Weights",
"ssdlite320_mobilenet_v3_large", "ssdlite320_mobilenet_v3_large",
] ]
class SSDlite320MobileNetV3LargeFPNWeights(Weights): class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum):
Coco_RefV1 = WeightEntry( Coco_V1 = Weights(
url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth", url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth",
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
...@@ -43,10 +43,10 @@ class SSDlite320MobileNetV3LargeFPNWeights(Weights): ...@@ -43,10 +43,10 @@ class SSDlite320MobileNetV3LargeFPNWeights(Weights):
def ssdlite320_mobilenet_v3_large( def ssdlite320_mobilenet_v3_large(
weights: Optional[SSDlite320MobileNetV3LargeFPNWeights] = None, weights: Optional[SSDLite320_MobileNet_V3_Large_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: Optional[int] = None, num_classes: Optional[int] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None, weights_backbone: Optional[MobileNet_V3_Large_Weights] = None,
trainable_backbone_layers: Optional[int] = None, trainable_backbone_layers: Optional[int] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any, **kwargs: Any,
...@@ -54,15 +54,15 @@ def ssdlite320_mobilenet_v3_large( ...@@ -54,15 +54,15 @@ def ssdlite320_mobilenet_v3_large(
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", SSDlite320MobileNetV3LargeFPNWeights.Coco_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", SSDLite320_MobileNet_V3_Large_Weights.Coco_V1)
weights = SSDlite320MobileNetV3LargeFPNWeights.verify(weights) weights = SSDLite320_MobileNet_V3_Large_Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone: if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
weights_backbone = _deprecated_param( weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1 kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1
) )
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
if "size" in kwargs: if "size" in kwargs:
warnings.warn("The size of the model is already fixed; ignoring the parameter.") warnings.warn("The size of the model is already fixed; ignoring the parameter.")
......
...@@ -6,21 +6,21 @@ from torchvision.prototype.transforms import ImageNetEval ...@@ -6,21 +6,21 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.efficientnet import EfficientNet, MBConvConfig from ...models.efficientnet import EfficientNet, MBConvConfig
from ._api import Weights, WeightEntry from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [ __all__ = [
"EfficientNet", "EfficientNet",
"EfficientNetB0Weights", "EfficientNet_B0_Weights",
"EfficientNetB1Weights", "EfficientNet_B1_Weights",
"EfficientNetB2Weights", "EfficientNet_B2_Weights",
"EfficientNetB3Weights", "EfficientNet_B3_Weights",
"EfficientNetB4Weights", "EfficientNet_B4_Weights",
"EfficientNetB5Weights", "EfficientNet_B5_Weights",
"EfficientNetB6Weights", "EfficientNet_B6_Weights",
"EfficientNetB7Weights", "EfficientNet_B7_Weights",
"efficientnet_b0", "efficientnet_b0",
"efficientnet_b1", "efficientnet_b1",
"efficientnet_b2", "efficientnet_b2",
...@@ -36,7 +36,7 @@ def _efficientnet( ...@@ -36,7 +36,7 @@ def _efficientnet(
width_mult: float, width_mult: float,
depth_mult: float, depth_mult: float,
dropout: float, dropout: float,
weights: Optional[Weights], weights: Optional[WeightsEnum],
progress: bool, progress: bool,
**kwargs: Any, **kwargs: Any,
) -> EfficientNet: ) -> EfficientNet:
...@@ -69,8 +69,8 @@ _COMMON_META = { ...@@ -69,8 +69,8 @@ _COMMON_META = {
} }
class EfficientNetB0Weights(Weights): class EfficientNet_B0_Weights(WeightsEnum):
ImageNet1K_TimmV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
...@@ -83,8 +83,8 @@ class EfficientNetB0Weights(Weights): ...@@ -83,8 +83,8 @@ class EfficientNetB0Weights(Weights):
) )
class EfficientNetB1Weights(Weights): class EfficientNet_B1_Weights(WeightsEnum):
ImageNet1K_TimmV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
transforms=partial(ImageNetEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
...@@ -97,8 +97,8 @@ class EfficientNetB1Weights(Weights): ...@@ -97,8 +97,8 @@ class EfficientNetB1Weights(Weights):
) )
class EfficientNetB2Weights(Weights): class EfficientNet_B2_Weights(WeightsEnum):
ImageNet1K_TimmV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
transforms=partial(ImageNetEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
...@@ -111,8 +111,8 @@ class EfficientNetB2Weights(Weights): ...@@ -111,8 +111,8 @@ class EfficientNetB2Weights(Weights):
) )
class EfficientNetB3Weights(Weights): class EfficientNet_B3_Weights(WeightsEnum):
ImageNet1K_TimmV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
transforms=partial(ImageNetEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
...@@ -125,8 +125,8 @@ class EfficientNetB3Weights(Weights): ...@@ -125,8 +125,8 @@ class EfficientNetB3Weights(Weights):
) )
class EfficientNetB4Weights(Weights): class EfficientNet_B4_Weights(WeightsEnum):
ImageNet1K_TimmV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
transforms=partial(ImageNetEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
...@@ -139,8 +139,8 @@ class EfficientNetB4Weights(Weights): ...@@ -139,8 +139,8 @@ class EfficientNetB4Weights(Weights):
) )
class EfficientNetB5Weights(Weights): class EfficientNet_B5_Weights(WeightsEnum):
ImageNet1K_TFV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
transforms=partial(ImageNetEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
...@@ -153,8 +153,8 @@ class EfficientNetB5Weights(Weights): ...@@ -153,8 +153,8 @@ class EfficientNetB5Weights(Weights):
) )
class EfficientNetB6Weights(Weights): class EfficientNet_B6_Weights(WeightsEnum):
ImageNet1K_TFV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
transforms=partial(ImageNetEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
...@@ -167,8 +167,8 @@ class EfficientNetB6Weights(Weights): ...@@ -167,8 +167,8 @@ class EfficientNetB6Weights(Weights):
) )
class EfficientNetB7Weights(Weights): class EfficientNet_B7_Weights(WeightsEnum):
ImageNet1K_TFV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
transforms=partial(ImageNetEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
...@@ -182,73 +182,73 @@ class EfficientNetB7Weights(Weights): ...@@ -182,73 +182,73 @@ class EfficientNetB7Weights(Weights):
def efficientnet_b0( def efficientnet_b0(
weights: Optional[EfficientNetB0Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB0Weights.ImageNet1K_TimmV1) weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B0_Weights.ImageNet1K_V1)
weights = EfficientNetB0Weights.verify(weights) weights = EfficientNet_B0_Weights.verify(weights)
return _efficientnet(width_mult=1.0, depth_mult=1.0, dropout=0.2, weights=weights, progress=progress, **kwargs) return _efficientnet(width_mult=1.0, depth_mult=1.0, dropout=0.2, weights=weights, progress=progress, **kwargs)
def efficientnet_b1( def efficientnet_b1(
weights: Optional[EfficientNetB1Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB1Weights.ImageNet1K_TimmV1) weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B1_Weights.ImageNet1K_V1)
weights = EfficientNetB1Weights.verify(weights) weights = EfficientNet_B1_Weights.verify(weights)
return _efficientnet(width_mult=1.0, depth_mult=1.1, dropout=0.2, weights=weights, progress=progress, **kwargs) return _efficientnet(width_mult=1.0, depth_mult=1.1, dropout=0.2, weights=weights, progress=progress, **kwargs)
def efficientnet_b2( def efficientnet_b2(
weights: Optional[EfficientNetB2Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB2Weights.ImageNet1K_TimmV1) weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B2_Weights.ImageNet1K_V1)
weights = EfficientNetB2Weights.verify(weights) weights = EfficientNet_B2_Weights.verify(weights)
return _efficientnet(width_mult=1.1, depth_mult=1.2, dropout=0.3, weights=weights, progress=progress, **kwargs) return _efficientnet(width_mult=1.1, depth_mult=1.2, dropout=0.3, weights=weights, progress=progress, **kwargs)
def efficientnet_b3( def efficientnet_b3(
weights: Optional[EfficientNetB3Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB3Weights.ImageNet1K_TimmV1) weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B3_Weights.ImageNet1K_V1)
weights = EfficientNetB3Weights.verify(weights) weights = EfficientNet_B3_Weights.verify(weights)
return _efficientnet(width_mult=1.2, depth_mult=1.4, dropout=0.3, weights=weights, progress=progress, **kwargs) return _efficientnet(width_mult=1.2, depth_mult=1.4, dropout=0.3, weights=weights, progress=progress, **kwargs)
def efficientnet_b4( def efficientnet_b4(
weights: Optional[EfficientNetB4Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB4Weights.ImageNet1K_TimmV1) weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B4_Weights.ImageNet1K_V1)
weights = EfficientNetB4Weights.verify(weights) weights = EfficientNet_B4_Weights.verify(weights)
return _efficientnet(width_mult=1.4, depth_mult=1.8, dropout=0.4, weights=weights, progress=progress, **kwargs) return _efficientnet(width_mult=1.4, depth_mult=1.8, dropout=0.4, weights=weights, progress=progress, **kwargs)
def efficientnet_b5( def efficientnet_b5(
weights: Optional[EfficientNetB5Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB5Weights.ImageNet1K_TFV1) weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B5_Weights.ImageNet1K_V1)
weights = EfficientNetB5Weights.verify(weights) weights = EfficientNet_B5_Weights.verify(weights)
return _efficientnet( return _efficientnet(
width_mult=1.6, width_mult=1.6,
...@@ -262,13 +262,13 @@ def efficientnet_b5( ...@@ -262,13 +262,13 @@ def efficientnet_b5(
def efficientnet_b6( def efficientnet_b6(
weights: Optional[EfficientNetB6Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB6Weights.ImageNet1K_TFV1) weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B6_Weights.ImageNet1K_V1)
weights = EfficientNetB6Weights.verify(weights) weights = EfficientNet_B6_Weights.verify(weights)
return _efficientnet( return _efficientnet(
width_mult=1.8, width_mult=1.8,
...@@ -282,13 +282,13 @@ def efficientnet_b6( ...@@ -282,13 +282,13 @@ def efficientnet_b6(
def efficientnet_b7( def efficientnet_b7(
weights: Optional[EfficientNetB7Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB7Weights.ImageNet1K_TFV1) weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B7_Weights.ImageNet1K_V1)
weights = EfficientNetB7Weights.verify(weights) weights = EfficientNet_B7_Weights.verify(weights)
return _efficientnet( return _efficientnet(
width_mult=2.0, width_mult=2.0,
......
...@@ -6,16 +6,16 @@ from torchvision.prototype.transforms import ImageNetEval ...@@ -6,16 +6,16 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs
from ._api import Weights, WeightEntry from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNetWeights", "googlenet"] __all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weights", "googlenet"]
class GoogLeNetWeights(Weights): class GoogLeNet_Weights(WeightsEnum):
ImageNet1K_TFV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/googlenet-1378be20.pth", url="https://download.pytorch.org/models/googlenet-1378be20.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -30,12 +30,12 @@ class GoogLeNetWeights(Weights): ...@@ -30,12 +30,12 @@ class GoogLeNetWeights(Weights):
) )
def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: def googlenet(weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", GoogLeNetWeights.ImageNet1K_TFV1) weights = _deprecated_param(kwargs, "pretrained", "weights", GoogLeNet_Weights.ImageNet1K_V1)
weights = GoogLeNetWeights.verify(weights) weights = GoogLeNet_Weights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", False) original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None: if weights is not None:
......
...@@ -5,16 +5,16 @@ from torchvision.prototype.transforms import ImageNetEval ...@@ -5,16 +5,16 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs
from ._api import Weights, WeightEntry from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "InceptionV3Weights", "inception_v3"] __all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_Weights", "inception_v3"]
class InceptionV3Weights(Weights): class Inception_V3_Weights(WeightsEnum):
ImageNet1K_TFV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
transforms=partial(ImageNetEval, crop_size=299, resize_size=342), transforms=partial(ImageNetEval, crop_size=299, resize_size=342),
meta={ meta={
...@@ -29,12 +29,12 @@ class InceptionV3Weights(Weights): ...@@ -29,12 +29,12 @@ class InceptionV3Weights(Weights):
) )
def inception_v3(weights: Optional[InceptionV3Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: def inception_v3(weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", InceptionV3Weights.ImageNet1K_TFV1) weights = _deprecated_param(kwargs, "pretrained", "weights", Inception_V3_Weights.ImageNet1K_V1)
weights = InceptionV3Weights.verify(weights) weights = Inception_V3_Weights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", True) original_aux_logits = kwargs.get("aux_logits", True)
if weights is not None: if weights is not None:
......
...@@ -5,17 +5,17 @@ from torchvision.prototype.transforms import ImageNetEval ...@@ -5,17 +5,17 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.mnasnet import MNASNet from ...models.mnasnet import MNASNet
from ._api import Weights, WeightEntry from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [ __all__ = [
"MNASNet", "MNASNet",
"MNASNet0_5Weights", "MNASNet0_5_Weights",
"MNASNet0_75Weights", "MNASNet0_75_Weights",
"MNASNet1_0Weights", "MNASNet1_0_Weights",
"MNASNet1_3Weights", "MNASNet1_3_Weights",
"mnasnet0_5", "mnasnet0_5",
"mnasnet0_75", "mnasnet0_75",
"mnasnet1_0", "mnasnet1_0",
...@@ -31,8 +31,8 @@ _COMMON_META = { ...@@ -31,8 +31,8 @@ _COMMON_META = {
} }
class MNASNet0_5Weights(Weights): class MNASNet0_5_Weights(WeightsEnum):
ImageNet1K_Community = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -44,13 +44,13 @@ class MNASNet0_5Weights(Weights): ...@@ -44,13 +44,13 @@ class MNASNet0_5Weights(Weights):
) )
class MNASNet0_75Weights(Weights): class MNASNet0_75_Weights(WeightsEnum):
# If a default model is added here the corresponding changes need to be done in mnasnet0_75 # If a default model is added here the corresponding changes need to be done in mnasnet0_75
pass pass
class MNASNet1_0Weights(Weights): class MNASNet1_0_Weights(WeightsEnum):
ImageNet1K_Community = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -62,12 +62,12 @@ class MNASNet1_0Weights(Weights): ...@@ -62,12 +62,12 @@ class MNASNet1_0Weights(Weights):
) )
class MNASNet1_3Weights(Weights): class MNASNet1_3_Weights(WeightsEnum):
# If a default model is added here the corresponding changes need to be done in mnasnet1_3 # If a default model is added here the corresponding changes need to be done in mnasnet1_3
pass pass
def _mnasnet(alpha: float, weights: Optional[Weights], progress: bool, **kwargs: Any) -> MNASNet: def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet:
if weights is not None: if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
...@@ -79,41 +79,41 @@ def _mnasnet(alpha: float, weights: Optional[Weights], progress: bool, **kwargs: ...@@ -79,41 +79,41 @@ def _mnasnet(alpha: float, weights: Optional[Weights], progress: bool, **kwargs:
return model return model
def mnasnet0_5(weights: Optional[MNASNet0_5Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: def mnasnet0_5(weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet0_5Weights.ImageNet1K_Community) weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet0_5_Weights.ImageNet1K_V1)
weights = MNASNet0_5Weights.verify(weights) weights = MNASNet0_5_Weights.verify(weights)
return _mnasnet(0.5, weights, progress, **kwargs) return _mnasnet(0.5, weights, progress, **kwargs)
def mnasnet0_75(weights: Optional[MNASNet0_75Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: def mnasnet0_75(weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None) weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = MNASNet0_75Weights.verify(weights) weights = MNASNet0_75_Weights.verify(weights)
return _mnasnet(0.75, weights, progress, **kwargs) return _mnasnet(0.75, weights, progress, **kwargs)
def mnasnet1_0(weights: Optional[MNASNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: def mnasnet1_0(weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet1_0Weights.ImageNet1K_Community) weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet1_0_Weights.ImageNet1K_V1)
weights = MNASNet1_0Weights.verify(weights) weights = MNASNet1_0_Weights.verify(weights)
return _mnasnet(1.0, weights, progress, **kwargs) return _mnasnet(1.0, weights, progress, **kwargs)
def mnasnet1_3(weights: Optional[MNASNet1_3Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: def mnasnet1_3(weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None) weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = MNASNet1_3Weights.verify(weights) weights = MNASNet1_3_Weights.verify(weights)
return _mnasnet(1.3, weights, progress, **kwargs) return _mnasnet(1.3, weights, progress, **kwargs)
...@@ -5,16 +5,16 @@ from torchvision.prototype.transforms import ImageNetEval ...@@ -5,16 +5,16 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.mobilenetv2 import MobileNetV2 from ...models.mobilenetv2 import MobileNetV2
from ._api import Weights, WeightEntry from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = ["MobileNetV2", "MobileNetV2Weights", "mobilenet_v2"] __all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"]
class MobileNetV2Weights(Weights): class MobileNet_V2_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -29,12 +29,12 @@ class MobileNetV2Weights(Weights): ...@@ -29,12 +29,12 @@ class MobileNetV2Weights(Weights):
) )
def mobilenet_v2(weights: Optional[MobileNetV2Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2: def mobilenet_v2(weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNetV2Weights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNet_V2_Weights.ImageNet1K_V1)
weights = MobileNetV2Weights.verify(weights) weights = MobileNet_V2_Weights.verify(weights)
if weights is not None: if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
......
...@@ -5,15 +5,15 @@ from torchvision.prototype.transforms import ImageNetEval ...@@ -5,15 +5,15 @@ from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig
from ._api import Weights, WeightEntry from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [ __all__ = [
"MobileNetV3", "MobileNetV3",
"MobileNetV3LargeWeights", "MobileNet_V3_Large_Weights",
"MobileNetV3SmallWeights", "MobileNet_V3_Small_Weights",
"mobilenet_v3_large", "mobilenet_v3_large",
"mobilenet_v3_small", "mobilenet_v3_small",
] ]
...@@ -22,7 +22,7 @@ __all__ = [ ...@@ -22,7 +22,7 @@ __all__ = [
def _mobilenet_v3( def _mobilenet_v3(
inverted_residual_setting: List[InvertedResidualConfig], inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int, last_channel: int,
weights: Optional[Weights], weights: Optional[WeightsEnum],
progress: bool, progress: bool,
**kwargs: Any, **kwargs: Any,
) -> MobileNetV3: ) -> MobileNetV3:
...@@ -44,8 +44,8 @@ _COMMON_META = { ...@@ -44,8 +44,8 @@ _COMMON_META = {
} }
class MobileNetV3LargeWeights(Weights): class MobileNet_V3_Large_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -56,7 +56,7 @@ class MobileNetV3LargeWeights(Weights): ...@@ -56,7 +56,7 @@ class MobileNetV3LargeWeights(Weights):
}, },
default=False, default=False,
) )
ImageNet1K_RefV2 = WeightEntry( ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth", url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={ meta={
...@@ -69,8 +69,8 @@ class MobileNetV3LargeWeights(Weights): ...@@ -69,8 +69,8 @@ class MobileNetV3LargeWeights(Weights):
) )
class MobileNetV3SmallWeights(Weights): class MobileNet_V3_Small_Weights(WeightsEnum):
ImageNet1K_RefV1 = WeightEntry( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -84,26 +84,26 @@ class MobileNetV3SmallWeights(Weights): ...@@ -84,26 +84,26 @@ class MobileNetV3SmallWeights(Weights):
def mobilenet_v3_large( def mobilenet_v3_large(
weights: Optional[MobileNetV3LargeWeights] = None, progress: bool = True, **kwargs: Any weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV3: ) -> MobileNetV3:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNetV3LargeWeights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNet_V3_Large_Weights.ImageNet1K_V1)
weights = MobileNetV3LargeWeights.verify(weights) weights = MobileNet_V3_Large_Weights.verify(weights)
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs)
def mobilenet_v3_small( def mobilenet_v3_small(
weights: Optional[MobileNetV3SmallWeights] = None, progress: bool = True, **kwargs: Any weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV3: ) -> MobileNetV3:
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNetV3SmallWeights.ImageNet1K_RefV1) weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNet_V3_Small_Weights.ImageNet1K_V1)
weights = MobileNetV3SmallWeights.verify(weights) weights = MobileNet_V3_Small_Weights.verify(weights)
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs) inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs)
return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs)
...@@ -10,21 +10,21 @@ from ....models.quantization.googlenet import ( ...@@ -10,21 +10,21 @@ from ....models.quantization.googlenet import (
_replace_relu, _replace_relu,
quantize_model, quantize_model,
) )
from .._api import Weights, WeightEntry from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ..googlenet import GoogLeNetWeights from ..googlenet import GoogLeNet_Weights
__all__ = [ __all__ = [
"QuantizableGoogLeNet", "QuantizableGoogLeNet",
"QuantizedGoogLeNetWeights", "GoogLeNet_QuantizedWeights",
"googlenet", "googlenet",
] ]
class QuantizedGoogLeNetWeights(Weights): class GoogLeNet_QuantizedWeights(WeightsEnum):
ImageNet1K_FBGEMM_TFV1 = WeightEntry( ImageNet1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -34,7 +34,7 @@ class QuantizedGoogLeNetWeights(Weights): ...@@ -34,7 +34,7 @@ class QuantizedGoogLeNetWeights(Weights):
"backend": "fbgemm", "backend": "fbgemm",
"quantization": "ptq", "quantization": "ptq",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
"unquantized": GoogLeNetWeights.ImageNet1K_TFV1, "unquantized": GoogLeNet_Weights.ImageNet1K_V1,
"acc@1": 69.826, "acc@1": 69.826,
"acc@5": 89.404, "acc@5": 89.404,
}, },
...@@ -43,7 +43,7 @@ class QuantizedGoogLeNetWeights(Weights): ...@@ -43,7 +43,7 @@ class QuantizedGoogLeNetWeights(Weights):
def googlenet( def googlenet(
weights: Optional[Union[QuantizedGoogLeNetWeights, GoogLeNetWeights]] = None, weights: Optional[Union[GoogLeNet_QuantizedWeights, GoogLeNet_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
...@@ -51,14 +51,12 @@ def googlenet( ...@@ -51,14 +51,12 @@ def googlenet(
if type(weights) == bool and weights: if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
default_value = ( default_value = GoogLeNet_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else GoogLeNet_Weights.ImageNet1K_V1
QuantizedGoogLeNetWeights.ImageNet1K_FBGEMM_TFV1 if quantize else GoogLeNetWeights.ImageNet1K_TFV1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize: if quantize:
weights = QuantizedGoogLeNetWeights.verify(weights) weights = GoogLeNet_QuantizedWeights.verify(weights)
else: else:
weights = GoogLeNetWeights.verify(weights) weights = GoogLeNet_Weights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", False) original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None: if weights is not None:
......
...@@ -9,21 +9,21 @@ from ....models.quantization.inception import ( ...@@ -9,21 +9,21 @@ from ....models.quantization.inception import (
_replace_relu, _replace_relu,
quantize_model, quantize_model,
) )
from .._api import Weights, WeightEntry from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ..inception import InceptionV3Weights from ..inception import Inception_V3_Weights
__all__ = [ __all__ = [
"QuantizableInception3", "QuantizableInception3",
"QuantizedInceptionV3Weights", "Inception_V3_QuantizedWeights",
"inception_v3", "inception_v3",
] ]
class QuantizedInceptionV3Weights(Weights): class Inception_V3_QuantizedWeights(WeightsEnum):
ImageNet1K_FBGEMM_TFV1 = WeightEntry( ImageNet1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth", url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth",
transforms=partial(ImageNetEval, crop_size=299, resize_size=342), transforms=partial(ImageNetEval, crop_size=299, resize_size=342),
meta={ meta={
...@@ -33,7 +33,7 @@ class QuantizedInceptionV3Weights(Weights): ...@@ -33,7 +33,7 @@ class QuantizedInceptionV3Weights(Weights):
"backend": "fbgemm", "backend": "fbgemm",
"quantization": "ptq", "quantization": "ptq",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
"unquantized": InceptionV3Weights.ImageNet1K_TFV1, "unquantized": Inception_V3_Weights.ImageNet1K_V1,
"acc@1": 77.176, "acc@1": 77.176,
"acc@5": 93.354, "acc@5": 93.354,
}, },
...@@ -42,7 +42,7 @@ class QuantizedInceptionV3Weights(Weights): ...@@ -42,7 +42,7 @@ class QuantizedInceptionV3Weights(Weights):
def inception_v3( def inception_v3(
weights: Optional[Union[QuantizedInceptionV3Weights, InceptionV3Weights]] = None, weights: Optional[Union[Inception_V3_QuantizedWeights, Inception_V3_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
...@@ -51,13 +51,13 @@ def inception_v3( ...@@ -51,13 +51,13 @@ def inception_v3(
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
default_value = ( default_value = (
QuantizedInceptionV3Weights.ImageNet1K_FBGEMM_TFV1 if quantize else InceptionV3Weights.ImageNet1K_TFV1 Inception_V3_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else Inception_V3_Weights.ImageNet1K_V1
) )
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize: if quantize:
weights = QuantizedInceptionV3Weights.verify(weights) weights = Inception_V3_QuantizedWeights.verify(weights)
else: else:
weights = InceptionV3Weights.verify(weights) weights = Inception_V3_Weights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", False) original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None: if weights is not None:
......
...@@ -10,21 +10,21 @@ from ....models.quantization.mobilenetv2 import ( ...@@ -10,21 +10,21 @@ from ....models.quantization.mobilenetv2 import (
_replace_relu, _replace_relu,
quantize_model, quantize_model,
) )
from .._api import Weights, WeightEntry from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ..mobilenetv2 import MobileNetV2Weights from ..mobilenetv2 import MobileNet_V2_Weights
__all__ = [ __all__ = [
"QuantizableMobileNetV2", "QuantizableMobileNetV2",
"QuantizedMobileNetV2Weights", "MobileNet_V2_QuantizedWeights",
"mobilenet_v2", "mobilenet_v2",
] ]
class QuantizedMobileNetV2Weights(Weights): class MobileNet_V2_QuantizedWeights(WeightsEnum):
ImageNet1K_QNNPACK_RefV1 = WeightEntry( ImageNet1K_QNNPACK_V1 = Weights(
url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth", url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -34,7 +34,7 @@ class QuantizedMobileNetV2Weights(Weights): ...@@ -34,7 +34,7 @@ class QuantizedMobileNetV2Weights(Weights):
"backend": "qnnpack", "backend": "qnnpack",
"quantization": "qat", "quantization": "qat",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2",
"unquantized": MobileNetV2Weights.ImageNet1K_RefV1, "unquantized": MobileNet_V2_Weights.ImageNet1K_V1,
"acc@1": 71.658, "acc@1": 71.658,
"acc@5": 90.150, "acc@5": 90.150,
}, },
...@@ -43,7 +43,7 @@ class QuantizedMobileNetV2Weights(Weights): ...@@ -43,7 +43,7 @@ class QuantizedMobileNetV2Weights(Weights):
def mobilenet_v2( def mobilenet_v2(
weights: Optional[Union[QuantizedMobileNetV2Weights, MobileNetV2Weights]] = None, weights: Optional[Union[MobileNet_V2_QuantizedWeights, MobileNet_V2_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
...@@ -52,13 +52,13 @@ def mobilenet_v2( ...@@ -52,13 +52,13 @@ def mobilenet_v2(
_deprecated_positional(kwargs, "pretrained", "weights", True) _deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
default_value = ( default_value = (
QuantizedMobileNetV2Weights.ImageNet1K_QNNPACK_RefV1 if quantize else MobileNetV2Weights.ImageNet1K_RefV1 MobileNet_V2_QuantizedWeights.ImageNet1K_QNNPACK_V1 if quantize else MobileNet_V2_Weights.ImageNet1K_V1
) )
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize: if quantize:
weights = QuantizedMobileNetV2Weights.verify(weights) weights = MobileNet_V2_QuantizedWeights.verify(weights)
else: else:
weights = MobileNetV2Weights.verify(weights) weights = MobileNet_V2_Weights.verify(weights)
if weights is not None: if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment