"git@developer.sourcefind.cn:OpenDAS/ktransformers.git" did not exist on "c8bf2501040e8ee43482e65ee73419367102755e"
Unverified Commit 02def7c4 authored by moto's avatar moto Committed by GitHub
Browse files

[fbsync] torchaudio: torch.quantization -> torch.ao.quantization (#1823)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/1817



This changes the imports in the `torchaudio` to include the new import locations.

```
codemod -d pytorch/audio --extensions py 'torch.quantization' 'torch.ao.quantization'
```

Reviewed By: mthrok

Differential Revision: D31302450

fbshipit-source-id: f31a0d4f453f840ea690edb688555a9d585787b5
Co-authored-by: default avatarZafar Takhirov <zaf@fb.com>
parent 358e9e93
...@@ -6,6 +6,7 @@ To use this script, you need `fairseq`. ...@@ -6,6 +6,7 @@ To use this script, you need `fairseq`.
import os import os
import argparse import argparse
import logging import logging
from typing import Tuple
import torch import torch
from torch.utils.mobile_optimizer import optimize_for_mobile from torch.utils.mobile_optimizer import optimize_for_mobile
...@@ -15,6 +16,12 @@ import fairseq ...@@ -15,6 +16,12 @@ import fairseq
from greedy_decoder import Decoder from greedy_decoder import Decoder
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION >= (1, 10):
import torch.ao.quantization as tq
else:
import torch.quantization as tq
_LG = logging.getLogger(__name__) _LG = logging.getLogger(__name__)
...@@ -149,7 +156,7 @@ def _main(): ...@@ -149,7 +156,7 @@ def _main():
if args.quantize: if args.quantize:
_LG.info('Quantizing the model') _LG.info('Quantizing the model')
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__() model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
encoder = torch.quantization.quantize_dynamic( encoder = tq.quantize_dynamic(
encoder, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8) encoder, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
_LG.info(encoder) _LG.info(encoder)
......
...@@ -2,12 +2,19 @@ ...@@ -2,12 +2,19 @@
import argparse import argparse
import logging import logging
import os import os
from typing import Tuple
import torch import torch
import torchaudio import torchaudio
from torchaudio.models.wav2vec2.utils.import_huggingface import import_huggingface_model from torchaudio.models.wav2vec2.utils.import_huggingface import import_huggingface_model
from greedy_decoder import Decoder from greedy_decoder import Decoder
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION >= (1, 10):
import torch.ao.quantization as tq
else:
import torch.quantization as tq
_LG = logging.getLogger(__name__) _LG = logging.getLogger(__name__)
...@@ -90,7 +97,7 @@ def _main(): ...@@ -90,7 +97,7 @@ def _main():
if args.quantize: if args.quantize:
_LG.info('Quantizing the model') _LG.info('Quantizing the model')
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__() model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
encoder = torch.quantization.quantize_dynamic( encoder = tq.quantize_dynamic(
encoder, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8) encoder, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
_LG.info(encoder) _LG.info(encoder)
......
...@@ -2,6 +2,7 @@ import os ...@@ -2,6 +2,7 @@ import os
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from typing import Tuple
from torchaudio.models.wav2vec2 import ( from torchaudio.models.wav2vec2 import (
wav2vec2_ft_base, wav2vec2_ft_base,
...@@ -24,6 +25,12 @@ from torchaudio_unittest.common_utils import ( ...@@ -24,6 +25,12 @@ from torchaudio_unittest.common_utils import (
) )
from parameterized import parameterized from parameterized import parameterized
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION >= (1, 10):
import torch.ao.quantization as tq
else:
import torch.quantization as tq
def _name_func(testcase_func, i, param): def _name_func(testcase_func, i, param):
return f"{testcase_func.__name__}_{i}_{param[0][0].__name__}" return f"{testcase_func.__name__}_{i}_{param[0][0].__name__}"
...@@ -210,7 +217,7 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -210,7 +217,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
# Remove the weight normalization forward hook # Remove the weight normalization forward hook
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__() model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
quantized = torch.quantization.quantize_dynamic( quantized = tq.quantize_dynamic(
model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8) model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
# A lazy way to check that Modules are different # A lazy way to check that Modules are different
...@@ -241,7 +248,7 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -241,7 +248,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
# Remove the weight normalization forward hook # Remove the weight normalization forward hook
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__() model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
quantized = torch.quantization.quantize_dynamic( quantized = tq.quantize_dynamic(
model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8) model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
# A lazy way to check that Modules are different # A lazy way to check that Modules are different
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment