Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
674a71d1
Unverified
Commit
674a71d1
authored
Jan 28, 2021
by
Caroline Chen
Committed by
GitHub
Jan 28, 2021
Browse files
Add target `dtype` argument to `save` function for sox backend (#1204)
parent
47d97e30
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
92 additions
and
26 deletions
+92
-26
test/torchaudio_unittest/sox_io_backend/save_test.py
test/torchaudio_unittest/sox_io_backend/save_test.py
+29
-14
torchaudio/backend/sox_io_backend.py
torchaudio/backend/sox_io_backend.py
+17
-4
torchaudio/csrc/sox/io.cpp
torchaudio/csrc/sox/io.cpp
+22
-6
torchaudio/csrc/sox/io.h
torchaudio/csrc/sox/io.h
+4
-2
torchaudio/csrc/sox/utils.cpp
torchaudio/csrc/sox/utils.cpp
+18
-0
torchaudio/csrc/sox/utils.h
torchaudio/csrc/sox/utils.h
+2
-0
No files found.
test/torchaudio_unittest/sox_io_backend/save_test.py
View file @
674a71d1
import
io
import
itertools
import
torch
from
torchaudio.backend
import
sox_io_backend
from
parameterized
import
parameterized
...
...
@@ -24,7 +25,7 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
"""`sox_io_backend.save` can save wav format."""
path
=
self
.
get_temp_path
(
'data.wav'
)
expected
=
get_wav_data
(
dtype
,
num_channels
,
num_frames
=
num_frames
)
sox_io_backend
.
save
(
path
,
expected
,
sample_rate
)
sox_io_backend
.
save
(
path
,
expected
,
sample_rate
,
dtype
=
None
)
found
,
sr
=
load_wav
(
path
)
assert
sample_rate
==
sr
self
.
assertEqual
(
found
,
expected
)
...
...
@@ -68,7 +69,7 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
save_wav
(
src_path
,
data
,
sample_rate
)
# 2.1. Convert the original wav to mp3 with torchaudio
sox_io_backend
.
save
(
mp3_path
,
load_wav
(
src_path
)[
0
],
sample_rate
,
compression
=
bit_rate
)
mp3_path
,
load_wav
(
src_path
)[
0
],
sample_rate
,
compression
=
bit_rate
,
dtype
=
None
)
# 2.2. Convert the mp3 to wav with Sox
sox_utils
.
convert_audio_file
(
mp3_path
,
wav_path
)
# 2.3. Load
...
...
@@ -99,7 +100,7 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
save_wav
(
src_path
,
data
,
sample_rate
)
# 2.1. Convert the original wav to flac with torchaudio
sox_io_backend
.
save
(
flc_path
,
load_wav
(
src_path
)[
0
],
sample_rate
,
compression
=
compression_level
)
flc_path
,
load_wav
(
src_path
)[
0
],
sample_rate
,
compression
=
compression_level
,
dtype
=
None
)
# 2.2. Convert the flac to wav with Sox
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
sox_utils
.
convert_audio_file
(
flc_path
,
wav_path
,
bit_depth
=
32
)
...
...
@@ -132,7 +133,7 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
save_wav
(
src_path
,
data
,
sample_rate
)
# 2.1. Convert the original wav to vorbis with torchaudio
sox_io_backend
.
save
(
vbs_path
,
load_wav
(
src_path
)[
0
],
sample_rate
,
compression
=
quality_level
)
vbs_path
,
load_wav
(
src_path
)[
0
],
sample_rate
,
compression
=
quality_level
,
dtype
=
None
)
# 2.2. Convert the vorbis to wav with Sox
sox_utils
.
convert_audio_file
(
vbs_path
,
wav_path
)
# 2.3. Load
...
...
@@ -184,7 +185,7 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
data
=
get_wav_data
(
'float32'
,
num_channels
,
normalize
=
True
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
src_path
,
data
,
sample_rate
)
# 2.1. Convert the original wav to sph with torchaudio
sox_io_backend
.
save
(
flc_path
,
load_wav
(
src_path
)[
0
],
sample_rate
)
sox_io_backend
.
save
(
flc_path
,
load_wav
(
src_path
)[
0
],
sample_rate
,
dtype
=
None
)
# 2.2. Convert the sph to wav with Sox
# converting to 32 bit because sph file has 24 bit depth which scipy cannot handle.
sox_utils
.
convert_audio_file
(
flc_path
,
wav_path
,
bit_depth
=
32
)
...
...
@@ -216,7 +217,7 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
src_path
,
data
,
sample_rate
)
# 2.1. Convert the original wav to amb with torchaudio
sox_io_backend
.
save
(
amb_path
,
load_wav
(
src_path
,
normalize
=
False
)[
0
],
sample_rate
)
sox_io_backend
.
save
(
amb_path
,
load_wav
(
src_path
,
normalize
=
False
)[
0
],
sample_rate
,
dtype
=
None
)
# 2.2. Convert the amb to wav with Sox
sox_utils
.
convert_audio_file
(
amb_path
,
wav_path
)
# 2.3. Load
...
...
@@ -248,7 +249,7 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
data
=
get_wav_data
(
'int16'
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
src_path
,
data
,
sample_rate
)
# 2.1. Convert the original wav to amr_nb with torchaudio
sox_io_backend
.
save
(
amr_path
,
load_wav
(
src_path
,
normalize
=
False
)[
0
],
sample_rate
)
sox_io_backend
.
save
(
amr_path
,
load_wav
(
src_path
,
normalize
=
False
)[
0
],
sample_rate
,
dtype
=
None
)
# 2.2. Convert the amr_nb to wav with Sox
sox_utils
.
convert_audio_file
(
amr_path
,
wav_path
)
# 2.3. Load
...
...
@@ -389,7 +390,7 @@ class TestSaveParams(TempDirMixin, PytorchTestCase):
path
=
self
.
get_temp_path
(
'data.wav'
)
data
=
get_wav_data
(
'int32'
,
2
,
channels_first
=
channels_first
)
sox_io_backend
.
save
(
path
,
data
,
8000
,
channels_first
=
channels_first
)
path
,
data
,
8000
,
channels_first
=
channels_first
,
dtype
=
None
)
found
=
load_wav
(
path
)[
0
]
expected
=
data
if
channels_first
else
data
.
transpose
(
1
,
0
)
self
.
assertEqual
(
found
,
expected
)
...
...
@@ -402,7 +403,7 @@ class TestSaveParams(TempDirMixin, PytorchTestCase):
path
=
self
.
get_temp_path
(
'data.wav'
)
expected
=
get_wav_data
(
dtype
,
4
)[::
2
,
::
2
]
assert
not
expected
.
is_contiguous
()
sox_io_backend
.
save
(
path
,
expected
,
8000
)
sox_io_backend
.
save
(
path
,
expected
,
8000
,
dtype
=
None
)
found
=
load_wav
(
path
)[
0
]
self
.
assertEqual
(
found
,
expected
)
...
...
@@ -415,10 +416,24 @@ class TestSaveParams(TempDirMixin, PytorchTestCase):
expected
=
get_wav_data
(
dtype
,
4
)[::
2
,
::
2
]
data
=
expected
.
clone
()
sox_io_backend
.
save
(
path
,
data
,
8000
)
sox_io_backend
.
save
(
path
,
data
,
8000
,
dtype
=
None
)
self
.
assertEqual
(
data
,
expected
)
@
parameterized
.
expand
([
(
'float32'
,
torch
.
tensor
([
-
1.0
,
-
0.5
,
0
,
0.5
,
1.0
]).
to
(
torch
.
float32
)),
(
'int32'
,
torch
.
tensor
([
-
2147483648
,
-
1073741824
,
0
,
1073741824
,
2147483647
]).
to
(
torch
.
int32
)),
(
'int16'
,
torch
.
tensor
([
-
32768
,
-
16384
,
0
,
16384
,
32767
]).
to
(
torch
.
int16
)),
(
'uint8'
,
torch
.
tensor
([
0
,
64
,
128
,
192
,
255
]).
to
(
torch
.
uint8
)),
])
def
test_dtype_conversion
(
self
,
dtype
,
expected
):
"""`save` performs dtype conversion on float32 src tensors only."""
path
=
self
.
get_temp_path
(
"data.wav"
)
data
=
torch
.
tensor
([
-
1.0
,
-
0.5
,
0
,
0.5
,
1.0
]).
to
(
torch
.
float32
).
view
(
-
1
,
1
)
sox_io_backend
.
save
(
path
,
data
,
8000
,
dtype
=
dtype
)
found
=
load_wav
(
path
,
normalize
=
False
)[
0
]
self
.
assertEqual
(
found
,
expected
.
view
(
-
1
,
1
))
@
skipIfNoExtension
@
skipIfNoExec
(
'sox'
)
...
...
@@ -452,11 +467,11 @@ class TestFileObject(SaveTestBase):
res_path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
sox_io_backend
.
save
(
ref_path
,
data
,
channels_first
=
channels_first
,
sample_rate
=
sample_rate
,
compression
=
compression
)
sample_rate
=
sample_rate
,
compression
=
compression
,
dtype
=
None
)
with
open
(
res_path
,
'wb'
)
as
fileobj
:
sox_io_backend
.
save
(
fileobj
,
data
,
channels_first
=
channels_first
,
sample_rate
=
sample_rate
,
compression
=
compression
,
format
=
ext
)
sample_rate
=
sample_rate
,
compression
=
compression
,
format
=
ext
,
dtype
=
None
)
expected_data
,
_
=
sox_io_backend
.
load
(
ref_path
)
data
,
sr
=
sox_io_backend
.
load
(
res_path
)
...
...
@@ -489,11 +504,11 @@ class TestFileObject(SaveTestBase):
res_path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
sox_io_backend
.
save
(
ref_path
,
data
,
channels_first
=
channels_first
,
sample_rate
=
sample_rate
,
compression
=
compression
)
sample_rate
=
sample_rate
,
compression
=
compression
,
dtype
=
None
)
fileobj
=
io
.
BytesIO
()
sox_io_backend
.
save
(
fileobj
,
data
,
channels_first
=
channels_first
,
sample_rate
=
sample_rate
,
compression
=
compression
,
format
=
ext
)
sample_rate
=
sample_rate
,
compression
=
compression
,
format
=
ext
,
dtype
=
None
)
fileobj
.
seek
(
0
)
with
open
(
res_path
,
'wb'
)
as
file_
:
file_
.
write
(
fileobj
.
read
())
...
...
torchaudio/backend/sox_io_backend.py
View file @
674a71d1
import
os
import
warnings
from
typing
import
Tuple
,
Optional
import
torch
...
...
@@ -178,15 +179,16 @@ def _save(
channels_first
:
bool
=
True
,
compression
:
Optional
[
float
]
=
None
,
format
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
str
]
=
None
,
):
if
hasattr
(
filepath
,
'write'
):
if
format
is
None
:
raise
RuntimeError
(
'`format` is required when saving to file object.'
)
torchaudio
.
_torchaudio
.
save_audio_fileobj
(
filepath
,
src
,
sample_rate
,
channels_first
,
compression
,
format
)
filepath
,
src
,
sample_rate
,
channels_first
,
compression
,
format
,
dtype
)
else
:
torch
.
ops
.
torchaudio
.
sox_io_save_audio_file
(
os
.
fspath
(
filepath
),
src
,
sample_rate
,
channels_first
,
compression
,
format
)
os
.
fspath
(
filepath
),
src
,
sample_rate
,
channels_first
,
compression
,
format
,
dtype
)
@
_mod_utils
.
requires_module
(
'torchaudio._torchaudio'
)
...
...
@@ -197,6 +199,7 @@ def save(
channels_first
:
bool
=
True
,
compression
:
Optional
[
float
]
=
None
,
format
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
str
]
=
None
,
):
"""Save audio data to file.
...
...
@@ -243,12 +246,22 @@ def save(
format (str, optional):
Output audio format. This is required when the output audio format cannot be infered from
``filepath``, (such as file extension or ``name`` attribute of the given file object).
dtype (str, optional)
Output tensor dtype.
Valid values: ``"uint8", "int16", "int32", "float32", "float64", None``
``dtype=None`` means no conversion is performed.
``dtype`` parameter is only effective for ``float32`` Tensor.
"""
if
src
.
dtype
==
torch
.
float32
and
dtype
is
None
:
warnings
.
warn
(
'`dtype` default value will be changed to `int16` in 0.9 release.'
'Specify `dtype` to suppress this warning.'
)
if
not
torch
.
jit
.
is_scripting
():
_save
(
filepath
,
src
,
sample_rate
,
channels_first
,
compression
,
format
)
_save
(
filepath
,
src
,
sample_rate
,
channels_first
,
compression
,
format
,
dtype
)
return
torch
.
ops
.
torchaudio
.
sox_io_save_audio_file
(
filepath
,
src
,
sample_rate
,
channels_first
,
compression
,
format
)
filepath
,
src
,
sample_rate
,
channels_first
,
compression
,
format
,
dtype
)
@
_mod_utils
.
requires_module
(
'torchaudio._torchaudio'
)
...
...
torchaudio/csrc/sox/io.cpp
View file @
674a71d1
...
...
@@ -107,10 +107,19 @@ void save_audio_file(
int64_t
sample_rate
,
bool
channels_first
,
c10
::
optional
<
double
>
compression
,
c10
::
optional
<
std
::
string
>
format
)
{
c10
::
optional
<
std
::
string
>
format
,
c10
::
optional
<
std
::
string
>
dtype
)
{
validate_input_tensor
(
tensor
);
auto
signal
=
TensorSignal
(
tensor
,
sample_rate
,
channels_first
);
if
(
tensor
.
dtype
()
!=
torch
::
kFloat32
&&
dtype
.
has_value
())
{
throw
std
::
runtime_error
(
"dtype conversion only supported for float32 tensors"
);
}
const
auto
tgt_dtype
=
(
tensor
.
dtype
()
==
torch
::
kFloat32
&&
dtype
.
has_value
())
?
get_dtype_from_str
(
dtype
.
value
())
:
tensor
.
dtype
();
const
auto
filetype
=
[
&
]()
{
if
(
format
.
has_value
())
...
...
@@ -124,8 +133,7 @@ void save_audio_file(
tensor
=
(
unnormalize_wav
(
tensor
)
/
65536
).
to
(
torch
::
kInt16
);
}
const
auto
signal_info
=
get_signalinfo
(
&
signal
,
filetype
);
const
auto
encoding_info
=
get_encodinginfo
(
filetype
,
tensor
.
dtype
(),
compression
);
const
auto
encoding_info
=
get_encodinginfo
(
filetype
,
tgt_dtype
,
compression
);
SoxFormat
sf
(
sox_open_write
(
path
.
c_str
(),
...
...
@@ -239,10 +247,19 @@ void save_audio_fileobj(
int64_t
sample_rate
,
bool
channels_first
,
c10
::
optional
<
double
>
compression
,
std
::
string
filetype
)
{
std
::
string
filetype
,
c10
::
optional
<
std
::
string
>
dtype
)
{
validate_input_tensor
(
tensor
);
auto
signal
=
TensorSignal
(
tensor
,
sample_rate
,
channels_first
);
if
(
tensor
.
dtype
()
!=
torch
::
kFloat32
&&
dtype
.
has_value
())
{
throw
std
::
runtime_error
(
"dtype conversion only supported for float32 tensors"
);
}
const
auto
tgt_dtype
=
(
tensor
.
dtype
()
==
torch
::
kFloat32
&&
dtype
.
has_value
())
?
get_dtype_from_str
(
dtype
.
value
())
:
tensor
.
dtype
();
if
(
filetype
==
"amr-nb"
)
{
const
auto
num_channels
=
tensor
.
size
(
channels_first
?
0
:
1
);
...
...
@@ -253,8 +270,7 @@ void save_audio_fileobj(
tensor
=
(
unnormalize_wav
(
tensor
)
/
65536
).
to
(
torch
::
kInt16
);
}
const
auto
signal_info
=
get_signalinfo
(
&
signal
,
filetype
);
const
auto
encoding_info
=
get_encodinginfo
(
filetype
,
tensor
.
dtype
(),
compression
);
const
auto
encoding_info
=
get_encodinginfo
(
filetype
,
tgt_dtype
,
compression
);
AutoReleaseBuffer
buffer
;
...
...
torchaudio/csrc/sox/io.h
View file @
674a71d1
...
...
@@ -46,7 +46,8 @@ void save_audio_file(
int64_t
sample_rate
,
bool
channels_first
,
c10
::
optional
<
double
>
compression
,
c10
::
optional
<
std
::
string
>
format
);
c10
::
optional
<
std
::
string
>
format
,
c10
::
optional
<
std
::
string
>
dtype
);
#ifdef TORCH_API_INCLUDE_EXTENSION_H
...
...
@@ -68,7 +69,8 @@ void save_audio_fileobj(
int64_t
sample_rate
,
bool
channels_first
,
c10
::
optional
<
double
>
compression
,
std
::
string
filetype
);
std
::
string
filetype
,
c10
::
optional
<
std
::
string
>
dtype
);
#endif // TORCH_API_INCLUDE_EXTENSION_H
...
...
torchaudio/csrc/sox/utils.cpp
View file @
674a71d1
...
...
@@ -156,6 +156,24 @@ caffe2::TypeMeta get_dtype(
return
c10
::
scalarTypeToTypeMeta
(
dtype
);
}
caffe2
::
TypeMeta
get_dtype_from_str
(
const
std
::
string
dtype
)
{
const
auto
tgt_dtype
=
[
&
]()
{
if
(
dtype
==
"uint8"
)
return
torch
::
kUInt8
;
else
if
(
dtype
==
"int16"
)
return
torch
::
kInt16
;
else
if
(
dtype
==
"int32"
)
return
torch
::
kInt32
;
else
if
(
dtype
==
"float32"
)
return
torch
::
kFloat32
;
else
if
(
dtype
==
"float64"
)
return
torch
::
kFloat64
;
else
throw
std
::
runtime_error
(
"Unsupported dtype"
);
}();
return
c10
::
scalarTypeToTypeMeta
(
tgt_dtype
);
}
torch
::
Tensor
convert_to_tensor
(
sox_sample_t
*
buffer
,
const
int32_t
num_samples
,
...
...
torchaudio/csrc/sox/utils.h
View file @
674a71d1
...
...
@@ -85,6 +85,8 @@ caffe2::TypeMeta get_dtype(
const
sox_encoding_t
encoding
,
const
unsigned
precision
);
caffe2
::
TypeMeta
get_dtype_from_str
(
const
std
::
string
dtype
);
///
/// Convert sox_sample_t buffer to uint8/int16/int32/float32 Tensor
/// NOTE: This function might modify the values in the input buffer to
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment