Unverified Commit 7b1b68d7 authored by Joao Gomes's avatar Joao Gomes Committed by GitHub
Browse files

Multi-weight support for MobileNetV3 prototype models (#4723)

* Adding multiweight support for mobilenetv3 prototype
parent b280c318
...@@ -281,7 +281,7 @@ def _mobilenet_v3_conf( ...@@ -281,7 +281,7 @@ def _mobilenet_v3_conf(
return inverted_residual_setting, last_channel return inverted_residual_setting, last_channel
def _mobilenet_v3_model( def _mobilenet_v3(
arch: str, arch: str,
inverted_residual_setting: List[InvertedResidualConfig], inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int, last_channel: int,
...@@ -309,7 +309,7 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs ...@@ -309,7 +309,7 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs
""" """
arch = "mobilenet_v3_large" arch = "mobilenet_v3_large"
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs) inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs)
return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) return _mobilenet_v3(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
...@@ -323,4 +323,4 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs ...@@ -323,4 +323,4 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs
""" """
arch = "mobilenet_v3_small" arch = "mobilenet_v3_small"
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs) inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs)
return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) return _mobilenet_v3(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
...@@ -3,5 +3,6 @@ from .resnet import * ...@@ -3,5 +3,6 @@ from .resnet import *
from .densenet import * from .densenet import *
from .vgg import * from .vgg import *
from .efficientnet import * from .efficientnet import *
from .mobilenetv3 import *
from . import detection from . import detection
from . import quantization from . import quantization
import warnings
from functools import partial
from typing import Any, Optional, List
from torchvision.transforms.functional import InterpolationMode
from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
__all__ = [
"MobileNetV3",
"MobileNetV3LargeWeights",
"MobileNetV3SmallWeights",
"mobilenet_v3_large",
"mobilenet_v3_small",
]
def _mobilenet_v3(
inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
weights: Optional[Weights],
progress: bool,
**kwargs: Any,
) -> MobileNetV3:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
return model
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
class MobileNetV3LargeWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small",
"acc@1": 74.042,
"acc@5": 91.340,
},
)
class MobileNetV3SmallWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small",
"acc@1": 67.668,
"acc@5": 87.402,
},
)
def mobilenet_v3_large(
weights: Optional[MobileNetV3LargeWeights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV3:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = MobileNetV3LargeWeights.verify(weights)
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs)
def mobilenet_v3_small(
weights: Optional[MobileNetV3SmallWeights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV3:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = MobileNetV3SmallWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = MobileNetV3SmallWeights.verify(weights)
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs)
return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment