Commit ffeba11a authored by mayp777's avatar mayp777
Browse files

UPDATE

parent 29deb085
...@@ -117,7 +117,6 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -117,7 +117,6 @@ class TestInfo(TempDirMixin, PytorchTestCase):
with patch("soundfile.info", _mock_info_func): with patch("soundfile.info", _mock_info_func):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
info = soundfile_backend.info("foo") info = soundfile_backend.info("foo")
assert len(w) == 1
assert "UNSEEN_SUBTYPE subtype is unknown to TorchAudio" in str(w[-1].message) assert "UNSEEN_SUBTYPE subtype is unknown to TorchAudio" in str(w[-1].message)
assert info.bits_per_sample == 0 assert info.bits_per_sample == 0
......
import io
import itertools import itertools
import os
import tarfile
from contextlib import contextmanager
from parameterized import parameterized from parameterized import parameterized
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import sox_io_backend from torchaudio.backend import sox_io_backend
from torchaudio.utils.sox_utils import get_buffer_size, set_buffer_size from torchaudio_unittest.backend.common import get_encoding
from torchaudio_unittest.backend.common import get_bits_per_sample, get_encoding
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
get_asset_path, get_asset_path,
get_wav_data, get_wav_data,
HttpServerMixin,
PytorchTestCase, PytorchTestCase,
save_wav, save_wav,
skipIfNoExec, skipIfNoExec,
skipIfNoModule,
skipIfNoSox, skipIfNoSox,
skipIfNoSoxDecoder,
sox_utils, sox_utils,
TempDirMixin, TempDirMixin,
) )
...@@ -25,10 +18,6 @@ from torchaudio_unittest.common_utils import ( ...@@ -25,10 +18,6 @@ from torchaudio_unittest.common_utils import (
from .common import name_func from .common import name_func
if _mod_utils.is_module_available("requests"):
import requests
@skipIfNoExec("sox") @skipIfNoExec("sox")
@skipIfNoSox @skipIfNoSox
class TestInfo(TempDirMixin, PytorchTestCase): class TestInfo(TempDirMixin, PytorchTestCase):
...@@ -208,6 +197,7 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -208,6 +197,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.bits_per_sample == bits_per_sample assert info.bits_per_sample == bits_per_sample
assert info.encoding == get_encoding("amb", dtype) assert info.encoding == get_encoding("amb", dtype)
@skipIfNoSoxDecoder("amr-nb")
def test_amr_nb(self): def test_amr_nb(self):
"""`sox_io_backend.info` can check amr-nb file correctly""" """`sox_io_backend.info` can check amr-nb file correctly"""
duration = 1 duration = 1
...@@ -287,6 +277,7 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -287,6 +277,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
@skipIfNoSox @skipIfNoSox
@skipIfNoSoxDecoder("opus")
class TestInfoOpus(PytorchTestCase): class TestInfoOpus(PytorchTestCase):
@parameterized.expand( @parameterized.expand(
list( list(
...@@ -314,283 +305,19 @@ class TestLoadWithoutExtension(PytorchTestCase): ...@@ -314,283 +305,19 @@ class TestLoadWithoutExtension(PytorchTestCase):
def test_mp3(self): def test_mp3(self):
"""MP3 file without extension can be loaded """MP3 file without extension can be loaded
Originally, we added `format` argument for this case, but now we use FFmpeg
for MP3 decoding, which works even without `format` argument.
https://github.com/pytorch/audio/issues/1040 https://github.com/pytorch/audio/issues/1040
The file was generated with the following command The file was generated with the following command
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
""" """
path = get_asset_path("mp3_without_ext") path = get_asset_path("mp3_without_ext")
sinfo = sox_io_backend.info(path) sinfo = sox_io_backend.info(path, format="mp3")
assert sinfo.sample_rate == 16000 assert sinfo.sample_rate == 16000
assert sinfo.num_frames == 80000 assert sinfo.num_frames == 81216
assert sinfo.num_channels == 1 assert sinfo.num_channels == 1
assert sinfo.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats assert sinfo.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
assert sinfo.encoding == "MP3" assert sinfo.encoding == "MP3"
with open(path, "rb") as fileobj:
sinfo = sox_io_backend.info(fileobj, format="mp3")
assert sinfo.sample_rate == 16000
assert sinfo.num_frames == 80000
assert sinfo.num_channels == 1
assert sinfo.bits_per_sample == 0
assert sinfo.encoding == "MP3"
class FileObjTestBase(TempDirMixin):
def _gen_file(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
path = self.get_temp_path(f"test.{ext}")
bit_depth = sox_utils.get_bit_depth(dtype)
duration = num_frames / sample_rate
comment_file = self._gen_comment_file(comments) if comments else None
sox_utils.gen_audio_file(
path,
sample_rate,
num_channels=num_channels,
encoding=sox_utils.get_encoding(dtype),
bit_depth=bit_depth,
duration=duration,
comment_file=comment_file,
)
return path
def _gen_comment_file(self, comments):
comment_path = self.get_temp_path("comment.txt")
with open(comment_path, "w") as file_:
file_.writelines(comments)
return comment_path
class Unseekable:
def __init__(self, fileobj):
self.fileobj = fileobj
def read(self, n):
return self.fileobj.read(n)
@skipIfNoSox
@skipIfNoExec("sox")
class TestFileObject(FileObjTestBase, PytorchTestCase):
def _query_fileobj(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
format_ = ext if ext in ["mp3"] else None
with open(path, "rb") as fileobj:
return sox_io_backend.info(fileobj, format_)
def _query_bytesio(self, ext, dtype, sample_rate, num_channels, num_frames):
path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
format_ = ext if ext in ["mp3"] else None
with open(path, "rb") as file_:
fileobj = io.BytesIO(file_.read())
return sox_io_backend.info(fileobj, format_)
def _query_tarfile(self, ext, dtype, sample_rate, num_channels, num_frames):
audio_path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
audio_file = os.path.basename(audio_path)
archive_path = self.get_temp_path("archive.tar.gz")
with tarfile.TarFile(archive_path, "w") as tarobj:
tarobj.add(audio_path, arcname=audio_file)
format_ = ext if ext in ["mp3"] else None
with tarfile.TarFile(archive_path, "r") as tarobj:
fileobj = tarobj.extractfile(audio_file)
return sox_io_backend.info(fileobj, format_)
@contextmanager
def _set_buffer_size(self, buffer_size):
try:
original_buffer_size = get_buffer_size()
set_buffer_size(buffer_size)
yield
finally:
set_buffer_size(original_buffer_size)
@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
("mp3", "float32"),
("flac", "float32"),
("vorbis", "float32"),
("amb", "int16"),
]
)
def test_fileobj(self, ext, dtype):
"""Querying audio via file object works"""
sample_rate = 16000
num_frames = 3 * sample_rate
num_channels = 2
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = {"vorbis": 0, "mp3": 49536}.get(ext, num_frames)
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand(
[
("vorbis", "float32"),
]
)
def test_fileobj_large_header(self, ext, dtype):
"""
For audio file with header size exceeding default buffer size:
- Querying audio via file object without enlarging buffer size fails.
- Querying audio via file object after enlarging buffer size succeeds.
"""
sample_rate = 16000
num_frames = 3 * sample_rate
num_channels = 2
comments = "metadata=" + " ".join(["value" for _ in range(1000)])
with self.assertRaises(RuntimeError):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
with self._set_buffer_size(16384):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ["vorbis"] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
("mp3", "float32"),
("flac", "float32"),
("vorbis", "float32"),
("amb", "int16"),
]
)
def test_bytesio(self, ext, dtype):
"""Querying audio via ByteIO object works for small data"""
sample_rate = 16000
num_frames = 3 * sample_rate
num_channels = 2
sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = {"vorbis": 0, "mp3": 49536}.get(ext, num_frames)
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
("mp3", "float32"),
("flac", "float32"),
("vorbis", "float32"),
("amb", "int16"),
]
)
def test_bytesio_tiny(self, ext, dtype):
"""Querying audio via ByteIO object works for small data"""
sample_rate = 8000
num_frames = 4
num_channels = 2
sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = {"vorbis": 0, "mp3": 1728}.get(ext, num_frames)
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
("mp3", "float32"),
("flac", "float32"),
("vorbis", "float32"),
("amb", "int16"),
]
)
def test_tarfile(self, ext, dtype):
"""Querying compressed audio via file-like object works"""
sample_rate = 16000
num_frames = 3.0 * sample_rate
num_channels = 2
sinfo = self._query_tarfile(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = {"vorbis": 0, "mp3": 49536}.get(ext, num_frames)
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@skipIfNoSox
@skipIfNoExec("sox")
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase):
def _query_http(self, ext, dtype, sample_rate, num_channels, num_frames):
audio_path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
audio_file = os.path.basename(audio_path)
url = self.get_url(audio_file)
format_ = ext if ext in ["mp3"] else None
with requests.get(url, stream=True) as resp:
return sox_io_backend.info(Unseekable(resp.raw), format=format_)
@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
("mp3", "float32"),
("flac", "float32"),
("vorbis", "float32"),
("amb", "int16"),
]
)
def test_requests(self, ext, dtype):
"""Querying compressed audio via requests works"""
sample_rate = 16000
num_frames = 3.0 * sample_rate
num_channels = 2
sinfo = self._query_http(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = {"vorbis": 0, "mp3": 49536}.get(ext, num_frames)
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@skipIfNoSox @skipIfNoSox
class TestInfoNoSuchFile(PytorchTestCase): class TestInfoNoSuchFile(PytorchTestCase):
......
import io
import itertools import itertools
import tarfile
import torch import torch
import torchaudio
from parameterized import parameterized from parameterized import parameterized
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import sox_io_backend from torchaudio.backend import sox_io_backend
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
get_asset_path, get_asset_path,
get_wav_data, get_wav_data,
HttpServerMixin,
load_wav, load_wav,
nested_params, nested_params,
PytorchTestCase, PytorchTestCase,
save_wav, save_wav,
skipIfNoExec, skipIfNoExec,
skipIfNoModule,
skipIfNoSox, skipIfNoSox,
skipIfNoSoxDecoder,
sox_utils, sox_utils,
TempDirMixin, TempDirMixin,
) )
...@@ -25,10 +20,6 @@ from torchaudio_unittest.common_utils import ( ...@@ -25,10 +20,6 @@ from torchaudio_unittest.common_utils import (
from .common import name_func from .common import name_func
if _mod_utils.is_module_available("requests"):
import requests
class LoadTestBase(TempDirMixin, PytorchTestCase): class LoadTestBase(TempDirMixin, PytorchTestCase):
def assert_format( def assert_format(
self, self,
...@@ -244,6 +235,7 @@ class TestLoad(LoadTestBase): ...@@ -244,6 +235,7 @@ class TestLoad(LoadTestBase):
), ),
name_func=name_func, name_func=name_func,
) )
@skipIfNoSoxDecoder("opus")
def test_opus(self, bitrate, num_channels, compression_level): def test_opus(self, bitrate, num_channels, compression_level):
"""`sox_io_backend.load` can load opus file correctly.""" """`sox_io_backend.load` can load opus file correctly."""
ops_path = get_asset_path("io", f"{bitrate}_{compression_level}_{num_channels}ch.opus") ops_path = get_asset_path("io", f"{bitrate}_{compression_level}_{num_channels}ch.opus")
...@@ -288,6 +280,7 @@ class TestLoad(LoadTestBase): ...@@ -288,6 +280,7 @@ class TestLoad(LoadTestBase):
"amb", sample_rate, num_channels, bit_depth=bit_depth, duration=1, encoding=encoding, normalize=normalize "amb", sample_rate, num_channels, bit_depth=bit_depth, duration=1, encoding=encoding, normalize=normalize
) )
@skipIfNoSoxDecoder("amr-nb")
def test_amr_nb(self): def test_amr_nb(self):
"""`sox_io_backend.load` can load amr_nb format correctly.""" """`sox_io_backend.load` can load amr_nb format correctly."""
self.assert_format("amr-nb", sample_rate=8000, num_channels=1, bit_depth=32, duration=1) self.assert_format("amr-nb", sample_rate=8000, num_channels=1, bit_depth=32, duration=1)
...@@ -322,306 +315,21 @@ class TestLoadParams(TempDirMixin, PytorchTestCase): ...@@ -322,306 +315,21 @@ class TestLoadParams(TempDirMixin, PytorchTestCase):
self._test(torch.ops.torchaudio.sox_io_load_audio_file, frame_offset, num_frames, channels_first, normalize) self._test(torch.ops.torchaudio.sox_io_load_audio_file, frame_offset, num_frames, channels_first, normalize)
# test file-like obj
def func(path, *args):
with open(path, "rb") as fileobj:
return torchaudio._torchaudio.load_audio_fileobj(fileobj, *args)
self._test(func, frame_offset, num_frames, channels_first, normalize)
@nested_params(
[0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
[True, False],
[True, False],
)
def test_ffmpeg(self, frame_offset, num_frames, channels_first, normalize):
"""The combination of properly changes the output tensor"""
from torchaudio.io._compat import load_audio, load_audio_fileobj
self._test(load_audio, frame_offset, num_frames, channels_first, normalize)
# test file-like obj
def func(path, *args):
with open(path, "rb") as fileobj:
return load_audio_fileobj(fileobj, *args)
self._test(func, frame_offset, num_frames, channels_first, normalize)
@skipIfNoSox @skipIfNoSox
class TestLoadWithoutExtension(PytorchTestCase): class TestLoadWithoutExtension(PytorchTestCase):
def test_mp3(self): def test_mp3(self):
"""MP3 file without extension can be loaded """MP3 file without extension can be loaded
Originally, we added `format` argument for this case, but now we use FFmpeg
for MP3 decoding, which works even without `format` argument.
https://github.com/pytorch/audio/issues/1040 https://github.com/pytorch/audio/issues/1040
The file was generated with the following command The file was generated with the following command
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
""" """
path = get_asset_path("mp3_without_ext") path = get_asset_path("mp3_without_ext")
_, sr = sox_io_backend.load(path) _, sr = sox_io_backend.load(path, format="mp3")
assert sr == 16000 assert sr == 16000
with open(path, "rb") as fileobj:
_, sr = sox_io_backend.load(fileobj)
assert sr == 16000
class CloggedFileObj:
def __init__(self, fileobj):
self.fileobj = fileobj
def read(self, _):
return self.fileobj.read(2)
def seek(self, offset, whence):
return self.fileobj.seek(offset, whence)
@skipIfNoSox
@skipIfNoExec("sox")
class TestFileObject(TempDirMixin, PytorchTestCase):
"""
In this test suite, the result of file-like object input is compared against file path input,
because `load` function is rigrously tested for file path inputs to match libsox's result,
"""
@parameterized.expand(
[
("wav", {"bit_depth": 16}),
("wav", {"bit_depth": 24}),
("wav", {"bit_depth": 32}),
("mp3", {"compression": 128}),
("mp3", {"compression": 320}),
("flac", {"compression": 0}),
("flac", {"compression": 5}),
("flac", {"compression": 8}),
("vorbis", {"compression": -1}),
("vorbis", {"compression": 10}),
("amb", {}),
]
)
def test_fileobj(self, ext, kwargs):
"""Loading audio via file object returns the same result as via file path."""
sample_rate = 16000
format_ = ext if ext in ["mp3"] else None
path = self.get_temp_path(f"test.{ext}")
sox_utils.gen_audio_file(path, sample_rate, num_channels=2, **kwargs)
expected, _ = sox_io_backend.load(path)
with open(path, "rb") as fileobj:
found, sr = sox_io_backend.load(fileobj, format=format_)
assert sr == sample_rate
self.assertEqual(expected, found)
@parameterized.expand(
[
("wav", {"bit_depth": 16}),
("wav", {"bit_depth": 24}),
("wav", {"bit_depth": 32}),
("mp3", {"compression": 128}),
("mp3", {"compression": 320}),
("flac", {"compression": 0}),
("flac", {"compression": 5}),
("flac", {"compression": 8}),
("vorbis", {"compression": -1}),
("vorbis", {"compression": 10}),
("amb", {}),
]
)
def test_bytesio(self, ext, kwargs):
"""Loading audio via BytesIO object returns the same result as via file path."""
sample_rate = 16000
format_ = ext if ext in ["mp3"] else None
path = self.get_temp_path(f"test.{ext}")
sox_utils.gen_audio_file(path, sample_rate, num_channels=2, **kwargs)
expected, _ = sox_io_backend.load(path)
with open(path, "rb") as file_:
fileobj = io.BytesIO(file_.read())
found, sr = sox_io_backend.load(fileobj, format=format_)
assert sr == sample_rate
self.assertEqual(expected, found)
@parameterized.expand(
[
("wav", {"bit_depth": 16}),
("wav", {"bit_depth": 24}),
("wav", {"bit_depth": 32}),
("mp3", {"compression": 128}),
("mp3", {"compression": 320}),
("flac", {"compression": 0}),
("flac", {"compression": 5}),
("flac", {"compression": 8}),
("vorbis", {"compression": -1}),
("vorbis", {"compression": 10}),
("amb", {}),
]
)
def test_bytesio_clogged(self, ext, kwargs):
"""Loading audio via clogged file object returns the same result as via file path.
This test case validates the case where fileobject returns shorter bytes than requeted.
"""
sample_rate = 16000
format_ = ext if ext in ["mp3"] else None
path = self.get_temp_path(f"test.{ext}")
sox_utils.gen_audio_file(path, sample_rate, num_channels=2, **kwargs)
expected, _ = sox_io_backend.load(path)
with open(path, "rb") as file_:
fileobj = CloggedFileObj(io.BytesIO(file_.read()))
found, sr = sox_io_backend.load(fileobj, format=format_)
assert sr == sample_rate
self.assertEqual(expected, found)
@parameterized.expand(
[
("wav", {"bit_depth": 16}),
("wav", {"bit_depth": 24}),
("wav", {"bit_depth": 32}),
("mp3", {"compression": 128}),
("mp3", {"compression": 320}),
("flac", {"compression": 0}),
("flac", {"compression": 5}),
("flac", {"compression": 8}),
("vorbis", {"compression": -1}),
("vorbis", {"compression": 10}),
("amb", {}),
]
)
def test_bytesio_tiny(self, ext, kwargs):
"""Loading very small audio via file object returns the same result as via file path."""
sample_rate = 16000
format_ = ext if ext in ["mp3"] else None
path = self.get_temp_path(f"test.{ext}")
sox_utils.gen_audio_file(path, sample_rate, num_channels=2, duration=1 / 1600, **kwargs)
expected, _ = sox_io_backend.load(path)
with open(path, "rb") as file_:
fileobj = io.BytesIO(file_.read())
found, sr = sox_io_backend.load(fileobj, format=format_)
assert sr == sample_rate
self.assertEqual(expected, found)
@parameterized.expand(
[
("wav", {"bit_depth": 16}),
("wav", {"bit_depth": 24}),
("wav", {"bit_depth": 32}),
("mp3", {"compression": 128}),
("mp3", {"compression": 320}),
("flac", {"compression": 0}),
("flac", {"compression": 5}),
("flac", {"compression": 8}),
("vorbis", {"compression": -1}),
("vorbis", {"compression": 10}),
("amb", {}),
]
)
def test_tarfile(self, ext, kwargs):
"""Loading compressed audio via file-like object returns the same result as via file path."""
sample_rate = 16000
format_ = ext if ext in ["mp3"] else None
audio_file = f"test.{ext}"
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path("archive.tar.gz")
sox_utils.gen_audio_file(audio_path, sample_rate, num_channels=2, **kwargs)
expected, _ = sox_io_backend.load(audio_path)
with tarfile.TarFile(archive_path, "w") as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, "r") as tarobj:
fileobj = tarobj.extractfile(audio_file)
found, sr = sox_io_backend.load(fileobj, format=format_)
assert sr == sample_rate
self.assertEqual(expected, found)
class Unseekable:
def __init__(self, fileobj):
self.fileobj = fileobj
def read(self, n):
return self.fileobj.read(n)
@skipIfNoSox
@skipIfNoExec("sox")
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
@parameterized.expand(
[
("wav", {"bit_depth": 16}),
("wav", {"bit_depth": 24}),
("wav", {"bit_depth": 32}),
("mp3", {"compression": 128}),
("mp3", {"compression": 320}),
("flac", {"compression": 0}),
("flac", {"compression": 5}),
("flac", {"compression": 8}),
("vorbis", {"compression": -1}),
("vorbis", {"compression": 10}),
("amb", {}),
]
)
def test_requests(self, ext, kwargs):
sample_rate = 16000
format_ = ext if ext in ["mp3"] else None
audio_file = f"test.{ext}"
audio_path = self.get_temp_path(audio_file)
sox_utils.gen_audio_file(audio_path, sample_rate, num_channels=2, **kwargs)
expected, _ = sox_io_backend.load(audio_path)
url = self.get_url(audio_file)
with requests.get(url, stream=True) as resp:
found, sr = sox_io_backend.load(Unseekable(resp.raw), format=format_)
assert sr == sample_rate
if ext != "mp3":
self.assertEqual(expected, found)
@parameterized.expand(
list(
itertools.product(
[0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
)
),
name_func=name_func,
)
def test_frame(self, frame_offset, num_frames):
"""num_frames and frame_offset correctly specify the region of data"""
sample_rate = 8000
audio_file = "test.wav"
audio_path = self.get_temp_path(audio_file)
original = get_wav_data("float32", num_channels=2)
save_wav(audio_path, original, sample_rate)
frame_end = None if num_frames == -1 else frame_offset + num_frames
expected = original[:, frame_offset:frame_end]
url = self.get_url(audio_file)
with requests.get(url, stream=True) as resp:
found, sr = sox_io_backend.load(resp.raw, frame_offset, num_frames)
assert sr == sample_rate
self.assertEqual(expected, found)
@skipIfNoSox @skipIfNoSox
class TestLoadNoSuchFile(PytorchTestCase): class TestLoadNoSuchFile(PytorchTestCase):
......
import io
import os import os
import torch import torch
...@@ -12,6 +11,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -12,6 +11,7 @@ from torchaudio_unittest.common_utils import (
save_wav, save_wav,
skipIfNoExec, skipIfNoExec,
skipIfNoSox, skipIfNoSox,
skipIfNoSoxEncoder,
sox_utils, sox_utils,
TempDirMixin, TempDirMixin,
TorchaudioTestCase, TorchaudioTestCase,
...@@ -43,7 +43,6 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase): ...@@ -43,7 +43,6 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
num_channels: int = 2, num_channels: int = 2,
num_frames: float = 3 * 8000, num_frames: float = 3 * 8000,
src_dtype: str = "int32", src_dtype: str = "int32",
test_mode: str = "path",
): ):
"""`save` function produces file that is comparable with `sox` command """`save` function produces file that is comparable with `sox` command
...@@ -97,37 +96,9 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase): ...@@ -97,37 +96,9 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
# 2.1. Convert the original wav to target format with torchaudio # 2.1. Convert the original wav to target format with torchaudio
data = load_wav(src_path, normalize=False)[0] data = load_wav(src_path, normalize=False)[0]
if test_mode == "path": sox_io_backend.save(
sox_io_backend.save( tgt_path, data, sample_rate, compression=compression, encoding=encoding, bits_per_sample=bits_per_sample
tgt_path, data, sample_rate, compression=compression, encoding=encoding, bits_per_sample=bits_per_sample )
)
elif test_mode == "fileobj":
with open(tgt_path, "bw") as file_:
sox_io_backend.save(
file_,
data,
sample_rate,
format=format,
compression=compression,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
elif test_mode == "bytesio":
file_ = io.BytesIO()
sox_io_backend.save(
file_,
data,
sample_rate,
format=format,
compression=compression,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
file_.seek(0)
with open(tgt_path, "bw") as f:
f.write(file_.read())
else:
raise ValueError(f"Unexpected test mode: {test_mode}")
# 2.2. Convert the target format to wav with sox # 2.2. Convert the target format to wav with sox
sox_utils.convert_audio_file(tgt_path, tst_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth) sox_utils.convert_audio_file(tgt_path, tst_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
# 2.3. Load with SciPy # 2.3. Load with SciPy
...@@ -150,7 +121,6 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase): ...@@ -150,7 +121,6 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
@skipIfNoSox @skipIfNoSox
class SaveTest(SaveTestBase): class SaveTest(SaveTestBase):
@nested_params( @nested_params(
["path", "fileobj", "bytesio"],
[ [
("PCM_U", 8), ("PCM_U", 8),
("PCM_S", 16), ("PCM_S", 16),
...@@ -161,12 +131,11 @@ class SaveTest(SaveTestBase): ...@@ -161,12 +131,11 @@ class SaveTest(SaveTestBase):
("ALAW", 8), ("ALAW", 8),
], ],
) )
def test_save_wav(self, test_mode, enc_params): def test_save_wav(self, enc_params):
encoding, bits_per_sample = enc_params encoding, bits_per_sample = enc_params
self.assert_save_consistency("wav", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode) self.assert_save_consistency("wav", encoding=encoding, bits_per_sample=bits_per_sample)
@nested_params( @nested_params(
["path", "fileobj", "bytesio"],
[ [
("float32",), ("float32",),
("int32",), ("int32",),
...@@ -174,12 +143,11 @@ class SaveTest(SaveTestBase): ...@@ -174,12 +143,11 @@ class SaveTest(SaveTestBase):
("uint8",), ("uint8",),
], ],
) )
def test_save_wav_dtype(self, test_mode, params): def test_save_wav_dtype(self, params):
(dtype,) = params (dtype,) = params
self.assert_save_consistency("wav", src_dtype=dtype, test_mode=test_mode) self.assert_save_consistency("wav", src_dtype=dtype)
@nested_params( @nested_params(
["path", "fileobj", "bytesio"],
[8, 16, 24], [8, 16, 24],
[ [
None, None,
...@@ -194,19 +162,13 @@ class SaveTest(SaveTestBase): ...@@ -194,19 +162,13 @@ class SaveTest(SaveTestBase):
8, 8,
], ],
) )
def test_save_flac(self, test_mode, bits_per_sample, compression_level): def test_save_flac(self, bits_per_sample, compression_level):
self.assert_save_consistency( self.assert_save_consistency("flac", compression=compression_level, bits_per_sample=bits_per_sample)
"flac", compression=compression_level, bits_per_sample=bits_per_sample, test_mode=test_mode
)
@nested_params( def test_save_htk(self):
["path", "fileobj", "bytesio"], self.assert_save_consistency("htk", num_channels=1)
)
def test_save_htk(self, test_mode):
self.assert_save_consistency("htk", test_mode=test_mode, num_channels=1)
@nested_params( @nested_params(
["path", "fileobj", "bytesio"],
[ [
None, None,
-1, -1,
...@@ -219,11 +181,10 @@ class SaveTest(SaveTestBase): ...@@ -219,11 +181,10 @@ class SaveTest(SaveTestBase):
10, 10,
], ],
) )
def test_save_vorbis(self, test_mode, quality_level): def test_save_vorbis(self, quality_level):
self.assert_save_consistency("vorbis", compression=quality_level, test_mode=test_mode) self.assert_save_consistency("vorbis", compression=quality_level)
@nested_params( @nested_params(
["path", "fileobj", "bytesio"],
[ [
( (
"PCM_S", "PCM_S",
...@@ -248,12 +209,11 @@ class SaveTest(SaveTestBase): ...@@ -248,12 +209,11 @@ class SaveTest(SaveTestBase):
("ALAW", 32), ("ALAW", 32),
], ],
) )
def test_save_sphere(self, test_mode, enc_params): def test_save_sphere(self, enc_params):
encoding, bits_per_sample = enc_params encoding, bits_per_sample = enc_params
self.assert_save_consistency("sph", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode) self.assert_save_consistency("sph", encoding=encoding, bits_per_sample=bits_per_sample)
@nested_params( @nested_params(
["path", "fileobj", "bytesio"],
[ [
( (
"PCM_U", "PCM_U",
...@@ -289,12 +249,11 @@ class SaveTest(SaveTestBase): ...@@ -289,12 +249,11 @@ class SaveTest(SaveTestBase):
), ),
], ],
) )
def test_save_amb(self, test_mode, enc_params): def test_save_amb(self, enc_params):
encoding, bits_per_sample = enc_params encoding, bits_per_sample = enc_params
self.assert_save_consistency("amb", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode) self.assert_save_consistency("amb", encoding=encoding, bits_per_sample=bits_per_sample)
@nested_params( @nested_params(
["path", "fileobj", "bytesio"],
[ [
None, None,
0, 0,
...@@ -307,18 +266,16 @@ class SaveTest(SaveTestBase): ...@@ -307,18 +266,16 @@ class SaveTest(SaveTestBase):
7, 7,
], ],
) )
def test_save_amr_nb(self, test_mode, bit_rate): @skipIfNoSoxEncoder("amr-nb")
self.assert_save_consistency("amr-nb", compression=bit_rate, num_channels=1, test_mode=test_mode) def test_save_amr_nb(self, bit_rate):
self.assert_save_consistency("amr-nb", compression=bit_rate, num_channels=1)
@nested_params( def test_save_gsm(self):
["path", "fileobj", "bytesio"], self.assert_save_consistency("gsm", num_channels=1)
)
def test_save_gsm(self, test_mode):
self.assert_save_consistency("gsm", num_channels=1, test_mode=test_mode)
with self.assertRaises(RuntimeError, msg="gsm format only supports single channel audio."): with self.assertRaises(RuntimeError, msg="gsm format only supports single channel audio."):
self.assert_save_consistency("gsm", num_channels=2, test_mode=test_mode) self.assert_save_consistency("gsm", num_channels=2)
with self.assertRaises(RuntimeError, msg="gsm format only supports a sampling rate of 8kHz."): with self.assertRaises(RuntimeError, msg="gsm format only supports a sampling rate of 8kHz."):
self.assert_save_consistency("gsm", sample_rate=16000, test_mode=test_mode) self.assert_save_consistency("gsm", sample_rate=16000)
@parameterized.expand( @parameterized.expand(
[ [
...@@ -326,12 +283,18 @@ class SaveTest(SaveTestBase): ...@@ -326,12 +283,18 @@ class SaveTest(SaveTestBase):
("flac",), ("flac",),
("vorbis",), ("vorbis",),
("sph", "PCM_S", 16), ("sph", "PCM_S", 16),
("amr-nb",),
("amb", "PCM_S", 16), ("amb", "PCM_S", 16),
], ],
name_func=name_func, name_func=name_func,
) )
def test_save_large(self, format, encoding=None, bits_per_sample=None): def test_save_large(self, format, encoding=None, bits_per_sample=None):
self._test_save_large(format, encoding, bits_per_sample)
@skipIfNoSoxEncoder("amr-nb")
def test_save_large_amr_nb(self):
self._test_save_large("amr-nb")
def _test_save_large(self, format, encoding=None, bits_per_sample=None):
"""`sox_io_backend.save` can save large files.""" """`sox_io_backend.save` can save large files."""
sample_rate = 8000 sample_rate = 8000
one_hour = 60 * 60 * sample_rate one_hour = 60 * 60 * sample_rate
......
import io
import itertools import itertools
from parameterized import parameterized from parameterized import parameterized
...@@ -89,88 +88,3 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase): ...@@ -89,88 +88,3 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase):
def test_flac(self, sample_rate, num_channels, compression_level): def test_flac(self, sample_rate, num_channels, compression_level):
"""Run smoke test on flac format""" """Run smoke test on flac format"""
self.run_smoke_test("flac", sample_rate, num_channels, compression=compression_level) self.run_smoke_test("flac", sample_rate, num_channels, compression=compression_level)
@skipIfNoSox
class SmokeTestFileObj(TorchaudioTestCase):
"""Run smoke test on various audio format
The purpose of this test suite is to verify that sox_io_backend functionalities do not exhibit
abnormal behaviors.
This test suite should be able to run without any additional tools (such as sox command),
however without such tools, the correctness of each function cannot be verified.
"""
def run_smoke_test(self, ext, sample_rate, num_channels, *, compression=None, dtype="float32"):
duration = 1
num_frames = sample_rate * duration
original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames)
fileobj = io.BytesIO()
# 1. run save
sox_io_backend.save(fileobj, original, sample_rate, compression=compression, format=ext)
# 2. run info
fileobj.seek(0)
info = sox_io_backend.info(fileobj, format=ext)
assert info.sample_rate == sample_rate
assert info.num_channels == num_channels
# 3. run load
fileobj.seek(0)
loaded, sr = sox_io_backend.load(fileobj, normalize=False, format=ext)
assert sr == sample_rate
assert loaded.shape[0] == num_channels
@parameterized.expand(
list(
itertools.product(
["float32", "int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_wav(self, dtype, sample_rate, num_channels):
"""Run smoke test on wav format"""
self.run_smoke_test("wav", sample_rate, num_channels, dtype=dtype)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320],
)
)
)
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""Run smoke test on mp3 format"""
self.run_smoke_test("mp3", sample_rate, num_channels, compression=bit_rate)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)
)
)
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""Run smoke test on vorbis format"""
self.run_smoke_test("vorbis", sample_rate, num_channels, compression=quality_level)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)
),
name_func=name_func,
)
def test_flac(self, sample_rate, num_channels, compression_level):
"""Run smoke test on flac format"""
self.run_smoke_test("flac", sample_rate, num_channels, compression=compression_level)
...@@ -20,11 +20,11 @@ from .common import get_enc_params, name_func ...@@ -20,11 +20,11 @@ from .common import get_enc_params, name_func
def py_info_func(filepath: str) -> torchaudio.backend.sox_io_backend.AudioMetaData: def py_info_func(filepath: str) -> torchaudio.backend.sox_io_backend.AudioMetaData:
return torchaudio.info(filepath) return torchaudio.backend.sox_io_backend.info(filepath)
def py_load_func(filepath: str, normalize: bool, channels_first: bool): def py_load_func(filepath: str, normalize: bool, channels_first: bool):
return torchaudio.load(filepath, normalize=normalize, channels_first=channels_first) return torchaudio.backend.sox_io_backend.load(filepath, normalize=normalize, channels_first=channels_first)
def py_save_func( def py_save_func(
...@@ -36,7 +36,9 @@ def py_save_func( ...@@ -36,7 +36,9 @@ def py_save_func(
encoding: Optional[str] = None, encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None, bits_per_sample: Optional[int] = None,
): ):
torchaudio.save(filepath, tensor, sample_rate, channels_first, compression, None, encoding, bits_per_sample) torchaudio.backend.sox_io_backend.save(
filepath, tensor, sample_rate, channels_first, compression, None, encoding, bits_per_sample
)
@skipIfNoExec("sox") @skipIfNoExec("sox")
...@@ -44,8 +46,6 @@ def py_save_func( ...@@ -44,8 +46,6 @@ def py_save_func(
class SoxIO(TempDirMixin, TorchaudioTestCase): class SoxIO(TempDirMixin, TorchaudioTestCase):
"""TorchScript-ability Test suite for `sox_io_backend`""" """TorchScript-ability Test suite for `sox_io_backend`"""
backend = "sox_io"
@parameterized.expand( @parameterized.expand(
list( list(
itertools.product( itertools.product(
......
...@@ -9,11 +9,11 @@ class BackendSwitchMixin: ...@@ -9,11 +9,11 @@ class BackendSwitchMixin:
backend_module = None backend_module = None
def test_switch(self): def test_switch(self):
torchaudio.set_audio_backend(self.backend) torchaudio.backend.utils.set_audio_backend(self.backend)
if self.backend is None: if self.backend is None:
assert torchaudio.get_audio_backend() is None assert torchaudio.backend.utils.get_audio_backend() is None
else: else:
assert torchaudio.get_audio_backend() == self.backend assert torchaudio.backend.utils.get_audio_backend() == self.backend
assert torchaudio.load == self.backend_module.load assert torchaudio.load == self.backend_module.load
assert torchaudio.save == self.backend_module.save assert torchaudio.save == self.backend_module.save
assert torchaudio.info == self.backend_module.info assert torchaudio.info == self.backend_module.info
......
from .autograd_utils import use_deterministic_algorithms
from .backend_utils import set_audio_backend from .backend_utils import set_audio_backend
from .case_utils import ( from .case_utils import (
disabledInCI,
HttpServerMixin, HttpServerMixin,
is_ffmpeg_available, is_ffmpeg_available,
PytorchTestCase, PytorchTestCase,
skipIfCudaSmallMemory,
skipIfNoAudioDevice,
skipIfNoCtcDecoder, skipIfNoCtcDecoder,
skipIfNoCuCtcDecoder,
skipIfNoCuda, skipIfNoCuda,
skipIfNoExec, skipIfNoExec,
skipIfkmeMark,
skipIfNoFFmpeg, skipIfNoFFmpeg,
skipIfNoKaldi, skipIfNoHWAccel,
skipIfNoMacOS,
skipIfNoModule, skipIfNoModule,
skipIfNoQengine, skipIfNoQengine,
skipIfNoRIR,
skipIfNoSox, skipIfNoSox,
skipIfNoSoxDecoder,
skipIfNoSoxEncoder,
skipIfPy310, skipIfPy310,
skipIfRocm, skipIfRocm,
TempDirMixin, TempDirMixin,
TestBaseMixin, TestBaseMixin,
TorchaudioTestCase, TorchaudioTestCase,
zip_equal,
) )
from .data_utils import get_asset_path, get_sinusoid, get_spectrogram, get_whitenoise from .data_utils import get_asset_path, get_sinusoid, get_spectrogram, get_whitenoise
from .func_utils import torch_script from .func_utils import torch_script
...@@ -35,17 +46,24 @@ __all__ = [ ...@@ -35,17 +46,24 @@ __all__ = [
"PytorchTestCase", "PytorchTestCase",
"TorchaudioTestCase", "TorchaudioTestCase",
"is_ffmpeg_available", "is_ffmpeg_available",
"skipIfNoAudioDevice",
"skipIfNoCtcDecoder", "skipIfNoCtcDecoder",
"skipIfNoCuCtcDecoder",
"skipIfNoCuda", "skipIfNoCuda",
"skipIfCudaSmallMemory",
"skipIfNoExec", "skipIfNoExec",
"skipIfNoMacOS",
"skipIfNoModule", "skipIfNoModule",
"skipIfNoKaldi", "skipIfNoRIR",
"skipIfNoSox", "skipIfNoSox",
"skipIfNoSoxBackend", "skipIfNoSoxDecoder",
"skipIfNoSoxEncoder",
"skipIfRocm", "skipIfRocm",
"skipIfNoQengine", "skipIfNoQengine",
"skipIfNoFFmpeg", "skipIfNoFFmpeg",
"skipIfNoHWAccel",
"skipIfPy310", "skipIfPy310",
"disabledInCI",
"get_wav_data", "get_wav_data",
"normalize_wav", "normalize_wav",
"load_wav", "load_wav",
...@@ -57,4 +75,6 @@ __all__ = [ ...@@ -57,4 +75,6 @@ __all__ = [
"get_image", "get_image",
"rgb_to_gray", "rgb_to_gray",
"rgb_to_yuv_ccir", "rgb_to_yuv_ccir",
"use_deterministic_algorithms",
"zip_equal",
] ]
import contextlib
import torch
@contextlib.contextmanager
def use_deterministic_algorithms(mode: bool, warn_only: bool):
r"""
This context manager can be used to temporarily enable or disable deterministic algorithms.
Upon exiting the context manager, the previous state of the flag will be restored.
"""
previous_mode: bool = torch.are_deterministic_algorithms_enabled()
previous_warn_only: bool = torch.is_deterministic_algorithms_warn_only_enabled()
try:
torch.use_deterministic_algorithms(mode, warn_only=warn_only)
yield {}
except RuntimeError as err:
raise err
finally:
torch.use_deterministic_algorithms(previous_mode, warn_only=previous_warn_only)
...@@ -6,11 +6,13 @@ import sys ...@@ -6,11 +6,13 @@ import sys
import tempfile import tempfile
import time import time
import unittest import unittest
from itertools import zip_longest
import torch import torch
import torchaudio import torchaudio
from torch.testing._internal.common_utils import TestCase as PytorchTestCase from torch.testing._internal.common_utils import TestCase as PytorchTestCase
from torchaudio._internal.module_utils import is_kaldi_available, is_module_available, is_sox_available from torchaudio._internal.module_utils import eval_env, is_module_available
from torchaudio.utils.ffmpeg_utils import get_video_decoders, get_video_encoders
from .backend_utils import set_audio_backend from .backend_utils import set_audio_backend
...@@ -65,7 +67,7 @@ class HttpServerMixin(TempDirMixin): ...@@ -65,7 +67,7 @@ class HttpServerMixin(TempDirMixin):
""" """
_proc = None _proc = None
_port = 8000 _port = 12345
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
...@@ -110,10 +112,11 @@ class TorchaudioTestCase(TestBaseMixin, PytorchTestCase): ...@@ -110,10 +112,11 @@ class TorchaudioTestCase(TestBaseMixin, PytorchTestCase):
def is_ffmpeg_available(): def is_ffmpeg_available():
return torchaudio._extension._FFMPEG_INITIALIZED return torchaudio._extension._FFMPEG_EXT is not None
_IS_CTC_DECODER_AVAILABLE = None _IS_CTC_DECODER_AVAILABLE = None
_IS_CUDA_CTC_DECODER_AVAILABLE = None
def is_ctc_decoder_available(): def is_ctc_decoder_available():
...@@ -128,22 +131,16 @@ def is_ctc_decoder_available(): ...@@ -128,22 +131,16 @@ def is_ctc_decoder_available():
return _IS_CTC_DECODER_AVAILABLE return _IS_CTC_DECODER_AVAILABLE
def _eval_env(var, default): def is_cuda_ctc_decoder_available():
if var not in os.environ: global _IS_CUDA_CTC_DECODER_AVAILABLE
return default if _IS_CUDA_CTC_DECODER_AVAILABLE is None:
try:
from torchaudio.models.decoder import CUCTCDecoder # noqa: F401
val = os.environ.get(var, "0") _IS_CUDA_CTC_DECODER_AVAILABLE = True
trues = ["1", "true", "TRUE", "on", "ON", "yes", "YES"] except Exception:
falses = ["0", "false", "FALSE", "off", "OFF", "no", "NO"] _IS_CUDA_CTC_DECODER_AVAILABLE = False
if val in trues: return _IS_CUDA_CTC_DECODER_AVAILABLE
return True
if val not in falses:
# fmt: off
raise RuntimeError(
f"Unexpected environment variable value `{var}={val}`. "
f"Expected one of {trues + falses}")
# fmt: on
return False
def _fail(reason): def _fail(reason):
...@@ -170,7 +167,7 @@ def _pass(test_item): ...@@ -170,7 +167,7 @@ def _pass(test_item):
return test_item return test_item
_IN_CI = _eval_env("CI", default=False) _IN_CI = eval_env("CI", default=False)
def _skipIf(condition, reason, key): def _skipIf(condition, reason, key):
...@@ -180,7 +177,7 @@ def _skipIf(condition, reason, key): ...@@ -180,7 +177,7 @@ def _skipIf(condition, reason, key):
# In CI, default to fail, so as to prevent accidental skip. # In CI, default to fail, so as to prevent accidental skip.
# In other env, default to skip # In other env, default to skip
var = f"TORCHAUDIO_TEST_ALLOW_SKIP_IF_{key}" var = f"TORCHAUDIO_TEST_ALLOW_SKIP_IF_{key}"
skip_allowed = _eval_env(var, default=not _IN_CI) skip_allowed = eval_env(var, default=not _IN_CI)
if skip_allowed: if skip_allowed:
return unittest.skip(reason) return unittest.skip(reason)
return _fail(f"{reason} But the test cannot be skipped. (CI={_IN_CI}, {var}={skip_allowed}.)") return _fail(f"{reason} But the test cannot be skipped. (CI={_IN_CI}, {var}={skip_allowed}.)")
...@@ -207,23 +204,53 @@ skipIfNoCuda = _skipIf( ...@@ -207,23 +204,53 @@ skipIfNoCuda = _skipIf(
reason="CUDA is not available.", reason="CUDA is not available.",
key="NO_CUDA", key="NO_CUDA",
) )
# Skip test if CUDA memory is not enough
# TODO: detect the real CUDA memory size and allow call site to configure how much the test needs
skipIfCudaSmallMemory = _skipIf(
"CI" in os.environ and torch.cuda.is_available(), # temporary
reason="CUDA does not have enough memory.",
key="CUDA_SMALL_MEMORY",
)
skipIfNoSox = _skipIf( skipIfNoSox = _skipIf(
not is_sox_available(), not torchaudio._extension._SOX_INITIALIZED,
reason="Sox features are not available.", reason="Sox features are not available.",
key="NO_SOX", key="NO_SOX",
) )
skipIfNoKaldi = _skipIf(
not is_kaldi_available(),
reason="Kaldi features are not available.", def skipIfNoSoxDecoder(ext):
key="NO_KALDI", return _skipIf(
not torchaudio._extension._SOX_INITIALIZED or ext not in torchaudio.utils.sox_utils.list_read_formats(),
f'sox does not handle "{ext}" for read.',
key="NO_SOX_DECODER",
)
def skipIfNoSoxEncoder(ext):
return _skipIf(
not torchaudio._extension._SOX_INITIALIZED or ext not in torchaudio.utils.sox_utils.list_write_formats(),
f'sox does not handle "{ext}" for write.',
key="NO_SOX_ENCODER",
)
skipIfNoRIR = _skipIf(
not torchaudio._extension._IS_RIR_AVAILABLE,
reason="RIR features are not available.",
key="NO_RIR",
) )
skipIfNoCtcDecoder = _skipIf( skipIfNoCtcDecoder = _skipIf(
not is_ctc_decoder_available(), not is_ctc_decoder_available(),
reason="CTC decoder not available.", reason="CTC decoder not available.",
key="NO_CTC_DECODER", key="NO_CTC_DECODER",
) )
skipIfNoCuCtcDecoder = _skipIf(
not is_cuda_ctc_decoder_available(),
reason="CUCTC decoder not available.",
key="NO_CUCTC_DECODER",
)
skipIfRocm = _skipIf( skipIfRocm = _skipIf(
_eval_env("TORCHAUDIO_TEST_WITH_ROCM", default=False), eval_env("TORCHAUDIO_TEST_WITH_ROCM", default=False),
reason="The test doesn't currently work on the ROCm stack.", reason="The test doesn't currently work on the ROCm stack.",
key="ON_ROCM", key="ON_ROCM",
) )
...@@ -245,3 +272,55 @@ skipIfPy310 = _skipIf( ...@@ -245,3 +272,55 @@ skipIfPy310 = _skipIf(
), ),
key="ON_PYTHON_310", key="ON_PYTHON_310",
) )
skipIfNoAudioDevice = _skipIf(
not torchaudio.utils.ffmpeg_utils.get_output_devices(),
reason="No output audio device is available.",
key="NO_AUDIO_OUT_DEVICE",
)
skipIfNoMacOS = _skipIf(
sys.platform != "darwin",
reason="This feature is only available for MacOS.",
key="NO_MACOS",
)
disabledInCI = _skipIf(
"CI" in os.environ,
reason="Tests are failing on CI consistently. Disabled while investigating.",
key="TEMPORARY_DISABLED",
)
def skipIfNoHWAccel(name):
key = "NO_HW_ACCEL"
if not is_ffmpeg_available():
return _skipIf(True, reason="ffmpeg features are not available.", key=key)
if not torch.cuda.is_available():
return _skipIf(True, reason="CUDA is not available.", key=key)
if torchaudio._extension._check_cuda_version() is None:
return _skipIf(True, "Torchaudio is not compiled with CUDA.", key=key)
if name not in get_video_decoders() and name not in get_video_encoders():
return _skipIf(True, f"{name} is not in the list of available decoders or encoders", key=key)
return _pass
def checkkme():
res = subprocess.run('rocminfo | grep gfx928',shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if res.stdout:
return True
return False
iskme = checkkme()
skipIfkmeMark = _skipIf(
iskme,
reason="not support fp64 in kme for this case",
key = "NOT_SUPPORT_FP64_IN_KME",
)
def zip_equal(*iterables):
"""With the regular Python `zip` function, if one iterable is longer than the other,
the remainder portions are ignored.This is resolved in Python 3.10 where we can use
`strict=True` in the `zip` function
From https://github.com/pytorch/text/blob/c047efeba813ac943cb8046a49e858a8b529d577/test/torchtext_unittest/common/case_utils.py#L45-L54 # noqa: E501
"""
sentinel = object()
for combo in zip_longest(*iterables, fillvalue=sentinel):
if sentinel in combo:
raise ValueError("Iterables have different lengths")
yield combo
...@@ -189,13 +189,19 @@ def compute_with_numpy_transducer(data): ...@@ -189,13 +189,19 @@ def compute_with_numpy_transducer(data):
def compute_with_pytorch_transducer(data): def compute_with_pytorch_transducer(data):
fused_log_softmax = data.get("fused_log_softmax", True)
input = data["logits"]
if not fused_log_softmax:
input = torch.nn.functional.log_softmax(input, dim=-1)
costs = rnnt_loss( costs = rnnt_loss(
logits=data["logits"], logits=input,
logit_lengths=data["logit_lengths"], logit_lengths=data["logit_lengths"],
target_lengths=data["target_lengths"], target_lengths=data["target_lengths"],
targets=data["targets"], targets=data["targets"],
blank=data["blank"], blank=data["blank"],
reduction="none", reduction="none",
fused_log_softmax=fused_log_softmax,
) )
loss = torch.sum(costs) loss = torch.sum(costs)
...@@ -260,6 +266,7 @@ def get_B1_T10_U3_D4_data( ...@@ -260,6 +266,7 @@ def get_B1_T10_U3_D4_data(
data["target_lengths"] = torch.tensor([2, 2], dtype=torch.int32, device=device) data["target_lengths"] = torch.tensor([2, 2], dtype=torch.int32, device=device)
data["targets"] = torch.tensor([[1, 2], [1, 2]], dtype=torch.int32, device=device) data["targets"] = torch.tensor([[1, 2], [1, 2]], dtype=torch.int32, device=device)
data["blank"] = 0 data["blank"] = 0
data["fused_log_softmax"] = False
return data return data
...@@ -552,6 +559,7 @@ def get_random_data( ...@@ -552,6 +559,7 @@ def get_random_data(
max_U=32, max_U=32,
max_D=40, max_D=40,
blank=-1, blank=-1,
fused_log_softmax=True,
dtype=torch.float32, dtype=torch.float32,
device=CPU_DEVICE, device=CPU_DEVICE,
seed=None, seed=None,
...@@ -591,6 +599,7 @@ def get_random_data( ...@@ -591,6 +599,7 @@ def get_random_data(
"logit_lengths": logit_lengths, "logit_lengths": logit_lengths,
"target_lengths": target_lengths, "target_lengths": target_lengths,
"blank": blank, "blank": blank,
"fused_log_softmax": fused_log_softmax,
} }
......
from contextlib import contextmanager
from functools import partial
from unittest.mock import patch
import torch
from parameterized import parameterized
from torchaudio._internal.module_utils import is_module_available
from torchaudio_unittest.common_utils import skipIfNoModule, TorchaudioTestCase
from .utils import MockCustomDataset, MockDataloader, MockSentencePieceProcessor
if is_module_available("pytorch_lightning", "sentencepiece"):
from asr.emformer_rnnt.mustc.lightning import MuSTCRNNTModule
class MockMUSTC:
def __init__(self, *args, **kwargs):
pass
def __getitem__(self, n: int):
return (
torch.rand(1, 32640),
"sup",
)
def __len__(self):
return 10
@contextmanager
def get_lightning_module():
with patch("sentencepiece.SentencePieceProcessor", new=partial(MockSentencePieceProcessor, num_symbols=500)), patch(
"asr.emformer_rnnt.mustc.lightning.GlobalStatsNormalization", new=torch.nn.Identity
), patch("asr.emformer_rnnt.mustc.lightning.MUSTC", new=MockMUSTC), patch(
"asr.emformer_rnnt.mustc.lightning.CustomDataset", new=MockCustomDataset
), patch(
"torch.utils.data.DataLoader", new=MockDataloader
):
yield MuSTCRNNTModule(
mustc_path="mustc_path",
sp_model_path="sp_model_path",
global_stats_path="global_stats_path",
)
@skipIfNoModule("pytorch_lightning")
@skipIfNoModule("sentencepiece")
class TestMuSTCRNNTModule(TorchaudioTestCase):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
@parameterized.expand(
[
("training_step", "train_dataloader"),
("validation_step", "val_dataloader"),
("test_step", "test_common_dataloader"),
("test_step", "test_he_dataloader"),
]
)
def test_step(self, step_fname, dataloader_fname):
with get_lightning_module() as lightning_module:
dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(dataloader))
getattr(lightning_module, step_fname)(batch, 0)
@parameterized.expand(
[
("val_dataloader",),
]
)
def test_forward(self, dataloader_fname):
with get_lightning_module() as lightning_module:
dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(dataloader))
lightning_module(batch)
from contextlib import contextmanager
from functools import partial
from unittest.mock import patch
import torch
from parameterized import parameterized
from torchaudio._internal.module_utils import is_module_available
from torchaudio_unittest.common_utils import skipIfNoModule, TorchaudioTestCase
from .utils import MockCustomDataset, MockDataloader, MockSentencePieceProcessor
if is_module_available("pytorch_lightning", "sentencepiece"):
from asr.emformer_rnnt.tedlium3.lightning import TEDLIUM3RNNTModule
class MockTEDLIUM:
def __init__(self, *args, **kwargs):
pass
def __getitem__(self, n: int):
return (
torch.rand(1, 32640),
16000,
"sup",
2,
3,
4,
)
def __len__(self):
return 10
@contextmanager
def get_lightning_module():
with patch("sentencepiece.SentencePieceProcessor", new=partial(MockSentencePieceProcessor, num_symbols=500)), patch(
"asr.emformer_rnnt.tedlium3.lightning.GlobalStatsNormalization", new=torch.nn.Identity
), patch("torchaudio.datasets.TEDLIUM", new=MockTEDLIUM), patch(
"asr.emformer_rnnt.tedlium3.lightning.CustomDataset", new=MockCustomDataset
), patch(
"torch.utils.data.DataLoader", new=MockDataloader
):
yield TEDLIUM3RNNTModule(
tedlium_path="tedlium_path",
sp_model_path="sp_model_path",
global_stats_path="global_stats_path",
)
@skipIfNoModule("pytorch_lightning")
@skipIfNoModule("sentencepiece")
class TestTEDLIUM3RNNTModule(TorchaudioTestCase):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
@parameterized.expand(
[
("training_step", "train_dataloader"),
("validation_step", "val_dataloader"),
("test_step", "test_dataloader"),
]
)
def test_step(self, step_fname, dataloader_fname):
with get_lightning_module() as lightning_module:
dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(dataloader))
getattr(lightning_module, step_fname)(batch, 0)
@parameterized.expand(
[
("val_dataloader",),
]
)
def test_forward(self, dataloader_fname):
with get_lightning_module() as lightning_module:
dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(dataloader))
lightning_module(batch)
...@@ -5,6 +5,7 @@ from .autograd_impl import Autograd, AutogradFloat32 ...@@ -5,6 +5,7 @@ from .autograd_impl import Autograd, AutogradFloat32
@common_utils.skipIfNoCuda @common_utils.skipIfNoCuda
@common_utils.skipIfkmeMark
class TestAutogradLfilterCUDA(Autograd, common_utils.PytorchTestCase): class TestAutogradLfilterCUDA(Autograd, common_utils.PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device("cuda") device = torch.device("cuda")
......
...@@ -6,7 +6,14 @@ import torchaudio.functional as F ...@@ -6,7 +6,14 @@ import torchaudio.functional as F
from parameterized import parameterized from parameterized import parameterized
from torch import Tensor from torch import Tensor
from torch.autograd import gradcheck, gradgradcheck from torch.autograd import gradcheck, gradgradcheck
from torchaudio_unittest.common_utils import get_spectrogram, get_whitenoise, rnnt_utils, TestBaseMixin from torchaudio_unittest.common_utils import (
get_spectrogram,
get_whitenoise,
nested_params,
rnnt_utils,
TestBaseMixin,
use_deterministic_algorithms,
)
class Autograd(TestBaseMixin): class Autograd(TestBaseMixin):
...@@ -71,26 +78,30 @@ class Autograd(TestBaseMixin): ...@@ -71,26 +78,30 @@ class Autograd(TestBaseMixin):
a = torch.tensor([0.7, 0.2, 0.6]) a = torch.tensor([0.7, 0.2, 0.6])
b = torch.tensor([0.4, 0.2, 0.9]) b = torch.tensor([0.4, 0.2, 0.9])
a.requires_grad = True a.requires_grad = True
self.assert_grad(F.filtfilt, (x, a, b), enable_all_grad=False) with use_deterministic_algorithms(True, False):
self.assert_grad(F.filtfilt, (x, a, b), enable_all_grad=False)
def test_filtfilt_b(self): def test_filtfilt_b(self):
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2) x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
a = torch.tensor([0.7, 0.2, 0.6]) a = torch.tensor([0.7, 0.2, 0.6])
b = torch.tensor([0.4, 0.2, 0.9]) b = torch.tensor([0.4, 0.2, 0.9])
b.requires_grad = True b.requires_grad = True
self.assert_grad(F.filtfilt, (x, a, b), enable_all_grad=False) with use_deterministic_algorithms(True, False):
self.assert_grad(F.filtfilt, (x, a, b), enable_all_grad=False)
def test_filtfilt_all_inputs(self): def test_filtfilt_all_inputs(self):
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2) x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
a = torch.tensor([0.7, 0.2, 0.6]) a = torch.tensor([0.7, 0.2, 0.6])
b = torch.tensor([0.4, 0.2, 0.9]) b = torch.tensor([0.4, 0.2, 0.9])
self.assert_grad(F.filtfilt, (x, a, b)) with use_deterministic_algorithms(True, False):
self.assert_grad(F.filtfilt, (x, a, b))
def test_filtfilt_batching(self): def test_filtfilt_batching(self):
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2) x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
a = torch.tensor([[0.7, 0.2, 0.6], [0.8, 0.2, 0.9]]) a = torch.tensor([[0.7, 0.2, 0.6], [0.8, 0.2, 0.9]])
b = torch.tensor([[0.4, 0.2, 0.9], [0.7, 0.2, 0.6]]) b = torch.tensor([[0.4, 0.2, 0.9], [0.7, 0.2, 0.6]])
self.assert_grad(F.filtfilt, (x, a, b)) with use_deterministic_algorithms(True, False):
self.assert_grad(F.filtfilt, (x, a, b))
def test_biquad(self): def test_biquad(self):
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1) x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1)
...@@ -335,6 +346,51 @@ class Autograd(TestBaseMixin): ...@@ -335,6 +346,51 @@ class Autograd(TestBaseMixin):
beamform_weights = torch.rand(batch_size, n_fft_bin, num_channels, dtype=torch.cfloat) beamform_weights = torch.rand(batch_size, n_fft_bin, num_channels, dtype=torch.cfloat)
self.assert_grad(F.apply_beamforming, (beamform_weights, specgram)) self.assert_grad(F.apply_beamforming, (beamform_weights, specgram))
@nested_params(
["convolve", "fftconvolve"],
["full", "valid", "same"],
)
def test_convolve(self, fn, mode):
leading_dims = (4, 3, 2)
L_x, L_y = 23, 40
x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device)
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device)
self.assert_grad(getattr(F, fn), (x, y, mode))
def test_add_noise(self):
leading_dims = (5, 2, 3)
L = 51
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
noise = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device)
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device) * 10
self.assert_grad(F.add_noise, (waveform, noise, snr, lengths))
def test_speed(self):
leading_dims = (3, 2)
T = 200
waveform = torch.rand(*leading_dims, T, dtype=self.dtype, device=self.device, requires_grad=True)
lengths = torch.randint(1, T, leading_dims, dtype=self.dtype, device=self.device)
self.assert_grad(F.speed, (waveform, 1000, 1.1, lengths), enable_all_grad=False)
def test_preemphasis(self):
waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype, requires_grad=True)
coeff = 0.9
self.assert_grad(F.preemphasis, (waveform, coeff))
def test_deemphasis(self):
waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype, requires_grad=True)
coeff = 0.9
self.assert_grad(F.deemphasis, (waveform, coeff))
def test_frechet_distance(self):
N = 16
mu_x = torch.rand((N,))
sigma_x = torch.rand((N, N))
mu_y = torch.rand((N,))
sigma_y = torch.rand((N, N))
self.assert_grad(F.frechet_distance, (mu_x, sigma_x, mu_y, sigma_y))
class AutogradFloat32(TestBaseMixin): class AutogradFloat32(TestBaseMixin):
def assert_grad( def assert_grad(
......
...@@ -27,7 +27,7 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -27,7 +27,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
backend = "default" backend = "default"
def assert_batch_consistency(self, functional, inputs, atol=1e-8, rtol=1e-5, seed=42): def assert_batch_consistency(self, functional, inputs, atol=1e-6, rtol=1e-5, seed=42):
n = inputs[0].size(0) n = inputs[0].size(0)
for i in range(1, len(inputs)): for i in range(1, len(inputs)):
self.assertEqual(inputs[i].size(0), n) self.assertEqual(inputs[i].size(0), n)
...@@ -65,7 +65,7 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -65,7 +65,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
"rand_init": False, "rand_init": False,
} }
func = partial(F.griffinlim, **kwargs) func = partial(F.griffinlim, **kwargs)
self.assert_batch_consistency(func, inputs=(batch,), atol=5e-5) self.assert_batch_consistency(func, inputs=(batch,), atol=1e-4)
@parameterized.expand( @parameterized.expand(
list( list(
...@@ -194,7 +194,7 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -194,7 +194,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self.assert_batch_consistency(func, inputs=(waveforms,)) self.assert_batch_consistency(func, inputs=(waveforms,))
def test_phaser(self): def test_phaser(self):
sample_rate = 44100 sample_rate = 8000
n_channels = 2 n_channels = 2
waveform = common_utils.get_whitenoise( waveform = common_utils.get_whitenoise(
sample_rate=sample_rate, n_channels=self.batch_size * n_channels, duration=1 sample_rate=sample_rate, n_channels=self.batch_size * n_channels, duration=1
...@@ -208,7 +208,7 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -208,7 +208,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
def test_flanger(self): def test_flanger(self):
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5 waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
sample_rate = 44100 sample_rate = 8000
kwargs = { kwargs = {
"sample_rate": sample_rate, "sample_rate": sample_rate,
} }
...@@ -233,7 +233,7 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -233,7 +233,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
func = partial(F.sliding_window_cmn, **kwargs) func = partial(F.sliding_window_cmn, **kwargs)
self.assert_batch_consistency(func, inputs=(spectrogram,)) self.assert_batch_consistency(func, inputs=(spectrogram,))
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")]) @parameterized.expand([("sinc_interp_hann"), ("sinc_interp_kaiser")])
def test_resample_waveform(self, resampling_method): def test_resample_waveform(self, resampling_method):
num_channels = 3 num_channels = 3
sr = 16000 sr = 16000
...@@ -257,18 +257,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -257,18 +257,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
atol=1e-7, atol=1e-7,
) )
@common_utils.skipIfNoKaldi
def test_compute_kaldi_pitch(self):
sample_rate = 44100
n_channels = 2
waveform = common_utils.get_whitenoise(sample_rate=sample_rate, n_channels=self.batch_size * n_channels)
batch = waveform.view(self.batch_size, n_channels, waveform.size(-1))
kwargs = {
"sample_rate": sample_rate,
}
func = partial(F.compute_kaldi_pitch, **kwargs)
self.assert_batch_consistency(func, inputs=(batch,))
def test_lfilter(self): def test_lfilter(self):
signal_length = 2048 signal_length = 2048
x = torch.randn(self.batch_size, signal_length) x = torch.randn(self.batch_size, signal_length)
...@@ -407,3 +395,90 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -407,3 +395,90 @@ class TestFunctional(common_utils.TorchaudioTestCase):
specgram = specgram.view(batch_size, num_channels, n_fft_bin, specgram.size(-1)) specgram = specgram.view(batch_size, num_channels, n_fft_bin, specgram.size(-1))
beamform_weights = torch.rand(batch_size, n_fft_bin, num_channels, dtype=torch.cfloat) beamform_weights = torch.rand(batch_size, n_fft_bin, num_channels, dtype=torch.cfloat)
self.assert_batch_consistency(F.apply_beamforming, (beamform_weights, specgram)) self.assert_batch_consistency(F.apply_beamforming, (beamform_weights, specgram))
@common_utils.nested_params(
["convolve", "fftconvolve"],
["full", "valid", "same"],
)
def test_convolve(self, fn, mode):
leading_dims = (2, 3)
L_x, L_y = 89, 43
x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device)
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device)
fn = getattr(F, fn)
actual = fn(x, y, mode)
expected = torch.stack(
[
torch.stack(
[fn(x[i, j].unsqueeze(0), y[i, j].unsqueeze(0), mode).squeeze(0) for j in range(leading_dims[1])]
)
for i in range(leading_dims[0])
]
)
self.assertEqual(expected, actual)
def test_add_noise(self):
leading_dims = (5, 2, 3)
L = 51
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
noise = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device)
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device) * 10
actual = F.add_noise(waveform, noise, snr, lengths)
expected = []
for i in range(leading_dims[0]):
for j in range(leading_dims[1]):
for k in range(leading_dims[2]):
expected.append(F.add_noise(waveform[i][j][k], noise[i][j][k], snr[i][j][k], lengths[i][j][k]))
self.assertEqual(torch.stack(expected), actual.reshape(-1, L))
def test_speed(self):
B = 5
orig_freq = 100
factor = 0.8
input_lengths = torch.randint(1, 1000, (B,), dtype=torch.int32)
unbatched_input = [torch.ones((int(length),)) * 1.0 for length in input_lengths]
batched_input = torch.nn.utils.rnn.pad_sequence(unbatched_input, batch_first=True)
output, output_lengths = F.speed(batched_input, orig_freq=orig_freq, factor=factor, lengths=input_lengths)
unbatched_output = []
unbatched_output_lengths = []
for idx in range(len(unbatched_input)):
w, l = F.speed(unbatched_input[idx], orig_freq=orig_freq, factor=factor, lengths=input_lengths[idx])
unbatched_output.append(w)
unbatched_output_lengths.append(l)
self.assertEqual(output_lengths, torch.stack(unbatched_output_lengths))
for idx in range(len(unbatched_output)):
w, l = output[idx], output_lengths[idx]
self.assertEqual(unbatched_output[idx], w[:l])
def test_preemphasis(self):
waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype)
coeff = 0.9
actual = F.preemphasis(waveform, coeff=coeff)
expected = []
for i in range(waveform.size(0)):
expected.append(F.preemphasis(waveform[i], coeff=coeff))
self.assertEqual(torch.stack(expected), actual)
def test_deemphasis(self):
waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype)
coeff = 0.9
actual = F.deemphasis(waveform, coeff=coeff)
expected = []
for i in range(waveform.size(0)):
expected.append(F.deemphasis(waveform[i], coeff=coeff))
self.assertEqual(torch.stack(expected), actual)
import unittest import unittest
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda, skipIfkmeMark
from .functional_impl import Functional from .functional_impl import Functional, FunctionalCUDAOnly
@skipIfNoCuda @skipIfNoCuda
...@@ -17,6 +17,20 @@ class TestFunctionalFloat32(Functional, PytorchTestCase): ...@@ -17,6 +17,20 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
@skipIfNoCuda @skipIfNoCuda
@skipIfkmeMark
class TestLFilterFloat64(Functional, PytorchTestCase): class TestLFilterFloat64(Functional, PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device("cuda") device = torch.device("cuda")
@skipIfNoCuda
class TestFunctionalCUDAOnlyFloat32(FunctionalCUDAOnly, PytorchTestCase):
dtype = torch.float32
device = torch.device("cuda")
@skipIfNoCuda
@skipIfkmeMark
class TestFunctionalCUDAOnlyFloat64(FunctionalCUDAOnly, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
...@@ -20,7 +20,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -20,7 +20,7 @@ from torchaudio_unittest.common_utils import (
class Functional(TestBaseMixin): class Functional(TestBaseMixin):
def _test_resample_waveform_accuracy( def _test_resample_waveform_accuracy(
self, up_scale_factor=None, down_scale_factor=None, resampling_method="sinc_interpolation", atol=1e-1, rtol=1e-4 self, up_scale_factor=None, down_scale_factor=None, resampling_method="sinc_interp_hann", atol=1e-1, rtol=1e-4
): ):
# resample the signal and compare it to the ground truth # resample the signal and compare it to the ground truth
n_to_trim = 20 n_to_trim = 20
...@@ -51,6 +51,7 @@ class Functional(TestBaseMixin): ...@@ -51,6 +51,7 @@ class Functional(TestBaseMixin):
def _test_costs_and_gradients(self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2): def _test_costs_and_gradients(self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2):
logits_shape = data["logits"].shape logits_shape = data["logits"].shape
costs, gradients = rnnt_utils.compute_with_pytorch_transducer(data=data) costs, gradients = rnnt_utils.compute_with_pytorch_transducer(data=data)
self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol) self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol)
self.assertEqual(logits_shape, gradients.shape) self.assertEqual(logits_shape, gradients.shape)
self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol) self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol)
...@@ -396,22 +397,38 @@ class Functional(TestBaseMixin): ...@@ -396,22 +397,38 @@ class Functional(TestBaseMixin):
close_to_limit = decibels < 6.0207 close_to_limit = decibels < 6.0207
assert close_to_limit.any(), f"No values were close to the limit. Did it over-clamp?\n{decibels}" assert close_to_limit.any(), f"No values were close to the limit. Did it over-clamp?\n{decibels}"
@parameterized.expand(list(itertools.product([(1, 201, 100), (10, 2, 201, 300)])))
def test_mask_along_axis_input_axis_check(self, shape):
specgram = torch.randn(*shape, dtype=self.dtype, device=self.device)
message = "Only Frequency and Time masking are supported"
with self.assertRaisesRegex(ValueError, message):
F.mask_along_axis(specgram, 100, 0.0, 0, 1.0)
@parameterized.expand( @parameterized.expand(
list(itertools.product([(2, 1025, 400), (1, 201, 100)], [100], [0.0, 30.0], [1, 2], [0.33, 1.0])) list(
itertools.product([(1025, 400), (1, 201, 100), (10, 2, 201, 300)], [100], [0.0, 30.0], [1, 2], [0.33, 1.0])
)
) )
def test_mask_along_axis(self, shape, mask_param, mask_value, axis, p): def test_mask_along_axis(self, shape, mask_param, mask_value, last_axis, p):
specgram = torch.randn(*shape, dtype=self.dtype, device=self.device) specgram = torch.randn(*shape, dtype=self.dtype, device=self.device)
# last_axis = 1 means the last axis; 2 means the second-to-last axis.
axis = len(shape) - last_axis
if p != 1.0: if p != 1.0:
mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis, p=p) mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis, p=p)
else: else:
mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis) mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis)
other_axis = 1 if axis == 2 else 2 other_axis = axis - 1 if last_axis == 1 else axis + 1
masked_columns = (mask_specgram == mask_value).sum(other_axis) masked_columns = (mask_specgram == mask_value).sum(other_axis)
num_masked_columns = (masked_columns == mask_specgram.size(other_axis)).sum() num_masked_columns = (masked_columns == mask_specgram.size(other_axis)).sum()
num_masked_columns = torch.div(num_masked_columns, mask_specgram.size(0), rounding_mode="floor")
den = 1
for i in range(len(shape) - 2):
den *= mask_specgram.size(i)
num_masked_columns = torch.div(num_masked_columns, den, rounding_mode="floor")
if p != 1.0: if p != 1.0:
mask_param = min(mask_param, int(specgram.shape[axis] * p)) mask_param = min(mask_param, int(specgram.shape[axis] * p))
...@@ -470,7 +487,7 @@ class Functional(TestBaseMixin): ...@@ -470,7 +487,7 @@ class Functional(TestBaseMixin):
@parameterized.expand( @parameterized.expand(
list( list(
itertools.product( itertools.product(
["sinc_interpolation", "kaiser_window"], ["sinc_interp_hann", "sinc_interp_kaiser"],
[16000, 44100], [16000, 44100],
) )
) )
...@@ -481,7 +498,7 @@ class Functional(TestBaseMixin): ...@@ -481,7 +498,7 @@ class Functional(TestBaseMixin):
resampled = F.resample(waveform, sample_rate, sample_rate) resampled = F.resample(waveform, sample_rate, sample_rate)
self.assertEqual(waveform, resampled) self.assertEqual(waveform, resampled)
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")]) @parameterized.expand([("sinc_interp_hann"), ("sinc_interp_kaiser")])
def test_resample_waveform_upsample_size(self, resampling_method): def test_resample_waveform_upsample_size(self, resampling_method):
sr = 16000 sr = 16000
waveform = get_whitenoise( waveform = get_whitenoise(
...@@ -491,7 +508,7 @@ class Functional(TestBaseMixin): ...@@ -491,7 +508,7 @@ class Functional(TestBaseMixin):
upsampled = F.resample(waveform, sr, sr * 2, resampling_method=resampling_method) upsampled = F.resample(waveform, sr, sr * 2, resampling_method=resampling_method)
assert upsampled.size(-1) == waveform.size(-1) * 2 assert upsampled.size(-1) == waveform.size(-1) * 2
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")]) @parameterized.expand([("sinc_interp_hann"), ("sinc_interp_kaiser")])
def test_resample_waveform_downsample_size(self, resampling_method): def test_resample_waveform_downsample_size(self, resampling_method):
sr = 16000 sr = 16000
waveform = get_whitenoise( waveform = get_whitenoise(
...@@ -501,7 +518,7 @@ class Functional(TestBaseMixin): ...@@ -501,7 +518,7 @@ class Functional(TestBaseMixin):
downsampled = F.resample(waveform, sr, sr // 2, resampling_method=resampling_method) downsampled = F.resample(waveform, sr, sr // 2, resampling_method=resampling_method)
assert downsampled.size(-1) == waveform.size(-1) // 2 assert downsampled.size(-1) == waveform.size(-1) // 2
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")]) @parameterized.expand([("sinc_interp_hann"), ("sinc_interp_kaiser")])
def test_resample_waveform_identity_size(self, resampling_method): def test_resample_waveform_identity_size(self, resampling_method):
sr = 16000 sr = 16000
waveform = get_whitenoise( waveform = get_whitenoise(
...@@ -514,7 +531,7 @@ class Functional(TestBaseMixin): ...@@ -514,7 +531,7 @@ class Functional(TestBaseMixin):
@parameterized.expand( @parameterized.expand(
list( list(
itertools.product( itertools.product(
["sinc_interpolation", "kaiser_window"], ["sinc_interp_hann", "sinc_interp_kaiser"],
list(range(1, 20)), list(range(1, 20)),
) )
) )
...@@ -525,7 +542,7 @@ class Functional(TestBaseMixin): ...@@ -525,7 +542,7 @@ class Functional(TestBaseMixin):
@parameterized.expand( @parameterized.expand(
list( list(
itertools.product( itertools.product(
["sinc_interpolation", "kaiser_window"], ["sinc_interp_hann", "sinc_interp_kaiser"],
list(range(1, 20)), list(range(1, 20)),
) )
) )
...@@ -637,13 +654,25 @@ class Functional(TestBaseMixin): ...@@ -637,13 +654,25 @@ class Functional(TestBaseMixin):
rtol=rtol, rtol=rtol,
) )
def test_rnnt_loss_costs_and_gradients_random_data_with_numpy_fp32(self): @parameterized.expand([(True,), (False,)])
def test_rnnt_loss_costs_and_gradients_random_data_with_numpy_fp32(self, fused_log_softmax):
seed = 777 seed = 777
for i in range(5): for i in range(5):
data = rnnt_utils.get_random_data(dtype=torch.float32, device=self.device, seed=(seed + i)) data = rnnt_utils.get_random_data(
fused_log_softmax=fused_log_softmax, dtype=torch.float32, device=self.device, seed=(seed + i)
)
ref_costs, ref_gradients = rnnt_utils.compute_with_numpy_transducer(data=data) ref_costs, ref_gradients = rnnt_utils.compute_with_numpy_transducer(data=data)
self._test_costs_and_gradients(data=data, ref_costs=ref_costs, ref_gradients=ref_gradients) self._test_costs_and_gradients(data=data, ref_costs=ref_costs, ref_gradients=ref_gradients)
def test_rnnt_loss_nonfused_softmax(self):
data = rnnt_utils.get_B1_T10_U3_D4_data()
ref_costs, ref_gradients = rnnt_utils.compute_with_numpy_transducer(data=data)
self._test_costs_and_gradients(
data=data,
ref_costs=ref_costs,
ref_gradients=ref_gradients,
)
def test_psd(self): def test_psd(self):
"""Verify the ``F.psd`` method by the numpy implementation. """Verify the ``F.psd`` method by the numpy implementation.
Given the multi-channel complex-valued spectrum as the input, Given the multi-channel complex-valued spectrum as the input,
...@@ -879,6 +908,412 @@ class Functional(TestBaseMixin): ...@@ -879,6 +908,412 @@ class Functional(TestBaseMixin):
torch.tensor(specgram_enhanced, dtype=self.complex_dtype, device=self.device), specgram_enhanced_audio torch.tensor(specgram_enhanced, dtype=self.complex_dtype, device=self.device), specgram_enhanced_audio
) )
@nested_params(
[(10, 4), (4, 3, 1, 2), (2,), ()],
[(100, 43), (21, 45)],
["full", "valid", "same"],
)
def test_convolve_numerics(self, leading_dims, lengths, mode):
"""Check that convolve returns values identical to those that SciPy produces."""
L_x, L_y = lengths
x = torch.rand(*(leading_dims + (L_x,)), dtype=self.dtype, device=self.device)
y = torch.rand(*(leading_dims + (L_y,)), dtype=self.dtype, device=self.device)
actual = F.convolve(x, y, mode=mode)
num_signals = torch.tensor(leading_dims).prod() if leading_dims else 1
x_reshaped = x.reshape((num_signals, L_x))
y_reshaped = y.reshape((num_signals, L_y))
expected = [
signal.convolve(x_reshaped[i].detach().cpu().numpy(), y_reshaped[i].detach().cpu().numpy(), mode=mode)
for i in range(num_signals)
]
expected = torch.tensor(np.array(expected))
expected = expected.reshape(leading_dims + (-1,))
self.assertEqual(expected, actual)
@nested_params(
[(10, 4), (4, 3, 1, 2), (2,), ()],
[(100, 43), (21, 45)],
["full", "valid", "same"],
)
def test_fftconvolve_numerics(self, leading_dims, lengths, mode):
"""Check that fftconvolve returns values identical to those that SciPy produces."""
L_x, L_y = lengths
x = torch.rand(*(leading_dims + (L_x,)), dtype=self.dtype, device=self.device)
y = torch.rand(*(leading_dims + (L_y,)), dtype=self.dtype, device=self.device)
actual = F.fftconvolve(x, y, mode=mode)
expected = signal.fftconvolve(x.detach().cpu().numpy(), y.detach().cpu().numpy(), axes=-1, mode=mode)
expected = torch.tensor(expected)
self.assertEqual(expected, actual)
@nested_params(
["convolve", "fftconvolve"],
[(5, 2, 3)],
[(5, 1, 3), (1, 2, 3), (1, 1, 3)],
)
def test_convolve_broadcast(self, fn, x_shape, y_shape):
"""convolve works for Tensors for different shapes if they are broadcast-able"""
# 1. Test broadcast case
x = torch.rand(x_shape, dtype=self.dtype, device=self.device)
y = torch.rand(y_shape, dtype=self.dtype, device=self.device)
out1 = getattr(F, fn)(x, y)
# 2. Test without broadcast
y_clone = y.expand(x_shape).clone()
assert y is not y_clone
assert y_clone.shape == x.shape
out2 = getattr(F, fn)(x, y_clone)
# check that they are same
self.assertEqual(out1, out2)
@parameterized.expand(
[
# fmt: off
# different ndim
(0, F.convolve, (4, 3, 1, 2), (10, 4)),
(0, F.convolve, (4, 3, 1, 2), (2, 2, 2)),
(0, F.convolve, (1, ), (10, 4)),
(0, F.convolve, (1, ), (2, 2, 2)),
(0, F.fftconvolve, (4, 3, 1, 2), (10, 4)),
(0, F.fftconvolve, (4, 3, 1, 2), (2, 2, 2)),
(0, F.fftconvolve, (1, ), (10, 4)),
(0, F.fftconvolve, (1, ), (2, 2, 2)),
# non-broadcastable leading dimensions
(1, F.convolve, (5, 2, 3), (5, 3, 3)),
(1, F.convolve, (5, 2, 3), (5, 3, 4)),
(1, F.convolve, (5, 2, 3), (5, 3, 5)),
(1, F.fftconvolve, (5, 2, 3), (5, 3, 3)),
(1, F.fftconvolve, (5, 2, 3), (5, 3, 4)),
(1, F.fftconvolve, (5, 2, 3), (5, 3, 5)),
# fmt: on
],
)
def test_convolve_input_dim_check(self, case, fn, x_shape, y_shape):
"""Check that convolve properly rejects inputs with incompatible dimensions."""
x = torch.rand(*x_shape, dtype=self.dtype, device=self.device)
y = torch.rand(*y_shape, dtype=self.dtype, device=self.device)
message = [
"The operands must be the same dimension",
"Leading dimensions of x and y are not broadcastable",
][case]
with self.assertRaisesRegex(ValueError, message):
fn(x, y)
def test_add_noise_broadcast(self):
"""Check that add_noise produces correct outputs when broadcasting input dimensions."""
leading_dims = (5, 2, 3)
L = 51
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
noise = torch.rand(5, 1, 1, L, dtype=self.dtype, device=self.device)
lengths = torch.rand(5, 1, 3, dtype=self.dtype, device=self.device)
snr = torch.rand(1, 1, 1, dtype=self.dtype, device=self.device) * 10
actual = F.add_noise(waveform, noise, snr, lengths)
noise_expanded = noise.expand(*leading_dims, L)
snr_expanded = snr.expand(*leading_dims)
lengths_expanded = lengths.expand(*leading_dims)
expected = F.add_noise(waveform, noise_expanded, snr_expanded, lengths_expanded)
self.assertEqual(expected, actual)
@parameterized.expand(
[((5, 2, 3), (2, 1, 1), (5, 2), (5, 2, 3)), ((2, 1), (5,), (5,), (5,)), ((3,), (5, 2, 3), (2, 1, 1), (5, 2))]
)
def test_add_noise_leading_dim_check(self, waveform_dims, noise_dims, lengths_dims, snr_dims):
"""Check that add_noise properly rejects inputs with different leading dimension lengths."""
L = 51
waveform = torch.rand(*waveform_dims, L, dtype=self.dtype, device=self.device)
noise = torch.rand(*noise_dims, L, dtype=self.dtype, device=self.device)
lengths = torch.rand(*lengths_dims, dtype=self.dtype, device=self.device)
snr = torch.rand(*snr_dims, dtype=self.dtype, device=self.device) * 10
with self.assertRaisesRegex(ValueError, "Input leading dimensions"):
F.add_noise(waveform, noise, snr, lengths)
def test_add_noise_length_check(self):
"""Check that add_noise properly rejects inputs that have inconsistent length dimensions."""
leading_dims = (5, 2, 3)
L = 51
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
noise = torch.rand(*leading_dims, 50, dtype=self.dtype, device=self.device)
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device)
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device) * 10
with self.assertRaisesRegex(ValueError, "Length dimensions"):
F.add_noise(waveform, noise, snr, lengths)
def test_speed_identity(self):
"""speed of 1.0 does not alter input waveform and length"""
leading_dims = (5, 4, 2)
T = 1000
waveform = torch.rand(*leading_dims, T)
lengths = torch.randint(1, 1000, leading_dims)
actual_waveform, actual_lengths = F.speed(waveform, orig_freq=1000, factor=1.0, lengths=lengths)
self.assertEqual(waveform, actual_waveform)
self.assertEqual(lengths, actual_lengths)
@nested_params([0.8, 1.1, 1.2], [True, False])
def test_speed_accuracy(self, factor, use_lengths):
"""sinusoidal waveform is properly compressed by factor"""
n_to_trim = 20
sample_rate = 1000
freq = 2
times = torch.arange(0, 5, 1.0 / sample_rate)
waveform = torch.cos(2 * math.pi * freq * times).unsqueeze(0).to(self.device, self.dtype)
if use_lengths:
lengths = torch.tensor([waveform.size(1)])
else:
lengths = None
output, output_lengths = F.speed(waveform, orig_freq=sample_rate, factor=factor, lengths=lengths)
if use_lengths:
self.assertEqual(output.size(1), output_lengths[0])
else:
self.assertEqual(None, output_lengths)
new_times = torch.arange(0, 5 / factor, 1.0 / sample_rate)
expected_waveform = torch.cos(2 * math.pi * freq * factor * new_times).unsqueeze(0).to(self.device, self.dtype)
self.assertEqual(
expected_waveform[..., n_to_trim:-n_to_trim], output[..., n_to_trim:-n_to_trim], atol=1e-1, rtol=1e-4
)
@nested_params(
[(3, 2, 100), (95,)],
[0.97, 0.9, 0.68],
)
def test_preemphasis(self, input_shape, coeff):
waveform = torch.rand(*input_shape, device=self.device, dtype=self.dtype)
actual = F.preemphasis(waveform, coeff=coeff)
a_coeffs = torch.tensor([1.0, 0.0], device=self.device, dtype=self.dtype)
b_coeffs = torch.tensor([1.0, -coeff], device=self.device, dtype=self.dtype)
expected = F.lfilter(waveform, a_coeffs=a_coeffs, b_coeffs=b_coeffs)
self.assertEqual(actual, expected)
@nested_params(
[(3, 2, 100), (95,)],
[0.97, 0.9, 0.68],
)
def test_preemphasis_deemphasis_roundtrip(self, input_shape, coeff):
waveform = torch.rand(*input_shape, device=self.device, dtype=self.dtype)
preemphasized = F.preemphasis(waveform, coeff=coeff)
deemphasized = F.deemphasis(preemphasized, coeff=coeff)
self.assertEqual(deemphasized, waveform)
@parameterized.expand(
[
([[0, 1, 1, 0]], [[0, 1, 5, 1, 0]], torch.int32),
([[0, 1, 2, 3, 4]], [[0, 1, 2, 3, 4]], torch.int32),
([[3, 3, 3]], [[3, 5, 3, 5, 3]], torch.int64),
([[0, 1, 2]], [[0, 1, 1, 1, 2]], torch.int64),
]
)
def test_forced_align(self, targets, ref_path, targets_dtype):
emission = torch.tensor(
[
[
[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553],
[0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436],
[0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688],
[0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533],
[0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107],
]
],
dtype=self.dtype,
device=self.device,
)
blank = 5
batch_index = 0
ref_path = torch.tensor(ref_path, dtype=targets_dtype, device=self.device)
ref_scores = torch.tensor(
[torch.log(emission[batch_index, i, ref_path[batch_index, i]]).item() for i in range(emission.shape[1])],
dtype=emission.dtype,
device=self.device,
).unsqueeze(0)
log_probs = torch.log(emission)
targets = torch.tensor(targets, dtype=targets_dtype, device=self.device)
input_lengths = torch.tensor([log_probs.shape[1]], device=self.device)
target_lengths = torch.tensor([targets.shape[1]], device=self.device)
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
assert hyp_path.shape == ref_path.shape
assert hyp_scores.shape == ref_scores.shape
self.assertEqual(hyp_path, ref_path)
self.assertEqual(hyp_scores, ref_scores)
@parameterized.expand([(torch.int32,), (torch.int64,)])
def test_forced_align_fail(self, targets_dtype):
log_probs = torch.rand(1, 5, 6, dtype=self.dtype, device=self.device)
targets = torch.tensor([[0, 1, 2, 3, 4, 4]], dtype=targets_dtype, device=self.device)
blank = 5
input_lengths = torch.tensor([log_probs.shape[1]], device=self.device)
target_lengths = torch.tensor([targets.shape[1]], device=self.device)
with self.assertRaisesRegex(RuntimeError, r"targets length is too long for CTC"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.tensor([[5, 3, 3]], dtype=targets_dtype, device=self.device)
with self.assertRaisesRegex(ValueError, r"targets Tensor shouldn't contain blank index"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
log_probs = log_probs.int()
targets = torch.tensor([[0, 1, 2, 3]], dtype=targets_dtype, device=self.device)
with self.assertRaisesRegex(RuntimeError, r"log_probs must be float64, float32 or float16"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
log_probs = log_probs.float()
targets = targets.float()
with self.assertRaisesRegex(RuntimeError, r"targets must be int32 or int64 type"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
log_probs = torch.rand(3, 4, 6, dtype=self.dtype, device=self.device)
targets = targets.int()
with self.assertRaisesRegex(
RuntimeError, r"The batch dimension for log_probs must be 1 at the current version"
):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.randint(0, 4, (3, 4), device=self.device)
log_probs = torch.rand(1, 3, 6, dtype=self.dtype, device=self.device)
with self.assertRaisesRegex(RuntimeError, r"The batch dimension for targets must be 1 at the current version"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.tensor([[0, 1, 2, 3]], dtype=targets_dtype, device=self.device)
input_lengths = torch.randint(1, 5, (3, 5), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"input_lengths must be 1-D"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
input_lengths = torch.tensor([log_probs.shape[0]], device=self.device)
target_lengths = torch.randint(1, 5, (3, 5), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"target_lengths must be 1-D"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
input_lengths = torch.tensor([10000], device=self.device)
target_lengths = torch.tensor([targets.shape[1]], device=self.device)
with self.assertRaisesRegex(RuntimeError, r"input length mismatch"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
input_lengths = torch.tensor([log_probs.shape[1]], device=self.device)
target_lengths = torch.tensor([10000], device=self.device)
with self.assertRaisesRegex(RuntimeError, r"target length mismatch"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.tensor([[7, 8, 9, 10]], dtype=targets_dtype, device=self.device)
log_probs = torch.rand(1, 10, 5, dtype=self.dtype, device=self.device)
with self.assertRaisesRegex(ValueError, r"targets values must be less than the CTC dimension"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
targets = torch.tensor([[1, 3, 3]], dtype=targets_dtype, device=self.device)
blank = 10000
with self.assertRaisesRegex(RuntimeError, r"blank must be within \[0, num classes\)"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
def _assert_tokens(self, first, second):
assert len(first) == len(second)
for f, s in zip(first, second):
self.assertEqual(f.token, s.token)
self.assertEqual(f.score, s.score)
self.assertEqual(f.start, s.start)
self.assertEqual(f.end, s.end)
@parameterized.expand(
[
([], [], []),
([F.TokenSpan(1, 0, 1, 1.0)], [1], [1.0]),
([F.TokenSpan(1, 0, 2, 0.5)], [1, 1], [0.4, 0.6]),
([F.TokenSpan(1, 0, 3, 0.6)], [1, 1, 1], [0.5, 0.6, 0.7]),
([F.TokenSpan(1, 0, 1, 0.8), F.TokenSpan(2, 1, 2, 0.9)], [1, 2], [0.8, 0.9]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(2, 1, 3, 0.5)], [1, 2, 2], [1.0, 0.4, 0.6]),
([F.TokenSpan(1, 0, 1, 0.8), F.TokenSpan(1, 2, 3, 1.0)], [1, 0, 1], [0.8, 0.9, 1.0]),
([F.TokenSpan(1, 0, 1, 0.8), F.TokenSpan(2, 2, 3, 1.0)], [1, 0, 2], [0.8, 0.9, 1.0]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(1, 2, 4, 0.5)], [1, 0, 1, 1], [1.0, 0.1, 0.4, 0.6]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(2, 2, 4, 0.5)], [1, 0, 2, 2], [1.0, 0.1, 0.4, 0.6]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(1, 3, 4, 0.4)], [1, 0, 0, 1], [1.0, 0.9, 0.7, 0.4]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(2, 3, 4, 0.4)], [1, 0, 0, 2], [1.0, 0.9, 0.7, 0.4]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(1, 3, 5, 0.5)], [1, 0, 0, 1, 1], [1.0, 0.9, 0.8, 0.6, 0.4]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(2, 3, 5, 0.5)], [1, 0, 0, 2, 2], [1.0, 0.9, 0.8, 0.6, 0.4]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(2, 2, 3, 0.5)], [1, 1, 2], [1.0, 0.8, 0.5]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(1, 3, 4, 0.7)], [1, 1, 0, 1], [1.0, 0.8, 0.1, 0.7]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(2, 3, 4, 0.7)], [1, 1, 0, 2], [1.0, 0.8, 0.1, 0.7]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(1, 3, 5, 0.4)], [1, 1, 0, 1, 1], [1.0, 0.8, 0.1, 0.5, 0.3]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(2, 3, 5, 0.4)], [1, 1, 0, 2, 2], [1.0, 0.8, 0.1, 0.5, 0.3]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(1, 4, 5, 0.3)], [1, 1, 0, 0, 1], [1.0, 0.8, 0.1, 0.5, 0.3]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(2, 4, 5, 0.3)], [1, 1, 0, 0, 2], [1.0, 0.8, 0.1, 0.5, 0.3]),
(
[F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(1, 4, 6, 0.2)],
[1, 1, 0, 0, 1, 1],
[1.0, 0.8, 0.6, 0.5, 0.3, 0.1],
),
(
[F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(2, 4, 6, 0.2)],
[1, 1, 0, 0, 2, 2],
[1.0, 0.8, 0.6, 0.5, 0.3, 0.1],
),
]
)
def test_merge_repeated_tokens(self, expected, tokens, scores):
scores_ = torch.tensor(scores, dtype=torch.float32, device=self.device)
tokens_ = torch.tensor(tokens, dtype=torch.int64, device=self.device)
spans = F.merge_tokens(tokens_, scores_, blank=0)
print(tokens_, scores_)
self._assert_tokens(spans, expected)
# Append blanks at the beginning and at the end.
for num_prefix, num_suffix in itertools.product([0, 1, 2], repeat=2):
tokens_ = ([0] * num_prefix) + tokens + ([0] * num_suffix)
scores_ = ([0.1] * num_prefix) + scores + ([0.1] * num_suffix)
tokens_ = torch.tensor(tokens_, dtype=torch.int64, device=self.device)
scores_ = torch.tensor(scores_, dtype=torch.float32, device=self.device)
expected_ = [F.TokenSpan(s.token, s.start + num_prefix, s.end + num_prefix, s.score) for s in expected]
print(tokens_, scores_)
spans = F.merge_tokens(tokens_, scores_, blank=0)
self._assert_tokens(spans, expected_)
def test_frechet_distance_univariate(self):
r"""Check that Frechet distance is computed correctly for simple one-dimensional case."""
mu_x = torch.rand((1,), device=self.device)
sigma_x = torch.rand((1, 1), device=self.device)
mu_y = torch.rand((1,), device=self.device)
sigma_y = torch.rand((1, 1), device=self.device)
# Matrix square root reduces to scalar square root.
expected = (mu_x - mu_y) ** 2 + sigma_x + sigma_y - 2 * torch.sqrt(sigma_x * sigma_y)
expected = expected.item()
actual = F.frechet_distance(mu_x, sigma_x, mu_y, sigma_y)
self.assertEqual(expected, actual)
def test_frechet_distance_diagonal_covariance(self):
r"""Check that Frechet distance is computed correctly for case where covariance matrices are diagonal."""
N = 15
mu_x = torch.rand((N,), device=self.device)
sigma_x = torch.diag(torch.rand((N,), device=self.device))
mu_y = torch.rand((N,), device=self.device)
sigma_y = torch.diag(torch.rand((N,), device=self.device))
expected = (
torch.sum((mu_x - mu_y) ** 2) + torch.sum(sigma_x + sigma_y) - 2 * torch.sum(torch.sqrt(sigma_x * sigma_y))
)
expected = expected.item()
actual = F.frechet_distance(mu_x, sigma_x, mu_y, sigma_y)
self.assertEqual(expected, actual)
class FunctionalCPUOnly(TestBaseMixin): class FunctionalCPUOnly(TestBaseMixin):
def test_melscale_fbanks_no_warning_high_n_freq(self): def test_melscale_fbanks_no_warning_high_n_freq(self):
...@@ -898,3 +1333,27 @@ class FunctionalCPUOnly(TestBaseMixin): ...@@ -898,3 +1333,27 @@ class FunctionalCPUOnly(TestBaseMixin):
warnings.simplefilter("always") warnings.simplefilter("always")
F.melscale_fbanks(201, 0, 8000, 128, 16000) F.melscale_fbanks(201, 0, 8000, 128, 16000)
assert len(w) == 1 assert len(w) == 1
class FunctionalCUDAOnly(TestBaseMixin):
@nested_params(
[torch.half, torch.float, torch.double],
[torch.int32, torch.int64],
[(1, 50, 100), (1, 100, 100)],
[(1, 10), (1, 40), (1, 45)],
)
def test_forced_align_same_result(self, log_probs_dtype, targets_dtype, log_probs_shape, targets_shape):
log_probs = torch.rand(log_probs_shape, dtype=log_probs_dtype, device=self.device)
targets = torch.randint(1, 100, targets_shape, dtype=targets_dtype, device=self.device)
input_lengths = torch.tensor([log_probs.shape[1]], device=self.device)
target_lengths = torch.tensor([targets.shape[1]], device=self.device)
log_probs_cuda = log_probs.cuda()
targets_cuda = targets.cuda()
input_lengths_cuda = input_lengths.cuda()
target_lengths_cuda = target_lengths.cuda()
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths)
hyp_path_cuda, hyp_scores_cuda = F.forced_align(
log_probs_cuda, targets_cuda, input_lengths_cuda, target_lengths_cuda
)
self.assertEqual(hyp_path, hyp_path_cuda.cpu())
self.assertEqual(hyp_scores, hyp_scores_cuda.cpu())
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from .kaldi_compatibility_test_impl import Kaldi, KaldiCPUOnly from .kaldi_compatibility_test_impl import Kaldi
class TestKaldiCPUOnly(KaldiCPUOnly, PytorchTestCase):
dtype = torch.float32
device = torch.device("cpu")
class TestKaldiFloat32(Kaldi, PytorchTestCase): class TestKaldiFloat32(Kaldi, PytorchTestCase):
......
import torch import torch
import torchaudio.functional as F import torchaudio.functional as F
from parameterized import parameterized from torchaudio_unittest.common_utils import skipIfNoExec, TempDirMixin, TestBaseMixin
from torchaudio_unittest.common_utils import (
get_sinusoid,
load_params,
save_wav,
skipIfNoExec,
TempDirMixin,
TestBaseMixin,
)
from torchaudio_unittest.common_utils.kaldi_utils import convert_args, run_kaldi from torchaudio_unittest.common_utils.kaldi_utils import convert_args, run_kaldi
...@@ -32,25 +24,3 @@ class Kaldi(TempDirMixin, TestBaseMixin): ...@@ -32,25 +24,3 @@ class Kaldi(TempDirMixin, TestBaseMixin):
command = ["apply-cmvn-sliding"] + convert_args(**kwargs) + ["ark:-", "ark:-"] command = ["apply-cmvn-sliding"] + convert_args(**kwargs) + ["ark:-", "ark:-"]
kaldi_result = run_kaldi(command, "ark", tensor) kaldi_result = run_kaldi(command, "ark", tensor)
self.assert_equal(result, expected=kaldi_result) self.assert_equal(result, expected=kaldi_result)
class KaldiCPUOnly(TempDirMixin, TestBaseMixin):
def assert_equal(self, output, *, expected, rtol=None, atol=None):
expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol)
@parameterized.expand(load_params("kaldi_test_pitch_args.jsonl"))
@skipIfNoExec("compute-kaldi-pitch-feats")
def test_pitch_feats(self, kwargs):
"""compute_kaldi_pitch produces numerically compatible result with compute-kaldi-pitch-feats"""
sample_rate = kwargs["sample_rate"]
waveform = get_sinusoid(dtype="float32", sample_rate=sample_rate)
result = F.compute_kaldi_pitch(waveform[0], **kwargs)
waveform = get_sinusoid(dtype="int16", sample_rate=sample_rate)
wave_file = self.get_temp_path("test.wav")
save_wav(wave_file, waveform, sample_rate)
command = ["compute-kaldi-pitch-feats"] + convert_args(**kwargs) + ["scp:-", "ark:-"]
kaldi_result = run_kaldi(command, "scp", wave_file)
self.assert_equal(result, expected=kaldi_result)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment