Unverified Commit 02def7c4 authored by moto's avatar moto Committed by GitHub
Browse files

[fbsync] torchaudio: torch.quantization -> torch.ao.quantization (#1823)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/1817



This changes the imports in the `torchaudio` to include the new import locations.

```
codemod -d pytorch/audio --extensions py 'torch.quantization' 'torch.ao.quantization'
```

Reviewed By: mthrok

Differential Revision: D31302450

fbshipit-source-id: f31a0d4f453f840ea690edb688555a9d585787b5
Co-authored-by: default avatarZafar Takhirov <zaf@fb.com>
parent 358e9e93
......@@ -6,6 +6,7 @@ To use this script, you need `fairseq`.
import os
import argparse
import logging
from typing import Tuple
import torch
from torch.utils.mobile_optimizer import optimize_for_mobile
......@@ -15,6 +16,12 @@ import fairseq
from greedy_decoder import Decoder
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION >= (1, 10):
import torch.ao.quantization as tq
else:
import torch.quantization as tq
_LG = logging.getLogger(__name__)
......@@ -149,7 +156,7 @@ def _main():
if args.quantize:
_LG.info('Quantizing the model')
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
encoder = torch.quantization.quantize_dynamic(
encoder = tq.quantize_dynamic(
encoder, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
_LG.info(encoder)
......
......@@ -2,12 +2,19 @@
import argparse
import logging
import os
from typing import Tuple
import torch
import torchaudio
from torchaudio.models.wav2vec2.utils.import_huggingface import import_huggingface_model
from greedy_decoder import Decoder
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION >= (1, 10):
import torch.ao.quantization as tq
else:
import torch.quantization as tq
_LG = logging.getLogger(__name__)
......@@ -90,7 +97,7 @@ def _main():
if args.quantize:
_LG.info('Quantizing the model')
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
encoder = torch.quantization.quantize_dynamic(
encoder = tq.quantize_dynamic(
encoder, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
_LG.info(encoder)
......
......@@ -2,6 +2,7 @@ import os
import torch
import torch.nn.functional as F
from typing import Tuple
from torchaudio.models.wav2vec2 import (
wav2vec2_ft_base,
......@@ -24,6 +25,12 @@ from torchaudio_unittest.common_utils import (
)
from parameterized import parameterized
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION >= (1, 10):
import torch.ao.quantization as tq
else:
import torch.quantization as tq
def _name_func(testcase_func, i, param):
return f"{testcase_func.__name__}_{i}_{param[0][0].__name__}"
......@@ -210,7 +217,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
# Remove the weight normalization forward hook
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
quantized = torch.quantization.quantize_dynamic(
quantized = tq.quantize_dynamic(
model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
# A lazy way to check that Modules are different
......@@ -241,7 +248,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
# Remove the weight normalization forward hook
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
quantized = torch.quantization.quantize_dynamic(
quantized = tq.quantize_dynamic(
model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
# A lazy way to check that Modules are different
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment