"docs/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "5e75c8e803e623481e2e76ba93444301d498be54"
Unverified Commit b3cdec1f authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Use new torch.ao.quantization instead of torch.quantization (#4554)


Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 1c9ccb7b
...@@ -4,7 +4,7 @@ import os ...@@ -4,7 +4,7 @@ import os
import time import time
import torch import torch
import torch.quantization import torch.ao.quantization
import torch.utils.data import torch.utils.data
import torchvision import torchvision
import utils import utils
...@@ -62,8 +62,8 @@ def main(args): ...@@ -62,8 +62,8 @@ def main(args):
if not (args.test_only or args.post_training_quantize): if not (args.test_only or args.post_training_quantize):
model.fuse_model() model.fuse_model()
model.qconfig = torch.quantization.get_default_qat_qconfig(args.backend) model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
torch.quantization.prepare_qat(model, inplace=True) torch.ao.quantization.prepare_qat(model, inplace=True)
if args.distributed and args.sync_bn: if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
...@@ -96,12 +96,12 @@ def main(args): ...@@ -96,12 +96,12 @@ def main(args):
) )
model.eval() model.eval()
model.fuse_model() model.fuse_model()
model.qconfig = torch.quantization.get_default_qconfig(args.backend) model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
torch.quantization.prepare(model, inplace=True) torch.ao.quantization.prepare(model, inplace=True)
# Calibrate first # Calibrate first
print("Calibrating") print("Calibrating")
evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1) evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
torch.quantization.convert(model, inplace=True) torch.ao.quantization.convert(model, inplace=True)
if args.output_dir: if args.output_dir:
print("Saving quantized model") print("Saving quantized model")
if utils.is_main_process(): if utils.is_main_process():
...@@ -114,8 +114,8 @@ def main(args): ...@@ -114,8 +114,8 @@ def main(args):
evaluate(model, criterion, data_loader_test, device=device) evaluate(model, criterion, data_loader_test, device=device)
return return
model.apply(torch.quantization.enable_observer) model.apply(torch.ao.quantization.enable_observer)
model.apply(torch.quantization.enable_fake_quant) model.apply(torch.ao.quantization.enable_fake_quant)
start_time = time.time() start_time = time.time()
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
if args.distributed: if args.distributed:
...@@ -126,7 +126,7 @@ def main(args): ...@@ -126,7 +126,7 @@ def main(args):
with torch.inference_mode(): with torch.inference_mode():
if epoch >= args.num_observer_update_epochs: if epoch >= args.num_observer_update_epochs:
print("Disabling observer for subseq epochs, epoch = ", epoch) print("Disabling observer for subseq epochs, epoch = ", epoch)
model.apply(torch.quantization.disable_observer) model.apply(torch.ao.quantization.disable_observer)
if epoch >= args.num_batch_norm_update_epochs: if epoch >= args.num_batch_norm_update_epochs:
print("Freezing BN for subseq epochs, epoch = ", epoch) print("Freezing BN for subseq epochs, epoch = ", epoch)
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
...@@ -136,7 +136,7 @@ def main(args): ...@@ -136,7 +136,7 @@ def main(args):
quantized_eval_model = copy.deepcopy(model_without_ddp) quantized_eval_model = copy.deepcopy(model_without_ddp)
quantized_eval_model.eval() quantized_eval_model.eval()
quantized_eval_model.to(torch.device("cpu")) quantized_eval_model.to(torch.device("cpu"))
torch.quantization.convert(quantized_eval_model, inplace=True) torch.ao.quantization.convert(quantized_eval_model, inplace=True)
print("Evaluate Quantized model") print("Evaluate Quantized model")
evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu")) evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
......
...@@ -345,8 +345,8 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T ...@@ -345,8 +345,8 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T
# Quantized Classification # Quantized Classification
model = M.quantization.mobilenet_v3_large(pretrained=False, quantize=False) model = M.quantization.mobilenet_v3_large(pretrained=False, quantize=False)
model.fuse_model() model.fuse_model()
model.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack') model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
_ = torch.quantization.prepare_qat(model, inplace=True) _ = torch.ao.quantization.prepare_qat(model, inplace=True)
print(store_model_weights(model, './qat.pth')) print(store_model_weights(model, './qat.pth'))
# Object Detection # Object Detection
......
...@@ -781,19 +781,19 @@ def test_quantized_classification_model(model_fn): ...@@ -781,19 +781,19 @@ def test_quantized_classification_model(model_fn):
model = model_fn(**kwargs) model = model_fn(**kwargs)
if eval_mode: if eval_mode:
model.eval() model.eval()
model.qconfig = torch.quantization.default_qconfig model.qconfig = torch.ao.quantization.default_qconfig
else: else:
model.train() model.train()
model.qconfig = torch.quantization.default_qat_qconfig model.qconfig = torch.ao.quantization.default_qat_qconfig
model.fuse_model() model.fuse_model()
if eval_mode: if eval_mode:
torch.quantization.prepare(model, inplace=True) torch.ao.quantization.prepare(model, inplace=True)
else: else:
torch.quantization.prepare_qat(model, inplace=True) torch.ao.quantization.prepare_qat(model, inplace=True)
model.eval() model.eval()
torch.quantization.convert(model, inplace=True) torch.ao.quantization.convert(model, inplace=True)
try: try:
torch.jit.script(model) torch.jit.script(model)
......
...@@ -31,7 +31,7 @@ class QuantizableBasicConv2d(BasicConv2d): ...@@ -31,7 +31,7 @@ class QuantizableBasicConv2d(BasicConv2d):
return x return x
def fuse_model(self) -> None: def fuse_model(self) -> None:
torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True) torch.ao.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)
class QuantizableInception(Inception): class QuantizableInception(Inception):
...@@ -74,8 +74,8 @@ class QuantizableGoogLeNet(GoogLeNet): ...@@ -74,8 +74,8 @@ class QuantizableGoogLeNet(GoogLeNet):
super().__init__( # type: ignore[misc] super().__init__( # type: ignore[misc]
blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux], *args, **kwargs blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux], *args, **kwargs
) )
self.quant = torch.quantization.QuantStub() self.quant = torch.ao.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub() self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x: Tensor) -> GoogLeNetOutputs: def forward(self, x: Tensor) -> GoogLeNetOutputs:
x = self._transform_input(x) x = self._transform_input(x)
......
...@@ -36,7 +36,7 @@ class QuantizableBasicConv2d(inception_module.BasicConv2d): ...@@ -36,7 +36,7 @@ class QuantizableBasicConv2d(inception_module.BasicConv2d):
return x return x
def fuse_model(self) -> None: def fuse_model(self) -> None:
torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True) torch.ao.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)
class QuantizableInceptionA(inception_module.InceptionA): class QuantizableInceptionA(inception_module.InceptionA):
...@@ -144,8 +144,8 @@ class QuantizableInception3(inception_module.Inception3): ...@@ -144,8 +144,8 @@ class QuantizableInception3(inception_module.Inception3):
QuantizableInceptionAux, QuantizableInceptionAux,
], ],
) )
self.quant = torch.quantization.QuantStub() self.quant = torch.ao.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub() self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x: Tensor) -> InceptionOutputs: def forward(self, x: Tensor) -> InceptionOutputs:
x = self._transform_input(x) x = self._transform_input(x)
......
...@@ -2,7 +2,7 @@ from typing import Any ...@@ -2,7 +2,7 @@ from typing import Any
from torch import Tensor from torch import Tensor
from torch import nn from torch import nn
from torch.quantization import QuantStub, DeQuantStub, fuse_modules from torch.ao.quantization import QuantStub, DeQuantStub, fuse_modules
from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
......
...@@ -2,7 +2,7 @@ from typing import Any, List, Optional ...@@ -2,7 +2,7 @@ from typing import Any, List, Optional
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from torch.quantization import QuantStub, DeQuantStub, fuse_modules from torch.ao.quantization import QuantStub, DeQuantStub, fuse_modules
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation, SqueezeExcitation from ...ops.misc import ConvNormActivation, SqueezeExcitation
...@@ -136,13 +136,13 @@ def _mobilenet_v3_model( ...@@ -136,13 +136,13 @@ def _mobilenet_v3_model(
backend = "qnnpack" backend = "qnnpack"
model.fuse_model() model.fuse_model()
model.qconfig = torch.quantization.get_default_qat_qconfig(backend) model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend)
torch.quantization.prepare_qat(model, inplace=True) torch.ao.quantization.prepare_qat(model, inplace=True)
if pretrained: if pretrained:
_load_weights(arch, model, quant_model_urls.get(arch + "_" + backend, None), progress) _load_weights(arch, model, quant_model_urls.get(arch + "_" + backend, None), progress)
torch.quantization.convert(model, inplace=True) torch.ao.quantization.convert(model, inplace=True)
model.eval() model.eval()
else: else:
if pretrained: if pretrained:
......
...@@ -3,7 +3,7 @@ from typing import Any, Type, Union, List ...@@ -3,7 +3,7 @@ from typing import Any, Type, Union, List
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.quantization import fuse_modules from torch.ao.quantization import fuse_modules
from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
...@@ -42,9 +42,9 @@ class QuantizableBasicBlock(BasicBlock): ...@@ -42,9 +42,9 @@ class QuantizableBasicBlock(BasicBlock):
return out return out
def fuse_model(self) -> None: def fuse_model(self) -> None:
torch.quantization.fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], inplace=True) torch.ao.quantization.fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], inplace=True)
if self.downsample: if self.downsample:
torch.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True) torch.ao.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True)
class QuantizableBottleneck(Bottleneck): class QuantizableBottleneck(Bottleneck):
...@@ -75,15 +75,15 @@ class QuantizableBottleneck(Bottleneck): ...@@ -75,15 +75,15 @@ class QuantizableBottleneck(Bottleneck):
def fuse_model(self) -> None: def fuse_model(self) -> None:
fuse_modules(self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], inplace=True) fuse_modules(self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], inplace=True)
if self.downsample: if self.downsample:
torch.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True) torch.ao.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True)
class QuantizableResNet(ResNet): class QuantizableResNet(ResNet):
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.quant = torch.quantization.QuantStub() self.quant = torch.ao.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub() self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
x = self.quant(x) x = self.quant(x)
......
...@@ -41,8 +41,8 @@ class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2): ...@@ -41,8 +41,8 @@ class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2):
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, inverted_residual=QuantizableInvertedResidual, **kwargs) # type: ignore[misc] super().__init__(*args, inverted_residual=QuantizableInvertedResidual, **kwargs) # type: ignore[misc]
self.quant = torch.quantization.QuantStub() self.quant = torch.ao.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub() self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
x = self.quant(x) x = self.quant(x)
...@@ -60,12 +60,12 @@ class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2): ...@@ -60,12 +60,12 @@ class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2):
for name, m in self._modules.items(): for name, m in self._modules.items():
if name in ["conv1", "conv5"]: if name in ["conv1", "conv5"]:
torch.quantization.fuse_modules(m, [["0", "1", "2"]], inplace=True) torch.ao.quantization.fuse_modules(m, [["0", "1", "2"]], inplace=True)
for m in self.modules(): for m in self.modules():
if type(m) is QuantizableInvertedResidual: if type(m) is QuantizableInvertedResidual:
if len(m.branch1._modules.items()) > 0: if len(m.branch1._modules.items()) > 0:
torch.quantization.fuse_modules(m.branch1, [["0", "1"], ["2", "3", "4"]], inplace=True) torch.ao.quantization.fuse_modules(m.branch1, [["0", "1"], ["2", "3", "4"]], inplace=True)
torch.quantization.fuse_modules( torch.ao.quantization.fuse_modules(
m.branch2, m.branch2,
[["0", "1", "2"], ["3", "4"], ["5", "6", "7"]], [["0", "1", "2"], ["3", "4"], ["5", "6", "7"]],
inplace=True, inplace=True,
......
...@@ -24,19 +24,19 @@ def quantize_model(model: nn.Module, backend: str) -> None: ...@@ -24,19 +24,19 @@ def quantize_model(model: nn.Module, backend: str) -> None:
model.eval() model.eval()
# Make sure that weight qconfig matches that of the serialized models # Make sure that weight qconfig matches that of the serialized models
if backend == "fbgemm": if backend == "fbgemm":
model.qconfig = torch.quantization.QConfig( # type: ignore[assignment] model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment]
activation=torch.quantization.default_observer, activation=torch.ao.quantization.default_observer,
weight=torch.quantization.default_per_channel_weight_observer, weight=torch.ao.quantization.default_per_channel_weight_observer,
) )
elif backend == "qnnpack": elif backend == "qnnpack":
model.qconfig = torch.quantization.QConfig( # type: ignore[assignment] model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment]
activation=torch.quantization.default_observer, weight=torch.quantization.default_weight_observer activation=torch.ao.quantization.default_observer, weight=torch.ao.quantization.default_weight_observer
) )
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
model.fuse_model() # type: ignore[operator] model.fuse_model() # type: ignore[operator]
torch.quantization.prepare(model, inplace=True) torch.ao.quantization.prepare(model, inplace=True)
model(_dummy_input_data) model(_dummy_input_data)
torch.quantization.convert(model, inplace=True) torch.ao.quantization.convert(model, inplace=True)
return return
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment