Unverified Commit b43353e8 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding multiweight support to Quantized GoogLeNet (#4848)

* Reordering the builders to use proper typing.

* Adding additional meta-data on existing quantized models.

* Fixing meta on unquantized model.

* Adding quantized googlenet builder.

* undo inception move.

* Adding recipe information.
parent 3300692c
...@@ -31,6 +31,10 @@ Here `$MODEL` is one of `alexnet`, `vgg11`, `vgg13`, `vgg16` or `vgg19`. Note ...@@ -31,6 +31,10 @@ Here `$MODEL` is one of `alexnet`, `vgg11`, `vgg13`, `vgg16` or `vgg19`. Note
that `vgg11_bn`, `vgg13_bn`, `vgg16_bn`, and `vgg19_bn` include batch that `vgg11_bn`, `vgg13_bn`, `vgg16_bn`, and `vgg19_bn` include batch
normalization and thus are trained with the default parameters. normalization and thus are trained with the default parameters.
### GoogLeNet
The weights of the GoogLeNet model are ported from the original paper rather than trained from scratch.
### Inception V3 ### Inception V3
The weights of the Inception V3 model are ported from the original paper rather than trained from scratch. The weights of the Inception V3 model are ported from the original paper rather than trained from scratch.
......
...@@ -95,6 +95,7 @@ def test_old_vs_new_factory(model_fn, module_name, dev): ...@@ -95,6 +95,7 @@ def test_old_vs_new_factory(model_fn, module_name, dev):
}, },
"quantization": { "quantization": {
"input_shape": (1, 3, 224, 224), "input_shape": (1, 3, 224, 224),
"quantize": True,
}, },
"segmentation": { "segmentation": {
"input_shape": (1, 3, 520, 520), "input_shape": (1, 3, 520, 520),
......
...@@ -19,69 +19,6 @@ quant_model_urls = { ...@@ -19,69 +19,6 @@ quant_model_urls = {
} }
def googlenet(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> "QuantizableGoogLeNet":
r"""GoogLeNet (Inception v1) model architecture from
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
Note that quantize = True returns a quantized model with 8 bit
weights. Quantized models only support inference and run on CPUs.
GPU inference is not yet supported
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model
aux_logits (bool): If True, adds two auxiliary branches that can improve training.
Default: *False* when pretrained is True otherwise *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
was trained on ImageNet. Default: *False*
"""
if pretrained:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if "aux_logits" not in kwargs:
kwargs["aux_logits"] = False
if kwargs["aux_logits"]:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
)
original_aux_logits = kwargs["aux_logits"]
kwargs["aux_logits"] = True
kwargs["init_weights"] = False
model = QuantizableGoogLeNet(**kwargs)
_replace_relu(model)
if quantize:
# TODO use pretrained as a string to specify the backend
backend = "fbgemm"
quantize_model(model, backend)
else:
assert pretrained in [True, False]
if pretrained:
if quantize:
model_url = quant_model_urls["googlenet_" + backend]
else:
model_url = model_urls["googlenet"]
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
model.aux1 = None # type: ignore[assignment]
model.aux2 = None # type: ignore[assignment]
return model
class QuantizableBasicConv2d(BasicConv2d): class QuantizableBasicConv2d(BasicConv2d):
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
...@@ -164,3 +101,65 @@ class QuantizableGoogLeNet(GoogLeNet): ...@@ -164,3 +101,65 @@ class QuantizableGoogLeNet(GoogLeNet):
for m in self.modules(): for m in self.modules():
if type(m) is QuantizableBasicConv2d: if type(m) is QuantizableBasicConv2d:
m.fuse_model() m.fuse_model()
def googlenet(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableGoogLeNet:
r"""GoogLeNet (Inception v1) model architecture from
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
Note that quantize = True returns a quantized model with 8 bit
weights. Quantized models only support inference and run on CPUs.
GPU inference is not yet supported
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model
aux_logits (bool): If True, adds two auxiliary branches that can improve training.
Default: *False* when pretrained is True otherwise *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
was trained on ImageNet. Default: *False*
"""
if pretrained:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if "aux_logits" not in kwargs:
kwargs["aux_logits"] = False
if kwargs["aux_logits"]:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
)
original_aux_logits = kwargs["aux_logits"]
kwargs["aux_logits"] = True
kwargs["init_weights"] = False
model = QuantizableGoogLeNet(**kwargs)
_replace_relu(model)
if quantize:
# TODO use pretrained as a string to specify the backend
backend = "fbgemm"
quantize_model(model, backend)
else:
assert pretrained in [True, False]
if pretrained:
if quantize:
model_url = quant_model_urls["googlenet_" + backend]
else:
model_url = model_urls["googlenet"]
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
model.aux1 = None # type: ignore[assignment]
model.aux2 = None # type: ignore[assignment]
return model
...@@ -14,14 +14,14 @@ __all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNetWeigh ...@@ -14,14 +14,14 @@ __all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNetWeigh
class GoogLeNetWeights(Weights): class GoogLeNetWeights(Weights):
ImageNet1K_Community = WeightEntry( ImageNet1K_TFV1 = WeightEntry(
url="https://download.pytorch.org/models/googlenet-1378be20.pth", url="https://download.pytorch.org/models/googlenet-1378be20.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
"size": (224, 224), "size": (224, 224),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/TheCodez/examples/blob/inception/imagenet/README.md#googlenet", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#googlenet",
"acc@1": 69.778, "acc@1": 69.778,
"acc@5": 89.530, "acc@5": 89.530,
}, },
...@@ -31,7 +31,7 @@ class GoogLeNetWeights(Weights): ...@@ -31,7 +31,7 @@ class GoogLeNetWeights(Weights):
def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = GoogLeNetWeights.ImageNet1K_Community if kwargs.pop("pretrained") else None weights = GoogLeNetWeights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = GoogLeNetWeights.verify(weights) weights = GoogLeNetWeights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", False) original_aux_logits = kwargs.get("aux_logits", False)
......
from .googlenet import *
from .resnet import * from .resnet import *
import warnings
from functools import partial
from typing import Any, Optional, Union
from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.googlenet import (
QuantizableGoogLeNet,
_replace_relu,
quantize_model,
)
from ...transforms.presets import ImageNetEval
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from ..googlenet import GoogLeNetWeights
__all__ = [
"QuantizableGoogLeNet",
"QuantizedGoogLeNetWeights",
"googlenet",
]
class QuantizedGoogLeNetWeights(Weights):
ImageNet1K_FBGEMM_TFV1 = WeightEntry(
url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"backend": "fbgemm",
"quantization": "ptq",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
"unquantized": GoogLeNetWeights.ImageNet1K_TFV1,
"acc@1": 69.826,
"acc@5": 89.404,
},
)
def googlenet(
weights: Optional[Union[QuantizedGoogLeNetWeights, GoogLeNetWeights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableGoogLeNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = QuantizedGoogLeNetWeights.ImageNet1K_FBGEMM_TFV1 if quantize else GoogLeNetWeights.ImageNet1K_TFV1
else:
weights = None
if quantize:
weights = QuantizedGoogLeNetWeights.verify(weights)
else:
weights = GoogLeNetWeights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if original_aux_logits:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
)
kwargs["aux_logits"] = True
kwargs["init_weights"] = False
kwargs["num_classes"] = len(weights.meta["categories"])
if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"]
backend = kwargs.pop("backend", "fbgemm")
model = QuantizableGoogLeNet(**kwargs)
_replace_relu(model)
if quantize:
quantize_model(model, backend)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
if not original_aux_logits:
model.aux_logits = False
model.aux1 = None # type: ignore[assignment]
model.aux2 = None # type: ignore[assignment]
return model
...@@ -2,6 +2,8 @@ import warnings ...@@ -2,6 +2,8 @@ import warnings
from functools import partial from functools import partial
from typing import Any, List, Optional, Type, Union from typing import Any, List, Optional, Type, Union
from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.resnet import ( from ....models.quantization.resnet import (
QuantizableBasicBlock, QuantizableBasicBlock,
QuantizableBottleneck, QuantizableBottleneck,
...@@ -54,7 +56,9 @@ def _resnet( ...@@ -54,7 +56,9 @@ def _resnet(
_common_meta = { _common_meta = {
"size": (224, 224), "size": (224, 224),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"backend": "fbgemm", "backend": "fbgemm",
"quantization": "ptq",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
} }
...@@ -65,6 +69,7 @@ class QuantizedResNet18Weights(Weights): ...@@ -65,6 +69,7 @@ class QuantizedResNet18Weights(Weights):
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_common_meta,
"unquantized": ResNet18Weights.ImageNet1K_RefV1,
"acc@1": 69.494, "acc@1": 69.494,
"acc@5": 88.882, "acc@5": 88.882,
}, },
...@@ -77,6 +82,7 @@ class QuantizedResNet50Weights(Weights): ...@@ -77,6 +82,7 @@ class QuantizedResNet50Weights(Weights):
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_common_meta,
"unquantized": ResNet50Weights.ImageNet1K_RefV1,
"acc@1": 75.920, "acc@1": 75.920,
"acc@5": 92.814, "acc@5": 92.814,
}, },
...@@ -89,6 +95,7 @@ class QuantizedResNeXt101_32x8dWeights(Weights): ...@@ -89,6 +95,7 @@ class QuantizedResNeXt101_32x8dWeights(Weights):
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_common_meta,
"unquantized": ResNeXt101_32x8dWeights.ImageNet1K_RefV1,
"acc@1": 78.986, "acc@1": 78.986,
"acc@5": 94.480, "acc@5": 94.480,
}, },
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment