Unverified Commit 0725ccc8 authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Add typing annotations to models/quantization (#4232)



* fix

* add typings

* fixup some more types

* Type more

* remove mypy ignore

* add missing typings

* fix a few mypy errors

* fix mypy errors

* fix mypy

* ignore types

* fixup annotation

* fix remaining types

* cleanup #TODO comments
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent a8d2227f
...@@ -20,10 +20,6 @@ ignore_errors=True ...@@ -20,10 +20,6 @@ ignore_errors=True
ignore_errors = True ignore_errors = True
[mypy-torchvision.models.quantization.*]
ignore_errors = True
[mypy-torchvision.ops.*] [mypy-torchvision.ops.*]
ignore_errors = True ignore_errors = True
......
...@@ -2,6 +2,8 @@ import warnings ...@@ -2,6 +2,8 @@ import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from typing import Any
from torch import Tensor
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from torchvision.models.googlenet import ( from torchvision.models.googlenet import (
...@@ -18,7 +20,13 @@ quant_model_urls = { ...@@ -18,7 +20,13 @@ quant_model_urls = {
} }
def googlenet(pretrained=False, progress=True, quantize=False, **kwargs): def googlenet(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> "QuantizableGoogLeNet":
r"""GoogLeNet (Inception v1) model architecture from r"""GoogLeNet (Inception v1) model architecture from
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_. `"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
...@@ -70,48 +78,51 @@ def googlenet(pretrained=False, progress=True, quantize=False, **kwargs): ...@@ -70,48 +78,51 @@ def googlenet(pretrained=False, progress=True, quantize=False, **kwargs):
if not original_aux_logits: if not original_aux_logits:
model.aux_logits = False model.aux_logits = False
model.aux1 = None model.aux1 = None # type: ignore[assignment]
model.aux2 = None model.aux2 = None # type: ignore[assignment]
return model return model
class QuantizableBasicConv2d(BasicConv2d): class QuantizableBasicConv2d(BasicConv2d):
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableBasicConv2d, self).__init__(*args, **kwargs) super(QuantizableBasicConv2d, self).__init__(*args, **kwargs)
self.relu = nn.ReLU() self.relu = nn.ReLU()
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
x = self.conv(x) x = self.conv(x)
x = self.bn(x) x = self.bn(x)
x = self.relu(x) x = self.relu(x)
return x return x
def fuse_model(self): def fuse_model(self) -> None:
torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True) torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)
class QuantizableInception(Inception): class QuantizableInception(Inception):
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInception, self).__init__( super(QuantizableInception, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d, *args, **kwargs) conv_block=QuantizableBasicConv2d, *args, **kwargs)
self.cat = nn.quantized.FloatFunctional() self.cat = nn.quantized.FloatFunctional()
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x) outputs = self._forward(x)
return self.cat.cat(outputs, 1) return self.cat.cat(outputs, 1)
class QuantizableInceptionAux(InceptionAux): class QuantizableInceptionAux(InceptionAux):
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionAux, self).__init__( super(QuantizableInceptionAux, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d, *args, **kwargs) conv_block=QuantizableBasicConv2d,
*args,
**kwargs
)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.7) self.dropout = nn.Dropout(0.7)
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
# aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14 # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
x = F.adaptive_avg_pool2d(x, (4, 4)) x = F.adaptive_avg_pool2d(x, (4, 4))
# aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4 # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
...@@ -130,9 +141,9 @@ class QuantizableInceptionAux(InceptionAux): ...@@ -130,9 +141,9 @@ class QuantizableInceptionAux(InceptionAux):
class QuantizableGoogLeNet(GoogLeNet): class QuantizableGoogLeNet(GoogLeNet):
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableGoogLeNet, self).__init__( super(QuantizableGoogLeNet, self).__init__( # type: ignore[misc]
blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux], blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux],
*args, *args,
**kwargs **kwargs
...@@ -140,7 +151,7 @@ class QuantizableGoogLeNet(GoogLeNet): ...@@ -140,7 +151,7 @@ class QuantizableGoogLeNet(GoogLeNet):
self.quant = torch.quantization.QuantStub() self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub() self.dequant = torch.quantization.DeQuantStub()
def forward(self, x): def forward(self, x: Tensor) -> GoogLeNetOutputs:
x = self._transform_input(x) x = self._transform_input(x)
x = self.quant(x) x = self.quant(x)
x, aux1, aux2 = self._forward(x) x, aux1, aux2 = self._forward(x)
...@@ -153,7 +164,7 @@ class QuantizableGoogLeNet(GoogLeNet): ...@@ -153,7 +164,7 @@ class QuantizableGoogLeNet(GoogLeNet):
else: else:
return self.eager_outputs(x, aux2, aux1) return self.eager_outputs(x, aux2, aux1)
def fuse_model(self): def fuse_model(self) -> None:
r"""Fuse conv/bn/relu modules in googlenet model r"""Fuse conv/bn/relu modules in googlenet model
Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization. Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
......
...@@ -3,6 +3,9 @@ import warnings ...@@ -3,6 +3,9 @@ import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor
from typing import Any, List
from torchvision.models import inception as inception_module from torchvision.models import inception as inception_module
from torchvision.models.inception import InceptionOutputs from torchvision.models.inception import InceptionOutputs
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
...@@ -22,7 +25,13 @@ quant_model_urls = { ...@@ -22,7 +25,13 @@ quant_model_urls = {
} }
def inception_v3(pretrained=False, progress=True, quantize=False, **kwargs): def inception_v3(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> "QuantizableInception3":
r"""Inception v3 model architecture from r"""Inception v3 model architecture from
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_. `"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
...@@ -84,68 +93,93 @@ def inception_v3(pretrained=False, progress=True, quantize=False, **kwargs): ...@@ -84,68 +93,93 @@ def inception_v3(pretrained=False, progress=True, quantize=False, **kwargs):
class QuantizableBasicConv2d(inception_module.BasicConv2d): class QuantizableBasicConv2d(inception_module.BasicConv2d):
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableBasicConv2d, self).__init__(*args, **kwargs) super(QuantizableBasicConv2d, self).__init__(*args, **kwargs)
self.relu = nn.ReLU() self.relu = nn.ReLU()
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
x = self.conv(x) x = self.conv(x)
x = self.bn(x) x = self.bn(x)
x = self.relu(x) x = self.relu(x)
return x return x
def fuse_model(self): def fuse_model(self) -> None:
torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True) torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)
class QuantizableInceptionA(inception_module.InceptionA): class QuantizableInceptionA(inception_module.InceptionA):
def __init__(self, *args, **kwargs): # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
super(QuantizableInceptionA, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs) def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionA, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
)
self.myop = nn.quantized.FloatFunctional() self.myop = nn.quantized.FloatFunctional()
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x) outputs = self._forward(x)
return self.myop.cat(outputs, 1) return self.myop.cat(outputs, 1)
class QuantizableInceptionB(inception_module.InceptionB): class QuantizableInceptionB(inception_module.InceptionB):
def __init__(self, *args, **kwargs): # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
super(QuantizableInceptionB, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs) def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionB, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
)
self.myop = nn.quantized.FloatFunctional() self.myop = nn.quantized.FloatFunctional()
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x) outputs = self._forward(x)
return self.myop.cat(outputs, 1) return self.myop.cat(outputs, 1)
class QuantizableInceptionC(inception_module.InceptionC): class QuantizableInceptionC(inception_module.InceptionC):
def __init__(self, *args, **kwargs): # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
super(QuantizableInceptionC, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs) def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionC, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
)
self.myop = nn.quantized.FloatFunctional() self.myop = nn.quantized.FloatFunctional()
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x) outputs = self._forward(x)
return self.myop.cat(outputs, 1) return self.myop.cat(outputs, 1)
class QuantizableInceptionD(inception_module.InceptionD): class QuantizableInceptionD(inception_module.InceptionD):
def __init__(self, *args, **kwargs): # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
super(QuantizableInceptionD, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs) def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionD, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
)
self.myop = nn.quantized.FloatFunctional() self.myop = nn.quantized.FloatFunctional()
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x) outputs = self._forward(x)
return self.myop.cat(outputs, 1) return self.myop.cat(outputs, 1)
class QuantizableInceptionE(inception_module.InceptionE): class QuantizableInceptionE(inception_module.InceptionE):
def __init__(self, *args, **kwargs): # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
super(QuantizableInceptionE, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs) def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionE, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
)
self.myop1 = nn.quantized.FloatFunctional() self.myop1 = nn.quantized.FloatFunctional()
self.myop2 = nn.quantized.FloatFunctional() self.myop2 = nn.quantized.FloatFunctional()
self.myop3 = nn.quantized.FloatFunctional() self.myop3 = nn.quantized.FloatFunctional()
def _forward(self, x): def _forward(self, x: Tensor) -> List[Tensor]:
branch1x1 = self.branch1x1(x) branch1x1 = self.branch1x1(x)
branch3x3 = self.branch3x3_1(x) branch3x3 = self.branch3x3_1(x)
...@@ -166,18 +200,28 @@ class QuantizableInceptionE(inception_module.InceptionE): ...@@ -166,18 +200,28 @@ class QuantizableInceptionE(inception_module.InceptionE):
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
return outputs return outputs
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x) outputs = self._forward(x)
return self.myop3.cat(outputs, 1) return self.myop3.cat(outputs, 1)
class QuantizableInceptionAux(inception_module.InceptionAux): class QuantizableInceptionAux(inception_module.InceptionAux):
def __init__(self, *args, **kwargs): # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
super(QuantizableInceptionAux, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs) def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionAux, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
)
class QuantizableInception3(inception_module.Inception3): class QuantizableInception3(inception_module.Inception3):
def __init__(self, num_classes=1000, aux_logits=True, transform_input=False): def __init__(
self,
num_classes: int = 1000,
aux_logits: bool = True,
transform_input: bool = False,
) -> None:
super(QuantizableInception3, self).__init__( super(QuantizableInception3, self).__init__(
num_classes=num_classes, num_classes=num_classes,
aux_logits=aux_logits, aux_logits=aux_logits,
...@@ -195,7 +239,7 @@ class QuantizableInception3(inception_module.Inception3): ...@@ -195,7 +239,7 @@ class QuantizableInception3(inception_module.Inception3):
self.quant = torch.quantization.QuantStub() self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub() self.dequant = torch.quantization.DeQuantStub()
def forward(self, x): def forward(self, x: Tensor) -> InceptionOutputs:
x = self._transform_input(x) x = self._transform_input(x)
x = self.quant(x) x = self.quant(x)
x, aux = self._forward(x) x, aux = self._forward(x)
...@@ -208,7 +252,7 @@ class QuantizableInception3(inception_module.Inception3): ...@@ -208,7 +252,7 @@ class QuantizableInception3(inception_module.Inception3):
else: else:
return self.eager_outputs(x, aux) return self.eager_outputs(x, aux)
def fuse_model(self): def fuse_model(self) -> None:
r"""Fuse conv/bn/relu modules in inception model r"""Fuse conv/bn/relu modules in inception model
Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization. Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
......
from torch import nn from torch import nn
from torch import Tensor
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from typing import Any
from torchvision.models.mobilenetv2 import InvertedResidual, ConvBNReLU, MobileNetV2, model_urls from torchvision.models.mobilenetv2 import InvertedResidual, ConvBNReLU, MobileNetV2, model_urls
from torch.quantization import QuantStub, DeQuantStub, fuse_modules from torch.quantization import QuantStub, DeQuantStub, fuse_modules
from .utils import _replace_relu, quantize_model from .utils import _replace_relu, quantize_model
...@@ -14,24 +19,24 @@ quant_model_urls = { ...@@ -14,24 +19,24 @@ quant_model_urls = {
class QuantizableInvertedResidual(InvertedResidual): class QuantizableInvertedResidual(InvertedResidual):
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInvertedResidual, self).__init__(*args, **kwargs) super(QuantizableInvertedResidual, self).__init__(*args, **kwargs)
self.skip_add = nn.quantized.FloatFunctional() self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
if self.use_res_connect: if self.use_res_connect:
return self.skip_add.add(x, self.conv(x)) return self.skip_add.add(x, self.conv(x))
else: else:
return self.conv(x) return self.conv(x)
def fuse_model(self): def fuse_model(self) -> None:
for idx in range(len(self.conv)): for idx in range(len(self.conv)):
if type(self.conv[idx]) == nn.Conv2d: if type(self.conv[idx]) == nn.Conv2d:
fuse_modules(self.conv, [str(idx), str(idx + 1)], inplace=True) fuse_modules(self.conv, [str(idx), str(idx + 1)], inplace=True)
class QuantizableMobileNetV2(MobileNetV2): class QuantizableMobileNetV2(MobileNetV2):
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
""" """
MobileNet V2 main class MobileNet V2 main class
...@@ -42,13 +47,13 @@ class QuantizableMobileNetV2(MobileNetV2): ...@@ -42,13 +47,13 @@ class QuantizableMobileNetV2(MobileNetV2):
self.quant = QuantStub() self.quant = QuantStub()
self.dequant = DeQuantStub() self.dequant = DeQuantStub()
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
x = self.quant(x) x = self.quant(x)
x = self._forward_impl(x) x = self._forward_impl(x)
x = self.dequant(x) x = self.dequant(x)
return x return x
def fuse_model(self): def fuse_model(self) -> None:
for m in self.modules(): for m in self.modules():
if type(m) == ConvBNReLU: if type(m) == ConvBNReLU:
fuse_modules(m, ['0', '1', '2'], inplace=True) fuse_modules(m, ['0', '1', '2'], inplace=True)
...@@ -56,7 +61,12 @@ class QuantizableMobileNetV2(MobileNetV2): ...@@ -56,7 +61,12 @@ class QuantizableMobileNetV2(MobileNetV2):
m.fuse_model() m.fuse_model()
def mobilenet_v2(pretrained=False, progress=True, quantize=False, **kwargs): def mobilenet_v2(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableMobileNetV2:
""" """
Constructs a MobileNetV2 architecture from Constructs a MobileNetV2 architecture from
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `"MobileNetV2: Inverted Residuals and Linear Bottlenecks"
......
...@@ -17,23 +17,28 @@ quant_model_urls = { ...@@ -17,23 +17,28 @@ quant_model_urls = {
class QuantizableSqueezeExcitation(SqueezeExcitation): class QuantizableSqueezeExcitation(SqueezeExcitation):
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.skip_mul = nn.quantized.FloatFunctional() self.skip_mul = nn.quantized.FloatFunctional()
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return self.skip_mul.mul(self._scale(input, False), input) return self.skip_mul.mul(self._scale(input, False), input)
def fuse_model(self): def fuse_model(self) -> None:
fuse_modules(self, ['fc1', 'relu'], inplace=True) fuse_modules(self, ['fc1', 'relu'], inplace=True)
class QuantizableInvertedResidual(InvertedResidual): class QuantizableInvertedResidual(InvertedResidual):
def __init__(self, *args, **kwargs): # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
super().__init__(*args, se_layer=QuantizableSqueezeExcitation, **kwargs) def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__( # type: ignore[misc]
se_layer=QuantizableSqueezeExcitation,
*args,
**kwargs
)
self.skip_add = nn.quantized.FloatFunctional() self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
if self.use_res_connect: if self.use_res_connect:
return self.skip_add.add(x, self.block(x)) return self.skip_add.add(x, self.block(x))
else: else:
...@@ -41,7 +46,7 @@ class QuantizableInvertedResidual(InvertedResidual): ...@@ -41,7 +46,7 @@ class QuantizableInvertedResidual(InvertedResidual):
class QuantizableMobileNetV3(MobileNetV3): class QuantizableMobileNetV3(MobileNetV3):
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
""" """
MobileNet V3 main class MobileNet V3 main class
...@@ -52,13 +57,13 @@ class QuantizableMobileNetV3(MobileNetV3): ...@@ -52,13 +57,13 @@ class QuantizableMobileNetV3(MobileNetV3):
self.quant = QuantStub() self.quant = QuantStub()
self.dequant = DeQuantStub() self.dequant = DeQuantStub()
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
x = self.quant(x) x = self.quant(x)
x = self._forward_impl(x) x = self._forward_impl(x)
x = self.dequant(x) x = self.dequant(x)
return x return x
def fuse_model(self): def fuse_model(self) -> None:
for m in self.modules(): for m in self.modules():
if type(m) == ConvBNActivation: if type(m) == ConvBNActivation:
modules_to_fuse = ['0', '1'] modules_to_fuse = ['0', '1']
...@@ -74,7 +79,7 @@ def _load_weights( ...@@ -74,7 +79,7 @@ def _load_weights(
model: QuantizableMobileNetV3, model: QuantizableMobileNetV3,
model_url: Optional[str], model_url: Optional[str],
progress: bool, progress: bool,
): ) -> None:
if model_url is None: if model_url is None:
raise ValueError("No checkpoint is available for {}".format(arch)) raise ValueError("No checkpoint is available for {}".format(arch))
state_dict = load_state_dict_from_url(model_url, progress=progress) state_dict = load_state_dict_from_url(model_url, progress=progress)
...@@ -88,8 +93,9 @@ def _mobilenet_v3_model( ...@@ -88,8 +93,9 @@ def _mobilenet_v3_model(
pretrained: bool, pretrained: bool,
progress: bool, progress: bool,
quantize: bool, quantize: bool,
**kwargs: Any **kwargs: Any,
): ) -> QuantizableMobileNetV3:
model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs) model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs)
_replace_relu(model) _replace_relu(model)
...@@ -112,7 +118,12 @@ def _mobilenet_v3_model( ...@@ -112,7 +118,12 @@ def _mobilenet_v3_model(
return model return model
def mobilenet_v3_large(pretrained=False, progress=True, quantize=False, **kwargs): def mobilenet_v3_large(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableMobileNetV3:
""" """
Constructs a MobileNetV3 Large architecture from Constructs a MobileNetV3 Large architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_. `"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
......
import torch import torch
from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from typing import Any, Type, Union, List
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from torch.quantization import fuse_modules from torch.quantization import fuse_modules
from .utils import _replace_relu, quantize_model from .utils import _replace_relu, quantize_model
...@@ -20,11 +23,11 @@ quant_model_urls = { ...@@ -20,11 +23,11 @@ quant_model_urls = {
class QuantizableBasicBlock(BasicBlock): class QuantizableBasicBlock(BasicBlock):
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableBasicBlock, self).__init__(*args, **kwargs) super(QuantizableBasicBlock, self).__init__(*args, **kwargs)
self.add_relu = torch.nn.quantized.FloatFunctional() self.add_relu = torch.nn.quantized.FloatFunctional()
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
identity = x identity = x
out = self.conv1(x) out = self.conv1(x)
...@@ -41,7 +44,7 @@ class QuantizableBasicBlock(BasicBlock): ...@@ -41,7 +44,7 @@ class QuantizableBasicBlock(BasicBlock):
return out return out
def fuse_model(self): def fuse_model(self) -> None:
torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu'], torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu'],
['conv2', 'bn2']], inplace=True) ['conv2', 'bn2']], inplace=True)
if self.downsample: if self.downsample:
...@@ -49,13 +52,13 @@ class QuantizableBasicBlock(BasicBlock): ...@@ -49,13 +52,13 @@ class QuantizableBasicBlock(BasicBlock):
class QuantizableBottleneck(Bottleneck): class QuantizableBottleneck(Bottleneck):
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableBottleneck, self).__init__(*args, **kwargs) super(QuantizableBottleneck, self).__init__(*args, **kwargs)
self.skip_add_relu = nn.quantized.FloatFunctional() self.skip_add_relu = nn.quantized.FloatFunctional()
self.relu1 = nn.ReLU(inplace=False) self.relu1 = nn.ReLU(inplace=False)
self.relu2 = nn.ReLU(inplace=False) self.relu2 = nn.ReLU(inplace=False)
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
identity = x identity = x
out = self.conv1(x) out = self.conv1(x)
out = self.bn1(out) out = self.bn1(out)
...@@ -73,7 +76,7 @@ class QuantizableBottleneck(Bottleneck): ...@@ -73,7 +76,7 @@ class QuantizableBottleneck(Bottleneck):
return out return out
def fuse_model(self): def fuse_model(self) -> None:
fuse_modules(self, [['conv1', 'bn1', 'relu1'], fuse_modules(self, [['conv1', 'bn1', 'relu1'],
['conv2', 'bn2', 'relu2'], ['conv2', 'bn2', 'relu2'],
['conv3', 'bn3']], inplace=True) ['conv3', 'bn3']], inplace=True)
...@@ -83,13 +86,13 @@ class QuantizableBottleneck(Bottleneck): ...@@ -83,13 +86,13 @@ class QuantizableBottleneck(Bottleneck):
class QuantizableResNet(ResNet): class QuantizableResNet(ResNet):
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableResNet, self).__init__(*args, **kwargs) super(QuantizableResNet, self).__init__(*args, **kwargs)
self.quant = torch.quantization.QuantStub() self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub() self.dequant = torch.quantization.DeQuantStub()
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
x = self.quant(x) x = self.quant(x)
# Ensure scriptability # Ensure scriptability
# super(QuantizableResNet,self).forward(x) # super(QuantizableResNet,self).forward(x)
...@@ -98,7 +101,7 @@ class QuantizableResNet(ResNet): ...@@ -98,7 +101,7 @@ class QuantizableResNet(ResNet):
x = self.dequant(x) x = self.dequant(x)
return x return x
def fuse_model(self): def fuse_model(self) -> None:
r"""Fuse conv/bn/relu modules in resnet models r"""Fuse conv/bn/relu modules in resnet models
Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization. Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization.
...@@ -112,7 +115,16 @@ class QuantizableResNet(ResNet): ...@@ -112,7 +115,16 @@ class QuantizableResNet(ResNet):
m.fuse_model() m.fuse_model()
def _resnet(arch, block, layers, pretrained, progress, quantize, **kwargs): def _resnet(
arch: str,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
pretrained: bool,
progress: bool,
quantize: bool,
**kwargs: Any,
) -> QuantizableResNet:
model = QuantizableResNet(block, layers, **kwargs) model = QuantizableResNet(block, layers, **kwargs)
_replace_relu(model) _replace_relu(model)
if quantize: if quantize:
...@@ -135,7 +147,12 @@ def _resnet(arch, block, layers, pretrained, progress, quantize, **kwargs): ...@@ -135,7 +147,12 @@ def _resnet(arch, block, layers, pretrained, progress, quantize, **kwargs):
return model return model
def resnet18(pretrained=False, progress=True, quantize=False, **kwargs): def resnet18(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableResNet:
r"""ResNet-18 model from r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
...@@ -148,7 +165,13 @@ def resnet18(pretrained=False, progress=True, quantize=False, **kwargs): ...@@ -148,7 +165,13 @@ def resnet18(pretrained=False, progress=True, quantize=False, **kwargs):
quantize, **kwargs) quantize, **kwargs)
def resnet50(pretrained=False, progress=True, quantize=False, **kwargs): def resnet50(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableResNet:
r"""ResNet-50 model from r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
...@@ -161,7 +184,12 @@ def resnet50(pretrained=False, progress=True, quantize=False, **kwargs): ...@@ -161,7 +184,12 @@ def resnet50(pretrained=False, progress=True, quantize=False, **kwargs):
quantize, **kwargs) quantize, **kwargs)
def resnext101_32x8d(pretrained=False, progress=True, quantize=False, **kwargs): def resnext101_32x8d(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableResNet:
r"""ResNeXt-101 32x8d model from r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_ `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from typing import Any
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
import torchvision.models.shufflenetv2 from torchvision.models import shufflenetv2
import sys
from .utils import _replace_relu, quantize_model from .utils import _replace_relu, quantize_model
shufflenetv2 = sys.modules['torchvision.models.shufflenetv2']
__all__ = [ __all__ = [
'QuantizableShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'QuantizableShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0',
'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0' 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'
...@@ -22,16 +22,16 @@ quant_model_urls = { ...@@ -22,16 +22,16 @@ quant_model_urls = {
class QuantizableInvertedResidual(shufflenetv2.InvertedResidual): class QuantizableInvertedResidual(shufflenetv2.InvertedResidual):
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInvertedResidual, self).__init__(*args, **kwargs) super(QuantizableInvertedResidual, self).__init__(*args, **kwargs)
self.cat = nn.quantized.FloatFunctional() self.cat = nn.quantized.FloatFunctional()
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
if self.stride == 1: if self.stride == 1:
x1, x2 = x.chunk(2, dim=1) x1, x2 = x.chunk(2, dim=1)
out = self.cat.cat((x1, self.branch2(x2)), dim=1) out = self.cat.cat([x1, self.branch2(x2)], dim=1)
else: else:
out = self.cat.cat((self.branch1(x), self.branch2(x)), dim=1) out = self.cat.cat([self.branch1(x), self.branch2(x)], dim=1)
out = shufflenetv2.channel_shuffle(out, 2) out = shufflenetv2.channel_shuffle(out, 2)
...@@ -39,18 +39,23 @@ class QuantizableInvertedResidual(shufflenetv2.InvertedResidual): ...@@ -39,18 +39,23 @@ class QuantizableInvertedResidual(shufflenetv2.InvertedResidual):
class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2): class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2):
def __init__(self, *args, **kwargs): # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
super(QuantizableShuffleNetV2, self).__init__(*args, inverted_residual=QuantizableInvertedResidual, **kwargs) def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableShuffleNetV2, self).__init__( # type: ignore[misc]
*args,
inverted_residual=QuantizableInvertedResidual,
**kwargs
)
self.quant = torch.quantization.QuantStub() self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub() self.dequant = torch.quantization.DeQuantStub()
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
x = self.quant(x) x = self.quant(x)
x = self._forward_impl(x) x = self._forward_impl(x)
x = self.dequant(x) x = self.dequant(x)
return x return x
def fuse_model(self): def fuse_model(self) -> None:
r"""Fuse conv/bn/relu modules in shufflenetv2 model r"""Fuse conv/bn/relu modules in shufflenetv2 model
Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization. Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
...@@ -74,7 +79,15 @@ class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2): ...@@ -74,7 +79,15 @@ class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2):
) )
def _shufflenetv2(arch, pretrained, progress, quantize, *args, **kwargs): def _shufflenetv2(
arch: str,
pretrained: bool,
progress: bool,
quantize: bool,
*args: Any,
**kwargs: Any,
) -> QuantizableShuffleNetV2:
model = QuantizableShuffleNetV2(*args, **kwargs) model = QuantizableShuffleNetV2(*args, **kwargs)
_replace_relu(model) _replace_relu(model)
...@@ -98,7 +111,12 @@ def _shufflenetv2(arch, pretrained, progress, quantize, *args, **kwargs): ...@@ -98,7 +111,12 @@ def _shufflenetv2(arch, pretrained, progress, quantize, *args, **kwargs):
return model return model
def shufflenet_v2_x0_5(pretrained=False, progress=True, quantize=False, **kwargs): def shufflenet_v2_x0_5(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableShuffleNetV2:
""" """
Constructs a ShuffleNetV2 with 0.5x output channels, as described in Constructs a ShuffleNetV2 with 0.5x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
...@@ -113,7 +131,12 @@ def shufflenet_v2_x0_5(pretrained=False, progress=True, quantize=False, **kwargs ...@@ -113,7 +131,12 @@ def shufflenet_v2_x0_5(pretrained=False, progress=True, quantize=False, **kwargs
[4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
def shufflenet_v2_x1_0(pretrained=False, progress=True, quantize=False, **kwargs): def shufflenet_v2_x1_0(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableShuffleNetV2:
""" """
Constructs a ShuffleNetV2 with 1.0x output channels, as described in Constructs a ShuffleNetV2 with 1.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
...@@ -128,7 +151,12 @@ def shufflenet_v2_x1_0(pretrained=False, progress=True, quantize=False, **kwargs ...@@ -128,7 +151,12 @@ def shufflenet_v2_x1_0(pretrained=False, progress=True, quantize=False, **kwargs
[4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
def shufflenet_v2_x1_5(pretrained=False, progress=True, quantize=False, **kwargs): def shufflenet_v2_x1_5(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableShuffleNetV2:
""" """
Constructs a ShuffleNetV2 with 1.5x output channels, as described in Constructs a ShuffleNetV2 with 1.5x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
...@@ -143,7 +171,12 @@ def shufflenet_v2_x1_5(pretrained=False, progress=True, quantize=False, **kwargs ...@@ -143,7 +171,12 @@ def shufflenet_v2_x1_5(pretrained=False, progress=True, quantize=False, **kwargs
[4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
def shufflenet_v2_x2_0(pretrained=False, progress=True, quantize=False, **kwargs): def shufflenet_v2_x2_0(
pretrained: bool = False,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableShuffleNetV2:
""" """
Constructs a ShuffleNetV2 with 2.0x output channels, as described in Constructs a ShuffleNetV2 with 2.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
......
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
from torch import nn from torch import nn
def _replace_relu(module): def _replace_relu(module: nn.Module) -> None:
reassign = {} reassign = {}
for name, mod in module.named_children(): for name, mod in module.named_children():
_replace_relu(mod) _replace_relu(mod)
...@@ -16,7 +16,7 @@ def _replace_relu(module): ...@@ -16,7 +16,7 @@ def _replace_relu(module):
module._modules[key] = value module._modules[key] = value
def quantize_model(model, backend): def quantize_model(model: nn.Module, backend: str) -> None:
_dummy_input_data = torch.rand(1, 3, 299, 299) _dummy_input_data = torch.rand(1, 3, 299, 299)
if backend not in torch.backends.quantized.supported_engines: if backend not in torch.backends.quantized.supported_engines:
raise RuntimeError("Quantized backend not supported ") raise RuntimeError("Quantized backend not supported ")
...@@ -24,15 +24,16 @@ def quantize_model(model, backend): ...@@ -24,15 +24,16 @@ def quantize_model(model, backend):
model.eval() model.eval()
# Make sure that weight qconfig matches that of the serialized models # Make sure that weight qconfig matches that of the serialized models
if backend == 'fbgemm': if backend == 'fbgemm':
model.qconfig = torch.quantization.QConfig( model.qconfig = torch.quantization.QConfig( # type: ignore[assignment]
activation=torch.quantization.default_observer, activation=torch.quantization.default_observer,
weight=torch.quantization.default_per_channel_weight_observer) weight=torch.quantization.default_per_channel_weight_observer)
elif backend == 'qnnpack': elif backend == 'qnnpack':
model.qconfig = torch.quantization.QConfig( model.qconfig = torch.quantization.QConfig( # type: ignore[assignment]
activation=torch.quantization.default_observer, activation=torch.quantization.default_observer,
weight=torch.quantization.default_weight_observer) weight=torch.quantization.default_weight_observer)
model.fuse_model() # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
model.fuse_model() # type: ignore[operator]
torch.quantization.prepare(model, inplace=True) torch.quantization.prepare(model, inplace=True)
model(_dummy_input_data) model(_dummy_input_data)
torch.quantization.convert(model, inplace=True) torch.quantization.convert(model, inplace=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment