Unverified Commit eb48a1d8 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding multiweight support to Quantized MobileNetV2 and MobileNetV3 (#4859)

* Adding multiweight suport on Quant MobileNetV2 and MobileNetV3.

* Fixing enum name.

* Fixing lint.
parent 50a35717
...@@ -4,8 +4,7 @@ from .efficientnet import * ...@@ -4,8 +4,7 @@ from .efficientnet import *
from .googlenet import * from .googlenet import *
from .inception import * from .inception import *
from .mnasnet import * from .mnasnet import *
from .mobilenetv2 import * from .mobilenet import *
from .mobilenetv3 import *
from .regnet import * from .regnet import *
from .resnet import * from .resnet import *
from .shufflenetv2 import * from .shufflenetv2 import *
......
from .mobilenetv2 import * # noqa: F401, F403
from .mobilenetv3 import * # noqa: F401, F403
from .mobilenetv2 import __all__ as mv2_all
from .mobilenetv3 import __all__ as mv3_all
__all__ = mv2_all + mv3_all
from .googlenet import * from .googlenet import *
from .inception import * from .inception import *
from .mobilenet import *
from .resnet import * from .resnet import *
from .shufflenetv2 import * from .shufflenetv2 import *
from .mobilenetv2 import * # noqa: F401, F403
from .mobilenetv3 import * # noqa: F401, F403
from .mobilenetv2 import __all__ as mv2_all
from .mobilenetv3 import __all__ as mv3_all
__all__ = mv2_all + mv3_all
import warnings
from functools import partial
from typing import Any, Optional, Union
from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.mobilenetv2 import (
QuantizableInvertedResidual,
QuantizableMobileNetV2,
_replace_relu,
quantize_model,
)
from ...transforms.presets import ImageNetEval
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from ..mobilenetv2 import MobileNetV2Weights
__all__ = [
"QuantizableMobileNetV2",
"QuantizedMobileNetV2Weights",
"mobilenet_v2",
]
class QuantizedMobileNetV2Weights(Weights):
ImageNet1K_QNNPACK_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"backend": "qnnpack",
"quantization": "qat",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2",
"unquantized": MobileNetV2Weights.ImageNet1K_RefV1,
"acc@1": 71.658,
"acc@5": 90.150,
},
)
def mobilenet_v2(
weights: Optional[Union[QuantizedMobileNetV2Weights, MobileNetV2Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableMobileNetV2:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = (
QuantizedMobileNetV2Weights.ImageNet1K_QNNPACK_RefV1
if quantize
else MobileNetV2Weights.ImageNet1K_RefV1
)
else:
weights = None
if quantize:
weights = QuantizedMobileNetV2Weights.verify(weights)
else:
weights = MobileNetV2Weights.verify(weights)
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"]
backend = kwargs.pop("backend", "qnnpack")
model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs)
_replace_relu(model)
if quantize:
quantize_model(model, backend)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
return model
import warnings
from functools import partial
from typing import Any, List, Optional, Union
import torch
from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.mobilenetv3 import (
InvertedResidualConfig,
QuantizableInvertedResidual,
QuantizableMobileNetV3,
_replace_relu,
)
from ...transforms.presets import ImageNetEval
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from ..mobilenetv3 import MobileNetV3LargeWeights, _mobilenet_v3_conf
__all__ = [
"QuantizableMobileNetV3",
"QuantizedMobileNetV3LargeWeights",
"mobilenet_v3_large",
]
def _mobilenet_v3_model(
inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
weights: Optional[Weights],
progress: bool,
quantize: bool,
**kwargs: Any,
) -> QuantizableMobileNetV3:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"]
backend = kwargs.pop("backend", "qnnpack")
model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs)
_replace_relu(model)
if quantize:
model.fuse_model()
model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
torch.quantization.prepare_qat(model, inplace=True)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
if quantize:
torch.quantization.convert(model, inplace=True)
model.eval()
return model
class QuantizedMobileNetV3LargeWeights(Weights):
ImageNet1K_QNNPACK_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"backend": "qnnpack",
"quantization": "qat",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3",
"unquantized": MobileNetV3LargeWeights.ImageNet1K_RefV1,
"acc@1": 73.004,
"acc@5": 90.858,
},
)
def mobilenet_v3_large(
weights: Optional[Union[QuantizedMobileNetV3LargeWeights, MobileNetV3LargeWeights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableMobileNetV3:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = (
QuantizedMobileNetV3LargeWeights.ImageNet1K_QNNPACK_RefV1
if quantize
else MobileNetV3LargeWeights.ImageNet1K_RefV1
)
else:
weights = None
if quantize:
weights = QuantizedMobileNetV3LargeWeights.verify(weights)
else:
weights = MobileNetV3LargeWeights.verify(weights)
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment