Unverified Commit 0725ccc8 authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Add typing annotations to models/quantization (#4232)



* fix

* add typings

* fixup some more types

* Type more

* remove mypy ignore

* add missing typings

* fix a few mypy errors

* fix mypy errors

* fix mypy

* ignore types

* fixup annotation

* fix remaining types

* cleanup #TODO comments
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent a8d2227f
......@@ -20,10 +20,6 @@ ignore_errors=True
ignore_errors = True
[mypy-torchvision.models.quantization.*]
ignore_errors = True
[mypy-torchvision.ops.*]
ignore_errors = True
......
......@@ -2,6 +2,8 @@ import warnings
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Any
from torch import Tensor
from ..._internally_replaced_utils import load_state_dict_from_url
from torchvision.models.googlenet import (
......@@ -18,7 +20,13 @@ quant_model_urls = {
}
def googlenet(pretrained=False, progress=True, quantize=False, **kwargs):
def googlenet(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> "QuantizableGoogLeNet":
r"""GoogLeNet (Inception v1) model architecture from
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
......@@ -70,48 +78,51 @@ def googlenet(pretrained=False, progress=True, quantize=False, **kwargs):
if not original_aux_logits:
model.aux_logits = False
model.aux1 = None
model.aux2 = None
model.aux1 = None # type: ignore[assignment]
model.aux2 = None # type: ignore[assignment]
return model
class QuantizableBasicConv2d(BasicConv2d):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableBasicConv2d, self).__init__(*args, **kwargs)
self.relu = nn.ReLU()
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
def fuse_model(self):
def fuse_model(self) -> None:
torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)
class QuantizableInception(Inception):
def __init__(self, *args, **kwargs):
super(QuantizableInception, self).__init__(
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInception, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d, *args, **kwargs)
self.cat = nn.quantized.FloatFunctional()
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x)
return self.cat.cat(outputs, 1)
class QuantizableInceptionAux(InceptionAux):
def __init__(self, *args, **kwargs):
super(QuantizableInceptionAux, self).__init__(
conv_block=QuantizableBasicConv2d, *args, **kwargs)
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionAux, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.7)
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
# aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
x = F.adaptive_avg_pool2d(x, (4, 4))
# aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
......@@ -130,9 +141,9 @@ class QuantizableInceptionAux(InceptionAux):
class QuantizableGoogLeNet(GoogLeNet):
def __init__(self, *args, **kwargs):
super(QuantizableGoogLeNet, self).__init__(
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableGoogLeNet, self).__init__( # type: ignore[misc]
blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux],
*args,
**kwargs
......@@ -140,7 +151,7 @@ class QuantizableGoogLeNet(GoogLeNet):
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
def forward(self, x: Tensor) -> GoogLeNetOutputs:
x = self._transform_input(x)
x = self.quant(x)
x, aux1, aux2 = self._forward(x)
......@@ -153,7 +164,7 @@ class QuantizableGoogLeNet(GoogLeNet):
else:
return self.eager_outputs(x, aux2, aux1)
def fuse_model(self):
def fuse_model(self) -> None:
r"""Fuse conv/bn/relu modules in googlenet model
Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
......
......@@ -3,6 +3,9 @@ import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Any, List
from torchvision.models import inception as inception_module
from torchvision.models.inception import InceptionOutputs
from ..._internally_replaced_utils import load_state_dict_from_url
......@@ -22,7 +25,13 @@ quant_model_urls = {
}
def inception_v3(pretrained=False, progress=True, quantize=False, **kwargs):
def inception_v3(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> "QuantizableInception3":
r"""Inception v3 model architecture from
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
......@@ -84,68 +93,93 @@ def inception_v3(pretrained=False, progress=True, quantize=False, **kwargs):
class QuantizableBasicConv2d(inception_module.BasicConv2d):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableBasicConv2d, self).__init__(*args, **kwargs)
self.relu = nn.ReLU()
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
def fuse_model(self):
def fuse_model(self) -> None:
torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)
class QuantizableInceptionA(inception_module.InceptionA):
def __init__(self, *args, **kwargs):
super(QuantizableInceptionA, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionA, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
)
self.myop = nn.quantized.FloatFunctional()
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x)
return self.myop.cat(outputs, 1)
class QuantizableInceptionB(inception_module.InceptionB):
def __init__(self, *args, **kwargs):
super(QuantizableInceptionB, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionB, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
)
self.myop = nn.quantized.FloatFunctional()
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x)
return self.myop.cat(outputs, 1)
class QuantizableInceptionC(inception_module.InceptionC):
def __init__(self, *args, **kwargs):
super(QuantizableInceptionC, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionC, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
)
self.myop = nn.quantized.FloatFunctional()
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x)
return self.myop.cat(outputs, 1)
class QuantizableInceptionD(inception_module.InceptionD):
def __init__(self, *args, **kwargs):
super(QuantizableInceptionD, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionD, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
)
self.myop = nn.quantized.FloatFunctional()
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x)
return self.myop.cat(outputs, 1)
class QuantizableInceptionE(inception_module.InceptionE):
def __init__(self, *args, **kwargs):
super(QuantizableInceptionE, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionE, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
)
self.myop1 = nn.quantized.FloatFunctional()
self.myop2 = nn.quantized.FloatFunctional()
self.myop3 = nn.quantized.FloatFunctional()
def _forward(self, x):
def _forward(self, x: Tensor) -> List[Tensor]:
branch1x1 = self.branch1x1(x)
branch3x3 = self.branch3x3_1(x)
......@@ -166,18 +200,28 @@ class QuantizableInceptionE(inception_module.InceptionE):
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
return outputs
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x)
return self.myop3.cat(outputs, 1)
class QuantizableInceptionAux(inception_module.InceptionAux):
def __init__(self, *args, **kwargs):
super(QuantizableInceptionAux, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionAux, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
)
class QuantizableInception3(inception_module.Inception3):
def __init__(self, num_classes=1000, aux_logits=True, transform_input=False):
def __init__(
self,
num_classes: int = 1000,
aux_logits: bool = True,
transform_input: bool = False,
) -> None:
super(QuantizableInception3, self).__init__(
num_classes=num_classes,
aux_logits=aux_logits,
......@@ -195,7 +239,7 @@ class QuantizableInception3(inception_module.Inception3):
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
def forward(self, x: Tensor) -> InceptionOutputs:
x = self._transform_input(x)
x = self.quant(x)
x, aux = self._forward(x)
......@@ -208,7 +252,7 @@ class QuantizableInception3(inception_module.Inception3):
else:
return self.eager_outputs(x, aux)
def fuse_model(self):
def fuse_model(self) -> None:
r"""Fuse conv/bn/relu modules in inception model
Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
......
from torch import nn
from torch import Tensor
from ..._internally_replaced_utils import load_state_dict_from_url
from typing import Any
from torchvision.models.mobilenetv2 import InvertedResidual, ConvBNReLU, MobileNetV2, model_urls
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
from .utils import _replace_relu, quantize_model
......@@ -14,24 +19,24 @@ quant_model_urls = {
class QuantizableInvertedResidual(InvertedResidual):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInvertedResidual, self).__init__(*args, **kwargs)
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
if self.use_res_connect:
return self.skip_add.add(x, self.conv(x))
else:
return self.conv(x)
def fuse_model(self):
def fuse_model(self) -> None:
for idx in range(len(self.conv)):
if type(self.conv[idx]) == nn.Conv2d:
fuse_modules(self.conv, [str(idx), str(idx + 1)], inplace=True)
class QuantizableMobileNetV2(MobileNetV2):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""
MobileNet V2 main class
......@@ -42,13 +47,13 @@ class QuantizableMobileNetV2(MobileNetV2):
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
x = self.quant(x)
x = self._forward_impl(x)
x = self.dequant(x)
return x
def fuse_model(self):
def fuse_model(self) -> None:
for m in self.modules():
if type(m) == ConvBNReLU:
fuse_modules(m, ['0', '1', '2'], inplace=True)
......@@ -56,7 +61,12 @@ class QuantizableMobileNetV2(MobileNetV2):
m.fuse_model()
def mobilenet_v2(pretrained=False, progress=True, quantize=False, **kwargs):
def mobilenet_v2(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableMobileNetV2:
"""
Constructs a MobileNetV2 architecture from
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks"
......
......@@ -17,23 +17,28 @@ quant_model_urls = {
class QuantizableSqueezeExcitation(SqueezeExcitation):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.skip_mul = nn.quantized.FloatFunctional()
def forward(self, input: Tensor) -> Tensor:
return self.skip_mul.mul(self._scale(input, False), input)
def fuse_model(self):
def fuse_model(self) -> None:
fuse_modules(self, ['fc1', 'relu'], inplace=True)
class QuantizableInvertedResidual(InvertedResidual):
def __init__(self, *args, **kwargs):
super().__init__(*args, se_layer=QuantizableSqueezeExcitation, **kwargs)
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__( # type: ignore[misc]
se_layer=QuantizableSqueezeExcitation,
*args,
**kwargs
)
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
if self.use_res_connect:
return self.skip_add.add(x, self.block(x))
else:
......@@ -41,7 +46,7 @@ class QuantizableInvertedResidual(InvertedResidual):
class QuantizableMobileNetV3(MobileNetV3):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""
MobileNet V3 main class
......@@ -52,13 +57,13 @@ class QuantizableMobileNetV3(MobileNetV3):
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
x = self.quant(x)
x = self._forward_impl(x)
x = self.dequant(x)
return x
def fuse_model(self):
def fuse_model(self) -> None:
for m in self.modules():
if type(m) == ConvBNActivation:
modules_to_fuse = ['0', '1']
......@@ -74,7 +79,7 @@ def _load_weights(
model: QuantizableMobileNetV3,
model_url: Optional[str],
progress: bool,
):
) -> None:
if model_url is None:
raise ValueError("No checkpoint is available for {}".format(arch))
state_dict = load_state_dict_from_url(model_url, progress=progress)
......@@ -88,8 +93,9 @@ def _mobilenet_v3_model(
pretrained: bool,
progress: bool,
quantize: bool,
**kwargs: Any
):
**kwargs: Any,
) -> QuantizableMobileNetV3:
model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs)
_replace_relu(model)
......@@ -112,7 +118,12 @@ def _mobilenet_v3_model(
return model
def mobilenet_v3_large(pretrained=False, progress=True, quantize=False, **kwargs):
def mobilenet_v3_large(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableMobileNetV3:
"""
Constructs a MobileNetV3 Large architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
......
import torch
from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls
import torch.nn as nn
from torch import Tensor
from typing import Any, Type, Union, List
from ..._internally_replaced_utils import load_state_dict_from_url
from torch.quantization import fuse_modules
from .utils import _replace_relu, quantize_model
......@@ -20,11 +23,11 @@ quant_model_urls = {
class QuantizableBasicBlock(BasicBlock):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableBasicBlock, self).__init__(*args, **kwargs)
self.add_relu = torch.nn.quantized.FloatFunctional()
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
......@@ -41,7 +44,7 @@ class QuantizableBasicBlock(BasicBlock):
return out
def fuse_model(self):
def fuse_model(self) -> None:
torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu'],
['conv2', 'bn2']], inplace=True)
if self.downsample:
......@@ -49,13 +52,13 @@ class QuantizableBasicBlock(BasicBlock):
class QuantizableBottleneck(Bottleneck):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableBottleneck, self).__init__(*args, **kwargs)
self.skip_add_relu = nn.quantized.FloatFunctional()
self.relu1 = nn.ReLU(inplace=False)
self.relu2 = nn.ReLU(inplace=False)
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
......@@ -73,7 +76,7 @@ class QuantizableBottleneck(Bottleneck):
return out
def fuse_model(self):
def fuse_model(self) -> None:
fuse_modules(self, [['conv1', 'bn1', 'relu1'],
['conv2', 'bn2', 'relu2'],
['conv3', 'bn3']], inplace=True)
......@@ -83,13 +86,13 @@ class QuantizableBottleneck(Bottleneck):
class QuantizableResNet(ResNet):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableResNet, self).__init__(*args, **kwargs)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
x = self.quant(x)
# Ensure scriptability
# super(QuantizableResNet,self).forward(x)
......@@ -98,7 +101,7 @@ class QuantizableResNet(ResNet):
x = self.dequant(x)
return x
def fuse_model(self):
def fuse_model(self) -> None:
r"""Fuse conv/bn/relu modules in resnet models
Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization.
......@@ -112,7 +115,16 @@ class QuantizableResNet(ResNet):
m.fuse_model()
def _resnet(arch, block, layers, pretrained, progress, quantize, **kwargs):
def _resnet(
arch: str,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
pretrained: bool,
progress: bool,
quantize: bool,
**kwargs: Any,
) -> QuantizableResNet:
model = QuantizableResNet(block, layers, **kwargs)
_replace_relu(model)
if quantize:
......@@ -135,7 +147,12 @@ def _resnet(arch, block, layers, pretrained, progress, quantize, **kwargs):
return model
def resnet18(pretrained=False, progress=True, quantize=False, **kwargs):
def resnet18(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableResNet:
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
......@@ -148,7 +165,13 @@ def resnet18(pretrained=False, progress=True, quantize=False, **kwargs):
quantize, **kwargs)
def resnet50(pretrained=False, progress=True, quantize=False, **kwargs):
def resnet50(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableResNet:
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
......@@ -161,7 +184,12 @@ def resnet50(pretrained=False, progress=True, quantize=False, **kwargs):
quantize, **kwargs)
def resnext101_32x8d(pretrained=False, progress=True, quantize=False, **kwargs):
def resnext101_32x8d(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableResNet:
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
......
import torch
import torch.nn as nn
from torch import Tensor
from typing import Any
from ..._internally_replaced_utils import load_state_dict_from_url
import torchvision.models.shufflenetv2
import sys
from torchvision.models import shufflenetv2
from .utils import _replace_relu, quantize_model
shufflenetv2 = sys.modules['torchvision.models.shufflenetv2']
__all__ = [
'QuantizableShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0',
'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'
......@@ -22,16 +22,16 @@ quant_model_urls = {
class QuantizableInvertedResidual(shufflenetv2.InvertedResidual):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInvertedResidual, self).__init__(*args, **kwargs)
self.cat = nn.quantized.FloatFunctional()
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1)
out = self.cat.cat((x1, self.branch2(x2)), dim=1)
out = self.cat.cat([x1, self.branch2(x2)], dim=1)
else:
out = self.cat.cat((self.branch1(x), self.branch2(x)), dim=1)
out = self.cat.cat([self.branch1(x), self.branch2(x)], dim=1)
out = shufflenetv2.channel_shuffle(out, 2)
......@@ -39,18 +39,23 @@ class QuantizableInvertedResidual(shufflenetv2.InvertedResidual):
class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2):
def __init__(self, *args, **kwargs):
super(QuantizableShuffleNetV2, self).__init__(*args, inverted_residual=QuantizableInvertedResidual, **kwargs)
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableShuffleNetV2, self).__init__( # type: ignore[misc]
*args,
inverted_residual=QuantizableInvertedResidual,
**kwargs
)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
x = self.quant(x)
x = self._forward_impl(x)
x = self.dequant(x)
return x
def fuse_model(self):
def fuse_model(self) -> None:
r"""Fuse conv/bn/relu modules in shufflenetv2 model
Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
......@@ -74,7 +79,15 @@ class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2):
)
def _shufflenetv2(arch, pretrained, progress, quantize, *args, **kwargs):
def _shufflenetv2(
arch: str,
pretrained: bool,
progress: bool,
quantize: bool,
*args: Any,
**kwargs: Any,
) -> QuantizableShuffleNetV2:
model = QuantizableShuffleNetV2(*args, **kwargs)
_replace_relu(model)
......@@ -98,7 +111,12 @@ def _shufflenetv2(arch, pretrained, progress, quantize, *args, **kwargs):
return model
def shufflenet_v2_x0_5(pretrained=False, progress=True, quantize=False, **kwargs):
def shufflenet_v2_x0_5(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableShuffleNetV2:
"""
Constructs a ShuffleNetV2 with 0.5x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
......@@ -113,7 +131,12 @@ def shufflenet_v2_x0_5(pretrained=False, progress=True, quantize=False, **kwargs
[4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
def shufflenet_v2_x1_0(pretrained=False, progress=True, quantize=False, **kwargs):
def shufflenet_v2_x1_0(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableShuffleNetV2:
"""
Constructs a ShuffleNetV2 with 1.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
......@@ -128,7 +151,12 @@ def shufflenet_v2_x1_0(pretrained=False, progress=True, quantize=False, **kwargs
[4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
def shufflenet_v2_x1_5(pretrained=False, progress=True, quantize=False, **kwargs):
def shufflenet_v2_x1_5(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableShuffleNetV2:
"""
Constructs a ShuffleNetV2 with 1.5x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
......@@ -143,7 +171,12 @@ def shufflenet_v2_x1_5(pretrained=False, progress=True, quantize=False, **kwargs
[4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
def shufflenet_v2_x2_0(pretrained=False, progress=True, quantize=False, **kwargs):
def shufflenet_v2_x2_0(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableShuffleNetV2:
"""
Constructs a ShuffleNetV2 with 2.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
......
......@@ -2,7 +2,7 @@ import torch
from torch import nn
def _replace_relu(module):
def _replace_relu(module: nn.Module) -> None:
reassign = {}
for name, mod in module.named_children():
_replace_relu(mod)
......@@ -16,7 +16,7 @@ def _replace_relu(module):
module._modules[key] = value
def quantize_model(model, backend):
def quantize_model(model: nn.Module, backend: str) -> None:
_dummy_input_data = torch.rand(1, 3, 299, 299)
if backend not in torch.backends.quantized.supported_engines:
raise RuntimeError("Quantized backend not supported ")
......@@ -24,15 +24,16 @@ def quantize_model(model, backend):
model.eval()
# Make sure that weight qconfig matches that of the serialized models
if backend == 'fbgemm':
model.qconfig = torch.quantization.QConfig(
model.qconfig = torch.quantization.QConfig( # type: ignore[assignment]
activation=torch.quantization.default_observer,
weight=torch.quantization.default_per_channel_weight_observer)
elif backend == 'qnnpack':
model.qconfig = torch.quantization.QConfig(
model.qconfig = torch.quantization.QConfig( # type: ignore[assignment]
activation=torch.quantization.default_observer,
weight=torch.quantization.default_weight_observer)
model.fuse_model()
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
model.fuse_model() # type: ignore[operator]
torch.quantization.prepare(model, inplace=True)
model(_dummy_input_data)
torch.quantization.convert(model, inplace=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment