Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
ffeba11a
Commit
ffeba11a
authored
Sep 02, 2024
by
mayp777
Browse files
UPDATE
parent
29deb085
Changes
337
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4237 additions
and
0 deletions
+4237
-0
test/torchaudio_unittest/assets/wav2vec2/huggingface/wavlm-large.json
...dio_unittest/assets/wav2vec2/huggingface/wavlm-large.json
+98
-0
test/torchaudio_unittest/backend/dispatcher/__init__.py
test/torchaudio_unittest/backend/dispatcher/__init__.py
+0
-0
test/torchaudio_unittest/backend/dispatcher/dispatcher_test.py
...torchaudio_unittest/backend/dispatcher/dispatcher_test.py
+129
-0
test/torchaudio_unittest/backend/dispatcher/ffmpeg/__init__.py
...torchaudio_unittest/backend/dispatcher/ffmpeg/__init__.py
+0
-0
test/torchaudio_unittest/backend/dispatcher/ffmpeg/info_test.py
...orchaudio_unittest/backend/dispatcher/ffmpeg/info_test.py
+611
-0
test/torchaudio_unittest/backend/dispatcher/ffmpeg/load_test.py
...orchaudio_unittest/backend/dispatcher/ffmpeg/load_test.py
+617
-0
test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py
...orchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py
+455
-0
test/torchaudio_unittest/backend/dispatcher/smoke_test.py
test/torchaudio_unittest/backend/dispatcher/smoke_test.py
+56
-0
test/torchaudio_unittest/backend/dispatcher/soundfile/__init__.py
...chaudio_unittest/backend/dispatcher/soundfile/__init__.py
+0
-0
test/torchaudio_unittest/backend/dispatcher/soundfile/common.py
...orchaudio_unittest/backend/dispatcher/soundfile/common.py
+56
-0
test/torchaudio_unittest/backend/dispatcher/soundfile/info_test.py
...haudio_unittest/backend/dispatcher/soundfile/info_test.py
+191
-0
test/torchaudio_unittest/backend/dispatcher/soundfile/load_test.py
...haudio_unittest/backend/dispatcher/soundfile/load_test.py
+369
-0
test/torchaudio_unittest/backend/dispatcher/soundfile/save_test.py
...haudio_unittest/backend/dispatcher/soundfile/save_test.py
+319
-0
test/torchaudio_unittest/backend/dispatcher/sox/__init__.py
test/torchaudio_unittest/backend/dispatcher/sox/__init__.py
+0
-0
test/torchaudio_unittest/backend/dispatcher/sox/common.py
test/torchaudio_unittest/backend/dispatcher/sox/common.py
+14
-0
test/torchaudio_unittest/backend/dispatcher/sox/info_test.py
test/torchaudio_unittest/backend/dispatcher/sox/info_test.py
+398
-0
test/torchaudio_unittest/backend/dispatcher/sox/load_test.py
test/torchaudio_unittest/backend/dispatcher/sox/load_test.py
+369
-0
test/torchaudio_unittest/backend/dispatcher/sox/roundtrip_test.py
...chaudio_unittest/backend/dispatcher/sox/roundtrip_test.py
+59
-0
test/torchaudio_unittest/backend/dispatcher/sox/save_test.py
test/torchaudio_unittest/backend/dispatcher/sox/save_test.py
+416
-0
test/torchaudio_unittest/backend/dispatcher/sox/smoke_test.py
.../torchaudio_unittest/backend/dispatcher/sox/smoke_test.py
+80
-0
No files found.
Too many changes to show.
To preserve performance only
337 of 337+
files are displayed.
Plain diff
Email patch
test/torchaudio_unittest/assets/wav2vec2/huggingface/wavlm-large.json
0 → 100644
View file @
ffeba11a
{
"activation_dropout"
:
0.0
,
"adapter_kernel_size"
:
3
,
"adapter_stride"
:
2
,
"add_adapter"
:
false
,
"apply_spec_augment"
:
true
,
"architectures"
:
[
"WavLMModel"
],
"attention_dropout"
:
0.1
,
"bos_token_id"
:
1
,
"classifier_proj_size"
:
256
,
"codevector_dim"
:
768
,
"contrastive_logits_temperature"
:
0.1
,
"conv_bias"
:
false
,
"conv_dim"
:
[
512
,
512
,
512
,
512
,
512
,
512
,
512
],
"conv_kernel"
:
[
10
,
3
,
3
,
3
,
3
,
2
,
2
],
"conv_stride"
:
[
5
,
2
,
2
,
2
,
2
,
2
,
2
],
"ctc_loss_reduction"
:
"sum"
,
"ctc_zero_infinity"
:
false
,
"diversity_loss_weight"
:
0.1
,
"do_stable_layer_norm"
:
true
,
"eos_token_id"
:
2
,
"feat_extract_activation"
:
"gelu"
,
"feat_extract_dropout"
:
0.0
,
"feat_extract_norm"
:
"layer"
,
"feat_proj_dropout"
:
0.1
,
"feat_quantizer_dropout"
:
0.0
,
"final_dropout"
:
0.0
,
"gradient_checkpointing"
:
false
,
"hidden_act"
:
"gelu"
,
"hidden_dropout"
:
0.1
,
"hidden_size"
:
1024
,
"initializer_range"
:
0.02
,
"intermediate_size"
:
4096
,
"layer_norm_eps"
:
1e-05
,
"layerdrop"
:
0.1
,
"mask_channel_length"
:
10
,
"mask_channel_min_space"
:
1
,
"mask_channel_other"
:
0.0
,
"mask_channel_prob"
:
0.0
,
"mask_channel_selection"
:
"static"
,
"mask_feature_length"
:
10
,
"mask_feature_min_masks"
:
0
,
"mask_feature_prob"
:
0.0
,
"mask_time_length"
:
10
,
"mask_time_min_masks"
:
2
,
"mask_time_min_space"
:
1
,
"mask_time_other"
:
0.0
,
"mask_time_prob"
:
0.075
,
"mask_time_selection"
:
"static"
,
"max_bucket_distance"
:
800
,
"model_type"
:
"wavlm"
,
"num_adapter_layers"
:
3
,
"num_attention_heads"
:
16
,
"num_buckets"
:
320
,
"num_codevector_groups"
:
2
,
"num_codevectors_per_group"
:
320
,
"num_conv_pos_embedding_groups"
:
16
,
"num_conv_pos_embeddings"
:
128
,
"num_ctc_classes"
:
80
,
"num_feat_extract_layers"
:
7
,
"num_hidden_layers"
:
24
,
"num_negatives"
:
100
,
"output_hidden_size"
:
1024
,
"pad_token_id"
:
0
,
"proj_codevector_dim"
:
768
,
"replace_prob"
:
0.5
,
"tokenizer_class"
:
"Wav2Vec2CTCTokenizer"
,
"torch_dtype"
:
"float32"
,
"transformers_version"
:
"4.15.0.dev0"
,
"use_weighted_layer_sum"
:
false
,
"vocab_size"
:
32
}
test/torchaudio_unittest/backend/dispatcher/__init__.py
0 → 100644
View file @
ffeba11a
test/torchaudio_unittest/backend/dispatcher/dispatcher_test.py
0 → 100644
View file @
ffeba11a
import
io
from
unittest.mock
import
patch
import
torch
from
parameterized
import
parameterized
from
torchaudio._backend.utils
import
(
FFmpegBackend
,
get_info_func
,
get_load_func
,
get_save_func
,
SoundfileBackend
,
SoXBackend
,
)
from
torchaudio_unittest.common_utils
import
PytorchTestCase
class
DispatcherTest
(
PytorchTestCase
):
@
parameterized
.
expand
(
[
# FFmpeg backend is used when no backend is specified.
({
"ffmpeg"
:
FFmpegBackend
,
"sox"
:
SoXBackend
,
"soundfile"
:
SoundfileBackend
},
FFmpegBackend
),
# SoX backend is used when no backend is specified and FFmpeg is not available.
({
"sox"
:
SoXBackend
,
"soundfile"
:
SoundfileBackend
},
SoXBackend
),
]
)
def
test_info
(
self
,
available_backends
,
expected_backend
):
filename
=
"test.wav"
format
=
"wav"
with
patch
(
"torchaudio._backend.utils.get_available_backends"
,
return_value
=
available_backends
),
patch
(
f
"torchaudio._backend.utils.
{
expected_backend
.
__name__
}
.info"
)
as
mock_info
:
get_info_func
()(
filename
,
format
=
format
)
mock_info
.
assert_called_once_with
(
filename
,
format
,
4096
)
@
parameterized
.
expand
(
[
# FFmpeg backend is used when no backend is specified.
({
"ffmpeg"
:
FFmpegBackend
,
"sox"
:
SoXBackend
,
"soundfile"
:
SoundfileBackend
},
FFmpegBackend
),
# Soundfile backend is used when no backend is specified, FFmpeg is not available,
# and input is file-like object (i.e. SoX is properly skipped over).
({
"sox"
:
SoXBackend
,
"soundfile"
:
SoundfileBackend
},
SoundfileBackend
),
]
)
def
test_info_fileobj
(
self
,
available_backends
,
expected_backend
):
f
=
io
.
BytesIO
()
format
=
"wav"
buffer_size
=
8192
with
patch
(
"torchaudio._backend.utils.get_available_backends"
,
return_value
=
available_backends
),
patch
(
f
"torchaudio._backend.utils.
{
expected_backend
.
__name__
}
.info"
)
as
mock_info
:
get_info_func
()(
f
,
format
=
format
,
buffer_size
=
buffer_size
)
mock_info
.
assert_called_once_with
(
f
,
format
,
buffer_size
)
@
parameterized
.
expand
(
[
# FFmpeg backend is used when no backend is specified.
({
"ffmpeg"
:
FFmpegBackend
,
"sox"
:
SoXBackend
,
"soundfile"
:
SoundfileBackend
},
FFmpegBackend
),
# SoX backend is used when no backend is specified and FFmpeg is not available.
({
"sox"
:
SoXBackend
,
"soundfile"
:
SoundfileBackend
},
SoXBackend
),
]
)
def
test_load
(
self
,
available_backends
,
expected_backend
):
filename
=
"test.wav"
format
=
"wav"
with
patch
(
"torchaudio._backend.utils.get_available_backends"
,
return_value
=
available_backends
),
patch
(
f
"torchaudio._backend.utils.
{
expected_backend
.
__name__
}
.load"
)
as
mock_load
:
get_load_func
()(
filename
,
format
=
format
)
mock_load
.
assert_called_once_with
(
filename
,
0
,
-
1
,
True
,
True
,
format
,
4096
)
@
parameterized
.
expand
(
[
# FFmpeg backend is used when no backend is specified.
({
"ffmpeg"
:
FFmpegBackend
,
"sox"
:
SoXBackend
,
"soundfile"
:
SoundfileBackend
},
FFmpegBackend
),
# Soundfile backend is used when no backend is specified, FFmpeg is not available,
# and input is file-like object (i.e. SoX is properly skipped over).
({
"sox"
:
SoXBackend
,
"soundfile"
:
SoundfileBackend
},
SoundfileBackend
),
]
)
def
test_load_fileobj
(
self
,
available_backends
,
expected_backend
):
f
=
io
.
BytesIO
()
format
=
"wav"
buffer_size
=
8192
with
patch
(
"torchaudio._backend.utils.get_available_backends"
,
return_value
=
available_backends
),
patch
(
f
"torchaudio._backend.utils.
{
expected_backend
.
__name__
}
.load"
)
as
mock_load
:
get_load_func
()(
f
,
format
=
format
,
buffer_size
=
buffer_size
)
mock_load
.
assert_called_once_with
(
f
,
0
,
-
1
,
True
,
True
,
format
,
buffer_size
)
@
parameterized
.
expand
(
[
# FFmpeg backend is used when no backend is specified.
({
"ffmpeg"
:
FFmpegBackend
,
"sox"
:
SoXBackend
,
"soundfile"
:
SoundfileBackend
},
FFmpegBackend
),
# SoX backend is used when no backend is specified and FFmpeg is not available.
({
"sox"
:
SoXBackend
,
"soundfile"
:
SoundfileBackend
},
SoXBackend
),
]
)
def
test_save
(
self
,
available_backends
,
expected_backend
):
src
=
torch
.
zeros
((
2
,
10
))
filename
=
"test.wav"
format
=
"wav"
sample_rate
=
16000
with
patch
(
"torchaudio._backend.utils.get_available_backends"
,
return_value
=
available_backends
),
patch
(
f
"torchaudio._backend.utils.
{
expected_backend
.
__name__
}
.save"
)
as
mock_save
:
get_save_func
()(
filename
,
src
,
sample_rate
,
format
=
format
)
mock_save
.
assert_called_once_with
(
filename
,
src
,
sample_rate
,
True
,
format
,
None
,
None
,
4096
,
None
)
@
parameterized
.
expand
(
[
# FFmpeg backend is used when no backend is specified.
({
"ffmpeg"
:
FFmpegBackend
,
"sox"
:
SoXBackend
,
"soundfile"
:
SoundfileBackend
},
FFmpegBackend
),
# Soundfile backend is used when no backend is specified, FFmpeg is not available,
# and input is file-like object (i.e. SoX is properly skipped over).
({
"sox"
:
SoXBackend
,
"soundfile"
:
SoundfileBackend
},
SoundfileBackend
),
]
)
def
test_save_fileobj
(
self
,
available_backends
,
expected_backend
):
src
=
torch
.
zeros
((
2
,
10
))
f
=
io
.
BytesIO
()
format
=
"wav"
buffer_size
=
8192
sample_rate
=
16000
with
patch
(
"torchaudio._backend.utils.get_available_backends"
,
return_value
=
available_backends
),
patch
(
f
"torchaudio._backend.utils.
{
expected_backend
.
__name__
}
.save"
)
as
mock_save
:
get_save_func
()(
f
,
src
,
sample_rate
,
format
=
format
,
buffer_size
=
buffer_size
)
mock_save
.
assert_called_once_with
(
f
,
src
,
sample_rate
,
True
,
format
,
None
,
None
,
buffer_size
,
None
)
test/torchaudio_unittest/backend/dispatcher/ffmpeg/__init__.py
0 → 100644
View file @
ffeba11a
test/torchaudio_unittest/backend/dispatcher/ffmpeg/info_test.py
0 → 100644
View file @
ffeba11a
import
io
import
itertools
import
os
import
pathlib
import
tarfile
from
contextlib
import
contextmanager
from
functools
import
partial
from
parameterized
import
parameterized
from
torchaudio._backend.utils
import
get_info_func
from
torchaudio._internal
import
module_utils
as
_mod_utils
from
torchaudio.utils.sox_utils
import
get_buffer_size
,
set_buffer_size
from
torchaudio_unittest.backend.common
import
get_bits_per_sample
,
get_encoding
from
torchaudio_unittest.backend.dispatcher.sox.common
import
name_func
from
torchaudio_unittest.common_utils
import
(
get_asset_path
,
get_wav_data
,
HttpServerMixin
,
PytorchTestCase
,
save_wav
,
skipIfNoExec
,
skipIfNoFFmpeg
,
skipIfNoModule
,
sox_utils
,
TempDirMixin
,
)
if
_mod_utils
.
is_module_available
(
"requests"
):
import
requests
@
skipIfNoExec
(
"sox"
)
@
skipIfNoFFmpeg
class
TestInfo
(
TempDirMixin
,
PytorchTestCase
):
_info
=
partial
(
get_info_func
(),
backend
=
"ffmpeg"
)
def
test_pathlike
(
self
):
"""FFmpeg dispatcher can query audio data from pathlike object"""
sample_rate
=
16000
dtype
=
"float32"
num_channels
=
2
duration
=
1
path
=
self
.
get_temp_path
(
"data.wav"
)
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
path
,
data
,
sample_rate
)
info
=
self
.
_info
(
pathlib
.
Path
(
path
))
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
sox_utils
.
get_bit_depth
(
dtype
)
assert
info
.
encoding
==
get_encoding
(
"wav"
,
dtype
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
[
8000
,
16000
],
[
1
,
2
],
)
),
name_func
=
name_func
,
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`info` can check wav file correctly"""
duration
=
1
path
=
self
.
get_temp_path
(
"data.wav"
)
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
path
,
data
,
sample_rate
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
sox_utils
.
get_bit_depth
(
dtype
)
assert
info
.
encoding
==
get_encoding
(
"wav"
,
dtype
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
[
8000
,
16000
],
# NOTE: ffmpeg can't handle more than 16 channels.
[
4
,
8
,
16
],
)
),
name_func
=
name_func
,
)
def
test_wav_multiple_channels
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`info` can check wav file with channels more than 2 correctly"""
duration
=
1
path
=
self
.
get_temp_path
(
"data.wav"
)
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
path
,
data
,
sample_rate
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
sox_utils
.
get_bit_depth
(
dtype
)
assert
info
.
encoding
==
get_encoding
(
"wav"
,
dtype
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
[
96
,
128
,
160
,
192
,
224
,
256
,
320
],
)
),
name_func
=
name_func
,
)
def
test_mp3
(
self
,
sample_rate
,
num_channels
,
bit_rate
):
"""`info` can check mp3 file correctly"""
duration
=
1
path
=
self
.
get_temp_path
(
"data.mp3"
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
compression
=
bit_rate
,
duration
=
duration
,
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
# mp3 does not preserve the number of samples
# assert info.num_frames == sample_rate * duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
info
.
encoding
==
"MP3"
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
list
(
range
(
9
)),
)
),
name_func
=
name_func
,
)
def
test_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
"""`info` can check flac file correctly"""
duration
=
1
path
=
self
.
get_temp_path
(
"data.flac"
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
compression
=
compression_level
,
duration
=
duration
,
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
24
# FLAC standard
assert
info
.
encoding
==
"FLAC"
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
[
-
1
,
0
,
1
,
2
,
3
,
3.6
,
5
,
10
],
)
),
name_func
=
name_func
,
)
def
test_vorbis
(
self
,
sample_rate
,
num_channels
,
quality_level
):
"""`info` can check vorbis file correctly"""
duration
=
1
path
=
self
.
get_temp_path
(
"data.vorbis"
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
compression
=
quality_level
,
duration
=
duration
,
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
# FFmpeg: AssertionError: assert 16384 == (16000 * 1)
# assert info.num_frames == sample_rate * duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
info
.
encoding
==
"VORBIS"
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
[
16
,
32
],
)
),
name_func
=
name_func
,
)
def
test_sphere
(
self
,
sample_rate
,
num_channels
,
bits_per_sample
):
"""`info` can check sph file correctly"""
duration
=
1
path
=
self
.
get_temp_path
(
"data.sph"
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
duration
=
duration
,
bit_depth
=
bits_per_sample
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
encoding
==
"PCM_S"
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"int32"
,
"int16"
,
"uint8"
],
[
8000
,
16000
],
[
1
,
2
],
)
),
name_func
=
name_func
,
)
def
test_amb
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`info` can check amb file correctly"""
duration
=
1
path
=
self
.
get_temp_path
(
"data.amb"
)
bits_per_sample
=
sox_utils
.
get_bit_depth
(
dtype
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
bit_depth
=
bits_per_sample
,
duration
=
duration
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
encoding
==
get_encoding
(
"amb"
,
dtype
)
# # NOTE: amr-nb not yet implemented for ffmpeg
# def test_amr_nb(self):
# """`info` can check amr-nb file correctly"""
# duration = 1
# num_channels = 1
# sample_rate = 8000
# path = self.get_temp_path("data.amr-nb")
# sox_utils.gen_audio_file(
# path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=16, duration=duration
# )
# info = self._info(path)
# assert info.sample_rate == sample_rate
# assert info.num_frames == sample_rate * duration
# assert info.num_channels == num_channels
# assert info.bits_per_sample == 0
# assert info.encoding == "AMR_NB"
def
test_ulaw
(
self
):
"""`info` can check ulaw file correctly"""
duration
=
1
num_channels
=
1
sample_rate
=
8000
path
=
self
.
get_temp_path
(
"data.wav"
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
bit_depth
=
8
,
encoding
=
"u-law"
,
duration
=
duration
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
8
assert
info
.
encoding
==
"ULAW"
def
test_alaw
(
self
):
"""`info` can check alaw file correctly"""
duration
=
1
num_channels
=
1
sample_rate
=
8000
path
=
self
.
get_temp_path
(
"data.wav"
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
bit_depth
=
8
,
encoding
=
"a-law"
,
duration
=
duration
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
8
assert
info
.
encoding
==
"ALAW"
def
test_gsm
(
self
):
"""`info` can check gsm file correctly"""
duration
=
1
num_channels
=
1
sample_rate
=
8000
path
=
self
.
get_temp_path
(
"data.gsm"
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
duration
=
duration
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
assert
info
.
encoding
==
"GSM"
# NOTE: htk not supported (RuntimeError: Invalid data found when processing input)
# def test_htk(self):
# """`info` can check HTK file correctly"""
# duration = 1
# num_channels = 1
# sample_rate = 8000
# path = self.get_temp_path("data.htk")
# sox_utils.gen_audio_file(
# path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=16, duration=duration
# )
# info = self._info(path)
# assert info.sample_rate == sample_rate
# assert info.num_frames == sample_rate * duration
# assert info.num_channels == num_channels
# # assert info.bits_per_sample == 16
# assert info.encoding == "PCM_S"
@
skipIfNoExec
(
"sox"
)
@
skipIfNoFFmpeg
class
TestInfoOpus
(
PytorchTestCase
):
_info
=
partial
(
get_info_func
(),
backend
=
"ffmpeg"
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"96k"
],
[
1
,
2
],
[
0
,
5
,
10
],
)
),
name_func
=
name_func
,
)
def
test_opus
(
self
,
bitrate
,
num_channels
,
compression_level
):
"""`info` can check opus file correcty"""
path
=
get_asset_path
(
"io"
,
f
"
{
bitrate
}
_
{
compression_level
}
_
{
num_channels
}
ch.opus"
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
48000
assert
info
.
num_frames
==
32768
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
info
.
encoding
==
"OPUS"
@
skipIfNoExec
(
"sox"
)
@
skipIfNoFFmpeg
class
TestLoadWithoutExtension
(
PytorchTestCase
):
_info
=
partial
(
get_info_func
(),
backend
=
"ffmpeg"
)
def
test_mp3
(
self
):
"""MP3 file without extension can be loaded
Originally, we added `format` argument for this case, but now we use FFmpeg
for MP3 decoding, which works even without `format` argument.
https://github.com/pytorch/audio/issues/1040
The file was generated with the following command
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
"""
path
=
get_asset_path
(
"mp3_without_ext"
)
sinfo
=
self
.
_info
(
path
)
assert
sinfo
.
sample_rate
==
16000
assert
sinfo
.
num_frames
==
80000
assert
sinfo
.
num_channels
==
1
assert
sinfo
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
sinfo
.
encoding
==
"MP3"
with
open
(
path
,
"rb"
)
as
fileobj
:
sinfo
=
self
.
_info
(
fileobj
,
format
=
"mp3"
)
assert
sinfo
.
sample_rate
==
16000
assert
sinfo
.
num_frames
==
80000
assert
sinfo
.
num_channels
==
1
assert
sinfo
.
bits_per_sample
==
0
assert
sinfo
.
encoding
==
"MP3"
class
FileObjTestBase
(
TempDirMixin
):
def
_gen_file
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
*
,
comments
=
None
):
path
=
self
.
get_temp_path
(
f
"test.
{
ext
}
"
)
bit_depth
=
sox_utils
.
get_bit_depth
(
dtype
)
duration
=
num_frames
/
sample_rate
comment_file
=
self
.
_gen_comment_file
(
comments
)
if
comments
else
None
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
num_channels
,
encoding
=
sox_utils
.
get_encoding
(
dtype
),
bit_depth
=
bit_depth
,
duration
=
duration
,
comment_file
=
comment_file
,
)
return
path
def
_gen_comment_file
(
self
,
comments
):
comment_path
=
self
.
get_temp_path
(
"comment.txt"
)
with
open
(
comment_path
,
"w"
)
as
file_
:
file_
.
writelines
(
comments
)
return
comment_path
class
Unseekable
:
def
__init__
(
self
,
fileobj
):
self
.
fileobj
=
fileobj
def
read
(
self
,
n
):
return
self
.
fileobj
.
read
(
n
)
@
skipIfNoExec
(
"sox"
)
class
TestFileObject
(
FileObjTestBase
,
PytorchTestCase
):
_info
=
partial
(
get_info_func
(),
backend
=
"ffmpeg"
)
def
_query_fileobj
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
*
,
comments
=
None
):
path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
comments
=
comments
)
format_
=
ext
if
ext
in
[
"mp3"
]
else
None
with
open
(
path
,
"rb"
)
as
fileobj
:
return
self
.
_info
(
fileobj
,
format_
)
def
_query_bytesio
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
format_
=
ext
if
ext
in
[
"mp3"
]
else
None
with
open
(
path
,
"rb"
)
as
file_
:
fileobj
=
io
.
BytesIO
(
file_
.
read
())
return
self
.
_info
(
fileobj
,
format_
)
def
_query_tarfile
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
audio_path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
audio_file
=
os
.
path
.
basename
(
audio_path
)
archive_path
=
self
.
get_temp_path
(
"archive.tar.gz"
)
with
tarfile
.
TarFile
(
archive_path
,
"w"
)
as
tarobj
:
tarobj
.
add
(
audio_path
,
arcname
=
audio_file
)
format_
=
ext
if
ext
in
[
"mp3"
]
else
None
with
tarfile
.
TarFile
(
archive_path
,
"r"
)
as
tarobj
:
fileobj
=
tarobj
.
extractfile
(
audio_file
)
return
self
.
_info
(
fileobj
,
format_
)
@
contextmanager
def
_set_buffer_size
(
self
,
buffer_size
):
try
:
original_buffer_size
=
get_buffer_size
()
set_buffer_size
(
buffer_size
)
yield
finally
:
set_buffer_size
(
original_buffer_size
)
@
parameterized
.
expand
(
[
(
"wav"
,
"float32"
),
(
"wav"
,
"int32"
),
(
"wav"
,
"int16"
),
(
"wav"
,
"uint8"
),
(
"mp3"
,
"float32"
),
(
"flac"
,
"float32"
),
(
"vorbis"
,
"float32"
),
(
"amb"
,
"int16"
),
]
)
def
test_fileobj
(
self
,
ext
,
dtype
):
"""Querying audio via file object works"""
sample_rate
=
16000
num_frames
=
3
*
sample_rate
num_channels
=
2
sinfo
=
self
.
_query_fileobj
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
{
"vorbis"
:
48128
,
"mp3"
:
49536
}.
get
(
ext
,
num_frames
)
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
parameterized
.
expand
(
[
(
"wav"
,
"float32"
),
(
"wav"
,
"int32"
),
(
"wav"
,
"int16"
),
(
"wav"
,
"uint8"
),
(
"mp3"
,
"float32"
),
(
"flac"
,
"float32"
),
(
"vorbis"
,
"float32"
),
(
"amb"
,
"int16"
),
]
)
def
test_bytesio
(
self
,
ext
,
dtype
):
"""Querying audio via ByteIO object works for small data"""
sample_rate
=
16000
num_frames
=
3
*
sample_rate
num_channels
=
2
sinfo
=
self
.
_query_bytesio
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
{
"vorbis"
:
48128
,
"mp3"
:
49536
}.
get
(
ext
,
num_frames
)
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
parameterized
.
expand
(
[
(
"wav"
,
"float32"
),
(
"wav"
,
"int32"
),
(
"wav"
,
"int16"
),
(
"wav"
,
"uint8"
),
(
"mp3"
,
"float32"
),
(
"flac"
,
"float32"
),
(
"vorbis"
,
"float32"
),
(
"amb"
,
"int16"
),
]
)
def
test_bytesio_tiny
(
self
,
ext
,
dtype
):
"""Querying audio via ByteIO object works for small data"""
sample_rate
=
8000
num_frames
=
4
num_channels
=
2
sinfo
=
self
.
_query_bytesio
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
{
"vorbis"
:
256
,
"mp3"
:
1728
}.
get
(
ext
,
num_frames
)
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
parameterized
.
expand
(
[
(
"wav"
,
"float32"
),
(
"wav"
,
"int32"
),
(
"wav"
,
"int16"
),
(
"wav"
,
"uint8"
),
(
"mp3"
,
"float32"
),
(
"flac"
,
"float32"
),
(
"vorbis"
,
"float32"
),
(
"amb"
,
"int16"
),
]
)
def
test_tarfile
(
self
,
ext
,
dtype
):
"""Querying compressed audio via file-like object works"""
sample_rate
=
16000
num_frames
=
3.0
*
sample_rate
num_channels
=
2
sinfo
=
self
.
_query_tarfile
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
{
"vorbis"
:
48128
,
"mp3"
:
49536
}.
get
(
ext
,
num_frames
)
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
skipIfNoFFmpeg
@
skipIfNoExec
(
"sox"
)
@
skipIfNoModule
(
"requests"
)
class
TestFileObjectHttp
(
HttpServerMixin
,
FileObjTestBase
,
PytorchTestCase
):
_info
=
partial
(
get_info_func
(),
backend
=
"ffmpeg"
)
def
_query_http
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
audio_path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
audio_file
=
os
.
path
.
basename
(
audio_path
)
url
=
self
.
get_url
(
audio_file
)
format_
=
ext
if
ext
in
[
"mp3"
]
else
None
with
requests
.
get
(
url
,
stream
=
True
)
as
resp
:
return
self
.
_info
(
Unseekable
(
resp
.
raw
),
format
=
format_
)
@
parameterized
.
expand
(
[
(
"wav"
,
"float32"
),
(
"wav"
,
"int32"
),
(
"wav"
,
"int16"
),
(
"wav"
,
"uint8"
),
(
"mp3"
,
"float32"
),
(
"flac"
,
"float32"
),
(
"vorbis"
,
"float32"
),
(
"amb"
,
"int16"
),
]
)
def
test_requests
(
self
,
ext
,
dtype
):
"""Querying compressed audio via requests works"""
sample_rate
=
16000
num_frames
=
3.0
*
sample_rate
num_channels
=
2
sinfo
=
self
.
_query_http
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
{
"vorbis"
:
48128
,
"mp3"
:
49536
}.
get
(
ext
,
num_frames
)
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
skipIfNoExec
(
"sox"
)
@
skipIfNoFFmpeg
class
TestInfoNoSuchFile
(
PytorchTestCase
):
_info
=
partial
(
get_info_func
(),
backend
=
"ffmpeg"
)
def
test_info_fail
(
self
):
"""
When attempted to get info on a non-existing file, error message must contain the file path.
"""
path
=
"non_existing_audio.wav"
with
self
.
assertRaisesRegex
(
RuntimeError
,
path
):
self
.
_info
(
path
)
test/torchaudio_unittest/backend/dispatcher/ffmpeg/load_test.py
0 → 100644
View file @
ffeba11a
import
io
import
itertools
import
pathlib
import
tarfile
from
functools
import
partial
from
parameterized
import
parameterized
from
torchaudio._backend.ffmpeg
import
_parse_save_args
from
torchaudio._backend.utils
import
get_load_func
from
torchaudio._internal
import
module_utils
as
_mod_utils
from
torchaudio_unittest.backend.dispatcher.sox.common
import
name_func
from
torchaudio_unittest.common_utils
import
(
disabledInCI
,
get_asset_path
,
get_wav_data
,
HttpServerMixin
,
load_wav
,
PytorchTestCase
,
save_wav
,
skipIfNoExec
,
skipIfNoFFmpeg
,
skipIfNoModule
,
sox_utils
,
TempDirMixin
,
)
from
.save_test
import
_convert_audio_file
if
_mod_utils
.
is_module_available
(
"requests"
):
import
requests
class
LoadTestBase
(
TempDirMixin
,
PytorchTestCase
):
_load
=
partial
(
get_load_func
(),
backend
=
"ffmpeg"
)
def
assert_format
(
self
,
format
:
str
,
sample_rate
:
float
,
num_channels
:
int
,
compression
:
float
=
None
,
bit_depth
:
int
=
None
,
duration
:
float
=
1
,
normalize
:
bool
=
True
,
encoding
:
str
=
None
,
atol
:
float
=
4e-05
,
rtol
:
float
=
1.3e-06
,
):
"""`self._load` can load given format correctly.
file encodings introduce delay and boundary effects so
we create a reference wav file from the original file format
x
|
| 1. Generate given format with Sox
|
+ ----------------------------------+ 3. Convert to wav with FFmpeg
| |
| 2. Load the given format | 4. Load with scipy
| with torchaudio |
v v
tensor ----------> x <----------- tensor
5. Compare
Underlying assumptions are;
i. Conversion of given format to wav with FFmpeg preserves data.
ii. Loading wav file with scipy is correct.
By combining i & ii, step 2. and 4. allow for loading reference given format
data without using torchaudio
"""
path
=
self
.
get_temp_path
(
f
"1.original.
{
format
}
"
)
ref_path
=
self
.
get_temp_path
(
"2.reference.wav"
)
# 1. Generate the given format with sox
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
encoding
=
encoding
,
compression
=
compression
,
bit_depth
=
bit_depth
,
duration
=
duration
,
)
# 2. Load the given format with torchaudio
data
,
sr
=
self
.
_load
(
path
,
normalize
=
normalize
)
# 3. Convert to wav with ffmpeg
if
normalize
:
encoder
=
"pcm_f32le"
else
:
encoding_map
=
{
"floating-point"
:
"PCM_F"
,
"signed-integer"
:
"PCM_S"
,
"unsigned-integer"
:
"PCM_U"
,
}
_
,
encoder
,
_
=
_parse_save_args
(
format
,
format
,
encoding_map
.
get
(
encoding
),
bit_depth
)
_convert_audio_file
(
path
,
ref_path
,
encoder
=
encoder
)
# 4. Load wav with scipy
data_ref
=
load_wav
(
ref_path
,
normalize
=
normalize
)[
0
]
# 5. Compare
assert
sr
==
sample_rate
self
.
assertEqual
(
data
,
data_ref
,
atol
=
atol
,
rtol
=
rtol
)
def
assert_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
duration
):
"""`self._load` can load wav format correctly.
Wav data loaded with sox_io backend should match those with scipy
"""
path
=
self
.
get_temp_path
(
"reference.wav"
)
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
normalize
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
path
,
data
,
sample_rate
)
expected
=
load_wav
(
path
,
normalize
=
normalize
)[
0
]
data
,
sr
=
self
.
_load
(
path
,
normalize
=
normalize
)
assert
sr
==
sample_rate
self
.
assertEqual
(
data
,
expected
)
@
skipIfNoExec
(
"sox"
)
@
skipIfNoFFmpeg
class
TestLoad
(
LoadTestBase
):
"""Test the correctness of `self._load` for various formats"""
def
test_pathlike
(
self
):
"""FFmpeg dispatcher can load waveform from pathlike object"""
sample_rate
=
16000
dtype
=
"float32"
num_channels
=
2
duration
=
1
path
=
self
.
get_temp_path
(
"data.wav"
)
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
path
,
data
,
sample_rate
)
waveform
,
sr
=
self
.
_load
(
pathlib
.
Path
(
path
))
self
.
assertEqual
(
sr
,
sample_rate
)
self
.
assertEqual
(
waveform
,
data
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
)
),
name_func
=
name_func
,
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
):
"""`self._load` can load wav format correctly."""
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
normalize
,
duration
=
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
)
),
name_func
=
name_func
,
)
def
test_24bit_wav
(
self
,
sample_rate
,
num_channels
,
normalize
):
"""`self._load` can load 24bit wav format correctly. Corectly casts it to ``int32`` tensor dtype."""
self
.
assert_format
(
"wav"
,
sample_rate
,
num_channels
,
bit_depth
=
24
,
normalize
=
normalize
,
duration
=
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"int16"
],
[
16000
],
[
2
],
[
False
],
)
),
name_func
=
name_func
,
)
def
test_wav_large
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
):
"""`self._load` can load large wav file correctly."""
two_hours
=
2
*
60
*
60
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
normalize
,
two_hours
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
[
4
,
8
,
16
],
)
),
name_func
=
name_func
,
)
def
test_multiple_channels
(
self
,
dtype
,
num_channels
):
"""`self._load` can load wav file with more than 2 channels."""
sample_rate
=
8000
normalize
=
False
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
normalize
,
duration
=
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
list
(
range
(
9
)),
)
),
name_func
=
name_func
,
)
def
test_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
"""`self._load` can load flac format correctly."""
self
.
assert_format
(
"flac"
,
sample_rate
,
num_channels
,
compression
=
compression_level
,
bit_depth
=
16
,
duration
=
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
16000
],
[
2
],
[
0
],
)
),
name_func
=
name_func
,
)
def
test_flac_large
(
self
,
sample_rate
,
num_channels
,
compression_level
):
"""`self._load` can load large flac file correctly."""
two_hours
=
2
*
60
*
60
self
.
assert_format
(
"flac"
,
sample_rate
,
num_channels
,
compression
=
compression_level
,
bit_depth
=
16
,
duration
=
two_hours
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
[
-
1
,
0
,
1
,
2
,
3
,
3.6
,
5
,
10
],
)
),
name_func
=
name_func
,
)
def
test_vorbis
(
self
,
sample_rate
,
num_channels
,
quality_level
):
"""`self._load` can load vorbis format correctly."""
self
.
assert_format
(
"vorbis"
,
sample_rate
,
num_channels
,
compression
=
quality_level
,
bit_depth
=
16
,
duration
=
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
16000
],
[
2
],
[
10
],
)
),
name_func
=
name_func
,
)
def
test_vorbis_large
(
self
,
sample_rate
,
num_channels
,
quality_level
):
"""`self._load` can load large vorbis file correctly."""
two_hours
=
2
*
60
*
60
self
.
assert_format
(
"vorbis"
,
sample_rate
,
num_channels
,
compression
=
quality_level
,
bit_depth
=
16
,
duration
=
two_hours
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"96k"
],
[
1
,
2
],
[
0
,
5
,
10
],
)
),
name_func
=
name_func
,
)
def
test_opus
(
self
,
bitrate
,
num_channels
,
compression_level
):
"""`self._load` can load opus file correctly."""
ops_path
=
get_asset_path
(
"io"
,
f
"
{
bitrate
}
_
{
compression_level
}
_
{
num_channels
}
ch.opus"
)
wav_path
=
self
.
get_temp_path
(
f
"
{
bitrate
}
_
{
compression_level
}
_
{
num_channels
}
ch.opus.wav"
)
_convert_audio_file
(
ops_path
,
wav_path
,
encoder
=
"pcm_f32le"
)
expected
,
sample_rate
=
load_wav
(
wav_path
)
found
,
sr
=
self
.
_load
(
ops_path
)
assert
sample_rate
==
sr
self
.
assertEqual
(
expected
,
found
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
)
),
name_func
=
name_func
,
)
def
test_sphere
(
self
,
sample_rate
,
num_channels
):
"""`self._load` can load sph format correctly."""
self
.
assert_format
(
"sph"
,
sample_rate
,
num_channels
,
bit_depth
=
32
,
duration
=
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"int16"
],
[
3
,
4
,
16
],
[
False
,
True
],
)
),
name_func
=
name_func
,
)
def
test_amb
(
self
,
dtype
,
num_channels
,
normalize
,
sample_rate
=
8000
):
"""`self._load` can load amb format correctly."""
bit_depth
=
sox_utils
.
get_bit_depth
(
dtype
)
encoding
=
sox_utils
.
get_encoding
(
dtype
)
self
.
assert_format
(
"amb"
,
sample_rate
,
num_channels
,
bit_depth
=
bit_depth
,
duration
=
1
,
encoding
=
encoding
,
normalize
=
normalize
)
# # NOTE: FFmpeg: RuntimeError: Failed to process a packet. (Not yet implemented in FFmpeg, patches welcome).
# def test_amr_nb(self):
# """`self._load` can load amr_nb format correctly."""
# self.assert_format("amr-nb", sample_rate=8000, num_channels=1, bit_depth=32, duration=1)
@
skipIfNoExec
(
"sox"
)
@
skipIfNoFFmpeg
class
TestLoadWithoutExtension
(
PytorchTestCase
):
_load
=
partial
(
get_load_func
(),
backend
=
"ffmpeg"
)
def
test_mp3
(
self
):
"""MP3 file without extension can be loaded
Originally, we added `format` argument for this case, but now we use FFmpeg
for MP3 decoding, which works even without `format` argument.
https://github.com/pytorch/audio/issues/1040
The file was generated with the following command
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
"""
path
=
get_asset_path
(
"mp3_without_ext"
)
_
,
sr
=
self
.
_load
(
path
)
assert
sr
==
16000
with
open
(
path
,
"rb"
)
as
fileobj
:
_
,
sr
=
self
.
_load
(
fileobj
)
assert
sr
==
16000
class
CloggedFileObj
:
def
__init__
(
self
,
fileobj
):
self
.
fileobj
=
fileobj
def
read
(
self
,
_
):
return
self
.
fileobj
.
read
(
2
)
def
seek
(
self
,
offset
,
whence
):
return
self
.
fileobj
.
seek
(
offset
,
whence
)
@
skipIfNoFFmpeg
@
skipIfNoExec
(
"sox"
)
class
TestFileObject
(
TempDirMixin
,
PytorchTestCase
):
"""
In this test suite, the result of file-like object input is compared against file path input,
because `load` function is rigrously tested for file path inputs to match libsox's result,
"""
_load
=
partial
(
get_load_func
(),
backend
=
"ffmpeg"
)
@
parameterized
.
expand
(
[
(
"wav"
,
{
"bit_depth"
:
16
}),
(
"wav"
,
{
"bit_depth"
:
24
}),
(
"wav"
,
{
"bit_depth"
:
32
}),
(
"mp3"
,
{
"compression"
:
128
}),
(
"mp3"
,
{
"compression"
:
320
}),
(
"flac"
,
{
"compression"
:
0
}),
(
"flac"
,
{
"compression"
:
5
}),
(
"flac"
,
{
"compression"
:
8
}),
(
"vorbis"
,
{
"compression"
:
-
1
}),
(
"vorbis"
,
{
"compression"
:
10
}),
(
"amb"
,
{}),
]
)
def
test_fileobj
(
self
,
ext
,
kwargs
):
"""Loading audio via file object returns the same result as via file path."""
sample_rate
=
16000
format_
=
ext
if
ext
in
[
"mp3"
]
else
None
path
=
self
.
get_temp_path
(
f
"test.
{
ext
}
"
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
2
,
**
kwargs
)
expected
,
_
=
self
.
_load
(
path
)
with
open
(
path
,
"rb"
)
as
fileobj
:
found
,
sr
=
self
.
_load
(
fileobj
,
format
=
format_
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
@
parameterized
.
expand
(
[
(
"wav"
,
{
"bit_depth"
:
16
}),
(
"wav"
,
{
"bit_depth"
:
24
}),
(
"wav"
,
{
"bit_depth"
:
32
}),
(
"mp3"
,
{
"compression"
:
128
}),
(
"mp3"
,
{
"compression"
:
320
}),
(
"flac"
,
{
"compression"
:
0
}),
(
"flac"
,
{
"compression"
:
5
}),
(
"flac"
,
{
"compression"
:
8
}),
(
"vorbis"
,
{
"compression"
:
-
1
}),
(
"vorbis"
,
{
"compression"
:
10
}),
(
"amb"
,
{}),
]
)
def
test_bytesio
(
self
,
ext
,
kwargs
):
"""Loading audio via BytesIO object returns the same result as via file path."""
sample_rate
=
16000
format_
=
ext
if
ext
in
[
"mp3"
]
else
None
path
=
self
.
get_temp_path
(
f
"test.
{
ext
}
"
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
2
,
**
kwargs
)
expected
,
_
=
self
.
_load
(
path
)
with
open
(
path
,
"rb"
)
as
file_
:
fileobj
=
io
.
BytesIO
(
file_
.
read
())
found
,
sr
=
self
.
_load
(
fileobj
,
format
=
format_
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
@
parameterized
.
expand
(
[
(
"wav"
,
{
"bit_depth"
:
16
}),
(
"wav"
,
{
"bit_depth"
:
24
}),
(
"wav"
,
{
"bit_depth"
:
32
}),
(
"mp3"
,
{
"compression"
:
128
}),
(
"mp3"
,
{
"compression"
:
320
}),
(
"flac"
,
{
"compression"
:
0
}),
(
"flac"
,
{
"compression"
:
5
}),
(
"flac"
,
{
"compression"
:
8
}),
(
"vorbis"
,
{
"compression"
:
-
1
}),
(
"vorbis"
,
{
"compression"
:
10
}),
(
"amb"
,
{}),
]
)
def
test_bytesio_clogged
(
self
,
ext
,
kwargs
):
"""Loading audio via clogged file object returns the same result as via file path.
This test case validates the case where fileobject returns shorter bytes than requeted.
"""
sample_rate
=
16000
format_
=
ext
if
ext
in
[
"mp3"
]
else
None
path
=
self
.
get_temp_path
(
f
"test.
{
ext
}
"
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
2
,
**
kwargs
)
expected
,
_
=
self
.
_load
(
path
)
with
open
(
path
,
"rb"
)
as
file_
:
fileobj
=
CloggedFileObj
(
io
.
BytesIO
(
file_
.
read
()))
found
,
sr
=
self
.
_load
(
fileobj
,
format
=
format_
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
@
parameterized
.
expand
(
[
(
"wav"
,
{
"bit_depth"
:
16
}),
(
"wav"
,
{
"bit_depth"
:
24
}),
(
"wav"
,
{
"bit_depth"
:
32
}),
(
"mp3"
,
{
"compression"
:
128
}),
(
"mp3"
,
{
"compression"
:
320
}),
(
"flac"
,
{
"compression"
:
0
}),
(
"flac"
,
{
"compression"
:
5
}),
(
"flac"
,
{
"compression"
:
8
}),
(
"vorbis"
,
{
"compression"
:
-
1
}),
(
"vorbis"
,
{
"compression"
:
10
}),
(
"amb"
,
{}),
]
)
def
test_bytesio_tiny
(
self
,
ext
,
kwargs
):
"""Loading very small audio via file object returns the same result as via file path."""
sample_rate
=
16000
format_
=
ext
if
ext
in
[
"mp3"
]
else
None
path
=
self
.
get_temp_path
(
f
"test.
{
ext
}
"
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
2
,
duration
=
1
/
1600
,
**
kwargs
)
expected
,
_
=
self
.
_load
(
path
)
with
open
(
path
,
"rb"
)
as
file_
:
fileobj
=
io
.
BytesIO
(
file_
.
read
())
found
,
sr
=
self
.
_load
(
fileobj
,
format
=
format_
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
@
parameterized
.
expand
(
[
(
"wav"
,
{
"bit_depth"
:
16
}),
(
"wav"
,
{
"bit_depth"
:
24
}),
(
"wav"
,
{
"bit_depth"
:
32
}),
(
"mp3"
,
{
"compression"
:
128
}),
(
"mp3"
,
{
"compression"
:
320
}),
(
"flac"
,
{
"compression"
:
0
}),
(
"flac"
,
{
"compression"
:
5
}),
(
"flac"
,
{
"compression"
:
8
}),
(
"vorbis"
,
{
"compression"
:
-
1
}),
(
"vorbis"
,
{
"compression"
:
10
}),
(
"amb"
,
{}),
]
)
def
test_tarfile
(
self
,
ext
,
kwargs
):
"""Loading compressed audio via file-like object returns the same result as via file path."""
sample_rate
=
16000
format_
=
ext
if
ext
in
[
"mp3"
]
else
None
audio_file
=
f
"test.
{
ext
}
"
audio_path
=
self
.
get_temp_path
(
audio_file
)
archive_path
=
self
.
get_temp_path
(
"archive.tar.gz"
)
sox_utils
.
gen_audio_file
(
audio_path
,
sample_rate
,
num_channels
=
2
,
**
kwargs
)
expected
,
_
=
self
.
_load
(
audio_path
)
with
tarfile
.
TarFile
(
archive_path
,
"w"
)
as
tarobj
:
tarobj
.
add
(
audio_path
,
arcname
=
audio_file
)
with
tarfile
.
TarFile
(
archive_path
,
"r"
)
as
tarobj
:
fileobj
=
tarobj
.
extractfile
(
audio_file
)
found
,
sr
=
self
.
_load
(
fileobj
,
format
=
format_
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
class
Unseekable
:
def
__init__
(
self
,
fileobj
):
self
.
fileobj
=
fileobj
def
read
(
self
,
n
):
return
self
.
fileobj
.
read
(
n
)
@
disabledInCI
@
skipIfNoFFmpeg
@
skipIfNoExec
(
"sox"
)
@
skipIfNoModule
(
"requests"
)
class
TestFileObjectHttp
(
HttpServerMixin
,
PytorchTestCase
):
_load
=
partial
(
get_load_func
(),
backend
=
"ffmpeg"
)
@
parameterized
.
expand
(
[
(
"wav"
,
{
"bit_depth"
:
16
}),
(
"wav"
,
{
"bit_depth"
:
24
}),
(
"wav"
,
{
"bit_depth"
:
32
}),
(
"mp3"
,
{
"compression"
:
128
}),
(
"mp3"
,
{
"compression"
:
320
}),
(
"flac"
,
{
"compression"
:
0
}),
(
"flac"
,
{
"compression"
:
5
}),
(
"flac"
,
{
"compression"
:
8
}),
(
"vorbis"
,
{
"compression"
:
-
1
}),
(
"vorbis"
,
{
"compression"
:
10
}),
(
"amb"
,
{}),
]
)
def
test_requests
(
self
,
ext
,
kwargs
):
sample_rate
=
16000
format_
=
ext
if
ext
in
[
"mp3"
]
else
None
audio_file
=
f
"test.
{
ext
}
"
audio_path
=
self
.
get_temp_path
(
audio_file
)
sox_utils
.
gen_audio_file
(
audio_path
,
sample_rate
,
num_channels
=
2
,
**
kwargs
)
expected
,
_
=
self
.
_load
(
audio_path
)
url
=
self
.
get_url
(
audio_file
)
with
requests
.
get
(
url
,
stream
=
True
)
as
resp
:
found
,
sr
=
self
.
_load
(
Unseekable
(
resp
.
raw
),
format
=
format_
)
assert
sr
==
sample_rate
if
ext
!=
"mp3"
:
self
.
assertEqual
(
expected
,
found
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
0
,
1
,
10
,
100
,
1000
],
[
-
1
,
1
,
10
,
100
,
1000
],
)
),
name_func
=
name_func
,
)
def
test_frame
(
self
,
frame_offset
,
num_frames
):
"""num_frames and frame_offset correctly specify the region of data"""
sample_rate
=
8000
audio_file
=
"test.wav"
audio_path
=
self
.
get_temp_path
(
audio_file
)
original
=
get_wav_data
(
"float32"
,
num_channels
=
2
)
save_wav
(
audio_path
,
original
,
sample_rate
)
frame_end
=
None
if
num_frames
==
-
1
else
frame_offset
+
num_frames
expected
=
original
[:,
frame_offset
:
frame_end
]
url
=
self
.
get_url
(
audio_file
)
with
requests
.
get
(
url
,
stream
=
True
)
as
resp
:
found
,
sr
=
self
.
_load
(
Unseekable
(
resp
.
raw
),
frame_offset
,
num_frames
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
@
skipIfNoExec
(
"sox"
)
@
skipIfNoFFmpeg
class
TestLoadNoSuchFile
(
PytorchTestCase
):
_load
=
partial
(
get_load_func
(),
backend
=
"ffmpeg"
)
def
test_load_fail
(
self
):
"""
When attempted to load a non-existing file, error message must contain the file path.
"""
path
=
"non_existing_audio.wav"
with
self
.
assertRaisesRegex
(
RuntimeError
,
path
):
self
.
_load
(
path
)
test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py
0 → 100644
View file @
ffeba11a
import
io
import
os
import
pathlib
import
subprocess
import
sys
from
functools
import
partial
from
typing
import
Optional
import
torch
from
parameterized
import
parameterized
from
torchaudio._backend.ffmpeg
import
_parse_save_args
from
torchaudio._backend.utils
import
get_save_func
from
torchaudio.io
import
CodecConfig
from
torchaudio_unittest.backend.dispatcher.sox.common
import
get_enc_params
,
name_func
from
torchaudio_unittest.common_utils
import
(
disabledInCI
,
get_wav_data
,
load_wav
,
nested_params
,
PytorchTestCase
,
save_wav
,
skipIfNoExec
,
skipIfNoFFmpeg
,
TempDirMixin
,
TorchaudioTestCase
,
)
def
_convert_audio_file
(
src_path
,
dst_path
,
muxer
=
None
,
encoder
=
None
,
sample_fmt
=
None
):
command
=
[
"ffmpeg"
,
"-hide_banner"
,
"-y"
,
"-i"
,
src_path
,
"-strict"
,
"-2"
]
if
muxer
:
command
+=
[
"-f"
,
muxer
]
if
encoder
:
command
+=
[
"-acodec"
,
encoder
]
if
sample_fmt
:
command
+=
[
"-sample_fmt"
,
sample_fmt
]
command
+=
[
dst_path
]
print
(
" "
.
join
(
command
),
file
=
sys
.
stderr
)
subprocess
.
run
(
command
,
check
=
True
)
class
SaveTestBase
(
TempDirMixin
,
TorchaudioTestCase
):
_save
=
partial
(
get_save_func
(),
backend
=
"ffmpeg"
)
def
assert_save_consistency
(
self
,
format
:
str
,
*
,
compression
:
Optional
[
CodecConfig
]
=
None
,
encoding
:
str
=
None
,
bits_per_sample
:
int
=
None
,
sample_rate
:
float
=
8000
,
num_channels
:
int
=
2
,
num_frames
:
float
=
3
*
8000
,
src_dtype
:
str
=
"int32"
,
test_mode
:
str
=
"path"
,
):
"""`save` function produces file that is comparable with `ffmpeg` command
To compare that the file produced by `save` function agains the file produced by
the equivalent `ffmpeg` command, we need to load both files.
But there are many formats that cannot be opened with common Python modules (like
SciPy).
So we use `ffmpeg` command to prepare the original data and convert the saved files
into a format that SciPy can read (PCM wav).
The following diagram illustrates this process. The difference is 2.1. and 3.1.
This assumes that
- loading data with SciPy preserves the data well.
- converting the resulting files into WAV format with `ffmpeg` preserve the data well.
x
| 1. Generate source wav file with SciPy
|
v
-------------- wav ----------------
| |
| 2.1. load with scipy | 3.1. Convert to the target
| then save it into the target | format depth with ffmpeg
| format with torchaudio |
v v
target format target format
| |
| 2.2. Convert to wav with ffmpeg | 3.2. Convert to wav with ffmpeg
| |
v v
wav wav
| |
| 2.3. load with scipy | 3.3. load with scipy
| |
v v
tensor -------> compare <--------- tensor
"""
src_path
=
self
.
get_temp_path
(
"1.source.wav"
)
tgt_path
=
self
.
get_temp_path
(
f
"2.1.torchaudio.
{
format
}
"
)
tst_path
=
self
.
get_temp_path
(
"2.2.result.wav"
)
sox_path
=
self
.
get_temp_path
(
f
"3.1.ffmpeg.
{
format
}
"
)
ref_path
=
self
.
get_temp_path
(
"3.2.ref.wav"
)
# 1. Generate original wav
data
=
get_wav_data
(
src_dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
num_frames
)
save_wav
(
src_path
,
data
,
sample_rate
)
# 2.1. Convert the original wav to target format with torchaudio
data
=
load_wav
(
src_path
,
normalize
=
False
)[
0
]
if
test_mode
==
"path"
:
ext
=
format
self
.
_save
(
tgt_path
,
data
,
sample_rate
,
compression
=
compression
,
format
=
format
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
)
elif
test_mode
==
"fileobj"
:
ext
=
None
with
open
(
tgt_path
,
"bw"
)
as
file_
:
self
.
_save
(
file_
,
data
,
sample_rate
,
compression
=
compression
,
format
=
format
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
)
elif
test_mode
==
"bytesio"
:
file_
=
io
.
BytesIO
()
ext
=
None
self
.
_save
(
file_
,
data
,
sample_rate
,
compression
=
compression
,
format
=
format
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
)
file_
.
seek
(
0
)
with
open
(
tgt_path
,
"bw"
)
as
f
:
f
.
write
(
file_
.
read
())
else
:
raise
ValueError
(
f
"Unexpected test mode:
{
test_mode
}
"
)
# 2.2. Convert the target format to wav with ffmpeg
_convert_audio_file
(
tgt_path
,
tst_path
,
encoder
=
"pcm_f32le"
)
# 2.3. Load with SciPy
found
=
load_wav
(
tst_path
,
normalize
=
False
)[
0
]
# 3.1. Convert the original wav to target format with ffmpeg
muxer
,
encoder
,
sample_fmt
=
_parse_save_args
(
ext
,
format
,
encoding
,
bits_per_sample
)
_convert_audio_file
(
src_path
,
sox_path
,
muxer
=
muxer
,
encoder
=
encoder
,
sample_fmt
=
sample_fmt
)
# 3.2. Convert the target format to wav with ffmpeg
_convert_audio_file
(
sox_path
,
ref_path
,
encoder
=
"pcm_f32le"
)
# 3.3. Load with SciPy
expected
=
load_wav
(
ref_path
,
normalize
=
False
)[
0
]
self
.
assertEqual
(
found
,
expected
)
@
disabledInCI
@
skipIfNoExec
(
"sox"
)
@
skipIfNoExec
(
"ffmpeg"
)
@
skipIfNoFFmpeg
class
SaveTest
(
SaveTestBase
):
def
test_pathlike
(
self
):
"""FFmpeg dispatcher can save audio data to pathlike object"""
sample_rate
=
16000
dtype
=
"float32"
num_channels
=
2
duration
=
1
path
=
self
.
get_temp_path
(
"data.wav"
)
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
self
.
_save
(
pathlib
.
Path
(
path
),
data
,
sample_rate
)
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
[
(
"PCM_U"
,
8
),
(
"PCM_S"
,
16
),
(
"PCM_S"
,
32
),
(
"PCM_F"
,
32
),
(
"PCM_F"
,
64
),
(
"ULAW"
,
8
),
(
"ALAW"
,
8
),
],
)
def
test_save_wav
(
self
,
test_mode
,
enc_params
):
encoding
,
bits_per_sample
=
enc_params
self
.
assert_save_consistency
(
"wav"
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
test_mode
=
test_mode
)
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
[
(
"float32"
,),
(
"int32"
,),
(
"int16"
,),
(
"uint8"
,),
],
)
def
test_save_wav_dtype
(
self
,
test_mode
,
params
):
(
dtype
,)
=
params
self
.
assert_save_consistency
(
"wav"
,
src_dtype
=
dtype
,
test_mode
=
test_mode
)
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
# NOTE: Supported sample formats: s16 s32 (24 bits)
# [8, 16, 24],
[
16
,
24
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
],
)
def
test_save_flac
(
self
,
test_mode
,
bits_per_sample
,
compression_level
):
# -acodec flac -sample_fmt s16
# 24 bits needs to be mapped to s32
codec_config
=
CodecConfig
(
compression_level
=
compression_level
,
)
self
.
assert_save_consistency
(
"flac"
,
compression
=
codec_config
,
bits_per_sample
=
bits_per_sample
,
test_mode
=
test_mode
)
# @nested_params(
# ["path", "fileobj", "bytesio"],
# )
# # NOTE: FFmpeg: Unable to find a suitable output format
# def test_save_htk(self, test_mode):
# self.assert_save_consistency("htk", test_mode=test_mode, num_channels=1)
@
nested_params
(
[
None
,
-
1
,
0
,
1
,
2
,
3
,
5
,
10
,
],
[
"path"
,
"fileobj"
,
"bytesio"
],
)
def
test_save_vorbis
(
self
,
quality_level
,
test_mode
):
# NOTE: ffmpeg doesn't recognize extension "vorbis", so we use "ogg"
# self.assert_save_consistency("vorbis", test_mode=test_mode)
codec_config
=
CodecConfig
(
qscale
=
quality_level
,
)
self
.
assert_save_consistency
(
"ogg"
,
compression
=
codec_config
,
test_mode
=
test_mode
)
# @nested_params(
# ["path", "fileobj", "bytesio"],
# [
# (
# "PCM_S",
# 8,
# ),
# (
# "PCM_S",
# 16,
# ),
# (
# "PCM_S",
# 24,
# ),
# (
# "PCM_S",
# 32,
# ),
# ("ULAW", 8),
# ("ALAW", 8),
# ("ALAW", 16),
# ("ALAW", 24),
# ("ALAW", 32),
# ],
# )
# NOTE: FFmpeg doesn't support encoding sphere files.
# def test_save_sphere(self, test_mode, enc_params):
# encoding, bits_per_sample = enc_params
# self.assert_save_consistency("sph", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
# @nested_params(
# ["path", "fileobj", "bytesio"],
# [
# (
# "PCM_U",
# 8,
# ),
# (
# "PCM_S",
# 16,
# ),
# (
# "PCM_S",
# 24,
# ),
# (
# "PCM_S",
# 32,
# ),
# (
# "PCM_F",
# 32,
# ),
# (
# "PCM_F",
# 64,
# ),
# (
# "ULAW",
# 8,
# ),
# (
# "ALAW",
# 8,
# ),
# ],
# )
# NOTE: FFmpeg doesn't support amb.
# def test_save_amb(self, test_mode, enc_params):
# encoding, bits_per_sample = enc_params
# self.assert_save_consistency("amb", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
# @nested_params(
# ["path", "fileobj", "bytesio"],
# )
# # NOTE: FFmpeg: Unable to find a suitable output format
# def test_save_amr_nb(self, test_mode):
# self.assert_save_consistency("amr-nb", num_channels=1, test_mode=test_mode)
# @nested_params(
# ["path", "fileobj", "bytesio"],
# )
# # NOTE: FFmpeg: RuntimeError: Unexpected codec: gsm
# def test_save_gsm(self, test_mode):
# self.assert_save_consistency("gsm", num_channels=1, test_mode=test_mode)
# with self.assertRaises(RuntimeError, msg="gsm format only supports single channel audio."):
# self.assert_save_consistency("gsm", num_channels=2, test_mode=test_mode)
# with self.assertRaises(RuntimeError, msg="gsm format only supports a sampling rate of 8kHz."):
# self.assert_save_consistency("gsm", sample_rate=16000, test_mode=test_mode)
@
parameterized
.
expand
(
[
(
"wav"
,
"PCM_S"
,
16
),
(
"flac"
,),
(
"ogg"
,),
# ("sph", "PCM_S", 16),
# ("amr-nb",),
# ("amb", "PCM_S", 16),
],
name_func
=
name_func
,
)
def
test_save_large
(
self
,
format
,
encoding
=
None
,
bits_per_sample
=
None
):
"""`self._save` can save large files."""
sample_rate
=
8000
one_hour
=
60
*
60
*
sample_rate
self
.
assert_save_consistency
(
format
,
# NOTE: for ogg, ffmpeg only supports >= 2 channels
num_channels
=
2
,
sample_rate
=
8000
,
num_frames
=
one_hour
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
)
@
parameterized
.
expand
(
[
(
16
,),
# NOTE: FFmpeg doesn't support more than 16 channels.
# (32,),
# (64,),
# (128,),
# (256,),
],
name_func
=
name_func
,
)
def
test_save_multi_channels
(
self
,
num_channels
):
"""`self._save` can save audio with many channels"""
self
.
assert_save_consistency
(
"wav"
,
encoding
=
"PCM_S"
,
bits_per_sample
=
16
,
num_channels
=
num_channels
)
@
skipIfNoExec
(
"sox"
)
@
skipIfNoFFmpeg
class
TestSaveParams
(
TempDirMixin
,
PytorchTestCase
):
"""Test the correctness of optional parameters of `self._save`"""
_save
=
partial
(
get_save_func
(),
backend
=
"ffmpeg"
)
@
parameterized
.
expand
([(
True
,),
(
False
,)],
name_func
=
name_func
)
def
test_save_channels_first
(
self
,
channels_first
):
"""channels_first swaps axes"""
path
=
self
.
get_temp_path
(
"data.wav"
)
data
=
get_wav_data
(
"int16"
,
2
,
channels_first
=
channels_first
,
normalize
=
False
)
self
.
_save
(
path
,
data
,
8000
,
channels_first
=
channels_first
)
found
=
load_wav
(
path
,
normalize
=
False
)[
0
]
expected
=
data
if
channels_first
else
data
.
transpose
(
1
,
0
)
self
.
assertEqual
(
found
,
expected
)
@
parameterized
.
expand
([
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
name_func
=
name_func
)
def
test_save_noncontiguous
(
self
,
dtype
):
"""Noncontiguous tensors are saved correctly"""
path
=
self
.
get_temp_path
(
"data.wav"
)
enc
,
bps
=
get_enc_params
(
dtype
)
expected
=
get_wav_data
(
dtype
,
4
,
normalize
=
False
)[::
2
,
::
2
]
assert
not
expected
.
is_contiguous
()
self
.
_save
(
path
,
expected
,
8000
,
encoding
=
enc
,
bits_per_sample
=
bps
)
found
=
load_wav
(
path
,
normalize
=
False
)[
0
]
self
.
assertEqual
(
found
,
expected
)
@
parameterized
.
expand
(
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
,
]
)
def
test_save_tensor_preserve
(
self
,
dtype
):
"""save function should not alter Tensor"""
path
=
self
.
get_temp_path
(
"data.wav"
)
expected
=
get_wav_data
(
dtype
,
4
,
normalize
=
False
)[::
2
,
::
2
]
data
=
expected
.
clone
()
self
.
_save
(
path
,
data
,
8000
)
self
.
assertEqual
(
data
,
expected
)
@
disabledInCI
@
skipIfNoExec
(
"sox"
)
@
skipIfNoFFmpeg
class
TestSaveNonExistingDirectory
(
PytorchTestCase
):
_save
=
partial
(
get_save_func
(),
backend
=
"ffmpeg"
)
def
test_save_fail
(
self
):
"""
When attempted to save into a non-existing dir, error message must contain the file path.
"""
path
=
os
.
path
.
join
(
"non_existing_directory"
,
"foo.wav"
)
with
self
.
assertRaisesRegex
(
RuntimeError
,
path
):
self
.
_save
(
path
,
torch
.
zeros
(
1
,
1
),
8000
)
test/torchaudio_unittest/backend/dispatcher/smoke_test.py
0 → 100644
View file @
ffeba11a
import
io
from
torchaudio._backend.utils
import
get_info_func
,
get_load_func
,
get_save_func
from
torchaudio_unittest.common_utils
import
get_wav_data
,
PytorchTestCase
,
skipIfNoFFmpeg
,
TempDirMixin
@
skipIfNoFFmpeg
class
SmokeTest
(
TempDirMixin
,
PytorchTestCase
):
def
run_smoke_test
(
self
,
ext
,
sample_rate
,
num_channels
,
*
,
dtype
=
"float32"
):
duration
=
1
num_frames
=
sample_rate
*
duration
path
=
self
.
get_temp_path
(
f
"test.
{
ext
}
"
)
original
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
num_frames
)
get_save_func
()(
path
,
original
,
sample_rate
)
info
=
get_info_func
()(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_channels
==
num_channels
loaded
,
sr
=
get_load_func
()(
path
,
normalize
=
False
)
assert
sr
==
sample_rate
assert
loaded
.
shape
[
0
]
==
num_channels
def
test_wav
(
self
):
dtype
=
"float32"
sample_rate
=
16000
num_channels
=
2
self
.
run_smoke_test
(
"wav"
,
sample_rate
,
num_channels
,
dtype
=
dtype
)
@
skipIfNoFFmpeg
class
SmokeTestFileObj
(
TempDirMixin
,
PytorchTestCase
):
def
run_smoke_test
(
self
,
ext
,
sample_rate
,
num_channels
,
*
,
dtype
=
"float32"
):
buffer_size
=
8192
duration
=
1
num_frames
=
sample_rate
*
duration
fileobj
=
io
.
BytesIO
()
original
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
num_frames
)
get_save_func
()(
fileobj
,
original
,
sample_rate
,
format
=
ext
,
buffer_size
=
buffer_size
)
fileobj
.
seek
(
0
)
info
=
get_info_func
()(
fileobj
,
format
=
ext
,
buffer_size
=
buffer_size
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_channels
==
num_channels
fileobj
.
seek
(
0
)
loaded
,
sr
=
get_load_func
()(
fileobj
,
normalize
=
False
,
format
=
ext
,
buffer_size
=
buffer_size
)
assert
sr
==
sample_rate
assert
loaded
.
shape
[
0
]
==
num_channels
def
test_wav
(
self
):
dtype
=
"float32"
sample_rate
=
16000
num_channels
=
2
self
.
run_smoke_test
(
"wav"
,
sample_rate
,
num_channels
,
dtype
=
dtype
)
test/torchaudio_unittest/backend/dispatcher/soundfile/__init__.py
0 → 100644
View file @
ffeba11a
test/torchaudio_unittest/backend/dispatcher/soundfile/common.py
0 → 100644
View file @
ffeba11a
import
itertools
from
unittest
import
skipIf
from
parameterized
import
parameterized
from
torchaudio._internal.module_utils
import
is_module_available
def
name_func
(
func
,
_
,
params
):
return
f
'
{
func
.
__name__
}
_
{
"_"
.
join
(
str
(
arg
)
for
arg
in
params
.
args
)
}
'
def
dtype2subtype
(
dtype
):
return
{
"float64"
:
"DOUBLE"
,
"float32"
:
"FLOAT"
,
"int32"
:
"PCM_32"
,
"int16"
:
"PCM_16"
,
"uint8"
:
"PCM_U8"
,
"int8"
:
"PCM_S8"
,
}[
dtype
]
def
skipIfFormatNotSupported
(
fmt
):
fmts
=
[]
if
is_module_available
(
"soundfile"
):
import
soundfile
fmts
=
soundfile
.
available_formats
()
return
skipIf
(
fmt
not
in
fmts
,
f
'"
{
fmt
}
" is not supported by soundfile'
)
return
skipIf
(
True
,
'"soundfile" not available.'
)
def
parameterize
(
*
params
):
return
parameterized
.
expand
(
list
(
itertools
.
product
(
*
params
)),
name_func
=
name_func
)
def
fetch_wav_subtype
(
dtype
,
encoding
,
bits_per_sample
):
subtype
=
{
(
None
,
None
):
dtype2subtype
(
dtype
),
(
None
,
8
):
"PCM_U8"
,
(
"PCM_U"
,
None
):
"PCM_U8"
,
(
"PCM_U"
,
8
):
"PCM_U8"
,
(
"PCM_S"
,
None
):
"PCM_32"
,
(
"PCM_S"
,
16
):
"PCM_16"
,
(
"PCM_S"
,
32
):
"PCM_32"
,
(
"PCM_F"
,
None
):
"FLOAT"
,
(
"PCM_F"
,
32
):
"FLOAT"
,
(
"PCM_F"
,
64
):
"DOUBLE"
,
(
"ULAW"
,
None
):
"ULAW"
,
(
"ULAW"
,
8
):
"ULAW"
,
(
"ALAW"
,
None
):
"ALAW"
,
(
"ALAW"
,
8
):
"ALAW"
,
}.
get
((
encoding
,
bits_per_sample
))
if
subtype
:
return
subtype
raise
ValueError
(
f
"wav does not support (
{
encoding
}
,
{
bits_per_sample
}
)."
)
test/torchaudio_unittest/backend/dispatcher/soundfile/info_test.py
0 → 100644
View file @
ffeba11a
import
tarfile
import
warnings
from
functools
import
partial
from
unittest.mock
import
patch
import
torch
from
torchaudio._backend.utils
import
get_info_func
from
torchaudio._internal
import
module_utils
as
_mod_utils
from
torchaudio_unittest.backend.common
import
get_bits_per_sample
,
get_encoding
from
torchaudio_unittest.common_utils
import
(
get_wav_data
,
nested_params
,
PytorchTestCase
,
save_wav
,
skipIfNoModule
,
TempDirMixin
,
)
from
.common
import
parameterize
,
skipIfFormatNotSupported
if
_mod_utils
.
is_module_available
(
"soundfile"
):
import
soundfile
@
skipIfNoModule
(
"soundfile"
)
class
TestInfo
(
TempDirMixin
,
PytorchTestCase
):
_info
=
partial
(
get_info_func
(),
backend
=
"soundfile"
)
@
parameterize
(
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
[
8000
,
16000
],
[
1
,
2
],
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`self._info` can check wav file correctly"""
duration
=
1
path
=
self
.
get_temp_path
(
"data.wav"
)
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
path
,
data
,
sample_rate
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
get_bits_per_sample
(
"wav"
,
dtype
)
assert
info
.
encoding
==
get_encoding
(
"wav"
,
dtype
)
@
parameterize
([
8000
,
16000
],
[
1
,
2
])
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_flac
(
self
,
sample_rate
,
num_channels
):
"""`self._info` can check flac file correctly"""
duration
=
1
num_frames
=
sample_rate
*
duration
data
=
torch
.
randn
(
num_frames
,
num_channels
).
numpy
()
path
=
self
.
get_temp_path
(
"data.flac"
)
soundfile
.
write
(
path
,
data
,
sample_rate
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
num_frames
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
16
assert
info
.
encoding
==
"FLAC"
@
parameterize
([
8000
,
16000
],
[
1
,
2
])
@
skipIfFormatNotSupported
(
"OGG"
)
def
test_ogg
(
self
,
sample_rate
,
num_channels
):
"""`self._info` can check ogg file correctly"""
duration
=
1
num_frames
=
sample_rate
*
duration
data
=
torch
.
randn
(
num_frames
,
num_channels
).
numpy
()
path
=
self
.
get_temp_path
(
"data.ogg"
)
soundfile
.
write
(
path
,
data
,
sample_rate
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
assert
info
.
encoding
==
"VORBIS"
@
nested_params
(
[
8000
,
16000
],
[
1
,
2
],
[(
"PCM_24"
,
24
),
(
"PCM_32"
,
32
)],
)
@
skipIfFormatNotSupported
(
"NIST"
)
def
test_sphere
(
self
,
sample_rate
,
num_channels
,
subtype_and_bit_depth
):
"""`self._info` can check sph file correctly"""
duration
=
1
num_frames
=
sample_rate
*
duration
data
=
torch
.
randn
(
num_frames
,
num_channels
).
numpy
()
path
=
self
.
get_temp_path
(
"data.nist"
)
subtype
,
bits_per_sample
=
subtype_and_bit_depth
soundfile
.
write
(
path
,
data
,
sample_rate
,
subtype
=
subtype
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
encoding
==
"PCM_S"
def
test_unknown_subtype_warning
(
self
):
"""self._info issues a warning when the subtype is unknown
This will happen if a new subtype is supported in SoundFile: the _SUBTYPE_TO_BITS_PER_SAMPLE
dict should be updated.
"""
def
_mock_info_func
(
_
):
class
MockSoundFileInfo
:
samplerate
=
8000
frames
=
356
channels
=
2
subtype
=
"UNSEEN_SUBTYPE"
format
=
"UNKNOWN"
return
MockSoundFileInfo
()
with
patch
(
"soundfile.info"
,
_mock_info_func
):
with
warnings
.
catch_warnings
(
record
=
True
)
as
w
:
info
=
self
.
_info
(
"foo"
)
assert
len
(
w
)
==
1
assert
"UNSEEN_SUBTYPE subtype is unknown to TorchAudio"
in
str
(
w
[
-
1
].
message
)
assert
info
.
bits_per_sample
==
0
@
skipIfNoModule
(
"soundfile"
)
class
TestFileObject
(
TempDirMixin
,
PytorchTestCase
):
_info
=
partial
(
get_info_func
(),
backend
=
"soundfile"
)
def
_test_fileobj
(
self
,
ext
,
subtype
,
bits_per_sample
):
"""Query audio via file-like object works"""
duration
=
2
sample_rate
=
16000
num_channels
=
2
num_frames
=
sample_rate
*
duration
path
=
self
.
get_temp_path
(
f
"test.
{
ext
}
"
)
data
=
torch
.
randn
(
num_frames
,
num_channels
).
numpy
()
soundfile
.
write
(
path
,
data
,
sample_rate
,
subtype
=
subtype
)
with
open
(
path
,
"rb"
)
as
fileobj
:
info
=
self
.
_info
(
fileobj
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
num_frames
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
encoding
==
"FLAC"
if
ext
==
"flac"
else
"PCM_S"
def
test_fileobj_wav
(
self
):
"""Loading audio via file-like object works"""
self
.
_test_fileobj
(
"wav"
,
"PCM_16"
,
16
)
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_fileobj_flac
(
self
):
"""Loading audio via file-like object works"""
self
.
_test_fileobj
(
"flac"
,
"PCM_16"
,
16
)
def
_test_tarobj
(
self
,
ext
,
subtype
,
bits_per_sample
):
"""Query compressed audio via file-like object works"""
duration
=
2
sample_rate
=
16000
num_channels
=
2
num_frames
=
sample_rate
*
duration
audio_file
=
f
"test.
{
ext
}
"
audio_path
=
self
.
get_temp_path
(
audio_file
)
archive_path
=
self
.
get_temp_path
(
"archive.tar.gz"
)
data
=
torch
.
randn
(
num_frames
,
num_channels
).
numpy
()
soundfile
.
write
(
audio_path
,
data
,
sample_rate
,
subtype
=
subtype
)
with
tarfile
.
TarFile
(
archive_path
,
"w"
)
as
tarobj
:
tarobj
.
add
(
audio_path
,
arcname
=
audio_file
)
with
tarfile
.
TarFile
(
archive_path
,
"r"
)
as
tarobj
:
fileobj
=
tarobj
.
extractfile
(
audio_file
)
info
=
self
.
_info
(
fileobj
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
num_frames
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
encoding
==
"FLAC"
if
ext
==
"flac"
else
"PCM_S"
def
test_tarobj_wav
(
self
):
"""Query compressed audio via file-like object works"""
self
.
_test_tarobj
(
"wav"
,
"PCM_16"
,
16
)
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_tarobj_flac
(
self
):
"""Query compressed audio via file-like object works"""
self
.
_test_tarobj
(
"flac"
,
"PCM_16"
,
16
)
test/torchaudio_unittest/backend/dispatcher/soundfile/load_test.py
0 → 100644
View file @
ffeba11a
import
os
import
tarfile
from
functools
import
partial
from
unittest.mock
import
patch
import
torch
from
parameterized
import
parameterized
from
torchaudio._backend.utils
import
get_load_func
from
torchaudio._internal
import
module_utils
as
_mod_utils
from
torchaudio_unittest.common_utils
import
(
get_wav_data
,
load_wav
,
normalize_wav
,
PytorchTestCase
,
save_wav
,
skipIfNoModule
,
TempDirMixin
,
)
from
.common
import
dtype2subtype
,
parameterize
,
skipIfFormatNotSupported
if
_mod_utils
.
is_module_available
(
"soundfile"
):
import
soundfile
def
_get_mock_path
(
ext
:
str
,
dtype
:
str
,
sample_rate
:
int
,
num_channels
:
int
,
num_frames
:
int
,
):
return
f
"
{
dtype
}
_
{
sample_rate
}
_
{
num_channels
}
_
{
num_frames
}
.
{
ext
}
"
def
_get_mock_params
(
path
:
str
):
filename
,
ext
=
path
.
split
(
"."
)
parts
=
filename
.
split
(
"_"
)
return
{
"ext"
:
ext
,
"dtype"
:
parts
[
0
],
"sample_rate"
:
int
(
parts
[
1
]),
"num_channels"
:
int
(
parts
[
2
]),
"num_frames"
:
int
(
parts
[
3
]),
}
class
SoundFileMock
:
def
__init__
(
self
,
path
,
mode
):
assert
mode
==
"r"
self
.
path
=
path
self
.
_params
=
_get_mock_params
(
path
)
self
.
_start
=
None
@
property
def
samplerate
(
self
):
return
self
.
_params
[
"sample_rate"
]
@
property
def
format
(
self
):
if
self
.
_params
[
"ext"
]
==
"wav"
:
return
"WAV"
if
self
.
_params
[
"ext"
]
==
"flac"
:
return
"FLAC"
if
self
.
_params
[
"ext"
]
==
"ogg"
:
return
"OGG"
if
self
.
_params
[
"ext"
]
in
[
"sph"
,
"nis"
,
"nist"
]:
return
"NIST"
@
property
def
subtype
(
self
):
if
self
.
_params
[
"ext"
]
==
"ogg"
:
return
"VORBIS"
return
dtype2subtype
(
self
.
_params
[
"dtype"
])
def
_prepare_read
(
self
,
start
,
stop
,
frames
):
assert
stop
is
None
self
.
_start
=
start
return
frames
def
read
(
self
,
frames
,
dtype
,
always_2d
):
assert
always_2d
data
=
get_wav_data
(
dtype
,
self
.
_params
[
"num_channels"
],
normalize
=
False
,
num_frames
=
self
.
_params
[
"num_frames"
],
channels_first
=
False
,
).
numpy
()
return
data
[
self
.
_start
:
self
.
_start
+
frames
]
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
*
args
,
**
kwargs
):
pass
class
MockedLoadTest
(
PytorchTestCase
):
_load
=
partial
(
get_load_func
(),
backend
=
"soundfile"
)
def
assert_dtype
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
):
"""When format is WAV or NIST, normalize=False will return the native dtype Tensor, otherwise float32"""
num_frames
=
3
*
sample_rate
path
=
_get_mock_path
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
expected_dtype
=
torch
.
float32
if
normalize
or
ext
not
in
[
"wav"
,
"nist"
]
else
getattr
(
torch
,
dtype
)
with
patch
(
"soundfile.SoundFile"
,
SoundFileMock
):
found
,
sr
=
self
.
_load
(
path
,
normalize
=
normalize
,
channels_first
=
channels_first
)
assert
found
.
dtype
==
expected_dtype
assert
sample_rate
==
sr
@
parameterize
(
[
"uint8"
,
"int16"
,
"int32"
,
"float32"
,
"float64"
],
[
8000
,
16000
],
[
1
,
2
],
[
True
,
False
],
[
True
,
False
],
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
):
"""Returns native dtype when normalize=False else float32"""
self
.
assert_dtype
(
"wav"
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
)
@
parameterize
(
[
"int8"
,
"int16"
,
"int32"
],
[
8000
,
16000
],
[
1
,
2
],
[
True
,
False
],
[
True
,
False
],
)
def
test_sphere
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
):
"""Returns float32 always"""
self
.
assert_dtype
(
"sph"
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
)
@
parameterize
([
8000
,
16000
],
[
1
,
2
],
[
True
,
False
],
[
True
,
False
])
def
test_ogg
(
self
,
sample_rate
,
num_channels
,
normalize
,
channels_first
):
"""Returns float32 always"""
self
.
assert_dtype
(
"ogg"
,
"int16"
,
sample_rate
,
num_channels
,
normalize
,
channels_first
)
@
parameterize
([
8000
,
16000
],
[
1
,
2
],
[
True
,
False
],
[
True
,
False
])
def
test_flac
(
self
,
sample_rate
,
num_channels
,
normalize
,
channels_first
):
"""`soundfile_backend.load` can load ogg format."""
self
.
assert_dtype
(
"flac"
,
"int16"
,
sample_rate
,
num_channels
,
normalize
,
channels_first
)
class
LoadTestBase
(
TempDirMixin
,
PytorchTestCase
):
_load
=
partial
(
get_load_func
(),
backend
=
"soundfile"
)
def
assert_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
=
True
,
duration
=
1
,
):
"""`soundfile_backend.load` can load wav format correctly.
Wav data loaded with soundfile backend should match those with scipy
"""
path
=
self
.
get_temp_path
(
"reference.wav"
)
num_frames
=
duration
*
sample_rate
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
normalize
,
num_frames
=
num_frames
,
channels_first
=
channels_first
,
)
save_wav
(
path
,
data
,
sample_rate
,
channels_first
=
channels_first
)
expected
=
load_wav
(
path
,
normalize
=
normalize
,
channels_first
=
channels_first
)[
0
]
data
,
sr
=
self
.
_load
(
path
,
normalize
=
normalize
,
channels_first
=
channels_first
)
assert
sr
==
sample_rate
self
.
assertEqual
(
data
,
expected
)
def
assert_sphere
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
=
True
,
duration
=
1
,
):
"""`soundfile_backend.load` can load SPHERE format correctly."""
path
=
self
.
get_temp_path
(
"reference.sph"
)
num_frames
=
duration
*
sample_rate
raw
=
get_wav_data
(
dtype
,
num_channels
,
num_frames
=
num_frames
,
normalize
=
False
,
channels_first
=
False
,
)
soundfile
.
write
(
path
,
raw
,
sample_rate
,
subtype
=
dtype2subtype
(
dtype
),
format
=
"NIST"
)
expected
=
normalize_wav
(
raw
.
t
()
if
channels_first
else
raw
)
data
,
sr
=
self
.
_load
(
path
,
channels_first
=
channels_first
)
assert
sr
==
sample_rate
self
.
assertEqual
(
data
,
expected
,
atol
=
1e-4
,
rtol
=
1e-8
)
def
assert_flac
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
=
True
,
duration
=
1
,
):
"""`soundfile_backend.load` can load FLAC format correctly."""
path
=
self
.
get_temp_path
(
"reference.flac"
)
num_frames
=
duration
*
sample_rate
raw
=
get_wav_data
(
dtype
,
num_channels
,
num_frames
=
num_frames
,
normalize
=
False
,
channels_first
=
False
,
)
soundfile
.
write
(
path
,
raw
,
sample_rate
)
expected
=
normalize_wav
(
raw
.
t
()
if
channels_first
else
raw
)
data
,
sr
=
self
.
_load
(
path
,
channels_first
=
channels_first
)
assert
sr
==
sample_rate
self
.
assertEqual
(
data
,
expected
,
atol
=
1e-4
,
rtol
=
1e-8
)
@
skipIfNoModule
(
"soundfile"
)
class
TestLoad
(
LoadTestBase
):
"""Test the correctness of `soundfile_backend.load` for various formats"""
@
parameterize
(
[
"float32"
,
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
[
False
,
True
],
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
):
"""`soundfile_backend.load` can load wav format correctly."""
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
)
@
parameterize
(
[
"int16"
],
[
16000
],
[
2
],
[
False
],
)
def
test_wav_large
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
):
"""`soundfile_backend.load` can load large wav file correctly."""
two_hours
=
2
*
60
*
60
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
normalize
,
duration
=
two_hours
)
@
parameterize
([
"float32"
,
"int32"
,
"int16"
],
[
4
,
8
,
16
,
32
],
[
False
,
True
])
def
test_multiple_channels
(
self
,
dtype
,
num_channels
,
channels_first
):
"""`soundfile_backend.load` can load wav file with more than 2 channels."""
sample_rate
=
8000
normalize
=
False
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
)
@
parameterize
([
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
])
@
skipIfFormatNotSupported
(
"NIST"
)
def
test_sphere
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
):
"""`soundfile_backend.load` can load sphere format correctly."""
self
.
assert_sphere
(
dtype
,
sample_rate
,
num_channels
,
channels_first
)
@
parameterize
([
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
])
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_flac
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
):
"""`soundfile_backend.load` can load flac format correctly."""
self
.
assert_flac
(
dtype
,
sample_rate
,
num_channels
,
channels_first
)
@
skipIfNoModule
(
"soundfile"
)
class
TestLoadFormat
(
TempDirMixin
,
PytorchTestCase
):
"""Given `format` parameter, `so.load` can load files without extension"""
_load
=
partial
(
get_load_func
(),
backend
=
"soundfile"
)
original
=
None
path
=
None
def
_make_file
(
self
,
format_
):
sample_rate
=
8000
path_with_ext
=
self
.
get_temp_path
(
f
"test.
{
format_
}
"
)
data
=
get_wav_data
(
"float32"
,
num_channels
=
2
).
numpy
().
T
soundfile
.
write
(
path_with_ext
,
data
,
sample_rate
)
expected
=
soundfile
.
read
(
path_with_ext
,
dtype
=
"float32"
)[
0
].
T
path
=
os
.
path
.
splitext
(
path_with_ext
)[
0
]
os
.
rename
(
path_with_ext
,
path
)
return
path
,
expected
def
_test_format
(
self
,
format_
):
"""Providing format allows to read file without extension"""
path
,
expected
=
self
.
_make_file
(
format_
)
found
,
_
=
self
.
_load
(
path
)
self
.
assertEqual
(
found
,
expected
)
@
parameterized
.
expand
(
[
(
"WAV"
,),
(
"wav"
,),
]
)
def
test_wav
(
self
,
format_
):
self
.
_test_format
(
format_
)
@
parameterized
.
expand
(
[
(
"FLAC"
,),
(
"flac"
,),
]
)
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_flac
(
self
,
format_
):
self
.
_test_format
(
format_
)
@
skipIfNoModule
(
"soundfile"
)
class
TestFileObject
(
TempDirMixin
,
PytorchTestCase
):
_load
=
partial
(
get_load_func
(),
backend
=
"soundfile"
)
def
_test_fileobj
(
self
,
ext
):
"""Loading audio via file-like object works"""
sample_rate
=
16000
path
=
self
.
get_temp_path
(
f
"test.
{
ext
}
"
)
data
=
get_wav_data
(
"float32"
,
num_channels
=
2
).
numpy
().
T
soundfile
.
write
(
path
,
data
,
sample_rate
)
expected
=
soundfile
.
read
(
path
,
dtype
=
"float32"
)[
0
].
T
with
open
(
path
,
"rb"
)
as
fileobj
:
found
,
sr
=
self
.
_load
(
fileobj
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
def
test_fileobj_wav
(
self
):
"""Loading audio via file-like object works"""
self
.
_test_fileobj
(
"wav"
)
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_fileobj_flac
(
self
):
"""Loading audio via file-like object works"""
self
.
_test_fileobj
(
"flac"
)
def
_test_tarfile
(
self
,
ext
):
"""Loading audio via file-like object works"""
sample_rate
=
16000
audio_file
=
f
"test.
{
ext
}
"
audio_path
=
self
.
get_temp_path
(
audio_file
)
archive_path
=
self
.
get_temp_path
(
"archive.tar.gz"
)
data
=
get_wav_data
(
"float32"
,
num_channels
=
2
).
numpy
().
T
soundfile
.
write
(
audio_path
,
data
,
sample_rate
)
expected
=
soundfile
.
read
(
audio_path
,
dtype
=
"float32"
)[
0
].
T
with
tarfile
.
TarFile
(
archive_path
,
"w"
)
as
tarobj
:
tarobj
.
add
(
audio_path
,
arcname
=
audio_file
)
with
tarfile
.
TarFile
(
archive_path
,
"r"
)
as
tarobj
:
fileobj
=
tarobj
.
extractfile
(
audio_file
)
found
,
sr
=
self
.
_load
(
fileobj
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
def
test_tarfile_wav
(
self
):
"""Loading audio via file-like object works"""
self
.
_test_tarfile
(
"wav"
)
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_tarfile_flac
(
self
):
"""Loading audio via file-like object works"""
self
.
_test_tarfile
(
"flac"
)
test/torchaudio_unittest/backend/dispatcher/soundfile/save_test.py
0 → 100644
View file @
ffeba11a
import
io
from
functools
import
partial
from
unittest.mock
import
patch
from
torchaudio._backend.utils
import
get_save_func
from
torchaudio._internal
import
module_utils
as
_mod_utils
from
torchaudio_unittest.common_utils
import
(
get_wav_data
,
load_wav
,
nested_params
,
PytorchTestCase
,
skipIfNoModule
,
TempDirMixin
,
)
from
.common
import
fetch_wav_subtype
,
parameterize
,
skipIfFormatNotSupported
if
_mod_utils
.
is_module_available
(
"soundfile"
):
import
soundfile
class
MockedSaveTest
(
PytorchTestCase
):
_save
=
partial
(
get_save_func
(),
backend
=
"soundfile"
)
@
nested_params
(
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
[
(
None
,
None
),
(
"PCM_U"
,
None
),
(
"PCM_U"
,
8
),
(
"PCM_S"
,
None
),
(
"PCM_S"
,
16
),
(
"PCM_S"
,
32
),
(
"PCM_F"
,
None
),
(
"PCM_F"
,
32
),
(
"PCM_F"
,
64
),
(
"ULAW"
,
None
),
(
"ULAW"
,
8
),
(
"ALAW"
,
None
),
(
"ALAW"
,
8
),
],
)
@
patch
(
"soundfile.write"
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
enc_params
,
mocked_write
):
"""self._save passes correct subtype to soundfile.write when WAV"""
filepath
=
"foo.wav"
input_tensor
=
get_wav_data
(
dtype
,
num_channels
,
num_frames
=
3
*
sample_rate
,
normalize
=
dtype
==
"float32"
,
channels_first
=
channels_first
,
).
t
()
encoding
,
bits_per_sample
=
enc_params
self
.
_save
(
filepath
,
input_tensor
,
sample_rate
,
channels_first
=
channels_first
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
)
# on +Py3.8 call_args.kwargs is more descreptive
args
=
mocked_write
.
call_args
[
1
]
assert
args
[
"file"
]
==
filepath
assert
args
[
"samplerate"
]
==
sample_rate
assert
args
[
"subtype"
]
==
fetch_wav_subtype
(
dtype
,
encoding
,
bits_per_sample
)
assert
args
[
"format"
]
is
None
self
.
assertEqual
(
args
[
"data"
],
input_tensor
.
t
()
if
channels_first
else
input_tensor
)
@
patch
(
"soundfile.write"
)
def
assert_non_wav
(
self
,
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
mocked_write
,
encoding
=
None
,
bits_per_sample
=
None
,
):
"""self._save passes correct subtype and format to soundfile.write when SPHERE"""
filepath
=
f
"foo.
{
fmt
}
"
input_tensor
=
get_wav_data
(
dtype
,
num_channels
,
num_frames
=
3
*
sample_rate
,
normalize
=
False
,
channels_first
=
channels_first
,
).
t
()
expected_data
=
input_tensor
.
t
()
if
channels_first
else
input_tensor
self
.
_save
(
filepath
,
input_tensor
,
sample_rate
,
channels_first
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
)
# on +Py3.8 call_args.kwargs is more descreptive
args
=
mocked_write
.
call_args
[
1
]
assert
args
[
"file"
]
==
filepath
assert
args
[
"samplerate"
]
==
sample_rate
if
fmt
in
[
"sph"
,
"nist"
,
"nis"
]:
assert
args
[
"format"
]
==
"NIST"
else
:
assert
args
[
"format"
]
is
None
self
.
assertEqual
(
args
[
"data"
],
expected_data
)
@
nested_params
(
[
"sph"
,
"nist"
,
"nis"
],
[
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
[
(
"PCM_S"
,
8
),
(
"PCM_S"
,
16
),
(
"PCM_S"
,
24
),
(
"PCM_S"
,
32
),
(
"ULAW"
,
8
),
(
"ALAW"
,
8
),
(
"ALAW"
,
16
),
(
"ALAW"
,
24
),
(
"ALAW"
,
32
),
],
)
def
test_sph
(
self
,
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
enc_params
):
"""self._save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
encoding
,
bits_per_sample
=
enc_params
self
.
assert_non_wav
(
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
)
@
parameterize
(
[
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
[
8
,
16
,
24
],
)
def
test_flac
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
bits_per_sample
):
"""self._save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
self
.
assert_non_wav
(
"flac"
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
bits_per_sample
=
bits_per_sample
)
@
parameterize
(
[
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
)
def
test_ogg
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
):
"""self._save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
self
.
assert_non_wav
(
"ogg"
,
dtype
,
sample_rate
,
num_channels
,
channels_first
)
@
skipIfNoModule
(
"soundfile"
)
class
SaveTestBase
(
TempDirMixin
,
PytorchTestCase
):
_save
=
partial
(
get_save_func
(),
backend
=
"soundfile"
)
def
assert_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
"""`self._save` can save wav format."""
path
=
self
.
get_temp_path
(
"data.wav"
)
expected
=
get_wav_data
(
dtype
,
num_channels
,
num_frames
=
num_frames
,
normalize
=
False
)
self
.
_save
(
path
,
expected
,
sample_rate
)
found
,
sr
=
load_wav
(
path
,
normalize
=
False
)
assert
sample_rate
==
sr
self
.
assertEqual
(
found
,
expected
)
def
_assert_non_wav
(
self
,
fmt
,
dtype
,
sample_rate
,
num_channels
):
"""`self._save` can save non-wav format.
Due to precision missmatch, and the lack of alternative way to decode the
resulting files without using soundfile, only meta data are validated.
"""
num_frames
=
sample_rate
*
3
path
=
self
.
get_temp_path
(
f
"data.
{
fmt
}
"
)
expected
=
get_wav_data
(
dtype
,
num_channels
,
num_frames
=
num_frames
,
normalize
=
False
)
self
.
_save
(
path
,
expected
,
sample_rate
)
sinfo
=
soundfile
.
info
(
path
)
assert
sinfo
.
format
==
fmt
.
upper
()
assert
sinfo
.
frames
==
num_frames
assert
sinfo
.
channels
==
num_channels
assert
sinfo
.
samplerate
==
sample_rate
def
assert_flac
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`self._save` can save flac format."""
self
.
_assert_non_wav
(
"flac"
,
dtype
,
sample_rate
,
num_channels
)
def
assert_sphere
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`self._save` can save sph format."""
self
.
_assert_non_wav
(
"nist"
,
dtype
,
sample_rate
,
num_channels
)
def
assert_ogg
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`self._save` can save ogg format.
As we cannot inspect the OGG format (it's lossy), we only check the metadata.
"""
self
.
_assert_non_wav
(
"ogg"
,
dtype
,
sample_rate
,
num_channels
)
@
skipIfNoModule
(
"soundfile"
)
class
TestSave
(
SaveTestBase
):
@
parameterize
(
[
"float32"
,
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`self._save` can save wav format."""
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
num_frames
=
None
)
@
parameterize
(
[
"float32"
,
"int32"
,
"int16"
],
[
4
,
8
,
16
,
32
],
)
def
test_multiple_channels
(
self
,
dtype
,
num_channels
):
"""`self._save` can save wav with more than 2 channels."""
sample_rate
=
8000
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
num_frames
=
None
)
@
parameterize
(
[
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
)
@
skipIfFormatNotSupported
(
"NIST"
)
def
test_sphere
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`self._save` can save sph format."""
self
.
assert_sphere
(
dtype
,
sample_rate
,
num_channels
)
@
parameterize
(
[
8000
,
16000
],
[
1
,
2
],
)
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_flac
(
self
,
sample_rate
,
num_channels
):
"""`self._save` can save flac format."""
self
.
assert_flac
(
"float32"
,
sample_rate
,
num_channels
)
@
parameterize
(
[
8000
,
16000
],
[
1
,
2
],
)
@
skipIfFormatNotSupported
(
"OGG"
)
def
test_ogg
(
self
,
sample_rate
,
num_channels
):
"""`self._save` can save ogg/vorbis format."""
self
.
assert_ogg
(
"float32"
,
sample_rate
,
num_channels
)
@
skipIfNoModule
(
"soundfile"
)
class
TestSaveParams
(
TempDirMixin
,
PytorchTestCase
):
"""Test the correctness of optional parameters of `self._save`"""
_save
=
partial
(
get_save_func
(),
backend
=
"soundfile"
)
@
parameterize
([
True
,
False
])
def
test_channels_first
(
self
,
channels_first
):
"""channels_first swaps axes"""
path
=
self
.
get_temp_path
(
"data.wav"
)
data
=
get_wav_data
(
"int32"
,
2
,
channels_first
=
channels_first
)
self
.
_save
(
path
,
data
,
8000
,
channels_first
=
channels_first
)
found
=
load_wav
(
path
)[
0
]
expected
=
data
if
channels_first
else
data
.
transpose
(
1
,
0
)
self
.
assertEqual
(
found
,
expected
,
atol
=
1e-4
,
rtol
=
1e-8
)
@
skipIfNoModule
(
"soundfile"
)
class
TestFileObject
(
TempDirMixin
,
PytorchTestCase
):
_save
=
partial
(
get_save_func
(),
backend
=
"soundfile"
)
def
_test_fileobj
(
self
,
ext
):
"""Saving audio to file-like object works"""
sample_rate
=
16000
path
=
self
.
get_temp_path
(
f
"test.
{
ext
}
"
)
subtype
=
"FLOAT"
if
ext
==
"wav"
else
None
data
=
get_wav_data
(
"float32"
,
num_channels
=
2
)
soundfile
.
write
(
path
,
data
.
numpy
().
T
,
sample_rate
,
subtype
=
subtype
)
expected
=
soundfile
.
read
(
path
,
dtype
=
"float32"
)[
0
]
fileobj
=
io
.
BytesIO
()
self
.
_save
(
fileobj
,
data
,
sample_rate
,
format
=
ext
)
fileobj
.
seek
(
0
)
found
,
sr
=
soundfile
.
read
(
fileobj
,
dtype
=
"float32"
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
,
atol
=
1e-4
,
rtol
=
1e-8
)
def
test_fileobj_wav
(
self
):
"""Saving audio via file-like object works"""
self
.
_test_fileobj
(
"wav"
)
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_fileobj_flac
(
self
):
"""Saving audio via file-like object works"""
self
.
_test_fileobj
(
"flac"
)
@
skipIfFormatNotSupported
(
"NIST"
)
def
test_fileobj_nist
(
self
):
"""Saving audio via file-like object works"""
self
.
_test_fileobj
(
"NIST"
)
@
skipIfFormatNotSupported
(
"OGG"
)
def
test_fileobj_ogg
(
self
):
"""Saving audio via file-like object works"""
self
.
_test_fileobj
(
"OGG"
)
test/torchaudio_unittest/backend/dispatcher/sox/__init__.py
0 → 100644
View file @
ffeba11a
test/torchaudio_unittest/backend/dispatcher/sox/common.py
0 → 100644
View file @
ffeba11a
def
name_func
(
func
,
_
,
params
):
return
f
'
{
func
.
__name__
}
_
{
"_"
.
join
(
str
(
arg
)
for
arg
in
params
.
args
)
}
'
def
get_enc_params
(
dtype
):
if
dtype
==
"float32"
:
return
"PCM_F"
,
32
if
dtype
==
"int32"
:
return
"PCM_S"
,
32
if
dtype
==
"int16"
:
return
"PCM_S"
,
16
if
dtype
==
"uint8"
:
return
"PCM_U"
,
8
raise
ValueError
(
f
"Unexpected dtype:
{
dtype
}
"
)
test/torchaudio_unittest/backend/dispatcher/sox/info_test.py
0 → 100644
View file @
ffeba11a
import
itertools
import
os
from
functools
import
partial
from
parameterized
import
parameterized
from
torchaudio._backend.utils
import
get_info_func
from
torchaudio._internal
import
module_utils
as
_mod_utils
from
torchaudio_unittest.backend.common
import
get_encoding
from
torchaudio_unittest.common_utils
import
(
disabledInCI
,
get_asset_path
,
get_wav_data
,
HttpServerMixin
,
PytorchTestCase
,
save_wav
,
skipIfNoExec
,
skipIfNoModule
,
skipIfNoSox
,
skipIfNoSoxDecoder
,
sox_utils
,
TempDirMixin
,
)
from
.common
import
name_func
if
_mod_utils
.
is_module_available
(
"requests"
):
import
requests
@
skipIfNoExec
(
"sox"
)
@
skipIfNoSox
class
TestInfo
(
TempDirMixin
,
PytorchTestCase
):
_info
=
partial
(
get_info_func
(),
backend
=
"sox"
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
[
8000
,
16000
],
[
1
,
2
],
)
),
name_func
=
name_func
,
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`self._info` can check wav file correctly"""
duration
=
1
path
=
self
.
get_temp_path
(
"data.wav"
)
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
path
,
data
,
sample_rate
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
sox_utils
.
get_bit_depth
(
dtype
)
assert
info
.
encoding
==
get_encoding
(
"wav"
,
dtype
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
[
8000
,
16000
],
[
4
,
8
,
16
,
32
],
)
),
name_func
=
name_func
,
)
def
test_wav_multiple_channels
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`self._info` can check wav file with channels more than 2 correctly"""
duration
=
1
path
=
self
.
get_temp_path
(
"data.wav"
)
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
path
,
data
,
sample_rate
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
sox_utils
.
get_bit_depth
(
dtype
)
assert
info
.
encoding
==
get_encoding
(
"wav"
,
dtype
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
list
(
range
(
9
)),
)
),
name_func
=
name_func
,
)
def
test_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
"""`self._info` can check flac file correctly"""
duration
=
1
path
=
self
.
get_temp_path
(
"data.flac"
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
compression
=
compression_level
,
duration
=
duration
,
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
24
# FLAC standard
assert
info
.
encoding
==
"FLAC"
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
[
-
1
,
0
,
1
,
2
,
3
,
3.6
,
5
,
10
],
)
),
name_func
=
name_func
,
)
def
test_vorbis
(
self
,
sample_rate
,
num_channels
,
quality_level
):
"""`self._info` can check vorbis file correctly"""
duration
=
1
path
=
self
.
get_temp_path
(
"data.vorbis"
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
compression
=
quality_level
,
duration
=
duration
,
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
info
.
encoding
==
"VORBIS"
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
[
16
,
32
],
)
),
name_func
=
name_func
,
)
def
test_sphere
(
self
,
sample_rate
,
num_channels
,
bits_per_sample
):
"""`self._info` can check sph file correctly"""
duration
=
1
path
=
self
.
get_temp_path
(
"data.sph"
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
duration
=
duration
,
bit_depth
=
bits_per_sample
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
encoding
==
"PCM_S"
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"int32"
,
"int16"
,
"uint8"
],
[
8000
,
16000
],
[
1
,
2
],
)
),
name_func
=
name_func
,
)
def
test_amb
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`self._info` can check amb file correctly"""
duration
=
1
path
=
self
.
get_temp_path
(
"data.amb"
)
bits_per_sample
=
sox_utils
.
get_bit_depth
(
dtype
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
bit_depth
=
bits_per_sample
,
duration
=
duration
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
encoding
==
get_encoding
(
"amb"
,
dtype
)
@
skipIfNoSoxDecoder
(
"amr-nb"
)
def
test_amr_nb
(
self
):
"""`self._info` can check amr-nb file correctly"""
duration
=
1
num_channels
=
1
sample_rate
=
8000
path
=
self
.
get_temp_path
(
"data.amr-nb"
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
bit_depth
=
16
,
duration
=
duration
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
assert
info
.
encoding
==
"AMR_NB"
def
test_ulaw
(
self
):
"""`self._info` can check ulaw file correctly"""
duration
=
1
num_channels
=
1
sample_rate
=
8000
path
=
self
.
get_temp_path
(
"data.wav"
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
bit_depth
=
8
,
encoding
=
"u-law"
,
duration
=
duration
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
8
assert
info
.
encoding
==
"ULAW"
def
test_alaw
(
self
):
"""`self._info` can check alaw file correctly"""
duration
=
1
num_channels
=
1
sample_rate
=
8000
path
=
self
.
get_temp_path
(
"data.wav"
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
bit_depth
=
8
,
encoding
=
"a-law"
,
duration
=
duration
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
8
assert
info
.
encoding
==
"ALAW"
def
test_gsm
(
self
):
"""`self._info` can check gsm file correctly"""
duration
=
1
num_channels
=
1
sample_rate
=
8000
path
=
self
.
get_temp_path
(
"data.gsm"
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
duration
=
duration
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
assert
info
.
encoding
==
"GSM"
def
test_htk
(
self
):
"""`self._info` can check HTK file correctly"""
duration
=
1
num_channels
=
1
sample_rate
=
8000
path
=
self
.
get_temp_path
(
"data.htk"
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
bit_depth
=
16
,
duration
=
duration
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
16
assert
info
.
encoding
==
"PCM_S"
@
disabledInCI
@
skipIfNoSoxDecoder
(
"opus"
)
class
TestInfoOpus
(
PytorchTestCase
):
_info
=
partial
(
get_info_func
(),
backend
=
"sox"
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"96k"
],
[
1
,
2
],
[
0
,
5
,
10
],
)
),
name_func
=
name_func
,
)
def
test_opus
(
self
,
bitrate
,
num_channels
,
compression_level
):
"""`self._info` can check opus file correcty"""
path
=
get_asset_path
(
"io"
,
f
"
{
bitrate
}
_
{
compression_level
}
_
{
num_channels
}
ch.opus"
)
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
48000
assert
info
.
num_frames
==
32768
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
info
.
encoding
==
"OPUS"
class
FileObjTestBase
(
TempDirMixin
):
def
_gen_file
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
*
,
comments
=
None
):
path
=
self
.
get_temp_path
(
f
"test.
{
ext
}
"
)
bit_depth
=
sox_utils
.
get_bit_depth
(
dtype
)
duration
=
num_frames
/
sample_rate
comment_file
=
self
.
_gen_comment_file
(
comments
)
if
comments
else
None
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
num_channels
,
encoding
=
sox_utils
.
get_encoding
(
dtype
),
bit_depth
=
bit_depth
,
duration
=
duration
,
comment_file
=
comment_file
,
)
return
path
def
_gen_comment_file
(
self
,
comments
):
comment_path
=
self
.
get_temp_path
(
"comment.txt"
)
with
open
(
comment_path
,
"w"
)
as
file_
:
file_
.
writelines
(
comments
)
return
comment_path
class
Unseekable
:
def
__init__
(
self
,
fileobj
):
self
.
fileobj
=
fileobj
def
read
(
self
,
n
):
return
self
.
fileobj
.
read
(
n
)
@
skipIfNoSox
@
skipIfNoExec
(
"sox"
)
class
TestFileObject
(
FileObjTestBase
,
PytorchTestCase
):
_info
=
partial
(
get_info_func
(),
backend
=
"sox"
)
def
_query_fileobj
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
*
,
comments
=
None
):
path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
comments
=
comments
)
with
open
(
path
,
"rb"
)
as
fileobj
:
return
self
.
_info
(
fileobj
,
None
)
@
parameterized
.
expand
(
[
(
"wav"
,
"float32"
),
(
"wav"
,
"int32"
),
(
"wav"
,
"int16"
),
(
"wav"
,
"uint8"
),
# ("mp3", "float32"),
(
"flac"
,
"float32"
),
(
"vorbis"
,
"float32"
),
(
"amb"
,
"int16"
),
]
)
def
test_fileobj
(
self
,
ext
,
dtype
):
"""Querying audio via file object works"""
sample_rate
=
16000
num_frames
=
3
*
sample_rate
num_channels
=
2
with
self
.
assertRaisesRegex
(
ValueError
,
"SoX backend does not support reading"
):
self
.
_query_fileobj
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
@
skipIfNoSox
@
skipIfNoExec
(
"sox"
)
@
skipIfNoModule
(
"requests"
)
class
TestFileObjectHttp
(
HttpServerMixin
,
FileObjTestBase
,
PytorchTestCase
):
_info
=
partial
(
get_info_func
(),
backend
=
"sox"
)
def
_query_http
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
audio_path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
audio_file
=
os
.
path
.
basename
(
audio_path
)
url
=
self
.
get_url
(
audio_file
)
# format_ = ext if ext in ["mp3"] else None
with
requests
.
get
(
url
,
stream
=
True
)
as
resp
:
return
self
.
_info
(
Unseekable
(
resp
.
raw
),
format
=
None
)
@
parameterized
.
expand
(
[
(
"wav"
,
"float32"
),
(
"wav"
,
"int32"
),
(
"wav"
,
"int16"
),
(
"wav"
,
"uint8"
),
# ("mp3", "float32"),
(
"flac"
,
"float32"
),
(
"vorbis"
,
"float32"
),
(
"amb"
,
"int16"
),
]
)
def
test_requests
(
self
,
ext
,
dtype
):
"""Querying compressed audio via requests works"""
sample_rate
=
16000
num_frames
=
3.0
*
sample_rate
num_channels
=
2
with
self
.
assertRaisesRegex
(
ValueError
,
"SoX backend does not support reading"
):
self
.
_query_http
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
@
skipIfNoSox
class
TestInfoNoSuchFile
(
PytorchTestCase
):
_info
=
partial
(
get_info_func
(),
backend
=
"sox"
)
def
test_info_fail
(
self
):
"""
When attempted to get info on a non-existing file, error message must contain the file path.
"""
path
=
"non_existing_audio.wav"
with
self
.
assertRaisesRegex
(
RuntimeError
,
path
):
self
.
_info
(
path
)
test/torchaudio_unittest/backend/dispatcher/sox/load_test.py
0 → 100644
View file @
ffeba11a
import
itertools
from
functools
import
partial
import
torch
from
parameterized
import
parameterized
from
torchaudio._backend.utils
import
get_load_func
from
torchaudio_unittest.common_utils
import
(
get_asset_path
,
get_wav_data
,
load_wav
,
nested_params
,
PytorchTestCase
,
save_wav
,
skipIfNoExec
,
skipIfNoSox
,
skipIfNoSoxDecoder
,
sox_utils
,
TempDirMixin
,
)
from
.common
import
name_func
class
LoadTestBase
(
TempDirMixin
,
PytorchTestCase
):
_load
=
partial
(
get_load_func
(),
backend
=
"sox"
)
def
assert_format
(
self
,
format
:
str
,
sample_rate
:
float
,
num_channels
:
int
,
compression
:
float
=
None
,
bit_depth
:
int
=
None
,
duration
:
float
=
1
,
normalize
:
bool
=
True
,
encoding
:
str
=
None
,
atol
:
float
=
4e-05
,
rtol
:
float
=
1.3e-06
,
):
"""`sox_io_backend.load` can load given format correctly.
file encodings introduce delay and boundary effects so
we create a reference wav file from the original file format
x
|
| 1. Generate given format with Sox
|
v 2. Convert to wav with Sox
given format ----------------------> wav
| |
| 3. Load with torchaudio | 4. Load with scipy
| |
v v
tensor ----------> x <----------- tensor
5. Compare
Underlying assumptions are;
i. Conversion of given format to wav with Sox preserves data.
ii. Loading wav file with scipy is correct.
By combining i & ii, step 2. and 4. allows to load reference given format
data without using torchaudio
"""
path
=
self
.
get_temp_path
(
f
"1.original.
{
format
}
"
)
ref_path
=
self
.
get_temp_path
(
"2.reference.wav"
)
# 1. Generate the given format with sox
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
encoding
=
encoding
,
compression
=
compression
,
bit_depth
=
bit_depth
,
duration
=
duration
,
)
# 2. Convert to wav with sox
wav_bit_depth
=
32
if
bit_depth
==
24
else
None
# for 24-bit wav
sox_utils
.
convert_audio_file
(
path
,
ref_path
,
bit_depth
=
wav_bit_depth
)
# 3. Load the given format with torchaudio
data
,
sr
=
self
.
_load
(
path
,
normalize
=
normalize
)
# 4. Load wav with scipy
data_ref
=
load_wav
(
ref_path
,
normalize
=
normalize
)[
0
]
# 5. Compare
assert
sr
==
sample_rate
self
.
assertEqual
(
data
,
data_ref
,
atol
=
atol
,
rtol
=
rtol
)
def
assert_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
duration
):
"""`sox_io_backend.load` can load wav format correctly.
Wav data loaded with sox_io backend should match those with scipy
"""
path
=
self
.
get_temp_path
(
"reference.wav"
)
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
normalize
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
path
,
data
,
sample_rate
)
expected
=
load_wav
(
path
,
normalize
=
normalize
)[
0
]
data
,
sr
=
self
.
_load
(
path
,
normalize
=
normalize
)
assert
sr
==
sample_rate
self
.
assertEqual
(
data
,
expected
)
@
skipIfNoExec
(
"sox"
)
@
skipIfNoSox
class
TestLoad
(
LoadTestBase
):
"""Test the correctness of `sox_io_backend.load` for various formats"""
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
)
),
name_func
=
name_func
,
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
):
"""`sox_io_backend.load` can load wav format correctly."""
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
normalize
,
duration
=
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
)
),
name_func
=
name_func
,
)
def
test_24bit_wav
(
self
,
sample_rate
,
num_channels
,
normalize
):
"""`sox_io_backend.load` can load 24bit wav format correctly. Corectly casts it to ``int32`` tensor dtype."""
self
.
assert_format
(
"wav"
,
sample_rate
,
num_channels
,
bit_depth
=
24
,
normalize
=
normalize
,
duration
=
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"int16"
],
[
16000
],
[
2
],
[
False
],
)
),
name_func
=
name_func
,
)
def
test_wav_large
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
):
"""`sox_io_backend.load` can load large wav file correctly."""
two_hours
=
2
*
60
*
60
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
normalize
,
two_hours
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
[
4
,
8
,
16
,
32
],
)
),
name_func
=
name_func
,
)
def
test_multiple_channels
(
self
,
dtype
,
num_channels
):
"""`sox_io_backend.load` can load wav file with more than 2 channels."""
sample_rate
=
8000
normalize
=
False
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
normalize
,
duration
=
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
list
(
range
(
9
)),
)
),
name_func
=
name_func
,
)
def
test_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
"""`sox_io_backend.load` can load flac format correctly."""
self
.
assert_format
(
"flac"
,
sample_rate
,
num_channels
,
compression
=
compression_level
,
bit_depth
=
16
,
duration
=
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
16000
],
[
2
],
[
0
],
)
),
name_func
=
name_func
,
)
def
test_flac_large
(
self
,
sample_rate
,
num_channels
,
compression_level
):
"""`sox_io_backend.load` can load large flac file correctly."""
two_hours
=
2
*
60
*
60
self
.
assert_format
(
"flac"
,
sample_rate
,
num_channels
,
compression
=
compression_level
,
bit_depth
=
16
,
duration
=
two_hours
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
[
-
1
,
0
,
1
,
2
,
3
,
3.6
,
5
,
10
],
)
),
name_func
=
name_func
,
)
def
test_vorbis
(
self
,
sample_rate
,
num_channels
,
quality_level
):
"""`sox_io_backend.load` can load vorbis format correctly."""
self
.
assert_format
(
"vorbis"
,
sample_rate
,
num_channels
,
compression
=
quality_level
,
bit_depth
=
16
,
duration
=
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
16000
],
[
2
],
[
10
],
)
),
name_func
=
name_func
,
)
def
test_vorbis_large
(
self
,
sample_rate
,
num_channels
,
quality_level
):
"""`sox_io_backend.load` can load large vorbis file correctly."""
two_hours
=
2
*
60
*
60
self
.
assert_format
(
"vorbis"
,
sample_rate
,
num_channels
,
compression
=
quality_level
,
bit_depth
=
16
,
duration
=
two_hours
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"96k"
],
[
1
,
2
],
[
0
,
5
,
10
],
)
),
name_func
=
name_func
,
)
@
skipIfNoSoxDecoder
(
"opus"
)
def
test_opus
(
self
,
bitrate
,
num_channels
,
compression_level
):
"""`sox_io_backend.load` can load opus file correctly."""
ops_path
=
get_asset_path
(
"io"
,
f
"
{
bitrate
}
_
{
compression_level
}
_
{
num_channels
}
ch.opus"
)
wav_path
=
self
.
get_temp_path
(
f
"
{
bitrate
}
_
{
compression_level
}
_
{
num_channels
}
ch.opus.wav"
)
sox_utils
.
convert_audio_file
(
ops_path
,
wav_path
)
expected
,
sample_rate
=
load_wav
(
wav_path
)
found
,
sr
=
self
.
_load
(
ops_path
)
assert
sample_rate
==
sr
self
.
assertEqual
(
expected
,
found
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
)
),
name_func
=
name_func
,
)
def
test_sphere
(
self
,
sample_rate
,
num_channels
):
"""`sox_io_backend.load` can load sph format correctly."""
self
.
assert_format
(
"sph"
,
sample_rate
,
num_channels
,
bit_depth
=
32
,
duration
=
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"float32"
,
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
)
),
name_func
=
name_func
,
)
def
test_amb
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
):
"""`sox_io_backend.load` can load amb format correctly."""
bit_depth
=
sox_utils
.
get_bit_depth
(
dtype
)
encoding
=
sox_utils
.
get_encoding
(
dtype
)
self
.
assert_format
(
"amb"
,
sample_rate
,
num_channels
,
bit_depth
=
bit_depth
,
duration
=
1
,
encoding
=
encoding
,
normalize
=
normalize
)
@
skipIfNoSoxDecoder
(
"amr-nb"
)
def
test_amr_nb
(
self
):
"""`sox_io_backend.load` can load amr_nb format correctly."""
self
.
assert_format
(
"amr-nb"
,
sample_rate
=
8000
,
num_channels
=
1
,
bit_depth
=
32
,
duration
=
1
)
@
skipIfNoSox
class
TestLoadParams
(
TempDirMixin
,
PytorchTestCase
):
"""Test the correctness of frame parameters of `sox_io_backend.load`"""
def
_test
(
self
,
func
,
frame_offset
,
num_frames
,
channels_first
,
normalize
):
original
=
get_wav_data
(
"int16"
,
num_channels
=
2
,
normalize
=
False
)
path
=
self
.
get_temp_path
(
"test.wav"
)
save_wav
(
path
,
original
,
sample_rate
=
8000
)
output
,
_
=
func
(
path
,
frame_offset
,
num_frames
,
normalize
,
channels_first
,
None
)
frame_end
=
None
if
num_frames
==
-
1
else
frame_offset
+
num_frames
expected
=
original
[:,
slice
(
frame_offset
,
frame_end
)]
if
not
channels_first
:
expected
=
expected
.
T
if
normalize
:
expected
=
expected
.
to
(
torch
.
float32
)
/
(
2
**
15
)
self
.
assertEqual
(
output
,
expected
)
@
nested_params
(
[
0
,
1
,
10
,
100
,
1000
],
[
-
1
,
1
,
10
,
100
,
1000
],
[
True
,
False
],
[
True
,
False
],
)
def
test_sox
(
self
,
frame_offset
,
num_frames
,
channels_first
,
normalize
):
"""The combination of properly changes the output tensor"""
self
.
_test
(
torch
.
ops
.
torchaudio
.
sox_io_load_audio_file
,
frame_offset
,
num_frames
,
channels_first
,
normalize
)
@
skipIfNoSox
@
skipIfNoExec
(
"sox"
)
class
TestFileObject
(
TempDirMixin
,
PytorchTestCase
):
"""
In this test suite, the result of file-like object input is compared against file path input,
because `load` function is rigrously tested for file path inputs to match libsox's result,
"""
_load
=
partial
(
get_load_func
(),
backend
=
"sox"
)
@
parameterized
.
expand
(
[
(
"wav"
,
{
"bit_depth"
:
16
}),
(
"wav"
,
{
"bit_depth"
:
24
}),
(
"wav"
,
{
"bit_depth"
:
32
}),
(
"flac"
,
{
"compression"
:
0
}),
(
"flac"
,
{
"compression"
:
5
}),
(
"flac"
,
{
"compression"
:
8
}),
(
"vorbis"
,
{
"compression"
:
-
1
}),
(
"vorbis"
,
{
"compression"
:
10
}),
(
"amb"
,
{}),
]
)
def
test_fileobj
(
self
,
ext
,
kwargs
):
"""Loading audio via file object returns the same result as via file path."""
sample_rate
=
16000
format_
=
ext
if
ext
in
[
"mp3"
]
else
None
path
=
self
.
get_temp_path
(
f
"test.
{
ext
}
"
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
2
,
**
kwargs
)
expected
,
_
=
self
.
_load
(
path
)
with
open
(
path
,
"rb"
)
as
fileobj
:
with
self
.
assertRaisesRegex
(
ValueError
,
"SoX backend does not support loading"
):
self
.
_load
(
fileobj
,
format
=
format_
)
@
skipIfNoSox
class
TestLoadNoSuchFile
(
PytorchTestCase
):
_load
=
partial
(
get_load_func
(),
backend
=
"sox"
)
def
test_load_fail
(
self
):
"""
When attempted to load a non-existing file, error message must contain the file path.
"""
path
=
"non_existing_audio.wav"
with
self
.
assertRaisesRegex
(
RuntimeError
,
path
):
self
.
_load
(
path
)
test/torchaudio_unittest/backend/dispatcher/sox/roundtrip_test.py
0 → 100644
View file @
ffeba11a
import
itertools
from
functools
import
partial
from
parameterized
import
parameterized
from
torchaudio._backend.utils
import
get_load_func
,
get_save_func
from
torchaudio_unittest.common_utils
import
get_wav_data
,
PytorchTestCase
,
skipIfNoExec
,
skipIfNoSox
,
TempDirMixin
from
.common
import
get_enc_params
,
name_func
@
skipIfNoExec
(
"sox"
)
@
skipIfNoSox
class
TestRoundTripIO
(
TempDirMixin
,
PytorchTestCase
):
"""save/load round trip should not degrade data for lossless formats"""
_load
=
partial
(
get_load_func
(),
backend
=
"sox"
)
_save
=
partial
(
get_save_func
(),
backend
=
"sox"
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
[
8000
,
16000
],
[
1
,
2
],
)
),
name_func
=
name_func
,
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""save/load round trip should not degrade data for wav formats"""
original
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
)
enc
,
bps
=
get_enc_params
(
dtype
)
data
=
original
for
i
in
range
(
10
):
path
=
self
.
get_temp_path
(
f
"
{
i
}
.wav"
)
self
.
_save
(
path
,
data
,
sample_rate
,
encoding
=
enc
,
bits_per_sample
=
bps
)
data
,
sr
=
self
.
_load
(
path
,
normalize
=
False
)
assert
sr
==
sample_rate
self
.
assertEqual
(
original
,
data
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
)
),
name_func
=
name_func
,
)
def
test_flac
(
self
,
sample_rate
,
num_channels
):
"""save/load round trip should not degrade data for flac formats"""
original
=
get_wav_data
(
"float32"
,
num_channels
)
data
=
original
for
i
in
range
(
10
):
path
=
self
.
get_temp_path
(
f
"
{
i
}
.flac"
)
self
.
_save
(
path
,
data
,
sample_rate
)
data
,
sr
=
self
.
_load
(
path
)
assert
sr
==
sample_rate
self
.
assertEqual
(
original
,
data
)
test/torchaudio_unittest/backend/dispatcher/sox/save_test.py
0 → 100644
View file @
ffeba11a
import
io
import
os
from
functools
import
partial
import
torch
from
parameterized
import
parameterized
from
torchaudio._backend.utils
import
get_save_func
from
torchaudio_unittest.common_utils
import
(
get_wav_data
,
load_wav
,
nested_params
,
PytorchTestCase
,
save_wav
,
skipIfNoExec
,
skipIfNoSox
,
skipIfNoSoxEncoder
,
sox_utils
,
TempDirMixin
,
TorchaudioTestCase
,
)
from
.common
import
get_enc_params
,
name_func
def
_get_sox_encoding
(
encoding
):
encodings
=
{
"PCM_F"
:
"floating-point"
,
"PCM_S"
:
"signed-integer"
,
"PCM_U"
:
"unsigned-integer"
,
"ULAW"
:
"u-law"
,
"ALAW"
:
"a-law"
,
}
return
encodings
.
get
(
encoding
)
class
SaveTestBase
(
TempDirMixin
,
TorchaudioTestCase
):
_save
=
partial
(
get_save_func
(),
backend
=
"sox"
)
def
assert_save_consistency
(
self
,
format
:
str
,
*
,
compression
:
float
=
None
,
encoding
:
str
=
None
,
bits_per_sample
:
int
=
None
,
sample_rate
:
float
=
8000
,
num_channels
:
int
=
2
,
num_frames
:
float
=
3
*
8000
,
src_dtype
:
str
=
"int32"
,
test_mode
:
str
=
"path"
,
):
"""`save` function produces file that is comparable with `sox` command
To compare that the file produced by `save` function agains the file produced by
the equivalent `sox` command, we need to load both files.
But there are many formats that cannot be opened with common Python modules (like
SciPy).
So we use `sox` command to prepare the original data and convert the saved files
into a format that SciPy can read (PCM wav).
The following diagram illustrates this process. The difference is 2.1. and 3.1.
This assumes that
- loading data with SciPy preserves the data well.
- converting the resulting files into WAV format with `sox` preserve the data well.
x
| 1. Generate source wav file with SciPy
|
v
-------------- wav ----------------
| |
| 2.1. load with scipy | 3.1. Convert to the target
| then save it into the target | format depth with sox
| format with torchaudio |
v v
target format target format
| |
| 2.2. Convert to wav with sox | 3.2. Convert to wav with sox
| |
v v
wav wav
| |
| 2.3. load with scipy | 3.3. load with scipy
| |
v v
tensor -------> compare <--------- tensor
"""
cmp_encoding
=
"floating-point"
cmp_bit_depth
=
32
src_path
=
self
.
get_temp_path
(
"1.source.wav"
)
tgt_path
=
self
.
get_temp_path
(
f
"2.1.torchaudio.
{
format
}
"
)
tst_path
=
self
.
get_temp_path
(
"2.2.result.wav"
)
sox_path
=
self
.
get_temp_path
(
f
"3.1.sox.
{
format
}
"
)
ref_path
=
self
.
get_temp_path
(
"3.2.ref.wav"
)
# 1. Generate original wav
data
=
get_wav_data
(
src_dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
num_frames
)
save_wav
(
src_path
,
data
,
sample_rate
)
# 2.1. Convert the original wav to target format with torchaudio
data
=
load_wav
(
src_path
,
normalize
=
False
)[
0
]
if
test_mode
==
"path"
:
self
.
_save
(
tgt_path
,
data
,
sample_rate
,
compression
=
compression
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
)
elif
test_mode
==
"fileobj"
:
with
open
(
tgt_path
,
"bw"
)
as
file_
:
self
.
_save
(
file_
,
data
,
sample_rate
,
compression
=
compression
,
format
=
format
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
)
elif
test_mode
==
"bytesio"
:
file_
=
io
.
BytesIO
()
self
.
_save
(
file_
,
data
,
sample_rate
,
compression
=
compression
,
format
=
format
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
)
file_
.
seek
(
0
)
with
open
(
tgt_path
,
"bw"
)
as
f
:
f
.
write
(
file_
.
read
())
else
:
raise
ValueError
(
f
"Unexpected test mode:
{
test_mode
}
"
)
# 2.2. Convert the target format to wav with sox
sox_utils
.
convert_audio_file
(
tgt_path
,
tst_path
,
encoding
=
cmp_encoding
,
bit_depth
=
cmp_bit_depth
)
# 2.3. Load with SciPy
found
=
load_wav
(
tst_path
,
normalize
=
False
)[
0
]
# 3.1. Convert the original wav to target format with sox
sox_encoding
=
_get_sox_encoding
(
encoding
)
sox_utils
.
convert_audio_file
(
src_path
,
sox_path
,
compression
=
compression
,
encoding
=
sox_encoding
,
bit_depth
=
bits_per_sample
)
# 3.2. Convert the target format to wav with sox
sox_utils
.
convert_audio_file
(
sox_path
,
ref_path
,
encoding
=
cmp_encoding
,
bit_depth
=
cmp_bit_depth
)
# 3.3. Load with SciPy
expected
=
load_wav
(
ref_path
,
normalize
=
False
)[
0
]
self
.
assertEqual
(
found
,
expected
)
@
skipIfNoExec
(
"sox"
)
@
skipIfNoSox
class
SaveTest
(
SaveTestBase
):
@
nested_params
(
[
(
"PCM_U"
,
8
),
(
"PCM_S"
,
16
),
(
"PCM_S"
,
32
),
(
"PCM_F"
,
32
),
(
"PCM_F"
,
64
),
(
"ULAW"
,
8
),
(
"ALAW"
,
8
),
],
)
def
test_save_wav
(
self
,
enc_params
):
encoding
,
bits_per_sample
=
enc_params
self
.
assert_save_consistency
(
"wav"
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
test_mode
=
"path"
)
@
nested_params
(
[
(
"float32"
,),
(
"int32"
,),
(
"int16"
,),
(
"uint8"
,),
],
)
def
test_save_wav_dtype
(
self
,
params
):
(
dtype
,)
=
params
self
.
assert_save_consistency
(
"wav"
,
src_dtype
=
dtype
,
test_mode
=
"path"
)
@
nested_params
(
[
8
,
16
,
24
],
[
None
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
],
)
def
test_save_flac
(
self
,
bits_per_sample
,
compression_level
):
self
.
assert_save_consistency
(
"flac"
,
compression
=
compression_level
,
bits_per_sample
=
bits_per_sample
,
test_mode
=
"path"
)
def
test_save_htk
(
self
):
self
.
assert_save_consistency
(
"htk"
,
test_mode
=
"path"
,
num_channels
=
1
)
@
nested_params
(
[
None
,
-
1
,
0
,
1
,
2
,
3
,
3.6
,
5
,
10
,
],
)
def
test_save_vorbis
(
self
,
quality_level
):
self
.
assert_save_consistency
(
"vorbis"
,
compression
=
quality_level
,
test_mode
=
"path"
)
@
nested_params
(
[
(
"PCM_S"
,
8
,
),
(
"PCM_S"
,
16
,
),
(
"PCM_S"
,
24
,
),
(
"PCM_S"
,
32
,
),
(
"ULAW"
,
8
),
(
"ALAW"
,
8
),
(
"ALAW"
,
16
),
(
"ALAW"
,
24
),
(
"ALAW"
,
32
),
],
)
def
test_save_sphere
(
self
,
enc_params
):
encoding
,
bits_per_sample
=
enc_params
self
.
assert_save_consistency
(
"sph"
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
test_mode
=
"path"
)
@
nested_params
(
[
(
"PCM_U"
,
8
,
),
(
"PCM_S"
,
16
,
),
(
"PCM_S"
,
24
,
),
(
"PCM_S"
,
32
,
),
(
"PCM_F"
,
32
,
),
(
"PCM_F"
,
64
,
),
(
"ULAW"
,
8
,
),
(
"ALAW"
,
8
,
),
],
)
def
test_save_amb
(
self
,
enc_params
):
encoding
,
bits_per_sample
=
enc_params
self
.
assert_save_consistency
(
"amb"
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
test_mode
=
"path"
)
@
nested_params
(
[
None
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
],
)
@
skipIfNoSoxEncoder
(
"amr-nb"
)
def
test_save_amr_nb
(
self
,
bit_rate
):
self
.
assert_save_consistency
(
"amr-nb"
,
compression
=
bit_rate
,
num_channels
=
1
,
test_mode
=
"path"
)
def
test_save_gsm
(
self
):
self
.
assert_save_consistency
(
"gsm"
,
num_channels
=
1
,
test_mode
=
"path"
)
with
self
.
assertRaises
(
RuntimeError
,
msg
=
"gsm format only supports single channel audio."
):
self
.
assert_save_consistency
(
"gsm"
,
num_channels
=
2
,
test_mode
=
"path"
)
with
self
.
assertRaises
(
RuntimeError
,
msg
=
"gsm format only supports a sampling rate of 8kHz."
):
self
.
assert_save_consistency
(
"gsm"
,
sample_rate
=
16000
,
test_mode
=
"path"
)
@
parameterized
.
expand
(
[
(
"wav"
,
"PCM_S"
,
16
),
(
"flac"
,),
(
"vorbis"
,),
(
"sph"
,
"PCM_S"
,
16
),
(
"amb"
,
"PCM_S"
,
16
),
],
name_func
=
name_func
,
)
def
test_save_large
(
self
,
format
,
encoding
=
None
,
bits_per_sample
=
None
):
self
.
_test_save_large
(
format
,
encoding
,
bits_per_sample
)
@
skipIfNoSoxEncoder
(
"amr-nb"
)
def
test_save_large_amr_nb
(
self
):
self
.
_test_save_large
(
"amr-nb"
)
def
_test_save_large
(
self
,
format
,
encoding
=
None
,
bits_per_sample
=
None
):
"""`self._save` can save large files."""
sample_rate
=
8000
one_hour
=
60
*
60
*
sample_rate
self
.
assert_save_consistency
(
format
,
num_channels
=
1
,
sample_rate
=
8000
,
num_frames
=
one_hour
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
)
@
parameterized
.
expand
(
[
(
32
,),
(
64
,),
(
128
,),
(
256
,),
],
name_func
=
name_func
,
)
def
test_save_multi_channels
(
self
,
num_channels
):
"""`self._save` can save audio with many channels"""
self
.
assert_save_consistency
(
"wav"
,
encoding
=
"PCM_S"
,
bits_per_sample
=
16
,
num_channels
=
num_channels
)
@
skipIfNoExec
(
"sox"
)
@
skipIfNoSox
class
TestSaveParams
(
TempDirMixin
,
PytorchTestCase
):
"""Test the correctness of optional parameters of `self._save`"""
_save
=
partial
(
get_save_func
(),
backend
=
"sox"
)
@
parameterized
.
expand
([(
True
,),
(
False
,)],
name_func
=
name_func
)
def
test_save_channels_first
(
self
,
channels_first
):
"""channels_first swaps axes"""
path
=
self
.
get_temp_path
(
"data.wav"
)
data
=
get_wav_data
(
"int16"
,
2
,
channels_first
=
channels_first
,
normalize
=
False
)
self
.
_save
(
path
,
data
,
8000
,
channels_first
=
channels_first
)
found
=
load_wav
(
path
,
normalize
=
False
)[
0
]
expected
=
data
if
channels_first
else
data
.
transpose
(
1
,
0
)
self
.
assertEqual
(
found
,
expected
)
@
parameterized
.
expand
([
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
name_func
=
name_func
)
def
test_save_noncontiguous
(
self
,
dtype
):
"""Noncontiguous tensors are saved correctly"""
path
=
self
.
get_temp_path
(
"data.wav"
)
enc
,
bps
=
get_enc_params
(
dtype
)
expected
=
get_wav_data
(
dtype
,
4
,
normalize
=
False
)[::
2
,
::
2
]
assert
not
expected
.
is_contiguous
()
self
.
_save
(
path
,
expected
,
8000
,
encoding
=
enc
,
bits_per_sample
=
bps
)
found
=
load_wav
(
path
,
normalize
=
False
)[
0
]
self
.
assertEqual
(
found
,
expected
)
@
parameterized
.
expand
(
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
,
]
)
def
test_save_tensor_preserve
(
self
,
dtype
):
"""save function should not alter Tensor"""
path
=
self
.
get_temp_path
(
"data.wav"
)
expected
=
get_wav_data
(
dtype
,
4
,
normalize
=
False
)[::
2
,
::
2
]
data
=
expected
.
clone
()
self
.
_save
(
path
,
data
,
8000
)
self
.
assertEqual
(
data
,
expected
)
@
skipIfNoSox
class
TestSaveNonExistingDirectory
(
PytorchTestCase
):
_save
=
partial
(
get_save_func
(),
backend
=
"sox"
)
def
test_save_fail
(
self
):
"""
When attempted to save into a non-existing dir, error message must contain the file path.
"""
path
=
os
.
path
.
join
(
"non_existing_directory"
,
"foo.wav"
)
with
self
.
assertRaisesRegex
(
RuntimeError
,
path
):
self
.
_save
(
path
,
torch
.
zeros
(
1
,
1
),
8000
)
test/torchaudio_unittest/backend/dispatcher/sox/smoke_test.py
0 → 100644
View file @
ffeba11a
import
itertools
from
functools
import
partial
from
parameterized
import
parameterized
from
torchaudio._backend.utils
import
get_info_func
,
get_load_func
,
get_save_func
from
torchaudio_unittest.common_utils
import
get_wav_data
,
skipIfNoSox
,
TempDirMixin
,
TorchaudioTestCase
from
.common
import
name_func
@
skipIfNoSox
class
SmokeTest
(
TempDirMixin
,
TorchaudioTestCase
):
"""Run smoke test on various audio format
The purpose of this test suite is to verify that sox_io_backend functionalities do not exhibit
abnormal behaviors.
This test suite should be able to run without any additional tools (such as sox command),
however without such tools, the correctness of each function cannot be verified.
"""
_info
=
partial
(
get_info_func
(),
backend
=
"sox"
)
_load
=
partial
(
get_load_func
(),
backend
=
"sox"
)
_save
=
partial
(
get_save_func
(),
backend
=
"sox"
)
def
run_smoke_test
(
self
,
ext
,
sample_rate
,
num_channels
,
*
,
dtype
=
"float32"
):
duration
=
1
num_frames
=
sample_rate
*
duration
path
=
self
.
get_temp_path
(
f
"test.
{
ext
}
"
)
original
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
num_frames
)
# 1. run save
self
.
_save
(
path
,
original
,
sample_rate
)
# 2. run info
info
=
self
.
_info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_channels
==
num_channels
# 3. run load
loaded
,
sr
=
self
.
_load
(
path
,
normalize
=
False
)
assert
sr
==
sample_rate
assert
loaded
.
shape
[
0
]
==
num_channels
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
[
8000
,
16000
],
[
1
,
2
],
)
),
name_func
=
name_func
,
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""Run smoke test on wav format"""
self
.
run_smoke_test
(
"wav"
,
sample_rate
,
num_channels
,
dtype
=
dtype
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
)
)
)
def
test_vorbis
(
self
,
sample_rate
,
num_channels
):
"""Run smoke test on vorbis format"""
self
.
run_smoke_test
(
"vorbis"
,
sample_rate
,
num_channels
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
)
),
name_func
=
name_func
,
)
def
test_flac
(
self
,
sample_rate
,
num_channels
):
"""Run smoke test on flac format"""
self
.
run_smoke_test
(
"flac"
,
sample_rate
,
num_channels
)
Prev
1
…
7
8
9
10
11
12
13
14
15
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment