Unverified Commit b43353e8 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding multiweight support to Quantized GoogLeNet (#4848)

* Reordering the builders to use proper typing.

* Adding additional meta-data on existing quantized models.

* Fixing meta on unquantized model.

* Adding quantized googlenet builder.

* undo inception move.

* Adding recipe information.
parent 3300692c
......@@ -31,6 +31,10 @@ Here `$MODEL` is one of `alexnet`, `vgg11`, `vgg13`, `vgg16` or `vgg19`. Note
that `vgg11_bn`, `vgg13_bn`, `vgg16_bn`, and `vgg19_bn` include batch
normalization and thus are trained with the default parameters.
### GoogLeNet
The weights of the GoogLeNet model are ported from the original paper rather than trained from scratch.
### Inception V3
The weights of the Inception V3 model are ported from the original paper rather than trained from scratch.
......
......@@ -95,6 +95,7 @@ def test_old_vs_new_factory(model_fn, module_name, dev):
},
"quantization": {
"input_shape": (1, 3, 224, 224),
"quantize": True,
},
"segmentation": {
"input_shape": (1, 3, 520, 520),
......
......@@ -19,69 +19,6 @@ quant_model_urls = {
}
def googlenet(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> "QuantizableGoogLeNet":
r"""GoogLeNet (Inception v1) model architecture from
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
Note that quantize = True returns a quantized model with 8 bit
weights. Quantized models only support inference and run on CPUs.
GPU inference is not yet supported
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model
aux_logits (bool): If True, adds two auxiliary branches that can improve training.
Default: *False* when pretrained is True otherwise *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
was trained on ImageNet. Default: *False*
"""
if pretrained:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if "aux_logits" not in kwargs:
kwargs["aux_logits"] = False
if kwargs["aux_logits"]:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
)
original_aux_logits = kwargs["aux_logits"]
kwargs["aux_logits"] = True
kwargs["init_weights"] = False
model = QuantizableGoogLeNet(**kwargs)
_replace_relu(model)
if quantize:
# TODO use pretrained as a string to specify the backend
backend = "fbgemm"
quantize_model(model, backend)
else:
assert pretrained in [True, False]
if pretrained:
if quantize:
model_url = quant_model_urls["googlenet_" + backend]
else:
model_url = model_urls["googlenet"]
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
model.aux1 = None # type: ignore[assignment]
model.aux2 = None # type: ignore[assignment]
return model
class QuantizableBasicConv2d(BasicConv2d):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
......@@ -164,3 +101,65 @@ class QuantizableGoogLeNet(GoogLeNet):
for m in self.modules():
if type(m) is QuantizableBasicConv2d:
m.fuse_model()
def googlenet(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableGoogLeNet:
r"""GoogLeNet (Inception v1) model architecture from
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
Note that quantize = True returns a quantized model with 8 bit
weights. Quantized models only support inference and run on CPUs.
GPU inference is not yet supported
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model
aux_logits (bool): If True, adds two auxiliary branches that can improve training.
Default: *False* when pretrained is True otherwise *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
was trained on ImageNet. Default: *False*
"""
if pretrained:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if "aux_logits" not in kwargs:
kwargs["aux_logits"] = False
if kwargs["aux_logits"]:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
)
original_aux_logits = kwargs["aux_logits"]
kwargs["aux_logits"] = True
kwargs["init_weights"] = False
model = QuantizableGoogLeNet(**kwargs)
_replace_relu(model)
if quantize:
# TODO use pretrained as a string to specify the backend
backend = "fbgemm"
quantize_model(model, backend)
else:
assert pretrained in [True, False]
if pretrained:
if quantize:
model_url = quant_model_urls["googlenet_" + backend]
else:
model_url = model_urls["googlenet"]
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
model.aux1 = None # type: ignore[assignment]
model.aux2 = None # type: ignore[assignment]
return model
......@@ -14,14 +14,14 @@ __all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNetWeigh
class GoogLeNetWeights(Weights):
ImageNet1K_Community = WeightEntry(
ImageNet1K_TFV1 = WeightEntry(
url="https://download.pytorch.org/models/googlenet-1378be20.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/TheCodez/examples/blob/inception/imagenet/README.md#googlenet",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#googlenet",
"acc@1": 69.778,
"acc@5": 89.530,
},
......@@ -31,7 +31,7 @@ class GoogLeNetWeights(Weights):
def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = GoogLeNetWeights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = GoogLeNetWeights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = GoogLeNetWeights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", False)
......
import warnings
from functools import partial
from typing import Any, Optional, Union
from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.googlenet import (
QuantizableGoogLeNet,
_replace_relu,
quantize_model,
)
from ...transforms.presets import ImageNetEval
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from ..googlenet import GoogLeNetWeights
__all__ = [
"QuantizableGoogLeNet",
"QuantizedGoogLeNetWeights",
"googlenet",
]
class QuantizedGoogLeNetWeights(Weights):
ImageNet1K_FBGEMM_TFV1 = WeightEntry(
url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"backend": "fbgemm",
"quantization": "ptq",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
"unquantized": GoogLeNetWeights.ImageNet1K_TFV1,
"acc@1": 69.826,
"acc@5": 89.404,
},
)
def googlenet(
weights: Optional[Union[QuantizedGoogLeNetWeights, GoogLeNetWeights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableGoogLeNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = QuantizedGoogLeNetWeights.ImageNet1K_FBGEMM_TFV1 if quantize else GoogLeNetWeights.ImageNet1K_TFV1
else:
weights = None
if quantize:
weights = QuantizedGoogLeNetWeights.verify(weights)
else:
weights = GoogLeNetWeights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if original_aux_logits:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
)
kwargs["aux_logits"] = True
kwargs["init_weights"] = False
kwargs["num_classes"] = len(weights.meta["categories"])
if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"]
backend = kwargs.pop("backend", "fbgemm")
model = QuantizableGoogLeNet(**kwargs)
_replace_relu(model)
if quantize:
quantize_model(model, backend)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
if not original_aux_logits:
model.aux_logits = False
model.aux1 = None # type: ignore[assignment]
model.aux2 = None # type: ignore[assignment]
return model
......@@ -2,6 +2,8 @@ import warnings
from functools import partial
from typing import Any, List, Optional, Type, Union
from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.resnet import (
QuantizableBasicBlock,
QuantizableBottleneck,
......@@ -54,7 +56,9 @@ def _resnet(
_common_meta = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"backend": "fbgemm",
"quantization": "ptq",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
}
......@@ -65,6 +69,7 @@ class QuantizedResNet18Weights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"unquantized": ResNet18Weights.ImageNet1K_RefV1,
"acc@1": 69.494,
"acc@5": 88.882,
},
......@@ -77,6 +82,7 @@ class QuantizedResNet50Weights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"unquantized": ResNet50Weights.ImageNet1K_RefV1,
"acc@1": 75.920,
"acc@5": 92.814,
},
......@@ -89,6 +95,7 @@ class QuantizedResNeXt101_32x8dWeights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"unquantized": ResNeXt101_32x8dWeights.ImageNet1K_RefV1,
"acc@1": 78.986,
"acc@5": 94.480,
},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment