Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
9dcc7a15
Commit
9dcc7a15
authored
Apr 25, 2022
by
flyingdown
Browse files
init v0.10.0
parent
db2b0b79
Pipeline
#254
failed with stages
in 0 seconds
Changes
416
Pipelines
2
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3213 additions
and
0 deletions
+3213
-0
test/torchaudio_unittest/backend/soundfile/info_test.py
test/torchaudio_unittest/backend/soundfile/info_test.py
+190
-0
test/torchaudio_unittest/backend/soundfile/load_test.py
test/torchaudio_unittest/backend/soundfile/load_test.py
+357
-0
test/torchaudio_unittest/backend/soundfile/save_test.py
test/torchaudio_unittest/backend/soundfile/save_test.py
+295
-0
test/torchaudio_unittest/backend/sox_io/__init__.py
test/torchaudio_unittest/backend/sox_io/__init__.py
+0
-0
test/torchaudio_unittest/backend/sox_io/common.py
test/torchaudio_unittest/backend/sox_io/common.py
+14
-0
test/torchaudio_unittest/backend/sox_io/info_test.py
test/torchaudio_unittest/backend/sox_io/info_test.py
+537
-0
test/torchaudio_unittest/backend/sox_io/load_test.py
test/torchaudio_unittest/backend/sox_io/load_test.py
+535
-0
test/torchaudio_unittest/backend/sox_io/roundtrip_test.py
test/torchaudio_unittest/backend/sox_io/roundtrip_test.py
+54
-0
test/torchaudio_unittest/backend/sox_io/save_test.py
test/torchaudio_unittest/backend/sox_io/save_test.py
+402
-0
test/torchaudio_unittest/backend/sox_io/smoke_test.py
test/torchaudio_unittest/backend/sox_io/smoke_test.py
+155
-0
test/torchaudio_unittest/backend/sox_io/torchscript_test.py
test/torchaudio_unittest/backend/sox_io/torchscript_test.py
+148
-0
test/torchaudio_unittest/backend/utils_test.py
test/torchaudio_unittest/backend/utils_test.py
+36
-0
test/torchaudio_unittest/common_utils/__init__.py
test/torchaudio_unittest/common_utils/__init__.py
+63
-0
test/torchaudio_unittest/common_utils/backend_utils.py
test/torchaudio_unittest/common_utils/backend_utils.py
+21
-0
test/torchaudio_unittest/common_utils/case_utils.py
test/torchaudio_unittest/common_utils/case_utils.py
+123
-0
test/torchaudio_unittest/common_utils/data_utils.py
test/torchaudio_unittest/common_utils/data_utils.py
+155
-0
test/torchaudio_unittest/common_utils/func_utils.py
test/torchaudio_unittest/common_utils/func_utils.py
+10
-0
test/torchaudio_unittest/common_utils/kaldi_utils.py
test/torchaudio_unittest/common_utils/kaldi_utils.py
+38
-0
test/torchaudio_unittest/common_utils/parameterized_utils.py
test/torchaudio_unittest/common_utils/parameterized_utils.py
+53
-0
test/torchaudio_unittest/common_utils/psd_utils.py
test/torchaudio_unittest/common_utils/psd_utils.py
+27
-0
No files found.
Too many changes to show.
To preserve performance only
416 of 416+
files are displayed.
Plain diff
Email patch
test/torchaudio_unittest/backend/soundfile/info_test.py
0 → 100644
View file @
9dcc7a15
from
unittest.mock
import
patch
import
warnings
import
tarfile
import
torch
from
torchaudio.backend
import
soundfile_backend
from
torchaudio._internal
import
module_utils
as
_mod_utils
from
torchaudio_unittest.common_utils
import
(
TempDirMixin
,
PytorchTestCase
,
skipIfNoModule
,
get_wav_data
,
save_wav
,
nested_params
,
)
from
torchaudio_unittest.backend.common
import
(
get_bits_per_sample
,
get_encoding
,
)
from
.common
import
skipIfFormatNotSupported
,
parameterize
if
_mod_utils
.
is_module_available
(
"soundfile"
):
import
soundfile
@
skipIfNoModule
(
"soundfile"
)
class
TestInfo
(
TempDirMixin
,
PytorchTestCase
):
@
parameterize
(
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
[
8000
,
16000
],
[
1
,
2
],
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`soundfile_backend.info` can check wav file correctly"""
duration
=
1
path
=
self
.
get_temp_path
(
"data.wav"
)
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
path
,
data
,
sample_rate
)
info
=
soundfile_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
get_bits_per_sample
(
"wav"
,
dtype
)
assert
info
.
encoding
==
get_encoding
(
"wav"
,
dtype
)
@
parameterize
([
8000
,
16000
],
[
1
,
2
])
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_flac
(
self
,
sample_rate
,
num_channels
):
"""`soundfile_backend.info` can check flac file correctly"""
duration
=
1
num_frames
=
sample_rate
*
duration
data
=
torch
.
randn
(
num_frames
,
num_channels
).
numpy
()
path
=
self
.
get_temp_path
(
"data.flac"
)
soundfile
.
write
(
path
,
data
,
sample_rate
)
info
=
soundfile_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
num_frames
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
16
assert
info
.
encoding
==
"FLAC"
@
parameterize
([
8000
,
16000
],
[
1
,
2
])
@
skipIfFormatNotSupported
(
"OGG"
)
def
test_ogg
(
self
,
sample_rate
,
num_channels
):
"""`soundfile_backend.info` can check ogg file correctly"""
duration
=
1
num_frames
=
sample_rate
*
duration
data
=
torch
.
randn
(
num_frames
,
num_channels
).
numpy
()
path
=
self
.
get_temp_path
(
"data.ogg"
)
soundfile
.
write
(
path
,
data
,
sample_rate
)
info
=
soundfile_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
assert
info
.
encoding
==
"VORBIS"
@
nested_params
(
[
8000
,
16000
],
[
1
,
2
],
[
(
'PCM_24'
,
24
),
(
'PCM_32'
,
32
)
],
)
@
skipIfFormatNotSupported
(
"NIST"
)
def
test_sphere
(
self
,
sample_rate
,
num_channels
,
subtype_and_bit_depth
):
"""`soundfile_backend.info` can check sph file correctly"""
duration
=
1
num_frames
=
sample_rate
*
duration
data
=
torch
.
randn
(
num_frames
,
num_channels
).
numpy
()
path
=
self
.
get_temp_path
(
"data.nist"
)
subtype
,
bits_per_sample
=
subtype_and_bit_depth
soundfile
.
write
(
path
,
data
,
sample_rate
,
subtype
=
subtype
)
info
=
soundfile_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
encoding
==
"PCM_S"
def
test_unknown_subtype_warning
(
self
):
"""soundfile_backend.info issues a warning when the subtype is unknown
This will happen if a new subtype is supported in SoundFile: the _SUBTYPE_TO_BITS_PER_SAMPLE
dict should be updated.
"""
def
_mock_info_func
(
_
):
class
MockSoundFileInfo
:
samplerate
=
8000
frames
=
356
channels
=
2
subtype
=
'UNSEEN_SUBTYPE'
format
=
'UNKNOWN'
return
MockSoundFileInfo
()
with
patch
(
"soundfile.info"
,
_mock_info_func
):
with
warnings
.
catch_warnings
(
record
=
True
)
as
w
:
info
=
soundfile_backend
.
info
(
"foo"
)
assert
len
(
w
)
==
1
assert
"UNSEEN_SUBTYPE subtype is unknown to TorchAudio"
in
str
(
w
[
-
1
].
message
)
assert
info
.
bits_per_sample
==
0
@
skipIfNoModule
(
"soundfile"
)
class
TestFileObject
(
TempDirMixin
,
PytorchTestCase
):
def
_test_fileobj
(
self
,
ext
,
subtype
,
bits_per_sample
):
"""Query audio via file-like object works"""
duration
=
2
sample_rate
=
16000
num_channels
=
2
num_frames
=
sample_rate
*
duration
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
data
=
torch
.
randn
(
num_frames
,
num_channels
).
numpy
()
soundfile
.
write
(
path
,
data
,
sample_rate
,
subtype
=
subtype
)
with
open
(
path
,
'rb'
)
as
fileobj
:
info
=
soundfile_backend
.
info
(
fileobj
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
num_frames
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
encoding
==
"FLAC"
if
ext
==
'flac'
else
"PCM_S"
def
test_fileobj_wav
(
self
):
"""Loading audio via file-like object works"""
self
.
_test_fileobj
(
'wav'
,
'PCM_16'
,
16
)
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_fileobj_flac
(
self
):
"""Loading audio via file-like object works"""
self
.
_test_fileobj
(
'flac'
,
'PCM_16'
,
16
)
def
_test_tarobj
(
self
,
ext
,
subtype
,
bits_per_sample
):
"""Query compressed audio via file-like object works"""
duration
=
2
sample_rate
=
16000
num_channels
=
2
num_frames
=
sample_rate
*
duration
audio_file
=
f
'test.
{
ext
}
'
audio_path
=
self
.
get_temp_path
(
audio_file
)
archive_path
=
self
.
get_temp_path
(
'archive.tar.gz'
)
data
=
torch
.
randn
(
num_frames
,
num_channels
).
numpy
()
soundfile
.
write
(
audio_path
,
data
,
sample_rate
,
subtype
=
subtype
)
with
tarfile
.
TarFile
(
archive_path
,
'w'
)
as
tarobj
:
tarobj
.
add
(
audio_path
,
arcname
=
audio_file
)
with
tarfile
.
TarFile
(
archive_path
,
'r'
)
as
tarobj
:
fileobj
=
tarobj
.
extractfile
(
audio_file
)
info
=
soundfile_backend
.
info
(
fileobj
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
num_frames
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
encoding
==
"FLAC"
if
ext
==
'flac'
else
"PCM_S"
def
test_tarobj_wav
(
self
):
"""Query compressed audio via file-like object works"""
self
.
_test_tarobj
(
'wav'
,
'PCM_16'
,
16
)
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_tarobj_flac
(
self
):
"""Query compressed audio via file-like object works"""
self
.
_test_tarobj
(
'flac'
,
'PCM_16'
,
16
)
test/torchaudio_unittest/backend/soundfile/load_test.py
0 → 100644
View file @
9dcc7a15
import
os
import
tarfile
from
unittest.mock
import
patch
import
torch
from
torchaudio._internal
import
module_utils
as
_mod_utils
from
torchaudio.backend
import
soundfile_backend
from
parameterized
import
parameterized
from
torchaudio_unittest.common_utils
import
(
TempDirMixin
,
PytorchTestCase
,
skipIfNoModule
,
get_wav_data
,
normalize_wav
,
load_wav
,
save_wav
,
)
from
.common
import
(
parameterize
,
dtype2subtype
,
skipIfFormatNotSupported
,
)
if
_mod_utils
.
is_module_available
(
"soundfile"
):
import
soundfile
def
_get_mock_path
(
ext
:
str
,
dtype
:
str
,
sample_rate
:
int
,
num_channels
:
int
,
num_frames
:
int
,
):
return
f
"
{
dtype
}
_
{
sample_rate
}
_
{
num_channels
}
_
{
num_frames
}
.
{
ext
}
"
def
_get_mock_params
(
path
:
str
):
filename
,
ext
=
path
.
split
(
"."
)
parts
=
filename
.
split
(
"_"
)
return
{
"ext"
:
ext
,
"dtype"
:
parts
[
0
],
"sample_rate"
:
int
(
parts
[
1
]),
"num_channels"
:
int
(
parts
[
2
]),
"num_frames"
:
int
(
parts
[
3
]),
}
class
SoundFileMock
:
def
__init__
(
self
,
path
,
mode
):
assert
mode
==
"r"
self
.
path
=
path
self
.
_params
=
_get_mock_params
(
path
)
self
.
_start
=
None
@
property
def
samplerate
(
self
):
return
self
.
_params
[
"sample_rate"
]
@
property
def
format
(
self
):
if
self
.
_params
[
"ext"
]
==
"wav"
:
return
"WAV"
if
self
.
_params
[
"ext"
]
==
"flac"
:
return
"FLAC"
if
self
.
_params
[
"ext"
]
==
"ogg"
:
return
"OGG"
if
self
.
_params
[
"ext"
]
in
[
"sph"
,
"nis"
,
"nist"
]:
return
"NIST"
@
property
def
subtype
(
self
):
if
self
.
_params
[
"ext"
]
==
"ogg"
:
return
"VORBIS"
return
dtype2subtype
(
self
.
_params
[
"dtype"
])
def
_prepare_read
(
self
,
start
,
stop
,
frames
):
assert
stop
is
None
self
.
_start
=
start
return
frames
def
read
(
self
,
frames
,
dtype
,
always_2d
):
assert
always_2d
data
=
get_wav_data
(
dtype
,
self
.
_params
[
"num_channels"
],
normalize
=
False
,
num_frames
=
self
.
_params
[
"num_frames"
],
channels_first
=
False
,
).
numpy
()
return
data
[
self
.
_start
:
self
.
_start
+
frames
]
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
*
args
,
**
kwargs
):
pass
class
MockedLoadTest
(
PytorchTestCase
):
def
assert_dtype
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
):
"""When format is WAV or NIST, normalize=False will return the native dtype Tensor, otherwise float32"""
num_frames
=
3
*
sample_rate
path
=
_get_mock_path
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
expected_dtype
=
(
torch
.
float32
if
normalize
or
ext
not
in
[
"wav"
,
"nist"
]
else
getattr
(
torch
,
dtype
)
)
with
patch
(
"soundfile.SoundFile"
,
SoundFileMock
):
found
,
sr
=
soundfile_backend
.
load
(
path
,
normalize
=
normalize
,
channels_first
=
channels_first
)
assert
found
.
dtype
==
expected_dtype
assert
sample_rate
==
sr
@
parameterize
(
[
"uint8"
,
"int16"
,
"int32"
,
"float32"
,
"float64"
],
[
8000
,
16000
],
[
1
,
2
],
[
True
,
False
],
[
True
,
False
],
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
):
"""Returns native dtype when normalize=False else float32"""
self
.
assert_dtype
(
"wav"
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
)
@
parameterize
(
[
"int8"
,
"int16"
,
"int32"
],
[
8000
,
16000
],
[
1
,
2
],
[
True
,
False
],
[
True
,
False
],
)
def
test_sphere
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
):
"""Returns float32 always"""
self
.
assert_dtype
(
"sph"
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
)
@
parameterize
([
8000
,
16000
],
[
1
,
2
],
[
True
,
False
],
[
True
,
False
])
def
test_ogg
(
self
,
sample_rate
,
num_channels
,
normalize
,
channels_first
):
"""Returns float32 always"""
self
.
assert_dtype
(
"ogg"
,
"int16"
,
sample_rate
,
num_channels
,
normalize
,
channels_first
)
@
parameterize
([
8000
,
16000
],
[
1
,
2
],
[
True
,
False
],
[
True
,
False
])
def
test_flac
(
self
,
sample_rate
,
num_channels
,
normalize
,
channels_first
):
"""`soundfile_backend.load` can load ogg format."""
self
.
assert_dtype
(
"flac"
,
"int16"
,
sample_rate
,
num_channels
,
normalize
,
channels_first
)
class
LoadTestBase
(
TempDirMixin
,
PytorchTestCase
):
def
assert_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
=
True
,
duration
=
1
,
):
"""`soundfile_backend.load` can load wav format correctly.
Wav data loaded with soundfile backend should match those with scipy
"""
path
=
self
.
get_temp_path
(
"reference.wav"
)
num_frames
=
duration
*
sample_rate
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
normalize
,
num_frames
=
num_frames
,
channels_first
=
channels_first
,
)
save_wav
(
path
,
data
,
sample_rate
,
channels_first
=
channels_first
)
expected
=
load_wav
(
path
,
normalize
=
normalize
,
channels_first
=
channels_first
)[
0
]
data
,
sr
=
soundfile_backend
.
load
(
path
,
normalize
=
normalize
,
channels_first
=
channels_first
)
assert
sr
==
sample_rate
self
.
assertEqual
(
data
,
expected
)
def
assert_sphere
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
=
True
,
duration
=
1
,
):
"""`soundfile_backend.load` can load SPHERE format correctly."""
path
=
self
.
get_temp_path
(
"reference.sph"
)
num_frames
=
duration
*
sample_rate
raw
=
get_wav_data
(
dtype
,
num_channels
,
num_frames
=
num_frames
,
normalize
=
False
,
channels_first
=
False
,
)
soundfile
.
write
(
path
,
raw
,
sample_rate
,
subtype
=
dtype2subtype
(
dtype
),
format
=
"NIST"
)
expected
=
normalize_wav
(
raw
.
t
()
if
channels_first
else
raw
)
data
,
sr
=
soundfile_backend
.
load
(
path
,
channels_first
=
channels_first
)
assert
sr
==
sample_rate
self
.
assertEqual
(
data
,
expected
,
atol
=
1e-4
,
rtol
=
1e-8
)
def
assert_flac
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
=
True
,
duration
=
1
,
):
"""`soundfile_backend.load` can load FLAC format correctly."""
path
=
self
.
get_temp_path
(
"reference.flac"
)
num_frames
=
duration
*
sample_rate
raw
=
get_wav_data
(
dtype
,
num_channels
,
num_frames
=
num_frames
,
normalize
=
False
,
channels_first
=
False
,
)
soundfile
.
write
(
path
,
raw
,
sample_rate
)
expected
=
normalize_wav
(
raw
.
t
()
if
channels_first
else
raw
)
data
,
sr
=
soundfile_backend
.
load
(
path
,
channels_first
=
channels_first
)
assert
sr
==
sample_rate
self
.
assertEqual
(
data
,
expected
,
atol
=
1e-4
,
rtol
=
1e-8
)
@
skipIfNoModule
(
"soundfile"
)
class
TestLoad
(
LoadTestBase
):
"""Test the correctness of `soundfile_backend.load` for various formats"""
@
parameterize
(
[
"float32"
,
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
[
False
,
True
],
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
):
"""`soundfile_backend.load` can load wav format correctly."""
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
)
@
parameterize
(
[
"int16"
],
[
16000
],
[
2
],
[
False
],
)
def
test_wav_large
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
):
"""`soundfile_backend.load` can load large wav file correctly."""
two_hours
=
2
*
60
*
60
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
normalize
,
duration
=
two_hours
)
@
parameterize
([
"float32"
,
"int32"
,
"int16"
],
[
4
,
8
,
16
,
32
],
[
False
,
True
])
def
test_multiple_channels
(
self
,
dtype
,
num_channels
,
channels_first
):
"""`soundfile_backend.load` can load wav file with more than 2 channels."""
sample_rate
=
8000
normalize
=
False
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
)
@
parameterize
([
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
])
@
skipIfFormatNotSupported
(
"NIST"
)
def
test_sphere
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
):
"""`soundfile_backend.load` can load sphere format correctly."""
self
.
assert_sphere
(
dtype
,
sample_rate
,
num_channels
,
channels_first
)
@
parameterize
([
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
])
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_flac
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
):
"""`soundfile_backend.load` can load flac format correctly."""
self
.
assert_flac
(
dtype
,
sample_rate
,
num_channels
,
channels_first
)
@
skipIfNoModule
(
"soundfile"
)
class
TestLoadFormat
(
TempDirMixin
,
PytorchTestCase
):
"""Given `format` parameter, `so.load` can load files without extension"""
original
=
None
path
=
None
def
_make_file
(
self
,
format_
):
sample_rate
=
8000
path_with_ext
=
self
.
get_temp_path
(
f
'test.
{
format_
}
'
)
data
=
get_wav_data
(
'float32'
,
num_channels
=
2
).
numpy
().
T
soundfile
.
write
(
path_with_ext
,
data
,
sample_rate
)
expected
=
soundfile
.
read
(
path_with_ext
,
dtype
=
'float32'
)[
0
].
T
path
=
os
.
path
.
splitext
(
path_with_ext
)[
0
]
os
.
rename
(
path_with_ext
,
path
)
return
path
,
expected
def
_test_format
(
self
,
format_
):
"""Providing format allows to read file without extension"""
path
,
expected
=
self
.
_make_file
(
format_
)
found
,
_
=
soundfile_backend
.
load
(
path
)
self
.
assertEqual
(
found
,
expected
)
@
parameterized
.
expand
([
(
'WAV'
,
),
(
'wav'
,
),
])
def
test_wav
(
self
,
format_
):
self
.
_test_format
(
format_
)
@
parameterized
.
expand
([
(
'FLAC'
,
),
(
'flac'
,),
])
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_flac
(
self
,
format_
):
self
.
_test_format
(
format_
)
@
skipIfNoModule
(
"soundfile"
)
class
TestFileObject
(
TempDirMixin
,
PytorchTestCase
):
def
_test_fileobj
(
self
,
ext
):
"""Loading audio via file-like object works"""
sample_rate
=
16000
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
data
=
get_wav_data
(
'float32'
,
num_channels
=
2
).
numpy
().
T
soundfile
.
write
(
path
,
data
,
sample_rate
)
expected
=
soundfile
.
read
(
path
,
dtype
=
'float32'
)[
0
].
T
with
open
(
path
,
'rb'
)
as
fileobj
:
found
,
sr
=
soundfile_backend
.
load
(
fileobj
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
def
test_fileobj_wav
(
self
):
"""Loading audio via file-like object works"""
self
.
_test_fileobj
(
'wav'
)
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_fileobj_flac
(
self
):
"""Loading audio via file-like object works"""
self
.
_test_fileobj
(
'flac'
)
def
_test_tarfile
(
self
,
ext
):
"""Loading audio via file-like object works"""
sample_rate
=
16000
audio_file
=
f
'test.
{
ext
}
'
audio_path
=
self
.
get_temp_path
(
audio_file
)
archive_path
=
self
.
get_temp_path
(
'archive.tar.gz'
)
data
=
get_wav_data
(
'float32'
,
num_channels
=
2
).
numpy
().
T
soundfile
.
write
(
audio_path
,
data
,
sample_rate
)
expected
=
soundfile
.
read
(
audio_path
,
dtype
=
'float32'
)[
0
].
T
with
tarfile
.
TarFile
(
archive_path
,
'w'
)
as
tarobj
:
tarobj
.
add
(
audio_path
,
arcname
=
audio_file
)
with
tarfile
.
TarFile
(
archive_path
,
'r'
)
as
tarobj
:
fileobj
=
tarobj
.
extractfile
(
audio_file
)
found
,
sr
=
soundfile_backend
.
load
(
fileobj
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
def
test_tarfile_wav
(
self
):
"""Loading audio via file-like object works"""
self
.
_test_tarfile
(
'wav'
)
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_tarfile_flac
(
self
):
"""Loading audio via file-like object works"""
self
.
_test_tarfile
(
'flac'
)
test/torchaudio_unittest/backend/soundfile/save_test.py
0 → 100644
View file @
9dcc7a15
import
io
from
unittest.mock
import
patch
from
torchaudio._internal
import
module_utils
as
_mod_utils
from
torchaudio.backend
import
soundfile_backend
from
torchaudio_unittest.common_utils
import
(
TempDirMixin
,
PytorchTestCase
,
skipIfNoModule
,
get_wav_data
,
load_wav
,
nested_params
,
)
from
.common
import
(
fetch_wav_subtype
,
parameterize
,
skipIfFormatNotSupported
,
)
if
_mod_utils
.
is_module_available
(
"soundfile"
):
import
soundfile
class
MockedSaveTest
(
PytorchTestCase
):
@
nested_params
(
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
[
(
None
,
None
),
(
'PCM_U'
,
None
),
(
'PCM_U'
,
8
),
(
'PCM_S'
,
None
),
(
'PCM_S'
,
16
),
(
'PCM_S'
,
32
),
(
'PCM_F'
,
None
),
(
'PCM_F'
,
32
),
(
'PCM_F'
,
64
),
(
'ULAW'
,
None
),
(
'ULAW'
,
8
),
(
'ALAW'
,
None
),
(
'ALAW'
,
8
),
],
)
@
patch
(
"soundfile.write"
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
enc_params
,
mocked_write
):
"""soundfile_backend.save passes correct subtype to soundfile.write when WAV"""
filepath
=
"foo.wav"
input_tensor
=
get_wav_data
(
dtype
,
num_channels
,
num_frames
=
3
*
sample_rate
,
normalize
=
dtype
==
"float32"
,
channels_first
=
channels_first
,
).
t
()
encoding
,
bits_per_sample
=
enc_params
soundfile_backend
.
save
(
filepath
,
input_tensor
,
sample_rate
,
channels_first
=
channels_first
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
)
# on +Py3.8 call_args.kwargs is more descreptive
args
=
mocked_write
.
call_args
[
1
]
assert
args
[
"file"
]
==
filepath
assert
args
[
"samplerate"
]
==
sample_rate
assert
args
[
"subtype"
]
==
fetch_wav_subtype
(
dtype
,
encoding
,
bits_per_sample
)
assert
args
[
"format"
]
is
None
self
.
assertEqual
(
args
[
"data"
],
input_tensor
.
t
()
if
channels_first
else
input_tensor
)
@
patch
(
"soundfile.write"
)
def
assert_non_wav
(
self
,
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
mocked_write
,
encoding
=
None
,
bits_per_sample
=
None
,
):
"""soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE"""
filepath
=
f
"foo.
{
fmt
}
"
input_tensor
=
get_wav_data
(
dtype
,
num_channels
,
num_frames
=
3
*
sample_rate
,
normalize
=
False
,
channels_first
=
channels_first
,
).
t
()
expected_data
=
input_tensor
.
t
()
if
channels_first
else
input_tensor
soundfile_backend
.
save
(
filepath
,
input_tensor
,
sample_rate
,
channels_first
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
)
# on +Py3.8 call_args.kwargs is more descreptive
args
=
mocked_write
.
call_args
[
1
]
assert
args
[
"file"
]
==
filepath
assert
args
[
"samplerate"
]
==
sample_rate
if
fmt
in
[
"sph"
,
"nist"
,
"nis"
]:
assert
args
[
"format"
]
==
"NIST"
else
:
assert
args
[
"format"
]
is
None
self
.
assertEqual
(
args
[
"data"
],
expected_data
)
@
nested_params
(
[
"sph"
,
"nist"
,
"nis"
],
[
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
[
(
'PCM_S'
,
8
),
(
'PCM_S'
,
16
),
(
'PCM_S'
,
24
),
(
'PCM_S'
,
32
),
(
'ULAW'
,
8
),
(
'ALAW'
,
8
),
(
'ALAW'
,
16
),
(
'ALAW'
,
24
),
(
'ALAW'
,
32
),
],
)
def
test_sph
(
self
,
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
enc_params
):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
encoding
,
bits_per_sample
=
enc_params
self
.
assert_non_wav
(
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
)
@
parameterize
(
[
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
[
8
,
16
,
24
],
)
def
test_flac
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
bits_per_sample
):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
self
.
assert_non_wav
(
"flac"
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
bits_per_sample
=
bits_per_sample
)
@
parameterize
(
[
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
)
def
test_ogg
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
self
.
assert_non_wav
(
"ogg"
,
dtype
,
sample_rate
,
num_channels
,
channels_first
)
@
skipIfNoModule
(
"soundfile"
)
class
SaveTestBase
(
TempDirMixin
,
PytorchTestCase
):
def
assert_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
"""`soundfile_backend.save` can save wav format."""
path
=
self
.
get_temp_path
(
"data.wav"
)
expected
=
get_wav_data
(
dtype
,
num_channels
,
num_frames
=
num_frames
,
normalize
=
False
)
soundfile_backend
.
save
(
path
,
expected
,
sample_rate
)
found
,
sr
=
load_wav
(
path
,
normalize
=
False
)
assert
sample_rate
==
sr
self
.
assertEqual
(
found
,
expected
)
def
_assert_non_wav
(
self
,
fmt
,
dtype
,
sample_rate
,
num_channels
):
"""`soundfile_backend.save` can save non-wav format.
Due to precision missmatch, and the lack of alternative way to decode the
resulting files without using soundfile, only meta data are validated.
"""
num_frames
=
sample_rate
*
3
path
=
self
.
get_temp_path
(
f
"data.
{
fmt
}
"
)
expected
=
get_wav_data
(
dtype
,
num_channels
,
num_frames
=
num_frames
,
normalize
=
False
)
soundfile_backend
.
save
(
path
,
expected
,
sample_rate
)
sinfo
=
soundfile
.
info
(
path
)
assert
sinfo
.
format
==
fmt
.
upper
()
assert
sinfo
.
frames
==
num_frames
assert
sinfo
.
channels
==
num_channels
assert
sinfo
.
samplerate
==
sample_rate
def
assert_flac
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`soundfile_backend.save` can save flac format."""
self
.
_assert_non_wav
(
"flac"
,
dtype
,
sample_rate
,
num_channels
)
def
assert_sphere
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`soundfile_backend.save` can save sph format."""
self
.
_assert_non_wav
(
"nist"
,
dtype
,
sample_rate
,
num_channels
)
def
assert_ogg
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`soundfile_backend.save` can save ogg format.
As we cannot inspect the OGG format (it's lossy), we only check the metadata.
"""
self
.
_assert_non_wav
(
"ogg"
,
dtype
,
sample_rate
,
num_channels
)
@
skipIfNoModule
(
"soundfile"
)
class
TestSave
(
SaveTestBase
):
@
parameterize
(
[
"float32"
,
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`soundfile_backend.save` can save wav format."""
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
num_frames
=
None
)
@
parameterize
(
[
"float32"
,
"int32"
,
"int16"
],
[
4
,
8
,
16
,
32
],
)
def
test_multiple_channels
(
self
,
dtype
,
num_channels
):
"""`soundfile_backend.save` can save wav with more than 2 channels."""
sample_rate
=
8000
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
num_frames
=
None
)
@
parameterize
(
[
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
)
@
skipIfFormatNotSupported
(
"NIST"
)
def
test_sphere
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`soundfile_backend.save` can save sph format."""
self
.
assert_sphere
(
dtype
,
sample_rate
,
num_channels
)
@
parameterize
(
[
8000
,
16000
],
[
1
,
2
],
)
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_flac
(
self
,
sample_rate
,
num_channels
):
"""`soundfile_backend.save` can save flac format."""
self
.
assert_flac
(
"float32"
,
sample_rate
,
num_channels
)
@
parameterize
(
[
8000
,
16000
],
[
1
,
2
],
)
@
skipIfFormatNotSupported
(
"OGG"
)
def
test_ogg
(
self
,
sample_rate
,
num_channels
):
"""`soundfile_backend.save` can save ogg/vorbis format."""
self
.
assert_ogg
(
"float32"
,
sample_rate
,
num_channels
)
@
skipIfNoModule
(
"soundfile"
)
class
TestSaveParams
(
TempDirMixin
,
PytorchTestCase
):
"""Test the correctness of optional parameters of `soundfile_backend.save`"""
@
parameterize
([
True
,
False
])
def
test_channels_first
(
self
,
channels_first
):
"""channels_first swaps axes"""
path
=
self
.
get_temp_path
(
"data.wav"
)
data
=
get_wav_data
(
"int32"
,
2
,
channels_first
=
channels_first
)
soundfile_backend
.
save
(
path
,
data
,
8000
,
channels_first
=
channels_first
)
found
=
load_wav
(
path
)[
0
]
expected
=
data
if
channels_first
else
data
.
transpose
(
1
,
0
)
self
.
assertEqual
(
found
,
expected
,
atol
=
1e-4
,
rtol
=
1e-8
)
@
skipIfNoModule
(
"soundfile"
)
class
TestFileObject
(
TempDirMixin
,
PytorchTestCase
):
def
_test_fileobj
(
self
,
ext
):
"""Saving audio to file-like object works"""
sample_rate
=
16000
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
subtype
=
'FLOAT'
if
ext
==
'wav'
else
None
data
=
get_wav_data
(
'float32'
,
num_channels
=
2
)
soundfile
.
write
(
path
,
data
.
numpy
().
T
,
sample_rate
,
subtype
=
subtype
)
expected
=
soundfile
.
read
(
path
,
dtype
=
'float32'
)[
0
]
fileobj
=
io
.
BytesIO
()
soundfile_backend
.
save
(
fileobj
,
data
,
sample_rate
,
format
=
ext
)
fileobj
.
seek
(
0
)
found
,
sr
=
soundfile
.
read
(
fileobj
,
dtype
=
'float32'
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
,
atol
=
1e-4
,
rtol
=
1e-8
)
def
test_fileobj_wav
(
self
):
"""Saving audio via file-like object works"""
self
.
_test_fileobj
(
'wav'
)
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_fileobj_flac
(
self
):
"""Saving audio via file-like object works"""
self
.
_test_fileobj
(
'flac'
)
@
skipIfFormatNotSupported
(
"NIST"
)
def
test_fileobj_nist
(
self
):
"""Saving audio via file-like object works"""
self
.
_test_fileobj
(
'NIST'
)
@
skipIfFormatNotSupported
(
"OGG"
)
def
test_fileobj_ogg
(
self
):
"""Saving audio via file-like object works"""
self
.
_test_fileobj
(
'OGG'
)
test/torchaudio_unittest/backend/sox_io/__init__.py
0 → 100644
View file @
9dcc7a15
test/torchaudio_unittest/backend/sox_io/common.py
0 → 100644
View file @
9dcc7a15
def
name_func
(
func
,
_
,
params
):
return
f
'
{
func
.
__name__
}
_
{
"_"
.
join
(
str
(
arg
)
for
arg
in
params
.
args
)
}
'
def
get_enc_params
(
dtype
):
if
dtype
==
'float32'
:
return
'PCM_F'
,
32
if
dtype
==
'int32'
:
return
'PCM_S'
,
32
if
dtype
==
'int16'
:
return
'PCM_S'
,
16
if
dtype
==
'uint8'
:
return
'PCM_U'
,
8
raise
ValueError
(
f
'Unexpected dtype:
{
dtype
}
'
)
test/torchaudio_unittest/backend/sox_io/info_test.py
0 → 100644
View file @
9dcc7a15
from
contextlib
import
contextmanager
import
io
import
os
import
itertools
import
tarfile
from
parameterized
import
parameterized
from
torchaudio.backend
import
sox_io_backend
from
torchaudio.utils.sox_utils
import
get_buffer_size
,
set_buffer_size
from
torchaudio._internal
import
module_utils
as
_mod_utils
from
torchaudio_unittest.backend.common
import
(
get_bits_per_sample
,
get_encoding
,
)
from
torchaudio_unittest.common_utils
import
(
TempDirMixin
,
HttpServerMixin
,
PytorchTestCase
,
skipIfNoExec
,
skipIfNoModule
,
skipIfNoSox
,
get_asset_path
,
get_wav_data
,
save_wav
,
sox_utils
,
)
from
.common
import
(
name_func
,
)
if
_mod_utils
.
is_module_available
(
"requests"
):
import
requests
@
skipIfNoExec
(
'sox'
)
@
skipIfNoSox
class
TestInfo
(
TempDirMixin
,
PytorchTestCase
):
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
8000
,
16000
],
[
1
,
2
],
)),
name_func
=
name_func
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`sox_io_backend.info` can check wav file correctly"""
duration
=
1
path
=
self
.
get_temp_path
(
'data.wav'
)
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
path
,
data
,
sample_rate
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
sox_utils
.
get_bit_depth
(
dtype
)
assert
info
.
encoding
==
get_encoding
(
'wav'
,
dtype
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
8000
,
16000
],
[
4
,
8
,
16
,
32
],
)),
name_func
=
name_func
)
def
test_wav_multiple_channels
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`sox_io_backend.info` can check wav file with channels more than 2 correctly"""
duration
=
1
path
=
self
.
get_temp_path
(
'data.wav'
)
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
path
,
data
,
sample_rate
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
sox_utils
.
get_bit_depth
(
dtype
)
assert
info
.
encoding
==
get_encoding
(
'wav'
,
dtype
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
[
96
,
128
,
160
,
192
,
224
,
256
,
320
],
)),
name_func
=
name_func
)
def
test_mp3
(
self
,
sample_rate
,
num_channels
,
bit_rate
):
"""`sox_io_backend.info` can check mp3 file correctly"""
duration
=
1
path
=
self
.
get_temp_path
(
'data.mp3'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
compression
=
bit_rate
,
duration
=
duration
,
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
# mp3 does not preserve the number of samples
# assert info.num_frames == sample_rate * duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
info
.
encoding
==
"MP3"
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
list
(
range
(
9
)),
)),
name_func
=
name_func
)
def
test_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
"""`sox_io_backend.info` can check flac file correctly"""
duration
=
1
path
=
self
.
get_temp_path
(
'data.flac'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
compression
=
compression_level
,
duration
=
duration
,
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
24
# FLAC standard
assert
info
.
encoding
==
"FLAC"
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
[
-
1
,
0
,
1
,
2
,
3
,
3.6
,
5
,
10
],
)),
name_func
=
name_func
)
def
test_vorbis
(
self
,
sample_rate
,
num_channels
,
quality_level
):
"""`sox_io_backend.info` can check vorbis file correctly"""
duration
=
1
path
=
self
.
get_temp_path
(
'data.vorbis'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
compression
=
quality_level
,
duration
=
duration
,
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
info
.
encoding
==
"VORBIS"
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
[
16
,
32
],
)),
name_func
=
name_func
)
def
test_sphere
(
self
,
sample_rate
,
num_channels
,
bits_per_sample
):
"""`sox_io_backend.info` can check sph file correctly"""
duration
=
1
path
=
self
.
get_temp_path
(
'data.sph'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
duration
=
duration
,
bit_depth
=
bits_per_sample
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
encoding
==
"PCM_S"
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'int32'
,
'int16'
,
'uint8'
],
[
8000
,
16000
],
[
1
,
2
],
)),
name_func
=
name_func
)
def
test_amb
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`sox_io_backend.info` can check amb file correctly"""
duration
=
1
path
=
self
.
get_temp_path
(
'data.amb'
)
bits_per_sample
=
sox_utils
.
get_bit_depth
(
dtype
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
bit_depth
=
bits_per_sample
,
duration
=
duration
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
encoding
==
get_encoding
(
"amb"
,
dtype
)
def
test_amr_nb
(
self
):
"""`sox_io_backend.info` can check amr-nb file correctly"""
duration
=
1
num_channels
=
1
sample_rate
=
8000
path
=
self
.
get_temp_path
(
'data.amr-nb'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
bit_depth
=
16
,
duration
=
duration
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
assert
info
.
encoding
==
"AMR_NB"
def
test_ulaw
(
self
):
"""`sox_io_backend.info` can check ulaw file correctly"""
duration
=
1
num_channels
=
1
sample_rate
=
8000
path
=
self
.
get_temp_path
(
'data.wav'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
bit_depth
=
8
,
encoding
=
'u-law'
,
duration
=
duration
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
8
assert
info
.
encoding
==
"ULAW"
def
test_alaw
(
self
):
"""`sox_io_backend.info` can check alaw file correctly"""
duration
=
1
num_channels
=
1
sample_rate
=
8000
path
=
self
.
get_temp_path
(
'data.wav'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
bit_depth
=
8
,
encoding
=
'a-law'
,
duration
=
duration
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
8
assert
info
.
encoding
==
"ALAW"
def
test_gsm
(
self
):
"""`sox_io_backend.info` can check gsm file correctly"""
duration
=
1
num_channels
=
1
sample_rate
=
8000
path
=
self
.
get_temp_path
(
'data.gsm'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
duration
=
duration
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
assert
info
.
encoding
==
"GSM"
def
test_htk
(
self
):
"""`sox_io_backend.info` can check HTK file correctly"""
duration
=
1
num_channels
=
1
sample_rate
=
8000
path
=
self
.
get_temp_path
(
'data.htk'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
bit_depth
=
16
,
duration
=
duration
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
16
assert
info
.
encoding
==
"PCM_S"
@
skipIfNoSox
class
TestInfoOpus
(
PytorchTestCase
):
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'96k'
],
[
1
,
2
],
[
0
,
5
,
10
],
)),
name_func
=
name_func
)
def
test_opus
(
self
,
bitrate
,
num_channels
,
compression_level
):
"""`sox_io_backend.info` can check opus file correcty"""
path
=
get_asset_path
(
'io'
,
f
'
{
bitrate
}
_
{
compression_level
}
_
{
num_channels
}
ch.opus'
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
sample_rate
==
48000
assert
info
.
num_frames
==
32768
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
info
.
encoding
==
"OPUS"
@
skipIfNoSox
class
TestLoadWithoutExtension
(
PytorchTestCase
):
def
test_mp3
(
self
):
"""Providing `format` allows to read mp3 without extension
libsox does not check header for mp3
https://github.com/pytorch/audio/issues/1040
The file was generated with the following command
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
"""
path
=
get_asset_path
(
"mp3_without_ext"
)
sinfo
=
sox_io_backend
.
info
(
path
,
format
=
"mp3"
)
assert
sinfo
.
sample_rate
==
16000
assert
sinfo
.
num_frames
==
81216
assert
sinfo
.
num_channels
==
1
assert
sinfo
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
sinfo
.
encoding
==
"MP3"
class
FileObjTestBase
(
TempDirMixin
):
def
_gen_file
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
*
,
comments
=
None
):
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
bit_depth
=
sox_utils
.
get_bit_depth
(
dtype
)
duration
=
num_frames
/
sample_rate
comment_file
=
self
.
_gen_comment_file
(
comments
)
if
comments
else
None
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
num_channels
,
encoding
=
sox_utils
.
get_encoding
(
dtype
),
bit_depth
=
bit_depth
,
duration
=
duration
,
comment_file
=
comment_file
,
)
return
path
def
_gen_comment_file
(
self
,
comments
):
comment_path
=
self
.
get_temp_path
(
"comment.txt"
)
with
open
(
comment_path
,
"w"
)
as
file_
:
file_
.
writelines
(
comments
)
return
comment_path
@
skipIfNoSox
@
skipIfNoExec
(
'sox'
)
class
TestFileObject
(
FileObjTestBase
,
PytorchTestCase
):
def
_query_fileobj
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
*
,
comments
=
None
):
path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
comments
=
comments
)
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
with
open
(
path
,
'rb'
)
as
fileobj
:
return
sox_io_backend
.
info
(
fileobj
,
format_
)
def
_query_bytesio
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
with
open
(
path
,
'rb'
)
as
file_
:
fileobj
=
io
.
BytesIO
(
file_
.
read
())
return
sox_io_backend
.
info
(
fileobj
,
format_
)
def
_query_tarfile
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
audio_path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
audio_file
=
os
.
path
.
basename
(
audio_path
)
archive_path
=
self
.
get_temp_path
(
'archive.tar.gz'
)
with
tarfile
.
TarFile
(
archive_path
,
'w'
)
as
tarobj
:
tarobj
.
add
(
audio_path
,
arcname
=
audio_file
)
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
with
tarfile
.
TarFile
(
archive_path
,
'r'
)
as
tarobj
:
fileobj
=
tarobj
.
extractfile
(
audio_file
)
return
sox_io_backend
.
info
(
fileobj
,
format_
)
@
contextmanager
def
_set_buffer_size
(
self
,
buffer_size
):
try
:
original_buffer_size
=
get_buffer_size
()
set_buffer_size
(
buffer_size
)
yield
finally
:
set_buffer_size
(
original_buffer_size
)
@
parameterized
.
expand
([
(
'wav'
,
"float32"
),
(
'wav'
,
"int32"
),
(
'wav'
,
"int16"
),
(
'wav'
,
"uint8"
),
(
'mp3'
,
"float32"
),
(
'flac'
,
"float32"
),
(
'vorbis'
,
"float32"
),
(
'amb'
,
"int16"
),
])
def
test_fileobj
(
self
,
ext
,
dtype
):
"""Querying audio via file object works"""
sample_rate
=
16000
num_frames
=
3
*
sample_rate
num_channels
=
2
sinfo
=
self
.
_query_fileobj
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
0
if
ext
in
[
'mp3'
,
'vorbis'
]
else
num_frames
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
parameterized
.
expand
([
(
'vorbis'
,
"float32"
),
])
def
test_fileobj_large_header
(
self
,
ext
,
dtype
):
"""
For audio file with header size exceeding default buffer size:
- Querying audio via file object without enlarging buffer size fails.
- Querying audio via file object after enlarging buffer size succeeds.
"""
sample_rate
=
16000
num_frames
=
3
*
sample_rate
num_channels
=
2
comments
=
"metadata="
+
" "
.
join
([
"value"
for
_
in
range
(
1000
)])
with
self
.
assertRaisesRegex
(
RuntimeError
,
"^Error loading audio file:"
):
sinfo
=
self
.
_query_fileobj
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
comments
=
comments
)
with
self
.
_set_buffer_size
(
16384
):
sinfo
=
self
.
_query_fileobj
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
comments
=
comments
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
0
if
ext
in
[
'mp3'
,
'vorbis'
]
else
num_frames
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
parameterized
.
expand
([
(
'wav'
,
"float32"
),
(
'wav'
,
"int32"
),
(
'wav'
,
"int16"
),
(
'wav'
,
"uint8"
),
(
'mp3'
,
"float32"
),
(
'flac'
,
"float32"
),
(
'vorbis'
,
"float32"
),
(
'amb'
,
"int16"
),
])
def
test_bytesio
(
self
,
ext
,
dtype
):
"""Querying audio via ByteIO object works for small data"""
sample_rate
=
16000
num_frames
=
3
*
sample_rate
num_channels
=
2
sinfo
=
self
.
_query_bytesio
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
0
if
ext
in
[
'mp3'
,
'vorbis'
]
else
num_frames
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
parameterized
.
expand
([
(
'wav'
,
"float32"
),
(
'wav'
,
"int32"
),
(
'wav'
,
"int16"
),
(
'wav'
,
"uint8"
),
(
'mp3'
,
"float32"
),
(
'flac'
,
"float32"
),
(
'vorbis'
,
"float32"
),
(
'amb'
,
"int16"
),
])
def
test_bytesio_tiny
(
self
,
ext
,
dtype
):
"""Querying audio via ByteIO object works for small data"""
sample_rate
=
8000
num_frames
=
4
num_channels
=
2
sinfo
=
self
.
_query_bytesio
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
0
if
ext
in
[
'mp3'
,
'vorbis'
]
else
num_frames
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
parameterized
.
expand
([
(
'wav'
,
"float32"
),
(
'wav'
,
"int32"
),
(
'wav'
,
"int16"
),
(
'wav'
,
"uint8"
),
(
'mp3'
,
"float32"
),
(
'flac'
,
"float32"
),
(
'vorbis'
,
"float32"
),
(
'amb'
,
"int16"
),
])
def
test_tarfile
(
self
,
ext
,
dtype
):
"""Querying compressed audio via file-like object works"""
sample_rate
=
16000
num_frames
=
3.0
*
sample_rate
num_channels
=
2
sinfo
=
self
.
_query_tarfile
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
0
if
ext
in
[
'mp3'
,
'vorbis'
]
else
num_frames
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
skipIfNoSox
@
skipIfNoExec
(
'sox'
)
@
skipIfNoModule
(
"requests"
)
class
TestFileObjectHttp
(
HttpServerMixin
,
FileObjTestBase
,
PytorchTestCase
):
def
_query_http
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
audio_path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
audio_file
=
os
.
path
.
basename
(
audio_path
)
url
=
self
.
get_url
(
audio_file
)
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
with
requests
.
get
(
url
,
stream
=
True
)
as
resp
:
return
sox_io_backend
.
info
(
resp
.
raw
,
format
=
format_
)
@
parameterized
.
expand
([
(
'wav'
,
"float32"
),
(
'wav'
,
"int32"
),
(
'wav'
,
"int16"
),
(
'wav'
,
"uint8"
),
(
'mp3'
,
"float32"
),
(
'flac'
,
"float32"
),
(
'vorbis'
,
"float32"
),
(
'amb'
,
"int16"
),
])
def
test_requests
(
self
,
ext
,
dtype
):
"""Querying compressed audio via requests works"""
sample_rate
=
16000
num_frames
=
3.0
*
sample_rate
num_channels
=
2
sinfo
=
self
.
_query_http
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
0
if
ext
in
[
'mp3'
,
'vorbis'
]
else
num_frames
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
skipIfNoSox
class
TestInfoNoSuchFile
(
PytorchTestCase
):
def
test_info_fail
(
self
):
"""
When attempted to get info on a non-existing file, error message must contain the file path.
"""
path
=
"non_existing_audio.wav"
with
self
.
assertRaisesRegex
(
RuntimeError
,
"^Error loading audio file: failed to open file {0}$"
.
format
(
path
)):
sox_io_backend
.
info
(
path
)
test/torchaudio_unittest/backend/sox_io/load_test.py
0 → 100644
View file @
9dcc7a15
import
io
import
itertools
import
tarfile
from
parameterized
import
parameterized
from
torchaudio.backend
import
sox_io_backend
from
torchaudio._internal
import
module_utils
as
_mod_utils
from
torchaudio_unittest.common_utils
import
(
TempDirMixin
,
HttpServerMixin
,
PytorchTestCase
,
skipIfNoExec
,
skipIfNoModule
,
skipIfNoSox
,
get_asset_path
,
get_wav_data
,
load_wav
,
save_wav
,
sox_utils
,
)
from
.common
import
(
name_func
,
)
if
_mod_utils
.
is_module_available
(
"requests"
):
import
requests
class
LoadTestBase
(
TempDirMixin
,
PytorchTestCase
):
def
assert_format
(
self
,
format
:
str
,
sample_rate
:
float
,
num_channels
:
int
,
compression
:
float
=
None
,
bit_depth
:
int
=
None
,
duration
:
float
=
1
,
normalize
:
bool
=
True
,
encoding
:
str
=
None
,
atol
:
float
=
4e-05
,
rtol
:
float
=
1.3e-06
,
):
"""`sox_io_backend.load` can load given format correctly.
file encodings introduce delay and boundary effects so
we create a reference wav file from the original file format
x
|
| 1. Generate given format with Sox
|
v 2. Convert to wav with Sox
given format ----------------------> wav
| |
| 3. Load with torchaudio | 4. Load with scipy
| |
v v
tensor ----------> x <----------- tensor
5. Compare
Underlying assumptions are;
i. Conversion of given format to wav with Sox preserves data.
ii. Loading wav file with scipy is correct.
By combining i & ii, step 2. and 4. allows to load reference given format
data without using torchaudio
"""
path
=
self
.
get_temp_path
(
f
'1.original.
{
format
}
'
)
ref_path
=
self
.
get_temp_path
(
'2.reference.wav'
)
# 1. Generate the given format with sox
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
encoding
=
encoding
,
compression
=
compression
,
bit_depth
=
bit_depth
,
duration
=
duration
,
)
# 2. Convert to wav with sox
wav_bit_depth
=
32
if
bit_depth
==
24
else
None
# for 24-bit wav
sox_utils
.
convert_audio_file
(
path
,
ref_path
,
bit_depth
=
wav_bit_depth
)
# 3. Load the given format with torchaudio
data
,
sr
=
sox_io_backend
.
load
(
path
,
normalize
=
normalize
)
# 4. Load wav with scipy
data_ref
=
load_wav
(
ref_path
,
normalize
=
normalize
)[
0
]
# 5. Compare
assert
sr
==
sample_rate
self
.
assertEqual
(
data
,
data_ref
,
atol
=
atol
,
rtol
=
rtol
)
def
assert_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
duration
):
"""`sox_io_backend.load` can load wav format correctly.
Wav data loaded with sox_io backend should match those with scipy
"""
path
=
self
.
get_temp_path
(
'reference.wav'
)
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
normalize
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
path
,
data
,
sample_rate
)
expected
=
load_wav
(
path
,
normalize
=
normalize
)[
0
]
data
,
sr
=
sox_io_backend
.
load
(
path
,
normalize
=
normalize
)
assert
sr
==
sample_rate
self
.
assertEqual
(
data
,
expected
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoSox
class
TestLoad
(
LoadTestBase
):
"""Test the correctness of `sox_io_backend.load` for various formats"""
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
)),
name_func
=
name_func
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
):
"""`sox_io_backend.load` can load wav format correctly."""
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
normalize
,
duration
=
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
)),
name_func
=
name_func
)
def
test_24bit_wav
(
self
,
sample_rate
,
num_channels
,
normalize
):
"""`sox_io_backend.load` can load 24bit wav format correctly. Corectly casts it to ``int32`` tensor dtype."""
self
.
assert_format
(
"wav"
,
sample_rate
,
num_channels
,
bit_depth
=
24
,
normalize
=
normalize
,
duration
=
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'int16'
],
[
16000
],
[
2
],
[
False
],
)),
name_func
=
name_func
)
def
test_wav_large
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
):
"""`sox_io_backend.load` can load large wav file correctly."""
two_hours
=
2
*
60
*
60
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
normalize
,
two_hours
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
4
,
8
,
16
,
32
],
)),
name_func
=
name_func
)
def
test_multiple_channels
(
self
,
dtype
,
num_channels
):
"""`sox_io_backend.load` can load wav file with more than 2 channels."""
sample_rate
=
8000
normalize
=
False
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
normalize
,
duration
=
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
,
44100
],
[
1
,
2
],
[
96
,
128
,
160
,
192
,
224
,
256
,
320
],
)),
name_func
=
name_func
)
def
test_mp3
(
self
,
sample_rate
,
num_channels
,
bit_rate
):
"""`sox_io_backend.load` can load mp3 format correctly."""
self
.
assert_format
(
"mp3"
,
sample_rate
,
num_channels
,
compression
=
bit_rate
,
duration
=
1
,
atol
=
5e-05
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
16000
],
[
2
],
[
128
],
)),
name_func
=
name_func
)
def
test_mp3_large
(
self
,
sample_rate
,
num_channels
,
bit_rate
):
"""`sox_io_backend.load` can load large mp3 file correctly."""
two_hours
=
2
*
60
*
60
self
.
assert_format
(
"mp3"
,
sample_rate
,
num_channels
,
compression
=
bit_rate
,
duration
=
two_hours
,
atol
=
5e-05
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
list
(
range
(
9
)),
)),
name_func
=
name_func
)
def
test_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
"""`sox_io_backend.load` can load flac format correctly."""
self
.
assert_format
(
"flac"
,
sample_rate
,
num_channels
,
compression
=
compression_level
,
bit_depth
=
16
,
duration
=
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
16000
],
[
2
],
[
0
],
)),
name_func
=
name_func
)
def
test_flac_large
(
self
,
sample_rate
,
num_channels
,
compression_level
):
"""`sox_io_backend.load` can load large flac file correctly."""
two_hours
=
2
*
60
*
60
self
.
assert_format
(
"flac"
,
sample_rate
,
num_channels
,
compression
=
compression_level
,
bit_depth
=
16
,
duration
=
two_hours
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
[
-
1
,
0
,
1
,
2
,
3
,
3.6
,
5
,
10
],
)),
name_func
=
name_func
)
def
test_vorbis
(
self
,
sample_rate
,
num_channels
,
quality_level
):
"""`sox_io_backend.load` can load vorbis format correctly."""
self
.
assert_format
(
"vorbis"
,
sample_rate
,
num_channels
,
compression
=
quality_level
,
bit_depth
=
16
,
duration
=
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
16000
],
[
2
],
[
10
],
)),
name_func
=
name_func
)
def
test_vorbis_large
(
self
,
sample_rate
,
num_channels
,
quality_level
):
"""`sox_io_backend.load` can load large vorbis file correctly."""
two_hours
=
2
*
60
*
60
self
.
assert_format
(
"vorbis"
,
sample_rate
,
num_channels
,
compression
=
quality_level
,
bit_depth
=
16
,
duration
=
two_hours
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'96k'
],
[
1
,
2
],
[
0
,
5
,
10
],
)),
name_func
=
name_func
)
def
test_opus
(
self
,
bitrate
,
num_channels
,
compression_level
):
"""`sox_io_backend.load` can load opus file correctly."""
ops_path
=
get_asset_path
(
'io'
,
f
'
{
bitrate
}
_
{
compression_level
}
_
{
num_channels
}
ch.opus'
)
wav_path
=
self
.
get_temp_path
(
f
'
{
bitrate
}
_
{
compression_level
}
_
{
num_channels
}
ch.opus.wav'
)
sox_utils
.
convert_audio_file
(
ops_path
,
wav_path
)
expected
,
sample_rate
=
load_wav
(
wav_path
)
found
,
sr
=
sox_io_backend
.
load
(
ops_path
)
assert
sample_rate
==
sr
self
.
assertEqual
(
expected
,
found
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
)),
name_func
=
name_func
)
def
test_sphere
(
self
,
sample_rate
,
num_channels
):
"""`sox_io_backend.load` can load sph format correctly."""
self
.
assert_format
(
"sph"
,
sample_rate
,
num_channels
,
bit_depth
=
32
,
duration
=
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
)),
name_func
=
name_func
)
def
test_amb
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
):
"""`sox_io_backend.load` can load amb format correctly."""
bit_depth
=
sox_utils
.
get_bit_depth
(
dtype
)
encoding
=
sox_utils
.
get_encoding
(
dtype
)
self
.
assert_format
(
"amb"
,
sample_rate
,
num_channels
,
bit_depth
=
bit_depth
,
duration
=
1
,
encoding
=
encoding
,
normalize
=
normalize
)
def
test_amr_nb
(
self
):
"""`sox_io_backend.load` can load amr_nb format correctly."""
self
.
assert_format
(
"amr-nb"
,
sample_rate
=
8000
,
num_channels
=
1
,
bit_depth
=
32
,
duration
=
1
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoSox
class
TestLoadParams
(
TempDirMixin
,
PytorchTestCase
):
"""Test the correctness of frame parameters of `sox_io_backend.load`"""
original
=
None
path
=
None
def
setUp
(
self
):
super
().
setUp
()
sample_rate
=
8000
self
.
original
=
get_wav_data
(
'float32'
,
num_channels
=
2
)
self
.
path
=
self
.
get_temp_path
(
'test.wav'
)
save_wav
(
self
.
path
,
self
.
original
,
sample_rate
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
0
,
1
,
10
,
100
,
1000
],
[
-
1
,
1
,
10
,
100
,
1000
],
)),
name_func
=
name_func
)
def
test_frame
(
self
,
frame_offset
,
num_frames
):
"""num_frames and frame_offset correctly specify the region of data"""
found
,
_
=
sox_io_backend
.
load
(
self
.
path
,
frame_offset
,
num_frames
)
frame_end
=
None
if
num_frames
==
-
1
else
frame_offset
+
num_frames
self
.
assertEqual
(
found
,
self
.
original
[:,
frame_offset
:
frame_end
])
@
parameterized
.
expand
([(
True
,
),
(
False
,
)],
name_func
=
name_func
)
def
test_channels_first
(
self
,
channels_first
):
"""channels_first swaps axes"""
found
,
_
=
sox_io_backend
.
load
(
self
.
path
,
channels_first
=
channels_first
)
expected
=
self
.
original
if
channels_first
else
self
.
original
.
transpose
(
1
,
0
)
self
.
assertEqual
(
found
,
expected
)
@
skipIfNoSox
class
TestLoadWithoutExtension
(
PytorchTestCase
):
def
test_mp3
(
self
):
"""Providing format allows to read mp3 without extension
libsox does not check header for mp3
https://github.com/pytorch/audio/issues/1040
The file was generated with the following command
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
"""
path
=
get_asset_path
(
"mp3_without_ext"
)
_
,
sr
=
sox_io_backend
.
load
(
path
,
format
=
"mp3"
)
assert
sr
==
16000
class
CloggedFileObj
:
def
__init__
(
self
,
fileobj
):
self
.
fileobj
=
fileobj
self
.
buffer
=
b
''
def
read
(
self
,
n
):
if
not
self
.
buffer
:
self
.
buffer
+=
self
.
fileobj
.
read
(
n
)
ret
=
self
.
buffer
[:
2
]
self
.
buffer
=
self
.
buffer
[
2
:]
return
ret
@
skipIfNoSox
@
skipIfNoExec
(
'sox'
)
class
TestFileObject
(
TempDirMixin
,
PytorchTestCase
):
"""
In this test suite, the result of file-like object input is compared against file path input,
because `load` function is rigrously tested for file path inputs to match libsox's result,
"""
@
parameterized
.
expand
([
(
'wav'
,
None
),
(
'mp3'
,
128
),
(
'mp3'
,
320
),
(
'flac'
,
0
),
(
'flac'
,
5
),
(
'flac'
,
8
),
(
'vorbis'
,
-
1
),
(
'vorbis'
,
10
),
(
'amb'
,
None
),
])
def
test_fileobj
(
self
,
ext
,
compression
):
"""Loading audio via file object returns the same result as via file path."""
sample_rate
=
16000
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
2
,
compression
=
compression
)
expected
,
_
=
sox_io_backend
.
load
(
path
)
with
open
(
path
,
'rb'
)
as
fileobj
:
found
,
sr
=
sox_io_backend
.
load
(
fileobj
,
format
=
format_
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
@
parameterized
.
expand
([
(
'wav'
,
None
),
(
'mp3'
,
128
),
(
'mp3'
,
320
),
(
'flac'
,
0
),
(
'flac'
,
5
),
(
'flac'
,
8
),
(
'vorbis'
,
-
1
),
(
'vorbis'
,
10
),
(
'amb'
,
None
),
])
def
test_bytesio
(
self
,
ext
,
compression
):
"""Loading audio via BytesIO object returns the same result as via file path."""
sample_rate
=
16000
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
2
,
compression
=
compression
)
expected
,
_
=
sox_io_backend
.
load
(
path
)
with
open
(
path
,
'rb'
)
as
file_
:
fileobj
=
io
.
BytesIO
(
file_
.
read
())
found
,
sr
=
sox_io_backend
.
load
(
fileobj
,
format
=
format_
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
@
parameterized
.
expand
([
(
'wav'
,
None
),
(
'mp3'
,
128
),
(
'mp3'
,
320
),
(
'flac'
,
0
),
(
'flac'
,
5
),
(
'flac'
,
8
),
(
'vorbis'
,
-
1
),
(
'vorbis'
,
10
),
(
'amb'
,
None
),
])
def
test_bytesio_clogged
(
self
,
ext
,
compression
):
"""Loading audio via clogged file object returns the same result as via file path.
This test case validates the case where fileobject returns shorter bytes than requeted.
"""
sample_rate
=
16000
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
2
,
compression
=
compression
)
expected
,
_
=
sox_io_backend
.
load
(
path
)
with
open
(
path
,
'rb'
)
as
file_
:
fileobj
=
CloggedFileObj
(
io
.
BytesIO
(
file_
.
read
()))
found
,
sr
=
sox_io_backend
.
load
(
fileobj
,
format
=
format_
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
@
parameterized
.
expand
([
(
'wav'
,
None
),
(
'mp3'
,
128
),
(
'mp3'
,
320
),
(
'flac'
,
0
),
(
'flac'
,
5
),
(
'flac'
,
8
),
(
'vorbis'
,
-
1
),
(
'vorbis'
,
10
),
(
'amb'
,
None
),
])
def
test_bytesio_tiny
(
self
,
ext
,
compression
):
"""Loading very small audio via file object returns the same result as via file path.
"""
sample_rate
=
16000
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
2
,
compression
=
compression
,
duration
=
1
/
1600
)
expected
,
_
=
sox_io_backend
.
load
(
path
)
with
open
(
path
,
'rb'
)
as
file_
:
fileobj
=
io
.
BytesIO
(
file_
.
read
())
found
,
sr
=
sox_io_backend
.
load
(
fileobj
,
format
=
format_
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
@
parameterized
.
expand
([
(
'wav'
,
None
),
(
'mp3'
,
128
),
(
'mp3'
,
320
),
(
'flac'
,
0
),
(
'flac'
,
5
),
(
'flac'
,
8
),
(
'vorbis'
,
-
1
),
(
'vorbis'
,
10
),
(
'amb'
,
None
),
])
def
test_tarfile
(
self
,
ext
,
compression
):
"""Loading compressed audio via file-like object returns the same result as via file path."""
sample_rate
=
16000
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
audio_file
=
f
'test.
{
ext
}
'
audio_path
=
self
.
get_temp_path
(
audio_file
)
archive_path
=
self
.
get_temp_path
(
'archive.tar.gz'
)
sox_utils
.
gen_audio_file
(
audio_path
,
sample_rate
,
num_channels
=
2
,
compression
=
compression
)
expected
,
_
=
sox_io_backend
.
load
(
audio_path
)
with
tarfile
.
TarFile
(
archive_path
,
'w'
)
as
tarobj
:
tarobj
.
add
(
audio_path
,
arcname
=
audio_file
)
with
tarfile
.
TarFile
(
archive_path
,
'r'
)
as
tarobj
:
fileobj
=
tarobj
.
extractfile
(
audio_file
)
found
,
sr
=
sox_io_backend
.
load
(
fileobj
,
format
=
format_
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
@
skipIfNoSox
@
skipIfNoExec
(
'sox'
)
@
skipIfNoModule
(
"requests"
)
class
TestFileObjectHttp
(
HttpServerMixin
,
PytorchTestCase
):
@
parameterized
.
expand
([
(
'wav'
,
None
),
(
'mp3'
,
128
),
(
'mp3'
,
320
),
(
'flac'
,
0
),
(
'flac'
,
5
),
(
'flac'
,
8
),
(
'vorbis'
,
-
1
),
(
'vorbis'
,
10
),
(
'amb'
,
None
),
])
def
test_requests
(
self
,
ext
,
compression
):
sample_rate
=
16000
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
audio_file
=
f
'test.
{
ext
}
'
audio_path
=
self
.
get_temp_path
(
audio_file
)
sox_utils
.
gen_audio_file
(
audio_path
,
sample_rate
,
num_channels
=
2
,
compression
=
compression
)
expected
,
_
=
sox_io_backend
.
load
(
audio_path
)
url
=
self
.
get_url
(
audio_file
)
with
requests
.
get
(
url
,
stream
=
True
)
as
resp
:
found
,
sr
=
sox_io_backend
.
load
(
resp
.
raw
,
format
=
format_
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
0
,
1
,
10
,
100
,
1000
],
[
-
1
,
1
,
10
,
100
,
1000
],
)),
name_func
=
name_func
)
def
test_frame
(
self
,
frame_offset
,
num_frames
):
"""num_frames and frame_offset correctly specify the region of data"""
sample_rate
=
8000
audio_file
=
'test.wav'
audio_path
=
self
.
get_temp_path
(
audio_file
)
original
=
get_wav_data
(
'float32'
,
num_channels
=
2
)
save_wav
(
audio_path
,
original
,
sample_rate
)
frame_end
=
None
if
num_frames
==
-
1
else
frame_offset
+
num_frames
expected
=
original
[:,
frame_offset
:
frame_end
]
url
=
self
.
get_url
(
audio_file
)
with
requests
.
get
(
url
,
stream
=
True
)
as
resp
:
found
,
sr
=
sox_io_backend
.
load
(
resp
.
raw
,
frame_offset
,
num_frames
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
@
skipIfNoSox
class
TestLoadNoSuchFile
(
PytorchTestCase
):
def
test_load_fail
(
self
):
"""
When attempted to load a non-existing file, error message must contain the file path.
"""
path
=
"non_existing_audio.wav"
with
self
.
assertRaisesRegex
(
RuntimeError
,
"^Error loading audio file: failed to open file {0}$"
.
format
(
path
)):
sox_io_backend
.
load
(
path
)
test/torchaudio_unittest/backend/sox_io/roundtrip_test.py
0 → 100644
View file @
9dcc7a15
import
itertools
from
torchaudio.backend
import
sox_io_backend
from
parameterized
import
parameterized
from
torchaudio_unittest.common_utils
import
(
TempDirMixin
,
PytorchTestCase
,
skipIfNoExec
,
skipIfNoSox
,
get_wav_data
,
)
from
.common
import
(
name_func
,
get_enc_params
,
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoSox
class
TestRoundTripIO
(
TempDirMixin
,
PytorchTestCase
):
"""save/load round trip should not degrade data for lossless formats"""
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
8000
,
16000
],
[
1
,
2
],
)),
name_func
=
name_func
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""save/load round trip should not degrade data for wav formats"""
original
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
)
enc
,
bps
=
get_enc_params
(
dtype
)
data
=
original
for
i
in
range
(
10
):
path
=
self
.
get_temp_path
(
f
'
{
i
}
.wav'
)
sox_io_backend
.
save
(
path
,
data
,
sample_rate
,
encoding
=
enc
,
bits_per_sample
=
bps
)
data
,
sr
=
sox_io_backend
.
load
(
path
,
normalize
=
False
)
assert
sr
==
sample_rate
self
.
assertEqual
(
original
,
data
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
list
(
range
(
9
)),
)),
name_func
=
name_func
)
def
test_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
"""save/load round trip should not degrade data for flac formats"""
original
=
get_wav_data
(
'float32'
,
num_channels
)
data
=
original
for
i
in
range
(
10
):
path
=
self
.
get_temp_path
(
f
'
{
i
}
.flac'
)
sox_io_backend
.
save
(
path
,
data
,
sample_rate
,
compression
=
compression_level
)
data
,
sr
=
sox_io_backend
.
load
(
path
)
assert
sr
==
sample_rate
self
.
assertEqual
(
original
,
data
)
test/torchaudio_unittest/backend/sox_io/save_test.py
0 → 100644
View file @
9dcc7a15
import
io
import
os
import
unittest
import
torch
from
torchaudio.backend
import
sox_io_backend
from
parameterized
import
parameterized
from
torchaudio_unittest.common_utils
import
(
TempDirMixin
,
TorchaudioTestCase
,
PytorchTestCase
,
skipIfNoExec
,
skipIfNoSox
,
get_wav_data
,
load_wav
,
save_wav
,
sox_utils
,
nested_params
,
)
from
.common
import
(
name_func
,
get_enc_params
,
)
def
_get_sox_encoding
(
encoding
):
encodings
=
{
'PCM_F'
:
'floating-point'
,
'PCM_S'
:
'signed-integer'
,
'PCM_U'
:
'unsigned-integer'
,
'ULAW'
:
'u-law'
,
'ALAW'
:
'a-law'
,
}
return
encodings
.
get
(
encoding
)
class
SaveTestBase
(
TempDirMixin
,
TorchaudioTestCase
):
def
assert_save_consistency
(
self
,
format
:
str
,
*
,
compression
:
float
=
None
,
encoding
:
str
=
None
,
bits_per_sample
:
int
=
None
,
sample_rate
:
float
=
8000
,
num_channels
:
int
=
2
,
num_frames
:
float
=
3
*
8000
,
src_dtype
:
str
=
'int32'
,
test_mode
:
str
=
"path"
,
):
"""`save` function produces file that is comparable with `sox` command
To compare that the file produced by `save` function agains the file produced by
the equivalent `sox` command, we need to load both files.
But there are many formats that cannot be opened with common Python modules (like
SciPy).
So we use `sox` command to prepare the original data and convert the saved files
into a format that SciPy can read (PCM wav).
The following diagram illustrates this process. The difference is 2.1. and 3.1.
This assumes that
- loading data with SciPy preserves the data well.
- converting the resulting files into WAV format with `sox` preserve the data well.
x
| 1. Generate source wav file with SciPy
|
v
-------------- wav ----------------
| |
| 2.1. load with scipy | 3.1. Convert to the target
| then save it into the target | format depth with sox
| format with torchaudio |
v v
target format target format
| |
| 2.2. Convert to wav with sox | 3.2. Convert to wav with sox
| |
v v
wav wav
| |
| 2.3. load with scipy | 3.3. load with scipy
| |
v v
tensor -------> compare <--------- tensor
"""
cmp_encoding
=
'floating-point'
cmp_bit_depth
=
32
src_path
=
self
.
get_temp_path
(
'1.source.wav'
)
tgt_path
=
self
.
get_temp_path
(
f
'2.1.torchaudio.
{
format
}
'
)
tst_path
=
self
.
get_temp_path
(
'2.2.result.wav'
)
sox_path
=
self
.
get_temp_path
(
f
'3.1.sox.
{
format
}
'
)
ref_path
=
self
.
get_temp_path
(
'3.2.ref.wav'
)
# 1. Generate original wav
data
=
get_wav_data
(
src_dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
num_frames
)
save_wav
(
src_path
,
data
,
sample_rate
)
# 2.1. Convert the original wav to target format with torchaudio
data
=
load_wav
(
src_path
,
normalize
=
False
)[
0
]
if
test_mode
==
"path"
:
sox_io_backend
.
save
(
tgt_path
,
data
,
sample_rate
,
compression
=
compression
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
)
elif
test_mode
==
"fileobj"
:
with
open
(
tgt_path
,
'bw'
)
as
file_
:
sox_io_backend
.
save
(
file_
,
data
,
sample_rate
,
format
=
format
,
compression
=
compression
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
)
elif
test_mode
==
"bytesio"
:
file_
=
io
.
BytesIO
()
sox_io_backend
.
save
(
file_
,
data
,
sample_rate
,
format
=
format
,
compression
=
compression
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
)
file_
.
seek
(
0
)
with
open
(
tgt_path
,
'bw'
)
as
f
:
f
.
write
(
file_
.
read
())
else
:
raise
ValueError
(
f
"Unexpected test mode:
{
test_mode
}
"
)
# 2.2. Convert the target format to wav with sox
sox_utils
.
convert_audio_file
(
tgt_path
,
tst_path
,
encoding
=
cmp_encoding
,
bit_depth
=
cmp_bit_depth
)
# 2.3. Load with SciPy
found
=
load_wav
(
tst_path
,
normalize
=
False
)[
0
]
# 3.1. Convert the original wav to target format with sox
sox_encoding
=
_get_sox_encoding
(
encoding
)
sox_utils
.
convert_audio_file
(
src_path
,
sox_path
,
compression
=
compression
,
encoding
=
sox_encoding
,
bit_depth
=
bits_per_sample
)
# 3.2. Convert the target format to wav with sox
sox_utils
.
convert_audio_file
(
sox_path
,
ref_path
,
encoding
=
cmp_encoding
,
bit_depth
=
cmp_bit_depth
)
# 3.3. Load with SciPy
expected
=
load_wav
(
ref_path
,
normalize
=
False
)[
0
]
self
.
assertEqual
(
found
,
expected
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoSox
class
SaveTest
(
SaveTestBase
):
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
[
(
'PCM_U'
,
8
),
(
'PCM_S'
,
16
),
(
'PCM_S'
,
32
),
(
'PCM_F'
,
32
),
(
'PCM_F'
,
64
),
(
'ULAW'
,
8
),
(
'ALAW'
,
8
),
],
)
def
test_save_wav
(
self
,
test_mode
,
enc_params
):
encoding
,
bits_per_sample
=
enc_params
self
.
assert_save_consistency
(
"wav"
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
test_mode
=
test_mode
)
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
[
(
'float32'
,
),
(
'int32'
,
),
(
'int16'
,
),
(
'uint8'
,
),
],
)
def
test_save_wav_dtype
(
self
,
test_mode
,
params
):
dtype
,
=
params
self
.
assert_save_consistency
(
"wav"
,
src_dtype
=
dtype
,
test_mode
=
test_mode
)
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
[
None
,
-
4.2
,
-
0.2
,
0
,
0.2
,
96
,
128
,
160
,
192
,
224
,
256
,
320
,
],
)
def
test_save_mp3
(
self
,
test_mode
,
bit_rate
):
if
test_mode
in
[
"fileobj"
,
"bytesio"
]:
if
bit_rate
is
not
None
and
bit_rate
<
1
:
raise
unittest
.
SkipTest
(
"mp3 format with variable bit rate is known to "
"not yield the exact same result as sox command."
)
self
.
assert_save_consistency
(
"mp3"
,
compression
=
bit_rate
,
test_mode
=
test_mode
)
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
[
8
,
16
,
24
],
[
None
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
],
)
def
test_save_flac
(
self
,
test_mode
,
bits_per_sample
,
compression_level
):
self
.
assert_save_consistency
(
"flac"
,
compression
=
compression_level
,
bits_per_sample
=
bits_per_sample
,
test_mode
=
test_mode
)
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
)
def
test_save_htk
(
self
,
test_mode
):
self
.
assert_save_consistency
(
"htk"
,
test_mode
=
test_mode
,
num_channels
=
1
)
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
[
None
,
-
1
,
0
,
1
,
2
,
3
,
3.6
,
5
,
10
,
],
)
def
test_save_vorbis
(
self
,
test_mode
,
quality_level
):
self
.
assert_save_consistency
(
"vorbis"
,
compression
=
quality_level
,
test_mode
=
test_mode
)
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
[
(
'PCM_S'
,
8
,
),
(
'PCM_S'
,
16
,
),
(
'PCM_S'
,
24
,
),
(
'PCM_S'
,
32
,
),
(
'ULAW'
,
8
),
(
'ALAW'
,
8
),
(
'ALAW'
,
16
),
(
'ALAW'
,
24
),
(
'ALAW'
,
32
),
],
)
def
test_save_sphere
(
self
,
test_mode
,
enc_params
):
encoding
,
bits_per_sample
=
enc_params
self
.
assert_save_consistency
(
"sph"
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
test_mode
=
test_mode
)
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
[
(
'PCM_U'
,
8
,
),
(
'PCM_S'
,
16
,
),
(
'PCM_S'
,
24
,
),
(
'PCM_S'
,
32
,
),
(
'PCM_F'
,
32
,
),
(
'PCM_F'
,
64
,
),
(
'ULAW'
,
8
,
),
(
'ALAW'
,
8
,
),
],
)
def
test_save_amb
(
self
,
test_mode
,
enc_params
):
encoding
,
bits_per_sample
=
enc_params
self
.
assert_save_consistency
(
"amb"
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
test_mode
=
test_mode
)
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
[
None
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
],
)
def
test_save_amr_nb
(
self
,
test_mode
,
bit_rate
):
self
.
assert_save_consistency
(
"amr-nb"
,
compression
=
bit_rate
,
num_channels
=
1
,
test_mode
=
test_mode
)
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
)
def
test_save_gsm
(
self
,
test_mode
):
self
.
assert_save_consistency
(
"gsm"
,
num_channels
=
1
,
test_mode
=
test_mode
)
with
self
.
assertRaises
(
RuntimeError
,
msg
=
"gsm format only supports single channel audio."
):
self
.
assert_save_consistency
(
"gsm"
,
num_channels
=
2
,
test_mode
=
test_mode
)
with
self
.
assertRaises
(
RuntimeError
,
msg
=
"gsm format only supports a sampling rate of 8kHz."
):
self
.
assert_save_consistency
(
"gsm"
,
sample_rate
=
16000
,
test_mode
=
test_mode
)
@
parameterized
.
expand
([
(
"wav"
,
"PCM_S"
,
16
),
(
"mp3"
,
),
(
"flac"
,
),
(
"vorbis"
,
),
(
"sph"
,
"PCM_S"
,
16
),
(
"amr-nb"
,
),
(
"amb"
,
"PCM_S"
,
16
),
],
name_func
=
name_func
)
def
test_save_large
(
self
,
format
,
encoding
=
None
,
bits_per_sample
=
None
):
"""`sox_io_backend.save` can save large files."""
sample_rate
=
8000
one_hour
=
60
*
60
*
sample_rate
self
.
assert_save_consistency
(
format
,
num_channels
=
1
,
sample_rate
=
8000
,
num_frames
=
one_hour
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
)
@
parameterized
.
expand
([
(
32
,
),
(
64
,
),
(
128
,
),
(
256
,
),
],
name_func
=
name_func
)
def
test_save_multi_channels
(
self
,
num_channels
):
"""`sox_io_backend.save` can save audio with many channels"""
self
.
assert_save_consistency
(
"wav"
,
encoding
=
"PCM_S"
,
bits_per_sample
=
16
,
num_channels
=
num_channels
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoSox
class
TestSaveParams
(
TempDirMixin
,
PytorchTestCase
):
"""Test the correctness of optional parameters of `sox_io_backend.save`"""
@
parameterized
.
expand
([(
True
,
),
(
False
,
)],
name_func
=
name_func
)
def
test_save_channels_first
(
self
,
channels_first
):
"""channels_first swaps axes"""
path
=
self
.
get_temp_path
(
'data.wav'
)
data
=
get_wav_data
(
'int16'
,
2
,
channels_first
=
channels_first
,
normalize
=
False
)
sox_io_backend
.
save
(
path
,
data
,
8000
,
channels_first
=
channels_first
)
found
=
load_wav
(
path
,
normalize
=
False
)[
0
]
expected
=
data
if
channels_first
else
data
.
transpose
(
1
,
0
)
self
.
assertEqual
(
found
,
expected
)
@
parameterized
.
expand
([
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
name_func
=
name_func
)
def
test_save_noncontiguous
(
self
,
dtype
):
"""Noncontiguous tensors are saved correctly"""
path
=
self
.
get_temp_path
(
'data.wav'
)
enc
,
bps
=
get_enc_params
(
dtype
)
expected
=
get_wav_data
(
dtype
,
4
,
normalize
=
False
)[::
2
,
::
2
]
assert
not
expected
.
is_contiguous
()
sox_io_backend
.
save
(
path
,
expected
,
8000
,
encoding
=
enc
,
bits_per_sample
=
bps
)
found
=
load_wav
(
path
,
normalize
=
False
)[
0
]
self
.
assertEqual
(
found
,
expected
)
@
parameterized
.
expand
([
'float32'
,
'int32'
,
'int16'
,
'uint8'
,
])
def
test_save_tensor_preserve
(
self
,
dtype
):
"""save function should not alter Tensor"""
path
=
self
.
get_temp_path
(
'data.wav'
)
expected
=
get_wav_data
(
dtype
,
4
,
normalize
=
False
)[::
2
,
::
2
]
data
=
expected
.
clone
()
sox_io_backend
.
save
(
path
,
data
,
8000
)
self
.
assertEqual
(
data
,
expected
)
@
skipIfNoSox
class
TestSaveNonExistingDirectory
(
PytorchTestCase
):
def
test_save_fail
(
self
):
"""
When attempted to save into a non-existing dir, error message must contain the file path.
"""
path
=
os
.
path
.
join
(
"non_existing_directory"
,
"foo.wav"
)
with
self
.
assertRaisesRegex
(
RuntimeError
,
"^Error saving audio file: failed to open file {0}$"
.
format
(
path
)):
sox_io_backend
.
save
(
path
,
torch
.
zeros
(
1
,
1
),
8000
)
test/torchaudio_unittest/backend/sox_io/smoke_test.py
0 → 100644
View file @
9dcc7a15
import
io
import
itertools
import
unittest
from
torchaudio.utils
import
sox_utils
from
torchaudio.backend
import
sox_io_backend
from
torchaudio._internal.module_utils
import
is_sox_available
from
parameterized
import
parameterized
from
torchaudio_unittest.common_utils
import
(
TempDirMixin
,
TorchaudioTestCase
,
skipIfNoSox
,
get_wav_data
,
)
from
.common
import
name_func
skipIfNoMP3
=
unittest
.
skipIf
(
not
is_sox_available
()
or
'mp3'
not
in
sox_utils
.
list_read_formats
()
or
'mp3'
not
in
sox_utils
.
list_write_formats
(),
'"sox_io" backend does not support MP3'
)
@
skipIfNoSox
class
SmokeTest
(
TempDirMixin
,
TorchaudioTestCase
):
"""Run smoke test on various audio format
The purpose of this test suite is to verify that sox_io_backend functionalities do not exhibit
abnormal behaviors.
This test suite should be able to run without any additional tools (such as sox command),
however without such tools, the correctness of each function cannot be verified.
"""
def
run_smoke_test
(
self
,
ext
,
sample_rate
,
num_channels
,
*
,
compression
=
None
,
dtype
=
'float32'
):
duration
=
1
num_frames
=
sample_rate
*
duration
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
original
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
num_frames
)
# 1. run save
sox_io_backend
.
save
(
path
,
original
,
sample_rate
,
compression
=
compression
)
# 2. run info
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_channels
==
num_channels
# 3. run load
loaded
,
sr
=
sox_io_backend
.
load
(
path
,
normalize
=
False
)
assert
sr
==
sample_rate
assert
loaded
.
shape
[
0
]
==
num_channels
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
8000
,
16000
],
[
1
,
2
],
)),
name_func
=
name_func
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""Run smoke test on wav format"""
self
.
run_smoke_test
(
'wav'
,
sample_rate
,
num_channels
,
dtype
=
dtype
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
[
-
4.2
,
-
0.2
,
0
,
0.2
,
96
,
128
,
160
,
192
,
224
,
256
,
320
],
)))
@
skipIfNoMP3
def
test_mp3
(
self
,
sample_rate
,
num_channels
,
bit_rate
):
"""Run smoke test on mp3 format"""
self
.
run_smoke_test
(
'mp3'
,
sample_rate
,
num_channels
,
compression
=
bit_rate
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
[
-
1
,
0
,
1
,
2
,
3
,
3.6
,
5
,
10
],
)))
def
test_vorbis
(
self
,
sample_rate
,
num_channels
,
quality_level
):
"""Run smoke test on vorbis format"""
self
.
run_smoke_test
(
'vorbis'
,
sample_rate
,
num_channels
,
compression
=
quality_level
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
list
(
range
(
9
)),
)),
name_func
=
name_func
)
def
test_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
"""Run smoke test on flac format"""
self
.
run_smoke_test
(
'flac'
,
sample_rate
,
num_channels
,
compression
=
compression_level
)
@
skipIfNoSox
class
SmokeTestFileObj
(
TorchaudioTestCase
):
"""Run smoke test on various audio format
The purpose of this test suite is to verify that sox_io_backend functionalities do not exhibit
abnormal behaviors.
This test suite should be able to run without any additional tools (such as sox command),
however without such tools, the correctness of each function cannot be verified.
"""
def
run_smoke_test
(
self
,
ext
,
sample_rate
,
num_channels
,
*
,
compression
=
None
,
dtype
=
'float32'
):
duration
=
1
num_frames
=
sample_rate
*
duration
original
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
num_frames
)
fileobj
=
io
.
BytesIO
()
# 1. run save
sox_io_backend
.
save
(
fileobj
,
original
,
sample_rate
,
compression
=
compression
,
format
=
ext
)
# 2. run info
fileobj
.
seek
(
0
)
info
=
sox_io_backend
.
info
(
fileobj
,
format
=
ext
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_channels
==
num_channels
# 3. run load
fileobj
.
seek
(
0
)
loaded
,
sr
=
sox_io_backend
.
load
(
fileobj
,
normalize
=
False
,
format
=
ext
)
assert
sr
==
sample_rate
assert
loaded
.
shape
[
0
]
==
num_channels
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
8000
,
16000
],
[
1
,
2
],
)),
name_func
=
name_func
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""Run smoke test on wav format"""
self
.
run_smoke_test
(
'wav'
,
sample_rate
,
num_channels
,
dtype
=
dtype
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
[
-
4.2
,
-
0.2
,
0
,
0.2
,
96
,
128
,
160
,
192
,
224
,
256
,
320
],
)))
@
skipIfNoMP3
def
test_mp3
(
self
,
sample_rate
,
num_channels
,
bit_rate
):
"""Run smoke test on mp3 format"""
self
.
run_smoke_test
(
'mp3'
,
sample_rate
,
num_channels
,
compression
=
bit_rate
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
[
-
1
,
0
,
1
,
2
,
3
,
3.6
,
5
,
10
],
)))
def
test_vorbis
(
self
,
sample_rate
,
num_channels
,
quality_level
):
"""Run smoke test on vorbis format"""
self
.
run_smoke_test
(
'vorbis'
,
sample_rate
,
num_channels
,
compression
=
quality_level
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
list
(
range
(
9
)),
)),
name_func
=
name_func
)
def
test_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
"""Run smoke test on flac format"""
self
.
run_smoke_test
(
'flac'
,
sample_rate
,
num_channels
,
compression
=
compression_level
)
test/torchaudio_unittest/backend/sox_io/torchscript_test.py
0 → 100644
View file @
9dcc7a15
import
itertools
from
typing
import
Optional
import
torch
import
torchaudio
from
parameterized
import
parameterized
from
torchaudio_unittest.common_utils
import
(
TempDirMixin
,
TorchaudioTestCase
,
skipIfNoExec
,
skipIfNoSox
,
get_wav_data
,
save_wav
,
load_wav
,
sox_utils
,
torch_script
,
)
from
.common
import
(
name_func
,
get_enc_params
,
)
def
py_info_func
(
filepath
:
str
)
->
torchaudio
.
backend
.
sox_io_backend
.
AudioMetaData
:
return
torchaudio
.
info
(
filepath
)
def
py_load_func
(
filepath
:
str
,
normalize
:
bool
,
channels_first
:
bool
):
return
torchaudio
.
load
(
filepath
,
normalize
=
normalize
,
channels_first
=
channels_first
)
def
py_save_func
(
filepath
:
str
,
tensor
:
torch
.
Tensor
,
sample_rate
:
int
,
channels_first
:
bool
=
True
,
compression
:
Optional
[
float
]
=
None
,
encoding
:
Optional
[
str
]
=
None
,
bits_per_sample
:
Optional
[
int
]
=
None
,
):
torchaudio
.
save
(
filepath
,
tensor
,
sample_rate
,
channels_first
,
compression
,
None
,
encoding
,
bits_per_sample
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoSox
class
SoxIO
(
TempDirMixin
,
TorchaudioTestCase
):
"""TorchScript-ability Test suite for `sox_io_backend`"""
backend
=
'sox_io'
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
8000
,
16000
],
[
1
,
2
],
)),
name_func
=
name_func
)
def
test_info_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`sox_io_backend.info` is torchscript-able and returns the same result"""
audio_path
=
self
.
get_temp_path
(
f
'
{
dtype
}
_
{
sample_rate
}
_
{
num_channels
}
.wav'
)
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
1
*
sample_rate
)
save_wav
(
audio_path
,
data
,
sample_rate
)
ts_info_func
=
torch_script
(
py_info_func
)
py_info
=
py_info_func
(
audio_path
)
ts_info
=
ts_info_func
(
audio_path
)
assert
py_info
.
sample_rate
==
ts_info
.
sample_rate
assert
py_info
.
num_frames
==
ts_info
.
num_frames
assert
py_info
.
num_channels
==
ts_info
.
num_channels
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
[
False
,
True
],
)),
name_func
=
name_func
)
def
test_load_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
):
"""`sox_io_backend.load` is torchscript-able and returns the same result"""
audio_path
=
self
.
get_temp_path
(
f
'test_load_
{
dtype
}
_
{
sample_rate
}
_
{
num_channels
}
_
{
normalize
}
.wav'
)
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
1
*
sample_rate
)
save_wav
(
audio_path
,
data
,
sample_rate
)
ts_load_func
=
torch_script
(
py_load_func
)
py_data
,
py_sr
=
py_load_func
(
audio_path
,
normalize
=
normalize
,
channels_first
=
channels_first
)
ts_data
,
ts_sr
=
ts_load_func
(
audio_path
,
normalize
=
normalize
,
channels_first
=
channels_first
)
self
.
assertEqual
(
py_sr
,
ts_sr
)
self
.
assertEqual
(
py_data
,
ts_data
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
8000
,
16000
],
[
1
,
2
],
)),
name_func
=
name_func
)
def
test_save_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
ts_save_func
=
torch_script
(
py_save_func
)
expected
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
)
py_path
=
self
.
get_temp_path
(
f
'test_save_py_
{
dtype
}
_
{
sample_rate
}
_
{
num_channels
}
.wav'
)
ts_path
=
self
.
get_temp_path
(
f
'test_save_ts_
{
dtype
}
_
{
sample_rate
}
_
{
num_channels
}
.wav'
)
enc
,
bps
=
get_enc_params
(
dtype
)
py_save_func
(
py_path
,
expected
,
sample_rate
,
True
,
None
,
enc
,
bps
)
ts_save_func
(
ts_path
,
expected
,
sample_rate
,
True
,
None
,
enc
,
bps
)
py_data
,
py_sr
=
load_wav
(
py_path
,
normalize
=
False
)
ts_data
,
ts_sr
=
load_wav
(
ts_path
,
normalize
=
False
)
self
.
assertEqual
(
sample_rate
,
py_sr
)
self
.
assertEqual
(
sample_rate
,
ts_sr
)
self
.
assertEqual
(
expected
,
py_data
)
self
.
assertEqual
(
expected
,
ts_data
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
list
(
range
(
9
)),
)),
name_func
=
name_func
)
def
test_save_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
ts_save_func
=
torch_script
(
py_save_func
)
expected
=
get_wav_data
(
'float32'
,
num_channels
)
py_path
=
self
.
get_temp_path
(
f
'test_save_py_
{
sample_rate
}
_
{
num_channels
}
_
{
compression_level
}
.flac'
)
ts_path
=
self
.
get_temp_path
(
f
'test_save_ts_
{
sample_rate
}
_
{
num_channels
}
_
{
compression_level
}
.flac'
)
py_save_func
(
py_path
,
expected
,
sample_rate
,
True
,
compression_level
,
None
,
None
)
ts_save_func
(
ts_path
,
expected
,
sample_rate
,
True
,
compression_level
,
None
,
None
)
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
py_path_wav
=
f
'
{
py_path
}
.wav'
ts_path_wav
=
f
'
{
ts_path
}
.wav'
sox_utils
.
convert_audio_file
(
py_path
,
py_path_wav
,
bit_depth
=
32
)
sox_utils
.
convert_audio_file
(
ts_path
,
ts_path_wav
,
bit_depth
=
32
)
py_data
,
py_sr
=
load_wav
(
py_path_wav
,
normalize
=
True
)
ts_data
,
ts_sr
=
load_wav
(
ts_path_wav
,
normalize
=
True
)
self
.
assertEqual
(
sample_rate
,
py_sr
)
self
.
assertEqual
(
sample_rate
,
ts_sr
)
self
.
assertEqual
(
expected
,
py_data
)
self
.
assertEqual
(
expected
,
ts_data
)
test/torchaudio_unittest/backend/utils_test.py
0 → 100644
View file @
9dcc7a15
import
torchaudio
from
torchaudio_unittest
import
common_utils
class
BackendSwitchMixin
:
"""Test set/get_audio_backend works"""
backend
=
None
backend_module
=
None
def
test_switch
(
self
):
torchaudio
.
set_audio_backend
(
self
.
backend
)
if
self
.
backend
is
None
:
assert
torchaudio
.
get_audio_backend
()
is
None
else
:
assert
torchaudio
.
get_audio_backend
()
==
self
.
backend
assert
torchaudio
.
load
==
self
.
backend_module
.
load
assert
torchaudio
.
save
==
self
.
backend_module
.
save
assert
torchaudio
.
info
==
self
.
backend_module
.
info
class
TestBackendSwitch_NoBackend
(
BackendSwitchMixin
,
common_utils
.
TorchaudioTestCase
):
backend
=
None
backend_module
=
torchaudio
.
backend
.
no_backend
@
common_utils
.
skipIfNoSox
class
TestBackendSwitch_SoXIO
(
BackendSwitchMixin
,
common_utils
.
TorchaudioTestCase
):
backend
=
'sox_io'
backend_module
=
torchaudio
.
backend
.
sox_io_backend
@
common_utils
.
skipIfNoModule
(
'soundfile'
)
class
TestBackendSwitch_soundfile
(
BackendSwitchMixin
,
common_utils
.
TorchaudioTestCase
):
backend
=
'soundfile'
backend_module
=
torchaudio
.
backend
.
soundfile_backend
test/torchaudio_unittest/common_utils/__init__.py
0 → 100644
View file @
9dcc7a15
from
.data_utils
import
(
get_asset_path
,
get_whitenoise
,
get_sinusoid
,
get_spectrogram
,
)
from
.backend_utils
import
(
set_audio_backend
,
)
from
.case_utils
import
(
TempDirMixin
,
HttpServerMixin
,
TestBaseMixin
,
PytorchTestCase
,
TorchaudioTestCase
,
skipIfNoCuda
,
skipIfNoExec
,
skipIfNoModule
,
skipIfNoKaldi
,
skipIfNoSox
,
skipIfRocm
,
skipIfNoQengine
,
)
from
.wav_utils
import
(
get_wav_data
,
normalize_wav
,
load_wav
,
save_wav
,
)
from
.parameterized_utils
import
(
load_params
,
nested_params
)
from
.func_utils
import
torch_script
__all__
=
[
'get_asset_path'
,
'get_whitenoise'
,
'get_sinusoid'
,
'get_spectrogram'
,
'set_audio_backend'
,
'TempDirMixin'
,
'HttpServerMixin'
,
'TestBaseMixin'
,
'PytorchTestCase'
,
'TorchaudioTestCase'
,
'skipIfNoCuda'
,
'skipIfNoExec'
,
'skipIfNoModule'
,
'skipIfNoKaldi'
,
'skipIfNoSox'
,
'skipIfNoSoxBackend'
,
'skipIfRocm'
,
'skipIfNoQengine'
,
'get_wav_data'
,
'normalize_wav'
,
'load_wav'
,
'save_wav'
,
'load_params'
,
'nested_params'
,
'torch_script'
,
]
test/torchaudio_unittest/common_utils/backend_utils.py
0 → 100644
View file @
9dcc7a15
import
unittest
import
torchaudio
def
set_audio_backend
(
backend
):
"""Allow additional backend value, 'default'"""
backends
=
torchaudio
.
list_audio_backends
()
if
backend
==
'soundfile'
:
be
=
'soundfile'
elif
backend
==
'default'
:
if
'sox_io'
in
backends
:
be
=
'sox_io'
elif
'soundfile'
in
backends
:
be
=
'soundfile'
else
:
raise
unittest
.
SkipTest
(
'No default backend available'
)
else
:
be
=
backend
torchaudio
.
set_audio_backend
(
be
)
test/torchaudio_unittest/common_utils/case_utils.py
0 → 100644
View file @
9dcc7a15
import
shutil
import
os.path
import
subprocess
import
tempfile
import
time
import
unittest
import
torch
from
torch.testing._internal.common_utils
import
TestCase
as
PytorchTestCase
from
torchaudio._internal.module_utils
import
(
is_module_available
,
is_sox_available
,
is_kaldi_available
)
from
.backend_utils
import
set_audio_backend
class
TempDirMixin
:
"""Mixin to provide easy access to temp dir"""
temp_dir_
=
None
@
classmethod
def
get_base_temp_dir
(
cls
):
# If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory.
# this is handy for debugging.
key
=
'TORCHAUDIO_TEST_TEMP_DIR'
if
key
in
os
.
environ
:
return
os
.
environ
[
key
]
if
cls
.
temp_dir_
is
None
:
cls
.
temp_dir_
=
tempfile
.
TemporaryDirectory
()
return
cls
.
temp_dir_
.
name
@
classmethod
def
tearDownClass
(
cls
):
super
().
tearDownClass
()
if
cls
.
temp_dir_
is
not
None
:
cls
.
temp_dir_
.
cleanup
()
cls
.
temp_dir_
=
None
def
get_temp_path
(
self
,
*
paths
):
temp_dir
=
os
.
path
.
join
(
self
.
get_base_temp_dir
(),
self
.
id
())
path
=
os
.
path
.
join
(
temp_dir
,
*
paths
)
os
.
makedirs
(
os
.
path
.
dirname
(
path
),
exist_ok
=
True
)
return
path
class
HttpServerMixin
(
TempDirMixin
):
"""Mixin that serves temporary directory as web server
This class creates temporary directory and serve the directory as HTTP service.
The server is up through the execution of all the test suite defined under the subclass.
"""
_proc
=
None
_port
=
8000
@
classmethod
def
setUpClass
(
cls
):
super
().
setUpClass
()
cls
.
_proc
=
subprocess
.
Popen
(
[
'python'
,
'-m'
,
'http.server'
,
f
'
{
cls
.
_port
}
'
],
cwd
=
cls
.
get_base_temp_dir
(),
stderr
=
subprocess
.
DEVNULL
)
# Disable server-side error log because it is confusing
time
.
sleep
(
2.0
)
@
classmethod
def
tearDownClass
(
cls
):
super
().
tearDownClass
()
cls
.
_proc
.
kill
()
def
get_url
(
self
,
*
route
):
return
f
'http://localhost:
{
self
.
_port
}
/
{
self
.
id
()
}
/
{
"/"
.
join
(
route
)
}
'
class
TestBaseMixin
:
"""Mixin to provide consistent way to define device/dtype/backend aware TestCase"""
dtype
=
None
device
=
None
backend
=
None
def
setUp
(
self
):
super
().
setUp
()
set_audio_backend
(
self
.
backend
)
@
property
def
complex_dtype
(
self
):
if
self
.
dtype
in
[
'float32'
,
'float'
,
torch
.
float
,
torch
.
float32
]:
return
torch
.
cfloat
if
self
.
dtype
in
[
'float64'
,
'double'
,
torch
.
double
,
torch
.
float64
]:
return
torch
.
cdouble
raise
ValueError
(
f
'No corresponding complex dtype for
{
self
.
dtype
}
'
)
class
TorchaudioTestCase
(
TestBaseMixin
,
PytorchTestCase
):
pass
def
skipIfNoExec
(
cmd
):
return
unittest
.
skipIf
(
shutil
.
which
(
cmd
)
is
None
,
f
'`
{
cmd
}
` is not available'
)
def
skipIfNoModule
(
module
,
display_name
=
None
):
display_name
=
display_name
or
module
return
unittest
.
skipIf
(
not
is_module_available
(
module
),
f
'"
{
display_name
}
" is not available'
)
def
skipIfNoCuda
(
test_item
):
if
torch
.
cuda
.
is_available
():
return
test_item
force_cuda_test
=
os
.
environ
.
get
(
'TORCHAUDIO_TEST_FORCE_CUDA'
,
'0'
)
if
force_cuda_test
not
in
[
'0'
,
'1'
]:
raise
ValueError
(
'"TORCHAUDIO_TEST_FORCE_CUDA" must be either "0" or "1".'
)
if
force_cuda_test
==
'1'
:
raise
RuntimeError
(
'"TORCHAUDIO_TEST_FORCE_CUDA" is set but CUDA is not available.'
)
return
unittest
.
skip
(
'CUDA is not available.'
)(
test_item
)
skipIfNoSox
=
unittest
.
skipIf
(
not
is_sox_available
(),
reason
=
'Sox not available'
)
skipIfNoKaldi
=
unittest
.
skipIf
(
not
is_kaldi_available
(),
reason
=
'Kaldi not available'
)
skipIfRocm
=
unittest
.
skipIf
(
os
.
getenv
(
'TORCHAUDIO_TEST_WITH_ROCM'
,
'0'
)
==
'1'
,
reason
=
"test doesn't currently work on the ROCm stack"
)
skipIfNoQengine
=
unittest
.
skipIf
(
'fbgemm'
not
in
torch
.
backends
.
quantized
.
supported_engines
,
reason
=
"`fbgemm` is not available."
)
test/torchaudio_unittest/common_utils/data_utils.py
0 → 100644
View file @
9dcc7a15
import
os.path
from
typing
import
Union
,
Optional
import
torch
_TEST_DIR_PATH
=
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'..'
))
def
get_asset_path
(
*
paths
):
"""Return full path of a test asset"""
return
os
.
path
.
join
(
_TEST_DIR_PATH
,
'assets'
,
*
paths
)
def
convert_tensor_encoding
(
tensor
:
torch
.
tensor
,
dtype
:
torch
.
dtype
,
):
"""Convert input tensor with values between -1 and 1 to integer encoding
Args:
tensor: input tensor, assumed between -1 and 1
dtype: desired output tensor dtype
Returns:
Tensor: shape of (n_channels, sample_rate * duration)
"""
if
dtype
==
torch
.
int32
:
tensor
*=
(
tensor
>
0
)
*
2147483647
+
(
tensor
<
0
)
*
2147483648
if
dtype
==
torch
.
int16
:
tensor
*=
(
tensor
>
0
)
*
32767
+
(
tensor
<
0
)
*
32768
if
dtype
==
torch
.
uint8
:
tensor
*=
(
tensor
>
0
)
*
127
+
(
tensor
<
0
)
*
128
tensor
+=
128
tensor
=
tensor
.
to
(
dtype
)
return
tensor
def
get_whitenoise
(
*
,
sample_rate
:
int
=
16000
,
duration
:
float
=
1
,
# seconds
n_channels
:
int
=
1
,
seed
:
int
=
0
,
dtype
:
Union
[
str
,
torch
.
dtype
]
=
"float32"
,
device
:
Union
[
str
,
torch
.
device
]
=
"cpu"
,
channels_first
=
True
,
scale_factor
:
float
=
1
,
):
"""Generate pseudo audio data with whitenoise
Args:
sample_rate: Sampling rate
duration: Length of the resulting Tensor in seconds.
n_channels: Number of channels
seed: Seed value used for random number generation.
Note that this function does not modify global random generator state.
dtype: Torch dtype
device: device
channels_first: whether first dimension is n_channels
scale_factor: scale the Tensor before clamping and quantization
Returns:
Tensor: shape of (n_channels, sample_rate * duration)
"""
if
isinstance
(
dtype
,
str
):
dtype
=
getattr
(
torch
,
dtype
)
if
dtype
not
in
[
torch
.
float64
,
torch
.
float32
,
torch
.
int32
,
torch
.
int16
,
torch
.
uint8
]:
raise
NotImplementedError
(
f
'dtype
{
dtype
}
is not supported.'
)
# According to the doc, folking rng on all CUDA devices is slow when there are many CUDA devices,
# so we only fork on CPU, generate values and move the data to the given device
with
torch
.
random
.
fork_rng
([]):
torch
.
random
.
manual_seed
(
seed
)
tensor
=
torch
.
randn
([
n_channels
,
int
(
sample_rate
*
duration
)],
dtype
=
torch
.
float32
,
device
=
'cpu'
)
tensor
/=
2.0
tensor
*=
scale_factor
tensor
.
clamp_
(
-
1.0
,
1.0
)
if
not
channels_first
:
tensor
=
tensor
.
t
()
tensor
=
tensor
.
to
(
device
)
return
convert_tensor_encoding
(
tensor
,
dtype
)
def
get_sinusoid
(
*
,
frequency
:
float
=
300
,
sample_rate
:
int
=
16000
,
duration
:
float
=
1
,
# seconds
n_channels
:
int
=
1
,
dtype
:
Union
[
str
,
torch
.
dtype
]
=
"float32"
,
device
:
Union
[
str
,
torch
.
device
]
=
"cpu"
,
channels_first
:
bool
=
True
,
):
"""Generate pseudo audio data with sine wave.
Args:
frequency: Frequency of sine wave
sample_rate: Sampling rate
duration: Length of the resulting Tensor in seconds.
n_channels: Number of channels
dtype: Torch dtype
device: device
Returns:
Tensor: shape of (n_channels, sample_rate * duration)
"""
if
isinstance
(
dtype
,
str
):
dtype
=
getattr
(
torch
,
dtype
)
pie2
=
2
*
3.141592653589793
end
=
pie2
*
frequency
*
duration
theta
=
torch
.
linspace
(
0
,
end
,
int
(
sample_rate
*
duration
),
dtype
=
torch
.
float32
,
device
=
device
)
tensor
=
torch
.
sin
(
theta
,
out
=
None
).
repeat
([
n_channels
,
1
])
if
not
channels_first
:
tensor
=
tensor
.
t
()
return
convert_tensor_encoding
(
tensor
,
dtype
)
def
get_spectrogram
(
waveform
,
*
,
n_fft
:
int
=
2048
,
hop_length
:
Optional
[
int
]
=
None
,
win_length
:
Optional
[
int
]
=
None
,
window
:
Optional
[
torch
.
Tensor
]
=
None
,
center
:
bool
=
True
,
pad_mode
:
str
=
'reflect'
,
power
:
Optional
[
float
]
=
None
,
):
"""Generate a spectrogram of the given Tensor
Args:
n_fft: The number of FFT bins.
hop_length: Stride for sliding window. default: ``n_fft // 4``.
win_length: The size of window frame and STFT filter. default: ``n_fft``.
winwdow: Window function. default: Hann window
center: Pad the input sequence if True. See ``torch.stft`` for the detail.
pad_mode: Padding method used when center is True. Default: "reflect".
power: If ``None``, raw spectrogram with complex values are returned,
otherwise the norm of the spectrogram is returned.
"""
hop_length
=
hop_length
or
n_fft
//
4
win_length
=
win_length
or
n_fft
window
=
torch
.
hann_window
(
win_length
,
device
=
waveform
.
device
)
if
window
is
None
else
window
spec
=
torch
.
stft
(
waveform
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
win_length
=
win_length
,
center
=
center
,
window
=
window
,
pad_mode
=
pad_mode
,
return_complex
=
True
)
if
power
is
not
None
:
spec
=
spec
.
abs
()
**
power
return
spec
test/torchaudio_unittest/common_utils/func_utils.py
0 → 100644
View file @
9dcc7a15
import
io
import
torch
def
torch_script
(
obj
):
"""TorchScript the given function or Module"""
buffer
=
io
.
BytesIO
()
torch
.
jit
.
save
(
torch
.
jit
.
script
(
obj
),
buffer
)
buffer
.
seek
(
0
)
return
torch
.
jit
.
load
(
buffer
)
test/torchaudio_unittest/common_utils/kaldi_utils.py
0 → 100644
View file @
9dcc7a15
import
subprocess
import
torch
def
convert_args
(
**
kwargs
):
args
=
[]
for
key
,
value
in
kwargs
.
items
():
if
key
==
'sample_rate'
:
key
=
'sample_frequency'
key
=
'--'
+
key
.
replace
(
'_'
,
'-'
)
value
=
str
(
value
).
lower
()
if
value
in
[
True
,
False
]
else
str
(
value
)
args
.
append
(
'%s=%s'
%
(
key
,
value
))
return
args
def
run_kaldi
(
command
,
input_type
,
input_value
):
"""Run provided Kaldi command, pass a tensor and get the resulting tensor
Args:
command (list of str): The command with arguments
input_type (str): 'ark' or 'scp'
input_value (Tensor for 'ark', string for 'scp'): The input to pass.
Must be a path to an audio file for 'scp'.
"""
import
kaldi_io
key
=
'foo'
process
=
subprocess
.
Popen
(
command
,
stdin
=
subprocess
.
PIPE
,
stdout
=
subprocess
.
PIPE
)
if
input_type
==
'ark'
:
kaldi_io
.
write_mat
(
process
.
stdin
,
input_value
.
cpu
().
numpy
(),
key
=
key
)
elif
input_type
==
'scp'
:
process
.
stdin
.
write
(
f
'
{
key
}
{
input_value
}
'
.
encode
(
'utf8'
))
else
:
raise
NotImplementedError
(
'Unexpected type'
)
process
.
stdin
.
close
()
result
=
dict
(
kaldi_io
.
read_mat_ark
(
process
.
stdout
))[
'foo'
]
return
torch
.
from_numpy
(
result
.
copy
())
# copy supresses some torch warning
test/torchaudio_unittest/common_utils/parameterized_utils.py
0 → 100644
View file @
9dcc7a15
import
json
from
itertools
import
product
from
parameterized
import
param
,
parameterized
from
.data_utils
import
get_asset_path
def
load_params
(
*
paths
):
with
open
(
get_asset_path
(
*
paths
),
'r'
)
as
file
:
return
[
param
(
json
.
loads
(
line
))
for
line
in
file
]
def
_name_func
(
func
,
_
,
params
):
strs
=
[]
for
arg
in
params
.
args
:
if
isinstance
(
arg
,
tuple
):
strs
.
append
(
"_"
.
join
(
str
(
a
)
for
a
in
arg
))
else
:
strs
.
append
(
str
(
arg
))
# sanitize the test name
name
=
"_"
.
join
(
strs
).
replace
(
"."
,
"_"
)
return
f
'
{
func
.
__name__
}
_
{
name
}
'
def
nested_params
(
*
params_set
):
"""Generate the cartesian product of the given list of parameters.
Args:
params_set (list of parameters): Parameters. When using ``parameterized.param`` class,
all the parameters have to be specified with the class, only using kwargs.
"""
flatten
=
[
p
for
params
in
params_set
for
p
in
params
]
# Parameters to be nested are given as list of plain objects
if
all
(
not
isinstance
(
p
,
param
)
for
p
in
flatten
):
args
=
list
(
product
(
*
params_set
))
return
parameterized
.
expand
(
args
,
name_func
=
_name_func
)
# Parameters to be nested are given as list of `parameterized.param`
if
not
all
(
isinstance
(
p
,
param
)
for
p
in
flatten
):
raise
TypeError
(
"When using ``parameterized.param``, "
"all the parameters have to be of the ``param`` type."
)
if
any
(
p
.
args
for
p
in
flatten
):
raise
ValueError
(
"When using ``parameterized.param``, "
"all the parameters have to be provided as keyword argument."
)
args
=
[
param
()]
for
params
in
params_set
:
args
=
[
param
(
**
x
.
kwargs
,
**
y
.
kwargs
)
for
x
in
args
for
y
in
params
]
return
parameterized
.
expand
(
args
)
test/torchaudio_unittest/common_utils/psd_utils.py
0 → 100644
View file @
9dcc7a15
from
typing
import
Optional
import
numpy
as
np
import
torch
def
psd_numpy
(
X
:
np
.
array
,
mask
:
Optional
[
np
.
array
],
multi_mask
:
bool
=
False
,
normalize
:
bool
=
True
,
eps
:
float
=
1e-15
)
->
np
.
array
:
X_conj
=
np
.
conj
(
X
)
psd_X
=
np
.
einsum
(
"...cft,...eft->...ftce"
,
X
,
X_conj
)
if
mask
is
not
None
:
if
multi_mask
:
mask
=
mask
.
mean
(
axis
=-
3
)
if
normalize
:
mask
=
mask
/
(
mask
.
sum
(
axis
=-
1
,
keepdims
=
True
)
+
eps
)
psd
=
psd_X
*
mask
[...,
None
,
None
]
else
:
psd
=
psd_X
psd
=
psd
.
sum
(
axis
=-
3
)
return
torch
.
tensor
(
psd
,
dtype
=
torch
.
cdouble
)
Prev
1
…
8
9
10
11
12
13
14
15
16
…
21
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment