"benchmark/vscode:/vscode.git/clone" did not exist on "45bc170b36ebdb74496e66260bde209ba7687ea8"
Commit 5859923a authored by Joao Gomes's avatar Joao Gomes Committed by Facebook GitHub Bot
Browse files

Apply arc lint to pytorch audio (#2096)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2096

run: `arc lint --apply-patches --paths-cmd 'hg files -I "./**/*.py"'`

Reviewed By: mthrok

Differential Revision: D33297351

fbshipit-source-id: 7bf5956edf0717c5ca90219f72414ff4eeaf5aa8
parent 0e5913d5
import pytest
import torch import torch
from torchaudio._internal import download_url_to_file from torchaudio._internal import download_url_to_file
import pytest
class GreedyCTCDecoder(torch.nn.Module): class GreedyCTCDecoder(torch.nn.Module):
...@@ -24,7 +24,7 @@ class GreedyCTCDecoder(torch.nn.Module): ...@@ -24,7 +24,7 @@ class GreedyCTCDecoder(torch.nn.Module):
for i in best_path: for i in best_path:
if i != self.blank: if i != self.blank:
hypothesis.append(self.labels[i]) hypothesis.append(self.labels[i])
return ''.join(hypothesis) return "".join(hypothesis)
@pytest.fixture @pytest.fixture
...@@ -33,24 +33,24 @@ def ctc_decoder(): ...@@ -33,24 +33,24 @@ def ctc_decoder():
_FILES = { _FILES = {
'en': 'Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac', "en": "Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac",
'de': '20090505-0900-PLENARY-16-de_20090505-21_56_00_8.flac', "de": "20090505-0900-PLENARY-16-de_20090505-21_56_00_8.flac",
'en2': '20120613-0900-PLENARY-8-en_20120613-13_46_50_3.flac', "en2": "20120613-0900-PLENARY-8-en_20120613-13_46_50_3.flac",
'es': '20130207-0900-PLENARY-7-es_20130207-13_02_05_5.flac', "es": "20130207-0900-PLENARY-7-es_20130207-13_02_05_5.flac",
'fr': '20121212-0900-PLENARY-5-fr_20121212-11_37_04_10.flac', "fr": "20121212-0900-PLENARY-5-fr_20121212-11_37_04_10.flac",
'it': '20170516-0900-PLENARY-16-it_20170516-18_56_31_1.flac', "it": "20170516-0900-PLENARY-16-it_20170516-18_56_31_1.flac",
} }
@pytest.fixture @pytest.fixture
def sample_speech(tmp_path, lang): def sample_speech(tmp_path, lang):
if lang not in _FILES: if lang not in _FILES:
raise NotImplementedError(f'Unexpected lang: {lang}') raise NotImplementedError(f"Unexpected lang: {lang}")
filename = _FILES[lang] filename = _FILES[lang]
path = tmp_path.parent / filename path = tmp_path.parent / filename
if not path.exists(): if not path.exists():
url = f'https://download.pytorch.org/torchaudio/test-assets/{filename}' url = f"https://download.pytorch.org/torchaudio/test-assets/{filename}"
print(f'downloading from {url}') print(f"downloading from {url}")
download_url_to_file(url, path, progress=False) download_url_to_file(url, path, progress=False)
return path return path
...@@ -62,13 +62,13 @@ def pytest_addoption(parser): ...@@ -62,13 +62,13 @@ def pytest_addoption(parser):
help=( help=(
"When provided, tests will use temporary directory as Torch Hub directory. " "When provided, tests will use temporary directory as Torch Hub directory. "
"Downloaded models will be deleted after each test." "Downloaded models will be deleted after each test."
) ),
) )
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def temp_hub_dir(tmpdir, pytestconfig): def temp_hub_dir(tmpdir, pytestconfig):
if not pytestconfig.getoption('use_tmp_hub_dir'): if not pytestconfig.getoption("use_tmp_hub_dir"):
yield yield
else: else:
org_dir = torch.hub.get_dir() org_dir = torch.hub.get_dir()
......
import pytest
from torchaudio.pipelines import ( from torchaudio.pipelines import (
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH, TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH, TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
TACOTRON2_WAVERNN_CHAR_LJSPEECH, TACOTRON2_WAVERNN_CHAR_LJSPEECH,
TACOTRON2_WAVERNN_PHONE_LJSPEECH, TACOTRON2_WAVERNN_PHONE_LJSPEECH,
) )
import pytest
@pytest.mark.parametrize( @pytest.mark.parametrize(
'bundle', "bundle",
[ [
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH, TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH, TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
TACOTRON2_WAVERNN_CHAR_LJSPEECH, TACOTRON2_WAVERNN_CHAR_LJSPEECH,
TACOTRON2_WAVERNN_PHONE_LJSPEECH, TACOTRON2_WAVERNN_PHONE_LJSPEECH,
] ],
) )
def test_tts_models(bundle): def test_tts_models(bundle):
"""Smoke test of TTS pipeline""" """Smoke test of TTS pipeline"""
......
import pytest
import torchaudio import torchaudio
from torchaudio.pipelines import ( from torchaudio.pipelines import (
WAV2VEC2_BASE, WAV2VEC2_BASE,
...@@ -24,7 +25,6 @@ from torchaudio.pipelines import ( ...@@ -24,7 +25,6 @@ from torchaudio.pipelines import (
VOXPOPULI_ASR_BASE_10K_FR, VOXPOPULI_ASR_BASE_10K_FR,
VOXPOPULI_ASR_BASE_10K_IT, VOXPOPULI_ASR_BASE_10K_IT,
) )
import pytest
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -37,7 +37,7 @@ import pytest ...@@ -37,7 +37,7 @@ import pytest
HUBERT_BASE, HUBERT_BASE,
HUBERT_LARGE, HUBERT_LARGE,
HUBERT_XLARGE, HUBERT_XLARGE,
] ],
) )
def test_pretraining_models(bundle): def test_pretraining_models(bundle):
"""Smoke test of downloading weights for pretraining models""" """Smoke test of downloading weights for pretraining models"""
...@@ -47,30 +47,46 @@ def test_pretraining_models(bundle): ...@@ -47,30 +47,46 @@ def test_pretraining_models(bundle):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"bundle,lang,expected", "bundle,lang,expected",
[ [
(WAV2VEC2_ASR_BASE_10M, 'en', 'I|HAD|THAT|CURIYOSSITY|BESID|ME|AT|THIS|MOMENT|'), (WAV2VEC2_ASR_BASE_10M, "en", "I|HAD|THAT|CURIYOSSITY|BESID|ME|AT|THIS|MOMENT|"),
(WAV2VEC2_ASR_BASE_100H, 'en', 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'), (WAV2VEC2_ASR_BASE_100H, "en", "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"),
(WAV2VEC2_ASR_BASE_960H, 'en', 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'), (WAV2VEC2_ASR_BASE_960H, "en", "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"),
(WAV2VEC2_ASR_LARGE_10M, 'en', 'I|HAD|THAT|CURIOUSITY|BESIDE|ME|AT|THIS|MOMENT|'), (WAV2VEC2_ASR_LARGE_10M, "en", "I|HAD|THAT|CURIOUSITY|BESIDE|ME|AT|THIS|MOMENT|"),
(WAV2VEC2_ASR_LARGE_100H, 'en', 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'), (WAV2VEC2_ASR_LARGE_100H, "en", "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"),
(WAV2VEC2_ASR_LARGE_960H, 'en', 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'), (WAV2VEC2_ASR_LARGE_960H, "en", "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"),
(WAV2VEC2_ASR_LARGE_LV60K_10M, 'en', 'I|HAD|THAT|CURIOUSSITY|BESID|ME|AT|THISS|MOMENT|'), (WAV2VEC2_ASR_LARGE_LV60K_10M, "en", "I|HAD|THAT|CURIOUSSITY|BESID|ME|AT|THISS|MOMENT|"),
(WAV2VEC2_ASR_LARGE_LV60K_100H, 'en', 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'), (WAV2VEC2_ASR_LARGE_LV60K_100H, "en", "I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"),
(WAV2VEC2_ASR_LARGE_LV60K_960H, 'en', 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'), (WAV2VEC2_ASR_LARGE_LV60K_960H, "en", "I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"),
(HUBERT_ASR_LARGE, 'en', 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'), (HUBERT_ASR_LARGE, "en", "I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"),
(HUBERT_ASR_XLARGE, 'en', 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'), (HUBERT_ASR_XLARGE, "en", "I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"),
(VOXPOPULI_ASR_BASE_10K_EN, 'en2', 'i|hope|that|we|will|see|a|ddrasstic|decrease|of|funding|for|the|failed|eu|project|and|that|more|money|will|come|back|to|the|taxpayers'), # noqa: E501 (
(VOXPOPULI_ASR_BASE_10K_ES, 'es', "la|primera|que|es|imprescindible|pensar|a|pequeña|a|escala|para|implicar|y|complementar|así|la|actuación|global"), # noqa: E501 VOXPOPULI_ASR_BASE_10K_EN,
(VOXPOPULI_ASR_BASE_10K_DE, 'de', "dabei|spielt|auch|eine|sorgfältige|berichterstattung|eine|wichtige|rolle"), "en2",
(VOXPOPULI_ASR_BASE_10K_FR, 'fr', 'la|commission|va|faire|des|propositions|sur|ce|sujet|comment|mettre|en|place|cette|capacité|fiscale|et|le|conseil|européen|y|reviendra|sour|les|sujets|au|moins|de|mars'), # noqa: E501 "i|hope|that|we|will|see|a|ddrasstic|decrease|of|funding|for|the|failed|eu|project|and|that|more|money|will|come|back|to|the|taxpayers",
(VOXPOPULI_ASR_BASE_10K_IT, 'it', 'credo|che|illatino|non|sia|contemplato|tra|le|traduzioni|e|quindi|mi|attengo|allitaliano') # noqa: E501 ), # noqa: E501
] (
VOXPOPULI_ASR_BASE_10K_ES,
"es",
"la|primera|que|es|imprescindible|pensar|a|pequeña|a|escala|para|implicar|y|complementar|así|la|actuación|global",
), # noqa: E501
(VOXPOPULI_ASR_BASE_10K_DE, "de", "dabei|spielt|auch|eine|sorgfältige|berichterstattung|eine|wichtige|rolle"),
(
VOXPOPULI_ASR_BASE_10K_FR,
"fr",
"la|commission|va|faire|des|propositions|sur|ce|sujet|comment|mettre|en|place|cette|capacité|fiscale|et|le|conseil|européen|y|reviendra|sour|les|sujets|au|moins|de|mars",
), # noqa: E501
(
VOXPOPULI_ASR_BASE_10K_IT,
"it",
"credo|che|illatino|non|sia|contemplato|tra|le|traduzioni|e|quindi|mi|attengo|allitaliano",
), # noqa: E501
],
) )
def test_finetune_asr_model( def test_finetune_asr_model(
bundle, bundle,
lang, lang,
expected, expected,
sample_speech, sample_speech,
ctc_decoder, ctc_decoder,
): ):
"""Smoke test of downloading weights for fine-tuning models and simple transcription""" """Smoke test of downloading weights for fine-tuning models and simple transcription"""
model = bundle.get_model().eval() model = bundle.get_model().eval()
......
...@@ -8,30 +8,26 @@ import torch ...@@ -8,30 +8,26 @@ import torch
def _parse_args(): def _parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Generate opus files for test")
description='Generate opus files for test' parser.add_argument("--num-channels", required=True, type=int)
) parser.add_argument("--compression-level", required=True, type=int, choices=list(range(11)))
parser.add_argument('--num-channels', required=True, type=int) parser.add_argument("--bitrate", default="96k")
parser.add_argument('--compression-level', required=True, type=int, choices=list(range(11)))
parser.add_argument('--bitrate', default='96k')
return parser.parse_args() return parser.parse_args()
def convert_to_opus( def convert_to_opus(src_path, dst_path, *, bitrate, compression_level):
src_path, dst_path,
*, bitrate, compression_level):
"""Convert audio file with `ffmpeg` command.""" """Convert audio file with `ffmpeg` command."""
command = ['ffmpeg', '-y', '-i', src_path, '-c:a', 'libopus', '-b:a', bitrate] command = ["ffmpeg", "-y", "-i", src_path, "-c:a", "libopus", "-b:a", bitrate]
if compression_level is not None: if compression_level is not None:
command += ['-compression_level', str(compression_level)] command += ["-compression_level", str(compression_level)]
command += [dst_path] command += [dst_path]
print(' '.join(command)) print(" ".join(command))
subprocess.run(command, check=True) subprocess.run(command, check=True)
def _generate(num_channels, compression_level, bitrate): def _generate(num_channels, compression_level, bitrate):
org_path = 'original.wav' org_path = "original.wav"
ops_path = f'{bitrate}_{compression_level}_{num_channels}ch.opus' ops_path = f"{bitrate}_{compression_level}_{num_channels}ch.opus"
# Note: ffmpeg forces sample rate 48k Hz for opus https://stackoverflow.com/a/39186779 # Note: ffmpeg forces sample rate 48k Hz for opus https://stackoverflow.com/a/39186779
# 1. generate original wav # 1. generate original wav
...@@ -46,5 +42,5 @@ def _main(): ...@@ -46,5 +42,5 @@ def _main():
_generate(args.num_channels, args.compression_level, args.bitrate) _generate(args.num_channels, args.compression_level, args.bitrate)
if __name__ == '__main__': if __name__ == "__main__":
_main() _main()
...@@ -32,8 +32,8 @@ python generate_hubert_model_config.py \ ...@@ -32,8 +32,8 @@ python generate_hubert_model_config.py \
> hubert_large_ll60k_finetune_ls960.json > hubert_large_ll60k_finetune_ls960.json
``` ```
""" """
import json
import argparse import argparse
import json
def _parse_args(): def _parse_args():
...@@ -42,12 +42,9 @@ def _parse_args(): ...@@ -42,12 +42,9 @@ def _parse_args():
formatter_class=argparse.RawTextHelpFormatter, formatter_class=argparse.RawTextHelpFormatter,
) )
parser.add_argument( parser.add_argument(
'--model-file', "--model-file",
required=True, required=True,
help=( help=("A pt file from " "https://github.com/pytorch/fairseq/tree/main/examples/hubert"),
'A pt file from '
'https://github.com/pytorch/fairseq/tree/main/examples/hubert'
)
) )
return parser.parse_args() return parser.parse_args()
...@@ -66,27 +63,27 @@ def _main(): ...@@ -66,27 +63,27 @@ def _main():
args = _parse_args() args = _parse_args()
model, cfg = _load(args.model_file) model, cfg = _load(args.model_file)
if model.__class__.__name__ == 'HubertModel': if model.__class__.__name__ == "HubertModel":
cfg['task']['data'] = '/foo/bar' cfg["task"]["data"] = "/foo/bar"
cfg['task']['label_dir'] = None cfg["task"]["label_dir"] = None
conf = { conf = {
'_name': 'hubert', "_name": "hubert",
'model': cfg['model'], "model": cfg["model"],
'task': cfg['task'], "task": cfg["task"],
'num_classes': model.num_classes, "num_classes": model.num_classes,
} }
elif model.__class__.__name__ == 'HubertCtc': elif model.__class__.__name__ == "HubertCtc":
conf = cfg['model'] conf = cfg["model"]
del conf['w2v_path'] del conf["w2v_path"]
keep = ['_name', 'task', 'model'] keep = ["_name", "task", "model"]
for key in list(k for k in conf['w2v_args'] if k not in keep): for key in list(k for k in conf["w2v_args"] if k not in keep):
del conf['w2v_args'][key] del conf["w2v_args"][key]
conf['data'] = '/foo/bar/' conf["data"] = "/foo/bar/"
conf['w2v_args']['task']['data'] = '/foo/bar' conf["w2v_args"]["task"]["data"] = "/foo/bar"
conf['w2v_args']['task']['labels'] = [] conf["w2v_args"]["task"]["labels"] = []
conf['w2v_args']['task']['label_dir'] = '/foo/bar' conf["w2v_args"]["task"]["label_dir"] = "/foo/bar"
print(json.dumps(conf, indent=4, sort_keys=True)) print(json.dumps(conf, indent=4, sort_keys=True))
if __name__ == '__main__': if __name__ == "__main__":
_main() _main()
...@@ -41,9 +41,9 @@ python generate_wav2vec2_model_config.py \ ...@@ -41,9 +41,9 @@ python generate_wav2vec2_model_config.py \
> wav2vec_large_lv60_self_960h.json > wav2vec_large_lv60_self_960h.json
``` ```
""" """
import os
import json
import argparse import argparse
import json
import os
def _parse_args(): def _parse_args():
...@@ -52,19 +52,13 @@ def _parse_args(): ...@@ -52,19 +52,13 @@ def _parse_args():
formatter_class=argparse.RawTextHelpFormatter, formatter_class=argparse.RawTextHelpFormatter,
) )
parser.add_argument( parser.add_argument(
'--model-file', "--model-file",
required=True, required=True,
help=( help=("A point file from " "https://github.com/pytorch/fairseq/tree/main/examples/wav2vec"),
'A point file from '
'https://github.com/pytorch/fairseq/tree/main/examples/wav2vec'
)
) )
parser.add_argument( parser.add_argument(
'--dict-dir', "--dict-dir",
help=( help=("Directory where `dict.ltr.txt` file is found. " "Default: the directory of the given model."),
'Directory where `dict.ltr.txt` file is found. '
'Default: the directory of the given model.'
)
) )
args = parser.parse_args() args = parser.parse_args()
if args.dict_dir is None: if args.dict_dir is None:
...@@ -75,32 +69,29 @@ def _parse_args(): ...@@ -75,32 +69,29 @@ def _parse_args():
def _to_json(conf): def _to_json(conf):
import yaml import yaml
from omegaconf import OmegaConf from omegaconf import OmegaConf
return yaml.safe_load(OmegaConf.to_yaml(conf)) return yaml.safe_load(OmegaConf.to_yaml(conf))
def _load(model_file, dict_dir): def _load(model_file, dict_dir):
import fairseq import fairseq
overrides = {'data': dict_dir} overrides = {"data": dict_dir}
_, args, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( _, args, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file], arg_overrides=overrides)
[model_file], arg_overrides=overrides return _to_json(args["model"])
)
return _to_json(args['model'])
def _main(): def _main():
args = _parse_args() args = _parse_args()
conf = _load(args.model_file, args.dict_dir) conf = _load(args.model_file, args.dict_dir)
if conf['_name'] == 'wav2vec_ctc': if conf["_name"] == "wav2vec_ctc":
del conf['data'] del conf["data"]
del conf['w2v_args']['task']['data'] del conf["w2v_args"]["task"]["data"]
conf['w2v_args'] = { conf["w2v_args"] = {key: conf["w2v_args"][key] for key in ["model", "task"]}
key: conf['w2v_args'][key] for key in ['model', 'task']
}
print(json.dumps(conf, indent=4, sort_keys=True)) print(json.dumps(conf, indent=4, sort_keys=True))
if __name__ == '__main__': if __name__ == "__main__":
_main() _main()
import os
import json import json
import os
from transformers import Wav2Vec2Model from transformers import Wav2Vec2Model
...@@ -22,16 +22,16 @@ def _main(): ...@@ -22,16 +22,16 @@ def _main():
"facebook/wav2vec2-large-xlsr-53-german", "facebook/wav2vec2-large-xlsr-53-german",
] ]
for key in keys: for key in keys:
path = os.path.join(_THIS_DIR, f'{key}.json') path = os.path.join(_THIS_DIR, f"{key}.json")
print('Generating ', path) print("Generating ", path)
cfg = Wav2Vec2Model.from_pretrained(key).config cfg = Wav2Vec2Model.from_pretrained(key).config
cfg = json.loads(cfg.to_json_string()) cfg = json.loads(cfg.to_json_string())
del cfg['_name_or_path'] del cfg["_name_or_path"]
with open(path, 'w') as file_: with open(path, "w") as file_:
file_.write(json.dumps(cfg, indent=4, sort_keys=True)) file_.write(json.dumps(cfg, indent=4, sort_keys=True))
file_.write('\n') file_.write("\n")
if __name__ == '__main__': if __name__ == "__main__":
_main() _main()
...@@ -3,23 +3,23 @@ from torchaudio_unittest.common_utils import sox_utils ...@@ -3,23 +3,23 @@ from torchaudio_unittest.common_utils import sox_utils
def get_encoding(ext, dtype): def get_encoding(ext, dtype):
exts = { exts = {
'mp3', "mp3",
'flac', "flac",
'vorbis', "vorbis",
} }
encodings = { encodings = {
'float32': 'PCM_F', "float32": "PCM_F",
'int32': 'PCM_S', "int32": "PCM_S",
'int16': 'PCM_S', "int16": "PCM_S",
'uint8': 'PCM_U', "uint8": "PCM_U",
} }
return ext.upper() if ext in exts else encodings[dtype] return ext.upper() if ext in exts else encodings[dtype]
def get_bits_per_sample(ext, dtype): def get_bits_per_sample(ext, dtype):
bits_per_samples = { bits_per_samples = {
'flac': 24, "flac": 24,
'mp3': 0, "mp3": 0,
'vorbis': 0, "vorbis": 0,
} }
return bits_per_samples.get(ext, sox_utils.get_bit_depth(dtype)) return bits_per_samples.get(ext, sox_utils.get_bit_depth(dtype))
...@@ -38,20 +38,19 @@ def fetch_wav_subtype(dtype, encoding, bits_per_sample): ...@@ -38,20 +38,19 @@ def fetch_wav_subtype(dtype, encoding, bits_per_sample):
subtype = { subtype = {
(None, None): dtype2subtype(dtype), (None, None): dtype2subtype(dtype),
(None, 8): "PCM_U8", (None, 8): "PCM_U8",
('PCM_U', None): "PCM_U8", ("PCM_U", None): "PCM_U8",
('PCM_U', 8): "PCM_U8", ("PCM_U", 8): "PCM_U8",
('PCM_S', None): "PCM_32", ("PCM_S", None): "PCM_32",
('PCM_S', 16): "PCM_16", ("PCM_S", 16): "PCM_16",
('PCM_S', 32): "PCM_32", ("PCM_S", 32): "PCM_32",
('PCM_F', None): "FLOAT", ("PCM_F", None): "FLOAT",
('PCM_F', 32): "FLOAT", ("PCM_F", 32): "FLOAT",
('PCM_F', 64): "DOUBLE", ("PCM_F", 64): "DOUBLE",
('ULAW', None): "ULAW", ("ULAW", None): "ULAW",
('ULAW', 8): "ULAW", ("ULAW", 8): "ULAW",
('ALAW', None): "ALAW", ("ALAW", None): "ALAW",
('ALAW', 8): "ALAW", ("ALAW", 8): "ALAW",
}.get((encoding, bits_per_sample)) }.get((encoding, bits_per_sample))
if subtype: if subtype:
return subtype return subtype
raise ValueError( raise ValueError(f"wav does not support ({encoding}, {bits_per_sample}).")
f"wav does not support ({encoding}, {bits_per_sample}).")
from unittest.mock import patch
import warnings
import tarfile import tarfile
import warnings
from unittest.mock import patch
import torch import torch
from torchaudio.backend import soundfile_backend
from torchaudio._internal import module_utils as _mod_utils from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import soundfile_backend
from torchaudio_unittest.backend.common import (
get_bits_per_sample,
get_encoding,
)
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
PytorchTestCase, PytorchTestCase,
...@@ -14,10 +17,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -14,10 +17,7 @@ from torchaudio_unittest.common_utils import (
save_wav, save_wav,
nested_params, nested_params,
) )
from torchaudio_unittest.backend.common import (
get_bits_per_sample,
get_encoding,
)
from .common import skipIfFormatNotSupported, parameterize from .common import skipIfFormatNotSupported, parameterize
if _mod_utils.is_module_available("soundfile"): if _mod_utils.is_module_available("soundfile"):
...@@ -27,15 +27,15 @@ if _mod_utils.is_module_available("soundfile"): ...@@ -27,15 +27,15 @@ if _mod_utils.is_module_available("soundfile"):
@skipIfNoModule("soundfile") @skipIfNoModule("soundfile")
class TestInfo(TempDirMixin, PytorchTestCase): class TestInfo(TempDirMixin, PytorchTestCase):
@parameterize( @parameterize(
["float32", "int32", "int16", "uint8"], [8000, 16000], [1, 2], ["float32", "int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
) )
def test_wav(self, dtype, sample_rate, num_channels): def test_wav(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.info` can check wav file correctly""" """`soundfile_backend.info` can check wav file correctly"""
duration = 1 duration = 1
path = self.get_temp_path("data.wav") path = self.get_temp_path("data.wav")
data = get_wav_data( data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
dtype, num_channels, normalize=False, num_frames=duration * sample_rate
)
save_wav(path, data, sample_rate) save_wav(path, data, sample_rate)
info = soundfile_backend.info(path) info = soundfile_backend.info(path)
assert info.sample_rate == sample_rate assert info.sample_rate == sample_rate
...@@ -81,10 +81,7 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -81,10 +81,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
@nested_params( @nested_params(
[8000, 16000], [8000, 16000],
[1, 2], [1, 2],
[ [("PCM_24", 24), ("PCM_32", 32)],
('PCM_24', 24),
('PCM_32', 32)
],
) )
@skipIfFormatNotSupported("NIST") @skipIfFormatNotSupported("NIST")
def test_sphere(self, sample_rate, num_channels, subtype_and_bit_depth): def test_sphere(self, sample_rate, num_channels, subtype_and_bit_depth):
...@@ -109,13 +106,15 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -109,13 +106,15 @@ class TestInfo(TempDirMixin, PytorchTestCase):
This will happen if a new subtype is supported in SoundFile: the _SUBTYPE_TO_BITS_PER_SAMPLE This will happen if a new subtype is supported in SoundFile: the _SUBTYPE_TO_BITS_PER_SAMPLE
dict should be updated. dict should be updated.
""" """
def _mock_info_func(_): def _mock_info_func(_):
class MockSoundFileInfo: class MockSoundFileInfo:
samplerate = 8000 samplerate = 8000
frames = 356 frames = 356
channels = 2 channels = 2
subtype = 'UNSEEN_SUBTYPE' subtype = "UNSEEN_SUBTYPE"
format = 'UNKNOWN' format = "UNKNOWN"
return MockSoundFileInfo() return MockSoundFileInfo()
with patch("soundfile.info", _mock_info_func): with patch("soundfile.info", _mock_info_func):
...@@ -134,27 +133,27 @@ class TestFileObject(TempDirMixin, PytorchTestCase): ...@@ -134,27 +133,27 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
sample_rate = 16000 sample_rate = 16000
num_channels = 2 num_channels = 2
num_frames = sample_rate * duration num_frames = sample_rate * duration
path = self.get_temp_path(f'test.{ext}') path = self.get_temp_path(f"test.{ext}")
data = torch.randn(num_frames, num_channels).numpy() data = torch.randn(num_frames, num_channels).numpy()
soundfile.write(path, data, sample_rate, subtype=subtype) soundfile.write(path, data, sample_rate, subtype=subtype)
with open(path, 'rb') as fileobj: with open(path, "rb") as fileobj:
info = soundfile_backend.info(fileobj) info = soundfile_backend.info(fileobj)
assert info.sample_rate == sample_rate assert info.sample_rate == sample_rate
assert info.num_frames == num_frames assert info.num_frames == num_frames
assert info.num_channels == num_channels assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample assert info.bits_per_sample == bits_per_sample
assert info.encoding == "FLAC" if ext == 'flac' else "PCM_S" assert info.encoding == "FLAC" if ext == "flac" else "PCM_S"
def test_fileobj_wav(self): def test_fileobj_wav(self):
"""Loading audio via file-like object works""" """Loading audio via file-like object works"""
self._test_fileobj('wav', 'PCM_16', 16) self._test_fileobj("wav", "PCM_16", 16)
@skipIfFormatNotSupported("FLAC") @skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self): def test_fileobj_flac(self):
"""Loading audio via file-like object works""" """Loading audio via file-like object works"""
self._test_fileobj('flac', 'PCM_16', 16) self._test_fileobj("flac", "PCM_16", 16)
def _test_tarobj(self, ext, subtype, bits_per_sample): def _test_tarobj(self, ext, subtype, bits_per_sample):
"""Query compressed audio via file-like object works""" """Query compressed audio via file-like object works"""
...@@ -162,29 +161,29 @@ class TestFileObject(TempDirMixin, PytorchTestCase): ...@@ -162,29 +161,29 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
sample_rate = 16000 sample_rate = 16000
num_channels = 2 num_channels = 2
num_frames = sample_rate * duration num_frames = sample_rate * duration
audio_file = f'test.{ext}' audio_file = f"test.{ext}"
audio_path = self.get_temp_path(audio_file) audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path('archive.tar.gz') archive_path = self.get_temp_path("archive.tar.gz")
data = torch.randn(num_frames, num_channels).numpy() data = torch.randn(num_frames, num_channels).numpy()
soundfile.write(audio_path, data, sample_rate, subtype=subtype) soundfile.write(audio_path, data, sample_rate, subtype=subtype)
with tarfile.TarFile(archive_path, 'w') as tarobj: with tarfile.TarFile(archive_path, "w") as tarobj:
tarobj.add(audio_path, arcname=audio_file) tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj: with tarfile.TarFile(archive_path, "r") as tarobj:
fileobj = tarobj.extractfile(audio_file) fileobj = tarobj.extractfile(audio_file)
info = soundfile_backend.info(fileobj) info = soundfile_backend.info(fileobj)
assert info.sample_rate == sample_rate assert info.sample_rate == sample_rate
assert info.num_frames == num_frames assert info.num_frames == num_frames
assert info.num_channels == num_channels assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample assert info.bits_per_sample == bits_per_sample
assert info.encoding == "FLAC" if ext == 'flac' else "PCM_S" assert info.encoding == "FLAC" if ext == "flac" else "PCM_S"
def test_tarobj_wav(self): def test_tarobj_wav(self):
"""Query compressed audio via file-like object works""" """Query compressed audio via file-like object works"""
self._test_tarobj('wav', 'PCM_16', 16) self._test_tarobj("wav", "PCM_16", 16)
@skipIfFormatNotSupported("FLAC") @skipIfFormatNotSupported("FLAC")
def test_tarobj_flac(self): def test_tarobj_flac(self):
"""Query compressed audio via file-like object works""" """Query compressed audio via file-like object works"""
self._test_tarobj('flac', 'PCM_16', 16) self._test_tarobj("flac", "PCM_16", 16)
...@@ -3,10 +3,9 @@ import tarfile ...@@ -3,10 +3,9 @@ import tarfile
from unittest.mock import patch from unittest.mock import patch
import torch import torch
from parameterized import parameterized
from torchaudio._internal import module_utils as _mod_utils from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import soundfile_backend from torchaudio.backend import soundfile_backend
from parameterized import parameterized
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
PytorchTestCase, PytorchTestCase,
...@@ -16,6 +15,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -16,6 +15,7 @@ from torchaudio_unittest.common_utils import (
load_wav, load_wav,
save_wav, save_wav,
) )
from .common import ( from .common import (
parameterize, parameterize,
dtype2subtype, dtype2subtype,
...@@ -27,7 +27,11 @@ if _mod_utils.is_module_available("soundfile"): ...@@ -27,7 +27,11 @@ if _mod_utils.is_module_available("soundfile"):
def _get_mock_path( def _get_mock_path(
ext: str, dtype: str, sample_rate: int, num_channels: int, num_frames: int, ext: str,
dtype: str,
sample_rate: int,
num_channels: int,
num_frames: int,
): ):
return f"{dtype}_{sample_rate}_{num_channels}_{num_frames}.{ext}" return f"{dtype}_{sample_rate}_{num_channels}_{num_frames}.{ext}"
...@@ -86,7 +90,7 @@ class SoundFileMock: ...@@ -86,7 +90,7 @@ class SoundFileMock:
num_frames=self._params["num_frames"], num_frames=self._params["num_frames"],
channels_first=False, channels_first=False,
).numpy() ).numpy()
return data[self._start:self._start + frames] return data[self._start : self._start + frames]
def __enter__(self): def __enter__(self):
return self return self
...@@ -96,21 +100,13 @@ class SoundFileMock: ...@@ -96,21 +100,13 @@ class SoundFileMock:
class MockedLoadTest(PytorchTestCase): class MockedLoadTest(PytorchTestCase):
def assert_dtype( def assert_dtype(self, ext, dtype, sample_rate, num_channels, normalize, channels_first):
self, ext, dtype, sample_rate, num_channels, normalize, channels_first
):
"""When format is WAV or NIST, normalize=False will return the native dtype Tensor, otherwise float32""" """When format is WAV or NIST, normalize=False will return the native dtype Tensor, otherwise float32"""
num_frames = 3 * sample_rate num_frames = 3 * sample_rate
path = _get_mock_path(ext, dtype, sample_rate, num_channels, num_frames) path = _get_mock_path(ext, dtype, sample_rate, num_channels, num_frames)
expected_dtype = ( expected_dtype = torch.float32 if normalize or ext not in ["wav", "nist"] else getattr(torch, dtype)
torch.float32
if normalize or ext not in ["wav", "nist"]
else getattr(torch, dtype)
)
with patch("soundfile.SoundFile", SoundFileMock): with patch("soundfile.SoundFile", SoundFileMock):
found, sr = soundfile_backend.load( found, sr = soundfile_backend.load(path, normalize=normalize, channels_first=channels_first)
path, normalize=normalize, channels_first=channels_first
)
assert found.dtype == expected_dtype assert found.dtype == expected_dtype
assert sample_rate == sr assert sample_rate == sr
...@@ -123,32 +119,28 @@ class MockedLoadTest(PytorchTestCase): ...@@ -123,32 +119,28 @@ class MockedLoadTest(PytorchTestCase):
) )
def test_wav(self, dtype, sample_rate, num_channels, normalize, channels_first): def test_wav(self, dtype, sample_rate, num_channels, normalize, channels_first):
"""Returns native dtype when normalize=False else float32""" """Returns native dtype when normalize=False else float32"""
self.assert_dtype( self.assert_dtype("wav", dtype, sample_rate, num_channels, normalize, channels_first)
"wav", dtype, sample_rate, num_channels, normalize, channels_first
)
@parameterize( @parameterize(
["int8", "int16", "int32"], [8000, 16000], [1, 2], [True, False], [True, False], ["int8", "int16", "int32"],
[8000, 16000],
[1, 2],
[True, False],
[True, False],
) )
def test_sphere(self, dtype, sample_rate, num_channels, normalize, channels_first): def test_sphere(self, dtype, sample_rate, num_channels, normalize, channels_first):
"""Returns float32 always""" """Returns float32 always"""
self.assert_dtype( self.assert_dtype("sph", dtype, sample_rate, num_channels, normalize, channels_first)
"sph", dtype, sample_rate, num_channels, normalize, channels_first
)
@parameterize([8000, 16000], [1, 2], [True, False], [True, False]) @parameterize([8000, 16000], [1, 2], [True, False], [True, False])
def test_ogg(self, sample_rate, num_channels, normalize, channels_first): def test_ogg(self, sample_rate, num_channels, normalize, channels_first):
"""Returns float32 always""" """Returns float32 always"""
self.assert_dtype( self.assert_dtype("ogg", "int16", sample_rate, num_channels, normalize, channels_first)
"ogg", "int16", sample_rate, num_channels, normalize, channels_first
)
@parameterize([8000, 16000], [1, 2], [True, False], [True, False]) @parameterize([8000, 16000], [1, 2], [True, False], [True, False])
def test_flac(self, sample_rate, num_channels, normalize, channels_first): def test_flac(self, sample_rate, num_channels, normalize, channels_first):
"""`soundfile_backend.load` can load ogg format.""" """`soundfile_backend.load` can load ogg format."""
self.assert_dtype( self.assert_dtype("flac", "int16", sample_rate, num_channels, normalize, channels_first)
"flac", "int16", sample_rate, num_channels, normalize, channels_first
)
class LoadTestBase(TempDirMixin, PytorchTestCase): class LoadTestBase(TempDirMixin, PytorchTestCase):
...@@ -176,14 +168,17 @@ class LoadTestBase(TempDirMixin, PytorchTestCase): ...@@ -176,14 +168,17 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
) )
save_wav(path, data, sample_rate, channels_first=channels_first) save_wav(path, data, sample_rate, channels_first=channels_first)
expected = load_wav(path, normalize=normalize, channels_first=channels_first)[0] expected = load_wav(path, normalize=normalize, channels_first=channels_first)[0]
data, sr = soundfile_backend.load( data, sr = soundfile_backend.load(path, normalize=normalize, channels_first=channels_first)
path, normalize=normalize, channels_first=channels_first
)
assert sr == sample_rate assert sr == sample_rate
self.assertEqual(data, expected) self.assertEqual(data, expected)
def assert_sphere( def assert_sphere(
self, dtype, sample_rate, num_channels, channels_first=True, duration=1, self,
dtype,
sample_rate,
num_channels,
channels_first=True,
duration=1,
): ):
"""`soundfile_backend.load` can load SPHERE format correctly.""" """`soundfile_backend.load` can load SPHERE format correctly."""
path = self.get_temp_path("reference.sph") path = self.get_temp_path("reference.sph")
...@@ -195,16 +190,19 @@ class LoadTestBase(TempDirMixin, PytorchTestCase): ...@@ -195,16 +190,19 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
normalize=False, normalize=False,
channels_first=False, channels_first=False,
) )
soundfile.write( soundfile.write(path, raw, sample_rate, subtype=dtype2subtype(dtype), format="NIST")
path, raw, sample_rate, subtype=dtype2subtype(dtype), format="NIST"
)
expected = normalize_wav(raw.t() if channels_first else raw) expected = normalize_wav(raw.t() if channels_first else raw)
data, sr = soundfile_backend.load(path, channels_first=channels_first) data, sr = soundfile_backend.load(path, channels_first=channels_first)
assert sr == sample_rate assert sr == sample_rate
self.assertEqual(data, expected, atol=1e-4, rtol=1e-8) self.assertEqual(data, expected, atol=1e-4, rtol=1e-8)
def assert_flac( def assert_flac(
self, dtype, sample_rate, num_channels, channels_first=True, duration=1, self,
dtype,
sample_rate,
num_channels,
channels_first=True,
duration=1,
): ):
"""`soundfile_backend.load` can load FLAC format correctly.""" """`soundfile_backend.load` can load FLAC format correctly."""
path = self.get_temp_path("reference.flac") path = self.get_temp_path("reference.flac")
...@@ -239,7 +237,10 @@ class TestLoad(LoadTestBase): ...@@ -239,7 +237,10 @@ class TestLoad(LoadTestBase):
self.assert_wav(dtype, sample_rate, num_channels, normalize, channels_first) self.assert_wav(dtype, sample_rate, num_channels, normalize, channels_first)
@parameterize( @parameterize(
["int16"], [16000], [2], [False], ["int16"],
[16000],
[2],
[False],
) )
def test_wav_large(self, dtype, sample_rate, num_channels, normalize): def test_wav_large(self, dtype, sample_rate, num_channels, normalize):
"""`soundfile_backend.load` can load large wav file correctly.""" """`soundfile_backend.load` can load large wav file correctly."""
...@@ -269,15 +270,16 @@ class TestLoad(LoadTestBase): ...@@ -269,15 +270,16 @@ class TestLoad(LoadTestBase):
@skipIfNoModule("soundfile") @skipIfNoModule("soundfile")
class TestLoadFormat(TempDirMixin, PytorchTestCase): class TestLoadFormat(TempDirMixin, PytorchTestCase):
"""Given `format` parameter, `so.load` can load files without extension""" """Given `format` parameter, `so.load` can load files without extension"""
original = None original = None
path = None path = None
def _make_file(self, format_): def _make_file(self, format_):
sample_rate = 8000 sample_rate = 8000
path_with_ext = self.get_temp_path(f'test.{format_}') path_with_ext = self.get_temp_path(f"test.{format_}")
data = get_wav_data('float32', num_channels=2).numpy().T data = get_wav_data("float32", num_channels=2).numpy().T
soundfile.write(path_with_ext, data, sample_rate) soundfile.write(path_with_ext, data, sample_rate)
expected = soundfile.read(path_with_ext, dtype='float32')[0].T expected = soundfile.read(path_with_ext, dtype="float32")[0].T
path = os.path.splitext(path_with_ext)[0] path = os.path.splitext(path_with_ext)[0]
os.rename(path_with_ext, path) os.rename(path_with_ext, path)
return path, expected return path, expected
...@@ -288,15 +290,21 @@ class TestLoadFormat(TempDirMixin, PytorchTestCase): ...@@ -288,15 +290,21 @@ class TestLoadFormat(TempDirMixin, PytorchTestCase):
found, _ = soundfile_backend.load(path) found, _ = soundfile_backend.load(path)
self.assertEqual(found, expected) self.assertEqual(found, expected)
@parameterized.expand([ @parameterized.expand(
('WAV', ), ('wav', ), [
]) ("WAV",),
("wav",),
]
)
def test_wav(self, format_): def test_wav(self, format_):
self._test_format(format_) self._test_format(format_)
@parameterized.expand([ @parameterized.expand(
('FLAC', ), ('flac',), [
]) ("FLAC",),
("flac",),
]
)
@skipIfFormatNotSupported("FLAC") @skipIfFormatNotSupported("FLAC")
def test_flac(self, format_): def test_flac(self, format_):
self._test_format(format_) self._test_format(format_)
...@@ -307,40 +315,40 @@ class TestFileObject(TempDirMixin, PytorchTestCase): ...@@ -307,40 +315,40 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
def _test_fileobj(self, ext): def _test_fileobj(self, ext):
"""Loading audio via file-like object works""" """Loading audio via file-like object works"""
sample_rate = 16000 sample_rate = 16000
path = self.get_temp_path(f'test.{ext}') path = self.get_temp_path(f"test.{ext}")
data = get_wav_data('float32', num_channels=2).numpy().T data = get_wav_data("float32", num_channels=2).numpy().T
soundfile.write(path, data, sample_rate) soundfile.write(path, data, sample_rate)
expected = soundfile.read(path, dtype='float32')[0].T expected = soundfile.read(path, dtype="float32")[0].T
with open(path, 'rb') as fileobj: with open(path, "rb") as fileobj:
found, sr = soundfile_backend.load(fileobj) found, sr = soundfile_backend.load(fileobj)
assert sr == sample_rate assert sr == sample_rate
self.assertEqual(expected, found) self.assertEqual(expected, found)
def test_fileobj_wav(self): def test_fileobj_wav(self):
"""Loading audio via file-like object works""" """Loading audio via file-like object works"""
self._test_fileobj('wav') self._test_fileobj("wav")
@skipIfFormatNotSupported("FLAC") @skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self): def test_fileobj_flac(self):
"""Loading audio via file-like object works""" """Loading audio via file-like object works"""
self._test_fileobj('flac') self._test_fileobj("flac")
def _test_tarfile(self, ext): def _test_tarfile(self, ext):
"""Loading audio via file-like object works""" """Loading audio via file-like object works"""
sample_rate = 16000 sample_rate = 16000
audio_file = f'test.{ext}' audio_file = f"test.{ext}"
audio_path = self.get_temp_path(audio_file) audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path('archive.tar.gz') archive_path = self.get_temp_path("archive.tar.gz")
data = get_wav_data('float32', num_channels=2).numpy().T data = get_wav_data("float32", num_channels=2).numpy().T
soundfile.write(audio_path, data, sample_rate) soundfile.write(audio_path, data, sample_rate)
expected = soundfile.read(audio_path, dtype='float32')[0].T expected = soundfile.read(audio_path, dtype="float32")[0].T
with tarfile.TarFile(archive_path, 'w') as tarobj: with tarfile.TarFile(archive_path, "w") as tarobj:
tarobj.add(audio_path, arcname=audio_file) tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj: with tarfile.TarFile(archive_path, "r") as tarobj:
fileobj = tarobj.extractfile(audio_file) fileobj = tarobj.extractfile(audio_file)
found, sr = soundfile_backend.load(fileobj) found, sr = soundfile_backend.load(fileobj)
...@@ -349,9 +357,9 @@ class TestFileObject(TempDirMixin, PytorchTestCase): ...@@ -349,9 +357,9 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
def test_tarfile_wav(self): def test_tarfile_wav(self):
"""Loading audio via file-like object works""" """Loading audio via file-like object works"""
self._test_tarfile('wav') self._test_tarfile("wav")
@skipIfFormatNotSupported("FLAC") @skipIfFormatNotSupported("FLAC")
def test_tarfile_flac(self): def test_tarfile_flac(self):
"""Loading audio via file-like object works""" """Loading audio via file-like object works"""
self._test_tarfile('flac') self._test_tarfile("flac")
...@@ -3,7 +3,6 @@ from unittest.mock import patch ...@@ -3,7 +3,6 @@ from unittest.mock import patch
from torchaudio._internal import module_utils as _mod_utils from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import soundfile_backend from torchaudio.backend import soundfile_backend
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
PytorchTestCase, PytorchTestCase,
...@@ -12,6 +11,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -12,6 +11,7 @@ from torchaudio_unittest.common_utils import (
load_wav, load_wav,
nested_params, nested_params,
) )
from .common import ( from .common import (
fetch_wav_subtype, fetch_wav_subtype,
parameterize, parameterize,
...@@ -30,23 +30,22 @@ class MockedSaveTest(PytorchTestCase): ...@@ -30,23 +30,22 @@ class MockedSaveTest(PytorchTestCase):
[False, True], [False, True],
[ [
(None, None), (None, None),
('PCM_U', None), ("PCM_U", None),
('PCM_U', 8), ("PCM_U", 8),
('PCM_S', None), ("PCM_S", None),
('PCM_S', 16), ("PCM_S", 16),
('PCM_S', 32), ("PCM_S", 32),
('PCM_F', None), ("PCM_F", None),
('PCM_F', 32), ("PCM_F", 32),
('PCM_F', 64), ("PCM_F", 64),
('ULAW', None), ("ULAW", None),
('ULAW', 8), ("ULAW", 8),
('ALAW', None), ("ALAW", None),
('ALAW', 8), ("ALAW", 8),
], ],
) )
@patch("soundfile.write") @patch("soundfile.write")
def test_wav(self, dtype, sample_rate, num_channels, channels_first, def test_wav(self, dtype, sample_rate, num_channels, channels_first, enc_params, mocked_write):
enc_params, mocked_write):
"""soundfile_backend.save passes correct subtype to soundfile.write when WAV""" """soundfile_backend.save passes correct subtype to soundfile.write when WAV"""
filepath = "foo.wav" filepath = "foo.wav"
input_tensor = get_wav_data( input_tensor = get_wav_data(
...@@ -59,25 +58,33 @@ class MockedSaveTest(PytorchTestCase): ...@@ -59,25 +58,33 @@ class MockedSaveTest(PytorchTestCase):
encoding, bits_per_sample = enc_params encoding, bits_per_sample = enc_params
soundfile_backend.save( soundfile_backend.save(
filepath, input_tensor, sample_rate, channels_first=channels_first, filepath,
encoding=encoding, bits_per_sample=bits_per_sample input_tensor,
sample_rate,
channels_first=channels_first,
encoding=encoding,
bits_per_sample=bits_per_sample,
) )
# on +Py3.8 call_args.kwargs is more descreptive # on +Py3.8 call_args.kwargs is more descreptive
args = mocked_write.call_args[1] args = mocked_write.call_args[1]
assert args["file"] == filepath assert args["file"] == filepath
assert args["samplerate"] == sample_rate assert args["samplerate"] == sample_rate
assert args["subtype"] == fetch_wav_subtype( assert args["subtype"] == fetch_wav_subtype(dtype, encoding, bits_per_sample)
dtype, encoding, bits_per_sample)
assert args["format"] is None assert args["format"] is None
self.assertEqual( self.assertEqual(args["data"], input_tensor.t() if channels_first else input_tensor)
args["data"], input_tensor.t() if channels_first else input_tensor
)
@patch("soundfile.write") @patch("soundfile.write")
def assert_non_wav( def assert_non_wav(
self, fmt, dtype, sample_rate, num_channels, channels_first, mocked_write, self,
encoding=None, bits_per_sample=None, fmt,
dtype,
sample_rate,
num_channels,
channels_first,
mocked_write,
encoding=None,
bits_per_sample=None,
): ):
"""soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE""" """soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE"""
filepath = f"foo.{fmt}" filepath = f"foo.{fmt}"
...@@ -91,8 +98,12 @@ class MockedSaveTest(PytorchTestCase): ...@@ -91,8 +98,12 @@ class MockedSaveTest(PytorchTestCase):
expected_data = input_tensor.t() if channels_first else input_tensor expected_data = input_tensor.t() if channels_first else input_tensor
soundfile_backend.save( soundfile_backend.save(
filepath, input_tensor, sample_rate, channels_first, filepath,
encoding=encoding, bits_per_sample=bits_per_sample, input_tensor,
sample_rate,
channels_first,
encoding=encoding,
bits_per_sample=bits_per_sample,
) )
# on +Py3.8 call_args.kwargs is more descreptive # on +Py3.8 call_args.kwargs is more descreptive
...@@ -112,38 +123,42 @@ class MockedSaveTest(PytorchTestCase): ...@@ -112,38 +123,42 @@ class MockedSaveTest(PytorchTestCase):
[1, 2], [1, 2],
[False, True], [False, True],
[ [
('PCM_S', 8), ("PCM_S", 8),
('PCM_S', 16), ("PCM_S", 16),
('PCM_S', 24), ("PCM_S", 24),
('PCM_S', 32), ("PCM_S", 32),
('ULAW', 8), ("ULAW", 8),
('ALAW', 8), ("ALAW", 8),
('ALAW', 16), ("ALAW", 16),
('ALAW', 24), ("ALAW", 24),
('ALAW', 32), ("ALAW", 32),
], ],
) )
def test_sph(self, fmt, dtype, sample_rate, num_channels, channels_first, enc_params): def test_sph(self, fmt, dtype, sample_rate, num_channels, channels_first, enc_params):
"""soundfile_backend.save passes default format and subtype (None-s) to """soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV""" soundfile.write when not WAV"""
encoding, bits_per_sample = enc_params encoding, bits_per_sample = enc_params
self.assert_non_wav(fmt, dtype, sample_rate, num_channels, self.assert_non_wav(
channels_first, encoding=encoding, fmt, dtype, sample_rate, num_channels, channels_first, encoding=encoding, bits_per_sample=bits_per_sample
bits_per_sample=bits_per_sample) )
@parameterize( @parameterize(
["int32", "int16"], [8000, 16000], [1, 2], [False, True], ["int32", "int16"],
[8000, 16000],
[1, 2],
[False, True],
[8, 16, 24], [8, 16, 24],
) )
def test_flac(self, dtype, sample_rate, num_channels, def test_flac(self, dtype, sample_rate, num_channels, channels_first, bits_per_sample):
channels_first, bits_per_sample):
"""soundfile_backend.save passes default format and subtype (None-s) to """soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV""" soundfile.write when not WAV"""
self.assert_non_wav("flac", dtype, sample_rate, num_channels, self.assert_non_wav("flac", dtype, sample_rate, num_channels, channels_first, bits_per_sample=bits_per_sample)
channels_first, bits_per_sample=bits_per_sample)
@parameterize( @parameterize(
["int32", "int16"], [8000, 16000], [1, 2], [False, True], ["int32", "int16"],
[8000, 16000],
[1, 2],
[False, True],
) )
def test_ogg(self, dtype, sample_rate, num_channels, channels_first): def test_ogg(self, dtype, sample_rate, num_channels, channels_first):
"""soundfile_backend.save passes default format and subtype (None-s) to """soundfile_backend.save passes default format and subtype (None-s) to
...@@ -156,9 +171,7 @@ class SaveTestBase(TempDirMixin, PytorchTestCase): ...@@ -156,9 +171,7 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
def assert_wav(self, dtype, sample_rate, num_channels, num_frames): def assert_wav(self, dtype, sample_rate, num_channels, num_frames):
"""`soundfile_backend.save` can save wav format.""" """`soundfile_backend.save` can save wav format."""
path = self.get_temp_path("data.wav") path = self.get_temp_path("data.wav")
expected = get_wav_data( expected = get_wav_data(dtype, num_channels, num_frames=num_frames, normalize=False)
dtype, num_channels, num_frames=num_frames, normalize=False
)
soundfile_backend.save(path, expected, sample_rate) soundfile_backend.save(path, expected, sample_rate)
found, sr = load_wav(path, normalize=False) found, sr = load_wav(path, normalize=False)
assert sample_rate == sr assert sample_rate == sr
...@@ -172,9 +185,7 @@ class SaveTestBase(TempDirMixin, PytorchTestCase): ...@@ -172,9 +185,7 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
""" """
num_frames = sample_rate * 3 num_frames = sample_rate * 3
path = self.get_temp_path(f"data.{fmt}") path = self.get_temp_path(f"data.{fmt}")
expected = get_wav_data( expected = get_wav_data(dtype, num_channels, num_frames=num_frames, normalize=False)
dtype, num_channels, num_frames=num_frames, normalize=False
)
soundfile_backend.save(path, expected, sample_rate) soundfile_backend.save(path, expected, sample_rate)
sinfo = soundfile.info(path) sinfo = soundfile.info(path)
assert sinfo.format == fmt.upper() assert sinfo.format == fmt.upper()
...@@ -201,14 +212,17 @@ class SaveTestBase(TempDirMixin, PytorchTestCase): ...@@ -201,14 +212,17 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
@skipIfNoModule("soundfile") @skipIfNoModule("soundfile")
class TestSave(SaveTestBase): class TestSave(SaveTestBase):
@parameterize( @parameterize(
["float32", "int32", "int16"], [8000, 16000], [1, 2], ["float32", "int32", "int16"],
[8000, 16000],
[1, 2],
) )
def test_wav(self, dtype, sample_rate, num_channels): def test_wav(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save wav format.""" """`soundfile_backend.save` can save wav format."""
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None) self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
@parameterize( @parameterize(
["float32", "int32", "int16"], [4, 8, 16, 32], ["float32", "int32", "int16"],
[4, 8, 16, 32],
) )
def test_multiple_channels(self, dtype, num_channels): def test_multiple_channels(self, dtype, num_channels):
"""`soundfile_backend.save` can save wav with more than 2 channels.""" """`soundfile_backend.save` can save wav with more than 2 channels."""
...@@ -216,7 +230,9 @@ class TestSave(SaveTestBase): ...@@ -216,7 +230,9 @@ class TestSave(SaveTestBase):
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None) self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
@parameterize( @parameterize(
["int32", "int16"], [8000, 16000], [1, 2], ["int32", "int16"],
[8000, 16000],
[1, 2],
) )
@skipIfFormatNotSupported("NIST") @skipIfFormatNotSupported("NIST")
def test_sphere(self, dtype, sample_rate, num_channels): def test_sphere(self, dtype, sample_rate, num_channels):
...@@ -224,7 +240,8 @@ class TestSave(SaveTestBase): ...@@ -224,7 +240,8 @@ class TestSave(SaveTestBase):
self.assert_sphere(dtype, sample_rate, num_channels) self.assert_sphere(dtype, sample_rate, num_channels)
@parameterize( @parameterize(
[8000, 16000], [1, 2], [8000, 16000],
[1, 2],
) )
@skipIfFormatNotSupported("FLAC") @skipIfFormatNotSupported("FLAC")
def test_flac(self, sample_rate, num_channels): def test_flac(self, sample_rate, num_channels):
...@@ -232,7 +249,8 @@ class TestSave(SaveTestBase): ...@@ -232,7 +249,8 @@ class TestSave(SaveTestBase):
self.assert_flac("float32", sample_rate, num_channels) self.assert_flac("float32", sample_rate, num_channels)
@parameterize( @parameterize(
[8000, 16000], [1, 2], [8000, 16000],
[1, 2],
) )
@skipIfFormatNotSupported("OGG") @skipIfFormatNotSupported("OGG")
def test_ogg(self, sample_rate, num_channels): def test_ogg(self, sample_rate, num_channels):
...@@ -260,36 +278,36 @@ class TestFileObject(TempDirMixin, PytorchTestCase): ...@@ -260,36 +278,36 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
def _test_fileobj(self, ext): def _test_fileobj(self, ext):
"""Saving audio to file-like object works""" """Saving audio to file-like object works"""
sample_rate = 16000 sample_rate = 16000
path = self.get_temp_path(f'test.{ext}') path = self.get_temp_path(f"test.{ext}")
subtype = 'FLOAT' if ext == 'wav' else None subtype = "FLOAT" if ext == "wav" else None
data = get_wav_data('float32', num_channels=2) data = get_wav_data("float32", num_channels=2)
soundfile.write(path, data.numpy().T, sample_rate, subtype=subtype) soundfile.write(path, data.numpy().T, sample_rate, subtype=subtype)
expected = soundfile.read(path, dtype='float32')[0] expected = soundfile.read(path, dtype="float32")[0]
fileobj = io.BytesIO() fileobj = io.BytesIO()
soundfile_backend.save(fileobj, data, sample_rate, format=ext) soundfile_backend.save(fileobj, data, sample_rate, format=ext)
fileobj.seek(0) fileobj.seek(0)
found, sr = soundfile.read(fileobj, dtype='float32') found, sr = soundfile.read(fileobj, dtype="float32")
assert sr == sample_rate assert sr == sample_rate
self.assertEqual(expected, found, atol=1e-4, rtol=1e-8) self.assertEqual(expected, found, atol=1e-4, rtol=1e-8)
def test_fileobj_wav(self): def test_fileobj_wav(self):
"""Saving audio via file-like object works""" """Saving audio via file-like object works"""
self._test_fileobj('wav') self._test_fileobj("wav")
@skipIfFormatNotSupported("FLAC") @skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self): def test_fileobj_flac(self):
"""Saving audio via file-like object works""" """Saving audio via file-like object works"""
self._test_fileobj('flac') self._test_fileobj("flac")
@skipIfFormatNotSupported("NIST") @skipIfFormatNotSupported("NIST")
def test_fileobj_nist(self): def test_fileobj_nist(self):
"""Saving audio via file-like object works""" """Saving audio via file-like object works"""
self._test_fileobj('NIST') self._test_fileobj("NIST")
@skipIfFormatNotSupported("OGG") @skipIfFormatNotSupported("OGG")
def test_fileobj_ogg(self): def test_fileobj_ogg(self):
"""Saving audio via file-like object works""" """Saving audio via file-like object works"""
self._test_fileobj('OGG') self._test_fileobj("OGG")
...@@ -3,12 +3,12 @@ def name_func(func, _, params): ...@@ -3,12 +3,12 @@ def name_func(func, _, params):
def get_enc_params(dtype): def get_enc_params(dtype):
if dtype == 'float32': if dtype == "float32":
return 'PCM_F', 32 return "PCM_F", 32
if dtype == 'int32': if dtype == "int32":
return 'PCM_S', 32 return "PCM_S", 32
if dtype == 'int16': if dtype == "int16":
return 'PCM_S', 16 return "PCM_S", 16
if dtype == 'uint8': if dtype == "uint8":
return 'PCM_U', 8 return "PCM_U", 8
raise ValueError(f'Unexpected dtype: {dtype}') raise ValueError(f"Unexpected dtype: {dtype}")
from contextlib import contextmanager
import io import io
import os
import itertools import itertools
import os
import tarfile import tarfile
from contextlib import contextmanager
from parameterized import parameterized from parameterized import parameterized
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import sox_io_backend from torchaudio.backend import sox_io_backend
from torchaudio.utils.sox_utils import get_buffer_size, set_buffer_size from torchaudio.utils.sox_utils import get_buffer_size, set_buffer_size
from torchaudio._internal import module_utils as _mod_utils
from torchaudio_unittest.backend.common import ( from torchaudio_unittest.backend.common import (
get_bits_per_sample, get_bits_per_sample,
get_encoding, get_encoding,
...@@ -25,6 +24,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -25,6 +24,7 @@ from torchaudio_unittest.common_utils import (
save_wav, save_wav,
sox_utils, sox_utils,
) )
from .common import ( from .common import (
name_func, name_func,
) )
...@@ -34,18 +34,23 @@ if _mod_utils.is_module_available("requests"): ...@@ -34,18 +34,23 @@ if _mod_utils.is_module_available("requests"):
import requests import requests
@skipIfNoExec('sox') @skipIfNoExec("sox")
@skipIfNoSox @skipIfNoSox
class TestInfo(TempDirMixin, PytorchTestCase): class TestInfo(TempDirMixin, PytorchTestCase):
@parameterized.expand(list(itertools.product( @parameterized.expand(
['float32', 'int32', 'int16', 'uint8'], list(
[8000, 16000], itertools.product(
[1, 2], ["float32", "int32", "int16", "uint8"],
)), name_func=name_func) [8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_wav(self, dtype, sample_rate, num_channels): def test_wav(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file correctly""" """`sox_io_backend.info` can check wav file correctly"""
duration = 1 duration = 1
path = self.get_temp_path('data.wav') path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate) save_wav(path, data, sample_rate)
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
...@@ -53,17 +58,22 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -53,17 +58,22 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.num_frames == sample_rate * duration assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels assert info.num_channels == num_channels
assert info.bits_per_sample == sox_utils.get_bit_depth(dtype) assert info.bits_per_sample == sox_utils.get_bit_depth(dtype)
assert info.encoding == get_encoding('wav', dtype) assert info.encoding == get_encoding("wav", dtype)
@parameterized.expand(list(itertools.product( @parameterized.expand(
['float32', 'int32', 'int16', 'uint8'], list(
[8000, 16000], itertools.product(
[4, 8, 16, 32], ["float32", "int32", "int16", "uint8"],
)), name_func=name_func) [8000, 16000],
[4, 8, 16, 32],
)
),
name_func=name_func,
)
def test_wav_multiple_channels(self, dtype, sample_rate, num_channels): def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file with channels more than 2 correctly""" """`sox_io_backend.info` can check wav file with channels more than 2 correctly"""
duration = 1 duration = 1
path = self.get_temp_path('data.wav') path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate) save_wav(path, data, sample_rate)
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
...@@ -71,20 +81,28 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -71,20 +81,28 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.num_frames == sample_rate * duration assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels assert info.num_channels == num_channels
assert info.bits_per_sample == sox_utils.get_bit_depth(dtype) assert info.bits_per_sample == sox_utils.get_bit_depth(dtype)
assert info.encoding == get_encoding('wav', dtype) assert info.encoding == get_encoding("wav", dtype)
@parameterized.expand(list(itertools.product( @parameterized.expand(
[8000, 16000], list(
[1, 2], itertools.product(
[96, 128, 160, 192, 224, 256, 320], [8000, 16000],
)), name_func=name_func) [1, 2],
[96, 128, 160, 192, 224, 256, 320],
)
),
name_func=name_func,
)
def test_mp3(self, sample_rate, num_channels, bit_rate): def test_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.info` can check mp3 file correctly""" """`sox_io_backend.info` can check mp3 file correctly"""
duration = 1 duration = 1
path = self.get_temp_path('data.mp3') path = self.get_temp_path("data.mp3")
sox_utils.gen_audio_file( sox_utils.gen_audio_file(
path, sample_rate, num_channels, path,
compression=bit_rate, duration=duration, sample_rate,
num_channels,
compression=bit_rate,
duration=duration,
) )
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate assert info.sample_rate == sample_rate
...@@ -94,18 +112,26 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -94,18 +112,26 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
assert info.encoding == "MP3" assert info.encoding == "MP3"
@parameterized.expand(list(itertools.product( @parameterized.expand(
[8000, 16000], list(
[1, 2], itertools.product(
list(range(9)), [8000, 16000],
)), name_func=name_func) [1, 2],
list(range(9)),
)
),
name_func=name_func,
)
def test_flac(self, sample_rate, num_channels, compression_level): def test_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.info` can check flac file correctly""" """`sox_io_backend.info` can check flac file correctly"""
duration = 1 duration = 1
path = self.get_temp_path('data.flac') path = self.get_temp_path("data.flac")
sox_utils.gen_audio_file( sox_utils.gen_audio_file(
path, sample_rate, num_channels, path,
compression=compression_level, duration=duration, sample_rate,
num_channels,
compression=compression_level,
duration=duration,
) )
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate assert info.sample_rate == sample_rate
...@@ -114,18 +140,26 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -114,18 +140,26 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.bits_per_sample == 24 # FLAC standard assert info.bits_per_sample == 24 # FLAC standard
assert info.encoding == "FLAC" assert info.encoding == "FLAC"
@parameterized.expand(list(itertools.product( @parameterized.expand(
[8000, 16000], list(
[1, 2], itertools.product(
[-1, 0, 1, 2, 3, 3.6, 5, 10], [8000, 16000],
)), name_func=name_func) [1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)
),
name_func=name_func,
)
def test_vorbis(self, sample_rate, num_channels, quality_level): def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.info` can check vorbis file correctly""" """`sox_io_backend.info` can check vorbis file correctly"""
duration = 1 duration = 1
path = self.get_temp_path('data.vorbis') path = self.get_temp_path("data.vorbis")
sox_utils.gen_audio_file( sox_utils.gen_audio_file(
path, sample_rate, num_channels, path,
compression=quality_level, duration=duration, sample_rate,
num_channels,
compression=quality_level,
duration=duration,
) )
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate assert info.sample_rate == sample_rate
...@@ -134,18 +168,21 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -134,18 +168,21 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
assert info.encoding == "VORBIS" assert info.encoding == "VORBIS"
@parameterized.expand(list(itertools.product( @parameterized.expand(
[8000, 16000], list(
[1, 2], itertools.product(
[16, 32], [8000, 16000],
)), name_func=name_func) [1, 2],
[16, 32],
)
),
name_func=name_func,
)
def test_sphere(self, sample_rate, num_channels, bits_per_sample): def test_sphere(self, sample_rate, num_channels, bits_per_sample):
"""`sox_io_backend.info` can check sph file correctly""" """`sox_io_backend.info` can check sph file correctly"""
duration = 1 duration = 1
path = self.get_temp_path('data.sph') path = self.get_temp_path("data.sph")
sox_utils.gen_audio_file( sox_utils.gen_audio_file(path, sample_rate, num_channels, duration=duration, bit_depth=bits_per_sample)
path, sample_rate, num_channels, duration=duration,
bit_depth=bits_per_sample)
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration assert info.num_frames == sample_rate * duration
...@@ -153,19 +190,22 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -153,19 +190,22 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.bits_per_sample == bits_per_sample assert info.bits_per_sample == bits_per_sample
assert info.encoding == "PCM_S" assert info.encoding == "PCM_S"
@parameterized.expand(list(itertools.product( @parameterized.expand(
['int32', 'int16', 'uint8'], list(
[8000, 16000], itertools.product(
[1, 2], ["int32", "int16", "uint8"],
)), name_func=name_func) [8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_amb(self, dtype, sample_rate, num_channels): def test_amb(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check amb file correctly""" """`sox_io_backend.info` can check amb file correctly"""
duration = 1 duration = 1
path = self.get_temp_path('data.amb') path = self.get_temp_path("data.amb")
bits_per_sample = sox_utils.get_bit_depth(dtype) bits_per_sample = sox_utils.get_bit_depth(dtype)
sox_utils.gen_audio_file( sox_utils.gen_audio_file(path, sample_rate, num_channels, bit_depth=bits_per_sample, duration=duration)
path, sample_rate, num_channels,
bit_depth=bits_per_sample, duration=duration)
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration assert info.num_frames == sample_rate * duration
...@@ -178,10 +218,10 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -178,10 +218,10 @@ class TestInfo(TempDirMixin, PytorchTestCase):
duration = 1 duration = 1
num_channels = 1 num_channels = 1
sample_rate = 8000 sample_rate = 8000
path = self.get_temp_path('data.amr-nb') path = self.get_temp_path("data.amr-nb")
sox_utils.gen_audio_file( sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=16, path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=16, duration=duration
duration=duration) )
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration assert info.num_frames == sample_rate * duration
...@@ -194,11 +234,10 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -194,11 +234,10 @@ class TestInfo(TempDirMixin, PytorchTestCase):
duration = 1 duration = 1
num_channels = 1 num_channels = 1
sample_rate = 8000 sample_rate = 8000
path = self.get_temp_path('data.wav') path = self.get_temp_path("data.wav")
sox_utils.gen_audio_file( sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels, path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=8, encoding="u-law", duration=duration
bit_depth=8, encoding='u-law', )
duration=duration)
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration assert info.num_frames == sample_rate * duration
...@@ -211,11 +250,10 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -211,11 +250,10 @@ class TestInfo(TempDirMixin, PytorchTestCase):
duration = 1 duration = 1
num_channels = 1 num_channels = 1
sample_rate = 8000 sample_rate = 8000
path = self.get_temp_path('data.wav') path = self.get_temp_path("data.wav")
sox_utils.gen_audio_file( sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels, path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=8, encoding="a-law", duration=duration
bit_depth=8, encoding='a-law', )
duration=duration)
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration assert info.num_frames == sample_rate * duration
...@@ -228,10 +266,8 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -228,10 +266,8 @@ class TestInfo(TempDirMixin, PytorchTestCase):
duration = 1 duration = 1
num_channels = 1 num_channels = 1
sample_rate = 8000 sample_rate = 8000
path = self.get_temp_path('data.gsm') path = self.get_temp_path("data.gsm")
sox_utils.gen_audio_file( sox_utils.gen_audio_file(path, sample_rate=sample_rate, num_channels=num_channels, duration=duration)
path, sample_rate=sample_rate, num_channels=num_channels,
duration=duration)
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate assert info.sample_rate == sample_rate
assert info.num_channels == num_channels assert info.num_channels == num_channels
...@@ -243,10 +279,10 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -243,10 +279,10 @@ class TestInfo(TempDirMixin, PytorchTestCase):
duration = 1 duration = 1
num_channels = 1 num_channels = 1
sample_rate = 8000 sample_rate = 8000
path = self.get_temp_path('data.htk') path = self.get_temp_path("data.htk")
sox_utils.gen_audio_file( sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels, path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=16, duration=duration
bit_depth=16, duration=duration) )
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration assert info.num_frames == sample_rate * duration
...@@ -257,14 +293,19 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -257,14 +293,19 @@ class TestInfo(TempDirMixin, PytorchTestCase):
@skipIfNoSox @skipIfNoSox
class TestInfoOpus(PytorchTestCase): class TestInfoOpus(PytorchTestCase):
@parameterized.expand(list(itertools.product( @parameterized.expand(
['96k'], list(
[1, 2], itertools.product(
[0, 5, 10], ["96k"],
)), name_func=name_func) [1, 2],
[0, 5, 10],
)
),
name_func=name_func,
)
def test_opus(self, bitrate, num_channels, compression_level): def test_opus(self, bitrate, num_channels, compression_level):
"""`sox_io_backend.info` can check opus file correcty""" """`sox_io_backend.info` can check opus file correcty"""
path = get_asset_path('io', f'{bitrate}_{compression_level}_{num_channels}ch.opus') path = get_asset_path("io", f"{bitrate}_{compression_level}_{num_channels}ch.opus")
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.sample_rate == 48000 assert info.sample_rate == 48000
assert info.num_frames == 32768 assert info.num_frames == 32768
...@@ -296,13 +337,15 @@ class TestLoadWithoutExtension(PytorchTestCase): ...@@ -296,13 +337,15 @@ class TestLoadWithoutExtension(PytorchTestCase):
class FileObjTestBase(TempDirMixin): class FileObjTestBase(TempDirMixin):
def _gen_file(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None): def _gen_file(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
path = self.get_temp_path(f'test.{ext}') path = self.get_temp_path(f"test.{ext}")
bit_depth = sox_utils.get_bit_depth(dtype) bit_depth = sox_utils.get_bit_depth(dtype)
duration = num_frames / sample_rate duration = num_frames / sample_rate
comment_file = self._gen_comment_file(comments) if comments else None comment_file = self._gen_comment_file(comments) if comments else None
sox_utils.gen_audio_file( sox_utils.gen_audio_file(
path, sample_rate, num_channels=num_channels, path,
sample_rate,
num_channels=num_channels,
encoding=sox_utils.get_encoding(dtype), encoding=sox_utils.get_encoding(dtype),
bit_depth=bit_depth, bit_depth=bit_depth,
duration=duration, duration=duration,
...@@ -318,29 +361,29 @@ class FileObjTestBase(TempDirMixin): ...@@ -318,29 +361,29 @@ class FileObjTestBase(TempDirMixin):
@skipIfNoSox @skipIfNoSox
@skipIfNoExec('sox') @skipIfNoExec("sox")
class TestFileObject(FileObjTestBase, PytorchTestCase): class TestFileObject(FileObjTestBase, PytorchTestCase):
def _query_fileobj(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None): def _query_fileobj(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames, comments=comments) path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
format_ = ext if ext in ['mp3'] else None format_ = ext if ext in ["mp3"] else None
with open(path, 'rb') as fileobj: with open(path, "rb") as fileobj:
return sox_io_backend.info(fileobj, format_) return sox_io_backend.info(fileobj, format_)
def _query_bytesio(self, ext, dtype, sample_rate, num_channels, num_frames): def _query_bytesio(self, ext, dtype, sample_rate, num_channels, num_frames):
path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames) path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
format_ = ext if ext in ['mp3'] else None format_ = ext if ext in ["mp3"] else None
with open(path, 'rb') as file_: with open(path, "rb") as file_:
fileobj = io.BytesIO(file_.read()) fileobj = io.BytesIO(file_.read())
return sox_io_backend.info(fileobj, format_) return sox_io_backend.info(fileobj, format_)
def _query_tarfile(self, ext, dtype, sample_rate, num_channels, num_frames): def _query_tarfile(self, ext, dtype, sample_rate, num_channels, num_frames):
audio_path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames) audio_path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
audio_file = os.path.basename(audio_path) audio_file = os.path.basename(audio_path)
archive_path = self.get_temp_path('archive.tar.gz') archive_path = self.get_temp_path("archive.tar.gz")
with tarfile.TarFile(archive_path, 'w') as tarobj: with tarfile.TarFile(archive_path, "w") as tarobj:
tarobj.add(audio_path, arcname=audio_file) tarobj.add(audio_path, arcname=audio_file)
format_ = ext if ext in ['mp3'] else None format_ = ext if ext in ["mp3"] else None
with tarfile.TarFile(archive_path, 'r') as tarobj: with tarfile.TarFile(archive_path, "r") as tarobj:
fileobj = tarobj.extractfile(audio_file) fileobj = tarobj.extractfile(audio_file)
return sox_io_backend.info(fileobj, format_) return sox_io_backend.info(fileobj, format_)
...@@ -353,16 +396,18 @@ class TestFileObject(FileObjTestBase, PytorchTestCase): ...@@ -353,16 +396,18 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
finally: finally:
set_buffer_size(original_buffer_size) set_buffer_size(original_buffer_size)
@parameterized.expand([ @parameterized.expand(
('wav', "float32"), [
('wav', "int32"), ("wav", "float32"),
('wav', "int16"), ("wav", "int32"),
('wav', "uint8"), ("wav", "int16"),
('mp3', "float32"), ("wav", "uint8"),
('flac', "float32"), ("mp3", "float32"),
('vorbis', "float32"), ("flac", "float32"),
('amb', "int16"), ("vorbis", "float32"),
]) ("amb", "int16"),
]
)
def test_fileobj(self, ext, dtype): def test_fileobj(self, ext, dtype):
"""Querying audio via file object works""" """Querying audio via file object works"""
sample_rate = 16000 sample_rate = 16000
...@@ -371,7 +416,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase): ...@@ -371,7 +416,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames) sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype) bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames
assert sinfo.sample_rate == sample_rate assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels assert sinfo.num_channels == num_channels
...@@ -379,9 +424,11 @@ class TestFileObject(FileObjTestBase, PytorchTestCase): ...@@ -379,9 +424,11 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
assert sinfo.bits_per_sample == bits_per_sample assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype) assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand([ @parameterized.expand(
('vorbis', "float32"), [
]) ("vorbis", "float32"),
]
)
def test_fileobj_large_header(self, ext, dtype): def test_fileobj_large_header(self, ext, dtype):
""" """
For audio file with header size exceeding default buffer size: For audio file with header size exceeding default buffer size:
...@@ -399,7 +446,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase): ...@@ -399,7 +446,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
with self._set_buffer_size(16384): with self._set_buffer_size(16384):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments) sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
bits_per_sample = get_bits_per_sample(ext, dtype) bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames
assert sinfo.sample_rate == sample_rate assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels assert sinfo.num_channels == num_channels
...@@ -407,16 +454,18 @@ class TestFileObject(FileObjTestBase, PytorchTestCase): ...@@ -407,16 +454,18 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
assert sinfo.bits_per_sample == bits_per_sample assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype) assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand([ @parameterized.expand(
('wav', "float32"), [
('wav', "int32"), ("wav", "float32"),
('wav', "int16"), ("wav", "int32"),
('wav', "uint8"), ("wav", "int16"),
('mp3', "float32"), ("wav", "uint8"),
('flac', "float32"), ("mp3", "float32"),
('vorbis', "float32"), ("flac", "float32"),
('amb', "int16"), ("vorbis", "float32"),
]) ("amb", "int16"),
]
)
def test_bytesio(self, ext, dtype): def test_bytesio(self, ext, dtype):
"""Querying audio via ByteIO object works for small data""" """Querying audio via ByteIO object works for small data"""
sample_rate = 16000 sample_rate = 16000
...@@ -425,7 +474,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase): ...@@ -425,7 +474,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames) sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype) bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames
assert sinfo.sample_rate == sample_rate assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels assert sinfo.num_channels == num_channels
...@@ -433,16 +482,18 @@ class TestFileObject(FileObjTestBase, PytorchTestCase): ...@@ -433,16 +482,18 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
assert sinfo.bits_per_sample == bits_per_sample assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype) assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand([ @parameterized.expand(
('wav', "float32"), [
('wav', "int32"), ("wav", "float32"),
('wav', "int16"), ("wav", "int32"),
('wav', "uint8"), ("wav", "int16"),
('mp3', "float32"), ("wav", "uint8"),
('flac', "float32"), ("mp3", "float32"),
('vorbis', "float32"), ("flac", "float32"),
('amb', "int16"), ("vorbis", "float32"),
]) ("amb", "int16"),
]
)
def test_bytesio_tiny(self, ext, dtype): def test_bytesio_tiny(self, ext, dtype):
"""Querying audio via ByteIO object works for small data""" """Querying audio via ByteIO object works for small data"""
sample_rate = 8000 sample_rate = 8000
...@@ -451,7 +502,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase): ...@@ -451,7 +502,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames) sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype) bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames
assert sinfo.sample_rate == sample_rate assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels assert sinfo.num_channels == num_channels
...@@ -459,16 +510,18 @@ class TestFileObject(FileObjTestBase, PytorchTestCase): ...@@ -459,16 +510,18 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
assert sinfo.bits_per_sample == bits_per_sample assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype) assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand([ @parameterized.expand(
('wav', "float32"), [
('wav', "int32"), ("wav", "float32"),
('wav', "int16"), ("wav", "int32"),
('wav', "uint8"), ("wav", "int16"),
('mp3', "float32"), ("wav", "uint8"),
('flac', "float32"), ("mp3", "float32"),
('vorbis', "float32"), ("flac", "float32"),
('amb', "int16"), ("vorbis", "float32"),
]) ("amb", "int16"),
]
)
def test_tarfile(self, ext, dtype): def test_tarfile(self, ext, dtype):
"""Querying compressed audio via file-like object works""" """Querying compressed audio via file-like object works"""
sample_rate = 16000 sample_rate = 16000
...@@ -477,7 +530,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase): ...@@ -477,7 +530,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
sinfo = self._query_tarfile(ext, dtype, sample_rate, num_channels, num_frames) sinfo = self._query_tarfile(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype) bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames
assert sinfo.sample_rate == sample_rate assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels assert sinfo.num_channels == num_channels
...@@ -487,7 +540,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase): ...@@ -487,7 +540,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
@skipIfNoSox @skipIfNoSox
@skipIfNoExec('sox') @skipIfNoExec("sox")
@skipIfNoModule("requests") @skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase): class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase):
def _query_http(self, ext, dtype, sample_rate, num_channels, num_frames): def _query_http(self, ext, dtype, sample_rate, num_channels, num_frames):
...@@ -495,20 +548,22 @@ class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase): ...@@ -495,20 +548,22 @@ class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase):
audio_file = os.path.basename(audio_path) audio_file = os.path.basename(audio_path)
url = self.get_url(audio_file) url = self.get_url(audio_file)
format_ = ext if ext in ['mp3'] else None format_ = ext if ext in ["mp3"] else None
with requests.get(url, stream=True) as resp: with requests.get(url, stream=True) as resp:
return sox_io_backend.info(resp.raw, format=format_) return sox_io_backend.info(resp.raw, format=format_)
@parameterized.expand([ @parameterized.expand(
('wav', "float32"), [
('wav', "int32"), ("wav", "float32"),
('wav', "int16"), ("wav", "int32"),
('wav', "uint8"), ("wav", "int16"),
('mp3', "float32"), ("wav", "uint8"),
('flac', "float32"), ("mp3", "float32"),
('vorbis', "float32"), ("flac", "float32"),
('amb', "int16"), ("vorbis", "float32"),
]) ("amb", "int16"),
]
)
def test_requests(self, ext, dtype): def test_requests(self, ext, dtype):
"""Querying compressed audio via requests works""" """Querying compressed audio via requests works"""
sample_rate = 16000 sample_rate = 16000
...@@ -517,7 +572,7 @@ class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase): ...@@ -517,7 +572,7 @@ class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase):
sinfo = self._query_http(ext, dtype, sample_rate, num_channels, num_frames) sinfo = self._query_http(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype) bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames
assert sinfo.sample_rate == sample_rate assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels assert sinfo.num_channels == num_channels
......
...@@ -3,9 +3,8 @@ import itertools ...@@ -3,9 +3,8 @@ import itertools
import tarfile import tarfile
from parameterized import parameterized from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio._internal import module_utils as _mod_utils from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import sox_io_backend
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
HttpServerMixin, HttpServerMixin,
...@@ -19,6 +18,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -19,6 +18,7 @@ from torchaudio_unittest.common_utils import (
save_wav, save_wav,
sox_utils, sox_utils,
) )
from .common import ( from .common import (
name_func, name_func,
) )
...@@ -30,17 +30,17 @@ if _mod_utils.is_module_available("requests"): ...@@ -30,17 +30,17 @@ if _mod_utils.is_module_available("requests"):
class LoadTestBase(TempDirMixin, PytorchTestCase): class LoadTestBase(TempDirMixin, PytorchTestCase):
def assert_format( def assert_format(
self, self,
format: str, format: str,
sample_rate: float, sample_rate: float,
num_channels: int, num_channels: int,
compression: float = None, compression: float = None,
bit_depth: int = None, bit_depth: int = None,
duration: float = 1, duration: float = 1,
normalize: bool = True, normalize: bool = True,
encoding: str = None, encoding: str = None,
atol: float = 4e-05, atol: float = 4e-05,
rtol: float = 1.3e-06, rtol: float = 1.3e-06,
): ):
"""`sox_io_backend.load` can load given format correctly. """`sox_io_backend.load` can load given format correctly.
...@@ -68,13 +68,18 @@ class LoadTestBase(TempDirMixin, PytorchTestCase): ...@@ -68,13 +68,18 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
data without using torchaudio data without using torchaudio
""" """
path = self.get_temp_path(f'1.original.{format}') path = self.get_temp_path(f"1.original.{format}")
ref_path = self.get_temp_path('2.reference.wav') ref_path = self.get_temp_path("2.reference.wav")
# 1. Generate the given format with sox # 1. Generate the given format with sox
sox_utils.gen_audio_file( sox_utils.gen_audio_file(
path, sample_rate, num_channels, encoding=encoding, path,
compression=compression, bit_depth=bit_depth, duration=duration, sample_rate,
num_channels,
encoding=encoding,
compression=compression,
bit_depth=bit_depth,
duration=duration,
) )
# 2. Convert to wav with sox # 2. Convert to wav with sox
wav_bit_depth = 32 if bit_depth == 24 else None # for 24-bit wav wav_bit_depth = 32 if bit_depth == 24 else None # for 24-bit wav
...@@ -92,7 +97,7 @@ class LoadTestBase(TempDirMixin, PytorchTestCase): ...@@ -92,7 +97,7 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
Wav data loaded with sox_io backend should match those with scipy Wav data loaded with sox_io backend should match those with scipy
""" """
path = self.get_temp_path('reference.wav') path = self.get_temp_path("reference.wav")
data = get_wav_data(dtype, num_channels, normalize=normalize, num_frames=duration * sample_rate) data = get_wav_data(dtype, num_channels, normalize=normalize, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate) save_wav(path, data, sample_rate)
expected = load_wav(path, normalize=normalize)[0] expected = load_wav(path, normalize=normalize)[0]
...@@ -101,118 +106,176 @@ class LoadTestBase(TempDirMixin, PytorchTestCase): ...@@ -101,118 +106,176 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
self.assertEqual(data, expected) self.assertEqual(data, expected)
@skipIfNoExec('sox') @skipIfNoExec("sox")
@skipIfNoSox @skipIfNoSox
class TestLoad(LoadTestBase): class TestLoad(LoadTestBase):
"""Test the correctness of `sox_io_backend.load` for various formats""" """Test the correctness of `sox_io_backend.load` for various formats"""
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'], @parameterized.expand(
[8000, 16000], list(
[1, 2], itertools.product(
[False, True], ["float32", "int32", "int16", "uint8"],
)), name_func=name_func) [8000, 16000],
[1, 2],
[False, True],
)
),
name_func=name_func,
)
def test_wav(self, dtype, sample_rate, num_channels, normalize): def test_wav(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load wav format correctly.""" """`sox_io_backend.load` can load wav format correctly."""
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1) self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1)
@parameterized.expand(list(itertools.product( @parameterized.expand(
[8000, 16000], list(
[1, 2], itertools.product(
[False, True], [8000, 16000],
)), name_func=name_func) [1, 2],
[False, True],
)
),
name_func=name_func,
)
def test_24bit_wav(self, sample_rate, num_channels, normalize): def test_24bit_wav(self, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load 24bit wav format correctly. Corectly casts it to ``int32`` tensor dtype.""" """`sox_io_backend.load` can load 24bit wav format correctly. Corectly casts it to ``int32`` tensor dtype."""
self.assert_format("wav", sample_rate, num_channels, bit_depth=24, normalize=normalize, duration=1) self.assert_format("wav", sample_rate, num_channels, bit_depth=24, normalize=normalize, duration=1)
@parameterized.expand(list(itertools.product( @parameterized.expand(
['int16'], list(
[16000], itertools.product(
[2], ["int16"],
[False], [16000],
)), name_func=name_func) [2],
[False],
)
),
name_func=name_func,
)
def test_wav_large(self, dtype, sample_rate, num_channels, normalize): def test_wav_large(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load large wav file correctly.""" """`sox_io_backend.load` can load large wav file correctly."""
two_hours = 2 * 60 * 60 two_hours = 2 * 60 * 60
self.assert_wav(dtype, sample_rate, num_channels, normalize, two_hours) self.assert_wav(dtype, sample_rate, num_channels, normalize, two_hours)
@parameterized.expand(list(itertools.product( @parameterized.expand(
['float32', 'int32', 'int16', 'uint8'], list(
[4, 8, 16, 32], itertools.product(
)), name_func=name_func) ["float32", "int32", "int16", "uint8"],
[4, 8, 16, 32],
)
),
name_func=name_func,
)
def test_multiple_channels(self, dtype, num_channels): def test_multiple_channels(self, dtype, num_channels):
"""`sox_io_backend.load` can load wav file with more than 2 channels.""" """`sox_io_backend.load` can load wav file with more than 2 channels."""
sample_rate = 8000 sample_rate = 8000
normalize = False normalize = False
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1) self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1)
@parameterized.expand(list(itertools.product( @parameterized.expand(
[8000, 16000, 44100], list(
[1, 2], itertools.product(
[96, 128, 160, 192, 224, 256, 320], [8000, 16000, 44100],
)), name_func=name_func) [1, 2],
[96, 128, 160, 192, 224, 256, 320],
)
),
name_func=name_func,
)
def test_mp3(self, sample_rate, num_channels, bit_rate): def test_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.load` can load mp3 format correctly.""" """`sox_io_backend.load` can load mp3 format correctly."""
self.assert_format("mp3", sample_rate, num_channels, compression=bit_rate, duration=1, atol=5e-05) self.assert_format("mp3", sample_rate, num_channels, compression=bit_rate, duration=1, atol=5e-05)
@parameterized.expand(list(itertools.product( @parameterized.expand(
[16000], list(
[2], itertools.product(
[128], [16000],
)), name_func=name_func) [2],
[128],
)
),
name_func=name_func,
)
def test_mp3_large(self, sample_rate, num_channels, bit_rate): def test_mp3_large(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.load` can load large mp3 file correctly.""" """`sox_io_backend.load` can load large mp3 file correctly."""
two_hours = 2 * 60 * 60 two_hours = 2 * 60 * 60
self.assert_format("mp3", sample_rate, num_channels, compression=bit_rate, duration=two_hours, atol=5e-05) self.assert_format("mp3", sample_rate, num_channels, compression=bit_rate, duration=two_hours, atol=5e-05)
@parameterized.expand(list(itertools.product( @parameterized.expand(
[8000, 16000], list(
[1, 2], itertools.product(
list(range(9)), [8000, 16000],
)), name_func=name_func) [1, 2],
list(range(9)),
)
),
name_func=name_func,
)
def test_flac(self, sample_rate, num_channels, compression_level): def test_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.load` can load flac format correctly.""" """`sox_io_backend.load` can load flac format correctly."""
self.assert_format("flac", sample_rate, num_channels, compression=compression_level, bit_depth=16, duration=1) self.assert_format("flac", sample_rate, num_channels, compression=compression_level, bit_depth=16, duration=1)
@parameterized.expand(list(itertools.product( @parameterized.expand(
[16000], list(
[2], itertools.product(
[0], [16000],
)), name_func=name_func) [2],
[0],
)
),
name_func=name_func,
)
def test_flac_large(self, sample_rate, num_channels, compression_level): def test_flac_large(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.load` can load large flac file correctly.""" """`sox_io_backend.load` can load large flac file correctly."""
two_hours = 2 * 60 * 60 two_hours = 2 * 60 * 60
self.assert_format( self.assert_format(
"flac", sample_rate, num_channels, compression=compression_level, bit_depth=16, duration=two_hours) "flac", sample_rate, num_channels, compression=compression_level, bit_depth=16, duration=two_hours
)
@parameterized.expand(list(itertools.product( @parameterized.expand(
[8000, 16000], list(
[1, 2], itertools.product(
[-1, 0, 1, 2, 3, 3.6, 5, 10], [8000, 16000],
)), name_func=name_func) [1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)
),
name_func=name_func,
)
def test_vorbis(self, sample_rate, num_channels, quality_level): def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.load` can load vorbis format correctly.""" """`sox_io_backend.load` can load vorbis format correctly."""
self.assert_format("vorbis", sample_rate, num_channels, compression=quality_level, bit_depth=16, duration=1) self.assert_format("vorbis", sample_rate, num_channels, compression=quality_level, bit_depth=16, duration=1)
@parameterized.expand(list(itertools.product( @parameterized.expand(
[16000], list(
[2], itertools.product(
[10], [16000],
)), name_func=name_func) [2],
[10],
)
),
name_func=name_func,
)
def test_vorbis_large(self, sample_rate, num_channels, quality_level): def test_vorbis_large(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.load` can load large vorbis file correctly.""" """`sox_io_backend.load` can load large vorbis file correctly."""
two_hours = 2 * 60 * 60 two_hours = 2 * 60 * 60
self.assert_format( self.assert_format(
"vorbis", sample_rate, num_channels, compression=quality_level, bit_depth=16, duration=two_hours) "vorbis", sample_rate, num_channels, compression=quality_level, bit_depth=16, duration=two_hours
)
@parameterized.expand(list(itertools.product( @parameterized.expand(
['96k'], list(
[1, 2], itertools.product(
[0, 5, 10], ["96k"],
)), name_func=name_func) [1, 2],
[0, 5, 10],
)
),
name_func=name_func,
)
def test_opus(self, bitrate, num_channels, compression_level): def test_opus(self, bitrate, num_channels, compression_level):
"""`sox_io_backend.load` can load opus file correctly.""" """`sox_io_backend.load` can load opus file correctly."""
ops_path = get_asset_path('io', f'{bitrate}_{compression_level}_{num_channels}ch.opus') ops_path = get_asset_path("io", f"{bitrate}_{compression_level}_{num_channels}ch.opus")
wav_path = self.get_temp_path(f'{bitrate}_{compression_level}_{num_channels}ch.opus.wav') wav_path = self.get_temp_path(f"{bitrate}_{compression_level}_{num_channels}ch.opus.wav")
sox_utils.convert_audio_file(ops_path, wav_path) sox_utils.convert_audio_file(ops_path, wav_path)
expected, sample_rate = load_wav(wav_path) expected, sample_rate = load_wav(wav_path)
...@@ -221,57 +284,74 @@ class TestLoad(LoadTestBase): ...@@ -221,57 +284,74 @@ class TestLoad(LoadTestBase):
assert sample_rate == sr assert sample_rate == sr
self.assertEqual(expected, found) self.assertEqual(expected, found)
@parameterized.expand(list(itertools.product( @parameterized.expand(
[8000, 16000], list(
[1, 2], itertools.product(
)), name_func=name_func) [8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_sphere(self, sample_rate, num_channels): def test_sphere(self, sample_rate, num_channels):
"""`sox_io_backend.load` can load sph format correctly.""" """`sox_io_backend.load` can load sph format correctly."""
self.assert_format("sph", sample_rate, num_channels, bit_depth=32, duration=1) self.assert_format("sph", sample_rate, num_channels, bit_depth=32, duration=1)
@parameterized.expand(list(itertools.product( @parameterized.expand(
['float32', 'int32', 'int16'], list(
[8000, 16000], itertools.product(
[1, 2], ["float32", "int32", "int16"],
[False, True], [8000, 16000],
)), name_func=name_func) [1, 2],
[False, True],
)
),
name_func=name_func,
)
def test_amb(self, dtype, sample_rate, num_channels, normalize): def test_amb(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load amb format correctly.""" """`sox_io_backend.load` can load amb format correctly."""
bit_depth = sox_utils.get_bit_depth(dtype) bit_depth = sox_utils.get_bit_depth(dtype)
encoding = sox_utils.get_encoding(dtype) encoding = sox_utils.get_encoding(dtype)
self.assert_format( self.assert_format(
"amb", sample_rate, num_channels, bit_depth=bit_depth, duration=1, encoding=encoding, normalize=normalize) "amb", sample_rate, num_channels, bit_depth=bit_depth, duration=1, encoding=encoding, normalize=normalize
)
def test_amr_nb(self): def test_amr_nb(self):
"""`sox_io_backend.load` can load amr_nb format correctly.""" """`sox_io_backend.load` can load amr_nb format correctly."""
self.assert_format("amr-nb", sample_rate=8000, num_channels=1, bit_depth=32, duration=1) self.assert_format("amr-nb", sample_rate=8000, num_channels=1, bit_depth=32, duration=1)
@skipIfNoExec('sox') @skipIfNoExec("sox")
@skipIfNoSox @skipIfNoSox
class TestLoadParams(TempDirMixin, PytorchTestCase): class TestLoadParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of frame parameters of `sox_io_backend.load`""" """Test the correctness of frame parameters of `sox_io_backend.load`"""
original = None original = None
path = None path = None
def setUp(self): def setUp(self):
super().setUp() super().setUp()
sample_rate = 8000 sample_rate = 8000
self.original = get_wav_data('float32', num_channels=2) self.original = get_wav_data("float32", num_channels=2)
self.path = self.get_temp_path('test.wav') self.path = self.get_temp_path("test.wav")
save_wav(self.path, self.original, sample_rate) save_wav(self.path, self.original, sample_rate)
@parameterized.expand(list(itertools.product( @parameterized.expand(
[0, 1, 10, 100, 1000], list(
[-1, 1, 10, 100, 1000], itertools.product(
)), name_func=name_func) [0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
)
),
name_func=name_func,
)
def test_frame(self, frame_offset, num_frames): def test_frame(self, frame_offset, num_frames):
"""num_frames and frame_offset correctly specify the region of data""" """num_frames and frame_offset correctly specify the region of data"""
found, _ = sox_io_backend.load(self.path, frame_offset, num_frames) found, _ = sox_io_backend.load(self.path, frame_offset, num_frames)
frame_end = None if num_frames == -1 else frame_offset + num_frames frame_end = None if num_frames == -1 else frame_offset + num_frames
self.assertEqual(found, self.original[:, frame_offset:frame_end]) self.assertEqual(found, self.original[:, frame_offset:frame_end])
@parameterized.expand([(True, ), (False, )], name_func=name_func) @parameterized.expand([(True,), (False,)], name_func=name_func)
def test_channels_first(self, channels_first): def test_channels_first(self, channels_first):
"""channels_first swaps axes""" """channels_first swaps axes"""
found, _ = sox_io_backend.load(self.path, channels_first=channels_first) found, _ = sox_io_backend.load(self.path, channels_first=channels_first)
...@@ -299,7 +379,7 @@ class TestLoadWithoutExtension(PytorchTestCase): ...@@ -299,7 +379,7 @@ class TestLoadWithoutExtension(PytorchTestCase):
class CloggedFileObj: class CloggedFileObj:
def __init__(self, fileobj): def __init__(self, fileobj):
self.fileobj = fileobj self.fileobj = fileobj
self.buffer = b'' self.buffer = b""
def read(self, n): def read(self, n):
if not self.buffer: if not self.buffer:
...@@ -310,163 +390,168 @@ class CloggedFileObj: ...@@ -310,163 +390,168 @@ class CloggedFileObj:
@skipIfNoSox @skipIfNoSox
@skipIfNoExec('sox') @skipIfNoExec("sox")
class TestFileObject(TempDirMixin, PytorchTestCase): class TestFileObject(TempDirMixin, PytorchTestCase):
""" """
In this test suite, the result of file-like object input is compared against file path input, In this test suite, the result of file-like object input is compared against file path input,
because `load` function is rigrously tested for file path inputs to match libsox's result, because `load` function is rigrously tested for file path inputs to match libsox's result,
""" """
@parameterized.expand([
('wav', {'bit_depth': 16}), @parameterized.expand(
('wav', {'bit_depth': 24}), [
('wav', {'bit_depth': 32}), ("wav", {"bit_depth": 16}),
('mp3', {'compression': 128}), ("wav", {"bit_depth": 24}),
('mp3', {'compression': 320}), ("wav", {"bit_depth": 32}),
('flac', {'compression': 0}), ("mp3", {"compression": 128}),
('flac', {'compression': 5}), ("mp3", {"compression": 320}),
('flac', {'compression': 8}), ("flac", {"compression": 0}),
('vorbis', {'compression': -1}), ("flac", {"compression": 5}),
('vorbis', {'compression': 10}), ("flac", {"compression": 8}),
('amb', {}), ("vorbis", {"compression": -1}),
]) ("vorbis", {"compression": 10}),
("amb", {}),
]
)
def test_fileobj(self, ext, kwargs): def test_fileobj(self, ext, kwargs):
"""Loading audio via file object returns the same result as via file path.""" """Loading audio via file object returns the same result as via file path."""
sample_rate = 16000 sample_rate = 16000
format_ = ext if ext in ['mp3'] else None format_ = ext if ext in ["mp3"] else None
path = self.get_temp_path(f'test.{ext}') path = self.get_temp_path(f"test.{ext}")
sox_utils.gen_audio_file( sox_utils.gen_audio_file(path, sample_rate, num_channels=2, **kwargs)
path, sample_rate, num_channels=2, **kwargs)
expected, _ = sox_io_backend.load(path) expected, _ = sox_io_backend.load(path)
with open(path, 'rb') as fileobj: with open(path, "rb") as fileobj:
found, sr = sox_io_backend.load(fileobj, format=format_) found, sr = sox_io_backend.load(fileobj, format=format_)
assert sr == sample_rate assert sr == sample_rate
self.assertEqual(expected, found) self.assertEqual(expected, found)
@parameterized.expand([ @parameterized.expand(
('wav', {'bit_depth': 16}), [
('wav', {'bit_depth': 24}), ("wav", {"bit_depth": 16}),
('wav', {'bit_depth': 32}), ("wav", {"bit_depth": 24}),
('mp3', {'compression': 128}), ("wav", {"bit_depth": 32}),
('mp3', {'compression': 320}), ("mp3", {"compression": 128}),
('flac', {'compression': 0}), ("mp3", {"compression": 320}),
('flac', {'compression': 5}), ("flac", {"compression": 0}),
('flac', {'compression': 8}), ("flac", {"compression": 5}),
('vorbis', {'compression': -1}), ("flac", {"compression": 8}),
('vorbis', {'compression': 10}), ("vorbis", {"compression": -1}),
('amb', {}), ("vorbis", {"compression": 10}),
]) ("amb", {}),
]
)
def test_bytesio(self, ext, kwargs): def test_bytesio(self, ext, kwargs):
"""Loading audio via BytesIO object returns the same result as via file path.""" """Loading audio via BytesIO object returns the same result as via file path."""
sample_rate = 16000 sample_rate = 16000
format_ = ext if ext in ['mp3'] else None format_ = ext if ext in ["mp3"] else None
path = self.get_temp_path(f'test.{ext}') path = self.get_temp_path(f"test.{ext}")
sox_utils.gen_audio_file( sox_utils.gen_audio_file(path, sample_rate, num_channels=2, **kwargs)
path, sample_rate, num_channels=2, **kwargs)
expected, _ = sox_io_backend.load(path) expected, _ = sox_io_backend.load(path)
with open(path, 'rb') as file_: with open(path, "rb") as file_:
fileobj = io.BytesIO(file_.read()) fileobj = io.BytesIO(file_.read())
found, sr = sox_io_backend.load(fileobj, format=format_) found, sr = sox_io_backend.load(fileobj, format=format_)
assert sr == sample_rate assert sr == sample_rate
self.assertEqual(expected, found) self.assertEqual(expected, found)
@parameterized.expand([ @parameterized.expand(
('wav', {'bit_depth': 16}), [
('wav', {'bit_depth': 24}), ("wav", {"bit_depth": 16}),
('wav', {'bit_depth': 32}), ("wav", {"bit_depth": 24}),
('mp3', {'compression': 128}), ("wav", {"bit_depth": 32}),
('mp3', {'compression': 320}), ("mp3", {"compression": 128}),
('flac', {'compression': 0}), ("mp3", {"compression": 320}),
('flac', {'compression': 5}), ("flac", {"compression": 0}),
('flac', {'compression': 8}), ("flac", {"compression": 5}),
('vorbis', {'compression': -1}), ("flac", {"compression": 8}),
('vorbis', {'compression': 10}), ("vorbis", {"compression": -1}),
('amb', {}), ("vorbis", {"compression": 10}),
]) ("amb", {}),
]
)
def test_bytesio_clogged(self, ext, kwargs): def test_bytesio_clogged(self, ext, kwargs):
"""Loading audio via clogged file object returns the same result as via file path. """Loading audio via clogged file object returns the same result as via file path.
This test case validates the case where fileobject returns shorter bytes than requeted. This test case validates the case where fileobject returns shorter bytes than requeted.
""" """
sample_rate = 16000 sample_rate = 16000
format_ = ext if ext in ['mp3'] else None format_ = ext if ext in ["mp3"] else None
path = self.get_temp_path(f'test.{ext}') path = self.get_temp_path(f"test.{ext}")
sox_utils.gen_audio_file( sox_utils.gen_audio_file(path, sample_rate, num_channels=2, **kwargs)
path, sample_rate, num_channels=2, **kwargs)
expected, _ = sox_io_backend.load(path) expected, _ = sox_io_backend.load(path)
with open(path, 'rb') as file_: with open(path, "rb") as file_:
fileobj = CloggedFileObj(io.BytesIO(file_.read())) fileobj = CloggedFileObj(io.BytesIO(file_.read()))
found, sr = sox_io_backend.load(fileobj, format=format_) found, sr = sox_io_backend.load(fileobj, format=format_)
assert sr == sample_rate assert sr == sample_rate
self.assertEqual(expected, found) self.assertEqual(expected, found)
@parameterized.expand([ @parameterized.expand(
('wav', {'bit_depth': 16}), [
('wav', {'bit_depth': 24}), ("wav", {"bit_depth": 16}),
('wav', {'bit_depth': 32}), ("wav", {"bit_depth": 24}),
('mp3', {'compression': 128}), ("wav", {"bit_depth": 32}),
('mp3', {'compression': 320}), ("mp3", {"compression": 128}),
('flac', {'compression': 0}), ("mp3", {"compression": 320}),
('flac', {'compression': 5}), ("flac", {"compression": 0}),
('flac', {'compression': 8}), ("flac", {"compression": 5}),
('vorbis', {'compression': -1}), ("flac", {"compression": 8}),
('vorbis', {'compression': 10}), ("vorbis", {"compression": -1}),
('amb', {}), ("vorbis", {"compression": 10}),
]) ("amb", {}),
]
)
def test_bytesio_tiny(self, ext, kwargs): def test_bytesio_tiny(self, ext, kwargs):
"""Loading very small audio via file object returns the same result as via file path. """Loading very small audio via file object returns the same result as via file path."""
"""
sample_rate = 16000 sample_rate = 16000
format_ = ext if ext in ['mp3'] else None format_ = ext if ext in ["mp3"] else None
path = self.get_temp_path(f'test.{ext}') path = self.get_temp_path(f"test.{ext}")
sox_utils.gen_audio_file( sox_utils.gen_audio_file(path, sample_rate, num_channels=2, duration=1 / 1600, **kwargs)
path, sample_rate, num_channels=2, duration=1 / 1600, **kwargs)
expected, _ = sox_io_backend.load(path) expected, _ = sox_io_backend.load(path)
with open(path, 'rb') as file_: with open(path, "rb") as file_:
fileobj = io.BytesIO(file_.read()) fileobj = io.BytesIO(file_.read())
found, sr = sox_io_backend.load(fileobj, format=format_) found, sr = sox_io_backend.load(fileobj, format=format_)
assert sr == sample_rate assert sr == sample_rate
self.assertEqual(expected, found) self.assertEqual(expected, found)
@parameterized.expand([ @parameterized.expand(
('wav', {'bit_depth': 16}), [
('wav', {'bit_depth': 24}), ("wav", {"bit_depth": 16}),
('wav', {'bit_depth': 32}), ("wav", {"bit_depth": 24}),
('mp3', {'compression': 128}), ("wav", {"bit_depth": 32}),
('mp3', {'compression': 320}), ("mp3", {"compression": 128}),
('flac', {'compression': 0}), ("mp3", {"compression": 320}),
('flac', {'compression': 5}), ("flac", {"compression": 0}),
('flac', {'compression': 8}), ("flac", {"compression": 5}),
('vorbis', {'compression': -1}), ("flac", {"compression": 8}),
('vorbis', {'compression': 10}), ("vorbis", {"compression": -1}),
('amb', {}), ("vorbis", {"compression": 10}),
]) ("amb", {}),
]
)
def test_tarfile(self, ext, kwargs): def test_tarfile(self, ext, kwargs):
"""Loading compressed audio via file-like object returns the same result as via file path.""" """Loading compressed audio via file-like object returns the same result as via file path."""
sample_rate = 16000 sample_rate = 16000
format_ = ext if ext in ['mp3'] else None format_ = ext if ext in ["mp3"] else None
audio_file = f'test.{ext}' audio_file = f"test.{ext}"
audio_path = self.get_temp_path(audio_file) audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path('archive.tar.gz') archive_path = self.get_temp_path("archive.tar.gz")
sox_utils.gen_audio_file( sox_utils.gen_audio_file(audio_path, sample_rate, num_channels=2, **kwargs)
audio_path, sample_rate, num_channels=2, **kwargs)
expected, _ = sox_io_backend.load(audio_path) expected, _ = sox_io_backend.load(audio_path)
with tarfile.TarFile(archive_path, 'w') as tarobj: with tarfile.TarFile(archive_path, "w") as tarobj:
tarobj.add(audio_path, arcname=audio_file) tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj: with tarfile.TarFile(archive_path, "r") as tarobj:
fileobj = tarobj.extractfile(audio_file) fileobj = tarobj.extractfile(audio_file)
found, sr = sox_io_backend.load(fileobj, format=format_) found, sr = sox_io_backend.load(fileobj, format=format_)
...@@ -475,30 +560,31 @@ class TestFileObject(TempDirMixin, PytorchTestCase): ...@@ -475,30 +560,31 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
@skipIfNoSox @skipIfNoSox
@skipIfNoExec('sox') @skipIfNoExec("sox")
@skipIfNoModule("requests") @skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, PytorchTestCase): class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
@parameterized.expand([ @parameterized.expand(
('wav', {'bit_depth': 16}), [
('wav', {'bit_depth': 24}), ("wav", {"bit_depth": 16}),
('wav', {'bit_depth': 32}), ("wav", {"bit_depth": 24}),
('mp3', {'compression': 128}), ("wav", {"bit_depth": 32}),
('mp3', {'compression': 320}), ("mp3", {"compression": 128}),
('flac', {'compression': 0}), ("mp3", {"compression": 320}),
('flac', {'compression': 5}), ("flac", {"compression": 0}),
('flac', {'compression': 8}), ("flac", {"compression": 5}),
('vorbis', {'compression': -1}), ("flac", {"compression": 8}),
('vorbis', {'compression': 10}), ("vorbis", {"compression": -1}),
('amb', {}), ("vorbis", {"compression": 10}),
]) ("amb", {}),
]
)
def test_requests(self, ext, kwargs): def test_requests(self, ext, kwargs):
sample_rate = 16000 sample_rate = 16000
format_ = ext if ext in ['mp3'] else None format_ = ext if ext in ["mp3"] else None
audio_file = f'test.{ext}' audio_file = f"test.{ext}"
audio_path = self.get_temp_path(audio_file) audio_path = self.get_temp_path(audio_file)
sox_utils.gen_audio_file( sox_utils.gen_audio_file(audio_path, sample_rate, num_channels=2, **kwargs)
audio_path, sample_rate, num_channels=2, **kwargs)
expected, _ = sox_io_backend.load(audio_path) expected, _ = sox_io_backend.load(audio_path)
url = self.get_url(audio_file) url = self.get_url(audio_file)
...@@ -508,17 +594,22 @@ class TestFileObjectHttp(HttpServerMixin, PytorchTestCase): ...@@ -508,17 +594,22 @@ class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
assert sr == sample_rate assert sr == sample_rate
self.assertEqual(expected, found) self.assertEqual(expected, found)
@parameterized.expand(list(itertools.product( @parameterized.expand(
[0, 1, 10, 100, 1000], list(
[-1, 1, 10, 100, 1000], itertools.product(
)), name_func=name_func) [0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
)
),
name_func=name_func,
)
def test_frame(self, frame_offset, num_frames): def test_frame(self, frame_offset, num_frames):
"""num_frames and frame_offset correctly specify the region of data""" """num_frames and frame_offset correctly specify the region of data"""
sample_rate = 8000 sample_rate = 8000
audio_file = 'test.wav' audio_file = "test.wav"
audio_path = self.get_temp_path(audio_file) audio_path = self.get_temp_path(audio_file)
original = get_wav_data('float32', num_channels=2) original = get_wav_data("float32", num_channels=2)
save_wav(audio_path, original, sample_rate) save_wav(audio_path, original, sample_rate)
frame_end = None if num_frames == -1 else frame_offset + num_frames frame_end = None if num_frames == -1 else frame_offset + num_frames
expected = original[:, frame_offset:frame_end] expected = original[:, frame_offset:frame_end]
......
import itertools import itertools
from torchaudio.backend import sox_io_backend
from parameterized import parameterized from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
PytorchTestCase, PytorchTestCase,
...@@ -10,44 +9,56 @@ from torchaudio_unittest.common_utils import ( ...@@ -10,44 +9,56 @@ from torchaudio_unittest.common_utils import (
skipIfNoSox, skipIfNoSox,
get_wav_data, get_wav_data,
) )
from .common import ( from .common import (
name_func, name_func,
get_enc_params, get_enc_params,
) )
@skipIfNoExec('sox') @skipIfNoExec("sox")
@skipIfNoSox @skipIfNoSox
class TestRoundTripIO(TempDirMixin, PytorchTestCase): class TestRoundTripIO(TempDirMixin, PytorchTestCase):
"""save/load round trip should not degrade data for lossless formats""" """save/load round trip should not degrade data for lossless formats"""
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'], @parameterized.expand(
[8000, 16000], list(
[1, 2], itertools.product(
)), name_func=name_func) ["float32", "int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_wav(self, dtype, sample_rate, num_channels): def test_wav(self, dtype, sample_rate, num_channels):
"""save/load round trip should not degrade data for wav formats""" """save/load round trip should not degrade data for wav formats"""
original = get_wav_data(dtype, num_channels, normalize=False) original = get_wav_data(dtype, num_channels, normalize=False)
enc, bps = get_enc_params(dtype) enc, bps = get_enc_params(dtype)
data = original data = original
for i in range(10): for i in range(10):
path = self.get_temp_path(f'{i}.wav') path = self.get_temp_path(f"{i}.wav")
sox_io_backend.save(path, data, sample_rate, encoding=enc, bits_per_sample=bps) sox_io_backend.save(path, data, sample_rate, encoding=enc, bits_per_sample=bps)
data, sr = sox_io_backend.load(path, normalize=False) data, sr = sox_io_backend.load(path, normalize=False)
assert sr == sample_rate assert sr == sample_rate
self.assertEqual(original, data) self.assertEqual(original, data)
@parameterized.expand(list(itertools.product( @parameterized.expand(
[8000, 16000], list(
[1, 2], itertools.product(
list(range(9)), [8000, 16000],
)), name_func=name_func) [1, 2],
list(range(9)),
)
),
name_func=name_func,
)
def test_flac(self, sample_rate, num_channels, compression_level): def test_flac(self, sample_rate, num_channels, compression_level):
"""save/load round trip should not degrade data for flac formats""" """save/load round trip should not degrade data for flac formats"""
original = get_wav_data('float32', num_channels) original = get_wav_data("float32", num_channels)
data = original data = original
for i in range(10): for i in range(10):
path = self.get_temp_path(f'{i}.flac') path = self.get_temp_path(f"{i}.flac")
sox_io_backend.save(path, data, sample_rate, compression=compression_level) sox_io_backend.save(path, data, sample_rate, compression=compression_level)
data, sr = sox_io_backend.load(path) data, sr = sox_io_backend.load(path)
assert sr == sample_rate assert sr == sample_rate
......
...@@ -3,9 +3,8 @@ import os ...@@ -3,9 +3,8 @@ import os
import unittest import unittest
import torch import torch
from torchaudio.backend import sox_io_backend
from parameterized import parameterized from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
TorchaudioTestCase, TorchaudioTestCase,
...@@ -18,6 +17,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -18,6 +17,7 @@ from torchaudio_unittest.common_utils import (
sox_utils, sox_utils,
nested_params, nested_params,
) )
from .common import ( from .common import (
name_func, name_func,
get_enc_params, get_enc_params,
...@@ -26,28 +26,28 @@ from .common import ( ...@@ -26,28 +26,28 @@ from .common import (
def _get_sox_encoding(encoding): def _get_sox_encoding(encoding):
encodings = { encodings = {
'PCM_F': 'floating-point', "PCM_F": "floating-point",
'PCM_S': 'signed-integer', "PCM_S": "signed-integer",
'PCM_U': 'unsigned-integer', "PCM_U": "unsigned-integer",
'ULAW': 'u-law', "ULAW": "u-law",
'ALAW': 'a-law', "ALAW": "a-law",
} }
return encodings.get(encoding) return encodings.get(encoding)
class SaveTestBase(TempDirMixin, TorchaudioTestCase): class SaveTestBase(TempDirMixin, TorchaudioTestCase):
def assert_save_consistency( def assert_save_consistency(
self, self,
format: str, format: str,
*, *,
compression: float = None, compression: float = None,
encoding: str = None, encoding: str = None,
bits_per_sample: int = None, bits_per_sample: int = None,
sample_rate: float = 8000, sample_rate: float = 8000,
num_channels: int = 2, num_channels: int = 2,
num_frames: float = 3 * 8000, num_frames: float = 3 * 8000,
src_dtype: str = 'int32', src_dtype: str = "int32",
test_mode: str = "path", test_mode: str = "path",
): ):
"""`save` function produces file that is comparable with `sox` command """`save` function produces file that is comparable with `sox` command
...@@ -86,14 +86,14 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase): ...@@ -86,14 +86,14 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
tensor -------> compare <--------- tensor tensor -------> compare <--------- tensor
""" """
cmp_encoding = 'floating-point' cmp_encoding = "floating-point"
cmp_bit_depth = 32 cmp_bit_depth = 32
src_path = self.get_temp_path('1.source.wav') src_path = self.get_temp_path("1.source.wav")
tgt_path = self.get_temp_path(f'2.1.torchaudio.{format}') tgt_path = self.get_temp_path(f"2.1.torchaudio.{format}")
tst_path = self.get_temp_path('2.2.result.wav') tst_path = self.get_temp_path("2.2.result.wav")
sox_path = self.get_temp_path(f'3.1.sox.{format}') sox_path = self.get_temp_path(f"3.1.sox.{format}")
ref_path = self.get_temp_path('3.2.ref.wav') ref_path = self.get_temp_path("3.2.ref.wav")
# 1. Generate original wav # 1. Generate original wav
data = get_wav_data(src_dtype, num_channels, normalize=False, num_frames=num_frames) data = get_wav_data(src_dtype, num_channels, normalize=False, num_frames=num_frames)
...@@ -103,78 +103,84 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase): ...@@ -103,78 +103,84 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
data = load_wav(src_path, normalize=False)[0] data = load_wav(src_path, normalize=False)[0]
if test_mode == "path": if test_mode == "path":
sox_io_backend.save( sox_io_backend.save(
tgt_path, data, sample_rate, tgt_path, data, sample_rate, compression=compression, encoding=encoding, bits_per_sample=bits_per_sample
compression=compression, encoding=encoding, bits_per_sample=bits_per_sample) )
elif test_mode == "fileobj": elif test_mode == "fileobj":
with open(tgt_path, 'bw') as file_: with open(tgt_path, "bw") as file_:
sox_io_backend.save( sox_io_backend.save(
file_, data, sample_rate, file_,
format=format, compression=compression, data,
encoding=encoding, bits_per_sample=bits_per_sample) sample_rate,
format=format,
compression=compression,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
elif test_mode == "bytesio": elif test_mode == "bytesio":
file_ = io.BytesIO() file_ = io.BytesIO()
sox_io_backend.save( sox_io_backend.save(
file_, data, sample_rate, file_,
format=format, compression=compression, data,
encoding=encoding, bits_per_sample=bits_per_sample) sample_rate,
format=format,
compression=compression,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
file_.seek(0) file_.seek(0)
with open(tgt_path, 'bw') as f: with open(tgt_path, "bw") as f:
f.write(file_.read()) f.write(file_.read())
else: else:
raise ValueError(f"Unexpected test mode: {test_mode}") raise ValueError(f"Unexpected test mode: {test_mode}")
# 2.2. Convert the target format to wav with sox # 2.2. Convert the target format to wav with sox
sox_utils.convert_audio_file( sox_utils.convert_audio_file(tgt_path, tst_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
tgt_path, tst_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
# 2.3. Load with SciPy # 2.3. Load with SciPy
found = load_wav(tst_path, normalize=False)[0] found = load_wav(tst_path, normalize=False)[0]
# 3.1. Convert the original wav to target format with sox # 3.1. Convert the original wav to target format with sox
sox_encoding = _get_sox_encoding(encoding) sox_encoding = _get_sox_encoding(encoding)
sox_utils.convert_audio_file( sox_utils.convert_audio_file(
src_path, sox_path, src_path, sox_path, compression=compression, encoding=sox_encoding, bit_depth=bits_per_sample
compression=compression, encoding=sox_encoding, bit_depth=bits_per_sample) )
# 3.2. Convert the target format to wav with sox # 3.2. Convert the target format to wav with sox
sox_utils.convert_audio_file( sox_utils.convert_audio_file(sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
# 3.3. Load with SciPy # 3.3. Load with SciPy
expected = load_wav(ref_path, normalize=False)[0] expected = load_wav(ref_path, normalize=False)[0]
self.assertEqual(found, expected) self.assertEqual(found, expected)
@skipIfNoExec('sox') @skipIfNoExec("sox")
@skipIfNoSox @skipIfNoSox
class SaveTest(SaveTestBase): class SaveTest(SaveTestBase):
@nested_params( @nested_params(
["path", "fileobj", "bytesio"], ["path", "fileobj", "bytesio"],
[ [
('PCM_U', 8), ("PCM_U", 8),
('PCM_S', 16), ("PCM_S", 16),
('PCM_S', 32), ("PCM_S", 32),
('PCM_F', 32), ("PCM_F", 32),
('PCM_F', 64), ("PCM_F", 64),
('ULAW', 8), ("ULAW", 8),
('ALAW', 8), ("ALAW", 8),
], ],
) )
def test_save_wav(self, test_mode, enc_params): def test_save_wav(self, test_mode, enc_params):
encoding, bits_per_sample = enc_params encoding, bits_per_sample = enc_params
self.assert_save_consistency( self.assert_save_consistency("wav", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
"wav", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
@nested_params( @nested_params(
["path", "fileobj", "bytesio"], ["path", "fileobj", "bytesio"],
[ [
('float32', ), ("float32",),
('int32', ), ("int32",),
('int16', ), ("int16",),
('uint8', ), ("uint8",),
], ],
) )
def test_save_wav_dtype(self, test_mode, params): def test_save_wav_dtype(self, test_mode, params):
dtype, = params (dtype,) = params
self.assert_save_consistency( self.assert_save_consistency("wav", src_dtype=dtype, test_mode=test_mode)
"wav", src_dtype=dtype, test_mode=test_mode)
@nested_params( @nested_params(
["path", "fileobj", "bytesio"], ["path", "fileobj", "bytesio"],
...@@ -197,10 +203,9 @@ class SaveTest(SaveTestBase): ...@@ -197,10 +203,9 @@ class SaveTest(SaveTestBase):
if test_mode in ["fileobj", "bytesio"]: if test_mode in ["fileobj", "bytesio"]:
if bit_rate is not None and bit_rate < 1: if bit_rate is not None and bit_rate < 1:
raise unittest.SkipTest( raise unittest.SkipTest(
"mp3 format with variable bit rate is known to " "mp3 format with variable bit rate is known to " "not yield the exact same result as sox command."
"not yield the exact same result as sox command.") )
self.assert_save_consistency( self.assert_save_consistency("mp3", compression=bit_rate, test_mode=test_mode)
"mp3", compression=bit_rate, test_mode=test_mode)
@nested_params( @nested_params(
["path", "fileobj", "bytesio"], ["path", "fileobj", "bytesio"],
...@@ -220,8 +225,8 @@ class SaveTest(SaveTestBase): ...@@ -220,8 +225,8 @@ class SaveTest(SaveTestBase):
) )
def test_save_flac(self, test_mode, bits_per_sample, compression_level): def test_save_flac(self, test_mode, bits_per_sample, compression_level):
self.assert_save_consistency( self.assert_save_consistency(
"flac", compression=compression_level, "flac", compression=compression_level, bits_per_sample=bits_per_sample, test_mode=test_mode
bits_per_sample=bits_per_sample, test_mode=test_mode) )
@nested_params( @nested_params(
["path", "fileobj", "bytesio"], ["path", "fileobj", "bytesio"],
...@@ -244,45 +249,78 @@ class SaveTest(SaveTestBase): ...@@ -244,45 +249,78 @@ class SaveTest(SaveTestBase):
], ],
) )
def test_save_vorbis(self, test_mode, quality_level): def test_save_vorbis(self, test_mode, quality_level):
self.assert_save_consistency( self.assert_save_consistency("vorbis", compression=quality_level, test_mode=test_mode)
"vorbis", compression=quality_level, test_mode=test_mode)
@nested_params( @nested_params(
["path", "fileobj", "bytesio"], ["path", "fileobj", "bytesio"],
[ [
('PCM_S', 8, ), (
('PCM_S', 16, ), "PCM_S",
('PCM_S', 24, ), 8,
('PCM_S', 32, ), ),
('ULAW', 8), (
('ALAW', 8), "PCM_S",
('ALAW', 16), 16,
('ALAW', 24), ),
('ALAW', 32), (
"PCM_S",
24,
),
(
"PCM_S",
32,
),
("ULAW", 8),
("ALAW", 8),
("ALAW", 16),
("ALAW", 24),
("ALAW", 32),
], ],
) )
def test_save_sphere(self, test_mode, enc_params): def test_save_sphere(self, test_mode, enc_params):
encoding, bits_per_sample = enc_params encoding, bits_per_sample = enc_params
self.assert_save_consistency( self.assert_save_consistency("sph", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
"sph", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
@nested_params( @nested_params(
["path", "fileobj", "bytesio"], ["path", "fileobj", "bytesio"],
[ [
('PCM_U', 8, ), (
('PCM_S', 16, ), "PCM_U",
('PCM_S', 24, ), 8,
('PCM_S', 32, ), ),
('PCM_F', 32, ), (
('PCM_F', 64, ), "PCM_S",
('ULAW', 8, ), 16,
('ALAW', 8, ), ),
(
"PCM_S",
24,
),
(
"PCM_S",
32,
),
(
"PCM_F",
32,
),
(
"PCM_F",
64,
),
(
"ULAW",
8,
),
(
"ALAW",
8,
),
], ],
) )
def test_save_amb(self, test_mode, enc_params): def test_save_amb(self, test_mode, enc_params):
encoding, bits_per_sample = enc_params encoding, bits_per_sample = enc_params
self.assert_save_consistency( self.assert_save_consistency("amb", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
"amb", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
@nested_params( @nested_params(
["path", "fileobj", "bytesio"], ["path", "fileobj", "bytesio"],
...@@ -299,90 +337,94 @@ class SaveTest(SaveTestBase): ...@@ -299,90 +337,94 @@ class SaveTest(SaveTestBase):
], ],
) )
def test_save_amr_nb(self, test_mode, bit_rate): def test_save_amr_nb(self, test_mode, bit_rate):
self.assert_save_consistency( self.assert_save_consistency("amr-nb", compression=bit_rate, num_channels=1, test_mode=test_mode)
"amr-nb", compression=bit_rate, num_channels=1, test_mode=test_mode)
@nested_params( @nested_params(
["path", "fileobj", "bytesio"], ["path", "fileobj", "bytesio"],
) )
def test_save_gsm(self, test_mode): def test_save_gsm(self, test_mode):
self.assert_save_consistency( self.assert_save_consistency("gsm", num_channels=1, test_mode=test_mode)
"gsm", num_channels=1, test_mode=test_mode) with self.assertRaises(RuntimeError, msg="gsm format only supports single channel audio."):
with self.assertRaises( self.assert_save_consistency("gsm", num_channels=2, test_mode=test_mode)
RuntimeError, msg="gsm format only supports single channel audio."): with self.assertRaises(RuntimeError, msg="gsm format only supports a sampling rate of 8kHz."):
self.assert_save_consistency( self.assert_save_consistency("gsm", sample_rate=16000, test_mode=test_mode)
"gsm", num_channels=2, test_mode=test_mode)
with self.assertRaises( @parameterized.expand(
RuntimeError, msg="gsm format only supports a sampling rate of 8kHz."): [
self.assert_save_consistency( ("wav", "PCM_S", 16),
"gsm", sample_rate=16000, test_mode=test_mode) ("mp3",),
("flac",),
@parameterized.expand([ ("vorbis",),
("wav", "PCM_S", 16), ("sph", "PCM_S", 16),
("mp3", ), ("amr-nb",),
("flac", ), ("amb", "PCM_S", 16),
("vorbis", ), ],
("sph", "PCM_S", 16), name_func=name_func,
("amr-nb", ), )
("amb", "PCM_S", 16),
], name_func=name_func)
def test_save_large(self, format, encoding=None, bits_per_sample=None): def test_save_large(self, format, encoding=None, bits_per_sample=None):
"""`sox_io_backend.save` can save large files.""" """`sox_io_backend.save` can save large files."""
sample_rate = 8000 sample_rate = 8000
one_hour = 60 * 60 * sample_rate one_hour = 60 * 60 * sample_rate
self.assert_save_consistency( self.assert_save_consistency(
format, num_channels=1, sample_rate=8000, num_frames=one_hour, format,
encoding=encoding, bits_per_sample=bits_per_sample) num_channels=1,
sample_rate=8000,
@parameterized.expand([ num_frames=one_hour,
(32, ), encoding=encoding,
(64, ), bits_per_sample=bits_per_sample,
(128, ), )
(256, ),
], name_func=name_func) @parameterized.expand(
[
(32,),
(64,),
(128,),
(256,),
],
name_func=name_func,
)
def test_save_multi_channels(self, num_channels): def test_save_multi_channels(self, num_channels):
"""`sox_io_backend.save` can save audio with many channels""" """`sox_io_backend.save` can save audio with many channels"""
self.assert_save_consistency( self.assert_save_consistency("wav", encoding="PCM_S", bits_per_sample=16, num_channels=num_channels)
"wav", encoding="PCM_S", bits_per_sample=16,
num_channels=num_channels)
@skipIfNoExec('sox') @skipIfNoExec("sox")
@skipIfNoSox @skipIfNoSox
class TestSaveParams(TempDirMixin, PytorchTestCase): class TestSaveParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of optional parameters of `sox_io_backend.save`""" """Test the correctness of optional parameters of `sox_io_backend.save`"""
@parameterized.expand([(True, ), (False, )], name_func=name_func)
@parameterized.expand([(True,), (False,)], name_func=name_func)
def test_save_channels_first(self, channels_first): def test_save_channels_first(self, channels_first):
"""channels_first swaps axes""" """channels_first swaps axes"""
path = self.get_temp_path('data.wav') path = self.get_temp_path("data.wav")
data = get_wav_data( data = get_wav_data("int16", 2, channels_first=channels_first, normalize=False)
'int16', 2, channels_first=channels_first, normalize=False) sox_io_backend.save(path, data, 8000, channels_first=channels_first)
sox_io_backend.save(
path, data, 8000, channels_first=channels_first)
found = load_wav(path, normalize=False)[0] found = load_wav(path, normalize=False)[0]
expected = data if channels_first else data.transpose(1, 0) expected = data if channels_first else data.transpose(1, 0)
self.assertEqual(found, expected) self.assertEqual(found, expected)
@parameterized.expand([ @parameterized.expand(["float32", "int32", "int16", "uint8"], name_func=name_func)
'float32', 'int32', 'int16', 'uint8'
], name_func=name_func)
def test_save_noncontiguous(self, dtype): def test_save_noncontiguous(self, dtype):
"""Noncontiguous tensors are saved correctly""" """Noncontiguous tensors are saved correctly"""
path = self.get_temp_path('data.wav') path = self.get_temp_path("data.wav")
enc, bps = get_enc_params(dtype) enc, bps = get_enc_params(dtype)
expected = get_wav_data(dtype, 4, normalize=False)[::2, ::2] expected = get_wav_data(dtype, 4, normalize=False)[::2, ::2]
assert not expected.is_contiguous() assert not expected.is_contiguous()
sox_io_backend.save( sox_io_backend.save(path, expected, 8000, encoding=enc, bits_per_sample=bps)
path, expected, 8000, encoding=enc, bits_per_sample=bps)
found = load_wav(path, normalize=False)[0] found = load_wav(path, normalize=False)[0]
self.assertEqual(found, expected) self.assertEqual(found, expected)
@parameterized.expand([ @parameterized.expand(
'float32', 'int32', 'int16', 'uint8', [
]) "float32",
"int32",
"int16",
"uint8",
]
)
def test_save_tensor_preserve(self, dtype): def test_save_tensor_preserve(self, dtype):
"""save function should not alter Tensor""" """save function should not alter Tensor"""
path = self.get_temp_path('data.wav') path = self.get_temp_path("data.wav")
expected = get_wav_data(dtype, 4, normalize=False)[::2, ::2] expected = get_wav_data(dtype, 4, normalize=False)[::2, ::2]
data = expected.clone() data = expected.clone()
......
...@@ -2,25 +2,24 @@ import io ...@@ -2,25 +2,24 @@ import io
import itertools import itertools
import unittest import unittest
from torchaudio.utils import sox_utils
from torchaudio.backend import sox_io_backend
from torchaudio._internal.module_utils import is_sox_available
from parameterized import parameterized from parameterized import parameterized
from torchaudio._internal.module_utils import is_sox_available
from torchaudio.backend import sox_io_backend
from torchaudio.utils import sox_utils
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
TorchaudioTestCase, TorchaudioTestCase,
skipIfNoSox, skipIfNoSox,
get_wav_data, get_wav_data,
) )
from .common import name_func from .common import name_func
skipIfNoMP3 = unittest.skipIf( skipIfNoMP3 = unittest.skipIf(
not is_sox_available() or not is_sox_available() or "mp3" not in sox_utils.list_read_formats() or "mp3" not in sox_utils.list_write_formats(),
'mp3' not in sox_utils.list_read_formats() or '"sox_io" backend does not support MP3',
'mp3' not in sox_utils.list_write_formats(), )
'"sox_io" backend does not support MP3')
@skipIfNoSox @skipIfNoSox
...@@ -33,10 +32,11 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase): ...@@ -33,10 +32,11 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase):
This test suite should be able to run without any additional tools (such as sox command), This test suite should be able to run without any additional tools (such as sox command),
however without such tools, the correctness of each function cannot be verified. however without such tools, the correctness of each function cannot be verified.
""" """
def run_smoke_test(self, ext, sample_rate, num_channels, *, compression=None, dtype='float32'):
def run_smoke_test(self, ext, sample_rate, num_channels, *, compression=None, dtype="float32"):
duration = 1 duration = 1
num_frames = sample_rate * duration num_frames = sample_rate * duration
path = self.get_temp_path(f'test.{ext}') path = self.get_temp_path(f"test.{ext}")
original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames) original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames)
# 1. run save # 1. run save
...@@ -50,42 +50,60 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase): ...@@ -50,42 +50,60 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase):
assert sr == sample_rate assert sr == sample_rate
assert loaded.shape[0] == num_channels assert loaded.shape[0] == num_channels
@parameterized.expand(list(itertools.product( @parameterized.expand(
['float32', 'int32', 'int16', 'uint8'], list(
[8000, 16000], itertools.product(
[1, 2], ["float32", "int32", "int16", "uint8"],
)), name_func=name_func) [8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_wav(self, dtype, sample_rate, num_channels): def test_wav(self, dtype, sample_rate, num_channels):
"""Run smoke test on wav format""" """Run smoke test on wav format"""
self.run_smoke_test('wav', sample_rate, num_channels, dtype=dtype) self.run_smoke_test("wav", sample_rate, num_channels, dtype=dtype)
@parameterized.expand(list(itertools.product( @parameterized.expand(
[8000, 16000], list(
[1, 2], itertools.product(
[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320], [8000, 16000],
))) [1, 2],
[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320],
)
)
)
@skipIfNoMP3 @skipIfNoMP3
def test_mp3(self, sample_rate, num_channels, bit_rate): def test_mp3(self, sample_rate, num_channels, bit_rate):
"""Run smoke test on mp3 format""" """Run smoke test on mp3 format"""
self.run_smoke_test('mp3', sample_rate, num_channels, compression=bit_rate) self.run_smoke_test("mp3", sample_rate, num_channels, compression=bit_rate)
@parameterized.expand(list(itertools.product( @parameterized.expand(
[8000, 16000], list(
[1, 2], itertools.product(
[-1, 0, 1, 2, 3, 3.6, 5, 10], [8000, 16000],
))) [1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)
)
)
def test_vorbis(self, sample_rate, num_channels, quality_level): def test_vorbis(self, sample_rate, num_channels, quality_level):
"""Run smoke test on vorbis format""" """Run smoke test on vorbis format"""
self.run_smoke_test('vorbis', sample_rate, num_channels, compression=quality_level) self.run_smoke_test("vorbis", sample_rate, num_channels, compression=quality_level)
@parameterized.expand(list(itertools.product( @parameterized.expand(
[8000, 16000], list(
[1, 2], itertools.product(
list(range(9)), [8000, 16000],
)), name_func=name_func) [1, 2],
list(range(9)),
)
),
name_func=name_func,
)
def test_flac(self, sample_rate, num_channels, compression_level): def test_flac(self, sample_rate, num_channels, compression_level):
"""Run smoke test on flac format""" """Run smoke test on flac format"""
self.run_smoke_test('flac', sample_rate, num_channels, compression=compression_level) self.run_smoke_test("flac", sample_rate, num_channels, compression=compression_level)
@skipIfNoSox @skipIfNoSox
...@@ -98,7 +116,8 @@ class SmokeTestFileObj(TorchaudioTestCase): ...@@ -98,7 +116,8 @@ class SmokeTestFileObj(TorchaudioTestCase):
This test suite should be able to run without any additional tools (such as sox command), This test suite should be able to run without any additional tools (such as sox command),
however without such tools, the correctness of each function cannot be verified. however without such tools, the correctness of each function cannot be verified.
""" """
def run_smoke_test(self, ext, sample_rate, num_channels, *, compression=None, dtype='float32'):
def run_smoke_test(self, ext, sample_rate, num_channels, *, compression=None, dtype="float32"):
duration = 1 duration = 1
num_frames = sample_rate * duration num_frames = sample_rate * duration
original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames) original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames)
...@@ -117,39 +136,57 @@ class SmokeTestFileObj(TorchaudioTestCase): ...@@ -117,39 +136,57 @@ class SmokeTestFileObj(TorchaudioTestCase):
assert sr == sample_rate assert sr == sample_rate
assert loaded.shape[0] == num_channels assert loaded.shape[0] == num_channels
@parameterized.expand(list(itertools.product( @parameterized.expand(
['float32', 'int32', 'int16', 'uint8'], list(
[8000, 16000], itertools.product(
[1, 2], ["float32", "int32", "int16", "uint8"],
)), name_func=name_func) [8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_wav(self, dtype, sample_rate, num_channels): def test_wav(self, dtype, sample_rate, num_channels):
"""Run smoke test on wav format""" """Run smoke test on wav format"""
self.run_smoke_test('wav', sample_rate, num_channels, dtype=dtype) self.run_smoke_test("wav", sample_rate, num_channels, dtype=dtype)
@parameterized.expand(list(itertools.product( @parameterized.expand(
[8000, 16000], list(
[1, 2], itertools.product(
[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320], [8000, 16000],
))) [1, 2],
[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320],
)
)
)
@skipIfNoMP3 @skipIfNoMP3
def test_mp3(self, sample_rate, num_channels, bit_rate): def test_mp3(self, sample_rate, num_channels, bit_rate):
"""Run smoke test on mp3 format""" """Run smoke test on mp3 format"""
self.run_smoke_test('mp3', sample_rate, num_channels, compression=bit_rate) self.run_smoke_test("mp3", sample_rate, num_channels, compression=bit_rate)
@parameterized.expand(list(itertools.product( @parameterized.expand(
[8000, 16000], list(
[1, 2], itertools.product(
[-1, 0, 1, 2, 3, 3.6, 5, 10], [8000, 16000],
))) [1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)
)
)
def test_vorbis(self, sample_rate, num_channels, quality_level): def test_vorbis(self, sample_rate, num_channels, quality_level):
"""Run smoke test on vorbis format""" """Run smoke test on vorbis format"""
self.run_smoke_test('vorbis', sample_rate, num_channels, compression=quality_level) self.run_smoke_test("vorbis", sample_rate, num_channels, compression=quality_level)
@parameterized.expand(list(itertools.product( @parameterized.expand(
[8000, 16000], list(
[1, 2], itertools.product(
list(range(9)), [8000, 16000],
)), name_func=name_func) [1, 2],
list(range(9)),
)
),
name_func=name_func,
)
def test_flac(self, sample_rate, num_channels, compression_level): def test_flac(self, sample_rate, num_channels, compression_level):
"""Run smoke test on flac format""" """Run smoke test on flac format"""
self.run_smoke_test('flac', sample_rate, num_channels, compression=compression_level) self.run_smoke_test("flac", sample_rate, num_channels, compression=compression_level)
...@@ -4,7 +4,6 @@ from typing import Optional ...@@ -4,7 +4,6 @@ from typing import Optional
import torch import torch
import torchaudio import torchaudio
from parameterized import parameterized from parameterized import parameterized
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
TorchaudioTestCase, TorchaudioTestCase,
...@@ -16,6 +15,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -16,6 +15,7 @@ from torchaudio_unittest.common_utils import (
sox_utils, sox_utils,
torch_script, torch_script,
) )
from .common import ( from .common import (
name_func, name_func,
get_enc_params, get_enc_params,
...@@ -27,38 +27,41 @@ def py_info_func(filepath: str) -> torchaudio.backend.sox_io_backend.AudioMetaDa ...@@ -27,38 +27,41 @@ def py_info_func(filepath: str) -> torchaudio.backend.sox_io_backend.AudioMetaDa
def py_load_func(filepath: str, normalize: bool, channels_first: bool): def py_load_func(filepath: str, normalize: bool, channels_first: bool):
return torchaudio.load( return torchaudio.load(filepath, normalize=normalize, channels_first=channels_first)
filepath, normalize=normalize, channels_first=channels_first)
def py_save_func( def py_save_func(
filepath: str, filepath: str,
tensor: torch.Tensor, tensor: torch.Tensor,
sample_rate: int, sample_rate: int,
channels_first: bool = True, channels_first: bool = True,
compression: Optional[float] = None, compression: Optional[float] = None,
encoding: Optional[str] = None, encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None, bits_per_sample: Optional[int] = None,
): ):
torchaudio.save( torchaudio.save(filepath, tensor, sample_rate, channels_first, compression, None, encoding, bits_per_sample)
filepath, tensor, sample_rate, channels_first,
compression, None, encoding, bits_per_sample)
@skipIfNoExec('sox') @skipIfNoExec("sox")
@skipIfNoSox @skipIfNoSox
class SoxIO(TempDirMixin, TorchaudioTestCase): class SoxIO(TempDirMixin, TorchaudioTestCase):
"""TorchScript-ability Test suite for `sox_io_backend`""" """TorchScript-ability Test suite for `sox_io_backend`"""
backend = 'sox_io'
@parameterized.expand(list(itertools.product( backend = "sox_io"
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000], @parameterized.expand(
[1, 2], list(
)), name_func=name_func) itertools.product(
["float32", "int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_info_wav(self, dtype, sample_rate, num_channels): def test_info_wav(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` is torchscript-able and returns the same result""" """`sox_io_backend.info` is torchscript-able and returns the same result"""
audio_path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav') audio_path = self.get_temp_path(f"{dtype}_{sample_rate}_{num_channels}.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate) data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate)
save_wav(audio_path, data, sample_rate) save_wav(audio_path, data, sample_rate)
...@@ -71,40 +74,48 @@ class SoxIO(TempDirMixin, TorchaudioTestCase): ...@@ -71,40 +74,48 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
assert py_info.num_frames == ts_info.num_frames assert py_info.num_frames == ts_info.num_frames
assert py_info.num_channels == ts_info.num_channels assert py_info.num_channels == ts_info.num_channels
@parameterized.expand(list(itertools.product( @parameterized.expand(
['float32', 'int32', 'int16', 'uint8'], list(
[8000, 16000], itertools.product(
[1, 2], ["float32", "int32", "int16", "uint8"],
[False, True], [8000, 16000],
[False, True], [1, 2],
)), name_func=name_func) [False, True],
[False, True],
)
),
name_func=name_func,
)
def test_load_wav(self, dtype, sample_rate, num_channels, normalize, channels_first): def test_load_wav(self, dtype, sample_rate, num_channels, normalize, channels_first):
"""`sox_io_backend.load` is torchscript-able and returns the same result""" """`sox_io_backend.load` is torchscript-able and returns the same result"""
audio_path = self.get_temp_path(f'test_load_{dtype}_{sample_rate}_{num_channels}_{normalize}.wav') audio_path = self.get_temp_path(f"test_load_{dtype}_{sample_rate}_{num_channels}_{normalize}.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate) data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate)
save_wav(audio_path, data, sample_rate) save_wav(audio_path, data, sample_rate)
ts_load_func = torch_script(py_load_func) ts_load_func = torch_script(py_load_func)
py_data, py_sr = py_load_func( py_data, py_sr = py_load_func(audio_path, normalize=normalize, channels_first=channels_first)
audio_path, normalize=normalize, channels_first=channels_first) ts_data, ts_sr = ts_load_func(audio_path, normalize=normalize, channels_first=channels_first)
ts_data, ts_sr = ts_load_func(
audio_path, normalize=normalize, channels_first=channels_first)
self.assertEqual(py_sr, ts_sr) self.assertEqual(py_sr, ts_sr)
self.assertEqual(py_data, ts_data) self.assertEqual(py_data, ts_data)
@parameterized.expand(list(itertools.product( @parameterized.expand(
['float32', 'int32', 'int16', 'uint8'], list(
[8000, 16000], itertools.product(
[1, 2], ["float32", "int32", "int16", "uint8"],
)), name_func=name_func) [8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_save_wav(self, dtype, sample_rate, num_channels): def test_save_wav(self, dtype, sample_rate, num_channels):
ts_save_func = torch_script(py_save_func) ts_save_func = torch_script(py_save_func)
expected = get_wav_data(dtype, num_channels, normalize=False) expected = get_wav_data(dtype, num_channels, normalize=False)
py_path = self.get_temp_path(f'test_save_py_{dtype}_{sample_rate}_{num_channels}.wav') py_path = self.get_temp_path(f"test_save_py_{dtype}_{sample_rate}_{num_channels}.wav")
ts_path = self.get_temp_path(f'test_save_ts_{dtype}_{sample_rate}_{num_channels}.wav') ts_path = self.get_temp_path(f"test_save_ts_{dtype}_{sample_rate}_{num_channels}.wav")
enc, bps = get_enc_params(dtype) enc, bps = get_enc_params(dtype)
py_save_func(py_path, expected, sample_rate, True, None, enc, bps) py_save_func(py_path, expected, sample_rate, True, None, enc, bps)
...@@ -118,24 +129,29 @@ class SoxIO(TempDirMixin, TorchaudioTestCase): ...@@ -118,24 +129,29 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
self.assertEqual(expected, py_data) self.assertEqual(expected, py_data)
self.assertEqual(expected, ts_data) self.assertEqual(expected, ts_data)
@parameterized.expand(list(itertools.product( @parameterized.expand(
[8000, 16000], list(
[1, 2], itertools.product(
list(range(9)), [8000, 16000],
)), name_func=name_func) [1, 2],
list(range(9)),
)
),
name_func=name_func,
)
def test_save_flac(self, sample_rate, num_channels, compression_level): def test_save_flac(self, sample_rate, num_channels, compression_level):
ts_save_func = torch_script(py_save_func) ts_save_func = torch_script(py_save_func)
expected = get_wav_data('float32', num_channels) expected = get_wav_data("float32", num_channels)
py_path = self.get_temp_path(f'test_save_py_{sample_rate}_{num_channels}_{compression_level}.flac') py_path = self.get_temp_path(f"test_save_py_{sample_rate}_{num_channels}_{compression_level}.flac")
ts_path = self.get_temp_path(f'test_save_ts_{sample_rate}_{num_channels}_{compression_level}.flac') ts_path = self.get_temp_path(f"test_save_ts_{sample_rate}_{num_channels}_{compression_level}.flac")
py_save_func(py_path, expected, sample_rate, True, compression_level, None, None) py_save_func(py_path, expected, sample_rate, True, compression_level, None, None)
ts_save_func(ts_path, expected, sample_rate, True, compression_level, None, None) ts_save_func(ts_path, expected, sample_rate, True, compression_level, None, None)
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle. # converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
py_path_wav = f'{py_path}.wav' py_path_wav = f"{py_path}.wav"
ts_path_wav = f'{ts_path}.wav' ts_path_wav = f"{ts_path}.wav"
sox_utils.convert_audio_file(py_path, py_path_wav, bit_depth=32) sox_utils.convert_audio_file(py_path, py_path_wav, bit_depth=32)
sox_utils.convert_audio_file(ts_path, ts_path_wav, bit_depth=32) sox_utils.convert_audio_file(ts_path, ts_path_wav, bit_depth=32)
......
import torchaudio import torchaudio
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
class BackendSwitchMixin: class BackendSwitchMixin:
"""Test set/get_audio_backend works""" """Test set/get_audio_backend works"""
backend = None backend = None
backend_module = None backend_module = None
...@@ -26,11 +26,11 @@ class TestBackendSwitch_NoBackend(BackendSwitchMixin, common_utils.TorchaudioTes ...@@ -26,11 +26,11 @@ class TestBackendSwitch_NoBackend(BackendSwitchMixin, common_utils.TorchaudioTes
@common_utils.skipIfNoSox @common_utils.skipIfNoSox
class TestBackendSwitch_SoXIO(BackendSwitchMixin, common_utils.TorchaudioTestCase): class TestBackendSwitch_SoXIO(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'sox_io' backend = "sox_io"
backend_module = torchaudio.backend.sox_io_backend backend_module = torchaudio.backend.sox_io_backend
@common_utils.skipIfNoModule('soundfile') @common_utils.skipIfNoModule("soundfile")
class TestBackendSwitch_soundfile(BackendSwitchMixin, common_utils.TorchaudioTestCase): class TestBackendSwitch_soundfile(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'soundfile' backend = "soundfile"
backend_module = torchaudio.backend.soundfile_backend backend_module = torchaudio.backend.soundfile_backend
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment