Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
b8fd5e94
Unverified
Commit
b8fd5e94
authored
Feb 23, 2021
by
Prabhat Roy
Committed by
GitHub
Feb 23, 2021
Browse files
Added encoding and bits_per_sample to soundfile's backend save() (#1274)
parent
e1bea2b7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
230 additions
and
32 deletions
+230
-32
test/torchaudio_unittest/backend/soundfile/common.py
test/torchaudio_unittest/backend/soundfile/common.py
+23
-0
test/torchaudio_unittest/backend/soundfile/save_test.py
test/torchaudio_unittest/backend/soundfile/save_test.py
+54
-13
torchaudio/backend/_soundfile_backend.py
torchaudio/backend/_soundfile_backend.py
+152
-18
torchaudio/backend/sox_io_backend.py
torchaudio/backend/sox_io_backend.py
+1
-1
No files found.
test/torchaudio_unittest/backend/soundfile/common.py
View file @
b8fd5e94
...
...
@@ -32,3 +32,26 @@ def skipIfFormatNotSupported(fmt):
def
parameterize
(
*
params
):
return
parameterized
.
expand
(
list
(
itertools
.
product
(
*
params
)),
name_func
=
name_func
)
def
fetch_wav_subtype
(
dtype
,
encoding
,
bits_per_sample
):
subtype
=
{
(
None
,
None
):
dtype2subtype
(
dtype
),
(
None
,
8
):
"PCM_U8"
,
(
'PCM_U'
,
None
):
"PCM_U8"
,
(
'PCM_U'
,
8
):
"PCM_U8"
,
(
'PCM_S'
,
None
):
"PCM_32"
,
(
'PCM_S'
,
16
):
"PCM_16"
,
(
'PCM_S'
,
32
):
"PCM_32"
,
(
'PCM_F'
,
None
):
"FLOAT"
,
(
'PCM_F'
,
32
):
"FLOAT"
,
(
'PCM_F'
,
64
):
"DOUBLE"
,
(
'ULAW'
,
None
):
"ULAW"
,
(
'ULAW'
,
8
):
"ULAW"
,
(
'ALAW'
,
None
):
"ALAW"
,
(
'ALAW'
,
8
):
"ALAW"
,
}.
get
((
encoding
,
bits_per_sample
))
if
subtype
:
return
subtype
raise
ValueError
(
f
"wav does not support (
{
encoding
}
,
{
bits_per_sample
}
)."
)
test/torchaudio_unittest/backend/soundfile/save_test.py
View file @
b8fd5e94
...
...
@@ -11,7 +11,11 @@ from torchaudio_unittest.common_utils import (
get_wav_data
,
load_wav
,
)
from
.common
import
parameterize
,
dtype2subtype
,
skipIfFormatNotSupported
from
.common
import
(
fetch_wav_subtype
,
parameterize
,
skipIfFormatNotSupported
,
)
if
_mod_utils
.
is_module_available
(
"soundfile"
):
import
soundfile
...
...
@@ -20,28 +24,47 @@ if _mod_utils.is_module_available("soundfile"):
class
MockedSaveTest
(
PytorchTestCase
):
@
parameterize
(
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
[
(
None
,
None
),
(
'PCM_U'
,
None
),
(
'PCM_U'
,
8
),
(
'PCM_S'
,
None
),
(
'PCM_S'
,
16
),
(
'PCM_S'
,
32
),
(
'PCM_F'
,
None
),
(
'PCM_F'
,
32
),
(
'PCM_F'
,
64
),
(
'ULAW'
,
None
),
(
'ULAW'
,
8
),
(
'ALAW'
,
None
),
(
'ALAW'
,
8
),
],
)
@
patch
(
"soundfile.write"
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
mocked_write
):
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
enc_params
,
mocked_write
):
"""soundfile_backend.save passes correct subtype to soundfile.write when WAV"""
filepath
=
"foo.wav"
input_tensor
=
get_wav_data
(
dtype
,
num_channels
,
num_frames
=
3
*
sample_rate
,
normalize
=
dtype
==
"fl
a
ot32"
,
normalize
=
dtype
==
"flo
a
t32"
,
channels_first
=
channels_first
,
).
t
()
encoding
,
bits_per_sample
=
enc_params
soundfile_backend
.
save
(
filepath
,
input_tensor
,
sample_rate
,
channels_first
=
channels_first
filepath
,
input_tensor
,
sample_rate
,
channels_first
=
channels_first
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
)
# on +Py3.8 call_args.kwargs is more descreptive
args
=
mocked_write
.
call_args
[
1
]
assert
args
[
"file"
]
==
filepath
assert
args
[
"samplerate"
]
==
sample_rate
assert
args
[
"subtype"
]
==
dtype2subtype
(
dtype
)
assert
args
[
"subtype"
]
==
fetch_wav_subtype
(
dtype
,
encoding
,
bits_per_sample
)
assert
args
[
"format"
]
is
None
self
.
assertEqual
(
args
[
"data"
],
input_tensor
.
t
()
if
channels_first
else
input_tensor
...
...
@@ -49,7 +72,8 @@ class MockedSaveTest(PytorchTestCase):
@
patch
(
"soundfile.write"
)
def
assert_non_wav
(
self
,
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
mocked_write
self
,
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
mocked_write
,
encoding
=
None
,
bits_per_sample
=
None
,
):
"""soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE"""
filepath
=
f
"foo.
{
fmt
}
"
...
...
@@ -63,14 +87,14 @@ class MockedSaveTest(PytorchTestCase):
expected_data
=
input_tensor
.
t
()
if
channels_first
else
input_tensor
soundfile_backend
.
save
(
filepath
,
input_tensor
,
sample_rate
,
channels_first
=
channels_first
filepath
,
input_tensor
,
sample_rate
,
channels_first
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
)
# on +Py3.8 call_args.kwargs is more descreptive
args
=
mocked_write
.
call_args
[
1
]
assert
args
[
"file"
]
==
filepath
assert
args
[
"samplerate"
]
==
sample_rate
assert
args
[
"subtype"
]
is
None
if
fmt
in
[
"sph"
,
"nist"
,
"nis"
]:
assert
args
[
"format"
]
==
"NIST"
else
:
...
...
@@ -83,19 +107,36 @@ class MockedSaveTest(PytorchTestCase):
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
[
(
'PCM_S'
,
8
),
(
'PCM_S'
,
16
),
(
'PCM_S'
,
24
),
(
'PCM_S'
,
32
),
(
'ULAW'
,
8
),
(
'ALAW'
,
8
),
(
'ALAW'
,
16
),
(
'ALAW'
,
24
),
(
'ALAW'
,
32
),
],
)
def
test_sph
(
self
,
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
):
def
test_sph
(
self
,
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
enc_params
):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
self
.
assert_non_wav
(
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
)
encoding
,
bits_per_sample
=
enc_params
self
.
assert_non_wav
(
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
)
@
parameterize
(
[
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
[
8
,
16
,
24
],
)
def
test_flac
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
):
def
test_flac
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
bits_per_sample
):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
self
.
assert_non_wav
(
"flac"
,
dtype
,
sample_rate
,
num_channels
,
channels_first
)
self
.
assert_non_wav
(
"flac"
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
bits_per_sample
=
bits_per_sample
)
@
parameterize
(
[
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
...
...
@@ -228,7 +269,7 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
found
,
sr
=
soundfile
.
read
(
fileobj
,
dtype
=
'float32'
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
self
.
assertEqual
(
expected
,
found
,
atol
=
1e-4
,
rtol
=
1e-8
)
def
test_fileobj_wav
(
self
):
"""Saving audio via file-like object works"""
...
...
torchaudio/backend/_soundfile_backend.py
View file @
b8fd5e94
...
...
@@ -209,6 +209,93 @@ def load(
return
waveform
,
sample_rate
def
_get_subtype_for_wav
(
dtype
:
torch
.
dtype
,
encoding
:
str
,
bits_per_sample
:
int
):
if
not
encoding
:
if
not
bits_per_sample
:
subtype
=
{
torch
.
uint8
:
"PCM_U8"
,
torch
.
int16
:
"PCM_16"
,
torch
.
int32
:
"PCM_32"
,
torch
.
float32
:
"FLOAT"
,
torch
.
float64
:
"DOUBLE"
,
}.
get
(
dtype
)
if
not
subtype
:
raise
ValueError
(
f
"Unsupported dtype for wav:
{
dtype
}
"
)
return
subtype
if
bits_per_sample
==
8
:
return
"PCM_U8"
return
f
"PCM_
{
bits_per_sample
}
"
if
encoding
==
"PCM_S"
:
if
not
bits_per_sample
:
return
"PCM_32"
if
bits_per_sample
==
8
:
raise
ValueError
(
"wav does not support 8-bit signed PCM encoding."
)
return
f
"PCM_
{
bits_per_sample
}
"
if
encoding
==
"PCM_U"
:
if
bits_per_sample
in
(
None
,
8
):
return
"PCM_U8"
raise
ValueError
(
"wav only supports 8-bit unsigned PCM encoding."
)
if
encoding
==
"PCM_F"
:
if
bits_per_sample
in
(
None
,
32
):
return
"FLOAT"
if
bits_per_sample
==
64
:
return
"DOUBLE"
raise
ValueError
(
"wav only supports 32/64-bit float PCM encoding."
)
if
encoding
==
"ULAW"
:
if
bits_per_sample
in
(
None
,
8
):
return
"ULAW"
raise
ValueError
(
"wav only supports 8-bit mu-law encoding."
)
if
encoding
==
"ALAW"
:
if
bits_per_sample
in
(
None
,
8
):
return
"ALAW"
raise
ValueError
(
"wav only supports 8-bit a-law encoding."
)
raise
ValueError
(
f
"wav does not support
{
encoding
}
."
)
def
_get_subtype_for_sphere
(
encoding
:
str
,
bits_per_sample
:
int
):
if
encoding
in
(
None
,
"PCM_S"
):
return
f
"PCM_
{
bits_per_sample
}
"
if
bits_per_sample
else
"PCM_32"
if
encoding
in
(
"PCM_U"
,
"PCM_F"
):
raise
ValueError
(
f
"sph does not support
{
encoding
}
encoding."
)
if
encoding
==
"ULAW"
:
if
bits_per_sample
in
(
None
,
8
):
return
"ULAW"
raise
ValueError
(
"sph only supports 8-bit for mu-law encoding."
)
if
encoding
==
"ALAW"
:
return
"ALAW"
raise
ValueError
(
f
"sph does not support
{
encoding
}
."
)
def
_get_subtype
(
dtype
:
torch
.
dtype
,
format
:
str
,
encoding
:
str
,
bits_per_sample
:
int
):
if
format
==
"wav"
:
return
_get_subtype_for_wav
(
dtype
,
encoding
,
bits_per_sample
)
if
format
==
"flac"
:
if
encoding
:
raise
ValueError
(
"flac does not support encoding."
)
if
not
bits_per_sample
:
return
"PCM_24"
if
bits_per_sample
>
24
:
raise
ValueError
(
"flac does not support bits_per_sample > 24."
)
return
"PCM_S8"
if
bits_per_sample
==
8
else
f
"PCM_
{
bits_per_sample
}
"
if
format
in
(
"ogg"
,
"vorbis"
):
if
encoding
or
bits_per_sample
:
raise
ValueError
(
"ogg/vorbis does not support encoding/bits_per_sample."
)
return
"VORBIS"
if
format
==
"sph"
:
return
_get_subtype_for_sphere
(
encoding
,
bits_per_sample
)
if
format
in
(
"nis"
,
"nist"
):
return
"PCM_16"
raise
ValueError
(
f
"Unsupported format:
{
format
}
"
)
@
_mod_utils
.
requires_module
(
"soundfile"
)
def
save
(
filepath
:
str
,
...
...
@@ -217,6 +304,8 @@ def save(
channels_first
:
bool
=
True
,
compression
:
Optional
[
float
]
=
None
,
format
:
Optional
[
str
]
=
None
,
encoding
:
Optional
[
str
]
=
None
,
bits_per_sample
:
Optional
[
int
]
=
None
,
):
"""Save audio data to file.
...
...
@@ -246,9 +335,65 @@ def save(
otherwise ``[time, channel]``.
compression (Optional[float]): Not used.
It is here only for interface compatibility reson with "sox_io" backend.
format (str, optional): Output audio format.
This is required when the output audio format cannot be infered from
``filepath``, (such as file extension or ``name`` attribute of the given file object).
format (str, optional): Override the audio format.
When ``filepath`` argument is path-like object, audio format is
inferred from file extension. If the file extension is missing or
different, you can specify the correct format with this argument.
When ``filepath`` argument is file-like object,
this argument is required.
Valid values are ``"wav"``, ``"ogg"``, ``"vorbis"``,
``"flac"`` and ``"sph"``.
encoding (str, optional): Changes the encoding for supported formats.
This argument is effective only for supported formats, sush as
``"wav"``, ``""flac"`` and ``"sph"``. Valid values are;
- ``"PCM_S"`` (signed integer Linear PCM)
- ``"PCM_U"`` (unsigned integer Linear PCM)
- ``"PCM_F"`` (floating point PCM)
- ``"ULAW"`` (mu-law)
- ``"ALAW"`` (a-law)
bits_per_sample (int, optional): Changes the bit depth for the
supported formats.
When ``format`` is one of ``"wav"``, ``"flac"`` or ``"sph"``,
you can change the bit depth.
Valid values are ``8``, ``16``, ``24``, ``32`` and ``64``.
Supported formats/encodings/bit depth/compression are:
``"wav"``
- 32-bit floating-point PCM
- 32-bit signed integer PCM
- 24-bit signed integer PCM
- 16-bit signed integer PCM
- 8-bit unsigned integer PCM
- 8-bit mu-law
- 8-bit a-law
Note: Default encoding/bit depth is determined by the dtype of
the input Tensor.
``"flac"``
- 8-bit
- 16-bit
- 24-bit (default)
``"ogg"``, ``"vorbis"``
- Doesn't accept changing configuration.
``"sph"``
- 8-bit signed integer PCM
- 16-bit signed integer PCM
- 24-bit signed integer PCM
- 32-bit signed integer PCM (default)
- 8-bit mu-law
- 8-bit a-law
- 16-bit a-law
- 24-bit a-law
- 32-bit a-law
"""
if
src
.
ndim
!=
2
:
raise
ValueError
(
f
"Expected 2D Tensor, got
{
src
.
ndim
}
D."
)
...
...
@@ -260,24 +405,13 @@ def save(
if
hasattr
(
filepath
,
'write'
):
if
format
is
None
:
raise
RuntimeError
(
'`format` is required when saving to file object.'
)
ext
=
format
ext
=
format
.
lower
()
else
:
ext
=
str
(
filepath
).
split
(
"."
)[
-
1
].
lower
()
if
ext
!=
"wav"
:
subtype
=
None
elif
src
.
dtype
==
torch
.
uint8
:
subtype
=
"PCM_U8"
elif
src
.
dtype
==
torch
.
int16
:
subtype
=
"PCM_16"
elif
src
.
dtype
==
torch
.
int32
:
subtype
=
"PCM_32"
elif
src
.
dtype
==
torch
.
float32
:
subtype
=
"FLOAT"
elif
src
.
dtype
==
torch
.
float64
:
subtype
=
"DOUBLE"
else
:
raise
ValueError
(
f
"Unsupported dtype for WAV:
{
src
.
dtype
}
"
)
if
bits_per_sample
not
in
(
None
,
8
,
16
,
24
,
32
,
64
):
raise
ValueError
(
"Invalid bits_per_sample."
)
subtype
=
_get_subtype
(
src
.
dtype
,
ext
,
encoding
,
bits_per_sample
)
# sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
# so we extend the extensions manually here
...
...
torchaudio/backend/sox_io_backend.py
View file @
b8fd5e94
...
...
@@ -198,7 +198,7 @@ def save(
``"amb"``, ``"flac"``, ``"sph"``, ``"gsm"``, and ``"htk"``.
encoding (str, optional): Changes the encoding for the supported formats.
This argument is effective only for supported formats,
cus
h as ``"wav"``, ``""amb"``
This argument is effective only for supported formats,
suc
h as ``"wav"``, ``""amb"``
and ``"sph"``. Valid values are;
- ``"PCM_S"`` (signed integer Linear PCM)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment