"git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "53e1b61a1e2498e66e4af9ff19e0bc55955b24b0"
Unverified Commit 1cbd9cde authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding multiweight support to Quantized InceptionV3 (#4850)

* Moving builder to the bottom to use proper typing.

* Renaming weights.

* Adding quantizated inception builder.

* Correct meta info.

* Fix linter.

* Removing init_weights to avoid exposing it on the class.
parent 4ccef06c
...@@ -24,72 +24,6 @@ quant_model_urls = { ...@@ -24,72 +24,6 @@ quant_model_urls = {
} }
def inception_v3(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> "QuantizableInception3":
r"""Inception v3 model architecture from
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
.. note::
**Important**: In contrast to the other models the inception_v3 expects tensors with a size of
N x 3 x 299 x 299, so ensure your images are sized accordingly.
Note that quantize = True returns a quantized model with 8 bit
weights. Quantized models only support inference and run on CPUs.
GPU inference is not yet supported
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model
aux_logits (bool): If True, add an auxiliary branch that can improve training.
Default: *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
was trained on ImageNet. Default: *False*
"""
if pretrained:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if "aux_logits" in kwargs:
original_aux_logits = kwargs["aux_logits"]
kwargs["aux_logits"] = True
else:
original_aux_logits = False
model = QuantizableInception3(**kwargs)
_replace_relu(model)
if quantize:
# TODO use pretrained as a string to specify the backend
backend = "fbgemm"
quantize_model(model, backend)
else:
assert pretrained in [True, False]
if pretrained:
if quantize:
if not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
model_url = quant_model_urls["inception_v3_google_" + backend]
else:
model_url = inception_module.model_urls["inception_v3_google"]
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
if not quantize:
if not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
return model
class QuantizableBasicConv2d(inception_module.BasicConv2d): class QuantizableBasicConv2d(inception_module.BasicConv2d):
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
...@@ -237,3 +171,68 @@ class QuantizableInception3(inception_module.Inception3): ...@@ -237,3 +171,68 @@ class QuantizableInception3(inception_module.Inception3):
for m in self.modules(): for m in self.modules():
if type(m) is QuantizableBasicConv2d: if type(m) is QuantizableBasicConv2d:
m.fuse_model() m.fuse_model()
def inception_v3(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableInception3:
r"""Inception v3 model architecture from
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
.. note::
**Important**: In contrast to the other models the inception_v3 expects tensors with a size of
N x 3 x 299 x 299, so ensure your images are sized accordingly.
Note that quantize = True returns a quantized model with 8 bit
weights. Quantized models only support inference and run on CPUs.
GPU inference is not yet supported
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model
aux_logits (bool): If True, add an auxiliary branch that can improve training.
Default: *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
was trained on ImageNet. Default: *False*
"""
if pretrained:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if "aux_logits" in kwargs:
original_aux_logits = kwargs["aux_logits"]
kwargs["aux_logits"] = True
else:
original_aux_logits = False
model = QuantizableInception3(**kwargs)
_replace_relu(model)
if quantize:
# TODO use pretrained as a string to specify the backend
backend = "fbgemm"
quantize_model(model, backend)
else:
assert pretrained in [True, False]
if pretrained:
if quantize:
if not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
model_url = quant_model_urls["inception_v3_google_" + backend]
else:
model_url = inception_module.model_urls["inception_v3_google"]
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
if not quantize:
if not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
return model
...@@ -10,10 +10,10 @@ from ._api import Weights, WeightEntry ...@@ -10,10 +10,10 @@ from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception3Weights", "inception_v3"] __all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "InceptionV3Weights", "inception_v3"]
class Inception3Weights(Weights): class InceptionV3Weights(Weights):
ImageNet1K_TFV1 = WeightEntry( ImageNet1K_TFV1 = WeightEntry(
url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
transforms=partial(ImageNetEval, crop_size=299, resize_size=342), transforms=partial(ImageNetEval, crop_size=299, resize_size=342),
...@@ -28,11 +28,11 @@ class Inception3Weights(Weights): ...@@ -28,11 +28,11 @@ class Inception3Weights(Weights):
) )
def inception_v3(weights: Optional[Inception3Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: def inception_v3(weights: Optional[InceptionV3Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = Inception3Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None weights = InceptionV3Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = Inception3Weights.verify(weights) weights = InceptionV3Weights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", True) original_aux_logits = kwargs.get("aux_logits", True)
if weights is not None: if weights is not None:
......
from .googlenet import * from .googlenet import *
from .inception import *
from .resnet import * from .resnet import *
import warnings
from functools import partial
from typing import Any, Optional, Union
from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.inception import (
QuantizableInception3,
_replace_relu,
quantize_model,
)
from ...transforms.presets import ImageNetEval
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from ..inception import InceptionV3Weights
__all__ = [
"QuantizableInception3",
"QuantizedInceptionV3Weights",
"inception_v3",
]
class QuantizedInceptionV3Weights(Weights):
ImageNet1K_FBGEMM_TFV1 = WeightEntry(
url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth",
transforms=partial(ImageNetEval, crop_size=299, resize_size=342),
meta={
"size": (299, 299),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"backend": "fbgemm",
"quantization": "ptq",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
"unquantized": InceptionV3Weights.ImageNet1K_TFV1,
"acc@1": 77.176,
"acc@5": 93.354,
},
)
def inception_v3(
weights: Optional[Union[QuantizedInceptionV3Weights, InceptionV3Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableInception3:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = (
QuantizedInceptionV3Weights.ImageNet1K_FBGEMM_TFV1 if quantize else InceptionV3Weights.ImageNet1K_TFV1
)
else:
weights = None
if quantize:
weights = QuantizedInceptionV3Weights.verify(weights)
else:
weights = InceptionV3Weights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
kwargs["aux_logits"] = True
kwargs["num_classes"] = len(weights.meta["categories"])
if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"]
backend = kwargs.pop("backend", "fbgemm")
model = QuantizableInception3(**kwargs)
_replace_relu(model)
if quantize:
quantize_model(model, backend)
if weights is not None:
if quantize and not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
model.load_state_dict(weights.state_dict(progress=progress))
if not quantize and not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment