Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
b8fd5e94
"src/vscode:/vscode.git/clone" did not exist on "80c00e5451e0ced32043fbb0ed06eb6f3c427f82"
Unverified
Commit
b8fd5e94
authored
Feb 23, 2021
by
Prabhat Roy
Committed by
GitHub
Feb 23, 2021
Browse files
Added encoding and bits_per_sample to soundfile's backend save() (#1274)
parent
e1bea2b7
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
230 additions
and
32 deletions
+230
-32
test/torchaudio_unittest/backend/soundfile/common.py
test/torchaudio_unittest/backend/soundfile/common.py
+23
-0
test/torchaudio_unittest/backend/soundfile/save_test.py
test/torchaudio_unittest/backend/soundfile/save_test.py
+54
-13
torchaudio/backend/_soundfile_backend.py
torchaudio/backend/_soundfile_backend.py
+152
-18
torchaudio/backend/sox_io_backend.py
torchaudio/backend/sox_io_backend.py
+1
-1
No files found.
test/torchaudio_unittest/backend/soundfile/common.py
View file @
b8fd5e94
...
@@ -32,3 +32,26 @@ def skipIfFormatNotSupported(fmt):
...
@@ -32,3 +32,26 @@ def skipIfFormatNotSupported(fmt):
def
parameterize
(
*
params
):
def
parameterize
(
*
params
):
return
parameterized
.
expand
(
list
(
itertools
.
product
(
*
params
)),
name_func
=
name_func
)
return
parameterized
.
expand
(
list
(
itertools
.
product
(
*
params
)),
name_func
=
name_func
)
def
fetch_wav_subtype
(
dtype
,
encoding
,
bits_per_sample
):
subtype
=
{
(
None
,
None
):
dtype2subtype
(
dtype
),
(
None
,
8
):
"PCM_U8"
,
(
'PCM_U'
,
None
):
"PCM_U8"
,
(
'PCM_U'
,
8
):
"PCM_U8"
,
(
'PCM_S'
,
None
):
"PCM_32"
,
(
'PCM_S'
,
16
):
"PCM_16"
,
(
'PCM_S'
,
32
):
"PCM_32"
,
(
'PCM_F'
,
None
):
"FLOAT"
,
(
'PCM_F'
,
32
):
"FLOAT"
,
(
'PCM_F'
,
64
):
"DOUBLE"
,
(
'ULAW'
,
None
):
"ULAW"
,
(
'ULAW'
,
8
):
"ULAW"
,
(
'ALAW'
,
None
):
"ALAW"
,
(
'ALAW'
,
8
):
"ALAW"
,
}.
get
((
encoding
,
bits_per_sample
))
if
subtype
:
return
subtype
raise
ValueError
(
f
"wav does not support (
{
encoding
}
,
{
bits_per_sample
}
)."
)
test/torchaudio_unittest/backend/soundfile/save_test.py
View file @
b8fd5e94
...
@@ -11,7 +11,11 @@ from torchaudio_unittest.common_utils import (
...
@@ -11,7 +11,11 @@ from torchaudio_unittest.common_utils import (
get_wav_data
,
get_wav_data
,
load_wav
,
load_wav
,
)
)
from
.common
import
parameterize
,
dtype2subtype
,
skipIfFormatNotSupported
from
.common
import
(
fetch_wav_subtype
,
parameterize
,
skipIfFormatNotSupported
,
)
if
_mod_utils
.
is_module_available
(
"soundfile"
):
if
_mod_utils
.
is_module_available
(
"soundfile"
):
import
soundfile
import
soundfile
...
@@ -20,28 +24,47 @@ if _mod_utils.is_module_available("soundfile"):
...
@@ -20,28 +24,47 @@ if _mod_utils.is_module_available("soundfile"):
class
MockedSaveTest
(
PytorchTestCase
):
class
MockedSaveTest
(
PytorchTestCase
):
@
parameterize
(
@
parameterize
(
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
[
(
None
,
None
),
(
'PCM_U'
,
None
),
(
'PCM_U'
,
8
),
(
'PCM_S'
,
None
),
(
'PCM_S'
,
16
),
(
'PCM_S'
,
32
),
(
'PCM_F'
,
None
),
(
'PCM_F'
,
32
),
(
'PCM_F'
,
64
),
(
'ULAW'
,
None
),
(
'ULAW'
,
8
),
(
'ALAW'
,
None
),
(
'ALAW'
,
8
),
],
)
)
@
patch
(
"soundfile.write"
)
@
patch
(
"soundfile.write"
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
mocked_write
):
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
enc_params
,
mocked_write
):
"""soundfile_backend.save passes correct subtype to soundfile.write when WAV"""
"""soundfile_backend.save passes correct subtype to soundfile.write when WAV"""
filepath
=
"foo.wav"
filepath
=
"foo.wav"
input_tensor
=
get_wav_data
(
input_tensor
=
get_wav_data
(
dtype
,
dtype
,
num_channels
,
num_channels
,
num_frames
=
3
*
sample_rate
,
num_frames
=
3
*
sample_rate
,
normalize
=
dtype
==
"fl
a
ot32"
,
normalize
=
dtype
==
"flo
a
t32"
,
channels_first
=
channels_first
,
channels_first
=
channels_first
,
).
t
()
).
t
()
encoding
,
bits_per_sample
=
enc_params
soundfile_backend
.
save
(
soundfile_backend
.
save
(
filepath
,
input_tensor
,
sample_rate
,
channels_first
=
channels_first
filepath
,
input_tensor
,
sample_rate
,
channels_first
=
channels_first
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
)
)
# on +Py3.8 call_args.kwargs is more descreptive
# on +Py3.8 call_args.kwargs is more descreptive
args
=
mocked_write
.
call_args
[
1
]
args
=
mocked_write
.
call_args
[
1
]
assert
args
[
"file"
]
==
filepath
assert
args
[
"file"
]
==
filepath
assert
args
[
"samplerate"
]
==
sample_rate
assert
args
[
"samplerate"
]
==
sample_rate
assert
args
[
"subtype"
]
==
dtype2subtype
(
dtype
)
assert
args
[
"subtype"
]
==
fetch_wav_subtype
(
dtype
,
encoding
,
bits_per_sample
)
assert
args
[
"format"
]
is
None
assert
args
[
"format"
]
is
None
self
.
assertEqual
(
self
.
assertEqual
(
args
[
"data"
],
input_tensor
.
t
()
if
channels_first
else
input_tensor
args
[
"data"
],
input_tensor
.
t
()
if
channels_first
else
input_tensor
...
@@ -49,7 +72,8 @@ class MockedSaveTest(PytorchTestCase):
...
@@ -49,7 +72,8 @@ class MockedSaveTest(PytorchTestCase):
@
patch
(
"soundfile.write"
)
@
patch
(
"soundfile.write"
)
def
assert_non_wav
(
def
assert_non_wav
(
self
,
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
mocked_write
self
,
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
mocked_write
,
encoding
=
None
,
bits_per_sample
=
None
,
):
):
"""soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE"""
"""soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE"""
filepath
=
f
"foo.
{
fmt
}
"
filepath
=
f
"foo.
{
fmt
}
"
...
@@ -63,14 +87,14 @@ class MockedSaveTest(PytorchTestCase):
...
@@ -63,14 +87,14 @@ class MockedSaveTest(PytorchTestCase):
expected_data
=
input_tensor
.
t
()
if
channels_first
else
input_tensor
expected_data
=
input_tensor
.
t
()
if
channels_first
else
input_tensor
soundfile_backend
.
save
(
soundfile_backend
.
save
(
filepath
,
input_tensor
,
sample_rate
,
channels_first
=
channels_first
filepath
,
input_tensor
,
sample_rate
,
channels_first
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
)
)
# on +Py3.8 call_args.kwargs is more descreptive
# on +Py3.8 call_args.kwargs is more descreptive
args
=
mocked_write
.
call_args
[
1
]
args
=
mocked_write
.
call_args
[
1
]
assert
args
[
"file"
]
==
filepath
assert
args
[
"file"
]
==
filepath
assert
args
[
"samplerate"
]
==
sample_rate
assert
args
[
"samplerate"
]
==
sample_rate
assert
args
[
"subtype"
]
is
None
if
fmt
in
[
"sph"
,
"nist"
,
"nis"
]:
if
fmt
in
[
"sph"
,
"nist"
,
"nis"
]:
assert
args
[
"format"
]
==
"NIST"
assert
args
[
"format"
]
==
"NIST"
else
:
else
:
...
@@ -83,19 +107,36 @@ class MockedSaveTest(PytorchTestCase):
...
@@ -83,19 +107,36 @@ class MockedSaveTest(PytorchTestCase):
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
[
False
,
True
],
[
False
,
True
],
[
(
'PCM_S'
,
8
),
(
'PCM_S'
,
16
),
(
'PCM_S'
,
24
),
(
'PCM_S'
,
32
),
(
'ULAW'
,
8
),
(
'ALAW'
,
8
),
(
'ALAW'
,
16
),
(
'ALAW'
,
24
),
(
'ALAW'
,
32
),
],
)
)
def
test_sph
(
self
,
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
):
def
test_sph
(
self
,
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
enc_params
):
"""soundfile_backend.save passes default format and subtype (None-s) to
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
soundfile.write when not WAV"""
self
.
assert_non_wav
(
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
)
encoding
,
bits_per_sample
=
enc_params
self
.
assert_non_wav
(
fmt
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
)
@
parameterize
(
@
parameterize
(
[
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
[
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
[
8
,
16
,
24
],
)
)
def
test_flac
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
):
def
test_flac
(
self
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
bits_per_sample
):
"""soundfile_backend.save passes default format and subtype (None-s) to
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
soundfile.write when not WAV"""
self
.
assert_non_wav
(
"flac"
,
dtype
,
sample_rate
,
num_channels
,
channels_first
)
self
.
assert_non_wav
(
"flac"
,
dtype
,
sample_rate
,
num_channels
,
channels_first
,
bits_per_sample
=
bits_per_sample
)
@
parameterize
(
@
parameterize
(
[
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
[
"int32"
,
"int16"
],
[
8000
,
16000
],
[
1
,
2
],
[
False
,
True
],
...
@@ -228,7 +269,7 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
...
@@ -228,7 +269,7 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
found
,
sr
=
soundfile
.
read
(
fileobj
,
dtype
=
'float32'
)
found
,
sr
=
soundfile
.
read
(
fileobj
,
dtype
=
'float32'
)
assert
sr
==
sample_rate
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
self
.
assertEqual
(
expected
,
found
,
atol
=
1e-4
,
rtol
=
1e-8
)
def
test_fileobj_wav
(
self
):
def
test_fileobj_wav
(
self
):
"""Saving audio via file-like object works"""
"""Saving audio via file-like object works"""
...
...
torchaudio/backend/_soundfile_backend.py
View file @
b8fd5e94
...
@@ -209,6 +209,93 @@ def load(
...
@@ -209,6 +209,93 @@ def load(
return
waveform
,
sample_rate
return
waveform
,
sample_rate
def
_get_subtype_for_wav
(
dtype
:
torch
.
dtype
,
encoding
:
str
,
bits_per_sample
:
int
):
if
not
encoding
:
if
not
bits_per_sample
:
subtype
=
{
torch
.
uint8
:
"PCM_U8"
,
torch
.
int16
:
"PCM_16"
,
torch
.
int32
:
"PCM_32"
,
torch
.
float32
:
"FLOAT"
,
torch
.
float64
:
"DOUBLE"
,
}.
get
(
dtype
)
if
not
subtype
:
raise
ValueError
(
f
"Unsupported dtype for wav:
{
dtype
}
"
)
return
subtype
if
bits_per_sample
==
8
:
return
"PCM_U8"
return
f
"PCM_
{
bits_per_sample
}
"
if
encoding
==
"PCM_S"
:
if
not
bits_per_sample
:
return
"PCM_32"
if
bits_per_sample
==
8
:
raise
ValueError
(
"wav does not support 8-bit signed PCM encoding."
)
return
f
"PCM_
{
bits_per_sample
}
"
if
encoding
==
"PCM_U"
:
if
bits_per_sample
in
(
None
,
8
):
return
"PCM_U8"
raise
ValueError
(
"wav only supports 8-bit unsigned PCM encoding."
)
if
encoding
==
"PCM_F"
:
if
bits_per_sample
in
(
None
,
32
):
return
"FLOAT"
if
bits_per_sample
==
64
:
return
"DOUBLE"
raise
ValueError
(
"wav only supports 32/64-bit float PCM encoding."
)
if
encoding
==
"ULAW"
:
if
bits_per_sample
in
(
None
,
8
):
return
"ULAW"
raise
ValueError
(
"wav only supports 8-bit mu-law encoding."
)
if
encoding
==
"ALAW"
:
if
bits_per_sample
in
(
None
,
8
):
return
"ALAW"
raise
ValueError
(
"wav only supports 8-bit a-law encoding."
)
raise
ValueError
(
f
"wav does not support
{
encoding
}
."
)
def
_get_subtype_for_sphere
(
encoding
:
str
,
bits_per_sample
:
int
):
if
encoding
in
(
None
,
"PCM_S"
):
return
f
"PCM_
{
bits_per_sample
}
"
if
bits_per_sample
else
"PCM_32"
if
encoding
in
(
"PCM_U"
,
"PCM_F"
):
raise
ValueError
(
f
"sph does not support
{
encoding
}
encoding."
)
if
encoding
==
"ULAW"
:
if
bits_per_sample
in
(
None
,
8
):
return
"ULAW"
raise
ValueError
(
"sph only supports 8-bit for mu-law encoding."
)
if
encoding
==
"ALAW"
:
return
"ALAW"
raise
ValueError
(
f
"sph does not support
{
encoding
}
."
)
def
_get_subtype
(
dtype
:
torch
.
dtype
,
format
:
str
,
encoding
:
str
,
bits_per_sample
:
int
):
if
format
==
"wav"
:
return
_get_subtype_for_wav
(
dtype
,
encoding
,
bits_per_sample
)
if
format
==
"flac"
:
if
encoding
:
raise
ValueError
(
"flac does not support encoding."
)
if
not
bits_per_sample
:
return
"PCM_24"
if
bits_per_sample
>
24
:
raise
ValueError
(
"flac does not support bits_per_sample > 24."
)
return
"PCM_S8"
if
bits_per_sample
==
8
else
f
"PCM_
{
bits_per_sample
}
"
if
format
in
(
"ogg"
,
"vorbis"
):
if
encoding
or
bits_per_sample
:
raise
ValueError
(
"ogg/vorbis does not support encoding/bits_per_sample."
)
return
"VORBIS"
if
format
==
"sph"
:
return
_get_subtype_for_sphere
(
encoding
,
bits_per_sample
)
if
format
in
(
"nis"
,
"nist"
):
return
"PCM_16"
raise
ValueError
(
f
"Unsupported format:
{
format
}
"
)
@
_mod_utils
.
requires_module
(
"soundfile"
)
@
_mod_utils
.
requires_module
(
"soundfile"
)
def
save
(
def
save
(
filepath
:
str
,
filepath
:
str
,
...
@@ -217,6 +304,8 @@ def save(
...
@@ -217,6 +304,8 @@ def save(
channels_first
:
bool
=
True
,
channels_first
:
bool
=
True
,
compression
:
Optional
[
float
]
=
None
,
compression
:
Optional
[
float
]
=
None
,
format
:
Optional
[
str
]
=
None
,
format
:
Optional
[
str
]
=
None
,
encoding
:
Optional
[
str
]
=
None
,
bits_per_sample
:
Optional
[
int
]
=
None
,
):
):
"""Save audio data to file.
"""Save audio data to file.
...
@@ -246,9 +335,65 @@ def save(
...
@@ -246,9 +335,65 @@ def save(
otherwise ``[time, channel]``.
otherwise ``[time, channel]``.
compression (Optional[float]): Not used.
compression (Optional[float]): Not used.
It is here only for interface compatibility reson with "sox_io" backend.
It is here only for interface compatibility reson with "sox_io" backend.
format (str, optional): Output audio format.
format (str, optional): Override the audio format.
This is required when the output audio format cannot be infered from
When ``filepath`` argument is path-like object, audio format is
``filepath``, (such as file extension or ``name`` attribute of the given file object).
inferred from file extension. If the file extension is missing or
different, you can specify the correct format with this argument.
When ``filepath`` argument is file-like object,
this argument is required.
Valid values are ``"wav"``, ``"ogg"``, ``"vorbis"``,
``"flac"`` and ``"sph"``.
encoding (str, optional): Changes the encoding for supported formats.
This argument is effective only for supported formats, sush as
``"wav"``, ``""flac"`` and ``"sph"``. Valid values are;
- ``"PCM_S"`` (signed integer Linear PCM)
- ``"PCM_U"`` (unsigned integer Linear PCM)
- ``"PCM_F"`` (floating point PCM)
- ``"ULAW"`` (mu-law)
- ``"ALAW"`` (a-law)
bits_per_sample (int, optional): Changes the bit depth for the
supported formats.
When ``format`` is one of ``"wav"``, ``"flac"`` or ``"sph"``,
you can change the bit depth.
Valid values are ``8``, ``16``, ``24``, ``32`` and ``64``.
Supported formats/encodings/bit depth/compression are:
``"wav"``
- 32-bit floating-point PCM
- 32-bit signed integer PCM
- 24-bit signed integer PCM
- 16-bit signed integer PCM
- 8-bit unsigned integer PCM
- 8-bit mu-law
- 8-bit a-law
Note: Default encoding/bit depth is determined by the dtype of
the input Tensor.
``"flac"``
- 8-bit
- 16-bit
- 24-bit (default)
``"ogg"``, ``"vorbis"``
- Doesn't accept changing configuration.
``"sph"``
- 8-bit signed integer PCM
- 16-bit signed integer PCM
- 24-bit signed integer PCM
- 32-bit signed integer PCM (default)
- 8-bit mu-law
- 8-bit a-law
- 16-bit a-law
- 24-bit a-law
- 32-bit a-law
"""
"""
if
src
.
ndim
!=
2
:
if
src
.
ndim
!=
2
:
raise
ValueError
(
f
"Expected 2D Tensor, got
{
src
.
ndim
}
D."
)
raise
ValueError
(
f
"Expected 2D Tensor, got
{
src
.
ndim
}
D."
)
...
@@ -260,24 +405,13 @@ def save(
...
@@ -260,24 +405,13 @@ def save(
if
hasattr
(
filepath
,
'write'
):
if
hasattr
(
filepath
,
'write'
):
if
format
is
None
:
if
format
is
None
:
raise
RuntimeError
(
'`format` is required when saving to file object.'
)
raise
RuntimeError
(
'`format` is required when saving to file object.'
)
ext
=
format
ext
=
format
.
lower
()
else
:
else
:
ext
=
str
(
filepath
).
split
(
"."
)[
-
1
].
lower
()
ext
=
str
(
filepath
).
split
(
"."
)[
-
1
].
lower
()
if
ext
!=
"wav"
:
if
bits_per_sample
not
in
(
None
,
8
,
16
,
24
,
32
,
64
):
subtype
=
None
raise
ValueError
(
"Invalid bits_per_sample."
)
elif
src
.
dtype
==
torch
.
uint8
:
subtype
=
_get_subtype
(
src
.
dtype
,
ext
,
encoding
,
bits_per_sample
)
subtype
=
"PCM_U8"
elif
src
.
dtype
==
torch
.
int16
:
subtype
=
"PCM_16"
elif
src
.
dtype
==
torch
.
int32
:
subtype
=
"PCM_32"
elif
src
.
dtype
==
torch
.
float32
:
subtype
=
"FLOAT"
elif
src
.
dtype
==
torch
.
float64
:
subtype
=
"DOUBLE"
else
:
raise
ValueError
(
f
"Unsupported dtype for WAV:
{
src
.
dtype
}
"
)
# sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
# sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
# so we extend the extensions manually here
# so we extend the extensions manually here
...
...
torchaudio/backend/sox_io_backend.py
View file @
b8fd5e94
...
@@ -198,7 +198,7 @@ def save(
...
@@ -198,7 +198,7 @@ def save(
``"amb"``, ``"flac"``, ``"sph"``, ``"gsm"``, and ``"htk"``.
``"amb"``, ``"flac"``, ``"sph"``, ``"gsm"``, and ``"htk"``.
encoding (str, optional): Changes the encoding for the supported formats.
encoding (str, optional): Changes the encoding for the supported formats.
This argument is effective only for supported formats,
cus
h as ``"wav"``, ``""amb"``
This argument is effective only for supported formats,
suc
h as ``"wav"``, ``""amb"``
and ``"sph"``. Valid values are;
and ``"sph"``. Valid values are;
- ``"PCM_S"`` (signed integer Linear PCM)
- ``"PCM_S"`` (signed integer Linear PCM)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment