Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
b8fd5e94
Unverified
Commit
b8fd5e94
authored
Feb 23, 2021
by
Prabhat Roy
Committed by
GitHub
Feb 23, 2021
Browse files
Added encoding and bits_per_sample to soundfile's backend save() (#1274)
parent
e1bea2b7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
230 additions
and
32 deletions
+230
-32
test/torchaudio_unittest/backend/soundfile/common.py
test/torchaudio_unittest/backend/soundfile/common.py
+23
-0
test/torchaudio_unittest/backend/soundfile/save_test.py
test/torchaudio_unittest/backend/soundfile/save_test.py
+54
-13
torchaudio/backend/_soundfile_backend.py
torchaudio/backend/_soundfile_backend.py
+152
-18
torchaudio/backend/sox_io_backend.py
torchaudio/backend/sox_io_backend.py
+1
-1
No files found.
test/torchaudio_unittest/backend/soundfile/common.py
View file @
b8fd5e94
...
@@ -32,3 +32,26 @@ def skipIfFormatNotSupported(fmt):
...
@@ -32,3 +32,26 @@ def skipIfFormatNotSupported(fmt):
def
parameterize
(
*
params
):
def
parameterize
(
*
params
):
return
parameterized
.
expand
(
list
(
itertools
.
product
(
*
params
)),
name_func
=
name_func
)
return
parameterized
.
expand
(
list
(
itertools
.
product
(
*
params
)),
name_func
=
name_func
)
def
fetch_wav_subtype
(
dtype
,
encoding
,
bits_per_sample
):
subtype
=
{
(
None
,
None
):
dtype2subtype
(
dtype
),
(
None
,
8
):
"PCM_U8"
,
(
'PCM_U'
,
None
):
"PCM_U8"
,
(
'PCM_U'
,
8
):
"PCM_U8"
,
(
'PCM_S'
,
None
):
"PCM_32"
,
(
'PCM_S'
,
16
):
"PCM_16"
,
(
'PCM_S'
,
32
):
"PCM_32"
,
(
'PCM_F'
,
None
):
"FLOAT"
,
(
'PCM_F'
,
32
):
"FLOAT"
,
(
'PCM_F'
,
64
):
"DOUBLE"
,
(
'ULAW'
,
None
):
"ULAW"
,
(
'ULAW'
,
8
):
"ULAW"
,
(
'ALAW'
,
None
):
"ALAW"
,
(
'ALAW'
,
8
):
"ALAW"
,
}.
get
((
encoding
,
bits_per_sample
))
if
subtype
:
return
subtype
raise
ValueError
(
f
"wav does not support (
{
encoding
}
,
{
bits_per_sample
}
)."
)
test/torchaudio_unittest/backend/soundfile/save_test.py
View file @
b8fd5e94
...
@@ -11,7 +11,11 @@ from torchaudio_unittest.common_utils import (
...
@@ -11,7 +11,11 @@ from torchaudio_unittest.common_utils import (
get_wav_data
,
get_wav_data
,
load_wav
,
load_wav
,
)
)
from
.common
import
parameterize
,
dtype2subtype
,
skipIfFormatNotSupported
from
.common
import
(
fetch_wav_subtype
,
parameterize
,
skipIfFormatNotSupported
,
)
if
_mod_utils
.
is_module_available
(
"soundfile"
):
if
_mod_utils
.
is_module_available
(
"soundfile"
):
import
soundfile
import
soundfile
...
@@ -20,28 +24,47 @@ if _mod_utils.is_module_available("soundfile"):
...
@@ -20,28 +24,47 @@ if _mod_utils.is_module_available("soundfile"):
class
MockedSaveTest
(
PytorchTestCase
):
class
MockedSaveTest
(
PytorchTestCase
):
@
parameterize
(
@
parameterize
(
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
[
(
None
,
None
),
(
'PCM_U'
,
None
),
(
'PCM_U'
,
8
),
(
'PCM_S'
,
None
),
(
'PCM_S'
,
16
),
(
'PCM_S'
,
32
),
(
'PCM_F'
,
None
),
(
'PCM_F'
,
32
),
(
'PCM_F'
,
64
),
(
'ULAW'
,
None
),
(
'ULAW'
,
8
),
(
'ALAW'
,
None
),
(
'ALAW'
,
8
),
],
)
)
@
patch
(
"soundfile.write"
)
@
patch
(
"soundfile.write"
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
mocked_write
):
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
enc_params
,
mocked_write
):
"""soundfile_backend.save passes correct subtype to soundfile.write when WAV"""
"""soundfile_backend.save passes correct subtype to soundfile.write when WAV"""
filepath
=
"foo.wav"
filepath
=
"foo.wav"
input_tensor
=
get_wav_data
(
input_tensor
=
get_wav_data
(
dtype
,
dtype
,
num_channels
,
num_channels
,
num_frames
=
3
*
sample_rate
,
num_frames
=
3
*
sample_rate
,
normalize
=
dtype
==
"fl
a
ot32"
,
normalize
=
dtype
==
"flo
a
t32"
,
channels_first
=
channels_first
,
channels_first
=
channels_first
,
).
t
()
).
t
()
encoding
,
bits_per_sample
=
enc_params
soundfile_backend
.
save
(
soundfile_backend
.
save
(
filepath
,
input_tensor
,
sample_rate
,
channels_first
=
channels_first
filepath
,
input_tensor
,
sample_rate
,
channels_first
=
channels_first
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
)
)
# on +Py3.8 call_args.kwargs is more descreptive
# on +Py3.8 call_args.kwargs is more descreptive
args
=
mocked_write
.
call_args
[
1
]
args
=
mocked_write
.
call_args
[
1
]
assert
args
[
"file"
]
==
filepath
assert
args
[
"file"
]
==
filepath
assert
args
[
"samplerate"
]
==
sample_rate
assert
args
[
"samplerate"
]
==
sample_rate
assert
args
[
"subtype"
]
==
dtype2subtype
(
dtype
)
assert
args
[
"subtype"
]
==
fetch_wav_subtype
(
dtype
,
encoding
,
bits_per_sample
)
assert
args
[
"format"
]
is
None
assert
args
[
"format"
]
is
None
self
.
assertEqual
(
self
.
assertEqual
(
args
[
"data"
],
input_tensor
.
t
()
if
channels_first
else
input_tensor
args
[
"data"
],
input_tensor
.
t
()
if
channels_first
else
input_tensor
...
@@ -49,7 +72,8 @@ class MockedSaveTest(PytorchTestCase):
...
@@ -49,7 +72,8 @@ class MockedSaveTest(PytorchTestCase):
@
patch
(
"soundfile.write"
)
@
patch
(
"soundfile.write"
)
def
assert_non_wav
(
def
assert_non_wav
(
self
,
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
mocked_write
self
,
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
mocked_write
,
encoding
=
None
,
bits_per_sample
=
None
,
):
):
"""soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE"""
"""soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE"""
filepath
=
f
"foo.
{
fmt
}
"
filepath
=
f
"foo.
{
fmt
}
"
...
@@ -63,14 +87,14 @@ class MockedSaveTest(PytorchTestCase):
...
@@ -63,14 +87,14 @@ class MockedSaveTest(PytorchTestCase):
expected_data
=
input_tensor
.
t
()
if
channels_first
else
input_tensor
expected_data
=
input_tensor
.
t
()
if
channels_first
else
input_tensor
soundfile_backend
.
save
(
soundfile_backend
.
save
(
filepath
,
input_tensor
,
sample_rate
,
channels_first
=
channels_first
filepath
,
input_tensor
,
sample_rate
,
channels_first
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
)
)
# on +Py3.8 call_args.kwargs is more descreptive
# on +Py3.8 call_args.kwargs is more descreptive
args
=
mocked_write
.
call_args
[
1
]
args
=
mocked_write
.
call_args
[
1
]
assert
args
[
"file"
]
==
filepath
assert
args
[
"file"
]
==
filepath
assert
args
[
"samplerate"
]
==
sample_rate
assert
args
[
"samplerate"
]
==
sample_rate
assert
args
[
"subtype"
]
is
None
if
fmt
in
[
"sph"
,
"nist"
,
"nis"
]:
if
fmt
in
[
"sph"
,
"nist"
,
"nis"
]:
assert
args
[
"format"
]
==
"NIST"
assert
args
[
"format"
]
==
"NIST"
else
:
else
:
...
@@ -83,19 +107,36 @@ class MockedSaveTest(PytorchTestCase):
...
@@ -83,19 +107,36 @@ class MockedSaveTest(PytorchTestCase):
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
[
False
,
True
],
[
False
,
True
],
[
(
'PCM_S'
,
8
),
(
'PCM_S'
,
16
),
(
'PCM_S'
,
24
),
(
'PCM_S'
,
32
),
(
'ULAW'
,
8
),
(
'ALAW'
,
8
),
(
'ALAW'
,
16
),
(
'ALAW'
,
24
),
(
'ALAW'
,
32
),
],
)
)
def
test_sph
(
self
,
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
):
def
test_sph
(
self
,
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
enc_params
):
"""soundfile_backend.save passes default format and subtype (None-s) to
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
soundfile.write when not WAV"""
self
.
assert_non_wav
(
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
)
encoding
,
bits_per_sample
=
enc_params
self
.
assert_non_wav
(
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
)
@
parameterize
(
@
parameterize
(
[
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
[
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
[
8
,
16
,
24
],
)
)
def
test_flac
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
):
def
test_flac
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
bits_per_sample
):
"""soundfile_backend.save passes default format and subtype (None-s) to
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
soundfile.write when not WAV"""
self
.
assert_non_wav
(
"flac"
,
dtype
,
sample_rate
,
num_channels
,
channels_first
)
self
.
assert_non_wav
(
"flac"
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
bits_per_sample
=
bits_per_sample
)
@
parameterize
(
@
parameterize
(
[
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
[
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
...
@@ -228,7 +269,7 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
...
@@ -228,7 +269,7 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
found
,
sr
=
soundfile
.
read
(
fileobj
,
dtype
=
'float32'
)
found
,
sr
=
soundfile
.
read
(
fileobj
,
dtype
=
'float32'
)
assert
sr
==
sample_rate
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
self
.
assertEqual
(
expected
,
found
,
atol
=
1e-4
,
rtol
=
1e-8
)
def
test_fileobj_wav
(
self
):
def
test_fileobj_wav
(
self
):
"""Saving audio via file-like object works"""
"""Saving audio via file-like object works"""
...
...
torchaudio/backend/_soundfile_backend.py
View file @
b8fd5e94
...
@@ -209,6 +209,93 @@ def load(
...
@@ -209,6 +209,93 @@ def load(
return
waveform
,
sample_rate
return
waveform
,
sample_rate
def
_get_subtype_for_wav
(
dtype
:
torch
.
dtype
,
encoding
:
str
,
bits_per_sample
:
int
):
if
not
encoding
:
if
not
bits_per_sample
:
subtype
=
{
torch
.
uint8
:
"PCM_U8"
,
torch
.
int16
:
"PCM_16"
,
torch
.
int32
:
"PCM_32"
,
torch
.
float32
:
"FLOAT"
,
torch
.
float64
:
"DOUBLE"
,
}.
get
(
dtype
)
if
not
subtype
:
raise
ValueError
(
f
"Unsupported dtype for wav:
{
dtype
}
"
)
return
subtype
if
bits_per_sample
==
8
:
return
"PCM_U8"
return
f
"PCM_
{
bits_per_sample
}
"
if
encoding
==
"PCM_S"
:
if
not
bits_per_sample
:
return
"PCM_32"
if
bits_per_sample
==
8
:
raise
ValueError
(
"wav does not support 8-bit signed PCM encoding."
)
return
f
"PCM_
{
bits_per_sample
}
"
if
encoding
==
"PCM_U"
:
if
bits_per_sample
in
(
None
,
8
):
return
"PCM_U8"
raise
ValueError
(
"wav only supports 8-bit unsigned PCM encoding."
)
if
encoding
==
"PCM_F"
:
if
bits_per_sample
in
(
None
,
32
):
return
"FLOAT"
if
bits_per_sample
==
64
:
return
"DOUBLE"
raise
ValueError
(
"wav only supports 32/64-bit float PCM encoding."
)
if
encoding
==
"ULAW"
:
if
bits_per_sample
in
(
None
,
8
):
return
"ULAW"
raise
ValueError
(
"wav only supports 8-bit mu-law encoding."
)
if
encoding
==
"ALAW"
:
if
bits_per_sample
in
(
None
,
8
):
return
"ALAW"
raise
ValueError
(
"wav only supports 8-bit a-law encoding."
)
raise
ValueError
(
f
"wav does not support
{
encoding
}
."
)
def
_get_subtype_for_sphere
(
encoding
:
str
,
bits_per_sample
:
int
):
if
encoding
in
(
None
,
"PCM_S"
):
return
f
"PCM_
{
bits_per_sample
}
"
if
bits_per_sample
else
"PCM_32"
if
encoding
in
(
"PCM_U"
,
"PCM_F"
):
raise
ValueError
(
f
"sph does not support
{
encoding
}
encoding."
)
if
encoding
==
"ULAW"
:
if
bits_per_sample
in
(
None
,
8
):
return
"ULAW"
raise
ValueError
(
"sph only supports 8-bit for mu-law encoding."
)
if
encoding
==
"ALAW"
:
return
"ALAW"
raise
ValueError
(
f
"sph does not support
{
encoding
}
."
)
def
_get_subtype
(
dtype
:
torch
.
dtype
,
format
:
str
,
encoding
:
str
,
bits_per_sample
:
int
):
if
format
==
"wav"
:
return
_get_subtype_for_wav
(
dtype
,
encoding
,
bits_per_sample
)
if
format
==
"flac"
:
if
encoding
:
raise
ValueError
(
"flac does not support encoding."
)
if
not
bits_per_sample
:
return
"PCM_24"
if
bits_per_sample
>
24
:
raise
ValueError
(
"flac does not support bits_per_sample > 24."
)
return
"PCM_S8"
if
bits_per_sample
==
8
else
f
"PCM_
{
bits_per_sample
}
"
if
format
in
(
"ogg"
,
"vorbis"
):
if
encoding
or
bits_per_sample
:
raise
ValueError
(
"ogg/vorbis does not support encoding/bits_per_sample."
)
return
"VORBIS"
if
format
==
"sph"
:
return
_get_subtype_for_sphere
(
encoding
,
bits_per_sample
)
if
format
in
(
"nis"
,
"nist"
):
return
"PCM_16"
raise
ValueError
(
f
"Unsupported format:
{
format
}
"
)
@
_mod_utils
.
requires_module
(
"soundfile"
)
@
_mod_utils
.
requires_module
(
"soundfile"
)
def
save
(
def
save
(
filepath
:
str
,
filepath
:
str
,
...
@@ -217,6 +304,8 @@ def save(
...
@@ -217,6 +304,8 @@ def save(
channels_first
:
bool
=
True
,
channels_first
:
bool
=
True
,
compression
:
Optional
[
float
]
=
None
,
compression
:
Optional
[
float
]
=
None
,
format
:
Optional
[
str
]
=
None
,
format
:
Optional
[
str
]
=
None
,
encoding
:
Optional
[
str
]
=
None
,
bits_per_sample
:
Optional
[
int
]
=
None
,
):
):
"""Save audio data to file.
"""Save audio data to file.
...
@@ -246,9 +335,65 @@ def save(
...
@@ -246,9 +335,65 @@ def save(
otherwise ``[time, channel]``.
otherwise ``[time, channel]``.
compression (Optional[float]): Not used.
compression (Optional[float]): Not used.
It is here only for interface compatibility reson with "sox_io" backend.
It is here only for interface compatibility reson with "sox_io" backend.
format (str, optional): Output audio format.
format (str, optional): Override the audio format.
This is required when the output audio format cannot be infered from
When ``filepath`` argument is path-like object, audio format is
``filepath``, (such as file extension or ``name`` attribute of the given file object).
inferred from file extension. If the file extension is missing or
different, you can specify the correct format with this argument.
When ``filepath`` argument is file-like object,
this argument is required.
Valid values are ``"wav"``, ``"ogg"``, ``"vorbis"``,
``"flac"`` and ``"sph"``.
encoding (str, optional): Changes the encoding for supported formats.
This argument is effective only for supported formats, sush as
``"wav"``, ``""flac"`` and ``"sph"``. Valid values are;
- ``"PCM_S"`` (signed integer Linear PCM)
- ``"PCM_U"`` (unsigned integer Linear PCM)
- ``"PCM_F"`` (floating point PCM)
- ``"ULAW"`` (mu-law)
- ``"ALAW"`` (a-law)
bits_per_sample (int, optional): Changes the bit depth for the
supported formats.
When ``format`` is one of ``"wav"``, ``"flac"`` or ``"sph"``,
you can change the bit depth.
Valid values are ``8``, ``16``, ``24``, ``32`` and ``64``.
Supported formats/encodings/bit depth/compression are:
``"wav"``
- 32-bit floating-point PCM
- 32-bit signed integer PCM
- 24-bit signed integer PCM
- 16-bit signed integer PCM
- 8-bit unsigned integer PCM
- 8-bit mu-law
- 8-bit a-law
Note: Default encoding/bit depth is determined by the dtype of
the input Tensor.
``"flac"``
- 8-bit
- 16-bit
- 24-bit (default)
``"ogg"``, ``"vorbis"``
- Doesn't accept changing configuration.
``"sph"``
- 8-bit signed integer PCM
- 16-bit signed integer PCM
- 24-bit signed integer PCM
- 32-bit signed integer PCM (default)
- 8-bit mu-law
- 8-bit a-law
- 16-bit a-law
- 24-bit a-law
- 32-bit a-law
"""
"""
if
src
.
ndim
!=
2
:
if
src
.
ndim
!=
2
:
raise
ValueError
(
f
"Expected 2D Tensor, got
{
src
.
ndim
}
D."
)
raise
ValueError
(
f
"Expected 2D Tensor, got
{
src
.
ndim
}
D."
)
...
@@ -260,24 +405,13 @@ def save(
...
@@ -260,24 +405,13 @@ def save(
if
hasattr
(
filepath
,
'write'
):
if
hasattr
(
filepath
,
'write'
):
if
format
is
None
:
if
format
is
None
:
raise
RuntimeError
(
'`format` is required when saving to file object.'
)
raise
RuntimeError
(
'`format` is required when saving to file object.'
)
ext
=
format
ext
=
format
.
lower
()
else
:
else
:
ext
=
str
(
filepath
).
split
(
"."
)[
-
1
].
lower
()
ext
=
str
(
filepath
).
split
(
"."
)[
-
1
].
lower
()
if
ext
!=
"wav"
:
if
bits_per_sample
not
in
(
None
,
8
,
16
,
24
,
32
,
64
):
subtype
=
None
raise
ValueError
(
"Invalid bits_per_sample."
)
elif
src
.
dtype
==
torch
.
uint8
:
subtype
=
_get_subtype
(
src
.
dtype
,
ext
,
encoding
,
bits_per_sample
)
subtype
=
"PCM_U8"
elif
src
.
dtype
==
torch
.
int16
:
subtype
=
"PCM_16"
elif
src
.
dtype
==
torch
.
int32
:
subtype
=
"PCM_32"
elif
src
.
dtype
==
torch
.
float32
:
subtype
=
"FLOAT"
elif
src
.
dtype
==
torch
.
float64
:
subtype
=
"DOUBLE"
else
:
raise
ValueError
(
f
"Unsupported dtype for WAV:
{
src
.
dtype
}
"
)
# sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
# sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
# so we extend the extensions manually here
# so we extend the extensions manually here
...
...
torchaudio/backend/sox_io_backend.py
View file @
b8fd5e94
...
@@ -198,7 +198,7 @@ def save(
...
@@ -198,7 +198,7 @@ def save(
``"amb"``, ``"flac"``, ``"sph"``, ``"gsm"``, and ``"htk"``.
``"amb"``, ``"flac"``, ``"sph"``, ``"gsm"``, and ``"htk"``.
encoding (str, optional): Changes the encoding for the supported formats.
encoding (str, optional): Changes the encoding for the supported formats.
This argument is effective only for supported formats,
cus
h as ``"wav"``, ``""amb"``
This argument is effective only for supported formats,
suc
h as ``"wav"``, ``""amb"``
and ``"sph"``. Valid values are;
and ``"sph"``. Valid values are;
- ``"PCM_S"`` (signed integer Linear PCM)
- ``"PCM_S"`` (signed integer Linear PCM)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment