Unverified Commit 8a16e12f authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Implement is_qat in TorchVision (#5299)

* Add is_qat support using a method getter

* Switch to an internal _fuse_modules

* Fix linter.

* Pass is_qat=False on PTQ

* Fix bug on ra_sampler flag.

* Set is_qat=True for QAT
parent 61a52b93
......@@ -178,7 +178,7 @@ def load_data(traindir, valdir, args):
print("Creating data loaders")
if args.distributed:
if args.ra_sampler:
if hasattr(args, "ra_sampler") and args.ra_sampler:
train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
else:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
......
......@@ -63,7 +63,7 @@ def main(args):
model.to(device)
if not (args.test_only or args.post_training_quantize):
model.fuse_model()
model.fuse_model(is_qat=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
torch.ao.quantization.prepare_qat(model, inplace=True)
......@@ -97,7 +97,7 @@ def main(args):
ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
)
model.eval()
model.fuse_model()
model.fuse_model(is_qat=False)
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
torch.ao.quantization.prepare(model, inplace=True)
# Calibrate first
......
......@@ -344,7 +344,7 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T
# Quantized Classification
model = M.quantization.mobilenet_v3_large(pretrained=False, quantize=False)
model.fuse_model()
model.fuse_model(is_qat=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
_ = torch.ao.quantization.prepare_qat(model, inplace=True)
print(store_model_weights(model, './qat.pth'))
......
......@@ -833,7 +833,7 @@ def test_quantized_classification_model(model_fn):
model.train()
model.qconfig = torch.ao.quantization.default_qat_qconfig
model.fuse_model()
model.fuse_model(is_qat=not eval_mode)
if eval_mode:
torch.ao.quantization.prepare(model, inplace=True)
else:
......
import warnings
from typing import Any
from typing import Any, Optional
import torch
import torch.nn as nn
......@@ -8,7 +8,7 @@ from torch.nn import functional as F
from torchvision.models.googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, model_urls
from ..._internally_replaced_utils import load_state_dict_from_url
from .utils import _replace_relu, quantize_model
from .utils import _fuse_modules, _replace_relu, quantize_model
__all__ = ["QuantizableGoogLeNet", "googlenet"]
......@@ -30,8 +30,8 @@ class QuantizableBasicConv2d(BasicConv2d):
x = self.relu(x)
return x
def fuse_model(self) -> None:
torch.ao.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)
def fuse_model(self, is_qat: Optional[bool] = None) -> None:
_fuse_modules(self, ["conv", "bn", "relu"], is_qat, inplace=True)
class QuantizableInception(Inception):
......@@ -90,7 +90,7 @@ class QuantizableGoogLeNet(GoogLeNet):
else:
return self.eager_outputs(x, aux2, aux1)
def fuse_model(self) -> None:
def fuse_model(self, is_qat: Optional[bool] = None) -> None:
r"""Fuse conv/bn/relu modules in googlenet model
Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
......@@ -100,7 +100,7 @@ class QuantizableGoogLeNet(GoogLeNet):
for m in self.modules():
if type(m) is QuantizableBasicConv2d:
m.fuse_model()
m.fuse_model(is_qat)
def googlenet(
......
import warnings
from typing import Any, List
from typing import Any, List, Optional
import torch
import torch.nn as nn
......@@ -9,7 +9,7 @@ from torchvision.models import inception as inception_module
from torchvision.models.inception import InceptionOutputs
from ..._internally_replaced_utils import load_state_dict_from_url
from .utils import _replace_relu, quantize_model
from .utils import _fuse_modules, _replace_relu, quantize_model
__all__ = [
......@@ -35,8 +35,8 @@ class QuantizableBasicConv2d(inception_module.BasicConv2d):
x = self.relu(x)
return x
def fuse_model(self) -> None:
torch.ao.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)
def fuse_model(self, is_qat: Optional[bool] = None) -> None:
_fuse_modules(self, ["conv", "bn", "relu"], is_qat, inplace=True)
class QuantizableInceptionA(inception_module.InceptionA):
......@@ -160,7 +160,7 @@ class QuantizableInception3(inception_module.Inception3):
else:
return self.eager_outputs(x, aux)
def fuse_model(self) -> None:
def fuse_model(self, is_qat: Optional[bool] = None) -> None:
r"""Fuse conv/bn/relu modules in inception model
Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
......@@ -170,7 +170,7 @@ class QuantizableInception3(inception_module.Inception3):
for m in self.modules():
if type(m) is QuantizableBasicConv2d:
m.fuse_model()
m.fuse_model(is_qat)
def inception_v3(
......
from typing import Any
from typing import Any, Optional
from torch import Tensor
from torch import nn
from torch.ao.quantization import QuantStub, DeQuantStub, fuse_modules
from torch.ao.quantization import QuantStub, DeQuantStub
from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls
from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation
from .utils import _replace_relu, quantize_model
from .utils import _fuse_modules, _replace_relu, quantize_model
__all__ = ["QuantizableMobileNetV2", "mobilenet_v2"]
......@@ -28,10 +28,10 @@ class QuantizableInvertedResidual(InvertedResidual):
else:
return self.conv(x)
def fuse_model(self) -> None:
def fuse_model(self, is_qat: Optional[bool] = None) -> None:
for idx in range(len(self.conv)):
if type(self.conv[idx]) is nn.Conv2d:
fuse_modules(self.conv, [str(idx), str(idx + 1)], inplace=True)
_fuse_modules(self.conv, [str(idx), str(idx + 1)], is_qat, inplace=True)
class QuantizableMobileNetV2(MobileNetV2):
......@@ -52,12 +52,12 @@ class QuantizableMobileNetV2(MobileNetV2):
x = self.dequant(x)
return x
def fuse_model(self) -> None:
def fuse_model(self, is_qat: Optional[bool] = None) -> None:
for m in self.modules():
if type(m) is ConvNormActivation:
fuse_modules(m, ["0", "1", "2"], inplace=True)
_fuse_modules(m, ["0", "1", "2"], is_qat, inplace=True)
if type(m) is QuantizableInvertedResidual:
m.fuse_model()
m.fuse_model(is_qat)
def mobilenet_v2(
......
......@@ -2,12 +2,12 @@ from typing import Any, List, Optional
import torch
from torch import nn, Tensor
from torch.ao.quantization import QuantStub, DeQuantStub, fuse_modules
from torch.ao.quantization import QuantStub, DeQuantStub
from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation, SqueezeExcitation
from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, model_urls, _mobilenet_v3_conf
from .utils import _replace_relu
from .utils import _fuse_modules, _replace_relu
__all__ = ["QuantizableMobileNetV3", "mobilenet_v3_large"]
......@@ -28,8 +28,8 @@ class QuantizableSqueezeExcitation(SqueezeExcitation):
def forward(self, input: Tensor) -> Tensor:
return self.skip_mul.mul(self._scale(input), input)
def fuse_model(self) -> None:
fuse_modules(self, ["fc1", "activation"], inplace=True)
def fuse_model(self, is_qat: Optional[bool] = None) -> None:
_fuse_modules(self, ["fc1", "activation"], is_qat, inplace=True)
def _load_from_state_dict(
self,
......@@ -101,15 +101,15 @@ class QuantizableMobileNetV3(MobileNetV3):
x = self.dequant(x)
return x
def fuse_model(self) -> None:
def fuse_model(self, is_qat: Optional[bool] = None) -> None:
for m in self.modules():
if type(m) is ConvNormActivation:
modules_to_fuse = ["0", "1"]
if len(m) == 3 and type(m[2]) is nn.ReLU:
modules_to_fuse.append("2")
fuse_modules(m, modules_to_fuse, inplace=True)
_fuse_modules(m, modules_to_fuse, is_qat, inplace=True)
elif type(m) is QuantizableSqueezeExcitation:
m.fuse_model()
m.fuse_model(is_qat)
def _load_weights(arch: str, model: QuantizableMobileNetV3, model_url: Optional[str], progress: bool) -> None:
......@@ -135,7 +135,7 @@ def _mobilenet_v3_model(
if quantize:
backend = "qnnpack"
model.fuse_model()
model.fuse_model(is_qat=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend)
torch.ao.quantization.prepare_qat(model, inplace=True)
......
from typing import Any, Type, Union, List
from typing import Any, Type, Union, List, Optional
import torch
import torch.nn as nn
from torch import Tensor
from torch.ao.quantization import fuse_modules
from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls
from ..._internally_replaced_utils import load_state_dict_from_url
from .utils import _replace_relu, quantize_model
from .utils import _fuse_modules, _replace_relu, quantize_model
__all__ = ["QuantizableResNet", "resnet18", "resnet50", "resnext101_32x8d"]
......@@ -41,10 +40,10 @@ class QuantizableBasicBlock(BasicBlock):
return out
def fuse_model(self) -> None:
torch.ao.quantization.fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], inplace=True)
def fuse_model(self, is_qat: Optional[bool] = None) -> None:
_fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], is_qat, inplace=True)
if self.downsample:
torch.ao.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True)
_fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True)
class QuantizableBottleneck(Bottleneck):
......@@ -72,10 +71,12 @@ class QuantizableBottleneck(Bottleneck):
return out
def fuse_model(self) -> None:
fuse_modules(self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], inplace=True)
def fuse_model(self, is_qat: Optional[bool] = None) -> None:
_fuse_modules(
self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], is_qat, inplace=True
)
if self.downsample:
torch.ao.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True)
_fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True)
class QuantizableResNet(ResNet):
......@@ -94,18 +95,17 @@ class QuantizableResNet(ResNet):
x = self.dequant(x)
return x
def fuse_model(self) -> None:
def fuse_model(self, is_qat: Optional[bool] = None) -> None:
r"""Fuse conv/bn/relu modules in resnet models
Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization.
Model is modified in place. Note that this operation does not change numerics
and the model after modification is in floating point
"""
fuse_modules(self, ["conv1", "bn1", "relu"], inplace=True)
_fuse_modules(self, ["conv1", "bn1", "relu"], is_qat, inplace=True)
for m in self.modules():
if type(m) is QuantizableBottleneck or type(m) is QuantizableBasicBlock:
m.fuse_model()
m.fuse_model(is_qat)
def _resnet(
......
......@@ -6,7 +6,7 @@ from torch import Tensor
from torchvision.models import shufflenetv2
from ..._internally_replaced_utils import load_state_dict_from_url
from .utils import _replace_relu, quantize_model
from .utils import _fuse_modules, _replace_relu, quantize_model
__all__ = [
"QuantizableShuffleNetV2",
......@@ -50,24 +50,24 @@ class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2):
x = self.dequant(x)
return x
def fuse_model(self) -> None:
def fuse_model(self, is_qat: Optional[bool] = None) -> None:
r"""Fuse conv/bn/relu modules in shufflenetv2 model
Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
Model is modified in place. Note that this operation does not change numerics
and the model after modification is in floating point
"""
for name, m in self._modules.items():
if name in ["conv1", "conv5"]:
torch.ao.quantization.fuse_modules(m, [["0", "1", "2"]], inplace=True)
if name in ["conv1", "conv5"] and m is not None:
_fuse_modules(m, [["0", "1", "2"]], is_qat, inplace=True)
for m in self.modules():
if type(m) is QuantizableInvertedResidual:
if len(m.branch1._modules.items()) > 0:
torch.ao.quantization.fuse_modules(m.branch1, [["0", "1"], ["2", "3", "4"]], inplace=True)
torch.ao.quantization.fuse_modules(
_fuse_modules(m.branch1, [["0", "1"], ["2", "3", "4"]], is_qat, inplace=True)
_fuse_modules(
m.branch2,
[["0", "1", "2"], ["3", "4"], ["5", "6", "7"]],
is_qat,
inplace=True,
)
......
from typing import Any, List, Optional, Union
import torch
from torch import nn
......@@ -39,4 +41,11 @@ def quantize_model(model: nn.Module, backend: str) -> None:
model(_dummy_input_data)
torch.ao.quantization.convert(model, inplace=True)
return
def _fuse_modules(
model: nn.Module, modules_to_fuse: Union[List[str], List[List[str]]], is_qat: Optional[bool], **kwargs: Any
):
if is_qat is None:
is_qat = model.training
method = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules
return method(model, modules_to_fuse, **kwargs)
......@@ -42,7 +42,7 @@ def _mobilenet_v3_model(
_replace_relu(model)
if quantize:
model.fuse_model()
model.fuse_model(is_qat=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend)
torch.ao.quantization.prepare_qat(model, inplace=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment