Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
ffeba11a
Commit
ffeba11a
authored
Sep 02, 2024
by
mayp777
Browse files
UPDATE
parent
29deb085
Changes
337
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1007 additions
and
843 deletions
+1007
-843
test/torchaudio_unittest/backend/soundfile/info_test.py
test/torchaudio_unittest/backend/soundfile/info_test.py
+0
-1
test/torchaudio_unittest/backend/sox_io/info_test.py
test/torchaudio_unittest/backend/sox_io/info_test.py
+6
-279
test/torchaudio_unittest/backend/sox_io/load_test.py
test/torchaudio_unittest/backend/sox_io/load_test.py
+4
-296
test/torchaudio_unittest/backend/sox_io/save_test.py
test/torchaudio_unittest/backend/sox_io/save_test.py
+32
-69
test/torchaudio_unittest/backend/sox_io/smoke_test.py
test/torchaudio_unittest/backend/sox_io/smoke_test.py
+0
-86
test/torchaudio_unittest/backend/sox_io/torchscript_test.py
test/torchaudio_unittest/backend/sox_io/torchscript_test.py
+5
-5
test/torchaudio_unittest/backend/utils_test.py
test/torchaudio_unittest/backend/utils_test.py
+3
-3
test/torchaudio_unittest/common_utils/__init__.py
test/torchaudio_unittest/common_utils/__init__.py
+23
-3
test/torchaudio_unittest/common_utils/autograd_utils.py
test/torchaudio_unittest/common_utils/autograd_utils.py
+20
-0
test/torchaudio_unittest/common_utils/case_utils.py
test/torchaudio_unittest/common_utils/case_utils.py
+105
-26
test/torchaudio_unittest/common_utils/rnnt_utils.py
test/torchaudio_unittest/common_utils/rnnt_utils.py
+10
-1
test/torchaudio_unittest/example/emformer_rnnt/test_mustc_lightning.py
...io_unittest/example/emformer_rnnt/test_mustc_lightning.py
+76
-0
test/torchaudio_unittest/example/emformer_rnnt/test_tedlium3_lightning.py
...unittest/example/emformer_rnnt/test_tedlium3_lightning.py
+79
-0
test/torchaudio_unittest/functional/autograd_cuda_test.py
test/torchaudio_unittest/functional/autograd_cuda_test.py
+1
-0
test/torchaudio_unittest/functional/autograd_impl.py
test/torchaudio_unittest/functional/autograd_impl.py
+61
-5
test/torchaudio_unittest/functional/batch_consistency_test.py
.../torchaudio_unittest/functional/batch_consistency_test.py
+92
-17
test/torchaudio_unittest/functional/functional_cuda_test.py
test/torchaudio_unittest/functional/functional_cuda_test.py
+16
-2
test/torchaudio_unittest/functional/functional_impl.py
test/torchaudio_unittest/functional/functional_impl.py
+472
-13
test/torchaudio_unittest/functional/kaldi_compatibility_cpu_test.py
...audio_unittest/functional/kaldi_compatibility_cpu_test.py
+1
-6
test/torchaudio_unittest/functional/kaldi_compatibility_test_impl.py
...udio_unittest/functional/kaldi_compatibility_test_impl.py
+1
-31
No files found.
Too many changes to show.
To preserve performance only
337 of 337+
files are displayed.
Plain diff
Email patch
test/torchaudio_unittest/backend/soundfile/info_test.py
View file @
ffeba11a
...
@@ -117,7 +117,6 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -117,7 +117,6 @@ class TestInfo(TempDirMixin, PytorchTestCase):
with
patch
(
"soundfile.info"
,
_mock_info_func
):
with
patch
(
"soundfile.info"
,
_mock_info_func
):
with
warnings
.
catch_warnings
(
record
=
True
)
as
w
:
with
warnings
.
catch_warnings
(
record
=
True
)
as
w
:
info
=
soundfile_backend
.
info
(
"foo"
)
info
=
soundfile_backend
.
info
(
"foo"
)
assert
len
(
w
)
==
1
assert
"UNSEEN_SUBTYPE subtype is unknown to TorchAudio"
in
str
(
w
[
-
1
].
message
)
assert
"UNSEEN_SUBTYPE subtype is unknown to TorchAudio"
in
str
(
w
[
-
1
].
message
)
assert
info
.
bits_per_sample
==
0
assert
info
.
bits_per_sample
==
0
...
...
test/torchaudio_unittest/backend/sox_io/info_test.py
View file @
ffeba11a
import
io
import
itertools
import
itertools
import
os
import
tarfile
from
contextlib
import
contextmanager
from
parameterized
import
parameterized
from
parameterized
import
parameterized
from
torchaudio._internal
import
module_utils
as
_mod_utils
from
torchaudio.backend
import
sox_io_backend
from
torchaudio.backend
import
sox_io_backend
from
torchaudio.utils.sox_utils
import
get_buffer_size
,
set_buffer_size
from
torchaudio_unittest.backend.common
import
get_encoding
from
torchaudio_unittest.backend.common
import
get_bits_per_sample
,
get_encoding
from
torchaudio_unittest.common_utils
import
(
from
torchaudio_unittest.common_utils
import
(
get_asset_path
,
get_asset_path
,
get_wav_data
,
get_wav_data
,
HttpServerMixin
,
PytorchTestCase
,
PytorchTestCase
,
save_wav
,
save_wav
,
skipIfNoExec
,
skipIfNoExec
,
skipIfNoModule
,
skipIfNoSox
,
skipIfNoSox
,
skipIfNoSoxDecoder
,
sox_utils
,
sox_utils
,
TempDirMixin
,
TempDirMixin
,
)
)
...
@@ -25,10 +18,6 @@ from torchaudio_unittest.common_utils import (
...
@@ -25,10 +18,6 @@ from torchaudio_unittest.common_utils import (
from
.common
import
name_func
from
.common
import
name_func
if
_mod_utils
.
is_module_available
(
"requests"
):
import
requests
@
skipIfNoExec
(
"sox"
)
@
skipIfNoExec
(
"sox"
)
@
skipIfNoSox
@
skipIfNoSox
class
TestInfo
(
TempDirMixin
,
PytorchTestCase
):
class
TestInfo
(
TempDirMixin
,
PytorchTestCase
):
...
@@ -208,6 +197,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -208,6 +197,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
encoding
==
get_encoding
(
"amb"
,
dtype
)
assert
info
.
encoding
==
get_encoding
(
"amb"
,
dtype
)
@
skipIfNoSoxDecoder
(
"amr-nb"
)
def
test_amr_nb
(
self
):
def
test_amr_nb
(
self
):
"""`sox_io_backend.info` can check amr-nb file correctly"""
"""`sox_io_backend.info` can check amr-nb file correctly"""
duration
=
1
duration
=
1
...
@@ -287,6 +277,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -287,6 +277,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
@
skipIfNoSox
@
skipIfNoSox
@
skipIfNoSoxDecoder
(
"opus"
)
class
TestInfoOpus
(
PytorchTestCase
):
class
TestInfoOpus
(
PytorchTestCase
):
@
parameterized
.
expand
(
@
parameterized
.
expand
(
list
(
list
(
...
@@ -314,283 +305,19 @@ class TestLoadWithoutExtension(PytorchTestCase):
...
@@ -314,283 +305,19 @@ class TestLoadWithoutExtension(PytorchTestCase):
def
test_mp3
(
self
):
def
test_mp3
(
self
):
"""MP3 file without extension can be loaded
"""MP3 file without extension can be loaded
Originally, we added `format` argument for this case, but now we use FFmpeg
for MP3 decoding, which works even without `format` argument.
https://github.com/pytorch/audio/issues/1040
https://github.com/pytorch/audio/issues/1040
The file was generated with the following command
The file was generated with the following command
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
"""
"""
path
=
get_asset_path
(
"mp3_without_ext"
)
path
=
get_asset_path
(
"mp3_without_ext"
)
sinfo
=
sox_io_backend
.
info
(
path
)
sinfo
=
sox_io_backend
.
info
(
path
,
format
=
"mp3"
)
assert
sinfo
.
sample_rate
==
16000
assert
sinfo
.
sample_rate
==
16000
assert
sinfo
.
num_frames
==
8
0000
assert
sinfo
.
num_frames
==
8
1216
assert
sinfo
.
num_channels
==
1
assert
sinfo
.
num_channels
==
1
assert
sinfo
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
sinfo
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
sinfo
.
encoding
==
"MP3"
assert
sinfo
.
encoding
==
"MP3"
with
open
(
path
,
"rb"
)
as
fileobj
:
sinfo
=
sox_io_backend
.
info
(
fileobj
,
format
=
"mp3"
)
assert
sinfo
.
sample_rate
==
16000
assert
sinfo
.
num_frames
==
80000
assert
sinfo
.
num_channels
==
1
assert
sinfo
.
bits_per_sample
==
0
assert
sinfo
.
encoding
==
"MP3"
class
FileObjTestBase
(
TempDirMixin
):
def
_gen_file
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
*
,
comments
=
None
):
path
=
self
.
get_temp_path
(
f
"test.
{
ext
}
"
)
bit_depth
=
sox_utils
.
get_bit_depth
(
dtype
)
duration
=
num_frames
/
sample_rate
comment_file
=
self
.
_gen_comment_file
(
comments
)
if
comments
else
None
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
num_channels
,
encoding
=
sox_utils
.
get_encoding
(
dtype
),
bit_depth
=
bit_depth
,
duration
=
duration
,
comment_file
=
comment_file
,
)
return
path
def
_gen_comment_file
(
self
,
comments
):
comment_path
=
self
.
get_temp_path
(
"comment.txt"
)
with
open
(
comment_path
,
"w"
)
as
file_
:
file_
.
writelines
(
comments
)
return
comment_path
class
Unseekable
:
def
__init__
(
self
,
fileobj
):
self
.
fileobj
=
fileobj
def
read
(
self
,
n
):
return
self
.
fileobj
.
read
(
n
)
@
skipIfNoSox
@
skipIfNoExec
(
"sox"
)
class
TestFileObject
(
FileObjTestBase
,
PytorchTestCase
):
def
_query_fileobj
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
*
,
comments
=
None
):
path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
comments
=
comments
)
format_
=
ext
if
ext
in
[
"mp3"
]
else
None
with
open
(
path
,
"rb"
)
as
fileobj
:
return
sox_io_backend
.
info
(
fileobj
,
format_
)
def
_query_bytesio
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
format_
=
ext
if
ext
in
[
"mp3"
]
else
None
with
open
(
path
,
"rb"
)
as
file_
:
fileobj
=
io
.
BytesIO
(
file_
.
read
())
return
sox_io_backend
.
info
(
fileobj
,
format_
)
def
_query_tarfile
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
audio_path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
audio_file
=
os
.
path
.
basename
(
audio_path
)
archive_path
=
self
.
get_temp_path
(
"archive.tar.gz"
)
with
tarfile
.
TarFile
(
archive_path
,
"w"
)
as
tarobj
:
tarobj
.
add
(
audio_path
,
arcname
=
audio_file
)
format_
=
ext
if
ext
in
[
"mp3"
]
else
None
with
tarfile
.
TarFile
(
archive_path
,
"r"
)
as
tarobj
:
fileobj
=
tarobj
.
extractfile
(
audio_file
)
return
sox_io_backend
.
info
(
fileobj
,
format_
)
@
contextmanager
def
_set_buffer_size
(
self
,
buffer_size
):
try
:
original_buffer_size
=
get_buffer_size
()
set_buffer_size
(
buffer_size
)
yield
finally
:
set_buffer_size
(
original_buffer_size
)
@
parameterized
.
expand
(
[
(
"wav"
,
"float32"
),
(
"wav"
,
"int32"
),
(
"wav"
,
"int16"
),
(
"wav"
,
"uint8"
),
(
"mp3"
,
"float32"
),
(
"flac"
,
"float32"
),
(
"vorbis"
,
"float32"
),
(
"amb"
,
"int16"
),
]
)
def
test_fileobj
(
self
,
ext
,
dtype
):
"""Querying audio via file object works"""
sample_rate
=
16000
num_frames
=
3
*
sample_rate
num_channels
=
2
sinfo
=
self
.
_query_fileobj
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
{
"vorbis"
:
0
,
"mp3"
:
49536
}.
get
(
ext
,
num_frames
)
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
parameterized
.
expand
(
[
(
"vorbis"
,
"float32"
),
]
)
def
test_fileobj_large_header
(
self
,
ext
,
dtype
):
"""
For audio file with header size exceeding default buffer size:
- Querying audio via file object without enlarging buffer size fails.
- Querying audio via file object after enlarging buffer size succeeds.
"""
sample_rate
=
16000
num_frames
=
3
*
sample_rate
num_channels
=
2
comments
=
"metadata="
+
" "
.
join
([
"value"
for
_
in
range
(
1000
)])
with
self
.
assertRaises
(
RuntimeError
):
sinfo
=
self
.
_query_fileobj
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
comments
=
comments
)
with
self
.
_set_buffer_size
(
16384
):
sinfo
=
self
.
_query_fileobj
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
comments
=
comments
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
0
if
ext
in
[
"vorbis"
]
else
num_frames
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
parameterized
.
expand
(
[
(
"wav"
,
"float32"
),
(
"wav"
,
"int32"
),
(
"wav"
,
"int16"
),
(
"wav"
,
"uint8"
),
(
"mp3"
,
"float32"
),
(
"flac"
,
"float32"
),
(
"vorbis"
,
"float32"
),
(
"amb"
,
"int16"
),
]
)
def
test_bytesio
(
self
,
ext
,
dtype
):
"""Querying audio via ByteIO object works for small data"""
sample_rate
=
16000
num_frames
=
3
*
sample_rate
num_channels
=
2
sinfo
=
self
.
_query_bytesio
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
{
"vorbis"
:
0
,
"mp3"
:
49536
}.
get
(
ext
,
num_frames
)
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
parameterized
.
expand
(
[
(
"wav"
,
"float32"
),
(
"wav"
,
"int32"
),
(
"wav"
,
"int16"
),
(
"wav"
,
"uint8"
),
(
"mp3"
,
"float32"
),
(
"flac"
,
"float32"
),
(
"vorbis"
,
"float32"
),
(
"amb"
,
"int16"
),
]
)
def
test_bytesio_tiny
(
self
,
ext
,
dtype
):
"""Querying audio via ByteIO object works for small data"""
sample_rate
=
8000
num_frames
=
4
num_channels
=
2
sinfo
=
self
.
_query_bytesio
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
{
"vorbis"
:
0
,
"mp3"
:
1728
}.
get
(
ext
,
num_frames
)
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
parameterized
.
expand
(
[
(
"wav"
,
"float32"
),
(
"wav"
,
"int32"
),
(
"wav"
,
"int16"
),
(
"wav"
,
"uint8"
),
(
"mp3"
,
"float32"
),
(
"flac"
,
"float32"
),
(
"vorbis"
,
"float32"
),
(
"amb"
,
"int16"
),
]
)
def
test_tarfile
(
self
,
ext
,
dtype
):
"""Querying compressed audio via file-like object works"""
sample_rate
=
16000
num_frames
=
3.0
*
sample_rate
num_channels
=
2
sinfo
=
self
.
_query_tarfile
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
{
"vorbis"
:
0
,
"mp3"
:
49536
}.
get
(
ext
,
num_frames
)
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
skipIfNoSox
@
skipIfNoExec
(
"sox"
)
@
skipIfNoModule
(
"requests"
)
class
TestFileObjectHttp
(
HttpServerMixin
,
FileObjTestBase
,
PytorchTestCase
):
def
_query_http
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
audio_path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
audio_file
=
os
.
path
.
basename
(
audio_path
)
url
=
self
.
get_url
(
audio_file
)
format_
=
ext
if
ext
in
[
"mp3"
]
else
None
with
requests
.
get
(
url
,
stream
=
True
)
as
resp
:
return
sox_io_backend
.
info
(
Unseekable
(
resp
.
raw
),
format
=
format_
)
@
parameterized
.
expand
(
[
(
"wav"
,
"float32"
),
(
"wav"
,
"int32"
),
(
"wav"
,
"int16"
),
(
"wav"
,
"uint8"
),
(
"mp3"
,
"float32"
),
(
"flac"
,
"float32"
),
(
"vorbis"
,
"float32"
),
(
"amb"
,
"int16"
),
]
)
def
test_requests
(
self
,
ext
,
dtype
):
"""Querying compressed audio via requests works"""
sample_rate
=
16000
num_frames
=
3.0
*
sample_rate
num_channels
=
2
sinfo
=
self
.
_query_http
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
{
"vorbis"
:
0
,
"mp3"
:
49536
}.
get
(
ext
,
num_frames
)
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
skipIfNoSox
@
skipIfNoSox
class
TestInfoNoSuchFile
(
PytorchTestCase
):
class
TestInfoNoSuchFile
(
PytorchTestCase
):
...
...
test/torchaudio_unittest/backend/sox_io/load_test.py
View file @
ffeba11a
import
io
import
itertools
import
itertools
import
tarfile
import
torch
import
torch
import
torchaudio
from
parameterized
import
parameterized
from
parameterized
import
parameterized
from
torchaudio._internal
import
module_utils
as
_mod_utils
from
torchaudio.backend
import
sox_io_backend
from
torchaudio.backend
import
sox_io_backend
from
torchaudio_unittest.common_utils
import
(
from
torchaudio_unittest.common_utils
import
(
get_asset_path
,
get_asset_path
,
get_wav_data
,
get_wav_data
,
HttpServerMixin
,
load_wav
,
load_wav
,
nested_params
,
nested_params
,
PytorchTestCase
,
PytorchTestCase
,
save_wav
,
save_wav
,
skipIfNoExec
,
skipIfNoExec
,
skipIfNoModule
,
skipIfNoSox
,
skipIfNoSox
,
skipIfNoSoxDecoder
,
sox_utils
,
sox_utils
,
TempDirMixin
,
TempDirMixin
,
)
)
...
@@ -25,10 +20,6 @@ from torchaudio_unittest.common_utils import (
...
@@ -25,10 +20,6 @@ from torchaudio_unittest.common_utils import (
from
.common
import
name_func
from
.common
import
name_func
if
_mod_utils
.
is_module_available
(
"requests"
):
import
requests
class
LoadTestBase
(
TempDirMixin
,
PytorchTestCase
):
class
LoadTestBase
(
TempDirMixin
,
PytorchTestCase
):
def
assert_format
(
def
assert_format
(
self
,
self
,
...
@@ -244,6 +235,7 @@ class TestLoad(LoadTestBase):
...
@@ -244,6 +235,7 @@ class TestLoad(LoadTestBase):
),
),
name_func
=
name_func
,
name_func
=
name_func
,
)
)
@
skipIfNoSoxDecoder
(
"opus"
)
def
test_opus
(
self
,
bitrate
,
num_channels
,
compression_level
):
def
test_opus
(
self
,
bitrate
,
num_channels
,
compression_level
):
"""`sox_io_backend.load` can load opus file correctly."""
"""`sox_io_backend.load` can load opus file correctly."""
ops_path
=
get_asset_path
(
"io"
,
f
"
{
bitrate
}
_
{
compression_level
}
_
{
num_channels
}
ch.opus"
)
ops_path
=
get_asset_path
(
"io"
,
f
"
{
bitrate
}
_
{
compression_level
}
_
{
num_channels
}
ch.opus"
)
...
@@ -288,6 +280,7 @@ class TestLoad(LoadTestBase):
...
@@ -288,6 +280,7 @@ class TestLoad(LoadTestBase):
"amb"
,
sample_rate
,
num_channels
,
bit_depth
=
bit_depth
,
duration
=
1
,
encoding
=
encoding
,
normalize
=
normalize
"amb"
,
sample_rate
,
num_channels
,
bit_depth
=
bit_depth
,
duration
=
1
,
encoding
=
encoding
,
normalize
=
normalize
)
)
@
skipIfNoSoxDecoder
(
"amr-nb"
)
def
test_amr_nb
(
self
):
def
test_amr_nb
(
self
):
"""`sox_io_backend.load` can load amr_nb format correctly."""
"""`sox_io_backend.load` can load amr_nb format correctly."""
self
.
assert_format
(
"amr-nb"
,
sample_rate
=
8000
,
num_channels
=
1
,
bit_depth
=
32
,
duration
=
1
)
self
.
assert_format
(
"amr-nb"
,
sample_rate
=
8000
,
num_channels
=
1
,
bit_depth
=
32
,
duration
=
1
)
...
@@ -322,306 +315,21 @@ class TestLoadParams(TempDirMixin, PytorchTestCase):
...
@@ -322,306 +315,21 @@ class TestLoadParams(TempDirMixin, PytorchTestCase):
self
.
_test
(
torch
.
ops
.
torchaudio
.
sox_io_load_audio_file
,
frame_offset
,
num_frames
,
channels_first
,
normalize
)
self
.
_test
(
torch
.
ops
.
torchaudio
.
sox_io_load_audio_file
,
frame_offset
,
num_frames
,
channels_first
,
normalize
)
# test file-like obj
def
func
(
path
,
*
args
):
with
open
(
path
,
"rb"
)
as
fileobj
:
return
torchaudio
.
_torchaudio
.
load_audio_fileobj
(
fileobj
,
*
args
)
self
.
_test
(
func
,
frame_offset
,
num_frames
,
channels_first
,
normalize
)
@
nested_params
(
[
0
,
1
,
10
,
100
,
1000
],
[
-
1
,
1
,
10
,
100
,
1000
],
[
True
,
False
],
[
True
,
False
],
)
def
test_ffmpeg
(
self
,
frame_offset
,
num_frames
,
channels_first
,
normalize
):
"""The combination of properly changes the output tensor"""
from
torchaudio.io._compat
import
load_audio
,
load_audio_fileobj
self
.
_test
(
load_audio
,
frame_offset
,
num_frames
,
channels_first
,
normalize
)
# test file-like obj
def
func
(
path
,
*
args
):
with
open
(
path
,
"rb"
)
as
fileobj
:
return
load_audio_fileobj
(
fileobj
,
*
args
)
self
.
_test
(
func
,
frame_offset
,
num_frames
,
channels_first
,
normalize
)
@
skipIfNoSox
@
skipIfNoSox
class
TestLoadWithoutExtension
(
PytorchTestCase
):
class
TestLoadWithoutExtension
(
PytorchTestCase
):
def
test_mp3
(
self
):
def
test_mp3
(
self
):
"""MP3 file without extension can be loaded
"""MP3 file without extension can be loaded
Originally, we added `format` argument for this case, but now we use FFmpeg
for MP3 decoding, which works even without `format` argument.
https://github.com/pytorch/audio/issues/1040
https://github.com/pytorch/audio/issues/1040
The file was generated with the following command
The file was generated with the following command
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
"""
"""
path
=
get_asset_path
(
"mp3_without_ext"
)
path
=
get_asset_path
(
"mp3_without_ext"
)
_
,
sr
=
sox_io_backend
.
load
(
path
)
_
,
sr
=
sox_io_backend
.
load
(
path
,
format
=
"mp3"
)
assert
sr
==
16000
assert
sr
==
16000
with
open
(
path
,
"rb"
)
as
fileobj
:
_
,
sr
=
sox_io_backend
.
load
(
fileobj
)
assert
sr
==
16000
class
CloggedFileObj
:
def
__init__
(
self
,
fileobj
):
self
.
fileobj
=
fileobj
def
read
(
self
,
_
):
return
self
.
fileobj
.
read
(
2
)
def
seek
(
self
,
offset
,
whence
):
return
self
.
fileobj
.
seek
(
offset
,
whence
)
@
skipIfNoSox
@
skipIfNoExec
(
"sox"
)
class
TestFileObject
(
TempDirMixin
,
PytorchTestCase
):
"""
In this test suite, the result of file-like object input is compared against file path input,
because `load` function is rigrously tested for file path inputs to match libsox's result,
"""
@
parameterized
.
expand
(
[
(
"wav"
,
{
"bit_depth"
:
16
}),
(
"wav"
,
{
"bit_depth"
:
24
}),
(
"wav"
,
{
"bit_depth"
:
32
}),
(
"mp3"
,
{
"compression"
:
128
}),
(
"mp3"
,
{
"compression"
:
320
}),
(
"flac"
,
{
"compression"
:
0
}),
(
"flac"
,
{
"compression"
:
5
}),
(
"flac"
,
{
"compression"
:
8
}),
(
"vorbis"
,
{
"compression"
:
-
1
}),
(
"vorbis"
,
{
"compression"
:
10
}),
(
"amb"
,
{}),
]
)
def
test_fileobj
(
self
,
ext
,
kwargs
):
"""Loading audio via file object returns the same result as via file path."""
sample_rate
=
16000
format_
=
ext
if
ext
in
[
"mp3"
]
else
None
path
=
self
.
get_temp_path
(
f
"test.
{
ext
}
"
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
2
,
**
kwargs
)
expected
,
_
=
sox_io_backend
.
load
(
path
)
with
open
(
path
,
"rb"
)
as
fileobj
:
found
,
sr
=
sox_io_backend
.
load
(
fileobj
,
format
=
format_
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
@
parameterized
.
expand
(
[
(
"wav"
,
{
"bit_depth"
:
16
}),
(
"wav"
,
{
"bit_depth"
:
24
}),
(
"wav"
,
{
"bit_depth"
:
32
}),
(
"mp3"
,
{
"compression"
:
128
}),
(
"mp3"
,
{
"compression"
:
320
}),
(
"flac"
,
{
"compression"
:
0
}),
(
"flac"
,
{
"compression"
:
5
}),
(
"flac"
,
{
"compression"
:
8
}),
(
"vorbis"
,
{
"compression"
:
-
1
}),
(
"vorbis"
,
{
"compression"
:
10
}),
(
"amb"
,
{}),
]
)
def
test_bytesio
(
self
,
ext
,
kwargs
):
"""Loading audio via BytesIO object returns the same result as via file path."""
sample_rate
=
16000
format_
=
ext
if
ext
in
[
"mp3"
]
else
None
path
=
self
.
get_temp_path
(
f
"test.
{
ext
}
"
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
2
,
**
kwargs
)
expected
,
_
=
sox_io_backend
.
load
(
path
)
with
open
(
path
,
"rb"
)
as
file_
:
fileobj
=
io
.
BytesIO
(
file_
.
read
())
found
,
sr
=
sox_io_backend
.
load
(
fileobj
,
format
=
format_
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
@
parameterized
.
expand
(
[
(
"wav"
,
{
"bit_depth"
:
16
}),
(
"wav"
,
{
"bit_depth"
:
24
}),
(
"wav"
,
{
"bit_depth"
:
32
}),
(
"mp3"
,
{
"compression"
:
128
}),
(
"mp3"
,
{
"compression"
:
320
}),
(
"flac"
,
{
"compression"
:
0
}),
(
"flac"
,
{
"compression"
:
5
}),
(
"flac"
,
{
"compression"
:
8
}),
(
"vorbis"
,
{
"compression"
:
-
1
}),
(
"vorbis"
,
{
"compression"
:
10
}),
(
"amb"
,
{}),
]
)
def
test_bytesio_clogged
(
self
,
ext
,
kwargs
):
"""Loading audio via clogged file object returns the same result as via file path.
This test case validates the case where fileobject returns shorter bytes than requeted.
"""
sample_rate
=
16000
format_
=
ext
if
ext
in
[
"mp3"
]
else
None
path
=
self
.
get_temp_path
(
f
"test.
{
ext
}
"
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
2
,
**
kwargs
)
expected
,
_
=
sox_io_backend
.
load
(
path
)
with
open
(
path
,
"rb"
)
as
file_
:
fileobj
=
CloggedFileObj
(
io
.
BytesIO
(
file_
.
read
()))
found
,
sr
=
sox_io_backend
.
load
(
fileobj
,
format
=
format_
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
@
parameterized
.
expand
(
[
(
"wav"
,
{
"bit_depth"
:
16
}),
(
"wav"
,
{
"bit_depth"
:
24
}),
(
"wav"
,
{
"bit_depth"
:
32
}),
(
"mp3"
,
{
"compression"
:
128
}),
(
"mp3"
,
{
"compression"
:
320
}),
(
"flac"
,
{
"compression"
:
0
}),
(
"flac"
,
{
"compression"
:
5
}),
(
"flac"
,
{
"compression"
:
8
}),
(
"vorbis"
,
{
"compression"
:
-
1
}),
(
"vorbis"
,
{
"compression"
:
10
}),
(
"amb"
,
{}),
]
)
def
test_bytesio_tiny
(
self
,
ext
,
kwargs
):
"""Loading very small audio via file object returns the same result as via file path."""
sample_rate
=
16000
format_
=
ext
if
ext
in
[
"mp3"
]
else
None
path
=
self
.
get_temp_path
(
f
"test.
{
ext
}
"
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
2
,
duration
=
1
/
1600
,
**
kwargs
)
expected
,
_
=
sox_io_backend
.
load
(
path
)
with
open
(
path
,
"rb"
)
as
file_
:
fileobj
=
io
.
BytesIO
(
file_
.
read
())
found
,
sr
=
sox_io_backend
.
load
(
fileobj
,
format
=
format_
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
@
parameterized
.
expand
(
[
(
"wav"
,
{
"bit_depth"
:
16
}),
(
"wav"
,
{
"bit_depth"
:
24
}),
(
"wav"
,
{
"bit_depth"
:
32
}),
(
"mp3"
,
{
"compression"
:
128
}),
(
"mp3"
,
{
"compression"
:
320
}),
(
"flac"
,
{
"compression"
:
0
}),
(
"flac"
,
{
"compression"
:
5
}),
(
"flac"
,
{
"compression"
:
8
}),
(
"vorbis"
,
{
"compression"
:
-
1
}),
(
"vorbis"
,
{
"compression"
:
10
}),
(
"amb"
,
{}),
]
)
def
test_tarfile
(
self
,
ext
,
kwargs
):
"""Loading compressed audio via file-like object returns the same result as via file path."""
sample_rate
=
16000
format_
=
ext
if
ext
in
[
"mp3"
]
else
None
audio_file
=
f
"test.
{
ext
}
"
audio_path
=
self
.
get_temp_path
(
audio_file
)
archive_path
=
self
.
get_temp_path
(
"archive.tar.gz"
)
sox_utils
.
gen_audio_file
(
audio_path
,
sample_rate
,
num_channels
=
2
,
**
kwargs
)
expected
,
_
=
sox_io_backend
.
load
(
audio_path
)
with
tarfile
.
TarFile
(
archive_path
,
"w"
)
as
tarobj
:
tarobj
.
add
(
audio_path
,
arcname
=
audio_file
)
with
tarfile
.
TarFile
(
archive_path
,
"r"
)
as
tarobj
:
fileobj
=
tarobj
.
extractfile
(
audio_file
)
found
,
sr
=
sox_io_backend
.
load
(
fileobj
,
format
=
format_
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
class
Unseekable
:
def
__init__
(
self
,
fileobj
):
self
.
fileobj
=
fileobj
def
read
(
self
,
n
):
return
self
.
fileobj
.
read
(
n
)
@
skipIfNoSox
@
skipIfNoExec
(
"sox"
)
@
skipIfNoModule
(
"requests"
)
class
TestFileObjectHttp
(
HttpServerMixin
,
PytorchTestCase
):
@
parameterized
.
expand
(
[
(
"wav"
,
{
"bit_depth"
:
16
}),
(
"wav"
,
{
"bit_depth"
:
24
}),
(
"wav"
,
{
"bit_depth"
:
32
}),
(
"mp3"
,
{
"compression"
:
128
}),
(
"mp3"
,
{
"compression"
:
320
}),
(
"flac"
,
{
"compression"
:
0
}),
(
"flac"
,
{
"compression"
:
5
}),
(
"flac"
,
{
"compression"
:
8
}),
(
"vorbis"
,
{
"compression"
:
-
1
}),
(
"vorbis"
,
{
"compression"
:
10
}),
(
"amb"
,
{}),
]
)
def
test_requests
(
self
,
ext
,
kwargs
):
sample_rate
=
16000
format_
=
ext
if
ext
in
[
"mp3"
]
else
None
audio_file
=
f
"test.
{
ext
}
"
audio_path
=
self
.
get_temp_path
(
audio_file
)
sox_utils
.
gen_audio_file
(
audio_path
,
sample_rate
,
num_channels
=
2
,
**
kwargs
)
expected
,
_
=
sox_io_backend
.
load
(
audio_path
)
url
=
self
.
get_url
(
audio_file
)
with
requests
.
get
(
url
,
stream
=
True
)
as
resp
:
found
,
sr
=
sox_io_backend
.
load
(
Unseekable
(
resp
.
raw
),
format
=
format_
)
assert
sr
==
sample_rate
if
ext
!=
"mp3"
:
self
.
assertEqual
(
expected
,
found
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
0
,
1
,
10
,
100
,
1000
],
[
-
1
,
1
,
10
,
100
,
1000
],
)
),
name_func
=
name_func
,
)
def
test_frame
(
self
,
frame_offset
,
num_frames
):
"""num_frames and frame_offset correctly specify the region of data"""
sample_rate
=
8000
audio_file
=
"test.wav"
audio_path
=
self
.
get_temp_path
(
audio_file
)
original
=
get_wav_data
(
"float32"
,
num_channels
=
2
)
save_wav
(
audio_path
,
original
,
sample_rate
)
frame_end
=
None
if
num_frames
==
-
1
else
frame_offset
+
num_frames
expected
=
original
[:,
frame_offset
:
frame_end
]
url
=
self
.
get_url
(
audio_file
)
with
requests
.
get
(
url
,
stream
=
True
)
as
resp
:
found
,
sr
=
sox_io_backend
.
load
(
resp
.
raw
,
frame_offset
,
num_frames
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
@
skipIfNoSox
@
skipIfNoSox
class
TestLoadNoSuchFile
(
PytorchTestCase
):
class
TestLoadNoSuchFile
(
PytorchTestCase
):
...
...
test/torchaudio_unittest/backend/sox_io/save_test.py
View file @
ffeba11a
import
io
import
os
import
os
import
torch
import
torch
...
@@ -12,6 +11,7 @@ from torchaudio_unittest.common_utils import (
...
@@ -12,6 +11,7 @@ from torchaudio_unittest.common_utils import (
save_wav
,
save_wav
,
skipIfNoExec
,
skipIfNoExec
,
skipIfNoSox
,
skipIfNoSox
,
skipIfNoSoxEncoder
,
sox_utils
,
sox_utils
,
TempDirMixin
,
TempDirMixin
,
TorchaudioTestCase
,
TorchaudioTestCase
,
...
@@ -43,7 +43,6 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
...
@@ -43,7 +43,6 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
num_channels
:
int
=
2
,
num_channels
:
int
=
2
,
num_frames
:
float
=
3
*
8000
,
num_frames
:
float
=
3
*
8000
,
src_dtype
:
str
=
"int32"
,
src_dtype
:
str
=
"int32"
,
test_mode
:
str
=
"path"
,
):
):
"""`save` function produces file that is comparable with `sox` command
"""`save` function produces file that is comparable with `sox` command
...
@@ -97,37 +96,9 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
...
@@ -97,37 +96,9 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
# 2.1. Convert the original wav to target format with torchaudio
# 2.1. Convert the original wav to target format with torchaudio
data
=
load_wav
(
src_path
,
normalize
=
False
)[
0
]
data
=
load_wav
(
src_path
,
normalize
=
False
)[
0
]
if
test_mode
==
"path"
:
sox_io_backend
.
save
(
sox_io_backend
.
save
(
tgt_path
,
data
,
sample_rate
,
compression
=
compression
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
tgt_path
,
data
,
sample_rate
,
compression
=
compression
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
)
)
elif
test_mode
==
"fileobj"
:
with
open
(
tgt_path
,
"bw"
)
as
file_
:
sox_io_backend
.
save
(
file_
,
data
,
sample_rate
,
format
=
format
,
compression
=
compression
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
)
elif
test_mode
==
"bytesio"
:
file_
=
io
.
BytesIO
()
sox_io_backend
.
save
(
file_
,
data
,
sample_rate
,
format
=
format
,
compression
=
compression
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
)
file_
.
seek
(
0
)
with
open
(
tgt_path
,
"bw"
)
as
f
:
f
.
write
(
file_
.
read
())
else
:
raise
ValueError
(
f
"Unexpected test mode:
{
test_mode
}
"
)
# 2.2. Convert the target format to wav with sox
# 2.2. Convert the target format to wav with sox
sox_utils
.
convert_audio_file
(
tgt_path
,
tst_path
,
encoding
=
cmp_encoding
,
bit_depth
=
cmp_bit_depth
)
sox_utils
.
convert_audio_file
(
tgt_path
,
tst_path
,
encoding
=
cmp_encoding
,
bit_depth
=
cmp_bit_depth
)
# 2.3. Load with SciPy
# 2.3. Load with SciPy
...
@@ -150,7 +121,6 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
...
@@ -150,7 +121,6 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
@
skipIfNoSox
@
skipIfNoSox
class
SaveTest
(
SaveTestBase
):
class
SaveTest
(
SaveTestBase
):
@
nested_params
(
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
[
[
(
"PCM_U"
,
8
),
(
"PCM_U"
,
8
),
(
"PCM_S"
,
16
),
(
"PCM_S"
,
16
),
...
@@ -161,12 +131,11 @@ class SaveTest(SaveTestBase):
...
@@ -161,12 +131,11 @@ class SaveTest(SaveTestBase):
(
"ALAW"
,
8
),
(
"ALAW"
,
8
),
],
],
)
)
def
test_save_wav
(
self
,
test_mode
,
enc_params
):
def
test_save_wav
(
self
,
enc_params
):
encoding
,
bits_per_sample
=
enc_params
encoding
,
bits_per_sample
=
enc_params
self
.
assert_save_consistency
(
"wav"
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
test_mode
=
test_mode
)
self
.
assert_save_consistency
(
"wav"
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
)
@
nested_params
(
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
[
[
(
"float32"
,),
(
"float32"
,),
(
"int32"
,),
(
"int32"
,),
...
@@ -174,12 +143,11 @@ class SaveTest(SaveTestBase):
...
@@ -174,12 +143,11 @@ class SaveTest(SaveTestBase):
(
"uint8"
,),
(
"uint8"
,),
],
],
)
)
def
test_save_wav_dtype
(
self
,
test_mode
,
params
):
def
test_save_wav_dtype
(
self
,
params
):
(
dtype
,)
=
params
(
dtype
,)
=
params
self
.
assert_save_consistency
(
"wav"
,
src_dtype
=
dtype
,
test_mode
=
test_mode
)
self
.
assert_save_consistency
(
"wav"
,
src_dtype
=
dtype
)
@
nested_params
(
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
[
8
,
16
,
24
],
[
8
,
16
,
24
],
[
[
None
,
None
,
...
@@ -194,19 +162,13 @@ class SaveTest(SaveTestBase):
...
@@ -194,19 +162,13 @@ class SaveTest(SaveTestBase):
8
,
8
,
],
],
)
)
def
test_save_flac
(
self
,
test_mode
,
bits_per_sample
,
compression_level
):
def
test_save_flac
(
self
,
bits_per_sample
,
compression_level
):
self
.
assert_save_consistency
(
self
.
assert_save_consistency
(
"flac"
,
compression
=
compression_level
,
bits_per_sample
=
bits_per_sample
)
"flac"
,
compression
=
compression_level
,
bits_per_sample
=
bits_per_sample
,
test_mode
=
test_mode
)
@
nested_params
(
def
test_save_htk
(
self
):
[
"path"
,
"fileobj"
,
"bytesio"
],
self
.
assert_save_consistency
(
"htk"
,
num_channels
=
1
)
)
def
test_save_htk
(
self
,
test_mode
):
self
.
assert_save_consistency
(
"htk"
,
test_mode
=
test_mode
,
num_channels
=
1
)
@
nested_params
(
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
[
[
None
,
None
,
-
1
,
-
1
,
...
@@ -219,11 +181,10 @@ class SaveTest(SaveTestBase):
...
@@ -219,11 +181,10 @@ class SaveTest(SaveTestBase):
10
,
10
,
],
],
)
)
def
test_save_vorbis
(
self
,
test_mode
,
quality_level
):
def
test_save_vorbis
(
self
,
quality_level
):
self
.
assert_save_consistency
(
"vorbis"
,
compression
=
quality_level
,
test_mode
=
test_mode
)
self
.
assert_save_consistency
(
"vorbis"
,
compression
=
quality_level
)
@
nested_params
(
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
[
[
(
(
"PCM_S"
,
"PCM_S"
,
...
@@ -248,12 +209,11 @@ class SaveTest(SaveTestBase):
...
@@ -248,12 +209,11 @@ class SaveTest(SaveTestBase):
(
"ALAW"
,
32
),
(
"ALAW"
,
32
),
],
],
)
)
def
test_save_sphere
(
self
,
test_mode
,
enc_params
):
def
test_save_sphere
(
self
,
enc_params
):
encoding
,
bits_per_sample
=
enc_params
encoding
,
bits_per_sample
=
enc_params
self
.
assert_save_consistency
(
"sph"
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
test_mode
=
test_mode
)
self
.
assert_save_consistency
(
"sph"
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
)
@
nested_params
(
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
[
[
(
(
"PCM_U"
,
"PCM_U"
,
...
@@ -289,12 +249,11 @@ class SaveTest(SaveTestBase):
...
@@ -289,12 +249,11 @@ class SaveTest(SaveTestBase):
),
),
],
],
)
)
def
test_save_amb
(
self
,
test_mode
,
enc_params
):
def
test_save_amb
(
self
,
enc_params
):
encoding
,
bits_per_sample
=
enc_params
encoding
,
bits_per_sample
=
enc_params
self
.
assert_save_consistency
(
"amb"
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
test_mode
=
test_mode
)
self
.
assert_save_consistency
(
"amb"
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
)
@
nested_params
(
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
[
[
None
,
None
,
0
,
0
,
...
@@ -307,18 +266,16 @@ class SaveTest(SaveTestBase):
...
@@ -307,18 +266,16 @@ class SaveTest(SaveTestBase):
7
,
7
,
],
],
)
)
def
test_save_amr_nb
(
self
,
test_mode
,
bit_rate
):
@
skipIfNoSoxEncoder
(
"amr-nb"
)
self
.
assert_save_consistency
(
"amr-nb"
,
compression
=
bit_rate
,
num_channels
=
1
,
test_mode
=
test_mode
)
def
test_save_amr_nb
(
self
,
bit_rate
):
self
.
assert_save_consistency
(
"amr-nb"
,
compression
=
bit_rate
,
num_channels
=
1
)
@
nested_params
(
def
test_save_gsm
(
self
):
[
"path"
,
"fileobj"
,
"bytesio"
],
self
.
assert_save_consistency
(
"gsm"
,
num_channels
=
1
)
)
def
test_save_gsm
(
self
,
test_mode
):
self
.
assert_save_consistency
(
"gsm"
,
num_channels
=
1
,
test_mode
=
test_mode
)
with
self
.
assertRaises
(
RuntimeError
,
msg
=
"gsm format only supports single channel audio."
):
with
self
.
assertRaises
(
RuntimeError
,
msg
=
"gsm format only supports single channel audio."
):
self
.
assert_save_consistency
(
"gsm"
,
num_channels
=
2
,
test_mode
=
test_mode
)
self
.
assert_save_consistency
(
"gsm"
,
num_channels
=
2
)
with
self
.
assertRaises
(
RuntimeError
,
msg
=
"gsm format only supports a sampling rate of 8kHz."
):
with
self
.
assertRaises
(
RuntimeError
,
msg
=
"gsm format only supports a sampling rate of 8kHz."
):
self
.
assert_save_consistency
(
"gsm"
,
sample_rate
=
16000
,
test_mode
=
test_mode
)
self
.
assert_save_consistency
(
"gsm"
,
sample_rate
=
16000
)
@
parameterized
.
expand
(
@
parameterized
.
expand
(
[
[
...
@@ -326,12 +283,18 @@ class SaveTest(SaveTestBase):
...
@@ -326,12 +283,18 @@ class SaveTest(SaveTestBase):
(
"flac"
,),
(
"flac"
,),
(
"vorbis"
,),
(
"vorbis"
,),
(
"sph"
,
"PCM_S"
,
16
),
(
"sph"
,
"PCM_S"
,
16
),
(
"amr-nb"
,),
(
"amb"
,
"PCM_S"
,
16
),
(
"amb"
,
"PCM_S"
,
16
),
],
],
name_func
=
name_func
,
name_func
=
name_func
,
)
)
def
test_save_large
(
self
,
format
,
encoding
=
None
,
bits_per_sample
=
None
):
def
test_save_large
(
self
,
format
,
encoding
=
None
,
bits_per_sample
=
None
):
self
.
_test_save_large
(
format
,
encoding
,
bits_per_sample
)
@
skipIfNoSoxEncoder
(
"amr-nb"
)
def
test_save_large_amr_nb
(
self
):
self
.
_test_save_large
(
"amr-nb"
)
def
_test_save_large
(
self
,
format
,
encoding
=
None
,
bits_per_sample
=
None
):
"""`sox_io_backend.save` can save large files."""
"""`sox_io_backend.save` can save large files."""
sample_rate
=
8000
sample_rate
=
8000
one_hour
=
60
*
60
*
sample_rate
one_hour
=
60
*
60
*
sample_rate
...
...
test/torchaudio_unittest/backend/sox_io/smoke_test.py
View file @
ffeba11a
import
io
import
itertools
import
itertools
from
parameterized
import
parameterized
from
parameterized
import
parameterized
...
@@ -89,88 +88,3 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase):
...
@@ -89,88 +88,3 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase):
def
test_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
def
test_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
"""Run smoke test on flac format"""
"""Run smoke test on flac format"""
self
.
run_smoke_test
(
"flac"
,
sample_rate
,
num_channels
,
compression
=
compression_level
)
self
.
run_smoke_test
(
"flac"
,
sample_rate
,
num_channels
,
compression
=
compression_level
)
@
skipIfNoSox
class
SmokeTestFileObj
(
TorchaudioTestCase
):
"""Run smoke test on various audio format
The purpose of this test suite is to verify that sox_io_backend functionalities do not exhibit
abnormal behaviors.
This test suite should be able to run without any additional tools (such as sox command),
however without such tools, the correctness of each function cannot be verified.
"""
def
run_smoke_test
(
self
,
ext
,
sample_rate
,
num_channels
,
*
,
compression
=
None
,
dtype
=
"float32"
):
duration
=
1
num_frames
=
sample_rate
*
duration
original
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
num_frames
)
fileobj
=
io
.
BytesIO
()
# 1. run save
sox_io_backend
.
save
(
fileobj
,
original
,
sample_rate
,
compression
=
compression
,
format
=
ext
)
# 2. run info
fileobj
.
seek
(
0
)
info
=
sox_io_backend
.
info
(
fileobj
,
format
=
ext
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_channels
==
num_channels
# 3. run load
fileobj
.
seek
(
0
)
loaded
,
sr
=
sox_io_backend
.
load
(
fileobj
,
normalize
=
False
,
format
=
ext
)
assert
sr
==
sample_rate
assert
loaded
.
shape
[
0
]
==
num_channels
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
[
8000
,
16000
],
[
1
,
2
],
)
),
name_func
=
name_func
,
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""Run smoke test on wav format"""
self
.
run_smoke_test
(
"wav"
,
sample_rate
,
num_channels
,
dtype
=
dtype
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
[
-
4.2
,
-
0.2
,
0
,
0.2
,
96
,
128
,
160
,
192
,
224
,
256
,
320
],
)
)
)
def
test_mp3
(
self
,
sample_rate
,
num_channels
,
bit_rate
):
"""Run smoke test on mp3 format"""
self
.
run_smoke_test
(
"mp3"
,
sample_rate
,
num_channels
,
compression
=
bit_rate
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
[
-
1
,
0
,
1
,
2
,
3
,
3.6
,
5
,
10
],
)
)
)
def
test_vorbis
(
self
,
sample_rate
,
num_channels
,
quality_level
):
"""Run smoke test on vorbis format"""
self
.
run_smoke_test
(
"vorbis"
,
sample_rate
,
num_channels
,
compression
=
quality_level
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
list
(
range
(
9
)),
)
),
name_func
=
name_func
,
)
def
test_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
"""Run smoke test on flac format"""
self
.
run_smoke_test
(
"flac"
,
sample_rate
,
num_channels
,
compression
=
compression_level
)
test/torchaudio_unittest/backend/sox_io/torchscript_test.py
View file @
ffeba11a
...
@@ -20,11 +20,11 @@ from .common import get_enc_params, name_func
...
@@ -20,11 +20,11 @@ from .common import get_enc_params, name_func
def
py_info_func
(
filepath
:
str
)
->
torchaudio
.
backend
.
sox_io_backend
.
AudioMetaData
:
def
py_info_func
(
filepath
:
str
)
->
torchaudio
.
backend
.
sox_io_backend
.
AudioMetaData
:
return
torchaudio
.
info
(
filepath
)
return
torchaudio
.
backend
.
sox_io_backend
.
info
(
filepath
)
def
py_load_func
(
filepath
:
str
,
normalize
:
bool
,
channels_first
:
bool
):
def
py_load_func
(
filepath
:
str
,
normalize
:
bool
,
channels_first
:
bool
):
return
torchaudio
.
load
(
filepath
,
normalize
=
normalize
,
channels_first
=
channels_first
)
return
torchaudio
.
backend
.
sox_io_backend
.
load
(
filepath
,
normalize
=
normalize
,
channels_first
=
channels_first
)
def
py_save_func
(
def
py_save_func
(
...
@@ -36,7 +36,9 @@ def py_save_func(
...
@@ -36,7 +36,9 @@ def py_save_func(
encoding
:
Optional
[
str
]
=
None
,
encoding
:
Optional
[
str
]
=
None
,
bits_per_sample
:
Optional
[
int
]
=
None
,
bits_per_sample
:
Optional
[
int
]
=
None
,
):
):
torchaudio
.
save
(
filepath
,
tensor
,
sample_rate
,
channels_first
,
compression
,
None
,
encoding
,
bits_per_sample
)
torchaudio
.
backend
.
sox_io_backend
.
save
(
filepath
,
tensor
,
sample_rate
,
channels_first
,
compression
,
None
,
encoding
,
bits_per_sample
)
@
skipIfNoExec
(
"sox"
)
@
skipIfNoExec
(
"sox"
)
...
@@ -44,8 +46,6 @@ def py_save_func(
...
@@ -44,8 +46,6 @@ def py_save_func(
class
SoxIO
(
TempDirMixin
,
TorchaudioTestCase
):
class
SoxIO
(
TempDirMixin
,
TorchaudioTestCase
):
"""TorchScript-ability Test suite for `sox_io_backend`"""
"""TorchScript-ability Test suite for `sox_io_backend`"""
backend
=
"sox_io"
@
parameterized
.
expand
(
@
parameterized
.
expand
(
list
(
list
(
itertools
.
product
(
itertools
.
product
(
...
...
test/torchaudio_unittest/backend/utils_test.py
View file @
ffeba11a
...
@@ -9,11 +9,11 @@ class BackendSwitchMixin:
...
@@ -9,11 +9,11 @@ class BackendSwitchMixin:
backend_module
=
None
backend_module
=
None
def
test_switch
(
self
):
def
test_switch
(
self
):
torchaudio
.
set_audio_backend
(
self
.
backend
)
torchaudio
.
backend
.
utils
.
set_audio_backend
(
self
.
backend
)
if
self
.
backend
is
None
:
if
self
.
backend
is
None
:
assert
torchaudio
.
get_audio_backend
()
is
None
assert
torchaudio
.
backend
.
utils
.
get_audio_backend
()
is
None
else
:
else
:
assert
torchaudio
.
get_audio_backend
()
==
self
.
backend
assert
torchaudio
.
backend
.
utils
.
get_audio_backend
()
==
self
.
backend
assert
torchaudio
.
load
==
self
.
backend_module
.
load
assert
torchaudio
.
load
==
self
.
backend_module
.
load
assert
torchaudio
.
save
==
self
.
backend_module
.
save
assert
torchaudio
.
save
==
self
.
backend_module
.
save
assert
torchaudio
.
info
==
self
.
backend_module
.
info
assert
torchaudio
.
info
==
self
.
backend_module
.
info
...
...
test/torchaudio_unittest/common_utils/__init__.py
View file @
ffeba11a
from
.autograd_utils
import
use_deterministic_algorithms
from
.backend_utils
import
set_audio_backend
from
.backend_utils
import
set_audio_backend
from
.case_utils
import
(
from
.case_utils
import
(
disabledInCI
,
HttpServerMixin
,
HttpServerMixin
,
is_ffmpeg_available
,
is_ffmpeg_available
,
PytorchTestCase
,
PytorchTestCase
,
skipIfCudaSmallMemory
,
skipIfNoAudioDevice
,
skipIfNoCtcDecoder
,
skipIfNoCtcDecoder
,
skipIfNoCuCtcDecoder
,
skipIfNoCuda
,
skipIfNoCuda
,
skipIfNoExec
,
skipIfNoExec
,
skipIfkmeMark
,
skipIfNoFFmpeg
,
skipIfNoFFmpeg
,
skipIfNoKaldi
,
skipIfNoHWAccel
,
skipIfNoMacOS
,
skipIfNoModule
,
skipIfNoModule
,
skipIfNoQengine
,
skipIfNoQengine
,
skipIfNoRIR
,
skipIfNoSox
,
skipIfNoSox
,
skipIfNoSoxDecoder
,
skipIfNoSoxEncoder
,
skipIfPy310
,
skipIfPy310
,
skipIfRocm
,
skipIfRocm
,
TempDirMixin
,
TempDirMixin
,
TestBaseMixin
,
TestBaseMixin
,
TorchaudioTestCase
,
TorchaudioTestCase
,
zip_equal
,
)
)
from
.data_utils
import
get_asset_path
,
get_sinusoid
,
get_spectrogram
,
get_whitenoise
from
.data_utils
import
get_asset_path
,
get_sinusoid
,
get_spectrogram
,
get_whitenoise
from
.func_utils
import
torch_script
from
.func_utils
import
torch_script
...
@@ -35,17 +46,24 @@ __all__ = [
...
@@ -35,17 +46,24 @@ __all__ = [
"PytorchTestCase"
,
"PytorchTestCase"
,
"TorchaudioTestCase"
,
"TorchaudioTestCase"
,
"is_ffmpeg_available"
,
"is_ffmpeg_available"
,
"skipIfNoAudioDevice"
,
"skipIfNoCtcDecoder"
,
"skipIfNoCtcDecoder"
,
"skipIfNoCuCtcDecoder"
,
"skipIfNoCuda"
,
"skipIfNoCuda"
,
"skipIfCudaSmallMemory"
,
"skipIfNoExec"
,
"skipIfNoExec"
,
"skipIfNoMacOS"
,
"skipIfNoModule"
,
"skipIfNoModule"
,
"skipIfNo
Kaldi
"
,
"skipIfNo
RIR
"
,
"skipIfNoSox"
,
"skipIfNoSox"
,
"skipIfNoSoxBackend"
,
"skipIfNoSoxDecoder"
,
"skipIfNoSoxEncoder"
,
"skipIfRocm"
,
"skipIfRocm"
,
"skipIfNoQengine"
,
"skipIfNoQengine"
,
"skipIfNoFFmpeg"
,
"skipIfNoFFmpeg"
,
"skipIfNoHWAccel"
,
"skipIfPy310"
,
"skipIfPy310"
,
"disabledInCI"
,
"get_wav_data"
,
"get_wav_data"
,
"normalize_wav"
,
"normalize_wav"
,
"load_wav"
,
"load_wav"
,
...
@@ -57,4 +75,6 @@ __all__ = [
...
@@ -57,4 +75,6 @@ __all__ = [
"get_image"
,
"get_image"
,
"rgb_to_gray"
,
"rgb_to_gray"
,
"rgb_to_yuv_ccir"
,
"rgb_to_yuv_ccir"
,
"use_deterministic_algorithms"
,
"zip_equal"
,
]
]
test/torchaudio_unittest/common_utils/autograd_utils.py
0 → 100644
View file @
ffeba11a
import
contextlib
import
torch
@
contextlib
.
contextmanager
def
use_deterministic_algorithms
(
mode
:
bool
,
warn_only
:
bool
):
r
"""
This context manager can be used to temporarily enable or disable deterministic algorithms.
Upon exiting the context manager, the previous state of the flag will be restored.
"""
previous_mode
:
bool
=
torch
.
are_deterministic_algorithms_enabled
()
previous_warn_only
:
bool
=
torch
.
is_deterministic_algorithms_warn_only_enabled
()
try
:
torch
.
use_deterministic_algorithms
(
mode
,
warn_only
=
warn_only
)
yield
{}
except
RuntimeError
as
err
:
raise
err
finally
:
torch
.
use_deterministic_algorithms
(
previous_mode
,
warn_only
=
previous_warn_only
)
test/torchaudio_unittest/common_utils/case_utils.py
View file @
ffeba11a
...
@@ -6,11 +6,13 @@ import sys
...
@@ -6,11 +6,13 @@ import sys
import
tempfile
import
tempfile
import
time
import
time
import
unittest
import
unittest
from
itertools
import
zip_longest
import
torch
import
torch
import
torchaudio
import
torchaudio
from
torch.testing._internal.common_utils
import
TestCase
as
PytorchTestCase
from
torch.testing._internal.common_utils
import
TestCase
as
PytorchTestCase
from
torchaudio._internal.module_utils
import
is_kaldi_available
,
is_module_available
,
is_sox_available
from
torchaudio._internal.module_utils
import
eval_env
,
is_module_available
from
torchaudio.utils.ffmpeg_utils
import
get_video_decoders
,
get_video_encoders
from
.backend_utils
import
set_audio_backend
from
.backend_utils
import
set_audio_backend
...
@@ -65,7 +67,7 @@ class HttpServerMixin(TempDirMixin):
...
@@ -65,7 +67,7 @@ class HttpServerMixin(TempDirMixin):
"""
"""
_proc
=
None
_proc
=
None
_port
=
8000
_port
=
12345
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
...
@@ -110,10 +112,11 @@ class TorchaudioTestCase(TestBaseMixin, PytorchTestCase):
...
@@ -110,10 +112,11 @@ class TorchaudioTestCase(TestBaseMixin, PytorchTestCase):
def
is_ffmpeg_available
():
def
is_ffmpeg_available
():
return
torchaudio
.
_extension
.
_FFMPEG_
INITIALIZED
return
torchaudio
.
_extension
.
_FFMPEG_
EXT
is
not
None
_IS_CTC_DECODER_AVAILABLE
=
None
_IS_CTC_DECODER_AVAILABLE
=
None
_IS_CUDA_CTC_DECODER_AVAILABLE
=
None
def
is_ctc_decoder_available
():
def
is_ctc_decoder_available
():
...
@@ -128,22 +131,16 @@ def is_ctc_decoder_available():
...
@@ -128,22 +131,16 @@ def is_ctc_decoder_available():
return
_IS_CTC_DECODER_AVAILABLE
return
_IS_CTC_DECODER_AVAILABLE
def
_eval_env
(
var
,
default
):
def
is_cuda_ctc_decoder_available
():
if
var
not
in
os
.
environ
:
global
_IS_CUDA_CTC_DECODER_AVAILABLE
return
default
if
_IS_CUDA_CTC_DECODER_AVAILABLE
is
None
:
try
:
from
torchaudio.models.decoder
import
CUCTCDecoder
# noqa: F401
val
=
os
.
environ
.
get
(
var
,
"0"
)
_IS_CUDA_CTC_DECODER_AVAILABLE
=
True
trues
=
[
"1"
,
"true"
,
"TRUE"
,
"on"
,
"ON"
,
"yes"
,
"YES"
]
except
Exception
:
falses
=
[
"0"
,
"false"
,
"FALSE"
,
"off"
,
"OFF"
,
"no"
,
"NO"
]
_IS_CUDA_CTC_DECODER_AVAILABLE
=
False
if
val
in
trues
:
return
_IS_CUDA_CTC_DECODER_AVAILABLE
return
True
if
val
not
in
falses
:
# fmt: off
raise
RuntimeError
(
f
"Unexpected environment variable value `
{
var
}
=
{
val
}
`. "
f
"Expected one of
{
trues
+
falses
}
"
)
# fmt: on
return
False
def
_fail
(
reason
):
def
_fail
(
reason
):
...
@@ -170,7 +167,7 @@ def _pass(test_item):
...
@@ -170,7 +167,7 @@ def _pass(test_item):
return
test_item
return
test_item
_IN_CI
=
_
eval_env
(
"CI"
,
default
=
False
)
_IN_CI
=
eval_env
(
"CI"
,
default
=
False
)
def
_skipIf
(
condition
,
reason
,
key
):
def
_skipIf
(
condition
,
reason
,
key
):
...
@@ -180,7 +177,7 @@ def _skipIf(condition, reason, key):
...
@@ -180,7 +177,7 @@ def _skipIf(condition, reason, key):
# In CI, default to fail, so as to prevent accidental skip.
# In CI, default to fail, so as to prevent accidental skip.
# In other env, default to skip
# In other env, default to skip
var
=
f
"TORCHAUDIO_TEST_ALLOW_SKIP_IF_
{
key
}
"
var
=
f
"TORCHAUDIO_TEST_ALLOW_SKIP_IF_
{
key
}
"
skip_allowed
=
_
eval_env
(
var
,
default
=
not
_IN_CI
)
skip_allowed
=
eval_env
(
var
,
default
=
not
_IN_CI
)
if
skip_allowed
:
if
skip_allowed
:
return
unittest
.
skip
(
reason
)
return
unittest
.
skip
(
reason
)
return
_fail
(
f
"
{
reason
}
But the test cannot be skipped. (CI=
{
_IN_CI
}
,
{
var
}
=
{
skip_allowed
}
.)"
)
return
_fail
(
f
"
{
reason
}
But the test cannot be skipped. (CI=
{
_IN_CI
}
,
{
var
}
=
{
skip_allowed
}
.)"
)
...
@@ -207,23 +204,53 @@ skipIfNoCuda = _skipIf(
...
@@ -207,23 +204,53 @@ skipIfNoCuda = _skipIf(
reason
=
"CUDA is not available."
,
reason
=
"CUDA is not available."
,
key
=
"NO_CUDA"
,
key
=
"NO_CUDA"
,
)
)
# Skip test if CUDA memory is not enough
# TODO: detect the real CUDA memory size and allow call site to configure how much the test needs
skipIfCudaSmallMemory
=
_skipIf
(
"CI"
in
os
.
environ
and
torch
.
cuda
.
is_available
(),
# temporary
reason
=
"CUDA does not have enough memory."
,
key
=
"CUDA_SMALL_MEMORY"
,
)
skipIfNoSox
=
_skipIf
(
skipIfNoSox
=
_skipIf
(
not
is_sox_available
()
,
not
torchaudio
.
_extension
.
_SOX_INITIALIZED
,
reason
=
"Sox features are not available."
,
reason
=
"Sox features are not available."
,
key
=
"NO_SOX"
,
key
=
"NO_SOX"
,
)
)
skipIfNoKaldi
=
_skipIf
(
not
is_kaldi_available
(),
reason
=
"Kaldi features are not available."
,
def
skipIfNoSoxDecoder
(
ext
):
key
=
"NO_KALDI"
,
return
_skipIf
(
not
torchaudio
.
_extension
.
_SOX_INITIALIZED
or
ext
not
in
torchaudio
.
utils
.
sox_utils
.
list_read_formats
(),
f
'sox does not handle "
{
ext
}
" for read.'
,
key
=
"NO_SOX_DECODER"
,
)
def
skipIfNoSoxEncoder
(
ext
):
return
_skipIf
(
not
torchaudio
.
_extension
.
_SOX_INITIALIZED
or
ext
not
in
torchaudio
.
utils
.
sox_utils
.
list_write_formats
(),
f
'sox does not handle "
{
ext
}
" for write.'
,
key
=
"NO_SOX_ENCODER"
,
)
skipIfNoRIR
=
_skipIf
(
not
torchaudio
.
_extension
.
_IS_RIR_AVAILABLE
,
reason
=
"RIR features are not available."
,
key
=
"NO_RIR"
,
)
)
skipIfNoCtcDecoder
=
_skipIf
(
skipIfNoCtcDecoder
=
_skipIf
(
not
is_ctc_decoder_available
(),
not
is_ctc_decoder_available
(),
reason
=
"CTC decoder not available."
,
reason
=
"CTC decoder not available."
,
key
=
"NO_CTC_DECODER"
,
key
=
"NO_CTC_DECODER"
,
)
)
skipIfNoCuCtcDecoder
=
_skipIf
(
not
is_cuda_ctc_decoder_available
(),
reason
=
"CUCTC decoder not available."
,
key
=
"NO_CUCTC_DECODER"
,
)
skipIfRocm
=
_skipIf
(
skipIfRocm
=
_skipIf
(
_
eval_env
(
"TORCHAUDIO_TEST_WITH_ROCM"
,
default
=
False
),
eval_env
(
"TORCHAUDIO_TEST_WITH_ROCM"
,
default
=
False
),
reason
=
"The test doesn't currently work on the ROCm stack."
,
reason
=
"The test doesn't currently work on the ROCm stack."
,
key
=
"ON_ROCM"
,
key
=
"ON_ROCM"
,
)
)
...
@@ -245,3 +272,55 @@ skipIfPy310 = _skipIf(
...
@@ -245,3 +272,55 @@ skipIfPy310 = _skipIf(
),
),
key
=
"ON_PYTHON_310"
,
key
=
"ON_PYTHON_310"
,
)
)
skipIfNoAudioDevice
=
_skipIf
(
not
torchaudio
.
utils
.
ffmpeg_utils
.
get_output_devices
(),
reason
=
"No output audio device is available."
,
key
=
"NO_AUDIO_OUT_DEVICE"
,
)
skipIfNoMacOS
=
_skipIf
(
sys
.
platform
!=
"darwin"
,
reason
=
"This feature is only available for MacOS."
,
key
=
"NO_MACOS"
,
)
disabledInCI
=
_skipIf
(
"CI"
in
os
.
environ
,
reason
=
"Tests are failing on CI consistently. Disabled while investigating."
,
key
=
"TEMPORARY_DISABLED"
,
)
def
skipIfNoHWAccel
(
name
):
key
=
"NO_HW_ACCEL"
if
not
is_ffmpeg_available
():
return
_skipIf
(
True
,
reason
=
"ffmpeg features are not available."
,
key
=
key
)
if
not
torch
.
cuda
.
is_available
():
return
_skipIf
(
True
,
reason
=
"CUDA is not available."
,
key
=
key
)
if
torchaudio
.
_extension
.
_check_cuda_version
()
is
None
:
return
_skipIf
(
True
,
"Torchaudio is not compiled with CUDA."
,
key
=
key
)
if
name
not
in
get_video_decoders
()
and
name
not
in
get_video_encoders
():
return
_skipIf
(
True
,
f
"
{
name
}
is not in the list of available decoders or encoders"
,
key
=
key
)
return
_pass
def
checkkme
():
res
=
subprocess
.
run
(
'rocminfo | grep gfx928'
,
shell
=
True
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
if
res
.
stdout
:
return
True
return
False
iskme
=
checkkme
()
skipIfkmeMark
=
_skipIf
(
iskme
,
reason
=
"not support fp64 in kme for this case"
,
key
=
"NOT_SUPPORT_FP64_IN_KME"
,
)
def
zip_equal
(
*
iterables
):
"""With the regular Python `zip` function, if one iterable is longer than the other,
the remainder portions are ignored.This is resolved in Python 3.10 where we can use
`strict=True` in the `zip` function
From https://github.com/pytorch/text/blob/c047efeba813ac943cb8046a49e858a8b529d577/test/torchtext_unittest/common/case_utils.py#L45-L54 # noqa: E501
"""
sentinel
=
object
()
for
combo
in
zip_longest
(
*
iterables
,
fillvalue
=
sentinel
):
if
sentinel
in
combo
:
raise
ValueError
(
"Iterables have different lengths"
)
yield
combo
test/torchaudio_unittest/common_utils/rnnt_utils.py
View file @
ffeba11a
...
@@ -189,13 +189,19 @@ def compute_with_numpy_transducer(data):
...
@@ -189,13 +189,19 @@ def compute_with_numpy_transducer(data):
def
compute_with_pytorch_transducer
(
data
):
def
compute_with_pytorch_transducer
(
data
):
fused_log_softmax
=
data
.
get
(
"fused_log_softmax"
,
True
)
input
=
data
[
"logits"
]
if
not
fused_log_softmax
:
input
=
torch
.
nn
.
functional
.
log_softmax
(
input
,
dim
=-
1
)
costs
=
rnnt_loss
(
costs
=
rnnt_loss
(
logits
=
data
[
"logits"
]
,
logits
=
input
,
logit_lengths
=
data
[
"logit_lengths"
],
logit_lengths
=
data
[
"logit_lengths"
],
target_lengths
=
data
[
"target_lengths"
],
target_lengths
=
data
[
"target_lengths"
],
targets
=
data
[
"targets"
],
targets
=
data
[
"targets"
],
blank
=
data
[
"blank"
],
blank
=
data
[
"blank"
],
reduction
=
"none"
,
reduction
=
"none"
,
fused_log_softmax
=
fused_log_softmax
,
)
)
loss
=
torch
.
sum
(
costs
)
loss
=
torch
.
sum
(
costs
)
...
@@ -260,6 +266,7 @@ def get_B1_T10_U3_D4_data(
...
@@ -260,6 +266,7 @@ def get_B1_T10_U3_D4_data(
data
[
"target_lengths"
]
=
torch
.
tensor
([
2
,
2
],
dtype
=
torch
.
int32
,
device
=
device
)
data
[
"target_lengths"
]
=
torch
.
tensor
([
2
,
2
],
dtype
=
torch
.
int32
,
device
=
device
)
data
[
"targets"
]
=
torch
.
tensor
([[
1
,
2
],
[
1
,
2
]],
dtype
=
torch
.
int32
,
device
=
device
)
data
[
"targets"
]
=
torch
.
tensor
([[
1
,
2
],
[
1
,
2
]],
dtype
=
torch
.
int32
,
device
=
device
)
data
[
"blank"
]
=
0
data
[
"blank"
]
=
0
data
[
"fused_log_softmax"
]
=
False
return
data
return
data
...
@@ -552,6 +559,7 @@ def get_random_data(
...
@@ -552,6 +559,7 @@ def get_random_data(
max_U
=
32
,
max_U
=
32
,
max_D
=
40
,
max_D
=
40
,
blank
=-
1
,
blank
=-
1
,
fused_log_softmax
=
True
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
CPU_DEVICE
,
device
=
CPU_DEVICE
,
seed
=
None
,
seed
=
None
,
...
@@ -591,6 +599,7 @@ def get_random_data(
...
@@ -591,6 +599,7 @@ def get_random_data(
"logit_lengths"
:
logit_lengths
,
"logit_lengths"
:
logit_lengths
,
"target_lengths"
:
target_lengths
,
"target_lengths"
:
target_lengths
,
"blank"
:
blank
,
"blank"
:
blank
,
"fused_log_softmax"
:
fused_log_softmax
,
}
}
...
...
test/torchaudio_unittest/example/emformer_rnnt/test_mustc_lightning.py
0 → 100644
View file @
ffeba11a
from
contextlib
import
contextmanager
from
functools
import
partial
from
unittest.mock
import
patch
import
torch
from
parameterized
import
parameterized
from
torchaudio._internal.module_utils
import
is_module_available
from
torchaudio_unittest.common_utils
import
skipIfNoModule
,
TorchaudioTestCase
from
.utils
import
MockCustomDataset
,
MockDataloader
,
MockSentencePieceProcessor
if
is_module_available
(
"pytorch_lightning"
,
"sentencepiece"
):
from
asr.emformer_rnnt.mustc.lightning
import
MuSTCRNNTModule
class
MockMUSTC
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
pass
def
__getitem__
(
self
,
n
:
int
):
return
(
torch
.
rand
(
1
,
32640
),
"sup"
,
)
def
__len__
(
self
):
return
10
@
contextmanager
def
get_lightning_module
():
with
patch
(
"sentencepiece.SentencePieceProcessor"
,
new
=
partial
(
MockSentencePieceProcessor
,
num_symbols
=
500
)),
patch
(
"asr.emformer_rnnt.mustc.lightning.GlobalStatsNormalization"
,
new
=
torch
.
nn
.
Identity
),
patch
(
"asr.emformer_rnnt.mustc.lightning.MUSTC"
,
new
=
MockMUSTC
),
patch
(
"asr.emformer_rnnt.mustc.lightning.CustomDataset"
,
new
=
MockCustomDataset
),
patch
(
"torch.utils.data.DataLoader"
,
new
=
MockDataloader
):
yield
MuSTCRNNTModule
(
mustc_path
=
"mustc_path"
,
sp_model_path
=
"sp_model_path"
,
global_stats_path
=
"global_stats_path"
,
)
@
skipIfNoModule
(
"pytorch_lightning"
)
@
skipIfNoModule
(
"sentencepiece"
)
class
TestMuSTCRNNTModule
(
TorchaudioTestCase
):
@
classmethod
def
setUpClass
(
cls
)
->
None
:
super
().
setUpClass
()
@
parameterized
.
expand
(
[
(
"training_step"
,
"train_dataloader"
),
(
"validation_step"
,
"val_dataloader"
),
(
"test_step"
,
"test_common_dataloader"
),
(
"test_step"
,
"test_he_dataloader"
),
]
)
def
test_step
(
self
,
step_fname
,
dataloader_fname
):
with
get_lightning_module
()
as
lightning_module
:
dataloader
=
getattr
(
lightning_module
,
dataloader_fname
)()
batch
=
next
(
iter
(
dataloader
))
getattr
(
lightning_module
,
step_fname
)(
batch
,
0
)
@
parameterized
.
expand
(
[
(
"val_dataloader"
,),
]
)
def
test_forward
(
self
,
dataloader_fname
):
with
get_lightning_module
()
as
lightning_module
:
dataloader
=
getattr
(
lightning_module
,
dataloader_fname
)()
batch
=
next
(
iter
(
dataloader
))
lightning_module
(
batch
)
test/torchaudio_unittest/example/emformer_rnnt/test_tedlium3_lightning.py
0 → 100644
View file @
ffeba11a
from
contextlib
import
contextmanager
from
functools
import
partial
from
unittest.mock
import
patch
import
torch
from
parameterized
import
parameterized
from
torchaudio._internal.module_utils
import
is_module_available
from
torchaudio_unittest.common_utils
import
skipIfNoModule
,
TorchaudioTestCase
from
.utils
import
MockCustomDataset
,
MockDataloader
,
MockSentencePieceProcessor
if
is_module_available
(
"pytorch_lightning"
,
"sentencepiece"
):
from
asr.emformer_rnnt.tedlium3.lightning
import
TEDLIUM3RNNTModule
class
MockTEDLIUM
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
pass
def
__getitem__
(
self
,
n
:
int
):
return
(
torch
.
rand
(
1
,
32640
),
16000
,
"sup"
,
2
,
3
,
4
,
)
def
__len__
(
self
):
return
10
@
contextmanager
def
get_lightning_module
():
with
patch
(
"sentencepiece.SentencePieceProcessor"
,
new
=
partial
(
MockSentencePieceProcessor
,
num_symbols
=
500
)),
patch
(
"asr.emformer_rnnt.tedlium3.lightning.GlobalStatsNormalization"
,
new
=
torch
.
nn
.
Identity
),
patch
(
"torchaudio.datasets.TEDLIUM"
,
new
=
MockTEDLIUM
),
patch
(
"asr.emformer_rnnt.tedlium3.lightning.CustomDataset"
,
new
=
MockCustomDataset
),
patch
(
"torch.utils.data.DataLoader"
,
new
=
MockDataloader
):
yield
TEDLIUM3RNNTModule
(
tedlium_path
=
"tedlium_path"
,
sp_model_path
=
"sp_model_path"
,
global_stats_path
=
"global_stats_path"
,
)
@
skipIfNoModule
(
"pytorch_lightning"
)
@
skipIfNoModule
(
"sentencepiece"
)
class
TestTEDLIUM3RNNTModule
(
TorchaudioTestCase
):
@
classmethod
def
setUpClass
(
cls
)
->
None
:
super
().
setUpClass
()
@
parameterized
.
expand
(
[
(
"training_step"
,
"train_dataloader"
),
(
"validation_step"
,
"val_dataloader"
),
(
"test_step"
,
"test_dataloader"
),
]
)
def
test_step
(
self
,
step_fname
,
dataloader_fname
):
with
get_lightning_module
()
as
lightning_module
:
dataloader
=
getattr
(
lightning_module
,
dataloader_fname
)()
batch
=
next
(
iter
(
dataloader
))
getattr
(
lightning_module
,
step_fname
)(
batch
,
0
)
@
parameterized
.
expand
(
[
(
"val_dataloader"
,),
]
)
def
test_forward
(
self
,
dataloader_fname
):
with
get_lightning_module
()
as
lightning_module
:
dataloader
=
getattr
(
lightning_module
,
dataloader_fname
)()
batch
=
next
(
iter
(
dataloader
))
lightning_module
(
batch
)
test/torchaudio_unittest/functional/autograd_cuda_test.py
View file @
ffeba11a
...
@@ -5,6 +5,7 @@ from .autograd_impl import Autograd, AutogradFloat32
...
@@ -5,6 +5,7 @@ from .autograd_impl import Autograd, AutogradFloat32
@
common_utils
.
skipIfNoCuda
@
common_utils
.
skipIfNoCuda
@
common_utils
.
skipIfkmeMark
class
TestAutogradLfilterCUDA
(
Autograd
,
common_utils
.
PytorchTestCase
):
class
TestAutogradLfilterCUDA
(
Autograd
,
common_utils
.
PytorchTestCase
):
dtype
=
torch
.
float64
dtype
=
torch
.
float64
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
...
...
test/torchaudio_unittest/functional/autograd_impl.py
View file @
ffeba11a
...
@@ -6,7 +6,14 @@ import torchaudio.functional as F
...
@@ -6,7 +6,14 @@ import torchaudio.functional as F
from
parameterized
import
parameterized
from
parameterized
import
parameterized
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.autograd
import
gradcheck
,
gradgradcheck
from
torch.autograd
import
gradcheck
,
gradgradcheck
from
torchaudio_unittest.common_utils
import
get_spectrogram
,
get_whitenoise
,
rnnt_utils
,
TestBaseMixin
from
torchaudio_unittest.common_utils
import
(
get_spectrogram
,
get_whitenoise
,
nested_params
,
rnnt_utils
,
TestBaseMixin
,
use_deterministic_algorithms
,
)
class
Autograd
(
TestBaseMixin
):
class
Autograd
(
TestBaseMixin
):
...
@@ -71,26 +78,30 @@ class Autograd(TestBaseMixin):
...
@@ -71,26 +78,30 @@ class Autograd(TestBaseMixin):
a
=
torch
.
tensor
([
0.7
,
0.2
,
0.6
])
a
=
torch
.
tensor
([
0.7
,
0.2
,
0.6
])
b
=
torch
.
tensor
([
0.4
,
0.2
,
0.9
])
b
=
torch
.
tensor
([
0.4
,
0.2
,
0.9
])
a
.
requires_grad
=
True
a
.
requires_grad
=
True
self
.
assert_grad
(
F
.
filtfilt
,
(
x
,
a
,
b
),
enable_all_grad
=
False
)
with
use_deterministic_algorithms
(
True
,
False
):
self
.
assert_grad
(
F
.
filtfilt
,
(
x
,
a
,
b
),
enable_all_grad
=
False
)
def
test_filtfilt_b
(
self
):
def
test_filtfilt_b
(
self
):
x
=
get_whitenoise
(
sample_rate
=
22050
,
duration
=
0.01
,
n_channels
=
2
)
x
=
get_whitenoise
(
sample_rate
=
22050
,
duration
=
0.01
,
n_channels
=
2
)
a
=
torch
.
tensor
([
0.7
,
0.2
,
0.6
])
a
=
torch
.
tensor
([
0.7
,
0.2
,
0.6
])
b
=
torch
.
tensor
([
0.4
,
0.2
,
0.9
])
b
=
torch
.
tensor
([
0.4
,
0.2
,
0.9
])
b
.
requires_grad
=
True
b
.
requires_grad
=
True
self
.
assert_grad
(
F
.
filtfilt
,
(
x
,
a
,
b
),
enable_all_grad
=
False
)
with
use_deterministic_algorithms
(
True
,
False
):
self
.
assert_grad
(
F
.
filtfilt
,
(
x
,
a
,
b
),
enable_all_grad
=
False
)
def
test_filtfilt_all_inputs
(
self
):
def
test_filtfilt_all_inputs
(
self
):
x
=
get_whitenoise
(
sample_rate
=
22050
,
duration
=
0.01
,
n_channels
=
2
)
x
=
get_whitenoise
(
sample_rate
=
22050
,
duration
=
0.01
,
n_channels
=
2
)
a
=
torch
.
tensor
([
0.7
,
0.2
,
0.6
])
a
=
torch
.
tensor
([
0.7
,
0.2
,
0.6
])
b
=
torch
.
tensor
([
0.4
,
0.2
,
0.9
])
b
=
torch
.
tensor
([
0.4
,
0.2
,
0.9
])
self
.
assert_grad
(
F
.
filtfilt
,
(
x
,
a
,
b
))
with
use_deterministic_algorithms
(
True
,
False
):
self
.
assert_grad
(
F
.
filtfilt
,
(
x
,
a
,
b
))
def
test_filtfilt_batching
(
self
):
def
test_filtfilt_batching
(
self
):
x
=
get_whitenoise
(
sample_rate
=
22050
,
duration
=
0.01
,
n_channels
=
2
)
x
=
get_whitenoise
(
sample_rate
=
22050
,
duration
=
0.01
,
n_channels
=
2
)
a
=
torch
.
tensor
([[
0.7
,
0.2
,
0.6
],
[
0.8
,
0.2
,
0.9
]])
a
=
torch
.
tensor
([[
0.7
,
0.2
,
0.6
],
[
0.8
,
0.2
,
0.9
]])
b
=
torch
.
tensor
([[
0.4
,
0.2
,
0.9
],
[
0.7
,
0.2
,
0.6
]])
b
=
torch
.
tensor
([[
0.4
,
0.2
,
0.9
],
[
0.7
,
0.2
,
0.6
]])
self
.
assert_grad
(
F
.
filtfilt
,
(
x
,
a
,
b
))
with
use_deterministic_algorithms
(
True
,
False
):
self
.
assert_grad
(
F
.
filtfilt
,
(
x
,
a
,
b
))
def
test_biquad
(
self
):
def
test_biquad
(
self
):
x
=
get_whitenoise
(
sample_rate
=
22050
,
duration
=
0.01
,
n_channels
=
1
)
x
=
get_whitenoise
(
sample_rate
=
22050
,
duration
=
0.01
,
n_channels
=
1
)
...
@@ -335,6 +346,51 @@ class Autograd(TestBaseMixin):
...
@@ -335,6 +346,51 @@ class Autograd(TestBaseMixin):
beamform_weights
=
torch
.
rand
(
batch_size
,
n_fft_bin
,
num_channels
,
dtype
=
torch
.
cfloat
)
beamform_weights
=
torch
.
rand
(
batch_size
,
n_fft_bin
,
num_channels
,
dtype
=
torch
.
cfloat
)
self
.
assert_grad
(
F
.
apply_beamforming
,
(
beamform_weights
,
specgram
))
self
.
assert_grad
(
F
.
apply_beamforming
,
(
beamform_weights
,
specgram
))
@
nested_params
(
[
"convolve"
,
"fftconvolve"
],
[
"full"
,
"valid"
,
"same"
],
)
def
test_convolve
(
self
,
fn
,
mode
):
leading_dims
=
(
4
,
3
,
2
)
L_x
,
L_y
=
23
,
40
x
=
torch
.
rand
(
*
leading_dims
,
L_x
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
y
=
torch
.
rand
(
*
leading_dims
,
L_y
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
self
.
assert_grad
(
getattr
(
F
,
fn
),
(
x
,
y
,
mode
))
def
test_add_noise
(
self
):
leading_dims
=
(
5
,
2
,
3
)
L
=
51
waveform
=
torch
.
rand
(
*
leading_dims
,
L
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
noise
=
torch
.
rand
(
*
leading_dims
,
L
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
lengths
=
torch
.
rand
(
*
leading_dims
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
snr
=
torch
.
rand
(
*
leading_dims
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
*
10
self
.
assert_grad
(
F
.
add_noise
,
(
waveform
,
noise
,
snr
,
lengths
))
def
test_speed
(
self
):
leading_dims
=
(
3
,
2
)
T
=
200
waveform
=
torch
.
rand
(
*
leading_dims
,
T
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
requires_grad
=
True
)
lengths
=
torch
.
randint
(
1
,
T
,
leading_dims
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
self
.
assert_grad
(
F
.
speed
,
(
waveform
,
1000
,
1.1
,
lengths
),
enable_all_grad
=
False
)
def
test_preemphasis
(
self
):
waveform
=
torch
.
rand
(
3
,
2
,
100
,
device
=
self
.
device
,
dtype
=
self
.
dtype
,
requires_grad
=
True
)
coeff
=
0.9
self
.
assert_grad
(
F
.
preemphasis
,
(
waveform
,
coeff
))
def
test_deemphasis
(
self
):
waveform
=
torch
.
rand
(
3
,
2
,
100
,
device
=
self
.
device
,
dtype
=
self
.
dtype
,
requires_grad
=
True
)
coeff
=
0.9
self
.
assert_grad
(
F
.
deemphasis
,
(
waveform
,
coeff
))
def
test_frechet_distance
(
self
):
N
=
16
mu_x
=
torch
.
rand
((
N
,))
sigma_x
=
torch
.
rand
((
N
,
N
))
mu_y
=
torch
.
rand
((
N
,))
sigma_y
=
torch
.
rand
((
N
,
N
))
self
.
assert_grad
(
F
.
frechet_distance
,
(
mu_x
,
sigma_x
,
mu_y
,
sigma_y
))
class
AutogradFloat32
(
TestBaseMixin
):
class
AutogradFloat32
(
TestBaseMixin
):
def
assert_grad
(
def
assert_grad
(
...
...
test/torchaudio_unittest/functional/batch_consistency_test.py
View file @
ffeba11a
...
@@ -27,7 +27,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
...
@@ -27,7 +27,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
backend
=
"default"
backend
=
"default"
def
assert_batch_consistency
(
self
,
functional
,
inputs
,
atol
=
1e-
8
,
rtol
=
1e-5
,
seed
=
42
):
def
assert_batch_consistency
(
self
,
functional
,
inputs
,
atol
=
1e-
6
,
rtol
=
1e-5
,
seed
=
42
):
n
=
inputs
[
0
].
size
(
0
)
n
=
inputs
[
0
].
size
(
0
)
for
i
in
range
(
1
,
len
(
inputs
)):
for
i
in
range
(
1
,
len
(
inputs
)):
self
.
assertEqual
(
inputs
[
i
].
size
(
0
),
n
)
self
.
assertEqual
(
inputs
[
i
].
size
(
0
),
n
)
...
@@ -65,7 +65,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
...
@@ -65,7 +65,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
"rand_init"
:
False
,
"rand_init"
:
False
,
}
}
func
=
partial
(
F
.
griffinlim
,
**
kwargs
)
func
=
partial
(
F
.
griffinlim
,
**
kwargs
)
self
.
assert_batch_consistency
(
func
,
inputs
=
(
batch
,),
atol
=
5
e-
5
)
self
.
assert_batch_consistency
(
func
,
inputs
=
(
batch
,),
atol
=
1
e-
4
)
@
parameterized
.
expand
(
@
parameterized
.
expand
(
list
(
list
(
...
@@ -194,7 +194,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
...
@@ -194,7 +194,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self
.
assert_batch_consistency
(
func
,
inputs
=
(
waveforms
,))
self
.
assert_batch_consistency
(
func
,
inputs
=
(
waveforms
,))
def
test_phaser
(
self
):
def
test_phaser
(
self
):
sample_rate
=
441
00
sample_rate
=
80
00
n_channels
=
2
n_channels
=
2
waveform
=
common_utils
.
get_whitenoise
(
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
,
n_channels
=
self
.
batch_size
*
n_channels
,
duration
=
1
sample_rate
=
sample_rate
,
n_channels
=
self
.
batch_size
*
n_channels
,
duration
=
1
...
@@ -208,7 +208,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
...
@@ -208,7 +208,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
def
test_flanger
(
self
):
def
test_flanger
(
self
):
waveforms
=
torch
.
rand
(
self
.
batch_size
,
2
,
100
)
-
0.5
waveforms
=
torch
.
rand
(
self
.
batch_size
,
2
,
100
)
-
0.5
sample_rate
=
441
00
sample_rate
=
80
00
kwargs
=
{
kwargs
=
{
"sample_rate"
:
sample_rate
,
"sample_rate"
:
sample_rate
,
}
}
...
@@ -233,7 +233,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
...
@@ -233,7 +233,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
func
=
partial
(
F
.
sliding_window_cmn
,
**
kwargs
)
func
=
partial
(
F
.
sliding_window_cmn
,
**
kwargs
)
self
.
assert_batch_consistency
(
func
,
inputs
=
(
spectrogram
,))
self
.
assert_batch_consistency
(
func
,
inputs
=
(
spectrogram
,))
@
parameterized
.
expand
([(
"sinc_interp
olatio
n"
),
(
"
kaiser_window
"
)])
@
parameterized
.
expand
([(
"sinc_interp
_han
n"
),
(
"
sinc_interp_kaiser
"
)])
def
test_resample_waveform
(
self
,
resampling_method
):
def
test_resample_waveform
(
self
,
resampling_method
):
num_channels
=
3
num_channels
=
3
sr
=
16000
sr
=
16000
...
@@ -257,18 +257,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
...
@@ -257,18 +257,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
atol
=
1e-7
,
atol
=
1e-7
,
)
)
@
common_utils
.
skipIfNoKaldi
def
test_compute_kaldi_pitch
(
self
):
sample_rate
=
44100
n_channels
=
2
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
,
n_channels
=
self
.
batch_size
*
n_channels
)
batch
=
waveform
.
view
(
self
.
batch_size
,
n_channels
,
waveform
.
size
(
-
1
))
kwargs
=
{
"sample_rate"
:
sample_rate
,
}
func
=
partial
(
F
.
compute_kaldi_pitch
,
**
kwargs
)
self
.
assert_batch_consistency
(
func
,
inputs
=
(
batch
,))
def
test_lfilter
(
self
):
def
test_lfilter
(
self
):
signal_length
=
2048
signal_length
=
2048
x
=
torch
.
randn
(
self
.
batch_size
,
signal_length
)
x
=
torch
.
randn
(
self
.
batch_size
,
signal_length
)
...
@@ -407,3 +395,90 @@ class TestFunctional(common_utils.TorchaudioTestCase):
...
@@ -407,3 +395,90 @@ class TestFunctional(common_utils.TorchaudioTestCase):
specgram
=
specgram
.
view
(
batch_size
,
num_channels
,
n_fft_bin
,
specgram
.
size
(
-
1
))
specgram
=
specgram
.
view
(
batch_size
,
num_channels
,
n_fft_bin
,
specgram
.
size
(
-
1
))
beamform_weights
=
torch
.
rand
(
batch_size
,
n_fft_bin
,
num_channels
,
dtype
=
torch
.
cfloat
)
beamform_weights
=
torch
.
rand
(
batch_size
,
n_fft_bin
,
num_channels
,
dtype
=
torch
.
cfloat
)
self
.
assert_batch_consistency
(
F
.
apply_beamforming
,
(
beamform_weights
,
specgram
))
self
.
assert_batch_consistency
(
F
.
apply_beamforming
,
(
beamform_weights
,
specgram
))
@
common_utils
.
nested_params
(
[
"convolve"
,
"fftconvolve"
],
[
"full"
,
"valid"
,
"same"
],
)
def
test_convolve
(
self
,
fn
,
mode
):
leading_dims
=
(
2
,
3
)
L_x
,
L_y
=
89
,
43
x
=
torch
.
rand
(
*
leading_dims
,
L_x
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
y
=
torch
.
rand
(
*
leading_dims
,
L_y
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
fn
=
getattr
(
F
,
fn
)
actual
=
fn
(
x
,
y
,
mode
)
expected
=
torch
.
stack
(
[
torch
.
stack
(
[
fn
(
x
[
i
,
j
].
unsqueeze
(
0
),
y
[
i
,
j
].
unsqueeze
(
0
),
mode
).
squeeze
(
0
)
for
j
in
range
(
leading_dims
[
1
])]
)
for
i
in
range
(
leading_dims
[
0
])
]
)
self
.
assertEqual
(
expected
,
actual
)
def
test_add_noise
(
self
):
leading_dims
=
(
5
,
2
,
3
)
L
=
51
waveform
=
torch
.
rand
(
*
leading_dims
,
L
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
noise
=
torch
.
rand
(
*
leading_dims
,
L
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
lengths
=
torch
.
rand
(
*
leading_dims
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
snr
=
torch
.
rand
(
*
leading_dims
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
*
10
actual
=
F
.
add_noise
(
waveform
,
noise
,
snr
,
lengths
)
expected
=
[]
for
i
in
range
(
leading_dims
[
0
]):
for
j
in
range
(
leading_dims
[
1
]):
for
k
in
range
(
leading_dims
[
2
]):
expected
.
append
(
F
.
add_noise
(
waveform
[
i
][
j
][
k
],
noise
[
i
][
j
][
k
],
snr
[
i
][
j
][
k
],
lengths
[
i
][
j
][
k
]))
self
.
assertEqual
(
torch
.
stack
(
expected
),
actual
.
reshape
(
-
1
,
L
))
def
test_speed
(
self
):
B
=
5
orig_freq
=
100
factor
=
0.8
input_lengths
=
torch
.
randint
(
1
,
1000
,
(
B
,),
dtype
=
torch
.
int32
)
unbatched_input
=
[
torch
.
ones
((
int
(
length
),))
*
1.0
for
length
in
input_lengths
]
batched_input
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
unbatched_input
,
batch_first
=
True
)
output
,
output_lengths
=
F
.
speed
(
batched_input
,
orig_freq
=
orig_freq
,
factor
=
factor
,
lengths
=
input_lengths
)
unbatched_output
=
[]
unbatched_output_lengths
=
[]
for
idx
in
range
(
len
(
unbatched_input
)):
w
,
l
=
F
.
speed
(
unbatched_input
[
idx
],
orig_freq
=
orig_freq
,
factor
=
factor
,
lengths
=
input_lengths
[
idx
])
unbatched_output
.
append
(
w
)
unbatched_output_lengths
.
append
(
l
)
self
.
assertEqual
(
output_lengths
,
torch
.
stack
(
unbatched_output_lengths
))
for
idx
in
range
(
len
(
unbatched_output
)):
w
,
l
=
output
[
idx
],
output_lengths
[
idx
]
self
.
assertEqual
(
unbatched_output
[
idx
],
w
[:
l
])
def
test_preemphasis
(
self
):
waveform
=
torch
.
rand
(
3
,
2
,
100
,
device
=
self
.
device
,
dtype
=
self
.
dtype
)
coeff
=
0.9
actual
=
F
.
preemphasis
(
waveform
,
coeff
=
coeff
)
expected
=
[]
for
i
in
range
(
waveform
.
size
(
0
)):
expected
.
append
(
F
.
preemphasis
(
waveform
[
i
],
coeff
=
coeff
))
self
.
assertEqual
(
torch
.
stack
(
expected
),
actual
)
def
test_deemphasis
(
self
):
waveform
=
torch
.
rand
(
3
,
2
,
100
,
device
=
self
.
device
,
dtype
=
self
.
dtype
)
coeff
=
0.9
actual
=
F
.
deemphasis
(
waveform
,
coeff
=
coeff
)
expected
=
[]
for
i
in
range
(
waveform
.
size
(
0
)):
expected
.
append
(
F
.
deemphasis
(
waveform
[
i
],
coeff
=
coeff
))
self
.
assertEqual
(
torch
.
stack
(
expected
),
actual
)
test/torchaudio_unittest/functional/functional_cuda_test.py
View file @
ffeba11a
import
unittest
import
unittest
import
torch
import
torch
from
torchaudio_unittest.common_utils
import
PytorchTestCase
,
skipIfNoCuda
from
torchaudio_unittest.common_utils
import
PytorchTestCase
,
skipIfNoCuda
,
skipIfkmeMark
from
.functional_impl
import
Functional
from
.functional_impl
import
Functional
,
FunctionalCUDAOnly
@
skipIfNoCuda
@
skipIfNoCuda
...
@@ -17,6 +17,20 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
...
@@ -17,6 +17,20 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
@
skipIfNoCuda
@
skipIfNoCuda
@
skipIfkmeMark
class
TestLFilterFloat64
(
Functional
,
PytorchTestCase
):
class
TestLFilterFloat64
(
Functional
,
PytorchTestCase
):
dtype
=
torch
.
float64
dtype
=
torch
.
float64
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
@
skipIfNoCuda
class
TestFunctionalCUDAOnlyFloat32
(
FunctionalCUDAOnly
,
PytorchTestCase
):
dtype
=
torch
.
float32
device
=
torch
.
device
(
"cuda"
)
@
skipIfNoCuda
@
skipIfkmeMark
class
TestFunctionalCUDAOnlyFloat64
(
FunctionalCUDAOnly
,
PytorchTestCase
):
dtype
=
torch
.
float64
device
=
torch
.
device
(
"cuda"
)
test/torchaudio_unittest/functional/functional_impl.py
View file @
ffeba11a
...
@@ -20,7 +20,7 @@ from torchaudio_unittest.common_utils import (
...
@@ -20,7 +20,7 @@ from torchaudio_unittest.common_utils import (
class
Functional
(
TestBaseMixin
):
class
Functional
(
TestBaseMixin
):
def
_test_resample_waveform_accuracy
(
def
_test_resample_waveform_accuracy
(
self
,
up_scale_factor
=
None
,
down_scale_factor
=
None
,
resampling_method
=
"sinc_interp
olatio
n"
,
atol
=
1e-1
,
rtol
=
1e-4
self
,
up_scale_factor
=
None
,
down_scale_factor
=
None
,
resampling_method
=
"sinc_interp
_han
n"
,
atol
=
1e-1
,
rtol
=
1e-4
):
):
# resample the signal and compare it to the ground truth
# resample the signal and compare it to the ground truth
n_to_trim
=
20
n_to_trim
=
20
...
@@ -51,6 +51,7 @@ class Functional(TestBaseMixin):
...
@@ -51,6 +51,7 @@ class Functional(TestBaseMixin):
def
_test_costs_and_gradients
(
self
,
data
,
ref_costs
,
ref_gradients
,
atol
=
1e-6
,
rtol
=
1e-2
):
def
_test_costs_and_gradients
(
self
,
data
,
ref_costs
,
ref_gradients
,
atol
=
1e-6
,
rtol
=
1e-2
):
logits_shape
=
data
[
"logits"
].
shape
logits_shape
=
data
[
"logits"
].
shape
costs
,
gradients
=
rnnt_utils
.
compute_with_pytorch_transducer
(
data
=
data
)
costs
,
gradients
=
rnnt_utils
.
compute_with_pytorch_transducer
(
data
=
data
)
self
.
assertEqual
(
costs
,
ref_costs
,
atol
=
atol
,
rtol
=
rtol
)
self
.
assertEqual
(
costs
,
ref_costs
,
atol
=
atol
,
rtol
=
rtol
)
self
.
assertEqual
(
logits_shape
,
gradients
.
shape
)
self
.
assertEqual
(
logits_shape
,
gradients
.
shape
)
self
.
assertEqual
(
gradients
,
ref_gradients
,
atol
=
atol
,
rtol
=
rtol
)
self
.
assertEqual
(
gradients
,
ref_gradients
,
atol
=
atol
,
rtol
=
rtol
)
...
@@ -396,22 +397,38 @@ class Functional(TestBaseMixin):
...
@@ -396,22 +397,38 @@ class Functional(TestBaseMixin):
close_to_limit
=
decibels
<
6.0207
close_to_limit
=
decibels
<
6.0207
assert
close_to_limit
.
any
(),
f
"No values were close to the limit. Did it over-clamp?
\n
{
decibels
}
"
assert
close_to_limit
.
any
(),
f
"No values were close to the limit. Did it over-clamp?
\n
{
decibels
}
"
@
parameterized
.
expand
(
list
(
itertools
.
product
([(
1
,
201
,
100
),
(
10
,
2
,
201
,
300
)])))
def
test_mask_along_axis_input_axis_check
(
self
,
shape
):
specgram
=
torch
.
randn
(
*
shape
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
message
=
"Only Frequency and Time masking are supported"
with
self
.
assertRaisesRegex
(
ValueError
,
message
):
F
.
mask_along_axis
(
specgram
,
100
,
0.0
,
0
,
1.0
)
@
parameterized
.
expand
(
@
parameterized
.
expand
(
list
(
itertools
.
product
([(
2
,
1025
,
400
),
(
1
,
201
,
100
)],
[
100
],
[
0.0
,
30.0
],
[
1
,
2
],
[
0.33
,
1.0
]))
list
(
itertools
.
product
([(
1025
,
400
),
(
1
,
201
,
100
),
(
10
,
2
,
201
,
300
)],
[
100
],
[
0.0
,
30.0
],
[
1
,
2
],
[
0.33
,
1.0
])
)
)
)
def
test_mask_along_axis
(
self
,
shape
,
mask_param
,
mask_value
,
axis
,
p
):
def
test_mask_along_axis
(
self
,
shape
,
mask_param
,
mask_value
,
last_
axis
,
p
):
specgram
=
torch
.
randn
(
*
shape
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
specgram
=
torch
.
randn
(
*
shape
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
# last_axis = 1 means the last axis; 2 means the second-to-last axis.
axis
=
len
(
shape
)
-
last_axis
if
p
!=
1.0
:
if
p
!=
1.0
:
mask_specgram
=
F
.
mask_along_axis
(
specgram
,
mask_param
,
mask_value
,
axis
,
p
=
p
)
mask_specgram
=
F
.
mask_along_axis
(
specgram
,
mask_param
,
mask_value
,
axis
,
p
=
p
)
else
:
else
:
mask_specgram
=
F
.
mask_along_axis
(
specgram
,
mask_param
,
mask_value
,
axis
)
mask_specgram
=
F
.
mask_along_axis
(
specgram
,
mask_param
,
mask_value
,
axis
)
other_axis
=
1
if
axis
==
2
else
2
other_axis
=
axis
-
1
if
last_
axis
==
1
else
axis
+
1
masked_columns
=
(
mask_specgram
==
mask_value
).
sum
(
other_axis
)
masked_columns
=
(
mask_specgram
==
mask_value
).
sum
(
other_axis
)
num_masked_columns
=
(
masked_columns
==
mask_specgram
.
size
(
other_axis
)).
sum
()
num_masked_columns
=
(
masked_columns
==
mask_specgram
.
size
(
other_axis
)).
sum
()
num_masked_columns
=
torch
.
div
(
num_masked_columns
,
mask_specgram
.
size
(
0
),
rounding_mode
=
"floor"
)
den
=
1
for
i
in
range
(
len
(
shape
)
-
2
):
den
*=
mask_specgram
.
size
(
i
)
num_masked_columns
=
torch
.
div
(
num_masked_columns
,
den
,
rounding_mode
=
"floor"
)
if
p
!=
1.0
:
if
p
!=
1.0
:
mask_param
=
min
(
mask_param
,
int
(
specgram
.
shape
[
axis
]
*
p
))
mask_param
=
min
(
mask_param
,
int
(
specgram
.
shape
[
axis
]
*
p
))
...
@@ -470,7 +487,7 @@ class Functional(TestBaseMixin):
...
@@ -470,7 +487,7 @@ class Functional(TestBaseMixin):
@
parameterized
.
expand
(
@
parameterized
.
expand
(
list
(
list
(
itertools
.
product
(
itertools
.
product
(
[
"sinc_interp
olation"
,
"kaiser_window
"
],
[
"sinc_interp
_hann"
,
"sinc_interp_kaiser
"
],
[
16000
,
44100
],
[
16000
,
44100
],
)
)
)
)
...
@@ -481,7 +498,7 @@ class Functional(TestBaseMixin):
...
@@ -481,7 +498,7 @@ class Functional(TestBaseMixin):
resampled
=
F
.
resample
(
waveform
,
sample_rate
,
sample_rate
)
resampled
=
F
.
resample
(
waveform
,
sample_rate
,
sample_rate
)
self
.
assertEqual
(
waveform
,
resampled
)
self
.
assertEqual
(
waveform
,
resampled
)
@
parameterized
.
expand
([(
"sinc_interp
olatio
n"
),
(
"
kaiser_window
"
)])
@
parameterized
.
expand
([(
"sinc_interp
_han
n"
),
(
"
sinc_interp_kaiser
"
)])
def
test_resample_waveform_upsample_size
(
self
,
resampling_method
):
def
test_resample_waveform_upsample_size
(
self
,
resampling_method
):
sr
=
16000
sr
=
16000
waveform
=
get_whitenoise
(
waveform
=
get_whitenoise
(
...
@@ -491,7 +508,7 @@ class Functional(TestBaseMixin):
...
@@ -491,7 +508,7 @@ class Functional(TestBaseMixin):
upsampled
=
F
.
resample
(
waveform
,
sr
,
sr
*
2
,
resampling_method
=
resampling_method
)
upsampled
=
F
.
resample
(
waveform
,
sr
,
sr
*
2
,
resampling_method
=
resampling_method
)
assert
upsampled
.
size
(
-
1
)
==
waveform
.
size
(
-
1
)
*
2
assert
upsampled
.
size
(
-
1
)
==
waveform
.
size
(
-
1
)
*
2
@
parameterized
.
expand
([(
"sinc_interp
olatio
n"
),
(
"
kaiser_window
"
)])
@
parameterized
.
expand
([(
"sinc_interp
_han
n"
),
(
"
sinc_interp_kaiser
"
)])
def
test_resample_waveform_downsample_size
(
self
,
resampling_method
):
def
test_resample_waveform_downsample_size
(
self
,
resampling_method
):
sr
=
16000
sr
=
16000
waveform
=
get_whitenoise
(
waveform
=
get_whitenoise
(
...
@@ -501,7 +518,7 @@ class Functional(TestBaseMixin):
...
@@ -501,7 +518,7 @@ class Functional(TestBaseMixin):
downsampled
=
F
.
resample
(
waveform
,
sr
,
sr
//
2
,
resampling_method
=
resampling_method
)
downsampled
=
F
.
resample
(
waveform
,
sr
,
sr
//
2
,
resampling_method
=
resampling_method
)
assert
downsampled
.
size
(
-
1
)
==
waveform
.
size
(
-
1
)
//
2
assert
downsampled
.
size
(
-
1
)
==
waveform
.
size
(
-
1
)
//
2
@
parameterized
.
expand
([(
"sinc_interp
olatio
n"
),
(
"
kaiser_window
"
)])
@
parameterized
.
expand
([(
"sinc_interp
_han
n"
),
(
"
sinc_interp_kaiser
"
)])
def
test_resample_waveform_identity_size
(
self
,
resampling_method
):
def
test_resample_waveform_identity_size
(
self
,
resampling_method
):
sr
=
16000
sr
=
16000
waveform
=
get_whitenoise
(
waveform
=
get_whitenoise
(
...
@@ -514,7 +531,7 @@ class Functional(TestBaseMixin):
...
@@ -514,7 +531,7 @@ class Functional(TestBaseMixin):
@
parameterized
.
expand
(
@
parameterized
.
expand
(
list
(
list
(
itertools
.
product
(
itertools
.
product
(
[
"sinc_interp
olation"
,
"kaiser_window
"
],
[
"sinc_interp
_hann"
,
"sinc_interp_kaiser
"
],
list
(
range
(
1
,
20
)),
list
(
range
(
1
,
20
)),
)
)
)
)
...
@@ -525,7 +542,7 @@ class Functional(TestBaseMixin):
...
@@ -525,7 +542,7 @@ class Functional(TestBaseMixin):
@
parameterized
.
expand
(
@
parameterized
.
expand
(
list
(
list
(
itertools
.
product
(
itertools
.
product
(
[
"sinc_interp
olation"
,
"kaiser_window
"
],
[
"sinc_interp
_hann"
,
"sinc_interp_kaiser
"
],
list
(
range
(
1
,
20
)),
list
(
range
(
1
,
20
)),
)
)
)
)
...
@@ -637,13 +654,25 @@ class Functional(TestBaseMixin):
...
@@ -637,13 +654,25 @@ class Functional(TestBaseMixin):
rtol
=
rtol
,
rtol
=
rtol
,
)
)
def
test_rnnt_loss_costs_and_gradients_random_data_with_numpy_fp32
(
self
):
@
parameterized
.
expand
([(
True
,),
(
False
,)])
def
test_rnnt_loss_costs_and_gradients_random_data_with_numpy_fp32
(
self
,
fused_log_softmax
):
seed
=
777
seed
=
777
for
i
in
range
(
5
):
for
i
in
range
(
5
):
data
=
rnnt_utils
.
get_random_data
(
dtype
=
torch
.
float32
,
device
=
self
.
device
,
seed
=
(
seed
+
i
))
data
=
rnnt_utils
.
get_random_data
(
fused_log_softmax
=
fused_log_softmax
,
dtype
=
torch
.
float32
,
device
=
self
.
device
,
seed
=
(
seed
+
i
)
)
ref_costs
,
ref_gradients
=
rnnt_utils
.
compute_with_numpy_transducer
(
data
=
data
)
ref_costs
,
ref_gradients
=
rnnt_utils
.
compute_with_numpy_transducer
(
data
=
data
)
self
.
_test_costs_and_gradients
(
data
=
data
,
ref_costs
=
ref_costs
,
ref_gradients
=
ref_gradients
)
self
.
_test_costs_and_gradients
(
data
=
data
,
ref_costs
=
ref_costs
,
ref_gradients
=
ref_gradients
)
def
test_rnnt_loss_nonfused_softmax
(
self
):
data
=
rnnt_utils
.
get_B1_T10_U3_D4_data
()
ref_costs
,
ref_gradients
=
rnnt_utils
.
compute_with_numpy_transducer
(
data
=
data
)
self
.
_test_costs_and_gradients
(
data
=
data
,
ref_costs
=
ref_costs
,
ref_gradients
=
ref_gradients
,
)
def
test_psd
(
self
):
def
test_psd
(
self
):
"""Verify the ``F.psd`` method by the numpy implementation.
"""Verify the ``F.psd`` method by the numpy implementation.
Given the multi-channel complex-valued spectrum as the input,
Given the multi-channel complex-valued spectrum as the input,
...
@@ -879,6 +908,412 @@ class Functional(TestBaseMixin):
...
@@ -879,6 +908,412 @@ class Functional(TestBaseMixin):
torch
.
tensor
(
specgram_enhanced
,
dtype
=
self
.
complex_dtype
,
device
=
self
.
device
),
specgram_enhanced_audio
torch
.
tensor
(
specgram_enhanced
,
dtype
=
self
.
complex_dtype
,
device
=
self
.
device
),
specgram_enhanced_audio
)
)
@
nested_params
(
[(
10
,
4
),
(
4
,
3
,
1
,
2
),
(
2
,),
()],
[(
100
,
43
),
(
21
,
45
)],
[
"full"
,
"valid"
,
"same"
],
)
def
test_convolve_numerics
(
self
,
leading_dims
,
lengths
,
mode
):
"""Check that convolve returns values identical to those that SciPy produces."""
L_x
,
L_y
=
lengths
x
=
torch
.
rand
(
*
(
leading_dims
+
(
L_x
,)),
dtype
=
self
.
dtype
,
device
=
self
.
device
)
y
=
torch
.
rand
(
*
(
leading_dims
+
(
L_y
,)),
dtype
=
self
.
dtype
,
device
=
self
.
device
)
actual
=
F
.
convolve
(
x
,
y
,
mode
=
mode
)
num_signals
=
torch
.
tensor
(
leading_dims
).
prod
()
if
leading_dims
else
1
x_reshaped
=
x
.
reshape
((
num_signals
,
L_x
))
y_reshaped
=
y
.
reshape
((
num_signals
,
L_y
))
expected
=
[
signal
.
convolve
(
x_reshaped
[
i
].
detach
().
cpu
().
numpy
(),
y_reshaped
[
i
].
detach
().
cpu
().
numpy
(),
mode
=
mode
)
for
i
in
range
(
num_signals
)
]
expected
=
torch
.
tensor
(
np
.
array
(
expected
))
expected
=
expected
.
reshape
(
leading_dims
+
(
-
1
,))
self
.
assertEqual
(
expected
,
actual
)
@
nested_params
(
[(
10
,
4
),
(
4
,
3
,
1
,
2
),
(
2
,),
()],
[(
100
,
43
),
(
21
,
45
)],
[
"full"
,
"valid"
,
"same"
],
)
def
test_fftconvolve_numerics
(
self
,
leading_dims
,
lengths
,
mode
):
"""Check that fftconvolve returns values identical to those that SciPy produces."""
L_x
,
L_y
=
lengths
x
=
torch
.
rand
(
*
(
leading_dims
+
(
L_x
,)),
dtype
=
self
.
dtype
,
device
=
self
.
device
)
y
=
torch
.
rand
(
*
(
leading_dims
+
(
L_y
,)),
dtype
=
self
.
dtype
,
device
=
self
.
device
)
actual
=
F
.
fftconvolve
(
x
,
y
,
mode
=
mode
)
expected
=
signal
.
fftconvolve
(
x
.
detach
().
cpu
().
numpy
(),
y
.
detach
().
cpu
().
numpy
(),
axes
=-
1
,
mode
=
mode
)
expected
=
torch
.
tensor
(
expected
)
self
.
assertEqual
(
expected
,
actual
)
@
nested_params
(
[
"convolve"
,
"fftconvolve"
],
[(
5
,
2
,
3
)],
[(
5
,
1
,
3
),
(
1
,
2
,
3
),
(
1
,
1
,
3
)],
)
def
test_convolve_broadcast
(
self
,
fn
,
x_shape
,
y_shape
):
"""convolve works for Tensors for different shapes if they are broadcast-able"""
# 1. Test broadcast case
x
=
torch
.
rand
(
x_shape
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
y
=
torch
.
rand
(
y_shape
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
out1
=
getattr
(
F
,
fn
)(
x
,
y
)
# 2. Test without broadcast
y_clone
=
y
.
expand
(
x_shape
).
clone
()
assert
y
is
not
y_clone
assert
y_clone
.
shape
==
x
.
shape
out2
=
getattr
(
F
,
fn
)(
x
,
y_clone
)
# check that they are same
self
.
assertEqual
(
out1
,
out2
)
@
parameterized
.
expand
(
[
# fmt: off
# different ndim
(
0
,
F
.
convolve
,
(
4
,
3
,
1
,
2
),
(
10
,
4
)),
(
0
,
F
.
convolve
,
(
4
,
3
,
1
,
2
),
(
2
,
2
,
2
)),
(
0
,
F
.
convolve
,
(
1
,
),
(
10
,
4
)),
(
0
,
F
.
convolve
,
(
1
,
),
(
2
,
2
,
2
)),
(
0
,
F
.
fftconvolve
,
(
4
,
3
,
1
,
2
),
(
10
,
4
)),
(
0
,
F
.
fftconvolve
,
(
4
,
3
,
1
,
2
),
(
2
,
2
,
2
)),
(
0
,
F
.
fftconvolve
,
(
1
,
),
(
10
,
4
)),
(
0
,
F
.
fftconvolve
,
(
1
,
),
(
2
,
2
,
2
)),
# non-broadcastable leading dimensions
(
1
,
F
.
convolve
,
(
5
,
2
,
3
),
(
5
,
3
,
3
)),
(
1
,
F
.
convolve
,
(
5
,
2
,
3
),
(
5
,
3
,
4
)),
(
1
,
F
.
convolve
,
(
5
,
2
,
3
),
(
5
,
3
,
5
)),
(
1
,
F
.
fftconvolve
,
(
5
,
2
,
3
),
(
5
,
3
,
3
)),
(
1
,
F
.
fftconvolve
,
(
5
,
2
,
3
),
(
5
,
3
,
4
)),
(
1
,
F
.
fftconvolve
,
(
5
,
2
,
3
),
(
5
,
3
,
5
)),
# fmt: on
],
)
def
test_convolve_input_dim_check
(
self
,
case
,
fn
,
x_shape
,
y_shape
):
"""Check that convolve properly rejects inputs with incompatible dimensions."""
x
=
torch
.
rand
(
*
x_shape
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
y
=
torch
.
rand
(
*
y_shape
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
message
=
[
"The operands must be the same dimension"
,
"Leading dimensions of x and y are not broadcastable"
,
][
case
]
with
self
.
assertRaisesRegex
(
ValueError
,
message
):
fn
(
x
,
y
)
def
test_add_noise_broadcast
(
self
):
"""Check that add_noise produces correct outputs when broadcasting input dimensions."""
leading_dims
=
(
5
,
2
,
3
)
L
=
51
waveform
=
torch
.
rand
(
*
leading_dims
,
L
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
noise
=
torch
.
rand
(
5
,
1
,
1
,
L
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
lengths
=
torch
.
rand
(
5
,
1
,
3
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
snr
=
torch
.
rand
(
1
,
1
,
1
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
*
10
actual
=
F
.
add_noise
(
waveform
,
noise
,
snr
,
lengths
)
noise_expanded
=
noise
.
expand
(
*
leading_dims
,
L
)
snr_expanded
=
snr
.
expand
(
*
leading_dims
)
lengths_expanded
=
lengths
.
expand
(
*
leading_dims
)
expected
=
F
.
add_noise
(
waveform
,
noise_expanded
,
snr_expanded
,
lengths_expanded
)
self
.
assertEqual
(
expected
,
actual
)
@
parameterized
.
expand
(
[((
5
,
2
,
3
),
(
2
,
1
,
1
),
(
5
,
2
),
(
5
,
2
,
3
)),
((
2
,
1
),
(
5
,),
(
5
,),
(
5
,)),
((
3
,),
(
5
,
2
,
3
),
(
2
,
1
,
1
),
(
5
,
2
))]
)
def
test_add_noise_leading_dim_check
(
self
,
waveform_dims
,
noise_dims
,
lengths_dims
,
snr_dims
):
"""Check that add_noise properly rejects inputs with different leading dimension lengths."""
L
=
51
waveform
=
torch
.
rand
(
*
waveform_dims
,
L
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
noise
=
torch
.
rand
(
*
noise_dims
,
L
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
lengths
=
torch
.
rand
(
*
lengths_dims
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
snr
=
torch
.
rand
(
*
snr_dims
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
*
10
with
self
.
assertRaisesRegex
(
ValueError
,
"Input leading dimensions"
):
F
.
add_noise
(
waveform
,
noise
,
snr
,
lengths
)
def
test_add_noise_length_check
(
self
):
"""Check that add_noise properly rejects inputs that have inconsistent length dimensions."""
leading_dims
=
(
5
,
2
,
3
)
L
=
51
waveform
=
torch
.
rand
(
*
leading_dims
,
L
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
noise
=
torch
.
rand
(
*
leading_dims
,
50
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
lengths
=
torch
.
rand
(
*
leading_dims
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
snr
=
torch
.
rand
(
*
leading_dims
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
*
10
with
self
.
assertRaisesRegex
(
ValueError
,
"Length dimensions"
):
F
.
add_noise
(
waveform
,
noise
,
snr
,
lengths
)
def
test_speed_identity
(
self
):
"""speed of 1.0 does not alter input waveform and length"""
leading_dims
=
(
5
,
4
,
2
)
T
=
1000
waveform
=
torch
.
rand
(
*
leading_dims
,
T
)
lengths
=
torch
.
randint
(
1
,
1000
,
leading_dims
)
actual_waveform
,
actual_lengths
=
F
.
speed
(
waveform
,
orig_freq
=
1000
,
factor
=
1.0
,
lengths
=
lengths
)
self
.
assertEqual
(
waveform
,
actual_waveform
)
self
.
assertEqual
(
lengths
,
actual_lengths
)
@
nested_params
([
0.8
,
1.1
,
1.2
],
[
True
,
False
])
def
test_speed_accuracy
(
self
,
factor
,
use_lengths
):
"""sinusoidal waveform is properly compressed by factor"""
n_to_trim
=
20
sample_rate
=
1000
freq
=
2
times
=
torch
.
arange
(
0
,
5
,
1.0
/
sample_rate
)
waveform
=
torch
.
cos
(
2
*
math
.
pi
*
freq
*
times
).
unsqueeze
(
0
).
to
(
self
.
device
,
self
.
dtype
)
if
use_lengths
:
lengths
=
torch
.
tensor
([
waveform
.
size
(
1
)])
else
:
lengths
=
None
output
,
output_lengths
=
F
.
speed
(
waveform
,
orig_freq
=
sample_rate
,
factor
=
factor
,
lengths
=
lengths
)
if
use_lengths
:
self
.
assertEqual
(
output
.
size
(
1
),
output_lengths
[
0
])
else
:
self
.
assertEqual
(
None
,
output_lengths
)
new_times
=
torch
.
arange
(
0
,
5
/
factor
,
1.0
/
sample_rate
)
expected_waveform
=
torch
.
cos
(
2
*
math
.
pi
*
freq
*
factor
*
new_times
).
unsqueeze
(
0
).
to
(
self
.
device
,
self
.
dtype
)
self
.
assertEqual
(
expected_waveform
[...,
n_to_trim
:
-
n_to_trim
],
output
[...,
n_to_trim
:
-
n_to_trim
],
atol
=
1e-1
,
rtol
=
1e-4
)
@
nested_params
(
[(
3
,
2
,
100
),
(
95
,)],
[
0.97
,
0.9
,
0.68
],
)
def
test_preemphasis
(
self
,
input_shape
,
coeff
):
waveform
=
torch
.
rand
(
*
input_shape
,
device
=
self
.
device
,
dtype
=
self
.
dtype
)
actual
=
F
.
preemphasis
(
waveform
,
coeff
=
coeff
)
a_coeffs
=
torch
.
tensor
([
1.0
,
0.0
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
b_coeffs
=
torch
.
tensor
([
1.0
,
-
coeff
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
expected
=
F
.
lfilter
(
waveform
,
a_coeffs
=
a_coeffs
,
b_coeffs
=
b_coeffs
)
self
.
assertEqual
(
actual
,
expected
)
@
nested_params
(
[(
3
,
2
,
100
),
(
95
,)],
[
0.97
,
0.9
,
0.68
],
)
def
test_preemphasis_deemphasis_roundtrip
(
self
,
input_shape
,
coeff
):
waveform
=
torch
.
rand
(
*
input_shape
,
device
=
self
.
device
,
dtype
=
self
.
dtype
)
preemphasized
=
F
.
preemphasis
(
waveform
,
coeff
=
coeff
)
deemphasized
=
F
.
deemphasis
(
preemphasized
,
coeff
=
coeff
)
self
.
assertEqual
(
deemphasized
,
waveform
)
@
parameterized
.
expand
(
[
([[
0
,
1
,
1
,
0
]],
[[
0
,
1
,
5
,
1
,
0
]],
torch
.
int32
),
([[
0
,
1
,
2
,
3
,
4
]],
[[
0
,
1
,
2
,
3
,
4
]],
torch
.
int32
),
([[
3
,
3
,
3
]],
[[
3
,
5
,
3
,
5
,
3
]],
torch
.
int64
),
([[
0
,
1
,
2
]],
[[
0
,
1
,
1
,
1
,
2
]],
torch
.
int64
),
]
)
def
test_forced_align
(
self
,
targets
,
ref_path
,
targets_dtype
):
emission
=
torch
.
tensor
(
[
[
[
0.633766
,
0.221185
,
0.0917319
,
0.0129757
,
0.0142857
,
0.0260553
],
[
0.111121
,
0.588392
,
0.278779
,
0.0055756
,
0.00569609
,
0.010436
],
[
0.0357786
,
0.633813
,
0.321418
,
0.00249248
,
0.00272882
,
0.0037688
],
[
0.0663296
,
0.643849
,
0.280111
,
0.00283995
,
0.0035545
,
0.00331533
],
[
0.458235
,
0.396634
,
0.123377
,
0.00648837
,
0.00903441
,
0.00623107
],
]
],
dtype
=
self
.
dtype
,
device
=
self
.
device
,
)
blank
=
5
batch_index
=
0
ref_path
=
torch
.
tensor
(
ref_path
,
dtype
=
targets_dtype
,
device
=
self
.
device
)
ref_scores
=
torch
.
tensor
(
[
torch
.
log
(
emission
[
batch_index
,
i
,
ref_path
[
batch_index
,
i
]]).
item
()
for
i
in
range
(
emission
.
shape
[
1
])],
dtype
=
emission
.
dtype
,
device
=
self
.
device
,
).
unsqueeze
(
0
)
log_probs
=
torch
.
log
(
emission
)
targets
=
torch
.
tensor
(
targets
,
dtype
=
targets_dtype
,
device
=
self
.
device
)
input_lengths
=
torch
.
tensor
([
log_probs
.
shape
[
1
]],
device
=
self
.
device
)
target_lengths
=
torch
.
tensor
([
targets
.
shape
[
1
]],
device
=
self
.
device
)
hyp_path
,
hyp_scores
=
F
.
forced_align
(
log_probs
,
targets
,
input_lengths
,
target_lengths
,
blank
)
assert
hyp_path
.
shape
==
ref_path
.
shape
assert
hyp_scores
.
shape
==
ref_scores
.
shape
self
.
assertEqual
(
hyp_path
,
ref_path
)
self
.
assertEqual
(
hyp_scores
,
ref_scores
)
@
parameterized
.
expand
([(
torch
.
int32
,),
(
torch
.
int64
,)])
def
test_forced_align_fail
(
self
,
targets_dtype
):
log_probs
=
torch
.
rand
(
1
,
5
,
6
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
targets
=
torch
.
tensor
([[
0
,
1
,
2
,
3
,
4
,
4
]],
dtype
=
targets_dtype
,
device
=
self
.
device
)
blank
=
5
input_lengths
=
torch
.
tensor
([
log_probs
.
shape
[
1
]],
device
=
self
.
device
)
target_lengths
=
torch
.
tensor
([
targets
.
shape
[
1
]],
device
=
self
.
device
)
with
self
.
assertRaisesRegex
(
RuntimeError
,
r
"targets length is too long for CTC"
):
hyp_path
,
hyp_scores
=
F
.
forced_align
(
log_probs
,
targets
,
input_lengths
,
target_lengths
,
blank
)
targets
=
torch
.
tensor
([[
5
,
3
,
3
]],
dtype
=
targets_dtype
,
device
=
self
.
device
)
with
self
.
assertRaisesRegex
(
ValueError
,
r
"targets Tensor shouldn't contain blank index"
):
hyp_path
,
hyp_scores
=
F
.
forced_align
(
log_probs
,
targets
,
input_lengths
,
target_lengths
,
blank
)
log_probs
=
log_probs
.
int
()
targets
=
torch
.
tensor
([[
0
,
1
,
2
,
3
]],
dtype
=
targets_dtype
,
device
=
self
.
device
)
with
self
.
assertRaisesRegex
(
RuntimeError
,
r
"log_probs must be float64, float32 or float16"
):
hyp_path
,
hyp_scores
=
F
.
forced_align
(
log_probs
,
targets
,
input_lengths
,
target_lengths
,
blank
)
log_probs
=
log_probs
.
float
()
targets
=
targets
.
float
()
with
self
.
assertRaisesRegex
(
RuntimeError
,
r
"targets must be int32 or int64 type"
):
hyp_path
,
hyp_scores
=
F
.
forced_align
(
log_probs
,
targets
,
input_lengths
,
target_lengths
,
blank
)
log_probs
=
torch
.
rand
(
3
,
4
,
6
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
targets
=
targets
.
int
()
with
self
.
assertRaisesRegex
(
RuntimeError
,
r
"The batch dimension for log_probs must be 1 at the current version"
):
hyp_path
,
hyp_scores
=
F
.
forced_align
(
log_probs
,
targets
,
input_lengths
,
target_lengths
,
blank
)
targets
=
torch
.
randint
(
0
,
4
,
(
3
,
4
),
device
=
self
.
device
)
log_probs
=
torch
.
rand
(
1
,
3
,
6
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
with
self
.
assertRaisesRegex
(
RuntimeError
,
r
"The batch dimension for targets must be 1 at the current version"
):
hyp_path
,
hyp_scores
=
F
.
forced_align
(
log_probs
,
targets
,
input_lengths
,
target_lengths
,
blank
)
targets
=
torch
.
tensor
([[
0
,
1
,
2
,
3
]],
dtype
=
targets_dtype
,
device
=
self
.
device
)
input_lengths
=
torch
.
randint
(
1
,
5
,
(
3
,
5
),
device
=
self
.
device
)
with
self
.
assertRaisesRegex
(
RuntimeError
,
r
"input_lengths must be 1-D"
):
hyp_path
,
hyp_scores
=
F
.
forced_align
(
log_probs
,
targets
,
input_lengths
,
target_lengths
,
blank
)
input_lengths
=
torch
.
tensor
([
log_probs
.
shape
[
0
]],
device
=
self
.
device
)
target_lengths
=
torch
.
randint
(
1
,
5
,
(
3
,
5
),
device
=
self
.
device
)
with
self
.
assertRaisesRegex
(
RuntimeError
,
r
"target_lengths must be 1-D"
):
hyp_path
,
hyp_scores
=
F
.
forced_align
(
log_probs
,
targets
,
input_lengths
,
target_lengths
,
blank
)
input_lengths
=
torch
.
tensor
([
10000
],
device
=
self
.
device
)
target_lengths
=
torch
.
tensor
([
targets
.
shape
[
1
]],
device
=
self
.
device
)
with
self
.
assertRaisesRegex
(
RuntimeError
,
r
"input length mismatch"
):
hyp_path
,
hyp_scores
=
F
.
forced_align
(
log_probs
,
targets
,
input_lengths
,
target_lengths
,
blank
)
input_lengths
=
torch
.
tensor
([
log_probs
.
shape
[
1
]],
device
=
self
.
device
)
target_lengths
=
torch
.
tensor
([
10000
],
device
=
self
.
device
)
with
self
.
assertRaisesRegex
(
RuntimeError
,
r
"target length mismatch"
):
hyp_path
,
hyp_scores
=
F
.
forced_align
(
log_probs
,
targets
,
input_lengths
,
target_lengths
,
blank
)
targets
=
torch
.
tensor
([[
7
,
8
,
9
,
10
]],
dtype
=
targets_dtype
,
device
=
self
.
device
)
log_probs
=
torch
.
rand
(
1
,
10
,
5
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
with
self
.
assertRaisesRegex
(
ValueError
,
r
"targets values must be less than the CTC dimension"
):
hyp_path
,
hyp_scores
=
F
.
forced_align
(
log_probs
,
targets
,
input_lengths
,
target_lengths
,
blank
)
targets
=
torch
.
tensor
([[
1
,
3
,
3
]],
dtype
=
targets_dtype
,
device
=
self
.
device
)
blank
=
10000
with
self
.
assertRaisesRegex
(
RuntimeError
,
r
"blank must be within \[0, num classes\)"
):
hyp_path
,
hyp_scores
=
F
.
forced_align
(
log_probs
,
targets
,
input_lengths
,
target_lengths
,
blank
)
def
_assert_tokens
(
self
,
first
,
second
):
assert
len
(
first
)
==
len
(
second
)
for
f
,
s
in
zip
(
first
,
second
):
self
.
assertEqual
(
f
.
token
,
s
.
token
)
self
.
assertEqual
(
f
.
score
,
s
.
score
)
self
.
assertEqual
(
f
.
start
,
s
.
start
)
self
.
assertEqual
(
f
.
end
,
s
.
end
)
@
parameterized
.
expand
(
[
([],
[],
[]),
([
F
.
TokenSpan
(
1
,
0
,
1
,
1.0
)],
[
1
],
[
1.0
]),
([
F
.
TokenSpan
(
1
,
0
,
2
,
0.5
)],
[
1
,
1
],
[
0.4
,
0.6
]),
([
F
.
TokenSpan
(
1
,
0
,
3
,
0.6
)],
[
1
,
1
,
1
],
[
0.5
,
0.6
,
0.7
]),
([
F
.
TokenSpan
(
1
,
0
,
1
,
0.8
),
F
.
TokenSpan
(
2
,
1
,
2
,
0.9
)],
[
1
,
2
],
[
0.8
,
0.9
]),
([
F
.
TokenSpan
(
1
,
0
,
1
,
1.0
),
F
.
TokenSpan
(
2
,
1
,
3
,
0.5
)],
[
1
,
2
,
2
],
[
1.0
,
0.4
,
0.6
]),
([
F
.
TokenSpan
(
1
,
0
,
1
,
0.8
),
F
.
TokenSpan
(
1
,
2
,
3
,
1.0
)],
[
1
,
0
,
1
],
[
0.8
,
0.9
,
1.0
]),
([
F
.
TokenSpan
(
1
,
0
,
1
,
0.8
),
F
.
TokenSpan
(
2
,
2
,
3
,
1.0
)],
[
1
,
0
,
2
],
[
0.8
,
0.9
,
1.0
]),
([
F
.
TokenSpan
(
1
,
0
,
1
,
1.0
),
F
.
TokenSpan
(
1
,
2
,
4
,
0.5
)],
[
1
,
0
,
1
,
1
],
[
1.0
,
0.1
,
0.4
,
0.6
]),
([
F
.
TokenSpan
(
1
,
0
,
1
,
1.0
),
F
.
TokenSpan
(
2
,
2
,
4
,
0.5
)],
[
1
,
0
,
2
,
2
],
[
1.0
,
0.1
,
0.4
,
0.6
]),
([
F
.
TokenSpan
(
1
,
0
,
1
,
1.0
),
F
.
TokenSpan
(
1
,
3
,
4
,
0.4
)],
[
1
,
0
,
0
,
1
],
[
1.0
,
0.9
,
0.7
,
0.4
]),
([
F
.
TokenSpan
(
1
,
0
,
1
,
1.0
),
F
.
TokenSpan
(
2
,
3
,
4
,
0.4
)],
[
1
,
0
,
0
,
2
],
[
1.0
,
0.9
,
0.7
,
0.4
]),
([
F
.
TokenSpan
(
1
,
0
,
1
,
1.0
),
F
.
TokenSpan
(
1
,
3
,
5
,
0.5
)],
[
1
,
0
,
0
,
1
,
1
],
[
1.0
,
0.9
,
0.8
,
0.6
,
0.4
]),
([
F
.
TokenSpan
(
1
,
0
,
1
,
1.0
),
F
.
TokenSpan
(
2
,
3
,
5
,
0.5
)],
[
1
,
0
,
0
,
2
,
2
],
[
1.0
,
0.9
,
0.8
,
0.6
,
0.4
]),
([
F
.
TokenSpan
(
1
,
0
,
2
,
0.9
),
F
.
TokenSpan
(
2
,
2
,
3
,
0.5
)],
[
1
,
1
,
2
],
[
1.0
,
0.8
,
0.5
]),
([
F
.
TokenSpan
(
1
,
0
,
2
,
0.9
),
F
.
TokenSpan
(
1
,
3
,
4
,
0.7
)],
[
1
,
1
,
0
,
1
],
[
1.0
,
0.8
,
0.1
,
0.7
]),
([
F
.
TokenSpan
(
1
,
0
,
2
,
0.9
),
F
.
TokenSpan
(
2
,
3
,
4
,
0.7
)],
[
1
,
1
,
0
,
2
],
[
1.0
,
0.8
,
0.1
,
0.7
]),
([
F
.
TokenSpan
(
1
,
0
,
2
,
0.9
),
F
.
TokenSpan
(
1
,
3
,
5
,
0.4
)],
[
1
,
1
,
0
,
1
,
1
],
[
1.0
,
0.8
,
0.1
,
0.5
,
0.3
]),
([
F
.
TokenSpan
(
1
,
0
,
2
,
0.9
),
F
.
TokenSpan
(
2
,
3
,
5
,
0.4
)],
[
1
,
1
,
0
,
2
,
2
],
[
1.0
,
0.8
,
0.1
,
0.5
,
0.3
]),
([
F
.
TokenSpan
(
1
,
0
,
2
,
0.9
),
F
.
TokenSpan
(
1
,
4
,
5
,
0.3
)],
[
1
,
1
,
0
,
0
,
1
],
[
1.0
,
0.8
,
0.1
,
0.5
,
0.3
]),
([
F
.
TokenSpan
(
1
,
0
,
2
,
0.9
),
F
.
TokenSpan
(
2
,
4
,
5
,
0.3
)],
[
1
,
1
,
0
,
0
,
2
],
[
1.0
,
0.8
,
0.1
,
0.5
,
0.3
]),
(
[
F
.
TokenSpan
(
1
,
0
,
2
,
0.9
),
F
.
TokenSpan
(
1
,
4
,
6
,
0.2
)],
[
1
,
1
,
0
,
0
,
1
,
1
],
[
1.0
,
0.8
,
0.6
,
0.5
,
0.3
,
0.1
],
),
(
[
F
.
TokenSpan
(
1
,
0
,
2
,
0.9
),
F
.
TokenSpan
(
2
,
4
,
6
,
0.2
)],
[
1
,
1
,
0
,
0
,
2
,
2
],
[
1.0
,
0.8
,
0.6
,
0.5
,
0.3
,
0.1
],
),
]
)
def
test_merge_repeated_tokens
(
self
,
expected
,
tokens
,
scores
):
scores_
=
torch
.
tensor
(
scores
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
tokens_
=
torch
.
tensor
(
tokens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
spans
=
F
.
merge_tokens
(
tokens_
,
scores_
,
blank
=
0
)
print
(
tokens_
,
scores_
)
self
.
_assert_tokens
(
spans
,
expected
)
# Append blanks at the beginning and at the end.
for
num_prefix
,
num_suffix
in
itertools
.
product
([
0
,
1
,
2
],
repeat
=
2
):
tokens_
=
([
0
]
*
num_prefix
)
+
tokens
+
([
0
]
*
num_suffix
)
scores_
=
([
0.1
]
*
num_prefix
)
+
scores
+
([
0.1
]
*
num_suffix
)
tokens_
=
torch
.
tensor
(
tokens_
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
scores_
=
torch
.
tensor
(
scores_
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
expected_
=
[
F
.
TokenSpan
(
s
.
token
,
s
.
start
+
num_prefix
,
s
.
end
+
num_prefix
,
s
.
score
)
for
s
in
expected
]
print
(
tokens_
,
scores_
)
spans
=
F
.
merge_tokens
(
tokens_
,
scores_
,
blank
=
0
)
self
.
_assert_tokens
(
spans
,
expected_
)
def
test_frechet_distance_univariate
(
self
):
r
"""Check that Frechet distance is computed correctly for simple one-dimensional case."""
mu_x
=
torch
.
rand
((
1
,),
device
=
self
.
device
)
sigma_x
=
torch
.
rand
((
1
,
1
),
device
=
self
.
device
)
mu_y
=
torch
.
rand
((
1
,),
device
=
self
.
device
)
sigma_y
=
torch
.
rand
((
1
,
1
),
device
=
self
.
device
)
# Matrix square root reduces to scalar square root.
expected
=
(
mu_x
-
mu_y
)
**
2
+
sigma_x
+
sigma_y
-
2
*
torch
.
sqrt
(
sigma_x
*
sigma_y
)
expected
=
expected
.
item
()
actual
=
F
.
frechet_distance
(
mu_x
,
sigma_x
,
mu_y
,
sigma_y
)
self
.
assertEqual
(
expected
,
actual
)
def
test_frechet_distance_diagonal_covariance
(
self
):
r
"""Check that Frechet distance is computed correctly for case where covariance matrices are diagonal."""
N
=
15
mu_x
=
torch
.
rand
((
N
,),
device
=
self
.
device
)
sigma_x
=
torch
.
diag
(
torch
.
rand
((
N
,),
device
=
self
.
device
))
mu_y
=
torch
.
rand
((
N
,),
device
=
self
.
device
)
sigma_y
=
torch
.
diag
(
torch
.
rand
((
N
,),
device
=
self
.
device
))
expected
=
(
torch
.
sum
((
mu_x
-
mu_y
)
**
2
)
+
torch
.
sum
(
sigma_x
+
sigma_y
)
-
2
*
torch
.
sum
(
torch
.
sqrt
(
sigma_x
*
sigma_y
))
)
expected
=
expected
.
item
()
actual
=
F
.
frechet_distance
(
mu_x
,
sigma_x
,
mu_y
,
sigma_y
)
self
.
assertEqual
(
expected
,
actual
)
class
FunctionalCPUOnly
(
TestBaseMixin
):
class
FunctionalCPUOnly
(
TestBaseMixin
):
def
test_melscale_fbanks_no_warning_high_n_freq
(
self
):
def
test_melscale_fbanks_no_warning_high_n_freq
(
self
):
...
@@ -898,3 +1333,27 @@ class FunctionalCPUOnly(TestBaseMixin):
...
@@ -898,3 +1333,27 @@ class FunctionalCPUOnly(TestBaseMixin):
warnings
.
simplefilter
(
"always"
)
warnings
.
simplefilter
(
"always"
)
F
.
melscale_fbanks
(
201
,
0
,
8000
,
128
,
16000
)
F
.
melscale_fbanks
(
201
,
0
,
8000
,
128
,
16000
)
assert
len
(
w
)
==
1
assert
len
(
w
)
==
1
class
FunctionalCUDAOnly
(
TestBaseMixin
):
@
nested_params
(
[
torch
.
half
,
torch
.
float
,
torch
.
double
],
[
torch
.
int32
,
torch
.
int64
],
[(
1
,
50
,
100
),
(
1
,
100
,
100
)],
[(
1
,
10
),
(
1
,
40
),
(
1
,
45
)],
)
def
test_forced_align_same_result
(
self
,
log_probs_dtype
,
targets_dtype
,
log_probs_shape
,
targets_shape
):
log_probs
=
torch
.
rand
(
log_probs_shape
,
dtype
=
log_probs_dtype
,
device
=
self
.
device
)
targets
=
torch
.
randint
(
1
,
100
,
targets_shape
,
dtype
=
targets_dtype
,
device
=
self
.
device
)
input_lengths
=
torch
.
tensor
([
log_probs
.
shape
[
1
]],
device
=
self
.
device
)
target_lengths
=
torch
.
tensor
([
targets
.
shape
[
1
]],
device
=
self
.
device
)
log_probs_cuda
=
log_probs
.
cuda
()
targets_cuda
=
targets
.
cuda
()
input_lengths_cuda
=
input_lengths
.
cuda
()
target_lengths_cuda
=
target_lengths
.
cuda
()
hyp_path
,
hyp_scores
=
F
.
forced_align
(
log_probs
,
targets
,
input_lengths
,
target_lengths
)
hyp_path_cuda
,
hyp_scores_cuda
=
F
.
forced_align
(
log_probs_cuda
,
targets_cuda
,
input_lengths_cuda
,
target_lengths_cuda
)
self
.
assertEqual
(
hyp_path
,
hyp_path_cuda
.
cpu
())
self
.
assertEqual
(
hyp_scores
,
hyp_scores_cuda
.
cpu
())
test/torchaudio_unittest/functional/kaldi_compatibility_cpu_test.py
View file @
ffeba11a
import
torch
import
torch
from
torchaudio_unittest.common_utils
import
PytorchTestCase
from
torchaudio_unittest.common_utils
import
PytorchTestCase
from
.kaldi_compatibility_test_impl
import
Kaldi
,
KaldiCPUOnly
from
.kaldi_compatibility_test_impl
import
Kaldi
class
TestKaldiCPUOnly
(
KaldiCPUOnly
,
PytorchTestCase
):
dtype
=
torch
.
float32
device
=
torch
.
device
(
"cpu"
)
class
TestKaldiFloat32
(
Kaldi
,
PytorchTestCase
):
class
TestKaldiFloat32
(
Kaldi
,
PytorchTestCase
):
...
...
test/torchaudio_unittest/functional/kaldi_compatibility_test_impl.py
View file @
ffeba11a
import
torch
import
torch
import
torchaudio.functional
as
F
import
torchaudio.functional
as
F
from
parameterized
import
parameterized
from
torchaudio_unittest.common_utils
import
skipIfNoExec
,
TempDirMixin
,
TestBaseMixin
from
torchaudio_unittest.common_utils
import
(
get_sinusoid
,
load_params
,
save_wav
,
skipIfNoExec
,
TempDirMixin
,
TestBaseMixin
,
)
from
torchaudio_unittest.common_utils.kaldi_utils
import
convert_args
,
run_kaldi
from
torchaudio_unittest.common_utils.kaldi_utils
import
convert_args
,
run_kaldi
...
@@ -32,25 +24,3 @@ class Kaldi(TempDirMixin, TestBaseMixin):
...
@@ -32,25 +24,3 @@ class Kaldi(TempDirMixin, TestBaseMixin):
command
=
[
"apply-cmvn-sliding"
]
+
convert_args
(
**
kwargs
)
+
[
"ark:-"
,
"ark:-"
]
command
=
[
"apply-cmvn-sliding"
]
+
convert_args
(
**
kwargs
)
+
[
"ark:-"
,
"ark:-"
]
kaldi_result
=
run_kaldi
(
command
,
"ark"
,
tensor
)
kaldi_result
=
run_kaldi
(
command
,
"ark"
,
tensor
)
self
.
assert_equal
(
result
,
expected
=
kaldi_result
)
self
.
assert_equal
(
result
,
expected
=
kaldi_result
)
class
KaldiCPUOnly
(
TempDirMixin
,
TestBaseMixin
):
def
assert_equal
(
self
,
output
,
*
,
expected
,
rtol
=
None
,
atol
=
None
):
expected
=
expected
.
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
self
.
assertEqual
(
output
,
expected
,
rtol
=
rtol
,
atol
=
atol
)
@
parameterized
.
expand
(
load_params
(
"kaldi_test_pitch_args.jsonl"
))
@
skipIfNoExec
(
"compute-kaldi-pitch-feats"
)
def
test_pitch_feats
(
self
,
kwargs
):
"""compute_kaldi_pitch produces numerically compatible result with compute-kaldi-pitch-feats"""
sample_rate
=
kwargs
[
"sample_rate"
]
waveform
=
get_sinusoid
(
dtype
=
"float32"
,
sample_rate
=
sample_rate
)
result
=
F
.
compute_kaldi_pitch
(
waveform
[
0
],
**
kwargs
)
waveform
=
get_sinusoid
(
dtype
=
"int16"
,
sample_rate
=
sample_rate
)
wave_file
=
self
.
get_temp_path
(
"test.wav"
)
save_wav
(
wave_file
,
waveform
,
sample_rate
)
command
=
[
"compute-kaldi-pitch-feats"
]
+
convert_args
(
**
kwargs
)
+
[
"scp:-"
,
"ark:-"
]
kaldi_result
=
run_kaldi
(
command
,
"scp"
,
wave_file
)
self
.
assert_equal
(
result
,
expected
=
kaldi_result
)
Prev
1
…
8
9
10
11
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment