Commit 5859923a authored by Joao Gomes's avatar Joao Gomes Committed by Facebook GitHub Bot
Browse files

Apply arc lint to pytorch audio (#2096)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2096

run: `arc lint --apply-patches --paths-cmd 'hg files -I "./**/*.py"'`

Reviewed By: mthrok

Differential Revision: D33297351

fbshipit-source-id: 7bf5956edf0717c5ca90219f72414ff4eeaf5aa8
parent 0e5913d5
import pytest
import torch
from torchaudio._internal import download_url_to_file
import pytest
class GreedyCTCDecoder(torch.nn.Module):
......@@ -24,7 +24,7 @@ class GreedyCTCDecoder(torch.nn.Module):
for i in best_path:
if i != self.blank:
hypothesis.append(self.labels[i])
return ''.join(hypothesis)
return "".join(hypothesis)
@pytest.fixture
......@@ -33,24 +33,24 @@ def ctc_decoder():
_FILES = {
'en': 'Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac',
'de': '20090505-0900-PLENARY-16-de_20090505-21_56_00_8.flac',
'en2': '20120613-0900-PLENARY-8-en_20120613-13_46_50_3.flac',
'es': '20130207-0900-PLENARY-7-es_20130207-13_02_05_5.flac',
'fr': '20121212-0900-PLENARY-5-fr_20121212-11_37_04_10.flac',
'it': '20170516-0900-PLENARY-16-it_20170516-18_56_31_1.flac',
"en": "Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac",
"de": "20090505-0900-PLENARY-16-de_20090505-21_56_00_8.flac",
"en2": "20120613-0900-PLENARY-8-en_20120613-13_46_50_3.flac",
"es": "20130207-0900-PLENARY-7-es_20130207-13_02_05_5.flac",
"fr": "20121212-0900-PLENARY-5-fr_20121212-11_37_04_10.flac",
"it": "20170516-0900-PLENARY-16-it_20170516-18_56_31_1.flac",
}
@pytest.fixture
def sample_speech(tmp_path, lang):
if lang not in _FILES:
raise NotImplementedError(f'Unexpected lang: {lang}')
raise NotImplementedError(f"Unexpected lang: {lang}")
filename = _FILES[lang]
path = tmp_path.parent / filename
if not path.exists():
url = f'https://download.pytorch.org/torchaudio/test-assets/{filename}'
print(f'downloading from {url}')
url = f"https://download.pytorch.org/torchaudio/test-assets/{filename}"
print(f"downloading from {url}")
download_url_to_file(url, path, progress=False)
return path
......@@ -62,13 +62,13 @@ def pytest_addoption(parser):
help=(
"When provided, tests will use temporary directory as Torch Hub directory. "
"Downloaded models will be deleted after each test."
)
),
)
@pytest.fixture(autouse=True)
def temp_hub_dir(tmpdir, pytestconfig):
if not pytestconfig.getoption('use_tmp_hub_dir'):
if not pytestconfig.getoption("use_tmp_hub_dir"):
yield
else:
org_dir = torch.hub.get_dir()
......
import pytest
from torchaudio.pipelines import (
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
TACOTRON2_WAVERNN_CHAR_LJSPEECH,
TACOTRON2_WAVERNN_PHONE_LJSPEECH,
)
import pytest
@pytest.mark.parametrize(
'bundle',
"bundle",
[
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
TACOTRON2_WAVERNN_CHAR_LJSPEECH,
TACOTRON2_WAVERNN_PHONE_LJSPEECH,
]
],
)
def test_tts_models(bundle):
"""Smoke test of TTS pipeline"""
......
import pytest
import torchaudio
from torchaudio.pipelines import (
WAV2VEC2_BASE,
......@@ -24,7 +25,6 @@ from torchaudio.pipelines import (
VOXPOPULI_ASR_BASE_10K_FR,
VOXPOPULI_ASR_BASE_10K_IT,
)
import pytest
@pytest.mark.parametrize(
......@@ -37,7 +37,7 @@ import pytest
HUBERT_BASE,
HUBERT_LARGE,
HUBERT_XLARGE,
]
],
)
def test_pretraining_models(bundle):
"""Smoke test of downloading weights for pretraining models"""
......@@ -47,30 +47,46 @@ def test_pretraining_models(bundle):
@pytest.mark.parametrize(
"bundle,lang,expected",
[
(WAV2VEC2_ASR_BASE_10M, 'en', 'I|HAD|THAT|CURIYOSSITY|BESID|ME|AT|THIS|MOMENT|'),
(WAV2VEC2_ASR_BASE_100H, 'en', 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
(WAV2VEC2_ASR_BASE_960H, 'en', 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
(WAV2VEC2_ASR_LARGE_10M, 'en', 'I|HAD|THAT|CURIOUSITY|BESIDE|ME|AT|THIS|MOMENT|'),
(WAV2VEC2_ASR_LARGE_100H, 'en', 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
(WAV2VEC2_ASR_LARGE_960H, 'en', 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
(WAV2VEC2_ASR_LARGE_LV60K_10M, 'en', 'I|HAD|THAT|CURIOUSSITY|BESID|ME|AT|THISS|MOMENT|'),
(WAV2VEC2_ASR_LARGE_LV60K_100H, 'en', 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
(WAV2VEC2_ASR_LARGE_LV60K_960H, 'en', 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
(HUBERT_ASR_LARGE, 'en', 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
(HUBERT_ASR_XLARGE, 'en', 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
(VOXPOPULI_ASR_BASE_10K_EN, 'en2', 'i|hope|that|we|will|see|a|ddrasstic|decrease|of|funding|for|the|failed|eu|project|and|that|more|money|will|come|back|to|the|taxpayers'), # noqa: E501
(VOXPOPULI_ASR_BASE_10K_ES, 'es', "la|primera|que|es|imprescindible|pensar|a|pequeña|a|escala|para|implicar|y|complementar|así|la|actuación|global"), # noqa: E501
(VOXPOPULI_ASR_BASE_10K_DE, 'de', "dabei|spielt|auch|eine|sorgfältige|berichterstattung|eine|wichtige|rolle"),
(VOXPOPULI_ASR_BASE_10K_FR, 'fr', 'la|commission|va|faire|des|propositions|sur|ce|sujet|comment|mettre|en|place|cette|capacité|fiscale|et|le|conseil|européen|y|reviendra|sour|les|sujets|au|moins|de|mars'), # noqa: E501
(VOXPOPULI_ASR_BASE_10K_IT, 'it', 'credo|che|illatino|non|sia|contemplato|tra|le|traduzioni|e|quindi|mi|attengo|allitaliano') # noqa: E501
]
(WAV2VEC2_ASR_BASE_10M, "en", "I|HAD|THAT|CURIYOSSITY|BESID|ME|AT|THIS|MOMENT|"),
(WAV2VEC2_ASR_BASE_100H, "en", "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"),
(WAV2VEC2_ASR_BASE_960H, "en", "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"),
(WAV2VEC2_ASR_LARGE_10M, "en", "I|HAD|THAT|CURIOUSITY|BESIDE|ME|AT|THIS|MOMENT|"),
(WAV2VEC2_ASR_LARGE_100H, "en", "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"),
(WAV2VEC2_ASR_LARGE_960H, "en", "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"),
(WAV2VEC2_ASR_LARGE_LV60K_10M, "en", "I|HAD|THAT|CURIOUSSITY|BESID|ME|AT|THISS|MOMENT|"),
(WAV2VEC2_ASR_LARGE_LV60K_100H, "en", "I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"),
(WAV2VEC2_ASR_LARGE_LV60K_960H, "en", "I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"),
(HUBERT_ASR_LARGE, "en", "I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"),
(HUBERT_ASR_XLARGE, "en", "I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"),
(
VOXPOPULI_ASR_BASE_10K_EN,
"en2",
"i|hope|that|we|will|see|a|ddrasstic|decrease|of|funding|for|the|failed|eu|project|and|that|more|money|will|come|back|to|the|taxpayers",
), # noqa: E501
(
VOXPOPULI_ASR_BASE_10K_ES,
"es",
"la|primera|que|es|imprescindible|pensar|a|pequeña|a|escala|para|implicar|y|complementar|así|la|actuación|global",
), # noqa: E501
(VOXPOPULI_ASR_BASE_10K_DE, "de", "dabei|spielt|auch|eine|sorgfältige|berichterstattung|eine|wichtige|rolle"),
(
VOXPOPULI_ASR_BASE_10K_FR,
"fr",
"la|commission|va|faire|des|propositions|sur|ce|sujet|comment|mettre|en|place|cette|capacité|fiscale|et|le|conseil|européen|y|reviendra|sour|les|sujets|au|moins|de|mars",
), # noqa: E501
(
VOXPOPULI_ASR_BASE_10K_IT,
"it",
"credo|che|illatino|non|sia|contemplato|tra|le|traduzioni|e|quindi|mi|attengo|allitaliano",
), # noqa: E501
],
)
def test_finetune_asr_model(
bundle,
lang,
expected,
sample_speech,
ctc_decoder,
bundle,
lang,
expected,
sample_speech,
ctc_decoder,
):
"""Smoke test of downloading weights for fine-tuning models and simple transcription"""
model = bundle.get_model().eval()
......
......@@ -8,30 +8,26 @@ import torch
def _parse_args():
parser = argparse.ArgumentParser(
description='Generate opus files for test'
)
parser.add_argument('--num-channels', required=True, type=int)
parser.add_argument('--compression-level', required=True, type=int, choices=list(range(11)))
parser.add_argument('--bitrate', default='96k')
parser = argparse.ArgumentParser(description="Generate opus files for test")
parser.add_argument("--num-channels", required=True, type=int)
parser.add_argument("--compression-level", required=True, type=int, choices=list(range(11)))
parser.add_argument("--bitrate", default="96k")
return parser.parse_args()
def convert_to_opus(
src_path, dst_path,
*, bitrate, compression_level):
def convert_to_opus(src_path, dst_path, *, bitrate, compression_level):
"""Convert audio file with `ffmpeg` command."""
command = ['ffmpeg', '-y', '-i', src_path, '-c:a', 'libopus', '-b:a', bitrate]
command = ["ffmpeg", "-y", "-i", src_path, "-c:a", "libopus", "-b:a", bitrate]
if compression_level is not None:
command += ['-compression_level', str(compression_level)]
command += ["-compression_level", str(compression_level)]
command += [dst_path]
print(' '.join(command))
print(" ".join(command))
subprocess.run(command, check=True)
def _generate(num_channels, compression_level, bitrate):
org_path = 'original.wav'
ops_path = f'{bitrate}_{compression_level}_{num_channels}ch.opus'
org_path = "original.wav"
ops_path = f"{bitrate}_{compression_level}_{num_channels}ch.opus"
# Note: ffmpeg forces sample rate 48k Hz for opus https://stackoverflow.com/a/39186779
# 1. generate original wav
......@@ -46,5 +42,5 @@ def _main():
_generate(args.num_channels, args.compression_level, args.bitrate)
if __name__ == '__main__':
if __name__ == "__main__":
_main()
......@@ -32,8 +32,8 @@ python generate_hubert_model_config.py \
> hubert_large_ll60k_finetune_ls960.json
```
"""
import json
import argparse
import json
def _parse_args():
......@@ -42,12 +42,9 @@ def _parse_args():
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
'--model-file',
"--model-file",
required=True,
help=(
'A pt file from '
'https://github.com/pytorch/fairseq/tree/main/examples/hubert'
)
help=("A pt file from " "https://github.com/pytorch/fairseq/tree/main/examples/hubert"),
)
return parser.parse_args()
......@@ -66,27 +63,27 @@ def _main():
args = _parse_args()
model, cfg = _load(args.model_file)
if model.__class__.__name__ == 'HubertModel':
cfg['task']['data'] = '/foo/bar'
cfg['task']['label_dir'] = None
if model.__class__.__name__ == "HubertModel":
cfg["task"]["data"] = "/foo/bar"
cfg["task"]["label_dir"] = None
conf = {
'_name': 'hubert',
'model': cfg['model'],
'task': cfg['task'],
'num_classes': model.num_classes,
"_name": "hubert",
"model": cfg["model"],
"task": cfg["task"],
"num_classes": model.num_classes,
}
elif model.__class__.__name__ == 'HubertCtc':
conf = cfg['model']
del conf['w2v_path']
keep = ['_name', 'task', 'model']
for key in list(k for k in conf['w2v_args'] if k not in keep):
del conf['w2v_args'][key]
conf['data'] = '/foo/bar/'
conf['w2v_args']['task']['data'] = '/foo/bar'
conf['w2v_args']['task']['labels'] = []
conf['w2v_args']['task']['label_dir'] = '/foo/bar'
elif model.__class__.__name__ == "HubertCtc":
conf = cfg["model"]
del conf["w2v_path"]
keep = ["_name", "task", "model"]
for key in list(k for k in conf["w2v_args"] if k not in keep):
del conf["w2v_args"][key]
conf["data"] = "/foo/bar/"
conf["w2v_args"]["task"]["data"] = "/foo/bar"
conf["w2v_args"]["task"]["labels"] = []
conf["w2v_args"]["task"]["label_dir"] = "/foo/bar"
print(json.dumps(conf, indent=4, sort_keys=True))
if __name__ == '__main__':
if __name__ == "__main__":
_main()
......@@ -41,9 +41,9 @@ python generate_wav2vec2_model_config.py \
> wav2vec_large_lv60_self_960h.json
```
"""
import os
import json
import argparse
import json
import os
def _parse_args():
......@@ -52,19 +52,13 @@ def _parse_args():
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
'--model-file',
"--model-file",
required=True,
help=(
'A point file from '
'https://github.com/pytorch/fairseq/tree/main/examples/wav2vec'
)
help=("A point file from " "https://github.com/pytorch/fairseq/tree/main/examples/wav2vec"),
)
parser.add_argument(
'--dict-dir',
help=(
'Directory where `dict.ltr.txt` file is found. '
'Default: the directory of the given model.'
)
"--dict-dir",
help=("Directory where `dict.ltr.txt` file is found. " "Default: the directory of the given model."),
)
args = parser.parse_args()
if args.dict_dir is None:
......@@ -75,32 +69,29 @@ def _parse_args():
def _to_json(conf):
import yaml
from omegaconf import OmegaConf
return yaml.safe_load(OmegaConf.to_yaml(conf))
def _load(model_file, dict_dir):
import fairseq
overrides = {'data': dict_dir}
_, args, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[model_file], arg_overrides=overrides
)
return _to_json(args['model'])
overrides = {"data": dict_dir}
_, args, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file], arg_overrides=overrides)
return _to_json(args["model"])
def _main():
args = _parse_args()
conf = _load(args.model_file, args.dict_dir)
if conf['_name'] == 'wav2vec_ctc':
del conf['data']
del conf['w2v_args']['task']['data']
conf['w2v_args'] = {
key: conf['w2v_args'][key] for key in ['model', 'task']
}
if conf["_name"] == "wav2vec_ctc":
del conf["data"]
del conf["w2v_args"]["task"]["data"]
conf["w2v_args"] = {key: conf["w2v_args"][key] for key in ["model", "task"]}
print(json.dumps(conf, indent=4, sort_keys=True))
if __name__ == '__main__':
if __name__ == "__main__":
_main()
import os
import json
import os
from transformers import Wav2Vec2Model
......@@ -22,16 +22,16 @@ def _main():
"facebook/wav2vec2-large-xlsr-53-german",
]
for key in keys:
path = os.path.join(_THIS_DIR, f'{key}.json')
print('Generating ', path)
path = os.path.join(_THIS_DIR, f"{key}.json")
print("Generating ", path)
cfg = Wav2Vec2Model.from_pretrained(key).config
cfg = json.loads(cfg.to_json_string())
del cfg['_name_or_path']
del cfg["_name_or_path"]
with open(path, 'w') as file_:
with open(path, "w") as file_:
file_.write(json.dumps(cfg, indent=4, sort_keys=True))
file_.write('\n')
file_.write("\n")
if __name__ == '__main__':
if __name__ == "__main__":
_main()
......@@ -3,23 +3,23 @@ from torchaudio_unittest.common_utils import sox_utils
def get_encoding(ext, dtype):
exts = {
'mp3',
'flac',
'vorbis',
"mp3",
"flac",
"vorbis",
}
encodings = {
'float32': 'PCM_F',
'int32': 'PCM_S',
'int16': 'PCM_S',
'uint8': 'PCM_U',
"float32": "PCM_F",
"int32": "PCM_S",
"int16": "PCM_S",
"uint8": "PCM_U",
}
return ext.upper() if ext in exts else encodings[dtype]
def get_bits_per_sample(ext, dtype):
bits_per_samples = {
'flac': 24,
'mp3': 0,
'vorbis': 0,
"flac": 24,
"mp3": 0,
"vorbis": 0,
}
return bits_per_samples.get(ext, sox_utils.get_bit_depth(dtype))
......@@ -38,20 +38,19 @@ def fetch_wav_subtype(dtype, encoding, bits_per_sample):
subtype = {
(None, None): dtype2subtype(dtype),
(None, 8): "PCM_U8",
('PCM_U', None): "PCM_U8",
('PCM_U', 8): "PCM_U8",
('PCM_S', None): "PCM_32",
('PCM_S', 16): "PCM_16",
('PCM_S', 32): "PCM_32",
('PCM_F', None): "FLOAT",
('PCM_F', 32): "FLOAT",
('PCM_F', 64): "DOUBLE",
('ULAW', None): "ULAW",
('ULAW', 8): "ULAW",
('ALAW', None): "ALAW",
('ALAW', 8): "ALAW",
("PCM_U", None): "PCM_U8",
("PCM_U", 8): "PCM_U8",
("PCM_S", None): "PCM_32",
("PCM_S", 16): "PCM_16",
("PCM_S", 32): "PCM_32",
("PCM_F", None): "FLOAT",
("PCM_F", 32): "FLOAT",
("PCM_F", 64): "DOUBLE",
("ULAW", None): "ULAW",
("ULAW", 8): "ULAW",
("ALAW", None): "ALAW",
("ALAW", 8): "ALAW",
}.get((encoding, bits_per_sample))
if subtype:
return subtype
raise ValueError(
f"wav does not support ({encoding}, {bits_per_sample}).")
raise ValueError(f"wav does not support ({encoding}, {bits_per_sample}).")
from unittest.mock import patch
import warnings
import tarfile
import warnings
from unittest.mock import patch
import torch
from torchaudio.backend import soundfile_backend
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import soundfile_backend
from torchaudio_unittest.backend.common import (
get_bits_per_sample,
get_encoding,
)
from torchaudio_unittest.common_utils import (
TempDirMixin,
PytorchTestCase,
......@@ -14,10 +17,7 @@ from torchaudio_unittest.common_utils import (
save_wav,
nested_params,
)
from torchaudio_unittest.backend.common import (
get_bits_per_sample,
get_encoding,
)
from .common import skipIfFormatNotSupported, parameterize
if _mod_utils.is_module_available("soundfile"):
......@@ -27,15 +27,15 @@ if _mod_utils.is_module_available("soundfile"):
@skipIfNoModule("soundfile")
class TestInfo(TempDirMixin, PytorchTestCase):
@parameterize(
["float32", "int32", "int16", "uint8"], [8000, 16000], [1, 2],
["float32", "int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
)
def test_wav(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.info` can check wav file correctly"""
duration = 1
path = self.get_temp_path("data.wav")
data = get_wav_data(
dtype, num_channels, normalize=False, num_frames=duration * sample_rate
)
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = soundfile_backend.info(path)
assert info.sample_rate == sample_rate
......@@ -81,10 +81,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
@nested_params(
[8000, 16000],
[1, 2],
[
('PCM_24', 24),
('PCM_32', 32)
],
[("PCM_24", 24), ("PCM_32", 32)],
)
@skipIfFormatNotSupported("NIST")
def test_sphere(self, sample_rate, num_channels, subtype_and_bit_depth):
......@@ -109,13 +106,15 @@ class TestInfo(TempDirMixin, PytorchTestCase):
This will happen if a new subtype is supported in SoundFile: the _SUBTYPE_TO_BITS_PER_SAMPLE
dict should be updated.
"""
def _mock_info_func(_):
class MockSoundFileInfo:
samplerate = 8000
frames = 356
channels = 2
subtype = 'UNSEEN_SUBTYPE'
format = 'UNKNOWN'
subtype = "UNSEEN_SUBTYPE"
format = "UNKNOWN"
return MockSoundFileInfo()
with patch("soundfile.info", _mock_info_func):
......@@ -134,27 +133,27 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
sample_rate = 16000
num_channels = 2
num_frames = sample_rate * duration
path = self.get_temp_path(f'test.{ext}')
path = self.get_temp_path(f"test.{ext}")
data = torch.randn(num_frames, num_channels).numpy()
soundfile.write(path, data, sample_rate, subtype=subtype)
with open(path, 'rb') as fileobj:
with open(path, "rb") as fileobj:
info = soundfile_backend.info(fileobj)
assert info.sample_rate == sample_rate
assert info.num_frames == num_frames
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
assert info.encoding == "FLAC" if ext == 'flac' else "PCM_S"
assert info.encoding == "FLAC" if ext == "flac" else "PCM_S"
def test_fileobj_wav(self):
"""Loading audio via file-like object works"""
self._test_fileobj('wav', 'PCM_16', 16)
self._test_fileobj("wav", "PCM_16", 16)
@skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self):
"""Loading audio via file-like object works"""
self._test_fileobj('flac', 'PCM_16', 16)
self._test_fileobj("flac", "PCM_16", 16)
def _test_tarobj(self, ext, subtype, bits_per_sample):
"""Query compressed audio via file-like object works"""
......@@ -162,29 +161,29 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
sample_rate = 16000
num_channels = 2
num_frames = sample_rate * duration
audio_file = f'test.{ext}'
audio_file = f"test.{ext}"
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path('archive.tar.gz')
archive_path = self.get_temp_path("archive.tar.gz")
data = torch.randn(num_frames, num_channels).numpy()
soundfile.write(audio_path, data, sample_rate, subtype=subtype)
with tarfile.TarFile(archive_path, 'w') as tarobj:
with tarfile.TarFile(archive_path, "w") as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj:
with tarfile.TarFile(archive_path, "r") as tarobj:
fileobj = tarobj.extractfile(audio_file)
info = soundfile_backend.info(fileobj)
assert info.sample_rate == sample_rate
assert info.num_frames == num_frames
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
assert info.encoding == "FLAC" if ext == 'flac' else "PCM_S"
assert info.encoding == "FLAC" if ext == "flac" else "PCM_S"
def test_tarobj_wav(self):
"""Query compressed audio via file-like object works"""
self._test_tarobj('wav', 'PCM_16', 16)
self._test_tarobj("wav", "PCM_16", 16)
@skipIfFormatNotSupported("FLAC")
def test_tarobj_flac(self):
"""Query compressed audio via file-like object works"""
self._test_tarobj('flac', 'PCM_16', 16)
self._test_tarobj("flac", "PCM_16", 16)
......@@ -3,10 +3,9 @@ import tarfile
from unittest.mock import patch
import torch
from parameterized import parameterized
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import soundfile_backend
from parameterized import parameterized
from torchaudio_unittest.common_utils import (
TempDirMixin,
PytorchTestCase,
......@@ -16,6 +15,7 @@ from torchaudio_unittest.common_utils import (
load_wav,
save_wav,
)
from .common import (
parameterize,
dtype2subtype,
......@@ -27,7 +27,11 @@ if _mod_utils.is_module_available("soundfile"):
def _get_mock_path(
ext: str, dtype: str, sample_rate: int, num_channels: int, num_frames: int,
ext: str,
dtype: str,
sample_rate: int,
num_channels: int,
num_frames: int,
):
return f"{dtype}_{sample_rate}_{num_channels}_{num_frames}.{ext}"
......@@ -86,7 +90,7 @@ class SoundFileMock:
num_frames=self._params["num_frames"],
channels_first=False,
).numpy()
return data[self._start:self._start + frames]
return data[self._start : self._start + frames]
def __enter__(self):
return self
......@@ -96,21 +100,13 @@ class SoundFileMock:
class MockedLoadTest(PytorchTestCase):
def assert_dtype(
self, ext, dtype, sample_rate, num_channels, normalize, channels_first
):
def assert_dtype(self, ext, dtype, sample_rate, num_channels, normalize, channels_first):
"""When format is WAV or NIST, normalize=False will return the native dtype Tensor, otherwise float32"""
num_frames = 3 * sample_rate
path = _get_mock_path(ext, dtype, sample_rate, num_channels, num_frames)
expected_dtype = (
torch.float32
if normalize or ext not in ["wav", "nist"]
else getattr(torch, dtype)
)
expected_dtype = torch.float32 if normalize or ext not in ["wav", "nist"] else getattr(torch, dtype)
with patch("soundfile.SoundFile", SoundFileMock):
found, sr = soundfile_backend.load(
path, normalize=normalize, channels_first=channels_first
)
found, sr = soundfile_backend.load(path, normalize=normalize, channels_first=channels_first)
assert found.dtype == expected_dtype
assert sample_rate == sr
......@@ -123,32 +119,28 @@ class MockedLoadTest(PytorchTestCase):
)
def test_wav(self, dtype, sample_rate, num_channels, normalize, channels_first):
"""Returns native dtype when normalize=False else float32"""
self.assert_dtype(
"wav", dtype, sample_rate, num_channels, normalize, channels_first
)
self.assert_dtype("wav", dtype, sample_rate, num_channels, normalize, channels_first)
@parameterize(
["int8", "int16", "int32"], [8000, 16000], [1, 2], [True, False], [True, False],
["int8", "int16", "int32"],
[8000, 16000],
[1, 2],
[True, False],
[True, False],
)
def test_sphere(self, dtype, sample_rate, num_channels, normalize, channels_first):
"""Returns float32 always"""
self.assert_dtype(
"sph", dtype, sample_rate, num_channels, normalize, channels_first
)
self.assert_dtype("sph", dtype, sample_rate, num_channels, normalize, channels_first)
@parameterize([8000, 16000], [1, 2], [True, False], [True, False])
def test_ogg(self, sample_rate, num_channels, normalize, channels_first):
"""Returns float32 always"""
self.assert_dtype(
"ogg", "int16", sample_rate, num_channels, normalize, channels_first
)
self.assert_dtype("ogg", "int16", sample_rate, num_channels, normalize, channels_first)
@parameterize([8000, 16000], [1, 2], [True, False], [True, False])
def test_flac(self, sample_rate, num_channels, normalize, channels_first):
"""`soundfile_backend.load` can load ogg format."""
self.assert_dtype(
"flac", "int16", sample_rate, num_channels, normalize, channels_first
)
self.assert_dtype("flac", "int16", sample_rate, num_channels, normalize, channels_first)
class LoadTestBase(TempDirMixin, PytorchTestCase):
......@@ -176,14 +168,17 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
)
save_wav(path, data, sample_rate, channels_first=channels_first)
expected = load_wav(path, normalize=normalize, channels_first=channels_first)[0]
data, sr = soundfile_backend.load(
path, normalize=normalize, channels_first=channels_first
)
data, sr = soundfile_backend.load(path, normalize=normalize, channels_first=channels_first)
assert sr == sample_rate
self.assertEqual(data, expected)
def assert_sphere(
self, dtype, sample_rate, num_channels, channels_first=True, duration=1,
self,
dtype,
sample_rate,
num_channels,
channels_first=True,
duration=1,
):
"""`soundfile_backend.load` can load SPHERE format correctly."""
path = self.get_temp_path("reference.sph")
......@@ -195,16 +190,19 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
normalize=False,
channels_first=False,
)
soundfile.write(
path, raw, sample_rate, subtype=dtype2subtype(dtype), format="NIST"
)
soundfile.write(path, raw, sample_rate, subtype=dtype2subtype(dtype), format="NIST")
expected = normalize_wav(raw.t() if channels_first else raw)
data, sr = soundfile_backend.load(path, channels_first=channels_first)
assert sr == sample_rate
self.assertEqual(data, expected, atol=1e-4, rtol=1e-8)
def assert_flac(
self, dtype, sample_rate, num_channels, channels_first=True, duration=1,
self,
dtype,
sample_rate,
num_channels,
channels_first=True,
duration=1,
):
"""`soundfile_backend.load` can load FLAC format correctly."""
path = self.get_temp_path("reference.flac")
......@@ -239,7 +237,10 @@ class TestLoad(LoadTestBase):
self.assert_wav(dtype, sample_rate, num_channels, normalize, channels_first)
@parameterize(
["int16"], [16000], [2], [False],
["int16"],
[16000],
[2],
[False],
)
def test_wav_large(self, dtype, sample_rate, num_channels, normalize):
"""`soundfile_backend.load` can load large wav file correctly."""
......@@ -269,15 +270,16 @@ class TestLoad(LoadTestBase):
@skipIfNoModule("soundfile")
class TestLoadFormat(TempDirMixin, PytorchTestCase):
"""Given `format` parameter, `so.load` can load files without extension"""
original = None
path = None
def _make_file(self, format_):
sample_rate = 8000
path_with_ext = self.get_temp_path(f'test.{format_}')
data = get_wav_data('float32', num_channels=2).numpy().T
path_with_ext = self.get_temp_path(f"test.{format_}")
data = get_wav_data("float32", num_channels=2).numpy().T
soundfile.write(path_with_ext, data, sample_rate)
expected = soundfile.read(path_with_ext, dtype='float32')[0].T
expected = soundfile.read(path_with_ext, dtype="float32")[0].T
path = os.path.splitext(path_with_ext)[0]
os.rename(path_with_ext, path)
return path, expected
......@@ -288,15 +290,21 @@ class TestLoadFormat(TempDirMixin, PytorchTestCase):
found, _ = soundfile_backend.load(path)
self.assertEqual(found, expected)
@parameterized.expand([
('WAV', ), ('wav', ),
])
@parameterized.expand(
[
("WAV",),
("wav",),
]
)
def test_wav(self, format_):
self._test_format(format_)
@parameterized.expand([
('FLAC', ), ('flac',),
])
@parameterized.expand(
[
("FLAC",),
("flac",),
]
)
@skipIfFormatNotSupported("FLAC")
def test_flac(self, format_):
self._test_format(format_)
......@@ -307,40 +315,40 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
def _test_fileobj(self, ext):
"""Loading audio via file-like object works"""
sample_rate = 16000
path = self.get_temp_path(f'test.{ext}')
path = self.get_temp_path(f"test.{ext}")
data = get_wav_data('float32', num_channels=2).numpy().T
data = get_wav_data("float32", num_channels=2).numpy().T
soundfile.write(path, data, sample_rate)
expected = soundfile.read(path, dtype='float32')[0].T
expected = soundfile.read(path, dtype="float32")[0].T
with open(path, 'rb') as fileobj:
with open(path, "rb") as fileobj:
found, sr = soundfile_backend.load(fileobj)
assert sr == sample_rate
self.assertEqual(expected, found)
def test_fileobj_wav(self):
"""Loading audio via file-like object works"""
self._test_fileobj('wav')
self._test_fileobj("wav")
@skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self):
"""Loading audio via file-like object works"""
self._test_fileobj('flac')
self._test_fileobj("flac")
def _test_tarfile(self, ext):
"""Loading audio via file-like object works"""
sample_rate = 16000
audio_file = f'test.{ext}'
audio_file = f"test.{ext}"
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path('archive.tar.gz')
archive_path = self.get_temp_path("archive.tar.gz")
data = get_wav_data('float32', num_channels=2).numpy().T
data = get_wav_data("float32", num_channels=2).numpy().T
soundfile.write(audio_path, data, sample_rate)
expected = soundfile.read(audio_path, dtype='float32')[0].T
expected = soundfile.read(audio_path, dtype="float32")[0].T
with tarfile.TarFile(archive_path, 'w') as tarobj:
with tarfile.TarFile(archive_path, "w") as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj:
with tarfile.TarFile(archive_path, "r") as tarobj:
fileobj = tarobj.extractfile(audio_file)
found, sr = soundfile_backend.load(fileobj)
......@@ -349,9 +357,9 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
def test_tarfile_wav(self):
"""Loading audio via file-like object works"""
self._test_tarfile('wav')
self._test_tarfile("wav")
@skipIfFormatNotSupported("FLAC")
def test_tarfile_flac(self):
"""Loading audio via file-like object works"""
self._test_tarfile('flac')
self._test_tarfile("flac")
......@@ -3,7 +3,6 @@ from unittest.mock import patch
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import soundfile_backend
from torchaudio_unittest.common_utils import (
TempDirMixin,
PytorchTestCase,
......@@ -12,6 +11,7 @@ from torchaudio_unittest.common_utils import (
load_wav,
nested_params,
)
from .common import (
fetch_wav_subtype,
parameterize,
......@@ -30,23 +30,22 @@ class MockedSaveTest(PytorchTestCase):
[False, True],
[
(None, None),
('PCM_U', None),
('PCM_U', 8),
('PCM_S', None),
('PCM_S', 16),
('PCM_S', 32),
('PCM_F', None),
('PCM_F', 32),
('PCM_F', 64),
('ULAW', None),
('ULAW', 8),
('ALAW', None),
('ALAW', 8),
("PCM_U", None),
("PCM_U", 8),
("PCM_S", None),
("PCM_S", 16),
("PCM_S", 32),
("PCM_F", None),
("PCM_F", 32),
("PCM_F", 64),
("ULAW", None),
("ULAW", 8),
("ALAW", None),
("ALAW", 8),
],
)
@patch("soundfile.write")
def test_wav(self, dtype, sample_rate, num_channels, channels_first,
enc_params, mocked_write):
def test_wav(self, dtype, sample_rate, num_channels, channels_first, enc_params, mocked_write):
"""soundfile_backend.save passes correct subtype to soundfile.write when WAV"""
filepath = "foo.wav"
input_tensor = get_wav_data(
......@@ -59,25 +58,33 @@ class MockedSaveTest(PytorchTestCase):
encoding, bits_per_sample = enc_params
soundfile_backend.save(
filepath, input_tensor, sample_rate, channels_first=channels_first,
encoding=encoding, bits_per_sample=bits_per_sample
filepath,
input_tensor,
sample_rate,
channels_first=channels_first,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
# on +Py3.8 call_args.kwargs is more descreptive
args = mocked_write.call_args[1]
assert args["file"] == filepath
assert args["samplerate"] == sample_rate
assert args["subtype"] == fetch_wav_subtype(
dtype, encoding, bits_per_sample)
assert args["subtype"] == fetch_wav_subtype(dtype, encoding, bits_per_sample)
assert args["format"] is None
self.assertEqual(
args["data"], input_tensor.t() if channels_first else input_tensor
)
self.assertEqual(args["data"], input_tensor.t() if channels_first else input_tensor)
@patch("soundfile.write")
def assert_non_wav(
self, fmt, dtype, sample_rate, num_channels, channels_first, mocked_write,
encoding=None, bits_per_sample=None,
self,
fmt,
dtype,
sample_rate,
num_channels,
channels_first,
mocked_write,
encoding=None,
bits_per_sample=None,
):
"""soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE"""
filepath = f"foo.{fmt}"
......@@ -91,8 +98,12 @@ class MockedSaveTest(PytorchTestCase):
expected_data = input_tensor.t() if channels_first else input_tensor
soundfile_backend.save(
filepath, input_tensor, sample_rate, channels_first,
encoding=encoding, bits_per_sample=bits_per_sample,
filepath,
input_tensor,
sample_rate,
channels_first,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
# on +Py3.8 call_args.kwargs is more descreptive
......@@ -112,38 +123,42 @@ class MockedSaveTest(PytorchTestCase):
[1, 2],
[False, True],
[
('PCM_S', 8),
('PCM_S', 16),
('PCM_S', 24),
('PCM_S', 32),
('ULAW', 8),
('ALAW', 8),
('ALAW', 16),
('ALAW', 24),
('ALAW', 32),
("PCM_S", 8),
("PCM_S", 16),
("PCM_S", 24),
("PCM_S", 32),
("ULAW", 8),
("ALAW", 8),
("ALAW", 16),
("ALAW", 24),
("ALAW", 32),
],
)
def test_sph(self, fmt, dtype, sample_rate, num_channels, channels_first, enc_params):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
encoding, bits_per_sample = enc_params
self.assert_non_wav(fmt, dtype, sample_rate, num_channels,
channels_first, encoding=encoding,
bits_per_sample=bits_per_sample)
self.assert_non_wav(
fmt, dtype, sample_rate, num_channels, channels_first, encoding=encoding, bits_per_sample=bits_per_sample
)
@parameterize(
["int32", "int16"], [8000, 16000], [1, 2], [False, True],
["int32", "int16"],
[8000, 16000],
[1, 2],
[False, True],
[8, 16, 24],
)
def test_flac(self, dtype, sample_rate, num_channels,
channels_first, bits_per_sample):
def test_flac(self, dtype, sample_rate, num_channels, channels_first, bits_per_sample):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
self.assert_non_wav("flac", dtype, sample_rate, num_channels,
channels_first, bits_per_sample=bits_per_sample)
self.assert_non_wav("flac", dtype, sample_rate, num_channels, channels_first, bits_per_sample=bits_per_sample)
@parameterize(
["int32", "int16"], [8000, 16000], [1, 2], [False, True],
["int32", "int16"],
[8000, 16000],
[1, 2],
[False, True],
)
def test_ogg(self, dtype, sample_rate, num_channels, channels_first):
"""soundfile_backend.save passes default format and subtype (None-s) to
......@@ -156,9 +171,7 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
def assert_wav(self, dtype, sample_rate, num_channels, num_frames):
"""`soundfile_backend.save` can save wav format."""
path = self.get_temp_path("data.wav")
expected = get_wav_data(
dtype, num_channels, num_frames=num_frames, normalize=False
)
expected = get_wav_data(dtype, num_channels, num_frames=num_frames, normalize=False)
soundfile_backend.save(path, expected, sample_rate)
found, sr = load_wav(path, normalize=False)
assert sample_rate == sr
......@@ -172,9 +185,7 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
"""
num_frames = sample_rate * 3
path = self.get_temp_path(f"data.{fmt}")
expected = get_wav_data(
dtype, num_channels, num_frames=num_frames, normalize=False
)
expected = get_wav_data(dtype, num_channels, num_frames=num_frames, normalize=False)
soundfile_backend.save(path, expected, sample_rate)
sinfo = soundfile.info(path)
assert sinfo.format == fmt.upper()
......@@ -201,14 +212,17 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
@skipIfNoModule("soundfile")
class TestSave(SaveTestBase):
@parameterize(
["float32", "int32", "int16"], [8000, 16000], [1, 2],
["float32", "int32", "int16"],
[8000, 16000],
[1, 2],
)
def test_wav(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save wav format."""
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
@parameterize(
["float32", "int32", "int16"], [4, 8, 16, 32],
["float32", "int32", "int16"],
[4, 8, 16, 32],
)
def test_multiple_channels(self, dtype, num_channels):
"""`soundfile_backend.save` can save wav with more than 2 channels."""
......@@ -216,7 +230,9 @@ class TestSave(SaveTestBase):
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
@parameterize(
["int32", "int16"], [8000, 16000], [1, 2],
["int32", "int16"],
[8000, 16000],
[1, 2],
)
@skipIfFormatNotSupported("NIST")
def test_sphere(self, dtype, sample_rate, num_channels):
......@@ -224,7 +240,8 @@ class TestSave(SaveTestBase):
self.assert_sphere(dtype, sample_rate, num_channels)
@parameterize(
[8000, 16000], [1, 2],
[8000, 16000],
[1, 2],
)
@skipIfFormatNotSupported("FLAC")
def test_flac(self, sample_rate, num_channels):
......@@ -232,7 +249,8 @@ class TestSave(SaveTestBase):
self.assert_flac("float32", sample_rate, num_channels)
@parameterize(
[8000, 16000], [1, 2],
[8000, 16000],
[1, 2],
)
@skipIfFormatNotSupported("OGG")
def test_ogg(self, sample_rate, num_channels):
......@@ -260,36 +278,36 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
def _test_fileobj(self, ext):
"""Saving audio to file-like object works"""
sample_rate = 16000
path = self.get_temp_path(f'test.{ext}')
path = self.get_temp_path(f"test.{ext}")
subtype = 'FLOAT' if ext == 'wav' else None
data = get_wav_data('float32', num_channels=2)
subtype = "FLOAT" if ext == "wav" else None
data = get_wav_data("float32", num_channels=2)
soundfile.write(path, data.numpy().T, sample_rate, subtype=subtype)
expected = soundfile.read(path, dtype='float32')[0]
expected = soundfile.read(path, dtype="float32")[0]
fileobj = io.BytesIO()
soundfile_backend.save(fileobj, data, sample_rate, format=ext)
fileobj.seek(0)
found, sr = soundfile.read(fileobj, dtype='float32')
found, sr = soundfile.read(fileobj, dtype="float32")
assert sr == sample_rate
self.assertEqual(expected, found, atol=1e-4, rtol=1e-8)
def test_fileobj_wav(self):
"""Saving audio via file-like object works"""
self._test_fileobj('wav')
self._test_fileobj("wav")
@skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self):
"""Saving audio via file-like object works"""
self._test_fileobj('flac')
self._test_fileobj("flac")
@skipIfFormatNotSupported("NIST")
def test_fileobj_nist(self):
"""Saving audio via file-like object works"""
self._test_fileobj('NIST')
self._test_fileobj("NIST")
@skipIfFormatNotSupported("OGG")
def test_fileobj_ogg(self):
"""Saving audio via file-like object works"""
self._test_fileobj('OGG')
self._test_fileobj("OGG")
......@@ -3,12 +3,12 @@ def name_func(func, _, params):
def get_enc_params(dtype):
if dtype == 'float32':
return 'PCM_F', 32
if dtype == 'int32':
return 'PCM_S', 32
if dtype == 'int16':
return 'PCM_S', 16
if dtype == 'uint8':
return 'PCM_U', 8
raise ValueError(f'Unexpected dtype: {dtype}')
if dtype == "float32":
return "PCM_F", 32
if dtype == "int32":
return "PCM_S", 32
if dtype == "int16":
return "PCM_S", 16
if dtype == "uint8":
return "PCM_U", 8
raise ValueError(f"Unexpected dtype: {dtype}")
from contextlib import contextmanager
import io
import os
import itertools
import os
import tarfile
from contextlib import contextmanager
from parameterized import parameterized
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import sox_io_backend
from torchaudio.utils.sox_utils import get_buffer_size, set_buffer_size
from torchaudio._internal import module_utils as _mod_utils
from torchaudio_unittest.backend.common import (
get_bits_per_sample,
get_encoding,
......@@ -25,6 +24,7 @@ from torchaudio_unittest.common_utils import (
save_wav,
sox_utils,
)
from .common import (
name_func,
)
......@@ -34,18 +34,23 @@ if _mod_utils.is_module_available("requests"):
import requests
@skipIfNoExec('sox')
@skipIfNoExec("sox")
@skipIfNoSox
class TestInfo(TempDirMixin, PytorchTestCase):
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
["float32", "int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_wav(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file correctly"""
duration = 1
path = self.get_temp_path('data.wav')
path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = sox_io_backend.info(path)
......@@ -53,17 +58,22 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == sox_utils.get_bit_depth(dtype)
assert info.encoding == get_encoding('wav', dtype)
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[4, 8, 16, 32],
)), name_func=name_func)
assert info.encoding == get_encoding("wav", dtype)
@parameterized.expand(
list(
itertools.product(
["float32", "int32", "int16", "uint8"],
[8000, 16000],
[4, 8, 16, 32],
)
),
name_func=name_func,
)
def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file with channels more than 2 correctly"""
duration = 1
path = self.get_temp_path('data.wav')
path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = sox_io_backend.info(path)
......@@ -71,20 +81,28 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == sox_utils.get_bit_depth(dtype)
assert info.encoding == get_encoding('wav', dtype)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[96, 128, 160, 192, 224, 256, 320],
)), name_func=name_func)
assert info.encoding == get_encoding("wav", dtype)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
[96, 128, 160, 192, 224, 256, 320],
)
),
name_func=name_func,
)
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.info` can check mp3 file correctly"""
duration = 1
path = self.get_temp_path('data.mp3')
path = self.get_temp_path("data.mp3")
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=bit_rate, duration=duration,
path,
sample_rate,
num_channels,
compression=bit_rate,
duration=duration,
)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
......@@ -94,18 +112,26 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
assert info.encoding == "MP3"
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)
),
name_func=name_func,
)
def test_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.info` can check flac file correctly"""
duration = 1
path = self.get_temp_path('data.flac')
path = self.get_temp_path("data.flac")
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=compression_level, duration=duration,
path,
sample_rate,
num_channels,
compression=compression_level,
duration=duration,
)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
......@@ -114,18 +140,26 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.bits_per_sample == 24 # FLAC standard
assert info.encoding == "FLAC"
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)
),
name_func=name_func,
)
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.info` can check vorbis file correctly"""
duration = 1
path = self.get_temp_path('data.vorbis')
path = self.get_temp_path("data.vorbis")
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=quality_level, duration=duration,
path,
sample_rate,
num_channels,
compression=quality_level,
duration=duration,
)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
......@@ -134,18 +168,21 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
assert info.encoding == "VORBIS"
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[16, 32],
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
[16, 32],
)
),
name_func=name_func,
)
def test_sphere(self, sample_rate, num_channels, bits_per_sample):
"""`sox_io_backend.info` can check sph file correctly"""
duration = 1
path = self.get_temp_path('data.sph')
sox_utils.gen_audio_file(
path, sample_rate, num_channels, duration=duration,
bit_depth=bits_per_sample)
path = self.get_temp_path("data.sph")
sox_utils.gen_audio_file(path, sample_rate, num_channels, duration=duration, bit_depth=bits_per_sample)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
......@@ -153,19 +190,22 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.bits_per_sample == bits_per_sample
assert info.encoding == "PCM_S"
@parameterized.expand(list(itertools.product(
['int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
["int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_amb(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check amb file correctly"""
duration = 1
path = self.get_temp_path('data.amb')
path = self.get_temp_path("data.amb")
bits_per_sample = sox_utils.get_bit_depth(dtype)
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=bits_per_sample, duration=duration)
sox_utils.gen_audio_file(path, sample_rate, num_channels, bit_depth=bits_per_sample, duration=duration)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
......@@ -178,10 +218,10 @@ class TestInfo(TempDirMixin, PytorchTestCase):
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path('data.amr-nb')
path = self.get_temp_path("data.amr-nb")
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=16,
duration=duration)
path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=16, duration=duration
)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
......@@ -194,11 +234,10 @@ class TestInfo(TempDirMixin, PytorchTestCase):
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path('data.wav')
path = self.get_temp_path("data.wav")
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels,
bit_depth=8, encoding='u-law',
duration=duration)
path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=8, encoding="u-law", duration=duration
)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
......@@ -211,11 +250,10 @@ class TestInfo(TempDirMixin, PytorchTestCase):
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path('data.wav')
path = self.get_temp_path("data.wav")
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels,
bit_depth=8, encoding='a-law',
duration=duration)
path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=8, encoding="a-law", duration=duration
)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
......@@ -228,10 +266,8 @@ class TestInfo(TempDirMixin, PytorchTestCase):
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path('data.gsm')
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels,
duration=duration)
path = self.get_temp_path("data.gsm")
sox_utils.gen_audio_file(path, sample_rate=sample_rate, num_channels=num_channels, duration=duration)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_channels == num_channels
......@@ -243,10 +279,10 @@ class TestInfo(TempDirMixin, PytorchTestCase):
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path('data.htk')
path = self.get_temp_path("data.htk")
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels,
bit_depth=16, duration=duration)
path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=16, duration=duration
)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
......@@ -257,14 +293,19 @@ class TestInfo(TempDirMixin, PytorchTestCase):
@skipIfNoSox
class TestInfoOpus(PytorchTestCase):
@parameterized.expand(list(itertools.product(
['96k'],
[1, 2],
[0, 5, 10],
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
["96k"],
[1, 2],
[0, 5, 10],
)
),
name_func=name_func,
)
def test_opus(self, bitrate, num_channels, compression_level):
"""`sox_io_backend.info` can check opus file correcty"""
path = get_asset_path('io', f'{bitrate}_{compression_level}_{num_channels}ch.opus')
path = get_asset_path("io", f"{bitrate}_{compression_level}_{num_channels}ch.opus")
info = sox_io_backend.info(path)
assert info.sample_rate == 48000
assert info.num_frames == 32768
......@@ -296,13 +337,15 @@ class TestLoadWithoutExtension(PytorchTestCase):
class FileObjTestBase(TempDirMixin):
def _gen_file(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
path = self.get_temp_path(f'test.{ext}')
path = self.get_temp_path(f"test.{ext}")
bit_depth = sox_utils.get_bit_depth(dtype)
duration = num_frames / sample_rate
comment_file = self._gen_comment_file(comments) if comments else None
sox_utils.gen_audio_file(
path, sample_rate, num_channels=num_channels,
path,
sample_rate,
num_channels=num_channels,
encoding=sox_utils.get_encoding(dtype),
bit_depth=bit_depth,
duration=duration,
......@@ -318,29 +361,29 @@ class FileObjTestBase(TempDirMixin):
@skipIfNoSox
@skipIfNoExec('sox')
@skipIfNoExec("sox")
class TestFileObject(FileObjTestBase, PytorchTestCase):
def _query_fileobj(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
format_ = ext if ext in ['mp3'] else None
with open(path, 'rb') as fileobj:
format_ = ext if ext in ["mp3"] else None
with open(path, "rb") as fileobj:
return sox_io_backend.info(fileobj, format_)
def _query_bytesio(self, ext, dtype, sample_rate, num_channels, num_frames):
path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
format_ = ext if ext in ['mp3'] else None
with open(path, 'rb') as file_:
format_ = ext if ext in ["mp3"] else None
with open(path, "rb") as file_:
fileobj = io.BytesIO(file_.read())
return sox_io_backend.info(fileobj, format_)
def _query_tarfile(self, ext, dtype, sample_rate, num_channels, num_frames):
audio_path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
audio_file = os.path.basename(audio_path)
archive_path = self.get_temp_path('archive.tar.gz')
with tarfile.TarFile(archive_path, 'w') as tarobj:
archive_path = self.get_temp_path("archive.tar.gz")
with tarfile.TarFile(archive_path, "w") as tarobj:
tarobj.add(audio_path, arcname=audio_file)
format_ = ext if ext in ['mp3'] else None
with tarfile.TarFile(archive_path, 'r') as tarobj:
format_ = ext if ext in ["mp3"] else None
with tarfile.TarFile(archive_path, "r") as tarobj:
fileobj = tarobj.extractfile(audio_file)
return sox_io_backend.info(fileobj, format_)
......@@ -353,16 +396,18 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
finally:
set_buffer_size(original_buffer_size)
@parameterized.expand([
('wav', "float32"),
('wav', "int32"),
('wav', "int16"),
('wav', "uint8"),
('mp3', "float32"),
('flac', "float32"),
('vorbis', "float32"),
('amb', "int16"),
])
@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
("mp3", "float32"),
("flac", "float32"),
("vorbis", "float32"),
("amb", "int16"),
]
)
def test_fileobj(self, ext, dtype):
"""Querying audio via file object works"""
sample_rate = 16000
......@@ -371,7 +416,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames
num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
......@@ -379,9 +424,11 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand([
('vorbis', "float32"),
])
@parameterized.expand(
[
("vorbis", "float32"),
]
)
def test_fileobj_large_header(self, ext, dtype):
"""
For audio file with header size exceeding default buffer size:
......@@ -399,7 +446,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
with self._set_buffer_size(16384):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames
num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
......@@ -407,16 +454,18 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand([
('wav', "float32"),
('wav', "int32"),
('wav', "int16"),
('wav', "uint8"),
('mp3', "float32"),
('flac', "float32"),
('vorbis', "float32"),
('amb', "int16"),
])
@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
("mp3", "float32"),
("flac", "float32"),
("vorbis", "float32"),
("amb", "int16"),
]
)
def test_bytesio(self, ext, dtype):
"""Querying audio via ByteIO object works for small data"""
sample_rate = 16000
......@@ -425,7 +474,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames
num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
......@@ -433,16 +482,18 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand([
('wav', "float32"),
('wav', "int32"),
('wav', "int16"),
('wav', "uint8"),
('mp3', "float32"),
('flac', "float32"),
('vorbis', "float32"),
('amb', "int16"),
])
@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
("mp3", "float32"),
("flac", "float32"),
("vorbis", "float32"),
("amb", "int16"),
]
)
def test_bytesio_tiny(self, ext, dtype):
"""Querying audio via ByteIO object works for small data"""
sample_rate = 8000
......@@ -451,7 +502,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames
num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
......@@ -459,16 +510,18 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand([
('wav', "float32"),
('wav', "int32"),
('wav', "int16"),
('wav', "uint8"),
('mp3', "float32"),
('flac', "float32"),
('vorbis', "float32"),
('amb', "int16"),
])
@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
("mp3", "float32"),
("flac", "float32"),
("vorbis", "float32"),
("amb", "int16"),
]
)
def test_tarfile(self, ext, dtype):
"""Querying compressed audio via file-like object works"""
sample_rate = 16000
......@@ -477,7 +530,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
sinfo = self._query_tarfile(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames
num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
......@@ -487,7 +540,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
@skipIfNoSox
@skipIfNoExec('sox')
@skipIfNoExec("sox")
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase):
def _query_http(self, ext, dtype, sample_rate, num_channels, num_frames):
......@@ -495,20 +548,22 @@ class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase):
audio_file = os.path.basename(audio_path)
url = self.get_url(audio_file)
format_ = ext if ext in ['mp3'] else None
format_ = ext if ext in ["mp3"] else None
with requests.get(url, stream=True) as resp:
return sox_io_backend.info(resp.raw, format=format_)
@parameterized.expand([
('wav', "float32"),
('wav', "int32"),
('wav', "int16"),
('wav', "uint8"),
('mp3', "float32"),
('flac', "float32"),
('vorbis', "float32"),
('amb', "int16"),
])
@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
("mp3", "float32"),
("flac", "float32"),
("vorbis", "float32"),
("amb", "int16"),
]
)
def test_requests(self, ext, dtype):
"""Querying compressed audio via requests works"""
sample_rate = 16000
......@@ -517,7 +572,7 @@ class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase):
sinfo = self._query_http(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames
num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
......
......@@ -3,9 +3,8 @@ import itertools
import tarfile
from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import sox_io_backend
from torchaudio_unittest.common_utils import (
TempDirMixin,
HttpServerMixin,
......@@ -19,6 +18,7 @@ from torchaudio_unittest.common_utils import (
save_wav,
sox_utils,
)
from .common import (
name_func,
)
......@@ -30,17 +30,17 @@ if _mod_utils.is_module_available("requests"):
class LoadTestBase(TempDirMixin, PytorchTestCase):
def assert_format(
self,
format: str,
sample_rate: float,
num_channels: int,
compression: float = None,
bit_depth: int = None,
duration: float = 1,
normalize: bool = True,
encoding: str = None,
atol: float = 4e-05,
rtol: float = 1.3e-06,
self,
format: str,
sample_rate: float,
num_channels: int,
compression: float = None,
bit_depth: int = None,
duration: float = 1,
normalize: bool = True,
encoding: str = None,
atol: float = 4e-05,
rtol: float = 1.3e-06,
):
"""`sox_io_backend.load` can load given format correctly.
......@@ -68,13 +68,18 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
data without using torchaudio
"""
path = self.get_temp_path(f'1.original.{format}')
ref_path = self.get_temp_path('2.reference.wav')
path = self.get_temp_path(f"1.original.{format}")
ref_path = self.get_temp_path("2.reference.wav")
# 1. Generate the given format with sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels, encoding=encoding,
compression=compression, bit_depth=bit_depth, duration=duration,
path,
sample_rate,
num_channels,
encoding=encoding,
compression=compression,
bit_depth=bit_depth,
duration=duration,
)
# 2. Convert to wav with sox
wav_bit_depth = 32 if bit_depth == 24 else None # for 24-bit wav
......@@ -92,7 +97,7 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
Wav data loaded with sox_io backend should match those with scipy
"""
path = self.get_temp_path('reference.wav')
path = self.get_temp_path("reference.wav")
data = get_wav_data(dtype, num_channels, normalize=normalize, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
expected = load_wav(path, normalize=normalize)[0]
......@@ -101,118 +106,176 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
self.assertEqual(data, expected)
@skipIfNoExec('sox')
@skipIfNoExec("sox")
@skipIfNoSox
class TestLoad(LoadTestBase):
"""Test the correctness of `sox_io_backend.load` for various formats"""
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
[False, True],
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
["float32", "int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
[False, True],
)
),
name_func=name_func,
)
def test_wav(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load wav format correctly."""
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[False, True],
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
[False, True],
)
),
name_func=name_func,
)
def test_24bit_wav(self, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load 24bit wav format correctly. Corectly casts it to ``int32`` tensor dtype."""
self.assert_format("wav", sample_rate, num_channels, bit_depth=24, normalize=normalize, duration=1)
@parameterized.expand(list(itertools.product(
['int16'],
[16000],
[2],
[False],
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
["int16"],
[16000],
[2],
[False],
)
),
name_func=name_func,
)
def test_wav_large(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load large wav file correctly."""
two_hours = 2 * 60 * 60
self.assert_wav(dtype, sample_rate, num_channels, normalize, two_hours)
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[4, 8, 16, 32],
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
["float32", "int32", "int16", "uint8"],
[4, 8, 16, 32],
)
),
name_func=name_func,
)
def test_multiple_channels(self, dtype, num_channels):
"""`sox_io_backend.load` can load wav file with more than 2 channels."""
sample_rate = 8000
normalize = False
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1)
@parameterized.expand(list(itertools.product(
[8000, 16000, 44100],
[1, 2],
[96, 128, 160, 192, 224, 256, 320],
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
[8000, 16000, 44100],
[1, 2],
[96, 128, 160, 192, 224, 256, 320],
)
),
name_func=name_func,
)
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.load` can load mp3 format correctly."""
self.assert_format("mp3", sample_rate, num_channels, compression=bit_rate, duration=1, atol=5e-05)
@parameterized.expand(list(itertools.product(
[16000],
[2],
[128],
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
[16000],
[2],
[128],
)
),
name_func=name_func,
)
def test_mp3_large(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.load` can load large mp3 file correctly."""
two_hours = 2 * 60 * 60
self.assert_format("mp3", sample_rate, num_channels, compression=bit_rate, duration=two_hours, atol=5e-05)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)
),
name_func=name_func,
)
def test_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.load` can load flac format correctly."""
self.assert_format("flac", sample_rate, num_channels, compression=compression_level, bit_depth=16, duration=1)
@parameterized.expand(list(itertools.product(
[16000],
[2],
[0],
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
[16000],
[2],
[0],
)
),
name_func=name_func,
)
def test_flac_large(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.load` can load large flac file correctly."""
two_hours = 2 * 60 * 60
self.assert_format(
"flac", sample_rate, num_channels, compression=compression_level, bit_depth=16, duration=two_hours)
"flac", sample_rate, num_channels, compression=compression_level, bit_depth=16, duration=two_hours
)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)
),
name_func=name_func,
)
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.load` can load vorbis format correctly."""
self.assert_format("vorbis", sample_rate, num_channels, compression=quality_level, bit_depth=16, duration=1)
@parameterized.expand(list(itertools.product(
[16000],
[2],
[10],
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
[16000],
[2],
[10],
)
),
name_func=name_func,
)
def test_vorbis_large(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.load` can load large vorbis file correctly."""
two_hours = 2 * 60 * 60
self.assert_format(
"vorbis", sample_rate, num_channels, compression=quality_level, bit_depth=16, duration=two_hours)
"vorbis", sample_rate, num_channels, compression=quality_level, bit_depth=16, duration=two_hours
)
@parameterized.expand(list(itertools.product(
['96k'],
[1, 2],
[0, 5, 10],
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
["96k"],
[1, 2],
[0, 5, 10],
)
),
name_func=name_func,
)
def test_opus(self, bitrate, num_channels, compression_level):
"""`sox_io_backend.load` can load opus file correctly."""
ops_path = get_asset_path('io', f'{bitrate}_{compression_level}_{num_channels}ch.opus')
wav_path = self.get_temp_path(f'{bitrate}_{compression_level}_{num_channels}ch.opus.wav')
ops_path = get_asset_path("io", f"{bitrate}_{compression_level}_{num_channels}ch.opus")
wav_path = self.get_temp_path(f"{bitrate}_{compression_level}_{num_channels}ch.opus.wav")
sox_utils.convert_audio_file(ops_path, wav_path)
expected, sample_rate = load_wav(wav_path)
......@@ -221,57 +284,74 @@ class TestLoad(LoadTestBase):
assert sample_rate == sr
self.assertEqual(expected, found)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_sphere(self, sample_rate, num_channels):
"""`sox_io_backend.load` can load sph format correctly."""
self.assert_format("sph", sample_rate, num_channels, bit_depth=32, duration=1)
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16'],
[8000, 16000],
[1, 2],
[False, True],
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
["float32", "int32", "int16"],
[8000, 16000],
[1, 2],
[False, True],
)
),
name_func=name_func,
)
def test_amb(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load amb format correctly."""
bit_depth = sox_utils.get_bit_depth(dtype)
encoding = sox_utils.get_encoding(dtype)
self.assert_format(
"amb", sample_rate, num_channels, bit_depth=bit_depth, duration=1, encoding=encoding, normalize=normalize)
"amb", sample_rate, num_channels, bit_depth=bit_depth, duration=1, encoding=encoding, normalize=normalize
)
def test_amr_nb(self):
"""`sox_io_backend.load` can load amr_nb format correctly."""
self.assert_format("amr-nb", sample_rate=8000, num_channels=1, bit_depth=32, duration=1)
@skipIfNoExec('sox')
@skipIfNoExec("sox")
@skipIfNoSox
class TestLoadParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of frame parameters of `sox_io_backend.load`"""
original = None
path = None
def setUp(self):
super().setUp()
sample_rate = 8000
self.original = get_wav_data('float32', num_channels=2)
self.path = self.get_temp_path('test.wav')
self.original = get_wav_data("float32", num_channels=2)
self.path = self.get_temp_path("test.wav")
save_wav(self.path, self.original, sample_rate)
@parameterized.expand(list(itertools.product(
[0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
[0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
)
),
name_func=name_func,
)
def test_frame(self, frame_offset, num_frames):
"""num_frames and frame_offset correctly specify the region of data"""
found, _ = sox_io_backend.load(self.path, frame_offset, num_frames)
frame_end = None if num_frames == -1 else frame_offset + num_frames
self.assertEqual(found, self.original[:, frame_offset:frame_end])
@parameterized.expand([(True, ), (False, )], name_func=name_func)
@parameterized.expand([(True,), (False,)], name_func=name_func)
def test_channels_first(self, channels_first):
"""channels_first swaps axes"""
found, _ = sox_io_backend.load(self.path, channels_first=channels_first)
......@@ -299,7 +379,7 @@ class TestLoadWithoutExtension(PytorchTestCase):
class CloggedFileObj:
def __init__(self, fileobj):
self.fileobj = fileobj
self.buffer = b''
self.buffer = b""
def read(self, n):
if not self.buffer:
......@@ -310,163 +390,168 @@ class CloggedFileObj:
@skipIfNoSox
@skipIfNoExec('sox')
@skipIfNoExec("sox")
class TestFileObject(TempDirMixin, PytorchTestCase):
"""
In this test suite, the result of file-like object input is compared against file path input,
because `load` function is rigrously tested for file path inputs to match libsox's result,
"""
@parameterized.expand([
('wav', {'bit_depth': 16}),
('wav', {'bit_depth': 24}),
('wav', {'bit_depth': 32}),
('mp3', {'compression': 128}),
('mp3', {'compression': 320}),
('flac', {'compression': 0}),
('flac', {'compression': 5}),
('flac', {'compression': 8}),
('vorbis', {'compression': -1}),
('vorbis', {'compression': 10}),
('amb', {}),
])
@parameterized.expand(
[
("wav", {"bit_depth": 16}),
("wav", {"bit_depth": 24}),
("wav", {"bit_depth": 32}),
("mp3", {"compression": 128}),
("mp3", {"compression": 320}),
("flac", {"compression": 0}),
("flac", {"compression": 5}),
("flac", {"compression": 8}),
("vorbis", {"compression": -1}),
("vorbis", {"compression": 10}),
("amb", {}),
]
)
def test_fileobj(self, ext, kwargs):
"""Loading audio via file object returns the same result as via file path."""
sample_rate = 16000
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')
format_ = ext if ext in ["mp3"] else None
path = self.get_temp_path(f"test.{ext}")
sox_utils.gen_audio_file(
path, sample_rate, num_channels=2, **kwargs)
sox_utils.gen_audio_file(path, sample_rate, num_channels=2, **kwargs)
expected, _ = sox_io_backend.load(path)
with open(path, 'rb') as fileobj:
with open(path, "rb") as fileobj:
found, sr = sox_io_backend.load(fileobj, format=format_)
assert sr == sample_rate
self.assertEqual(expected, found)
@parameterized.expand([
('wav', {'bit_depth': 16}),
('wav', {'bit_depth': 24}),
('wav', {'bit_depth': 32}),
('mp3', {'compression': 128}),
('mp3', {'compression': 320}),
('flac', {'compression': 0}),
('flac', {'compression': 5}),
('flac', {'compression': 8}),
('vorbis', {'compression': -1}),
('vorbis', {'compression': 10}),
('amb', {}),
])
@parameterized.expand(
[
("wav", {"bit_depth": 16}),
("wav", {"bit_depth": 24}),
("wav", {"bit_depth": 32}),
("mp3", {"compression": 128}),
("mp3", {"compression": 320}),
("flac", {"compression": 0}),
("flac", {"compression": 5}),
("flac", {"compression": 8}),
("vorbis", {"compression": -1}),
("vorbis", {"compression": 10}),
("amb", {}),
]
)
def test_bytesio(self, ext, kwargs):
"""Loading audio via BytesIO object returns the same result as via file path."""
sample_rate = 16000
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')
format_ = ext if ext in ["mp3"] else None
path = self.get_temp_path(f"test.{ext}")
sox_utils.gen_audio_file(
path, sample_rate, num_channels=2, **kwargs)
sox_utils.gen_audio_file(path, sample_rate, num_channels=2, **kwargs)
expected, _ = sox_io_backend.load(path)
with open(path, 'rb') as file_:
with open(path, "rb") as file_:
fileobj = io.BytesIO(file_.read())
found, sr = sox_io_backend.load(fileobj, format=format_)
assert sr == sample_rate
self.assertEqual(expected, found)
@parameterized.expand([
('wav', {'bit_depth': 16}),
('wav', {'bit_depth': 24}),
('wav', {'bit_depth': 32}),
('mp3', {'compression': 128}),
('mp3', {'compression': 320}),
('flac', {'compression': 0}),
('flac', {'compression': 5}),
('flac', {'compression': 8}),
('vorbis', {'compression': -1}),
('vorbis', {'compression': 10}),
('amb', {}),
])
@parameterized.expand(
[
("wav", {"bit_depth": 16}),
("wav", {"bit_depth": 24}),
("wav", {"bit_depth": 32}),
("mp3", {"compression": 128}),
("mp3", {"compression": 320}),
("flac", {"compression": 0}),
("flac", {"compression": 5}),
("flac", {"compression": 8}),
("vorbis", {"compression": -1}),
("vorbis", {"compression": 10}),
("amb", {}),
]
)
def test_bytesio_clogged(self, ext, kwargs):
"""Loading audio via clogged file object returns the same result as via file path.
This test case validates the case where fileobject returns shorter bytes than requeted.
"""
sample_rate = 16000
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')
format_ = ext if ext in ["mp3"] else None
path = self.get_temp_path(f"test.{ext}")
sox_utils.gen_audio_file(
path, sample_rate, num_channels=2, **kwargs)
sox_utils.gen_audio_file(path, sample_rate, num_channels=2, **kwargs)
expected, _ = sox_io_backend.load(path)
with open(path, 'rb') as file_:
with open(path, "rb") as file_:
fileobj = CloggedFileObj(io.BytesIO(file_.read()))
found, sr = sox_io_backend.load(fileobj, format=format_)
assert sr == sample_rate
self.assertEqual(expected, found)
@parameterized.expand([
('wav', {'bit_depth': 16}),
('wav', {'bit_depth': 24}),
('wav', {'bit_depth': 32}),
('mp3', {'compression': 128}),
('mp3', {'compression': 320}),
('flac', {'compression': 0}),
('flac', {'compression': 5}),
('flac', {'compression': 8}),
('vorbis', {'compression': -1}),
('vorbis', {'compression': 10}),
('amb', {}),
])
@parameterized.expand(
[
("wav", {"bit_depth": 16}),
("wav", {"bit_depth": 24}),
("wav", {"bit_depth": 32}),
("mp3", {"compression": 128}),
("mp3", {"compression": 320}),
("flac", {"compression": 0}),
("flac", {"compression": 5}),
("flac", {"compression": 8}),
("vorbis", {"compression": -1}),
("vorbis", {"compression": 10}),
("amb", {}),
]
)
def test_bytesio_tiny(self, ext, kwargs):
"""Loading very small audio via file object returns the same result as via file path.
"""
"""Loading very small audio via file object returns the same result as via file path."""
sample_rate = 16000
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')
format_ = ext if ext in ["mp3"] else None
path = self.get_temp_path(f"test.{ext}")
sox_utils.gen_audio_file(
path, sample_rate, num_channels=2, duration=1 / 1600, **kwargs)
sox_utils.gen_audio_file(path, sample_rate, num_channels=2, duration=1 / 1600, **kwargs)
expected, _ = sox_io_backend.load(path)
with open(path, 'rb') as file_:
with open(path, "rb") as file_:
fileobj = io.BytesIO(file_.read())
found, sr = sox_io_backend.load(fileobj, format=format_)
assert sr == sample_rate
self.assertEqual(expected, found)
@parameterized.expand([
('wav', {'bit_depth': 16}),
('wav', {'bit_depth': 24}),
('wav', {'bit_depth': 32}),
('mp3', {'compression': 128}),
('mp3', {'compression': 320}),
('flac', {'compression': 0}),
('flac', {'compression': 5}),
('flac', {'compression': 8}),
('vorbis', {'compression': -1}),
('vorbis', {'compression': 10}),
('amb', {}),
])
@parameterized.expand(
[
("wav", {"bit_depth": 16}),
("wav", {"bit_depth": 24}),
("wav", {"bit_depth": 32}),
("mp3", {"compression": 128}),
("mp3", {"compression": 320}),
("flac", {"compression": 0}),
("flac", {"compression": 5}),
("flac", {"compression": 8}),
("vorbis", {"compression": -1}),
("vorbis", {"compression": 10}),
("amb", {}),
]
)
def test_tarfile(self, ext, kwargs):
"""Loading compressed audio via file-like object returns the same result as via file path."""
sample_rate = 16000
format_ = ext if ext in ['mp3'] else None
audio_file = f'test.{ext}'
format_ = ext if ext in ["mp3"] else None
audio_file = f"test.{ext}"
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path('archive.tar.gz')
archive_path = self.get_temp_path("archive.tar.gz")
sox_utils.gen_audio_file(
audio_path, sample_rate, num_channels=2, **kwargs)
sox_utils.gen_audio_file(audio_path, sample_rate, num_channels=2, **kwargs)
expected, _ = sox_io_backend.load(audio_path)
with tarfile.TarFile(archive_path, 'w') as tarobj:
with tarfile.TarFile(archive_path, "w") as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj:
with tarfile.TarFile(archive_path, "r") as tarobj:
fileobj = tarobj.extractfile(audio_file)
found, sr = sox_io_backend.load(fileobj, format=format_)
......@@ -475,30 +560,31 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
@skipIfNoSox
@skipIfNoExec('sox')
@skipIfNoExec("sox")
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
@parameterized.expand([
('wav', {'bit_depth': 16}),
('wav', {'bit_depth': 24}),
('wav', {'bit_depth': 32}),
('mp3', {'compression': 128}),
('mp3', {'compression': 320}),
('flac', {'compression': 0}),
('flac', {'compression': 5}),
('flac', {'compression': 8}),
('vorbis', {'compression': -1}),
('vorbis', {'compression': 10}),
('amb', {}),
])
@parameterized.expand(
[
("wav", {"bit_depth": 16}),
("wav", {"bit_depth": 24}),
("wav", {"bit_depth": 32}),
("mp3", {"compression": 128}),
("mp3", {"compression": 320}),
("flac", {"compression": 0}),
("flac", {"compression": 5}),
("flac", {"compression": 8}),
("vorbis", {"compression": -1}),
("vorbis", {"compression": 10}),
("amb", {}),
]
)
def test_requests(self, ext, kwargs):
sample_rate = 16000
format_ = ext if ext in ['mp3'] else None
audio_file = f'test.{ext}'
format_ = ext if ext in ["mp3"] else None
audio_file = f"test.{ext}"
audio_path = self.get_temp_path(audio_file)
sox_utils.gen_audio_file(
audio_path, sample_rate, num_channels=2, **kwargs)
sox_utils.gen_audio_file(audio_path, sample_rate, num_channels=2, **kwargs)
expected, _ = sox_io_backend.load(audio_path)
url = self.get_url(audio_file)
......@@ -508,17 +594,22 @@ class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
assert sr == sample_rate
self.assertEqual(expected, found)
@parameterized.expand(list(itertools.product(
[0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
[0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
)
),
name_func=name_func,
)
def test_frame(self, frame_offset, num_frames):
"""num_frames and frame_offset correctly specify the region of data"""
sample_rate = 8000
audio_file = 'test.wav'
audio_file = "test.wav"
audio_path = self.get_temp_path(audio_file)
original = get_wav_data('float32', num_channels=2)
original = get_wav_data("float32", num_channels=2)
save_wav(audio_path, original, sample_rate)
frame_end = None if num_frames == -1 else frame_offset + num_frames
expected = original[:, frame_offset:frame_end]
......
import itertools
from torchaudio.backend import sox_io_backend
from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio_unittest.common_utils import (
TempDirMixin,
PytorchTestCase,
......@@ -10,44 +9,56 @@ from torchaudio_unittest.common_utils import (
skipIfNoSox,
get_wav_data,
)
from .common import (
name_func,
get_enc_params,
)
@skipIfNoExec('sox')
@skipIfNoExec("sox")
@skipIfNoSox
class TestRoundTripIO(TempDirMixin, PytorchTestCase):
"""save/load round trip should not degrade data for lossless formats"""
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
["float32", "int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_wav(self, dtype, sample_rate, num_channels):
"""save/load round trip should not degrade data for wav formats"""
original = get_wav_data(dtype, num_channels, normalize=False)
enc, bps = get_enc_params(dtype)
data = original
for i in range(10):
path = self.get_temp_path(f'{i}.wav')
path = self.get_temp_path(f"{i}.wav")
sox_io_backend.save(path, data, sample_rate, encoding=enc, bits_per_sample=bps)
data, sr = sox_io_backend.load(path, normalize=False)
assert sr == sample_rate
self.assertEqual(original, data)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)
),
name_func=name_func,
)
def test_flac(self, sample_rate, num_channels, compression_level):
"""save/load round trip should not degrade data for flac formats"""
original = get_wav_data('float32', num_channels)
original = get_wav_data("float32", num_channels)
data = original
for i in range(10):
path = self.get_temp_path(f'{i}.flac')
path = self.get_temp_path(f"{i}.flac")
sox_io_backend.save(path, data, sample_rate, compression=compression_level)
data, sr = sox_io_backend.load(path)
assert sr == sample_rate
......
......@@ -3,9 +3,8 @@ import os
import unittest
import torch
from torchaudio.backend import sox_io_backend
from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio_unittest.common_utils import (
TempDirMixin,
TorchaudioTestCase,
......@@ -18,6 +17,7 @@ from torchaudio_unittest.common_utils import (
sox_utils,
nested_params,
)
from .common import (
name_func,
get_enc_params,
......@@ -26,28 +26,28 @@ from .common import (
def _get_sox_encoding(encoding):
encodings = {
'PCM_F': 'floating-point',
'PCM_S': 'signed-integer',
'PCM_U': 'unsigned-integer',
'ULAW': 'u-law',
'ALAW': 'a-law',
"PCM_F": "floating-point",
"PCM_S": "signed-integer",
"PCM_U": "unsigned-integer",
"ULAW": "u-law",
"ALAW": "a-law",
}
return encodings.get(encoding)
class SaveTestBase(TempDirMixin, TorchaudioTestCase):
def assert_save_consistency(
self,
format: str,
*,
compression: float = None,
encoding: str = None,
bits_per_sample: int = None,
sample_rate: float = 8000,
num_channels: int = 2,
num_frames: float = 3 * 8000,
src_dtype: str = 'int32',
test_mode: str = "path",
self,
format: str,
*,
compression: float = None,
encoding: str = None,
bits_per_sample: int = None,
sample_rate: float = 8000,
num_channels: int = 2,
num_frames: float = 3 * 8000,
src_dtype: str = "int32",
test_mode: str = "path",
):
"""`save` function produces file that is comparable with `sox` command
......@@ -86,14 +86,14 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
tensor -------> compare <--------- tensor
"""
cmp_encoding = 'floating-point'
cmp_encoding = "floating-point"
cmp_bit_depth = 32
src_path = self.get_temp_path('1.source.wav')
tgt_path = self.get_temp_path(f'2.1.torchaudio.{format}')
tst_path = self.get_temp_path('2.2.result.wav')
sox_path = self.get_temp_path(f'3.1.sox.{format}')
ref_path = self.get_temp_path('3.2.ref.wav')
src_path = self.get_temp_path("1.source.wav")
tgt_path = self.get_temp_path(f"2.1.torchaudio.{format}")
tst_path = self.get_temp_path("2.2.result.wav")
sox_path = self.get_temp_path(f"3.1.sox.{format}")
ref_path = self.get_temp_path("3.2.ref.wav")
# 1. Generate original wav
data = get_wav_data(src_dtype, num_channels, normalize=False, num_frames=num_frames)
......@@ -103,78 +103,84 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
data = load_wav(src_path, normalize=False)[0]
if test_mode == "path":
sox_io_backend.save(
tgt_path, data, sample_rate,
compression=compression, encoding=encoding, bits_per_sample=bits_per_sample)
tgt_path, data, sample_rate, compression=compression, encoding=encoding, bits_per_sample=bits_per_sample
)
elif test_mode == "fileobj":
with open(tgt_path, 'bw') as file_:
with open(tgt_path, "bw") as file_:
sox_io_backend.save(
file_, data, sample_rate,
format=format, compression=compression,
encoding=encoding, bits_per_sample=bits_per_sample)
file_,
data,
sample_rate,
format=format,
compression=compression,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
elif test_mode == "bytesio":
file_ = io.BytesIO()
sox_io_backend.save(
file_, data, sample_rate,
format=format, compression=compression,
encoding=encoding, bits_per_sample=bits_per_sample)
file_,
data,
sample_rate,
format=format,
compression=compression,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
file_.seek(0)
with open(tgt_path, 'bw') as f:
with open(tgt_path, "bw") as f:
f.write(file_.read())
else:
raise ValueError(f"Unexpected test mode: {test_mode}")
# 2.2. Convert the target format to wav with sox
sox_utils.convert_audio_file(
tgt_path, tst_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
sox_utils.convert_audio_file(tgt_path, tst_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
# 2.3. Load with SciPy
found = load_wav(tst_path, normalize=False)[0]
# 3.1. Convert the original wav to target format with sox
sox_encoding = _get_sox_encoding(encoding)
sox_utils.convert_audio_file(
src_path, sox_path,
compression=compression, encoding=sox_encoding, bit_depth=bits_per_sample)
src_path, sox_path, compression=compression, encoding=sox_encoding, bit_depth=bits_per_sample
)
# 3.2. Convert the target format to wav with sox
sox_utils.convert_audio_file(
sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
sox_utils.convert_audio_file(sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
# 3.3. Load with SciPy
expected = load_wav(ref_path, normalize=False)[0]
self.assertEqual(found, expected)
@skipIfNoExec('sox')
@skipIfNoExec("sox")
@skipIfNoSox
class SaveTest(SaveTestBase):
@nested_params(
["path", "fileobj", "bytesio"],
[
('PCM_U', 8),
('PCM_S', 16),
('PCM_S', 32),
('PCM_F', 32),
('PCM_F', 64),
('ULAW', 8),
('ALAW', 8),
("PCM_U", 8),
("PCM_S", 16),
("PCM_S", 32),
("PCM_F", 32),
("PCM_F", 64),
("ULAW", 8),
("ALAW", 8),
],
)
def test_save_wav(self, test_mode, enc_params):
encoding, bits_per_sample = enc_params
self.assert_save_consistency(
"wav", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
self.assert_save_consistency("wav", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
@nested_params(
["path", "fileobj", "bytesio"],
[
('float32', ),
('int32', ),
('int16', ),
('uint8', ),
("float32",),
("int32",),
("int16",),
("uint8",),
],
)
def test_save_wav_dtype(self, test_mode, params):
dtype, = params
self.assert_save_consistency(
"wav", src_dtype=dtype, test_mode=test_mode)
(dtype,) = params
self.assert_save_consistency("wav", src_dtype=dtype, test_mode=test_mode)
@nested_params(
["path", "fileobj", "bytesio"],
......@@ -197,10 +203,9 @@ class SaveTest(SaveTestBase):
if test_mode in ["fileobj", "bytesio"]:
if bit_rate is not None and bit_rate < 1:
raise unittest.SkipTest(
"mp3 format with variable bit rate is known to "
"not yield the exact same result as sox command.")
self.assert_save_consistency(
"mp3", compression=bit_rate, test_mode=test_mode)
"mp3 format with variable bit rate is known to " "not yield the exact same result as sox command."
)
self.assert_save_consistency("mp3", compression=bit_rate, test_mode=test_mode)
@nested_params(
["path", "fileobj", "bytesio"],
......@@ -220,8 +225,8 @@ class SaveTest(SaveTestBase):
)
def test_save_flac(self, test_mode, bits_per_sample, compression_level):
self.assert_save_consistency(
"flac", compression=compression_level,
bits_per_sample=bits_per_sample, test_mode=test_mode)
"flac", compression=compression_level, bits_per_sample=bits_per_sample, test_mode=test_mode
)
@nested_params(
["path", "fileobj", "bytesio"],
......@@ -244,45 +249,78 @@ class SaveTest(SaveTestBase):
],
)
def test_save_vorbis(self, test_mode, quality_level):
self.assert_save_consistency(
"vorbis", compression=quality_level, test_mode=test_mode)
self.assert_save_consistency("vorbis", compression=quality_level, test_mode=test_mode)
@nested_params(
["path", "fileobj", "bytesio"],
[
('PCM_S', 8, ),
('PCM_S', 16, ),
('PCM_S', 24, ),
('PCM_S', 32, ),
('ULAW', 8),
('ALAW', 8),
('ALAW', 16),
('ALAW', 24),
('ALAW', 32),
(
"PCM_S",
8,
),
(
"PCM_S",
16,
),
(
"PCM_S",
24,
),
(
"PCM_S",
32,
),
("ULAW", 8),
("ALAW", 8),
("ALAW", 16),
("ALAW", 24),
("ALAW", 32),
],
)
def test_save_sphere(self, test_mode, enc_params):
encoding, bits_per_sample = enc_params
self.assert_save_consistency(
"sph", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
self.assert_save_consistency("sph", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
@nested_params(
["path", "fileobj", "bytesio"],
[
('PCM_U', 8, ),
('PCM_S', 16, ),
('PCM_S', 24, ),
('PCM_S', 32, ),
('PCM_F', 32, ),
('PCM_F', 64, ),
('ULAW', 8, ),
('ALAW', 8, ),
(
"PCM_U",
8,
),
(
"PCM_S",
16,
),
(
"PCM_S",
24,
),
(
"PCM_S",
32,
),
(
"PCM_F",
32,
),
(
"PCM_F",
64,
),
(
"ULAW",
8,
),
(
"ALAW",
8,
),
],
)
def test_save_amb(self, test_mode, enc_params):
encoding, bits_per_sample = enc_params
self.assert_save_consistency(
"amb", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
self.assert_save_consistency("amb", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
@nested_params(
["path", "fileobj", "bytesio"],
......@@ -299,90 +337,94 @@ class SaveTest(SaveTestBase):
],
)
def test_save_amr_nb(self, test_mode, bit_rate):
self.assert_save_consistency(
"amr-nb", compression=bit_rate, num_channels=1, test_mode=test_mode)
self.assert_save_consistency("amr-nb", compression=bit_rate, num_channels=1, test_mode=test_mode)
@nested_params(
["path", "fileobj", "bytesio"],
)
def test_save_gsm(self, test_mode):
self.assert_save_consistency(
"gsm", num_channels=1, test_mode=test_mode)
with self.assertRaises(
RuntimeError, msg="gsm format only supports single channel audio."):
self.assert_save_consistency(
"gsm", num_channels=2, test_mode=test_mode)
with self.assertRaises(
RuntimeError, msg="gsm format only supports a sampling rate of 8kHz."):
self.assert_save_consistency(
"gsm", sample_rate=16000, test_mode=test_mode)
@parameterized.expand([
("wav", "PCM_S", 16),
("mp3", ),
("flac", ),
("vorbis", ),
("sph", "PCM_S", 16),
("amr-nb", ),
("amb", "PCM_S", 16),
], name_func=name_func)
self.assert_save_consistency("gsm", num_channels=1, test_mode=test_mode)
with self.assertRaises(RuntimeError, msg="gsm format only supports single channel audio."):
self.assert_save_consistency("gsm", num_channels=2, test_mode=test_mode)
with self.assertRaises(RuntimeError, msg="gsm format only supports a sampling rate of 8kHz."):
self.assert_save_consistency("gsm", sample_rate=16000, test_mode=test_mode)
@parameterized.expand(
[
("wav", "PCM_S", 16),
("mp3",),
("flac",),
("vorbis",),
("sph", "PCM_S", 16),
("amr-nb",),
("amb", "PCM_S", 16),
],
name_func=name_func,
)
def test_save_large(self, format, encoding=None, bits_per_sample=None):
"""`sox_io_backend.save` can save large files."""
sample_rate = 8000
one_hour = 60 * 60 * sample_rate
self.assert_save_consistency(
format, num_channels=1, sample_rate=8000, num_frames=one_hour,
encoding=encoding, bits_per_sample=bits_per_sample)
@parameterized.expand([
(32, ),
(64, ),
(128, ),
(256, ),
], name_func=name_func)
format,
num_channels=1,
sample_rate=8000,
num_frames=one_hour,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
@parameterized.expand(
[
(32,),
(64,),
(128,),
(256,),
],
name_func=name_func,
)
def test_save_multi_channels(self, num_channels):
"""`sox_io_backend.save` can save audio with many channels"""
self.assert_save_consistency(
"wav", encoding="PCM_S", bits_per_sample=16,
num_channels=num_channels)
self.assert_save_consistency("wav", encoding="PCM_S", bits_per_sample=16, num_channels=num_channels)
@skipIfNoExec('sox')
@skipIfNoExec("sox")
@skipIfNoSox
class TestSaveParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of optional parameters of `sox_io_backend.save`"""
@parameterized.expand([(True, ), (False, )], name_func=name_func)
@parameterized.expand([(True,), (False,)], name_func=name_func)
def test_save_channels_first(self, channels_first):
"""channels_first swaps axes"""
path = self.get_temp_path('data.wav')
data = get_wav_data(
'int16', 2, channels_first=channels_first, normalize=False)
sox_io_backend.save(
path, data, 8000, channels_first=channels_first)
path = self.get_temp_path("data.wav")
data = get_wav_data("int16", 2, channels_first=channels_first, normalize=False)
sox_io_backend.save(path, data, 8000, channels_first=channels_first)
found = load_wav(path, normalize=False)[0]
expected = data if channels_first else data.transpose(1, 0)
self.assertEqual(found, expected)
@parameterized.expand([
'float32', 'int32', 'int16', 'uint8'
], name_func=name_func)
@parameterized.expand(["float32", "int32", "int16", "uint8"], name_func=name_func)
def test_save_noncontiguous(self, dtype):
"""Noncontiguous tensors are saved correctly"""
path = self.get_temp_path('data.wav')
path = self.get_temp_path("data.wav")
enc, bps = get_enc_params(dtype)
expected = get_wav_data(dtype, 4, normalize=False)[::2, ::2]
assert not expected.is_contiguous()
sox_io_backend.save(
path, expected, 8000, encoding=enc, bits_per_sample=bps)
sox_io_backend.save(path, expected, 8000, encoding=enc, bits_per_sample=bps)
found = load_wav(path, normalize=False)[0]
self.assertEqual(found, expected)
@parameterized.expand([
'float32', 'int32', 'int16', 'uint8',
])
@parameterized.expand(
[
"float32",
"int32",
"int16",
"uint8",
]
)
def test_save_tensor_preserve(self, dtype):
"""save function should not alter Tensor"""
path = self.get_temp_path('data.wav')
path = self.get_temp_path("data.wav")
expected = get_wav_data(dtype, 4, normalize=False)[::2, ::2]
data = expected.clone()
......
......@@ -2,25 +2,24 @@ import io
import itertools
import unittest
from torchaudio.utils import sox_utils
from torchaudio.backend import sox_io_backend
from torchaudio._internal.module_utils import is_sox_available
from parameterized import parameterized
from torchaudio._internal.module_utils import is_sox_available
from torchaudio.backend import sox_io_backend
from torchaudio.utils import sox_utils
from torchaudio_unittest.common_utils import (
TempDirMixin,
TorchaudioTestCase,
skipIfNoSox,
get_wav_data,
)
from .common import name_func
skipIfNoMP3 = unittest.skipIf(
not is_sox_available() or
'mp3' not in sox_utils.list_read_formats() or
'mp3' not in sox_utils.list_write_formats(),
'"sox_io" backend does not support MP3')
not is_sox_available() or "mp3" not in sox_utils.list_read_formats() or "mp3" not in sox_utils.list_write_formats(),
'"sox_io" backend does not support MP3',
)
@skipIfNoSox
......@@ -33,10 +32,11 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase):
This test suite should be able to run without any additional tools (such as sox command),
however without such tools, the correctness of each function cannot be verified.
"""
def run_smoke_test(self, ext, sample_rate, num_channels, *, compression=None, dtype='float32'):
def run_smoke_test(self, ext, sample_rate, num_channels, *, compression=None, dtype="float32"):
duration = 1
num_frames = sample_rate * duration
path = self.get_temp_path(f'test.{ext}')
path = self.get_temp_path(f"test.{ext}")
original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames)
# 1. run save
......@@ -50,42 +50,60 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase):
assert sr == sample_rate
assert loaded.shape[0] == num_channels
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
["float32", "int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_wav(self, dtype, sample_rate, num_channels):
"""Run smoke test on wav format"""
self.run_smoke_test('wav', sample_rate, num_channels, dtype=dtype)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320],
)))
self.run_smoke_test("wav", sample_rate, num_channels, dtype=dtype)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320],
)
)
)
@skipIfNoMP3
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""Run smoke test on mp3 format"""
self.run_smoke_test('mp3', sample_rate, num_channels, compression=bit_rate)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)))
self.run_smoke_test("mp3", sample_rate, num_channels, compression=bit_rate)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)
)
)
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""Run smoke test on vorbis format"""
self.run_smoke_test('vorbis', sample_rate, num_channels, compression=quality_level)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=name_func)
self.run_smoke_test("vorbis", sample_rate, num_channels, compression=quality_level)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)
),
name_func=name_func,
)
def test_flac(self, sample_rate, num_channels, compression_level):
"""Run smoke test on flac format"""
self.run_smoke_test('flac', sample_rate, num_channels, compression=compression_level)
self.run_smoke_test("flac", sample_rate, num_channels, compression=compression_level)
@skipIfNoSox
......@@ -98,7 +116,8 @@ class SmokeTestFileObj(TorchaudioTestCase):
This test suite should be able to run without any additional tools (such as sox command),
however without such tools, the correctness of each function cannot be verified.
"""
def run_smoke_test(self, ext, sample_rate, num_channels, *, compression=None, dtype='float32'):
def run_smoke_test(self, ext, sample_rate, num_channels, *, compression=None, dtype="float32"):
duration = 1
num_frames = sample_rate * duration
original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames)
......@@ -117,39 +136,57 @@ class SmokeTestFileObj(TorchaudioTestCase):
assert sr == sample_rate
assert loaded.shape[0] == num_channels
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
["float32", "int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_wav(self, dtype, sample_rate, num_channels):
"""Run smoke test on wav format"""
self.run_smoke_test('wav', sample_rate, num_channels, dtype=dtype)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320],
)))
self.run_smoke_test("wav", sample_rate, num_channels, dtype=dtype)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320],
)
)
)
@skipIfNoMP3
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""Run smoke test on mp3 format"""
self.run_smoke_test('mp3', sample_rate, num_channels, compression=bit_rate)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)))
self.run_smoke_test("mp3", sample_rate, num_channels, compression=bit_rate)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)
)
)
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""Run smoke test on vorbis format"""
self.run_smoke_test('vorbis', sample_rate, num_channels, compression=quality_level)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=name_func)
self.run_smoke_test("vorbis", sample_rate, num_channels, compression=quality_level)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)
),
name_func=name_func,
)
def test_flac(self, sample_rate, num_channels, compression_level):
"""Run smoke test on flac format"""
self.run_smoke_test('flac', sample_rate, num_channels, compression=compression_level)
self.run_smoke_test("flac", sample_rate, num_channels, compression=compression_level)
......@@ -4,7 +4,6 @@ from typing import Optional
import torch
import torchaudio
from parameterized import parameterized
from torchaudio_unittest.common_utils import (
TempDirMixin,
TorchaudioTestCase,
......@@ -16,6 +15,7 @@ from torchaudio_unittest.common_utils import (
sox_utils,
torch_script,
)
from .common import (
name_func,
get_enc_params,
......@@ -27,38 +27,41 @@ def py_info_func(filepath: str) -> torchaudio.backend.sox_io_backend.AudioMetaDa
def py_load_func(filepath: str, normalize: bool, channels_first: bool):
return torchaudio.load(
filepath, normalize=normalize, channels_first=channels_first)
return torchaudio.load(filepath, normalize=normalize, channels_first=channels_first)
def py_save_func(
filepath: str,
tensor: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
filepath: str,
tensor: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
):
torchaudio.save(
filepath, tensor, sample_rate, channels_first,
compression, None, encoding, bits_per_sample)
torchaudio.save(filepath, tensor, sample_rate, channels_first, compression, None, encoding, bits_per_sample)
@skipIfNoExec('sox')
@skipIfNoExec("sox")
@skipIfNoSox
class SoxIO(TempDirMixin, TorchaudioTestCase):
"""TorchScript-ability Test suite for `sox_io_backend`"""
backend = 'sox_io'
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=name_func)
backend = "sox_io"
@parameterized.expand(
list(
itertools.product(
["float32", "int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_info_wav(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` is torchscript-able and returns the same result"""
audio_path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
audio_path = self.get_temp_path(f"{dtype}_{sample_rate}_{num_channels}.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate)
save_wav(audio_path, data, sample_rate)
......@@ -71,40 +74,48 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
assert py_info.num_frames == ts_info.num_frames
assert py_info.num_channels == ts_info.num_channels
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
[False, True],
[False, True],
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
["float32", "int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
[False, True],
[False, True],
)
),
name_func=name_func,
)
def test_load_wav(self, dtype, sample_rate, num_channels, normalize, channels_first):
"""`sox_io_backend.load` is torchscript-able and returns the same result"""
audio_path = self.get_temp_path(f'test_load_{dtype}_{sample_rate}_{num_channels}_{normalize}.wav')
audio_path = self.get_temp_path(f"test_load_{dtype}_{sample_rate}_{num_channels}_{normalize}.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate)
save_wav(audio_path, data, sample_rate)
ts_load_func = torch_script(py_load_func)
py_data, py_sr = py_load_func(
audio_path, normalize=normalize, channels_first=channels_first)
ts_data, ts_sr = ts_load_func(
audio_path, normalize=normalize, channels_first=channels_first)
py_data, py_sr = py_load_func(audio_path, normalize=normalize, channels_first=channels_first)
ts_data, ts_sr = ts_load_func(audio_path, normalize=normalize, channels_first=channels_first)
self.assertEqual(py_sr, ts_sr)
self.assertEqual(py_data, ts_data)
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
["float32", "int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_save_wav(self, dtype, sample_rate, num_channels):
ts_save_func = torch_script(py_save_func)
expected = get_wav_data(dtype, num_channels, normalize=False)
py_path = self.get_temp_path(f'test_save_py_{dtype}_{sample_rate}_{num_channels}.wav')
ts_path = self.get_temp_path(f'test_save_ts_{dtype}_{sample_rate}_{num_channels}.wav')
py_path = self.get_temp_path(f"test_save_py_{dtype}_{sample_rate}_{num_channels}.wav")
ts_path = self.get_temp_path(f"test_save_ts_{dtype}_{sample_rate}_{num_channels}.wav")
enc, bps = get_enc_params(dtype)
py_save_func(py_path, expected, sample_rate, True, None, enc, bps)
......@@ -118,24 +129,29 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
self.assertEqual(expected, py_data)
self.assertEqual(expected, ts_data)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=name_func)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)
),
name_func=name_func,
)
def test_save_flac(self, sample_rate, num_channels, compression_level):
ts_save_func = torch_script(py_save_func)
expected = get_wav_data('float32', num_channels)
py_path = self.get_temp_path(f'test_save_py_{sample_rate}_{num_channels}_{compression_level}.flac')
ts_path = self.get_temp_path(f'test_save_ts_{sample_rate}_{num_channels}_{compression_level}.flac')
expected = get_wav_data("float32", num_channels)
py_path = self.get_temp_path(f"test_save_py_{sample_rate}_{num_channels}_{compression_level}.flac")
ts_path = self.get_temp_path(f"test_save_ts_{sample_rate}_{num_channels}_{compression_level}.flac")
py_save_func(py_path, expected, sample_rate, True, compression_level, None, None)
ts_save_func(ts_path, expected, sample_rate, True, compression_level, None, None)
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
py_path_wav = f'{py_path}.wav'
ts_path_wav = f'{ts_path}.wav'
py_path_wav = f"{py_path}.wav"
ts_path_wav = f"{ts_path}.wav"
sox_utils.convert_audio_file(py_path, py_path_wav, bit_depth=32)
sox_utils.convert_audio_file(ts_path, ts_path_wav, bit_depth=32)
......
import torchaudio
from torchaudio_unittest import common_utils
class BackendSwitchMixin:
"""Test set/get_audio_backend works"""
backend = None
backend_module = None
......@@ -26,11 +26,11 @@ class TestBackendSwitch_NoBackend(BackendSwitchMixin, common_utils.TorchaudioTes
@common_utils.skipIfNoSox
class TestBackendSwitch_SoXIO(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'sox_io'
backend = "sox_io"
backend_module = torchaudio.backend.sox_io_backend
@common_utils.skipIfNoModule('soundfile')
@common_utils.skipIfNoModule("soundfile")
class TestBackendSwitch_soundfile(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'soundfile'
backend = "soundfile"
backend_module = torchaudio.backend.soundfile_backend
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment