"cacheflow/git@developer.sourcefind.cn:norm/vllm.git" did not exist on "608f74ffe5635cd8ac7fc78a13e3d67a4f59b9e0"
Unverified Commit 8a16e12f authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Implement is_qat in TorchVision (#5299)

* Add is_qat support using a method getter

* Switch to an internal _fuse_modules

* Fix linter.

* Pass is_qat=False on PTQ

* Fix bug on ra_sampler flag.

* Set is_qat=True for QAT
parent 61a52b93
...@@ -178,7 +178,7 @@ def load_data(traindir, valdir, args): ...@@ -178,7 +178,7 @@ def load_data(traindir, valdir, args):
print("Creating data loaders") print("Creating data loaders")
if args.distributed: if args.distributed:
if args.ra_sampler: if hasattr(args, "ra_sampler") and args.ra_sampler:
train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps) train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
else: else:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
......
...@@ -63,7 +63,7 @@ def main(args): ...@@ -63,7 +63,7 @@ def main(args):
model.to(device) model.to(device)
if not (args.test_only or args.post_training_quantize): if not (args.test_only or args.post_training_quantize):
model.fuse_model() model.fuse_model(is_qat=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend) model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
torch.ao.quantization.prepare_qat(model, inplace=True) torch.ao.quantization.prepare_qat(model, inplace=True)
...@@ -97,7 +97,7 @@ def main(args): ...@@ -97,7 +97,7 @@ def main(args):
ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
) )
model.eval() model.eval()
model.fuse_model() model.fuse_model(is_qat=False)
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend) model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
torch.ao.quantization.prepare(model, inplace=True) torch.ao.quantization.prepare(model, inplace=True)
# Calibrate first # Calibrate first
......
...@@ -344,7 +344,7 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T ...@@ -344,7 +344,7 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T
# Quantized Classification # Quantized Classification
model = M.quantization.mobilenet_v3_large(pretrained=False, quantize=False) model = M.quantization.mobilenet_v3_large(pretrained=False, quantize=False)
model.fuse_model() model.fuse_model(is_qat=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack') model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
_ = torch.ao.quantization.prepare_qat(model, inplace=True) _ = torch.ao.quantization.prepare_qat(model, inplace=True)
print(store_model_weights(model, './qat.pth')) print(store_model_weights(model, './qat.pth'))
......
...@@ -833,7 +833,7 @@ def test_quantized_classification_model(model_fn): ...@@ -833,7 +833,7 @@ def test_quantized_classification_model(model_fn):
model.train() model.train()
model.qconfig = torch.ao.quantization.default_qat_qconfig model.qconfig = torch.ao.quantization.default_qat_qconfig
model.fuse_model() model.fuse_model(is_qat=not eval_mode)
if eval_mode: if eval_mode:
torch.ao.quantization.prepare(model, inplace=True) torch.ao.quantization.prepare(model, inplace=True)
else: else:
......
import warnings import warnings
from typing import Any from typing import Any, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -8,7 +8,7 @@ from torch.nn import functional as F ...@@ -8,7 +8,7 @@ from torch.nn import functional as F
from torchvision.models.googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, model_urls from torchvision.models.googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, model_urls
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from .utils import _replace_relu, quantize_model from .utils import _fuse_modules, _replace_relu, quantize_model
__all__ = ["QuantizableGoogLeNet", "googlenet"] __all__ = ["QuantizableGoogLeNet", "googlenet"]
...@@ -30,8 +30,8 @@ class QuantizableBasicConv2d(BasicConv2d): ...@@ -30,8 +30,8 @@ class QuantizableBasicConv2d(BasicConv2d):
x = self.relu(x) x = self.relu(x)
return x return x
def fuse_model(self) -> None: def fuse_model(self, is_qat: Optional[bool] = None) -> None:
torch.ao.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True) _fuse_modules(self, ["conv", "bn", "relu"], is_qat, inplace=True)
class QuantizableInception(Inception): class QuantizableInception(Inception):
...@@ -90,7 +90,7 @@ class QuantizableGoogLeNet(GoogLeNet): ...@@ -90,7 +90,7 @@ class QuantizableGoogLeNet(GoogLeNet):
else: else:
return self.eager_outputs(x, aux2, aux1) return self.eager_outputs(x, aux2, aux1)
def fuse_model(self) -> None: def fuse_model(self, is_qat: Optional[bool] = None) -> None:
r"""Fuse conv/bn/relu modules in googlenet model r"""Fuse conv/bn/relu modules in googlenet model
Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization. Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
...@@ -100,7 +100,7 @@ class QuantizableGoogLeNet(GoogLeNet): ...@@ -100,7 +100,7 @@ class QuantizableGoogLeNet(GoogLeNet):
for m in self.modules(): for m in self.modules():
if type(m) is QuantizableBasicConv2d: if type(m) is QuantizableBasicConv2d:
m.fuse_model() m.fuse_model(is_qat)
def googlenet( def googlenet(
......
import warnings import warnings
from typing import Any, List from typing import Any, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -9,7 +9,7 @@ from torchvision.models import inception as inception_module ...@@ -9,7 +9,7 @@ from torchvision.models import inception as inception_module
from torchvision.models.inception import InceptionOutputs from torchvision.models.inception import InceptionOutputs
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from .utils import _replace_relu, quantize_model from .utils import _fuse_modules, _replace_relu, quantize_model
__all__ = [ __all__ = [
...@@ -35,8 +35,8 @@ class QuantizableBasicConv2d(inception_module.BasicConv2d): ...@@ -35,8 +35,8 @@ class QuantizableBasicConv2d(inception_module.BasicConv2d):
x = self.relu(x) x = self.relu(x)
return x return x
def fuse_model(self) -> None: def fuse_model(self, is_qat: Optional[bool] = None) -> None:
torch.ao.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True) _fuse_modules(self, ["conv", "bn", "relu"], is_qat, inplace=True)
class QuantizableInceptionA(inception_module.InceptionA): class QuantizableInceptionA(inception_module.InceptionA):
...@@ -160,7 +160,7 @@ class QuantizableInception3(inception_module.Inception3): ...@@ -160,7 +160,7 @@ class QuantizableInception3(inception_module.Inception3):
else: else:
return self.eager_outputs(x, aux) return self.eager_outputs(x, aux)
def fuse_model(self) -> None: def fuse_model(self, is_qat: Optional[bool] = None) -> None:
r"""Fuse conv/bn/relu modules in inception model r"""Fuse conv/bn/relu modules in inception model
Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization. Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
...@@ -170,7 +170,7 @@ class QuantizableInception3(inception_module.Inception3): ...@@ -170,7 +170,7 @@ class QuantizableInception3(inception_module.Inception3):
for m in self.modules(): for m in self.modules():
if type(m) is QuantizableBasicConv2d: if type(m) is QuantizableBasicConv2d:
m.fuse_model() m.fuse_model(is_qat)
def inception_v3( def inception_v3(
......
from typing import Any from typing import Any, Optional
from torch import Tensor from torch import Tensor
from torch import nn from torch import nn
from torch.ao.quantization import QuantStub, DeQuantStub, fuse_modules from torch.ao.quantization import QuantStub, DeQuantStub
from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation from ...ops.misc import ConvNormActivation
from .utils import _replace_relu, quantize_model from .utils import _fuse_modules, _replace_relu, quantize_model
__all__ = ["QuantizableMobileNetV2", "mobilenet_v2"] __all__ = ["QuantizableMobileNetV2", "mobilenet_v2"]
...@@ -28,10 +28,10 @@ class QuantizableInvertedResidual(InvertedResidual): ...@@ -28,10 +28,10 @@ class QuantizableInvertedResidual(InvertedResidual):
else: else:
return self.conv(x) return self.conv(x)
def fuse_model(self) -> None: def fuse_model(self, is_qat: Optional[bool] = None) -> None:
for idx in range(len(self.conv)): for idx in range(len(self.conv)):
if type(self.conv[idx]) is nn.Conv2d: if type(self.conv[idx]) is nn.Conv2d:
fuse_modules(self.conv, [str(idx), str(idx + 1)], inplace=True) _fuse_modules(self.conv, [str(idx), str(idx + 1)], is_qat, inplace=True)
class QuantizableMobileNetV2(MobileNetV2): class QuantizableMobileNetV2(MobileNetV2):
...@@ -52,12 +52,12 @@ class QuantizableMobileNetV2(MobileNetV2): ...@@ -52,12 +52,12 @@ class QuantizableMobileNetV2(MobileNetV2):
x = self.dequant(x) x = self.dequant(x)
return x return x
def fuse_model(self) -> None: def fuse_model(self, is_qat: Optional[bool] = None) -> None:
for m in self.modules(): for m in self.modules():
if type(m) is ConvNormActivation: if type(m) is ConvNormActivation:
fuse_modules(m, ["0", "1", "2"], inplace=True) _fuse_modules(m, ["0", "1", "2"], is_qat, inplace=True)
if type(m) is QuantizableInvertedResidual: if type(m) is QuantizableInvertedResidual:
m.fuse_model() m.fuse_model(is_qat)
def mobilenet_v2( def mobilenet_v2(
......
...@@ -2,12 +2,12 @@ from typing import Any, List, Optional ...@@ -2,12 +2,12 @@ from typing import Any, List, Optional
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from torch.ao.quantization import QuantStub, DeQuantStub, fuse_modules from torch.ao.quantization import QuantStub, DeQuantStub
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation, SqueezeExcitation from ...ops.misc import ConvNormActivation, SqueezeExcitation
from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, model_urls, _mobilenet_v3_conf from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, model_urls, _mobilenet_v3_conf
from .utils import _replace_relu from .utils import _fuse_modules, _replace_relu
__all__ = ["QuantizableMobileNetV3", "mobilenet_v3_large"] __all__ = ["QuantizableMobileNetV3", "mobilenet_v3_large"]
...@@ -28,8 +28,8 @@ class QuantizableSqueezeExcitation(SqueezeExcitation): ...@@ -28,8 +28,8 @@ class QuantizableSqueezeExcitation(SqueezeExcitation):
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return self.skip_mul.mul(self._scale(input), input) return self.skip_mul.mul(self._scale(input), input)
def fuse_model(self) -> None: def fuse_model(self, is_qat: Optional[bool] = None) -> None:
fuse_modules(self, ["fc1", "activation"], inplace=True) _fuse_modules(self, ["fc1", "activation"], is_qat, inplace=True)
def _load_from_state_dict( def _load_from_state_dict(
self, self,
...@@ -101,15 +101,15 @@ class QuantizableMobileNetV3(MobileNetV3): ...@@ -101,15 +101,15 @@ class QuantizableMobileNetV3(MobileNetV3):
x = self.dequant(x) x = self.dequant(x)
return x return x
def fuse_model(self) -> None: def fuse_model(self, is_qat: Optional[bool] = None) -> None:
for m in self.modules(): for m in self.modules():
if type(m) is ConvNormActivation: if type(m) is ConvNormActivation:
modules_to_fuse = ["0", "1"] modules_to_fuse = ["0", "1"]
if len(m) == 3 and type(m[2]) is nn.ReLU: if len(m) == 3 and type(m[2]) is nn.ReLU:
modules_to_fuse.append("2") modules_to_fuse.append("2")
fuse_modules(m, modules_to_fuse, inplace=True) _fuse_modules(m, modules_to_fuse, is_qat, inplace=True)
elif type(m) is QuantizableSqueezeExcitation: elif type(m) is QuantizableSqueezeExcitation:
m.fuse_model() m.fuse_model(is_qat)
def _load_weights(arch: str, model: QuantizableMobileNetV3, model_url: Optional[str], progress: bool) -> None: def _load_weights(arch: str, model: QuantizableMobileNetV3, model_url: Optional[str], progress: bool) -> None:
...@@ -135,7 +135,7 @@ def _mobilenet_v3_model( ...@@ -135,7 +135,7 @@ def _mobilenet_v3_model(
if quantize: if quantize:
backend = "qnnpack" backend = "qnnpack"
model.fuse_model() model.fuse_model(is_qat=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend) model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend)
torch.ao.quantization.prepare_qat(model, inplace=True) torch.ao.quantization.prepare_qat(model, inplace=True)
......
from typing import Any, Type, Union, List from typing import Any, Type, Union, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.ao.quantization import fuse_modules
from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from .utils import _replace_relu, quantize_model from .utils import _fuse_modules, _replace_relu, quantize_model
__all__ = ["QuantizableResNet", "resnet18", "resnet50", "resnext101_32x8d"] __all__ = ["QuantizableResNet", "resnet18", "resnet50", "resnext101_32x8d"]
...@@ -41,10 +40,10 @@ class QuantizableBasicBlock(BasicBlock): ...@@ -41,10 +40,10 @@ class QuantizableBasicBlock(BasicBlock):
return out return out
def fuse_model(self) -> None: def fuse_model(self, is_qat: Optional[bool] = None) -> None:
torch.ao.quantization.fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], inplace=True) _fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], is_qat, inplace=True)
if self.downsample: if self.downsample:
torch.ao.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True) _fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True)
class QuantizableBottleneck(Bottleneck): class QuantizableBottleneck(Bottleneck):
...@@ -72,10 +71,12 @@ class QuantizableBottleneck(Bottleneck): ...@@ -72,10 +71,12 @@ class QuantizableBottleneck(Bottleneck):
return out return out
def fuse_model(self) -> None: def fuse_model(self, is_qat: Optional[bool] = None) -> None:
fuse_modules(self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], inplace=True) _fuse_modules(
self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], is_qat, inplace=True
)
if self.downsample: if self.downsample:
torch.ao.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True) _fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True)
class QuantizableResNet(ResNet): class QuantizableResNet(ResNet):
...@@ -94,18 +95,17 @@ class QuantizableResNet(ResNet): ...@@ -94,18 +95,17 @@ class QuantizableResNet(ResNet):
x = self.dequant(x) x = self.dequant(x)
return x return x
def fuse_model(self) -> None: def fuse_model(self, is_qat: Optional[bool] = None) -> None:
r"""Fuse conv/bn/relu modules in resnet models r"""Fuse conv/bn/relu modules in resnet models
Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization. Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization.
Model is modified in place. Note that this operation does not change numerics Model is modified in place. Note that this operation does not change numerics
and the model after modification is in floating point and the model after modification is in floating point
""" """
_fuse_modules(self, ["conv1", "bn1", "relu"], is_qat, inplace=True)
fuse_modules(self, ["conv1", "bn1", "relu"], inplace=True)
for m in self.modules(): for m in self.modules():
if type(m) is QuantizableBottleneck or type(m) is QuantizableBasicBlock: if type(m) is QuantizableBottleneck or type(m) is QuantizableBasicBlock:
m.fuse_model() m.fuse_model(is_qat)
def _resnet( def _resnet(
......
...@@ -6,7 +6,7 @@ from torch import Tensor ...@@ -6,7 +6,7 @@ from torch import Tensor
from torchvision.models import shufflenetv2 from torchvision.models import shufflenetv2
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from .utils import _replace_relu, quantize_model from .utils import _fuse_modules, _replace_relu, quantize_model
__all__ = [ __all__ = [
"QuantizableShuffleNetV2", "QuantizableShuffleNetV2",
...@@ -50,24 +50,24 @@ class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2): ...@@ -50,24 +50,24 @@ class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2):
x = self.dequant(x) x = self.dequant(x)
return x return x
def fuse_model(self) -> None: def fuse_model(self, is_qat: Optional[bool] = None) -> None:
r"""Fuse conv/bn/relu modules in shufflenetv2 model r"""Fuse conv/bn/relu modules in shufflenetv2 model
Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization. Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
Model is modified in place. Note that this operation does not change numerics Model is modified in place. Note that this operation does not change numerics
and the model after modification is in floating point and the model after modification is in floating point
""" """
for name, m in self._modules.items(): for name, m in self._modules.items():
if name in ["conv1", "conv5"]: if name in ["conv1", "conv5"] and m is not None:
torch.ao.quantization.fuse_modules(m, [["0", "1", "2"]], inplace=True) _fuse_modules(m, [["0", "1", "2"]], is_qat, inplace=True)
for m in self.modules(): for m in self.modules():
if type(m) is QuantizableInvertedResidual: if type(m) is QuantizableInvertedResidual:
if len(m.branch1._modules.items()) > 0: if len(m.branch1._modules.items()) > 0:
torch.ao.quantization.fuse_modules(m.branch1, [["0", "1"], ["2", "3", "4"]], inplace=True) _fuse_modules(m.branch1, [["0", "1"], ["2", "3", "4"]], is_qat, inplace=True)
torch.ao.quantization.fuse_modules( _fuse_modules(
m.branch2, m.branch2,
[["0", "1", "2"], ["3", "4"], ["5", "6", "7"]], [["0", "1", "2"], ["3", "4"], ["5", "6", "7"]],
is_qat,
inplace=True, inplace=True,
) )
......
from typing import Any, List, Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -39,4 +41,11 @@ def quantize_model(model: nn.Module, backend: str) -> None: ...@@ -39,4 +41,11 @@ def quantize_model(model: nn.Module, backend: str) -> None:
model(_dummy_input_data) model(_dummy_input_data)
torch.ao.quantization.convert(model, inplace=True) torch.ao.quantization.convert(model, inplace=True)
return
def _fuse_modules(
model: nn.Module, modules_to_fuse: Union[List[str], List[List[str]]], is_qat: Optional[bool], **kwargs: Any
):
if is_qat is None:
is_qat = model.training
method = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules
return method(model, modules_to_fuse, **kwargs)
...@@ -42,7 +42,7 @@ def _mobilenet_v3_model( ...@@ -42,7 +42,7 @@ def _mobilenet_v3_model(
_replace_relu(model) _replace_relu(model)
if quantize: if quantize:
model.fuse_model() model.fuse_model(is_qat=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend) model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend)
torch.ao.quantization.prepare_qat(model, inplace=True) torch.ao.quantization.prepare_qat(model, inplace=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment