Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
c3cb2015
Unverified
Commit
c3cb2015
authored
Feb 12, 2021
by
moto
Committed by
GitHub
Feb 12, 2021
Browse files
Add encoding and bits_per_sample option to save function (#1226)
parent
4f9b5520
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
883 additions
and
631 deletions
+883
-631
test/torchaudio_unittest/backend/sox_io/common.py
test/torchaudio_unittest/backend/sox_io/common.py
+12
-0
test/torchaudio_unittest/backend/sox_io/roundtrip_test.py
test/torchaudio_unittest/backend/sox_io/roundtrip_test.py
+3
-1
test/torchaudio_unittest/backend/sox_io/save_test.py
test/torchaudio_unittest/backend/sox_io/save_test.py
+302
-438
test/torchaudio_unittest/backend/sox_io/torchscript_test.py
test/torchaudio_unittest/backend/sox_io/torchscript_test.py
+14
-8
test/torchaudio_unittest/common_utils/sox_utils.py
test/torchaudio_unittest/common_utils/sox_utils.py
+8
-4
torchaudio/backend/sox_io_backend.py
torchaudio/backend/sox_io_backend.py
+128
-63
torchaudio/csrc/CMakeLists.txt
torchaudio/csrc/CMakeLists.txt
+1
-0
torchaudio/csrc/sox/effects_chain.cpp
torchaudio/csrc/sox/effects_chain.cpp
+30
-8
torchaudio/csrc/sox/io.cpp
torchaudio/csrc/sox/io.cpp
+16
-27
torchaudio/csrc/sox/io.h
torchaudio/csrc/sox/io.h
+8
-6
torchaudio/csrc/sox/types.cpp
torchaudio/csrc/sox/types.cpp
+102
-0
torchaudio/csrc/sox/types.h
torchaudio/csrc/sox/types.h
+55
-0
torchaudio/csrc/sox/utils.cpp
torchaudio/csrc/sox/utils.cpp
+200
-69
torchaudio/csrc/sox/utils.h
torchaudio/csrc/sox/utils.h
+4
-7
No files found.
test/torchaudio_unittest/backend/sox_io/common.py
View file @
c3cb2015
def
name_func
(
func
,
_
,
params
):
def
name_func
(
func
,
_
,
params
):
return
f
'
{
func
.
__name__
}
_
{
"_"
.
join
(
str
(
arg
)
for
arg
in
params
.
args
)
}
'
return
f
'
{
func
.
__name__
}
_
{
"_"
.
join
(
str
(
arg
)
for
arg
in
params
.
args
)
}
'
def
get_enc_params
(
dtype
):
if
dtype
==
'float32'
:
return
'PCM_F'
,
32
if
dtype
==
'int32'
:
return
'PCM_S'
,
32
if
dtype
==
'int16'
:
return
'PCM_S'
,
16
if
dtype
==
'uint8'
:
return
'PCM_U'
,
8
raise
ValueError
(
f
'Unexpected dtype:
{
dtype
}
'
)
test/torchaudio_unittest/backend/sox_io/roundtrip_test.py
View file @
c3cb2015
...
@@ -12,6 +12,7 @@ from torchaudio_unittest.common_utils import (
...
@@ -12,6 +12,7 @@ from torchaudio_unittest.common_utils import (
)
)
from
.common
import
(
from
.common
import
(
name_func
,
name_func
,
get_enc_params
,
)
)
...
@@ -27,10 +28,11 @@ class TestRoundTripIO(TempDirMixin, PytorchTestCase):
...
@@ -27,10 +28,11 @@ class TestRoundTripIO(TempDirMixin, PytorchTestCase):
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""save/load round trip should not degrade data for wav formats"""
"""save/load round trip should not degrade data for wav formats"""
original
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
)
original
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
)
enc
,
bps
=
get_enc_params
(
dtype
)
data
=
original
data
=
original
for
i
in
range
(
10
):
for
i
in
range
(
10
):
path
=
self
.
get_temp_path
(
f
'
{
i
}
.wav'
)
path
=
self
.
get_temp_path
(
f
'
{
i
}
.wav'
)
sox_io_backend
.
save
(
path
,
data
,
sample_rate
)
sox_io_backend
.
save
(
path
,
data
,
sample_rate
,
encoding
=
enc
,
bits_per_sample
=
bps
)
data
,
sr
=
sox_io_backend
.
load
(
path
,
normalize
=
False
)
data
,
sr
=
sox_io_backend
.
load
(
path
,
normalize
=
False
)
assert
sr
==
sample_rate
assert
sr
==
sample_rate
self
.
assertEqual
(
original
,
data
)
self
.
assertEqual
(
original
,
data
)
...
...
test/torchaudio_unittest/backend/sox_io/save_test.py
View file @
c3cb2015
import
io
import
io
import
itertools
import
unittest
from
itertools
import
product
import
torch
from
torchaudio.backend
import
sox_io_backend
from
torchaudio.backend
import
sox_io_backend
from
parameterized
import
parameterized
from
parameterized
import
parameterized
from
torchaudio_unittest.common_utils
import
(
from
torchaudio_unittest.common_utils
import
(
TempDirMixin
,
TempDirMixin
,
TorchaudioTestCase
,
PytorchTestCase
,
PytorchTestCase
,
skipIfNoExec
,
skipIfNoExec
,
skipIfNoExtension
,
skipIfNoExtension
,
...
@@ -17,37 +18,62 @@ from torchaudio_unittest.common_utils import (
...
@@ -17,37 +18,62 @@ from torchaudio_unittest.common_utils import (
)
)
from
.common
import
(
from
.common
import
(
name_func
,
name_func
,
get_enc_params
,
)
)
class
SaveTestBase
(
TempDirMixin
,
PytorchTestCase
):
def
_get_sox_encoding
(
encoding
):
def
assert_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
encodings
=
{
"""`sox_io_backend.save` can save wav format."""
'PCM_F'
:
'floating-point'
,
path
=
self
.
get_temp_path
(
'data.wav'
)
'PCM_S'
:
'signed-integer'
,
expected
=
get_wav_data
(
dtype
,
num_channels
,
num_frames
=
num_frames
)
'PCM_U'
:
'unsigned-integer'
,
sox_io_backend
.
save
(
path
,
expected
,
sample_rate
,
dtype
=
None
)
'ULAW'
:
'u-law'
,
found
,
sr
=
load_wav
(
path
)
'ALAW'
:
'a-law'
,
assert
sample_rate
==
sr
}
self
.
assertEqual
(
found
,
expected
)
return
encodings
.
get
(
encoding
)
def
assert_mp3
(
self
,
sample_rate
,
num_channels
,
bit_rate
,
duration
):
"""`sox_io_backend.save` can save mp3 format.
class
SaveTestBase
(
TempDirMixin
,
TorchaudioTestCase
):
def
assert_save_consistency
(
mp3 encoding introduces delay and boundary effects so
self
,
we convert the resulting mp3 to wav and compare the results there
format
:
str
,
*
,
|
compression
:
float
=
None
,
| 1. Generate original wav file with SciPy
encoding
:
str
=
None
,
bits_per_sample
:
int
=
None
,
sample_rate
:
float
=
8000
,
num_channels
:
int
=
2
,
num_frames
:
float
=
3
*
8000
,
src_dtype
:
str
=
'int32'
,
test_mode
:
str
=
"path"
,
):
"""`save` function produces file that is comparable with `sox` command
To compare that the file produced by `save` function agains the file produced by
the equivalent `sox` command, we need to load both files.
But there are many formats that cannot be opened with common Python modules (like
SciPy).
So we use `sox` command to prepare the original data and convert the saved files
into a format that SciPy can read (PCM wav).
The following diagram illustrates this process. The difference is 2.1. and 3.1.
This assumes that
- loading data with SciPy preserves the data well.
- converting the resulting files into WAV format with `sox` preserve the data well.
x
| 1. Generate source wav file with SciPy
|
|
v
v
-------------- wav ----------------
-------------- wav ----------------
| |
| |
| 2.1. load with scipy | 3.1. Convert to mp3 with Sox
| 2.1. load with scipy | 3.1. Convert to the target
| then save with torchaudio |
| then save it into the target | format depth with sox
| format with torchaudio |
v v
v v
mp3
mp3
target format
target format
| |
| |
| 2.2. Convert to wav with
S
ox | 3.2. Convert to wav with
S
ox
| 2.2. Convert to wav with
s
ox | 3.2. Convert to wav with
s
ox
| |
| |
v v
v v
wav wav
wav wav
...
@@ -58,326 +84,260 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
...
@@ -58,326 +84,260 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
tensor -------> compare <--------- tensor
tensor -------> compare <--------- tensor
"""
"""
src_path
=
self
.
get_temp_path
(
'1.reference.wav'
)
cmp_encoding
=
'floating-point'
mp3_path
=
self
.
get_temp_path
(
'2.1.torchaudio.mp3'
)
cmp_bit_depth
=
32
wav_path
=
self
.
get_temp_path
(
'2.2.torchaudio.wav'
)
mp3_path_sox
=
self
.
get_temp_path
(
'3.1.sox.mp3'
)
wav_path_sox
=
self
.
get_temp_path
(
'3.2.sox.wav'
)
# 1. Generate original wav
src_path
=
self
.
get_temp_path
(
'1.source.wav'
)
data
=
get_wav_data
(
'float32'
,
num_channels
,
normalize
=
True
,
num_frames
=
duration
*
sample_rate
)
tgt_path
=
self
.
get_temp_path
(
f
'2.1.torchaudio.
{
format
}
'
)
save_wav
(
src_path
,
data
,
sample_rate
)
tst_path
=
self
.
get_temp_path
(
'2.2.result.wav'
)
# 2.1. Convert the original wav to mp3 with torchaudio
sox_path
=
self
.
get_temp_path
(
f
'3.1.sox.
{
format
}
'
)
sox_io_backend
.
save
(
ref_path
=
self
.
get_temp_path
(
'3.2.ref.wav'
)
mp3_path
,
load_wav
(
src_path
)[
0
],
sample_rate
,
compression
=
bit_rate
,
dtype
=
None
)
# 2.2. Convert the mp3 to wav with Sox
sox_utils
.
convert_audio_file
(
mp3_path
,
wav_path
)
# 2.3. Load
found
=
load_wav
(
wav_path
)[
0
]
# 3.1. Convert the original wav to mp3 with SoX
sox_utils
.
convert_audio_file
(
src_path
,
mp3_path_sox
,
compression
=
bit_rate
)
# 3.2. Convert the mp3 to wav with Sox
sox_utils
.
convert_audio_file
(
mp3_path_sox
,
wav_path_sox
)
# 3.3. Load
expected
=
load_wav
(
wav_path_sox
)[
0
]
self
.
assertEqual
(
found
,
expected
)
def
assert_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
,
duration
):
"""`sox_io_backend.save` can save flac format.
This test takes the same strategy as mp3 to compare the result
"""
src_path
=
self
.
get_temp_path
(
'1.reference.wav'
)
flc_path
=
self
.
get_temp_path
(
'2.1.torchaudio.flac'
)
wav_path
=
self
.
get_temp_path
(
'2.2.torchaudio.wav'
)
flc_path_sox
=
self
.
get_temp_path
(
'3.1.sox.flac'
)
wav_path_sox
=
self
.
get_temp_path
(
'3.2.sox.wav'
)
# 1. Generate original wav
# 1. Generate original wav
data
=
get_wav_data
(
'float32'
,
num_channels
,
normalize
=
Tru
e
,
num_frames
=
duration
*
sample_rate
)
data
=
get_wav_data
(
src_dtype
,
num_channels
,
normalize
=
Fals
e
,
num_frames
=
num_frames
)
save_wav
(
src_path
,
data
,
sample_rate
)
save_wav
(
src_path
,
data
,
sample_rate
)
# 2.1. Convert the original wav to flac with torchaudio
sox_io_backend
.
save
(
flc_path
,
load_wav
(
src_path
)[
0
],
sample_rate
,
compression
=
compression_level
,
dtype
=
None
)
# 2.2. Convert the flac to wav with Sox
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
sox_utils
.
convert_audio_file
(
flc_path
,
wav_path
,
bit_depth
=
32
)
# 2.3. Load
found
=
load_wav
(
wav_path
)[
0
]
# 3.1. Convert the original wav to flac with SoX
sox_utils
.
convert_audio_file
(
src_path
,
flc_path_sox
,
compression
=
compression_level
)
# 3.2. Convert the flac to wav with Sox
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
sox_utils
.
convert_audio_file
(
flc_path_sox
,
wav_path_sox
,
bit_depth
=
32
)
# 3.3. Load
expected
=
load_wav
(
wav_path_sox
)[
0
]
self
.
assertEqual
(
found
,
expected
)
def
_assert_vorbis
(
self
,
sample_rate
,
num_channels
,
quality_level
,
duration
):
"""`sox_io_backend.save` can save vorbis format.
This test takes the same strategy as mp3 to compare the result
"""
src_path
=
self
.
get_temp_path
(
'1.reference.wav'
)
vbs_path
=
self
.
get_temp_path
(
'2.1.torchaudio.vorbis'
)
wav_path
=
self
.
get_temp_path
(
'2.2.torchaudio.wav'
)
vbs_path_sox
=
self
.
get_temp_path
(
'3.1.sox.vorbis'
)
wav_path_sox
=
self
.
get_temp_path
(
'3.2.sox.wav'
)
# 1. Generate original wav
# 2.1. Convert the original wav to target format with torchaudio
data
=
get_wav_data
(
'int16'
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
data
=
load_wav
(
src_path
,
normalize
=
False
)[
0
]
save_wav
(
src_path
,
data
,
sample_rate
)
if
test_mode
==
"path"
:
# 2.1. Convert the original wav to vorbis with torchaudio
sox_io_backend
.
save
(
sox_io_backend
.
save
(
tgt_path
,
data
,
sample_rate
,
vbs_path
,
load_wav
(
src_path
)[
0
],
sample_rate
,
compression
=
quality_level
,
dtype
=
None
)
compression
=
compression
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
)
# 2.2. Convert the vorbis to wav with Sox
elif
test_mode
==
"fileobj"
:
sox_utils
.
convert_audio_file
(
vbs_path
,
wav_path
)
with
open
(
tgt_path
,
'bw'
)
as
file_
:
# 2.3. Load
sox_io_backend
.
save
(
found
=
load_wav
(
wav_path
)[
0
]
file_
,
data
,
sample_rate
,
format
=
format
,
compression
=
compression
,
# 3.1. Convert the original wav to vorbis with SoX
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
)
sox_utils
.
convert_audio_file
(
src_path
,
vbs_path_sox
,
compression
=
quality_level
)
elif
test_mode
==
"bytesio"
:
# 3.2. Convert the vorbis to wav with Sox
file_
=
io
.
BytesIO
()
sox_utils
.
convert_audio_file
(
vbs_path_sox
,
wav_path_sox
)
sox_io_backend
.
save
(
# 3.3. Load
file_
,
data
,
sample_rate
,
expected
=
load_wav
(
wav_path_sox
)[
0
]
format
=
format
,
compression
=
compression
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
)
# sox's vorbis encoding has some random boundary effect, which cause small number of
file_
.
seek
(
0
)
# samples yields higher descrepency than the others.
with
open
(
tgt_path
,
'bw'
)
as
f
:
# so we allow small portions of data to be outside of absolute torelance.
f
.
write
(
file_
.
read
())
# make sure to pass somewhat long duration
atol
=
1.0e-4
max_failure_allowed
=
0.01
# this percent of samples are allowed to outside of atol.
failure_ratio
=
((
found
-
expected
).
abs
()
>
atol
).
sum
().
item
()
/
found
.
numel
()
if
failure_ratio
>
max_failure_allowed
:
# it's failed and this will give a better error message.
self
.
assertEqual
(
found
,
expected
,
atol
=
atol
,
rtol
=
1.3e-6
)
def
assert_vorbis
(
self
,
*
args
,
**
kwargs
):
# sox's vorbis encoding has some randomness, so we run tests multiple time
max_retry
=
5
error
=
None
for
_
in
range
(
max_retry
):
try
:
self
.
_assert_vorbis
(
*
args
,
**
kwargs
)
break
except
AssertionError
as
e
:
error
=
e
else
:
else
:
raise
error
raise
ValueError
(
f
"Unexpected test mode:
{
test_mode
}
"
)
# 2.2. Convert the target format to wav with sox
def
assert_sphere
(
self
,
sample_rate
,
num_channels
,
duration
):
sox_utils
.
convert_audio_file
(
"""`sox_io_backend.save` can save sph format.
tgt_path
,
tst_path
,
encoding
=
cmp_encoding
,
bit_depth
=
cmp_bit_depth
)
# 2.3. Load with SciPy
This test takes the same strategy as mp3 to compare the result
found
=
load_wav
(
tst_path
,
normalize
=
False
)[
0
]
"""
src_path
=
self
.
get_temp_path
(
'1.reference.wav'
)
# 3.1. Convert the original wav to target format with sox
flc_path
=
self
.
get_temp_path
(
'2.1.torchaudio.sph'
)
sox_encoding
=
_get_sox_encoding
(
encoding
)
wav_path
=
self
.
get_temp_path
(
'2.2.torchaudio.wav'
)
sox_utils
.
convert_audio_file
(
flc_path_sox
=
self
.
get_temp_path
(
'3.1.sox.sph'
)
src_path
,
sox_path
,
wav_path_sox
=
self
.
get_temp_path
(
'3.2.sox.wav'
)
compression
=
compression
,
encoding
=
sox_encoding
,
bit_depth
=
bits_per_sample
)
# 3.2. Convert the target format to wav with sox
# 1. Generate original wav
sox_utils
.
convert_audio_file
(
data
=
get_wav_data
(
'float32'
,
num_channels
,
normalize
=
True
,
num_frames
=
duration
*
sample_rate
)
sox_path
,
ref_path
,
encoding
=
cmp_encoding
,
bit_depth
=
cmp_bit_depth
)
save_wav
(
src_path
,
data
,
sample_rate
)
# 3.3. Load with SciPy
# 2.1. Convert the original wav to sph with torchaudio
expected
=
load_wav
(
ref_path
,
normalize
=
False
)[
0
]
sox_io_backend
.
save
(
flc_path
,
load_wav
(
src_path
)[
0
],
sample_rate
,
dtype
=
None
)
# 2.2. Convert the sph to wav with Sox
# converting to 32 bit because sph file has 24 bit depth which scipy cannot handle.
sox_utils
.
convert_audio_file
(
flc_path
,
wav_path
,
bit_depth
=
32
)
# 2.3. Load
found
=
load_wav
(
wav_path
)[
0
]
# 3.1. Convert the original wav to sph with SoX
sox_utils
.
convert_audio_file
(
src_path
,
flc_path_sox
)
# 3.2. Convert the sph to wav with Sox
# converting to 32 bit because sph file has 24 bit depth which scipy cannot handle.
sox_utils
.
convert_audio_file
(
flc_path_sox
,
wav_path_sox
,
bit_depth
=
32
)
# 3.3. Load
expected
=
load_wav
(
wav_path_sox
)[
0
]
self
.
assertEqual
(
found
,
expected
)
self
.
assertEqual
(
found
,
expected
)
def
assert_amb
(
self
,
dtype
,
sample_rate
,
num_channels
,
duration
):
"""`sox_io_backend.save` can save amb format.
This test takes the same strategy as mp3 to compare the result
def
nested_params
(
*
params
):
"""
def
_name_func
(
func
,
_
,
params
):
src_path
=
self
.
get_temp_path
(
'1.reference.wav'
)
strs
=
[]
amb_path
=
self
.
get_temp_path
(
'2.1.torchaudio.amb'
)
for
arg
in
params
.
args
:
wav_path
=
self
.
get_temp_path
(
'2.2.torchaudio.wav'
)
if
isinstance
(
arg
,
tuple
):
amb_path_sox
=
self
.
get_temp_path
(
'3.1.sox.amb'
)
strs
.
append
(
"_"
.
join
(
str
(
a
)
for
a
in
arg
))
wav_path_sox
=
self
.
get_temp_path
(
'3.2.sox.wav'
)
else
:
strs
.
append
(
str
(
arg
))
return
f
'
{
func
.
__name__
}
_
{
"_"
.
join
(
strs
)
}
'
# 1. Generate original wav
return
parameterized
.
expand
(
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
list
(
product
(
*
params
)),
save_wav
(
src_path
,
data
,
sample_rate
)
name_func
=
_name_func
# 2.1. Convert the original wav to amb with torchaudio
)
sox_io_backend
.
save
(
amb_path
,
load_wav
(
src_path
,
normalize
=
False
)[
0
],
sample_rate
,
dtype
=
None
)
# 2.2. Convert the amb to wav with Sox
sox_utils
.
convert_audio_file
(
amb_path
,
wav_path
)
# 2.3. Load
found
=
load_wav
(
wav_path
)[
0
]
# 3.1. Convert the original wav to amb with SoX
sox_utils
.
convert_audio_file
(
src_path
,
amb_path_sox
)
# 3.2. Convert the amb to wav with Sox
sox_utils
.
convert_audio_file
(
amb_path_sox
,
wav_path_sox
)
# 3.3. Load
expected
=
load_wav
(
wav_path_sox
)[
0
]
self
.
assertEqual
(
found
,
expected
)
def
assert_amr_nb
(
self
,
duration
):
"""`sox_io_backend.save` can save amr_nb format.
This test takes the same strategy as mp3 to compare the result
"""
sample_rate
=
8000
num_channels
=
1
src_path
=
self
.
get_temp_path
(
'1.reference.wav'
)
amr_path
=
self
.
get_temp_path
(
'2.1.torchaudio.amr-nb'
)
wav_path
=
self
.
get_temp_path
(
'2.2.torchaudio.wav'
)
amr_path_sox
=
self
.
get_temp_path
(
'3.1.sox.amr-nb'
)
wav_path_sox
=
self
.
get_temp_path
(
'3.2.sox.wav'
)
# 1. Generate original wav
data
=
get_wav_data
(
'int16'
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
src_path
,
data
,
sample_rate
)
# 2.1. Convert the original wav to amr_nb with torchaudio
sox_io_backend
.
save
(
amr_path
,
load_wav
(
src_path
,
normalize
=
False
)[
0
],
sample_rate
,
dtype
=
None
)
# 2.2. Convert the amr_nb to wav with Sox
sox_utils
.
convert_audio_file
(
amr_path
,
wav_path
)
# 2.3. Load
found
=
load_wav
(
wav_path
)[
0
]
# 3.1. Convert the original wav to amr_nb with SoX
sox_utils
.
convert_audio_file
(
src_path
,
amr_path_sox
)
# 3.2. Convert the amr_nb to wav with Sox
sox_utils
.
convert_audio_file
(
amr_path_sox
,
wav_path_sox
)
# 3.3. Load
expected
=
load_wav
(
wav_path_sox
)[
0
]
self
.
assertEqual
(
found
,
expected
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExtension
@
skipIfNoExtension
class
TestSave
(
SaveTestBase
):
class
SaveTest
(
SaveTestBase
):
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
nested_params
(
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
"path"
,
"fileobj"
,
"bytesio"
],
[
8000
,
16000
],
[
[
1
,
2
],
(
'PCM_U'
,
8
),
)),
name_func
=
name_func
)
(
'PCM_S'
,
16
),
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
(
'PCM_S'
,
32
),
"""`sox_io_backend.save` can save wav format."""
(
'PCM_F'
,
32
),
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
num_frames
=
None
)
(
'PCM_F'
,
64
),
(
'ULAW'
,
8
),
@
parameterized
.
expand
(
list
(
itertools
.
product
(
(
'ALAW'
,
8
),
[
'float32'
],
],
[
16000
],
)
[
2
],
def
test_save_wav
(
self
,
test_mode
,
enc_params
):
)),
name_func
=
name_func
)
encoding
,
bits_per_sample
=
enc_params
def
test_wav_large
(
self
,
dtype
,
sample_rate
,
num_channels
):
self
.
assert_save_consistency
(
"""`sox_io_backend.save` can save large wav file."""
"wav"
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
test_mode
=
test_mode
)
two_hours
=
2
*
60
*
60
*
sample_rate
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
num_frames
=
two_hours
)
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
(
'float32'
,
),
[
4
,
8
,
16
,
32
],
(
'int32'
,
),
)),
name_func
=
name_func
)
(
'int16'
,
),
def
test_multiple_channels
(
self
,
dtype
,
num_channels
):
(
'uint8'
,
),
"""`sox_io_backend.save` can save wav with more than 2 channels."""
],
)
def
test_save_wav_dtype
(
self
,
test_mode
,
params
):
dtype
,
=
params
self
.
assert_save_consistency
(
"wav"
,
src_dtype
=
dtype
,
test_mode
=
test_mode
)
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
[
None
,
-
4.2
,
-
0.2
,
0
,
0.2
,
96
,
128
,
160
,
192
,
224
,
256
,
320
,
],
)
def
test_save_mp3
(
self
,
test_mode
,
bit_rate
):
if
test_mode
in
[
"fileobj"
,
"bytesio"
]:
if
bit_rate
is
not
None
and
bit_rate
<
1
:
raise
unittest
.
SkipTest
(
"mp3 format with variable bit rate is known to "
"not yield the exact same result as sox command."
)
self
.
assert_save_consistency
(
"mp3"
,
compression
=
bit_rate
,
test_mode
=
test_mode
)
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
[
8
,
16
,
24
],
[
None
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
],
)
def
test_save_flac
(
self
,
test_mode
,
bits_per_sample
,
compression_level
):
self
.
assert_save_consistency
(
"flac"
,
compression
=
compression_level
,
bits_per_sample
=
bits_per_sample
,
test_mode
=
test_mode
)
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
[
None
,
-
1
,
0
,
1
,
2
,
3
,
3.6
,
5
,
10
,
],
)
def
test_save_vorbis
(
self
,
test_mode
,
quality_level
):
self
.
assert_save_consistency
(
"vorbis"
,
compression
=
quality_level
,
test_mode
=
test_mode
)
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
[
(
'PCM_S'
,
8
,
),
(
'PCM_S'
,
16
,
),
(
'PCM_S'
,
24
,
),
(
'PCM_S'
,
32
,
),
(
'ULAW'
,
8
),
(
'ALAW'
,
8
),
(
'ALAW'
,
16
),
(
'ALAW'
,
24
),
(
'ALAW'
,
32
),
],
)
def
test_save_sphere
(
self
,
test_mode
,
enc_params
):
encoding
,
bits_per_sample
=
enc_params
self
.
assert_save_consistency
(
"sph"
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
test_mode
=
test_mode
)
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
[
(
'PCM_U'
,
8
,
),
(
'PCM_S'
,
16
,
),
(
'PCM_S'
,
24
,
),
(
'PCM_S'
,
32
,
),
(
'PCM_F'
,
32
,
),
(
'PCM_F'
,
64
,
),
(
'ULAW'
,
8
,
),
(
'ALAW'
,
8
,
),
],
)
def
test_save_amb
(
self
,
test_mode
,
enc_params
):
encoding
,
bits_per_sample
=
enc_params
self
.
assert_save_consistency
(
"amb"
,
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
,
test_mode
=
test_mode
)
@
nested_params
(
[
"path"
,
"fileobj"
,
"bytesio"
],
[
None
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
],
)
def
test_save_amr_nb
(
self
,
test_mode
,
bit_rate
):
self
.
assert_save_consistency
(
"amr-nb"
,
compression
=
bit_rate
,
num_channels
=
1
,
test_mode
=
test_mode
)
@
parameterized
.
expand
([
(
"wav"
,
"PCM_S"
,
16
),
(
"mp3"
,
),
(
"flac"
,
),
(
"vorbis"
,
),
(
"sph"
,
"PCM_S"
,
16
),
(
"amr-nb"
,
),
(
"amb"
,
"PCM_S"
,
16
),
],
name_func
=
name_func
)
def
test_save_large
(
self
,
format
,
encoding
=
None
,
bits_per_sample
=
None
):
"""`sox_io_backend.save` can save large files."""
sample_rate
=
8000
sample_rate
=
8000
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
num_frames
=
None
)
one_hour
=
60
*
60
*
sample_rate
self
.
assert_save_consistency
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
format
,
num_channels
=
1
,
sample_rate
=
8000
,
num_frames
=
one_hour
,
[
8000
,
16000
],
encoding
=
encoding
,
bits_per_sample
=
bits_per_sample
)
[
1
,
2
],
[
-
4.2
,
-
0.2
,
0
,
0.2
,
96
,
128
,
160
,
192
,
224
,
256
,
320
],
@
parameterized
.
expand
([
)),
name_func
=
name_func
)
(
32
,
),
def
test_mp3
(
self
,
sample_rate
,
num_channels
,
bit_rate
):
(
64
,
),
"""`sox_io_backend.save` can save mp3 format."""
(
128
,
),
self
.
assert_mp3
(
sample_rate
,
num_channels
,
bit_rate
,
duration
=
1
)
(
256
,
),
],
name_func
=
name_func
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
def
test_save_multi_channels
(
self
,
num_channels
):
[
16000
],
"""`sox_io_backend.save` can save audio with many channels"""
[
2
],
self
.
assert_save_consistency
(
[
128
],
"wav"
,
encoding
=
"PCM_S"
,
bits_per_sample
=
16
,
)),
name_func
=
name_func
)
num_channels
=
num_channels
)
def
test_mp3_large
(
self
,
sample_rate
,
num_channels
,
bit_rate
):
"""`sox_io_backend.save` can save large mp3 file."""
two_hours
=
2
*
60
*
60
self
.
assert_mp3
(
sample_rate
,
num_channels
,
bit_rate
,
duration
=
two_hours
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
[
None
]
+
list
(
range
(
9
)),
)),
name_func
=
name_func
)
def
test_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
"""`sox_io_backend.save` can save flac format."""
self
.
assert_flac
(
sample_rate
,
num_channels
,
compression_level
,
duration
=
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
16000
],
[
2
],
[
0
],
)),
name_func
=
name_func
)
def
test_flac_large
(
self
,
sample_rate
,
num_channels
,
compression_level
):
"""`sox_io_backend.save` can save large flac file."""
two_hours
=
2
*
60
*
60
self
.
assert_flac
(
sample_rate
,
num_channels
,
compression_level
,
duration
=
two_hours
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
[
None
,
-
1
,
0
,
1
,
2
,
3
,
3.6
,
5
,
10
],
)),
name_func
=
name_func
)
def
test_vorbis
(
self
,
sample_rate
,
num_channels
,
quality_level
):
"""`sox_io_backend.save` can save vorbis format."""
self
.
assert_vorbis
(
sample_rate
,
num_channels
,
quality_level
,
duration
=
20
)
# note: torchaudio can load large vorbis file, but cannot save large volbis file
# the following test causes Segmentation fault
#
'''
@parameterized.expand(list(itertools.product(
[16000],
[2],
[10],
)), name_func=name_func)
def test_vorbis_large(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.save` can save large vorbis file correctly."""
two_hours = 2 * 60 * 60
self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours)
'''
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
)),
name_func
=
name_func
)
def
test_sphere
(
self
,
sample_rate
,
num_channels
):
"""`sox_io_backend.save` can save sph format."""
self
.
assert_sphere
(
sample_rate
,
num_channels
,
duration
=
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
8000
,
16000
],
[
1
,
2
],
)),
name_func
=
name_func
)
def
test_amb
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`sox_io_backend.save` can save amb format."""
self
.
assert_amb
(
dtype
,
sample_rate
,
num_channels
,
duration
=
1
)
def
test_amr_nb
(
self
):
"""`sox_io_backend.save` can save amr-nb format."""
self
.
assert_amr_nb
(
duration
=
1
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExec
(
'sox'
)
...
@@ -385,136 +345,40 @@ class TestSave(SaveTestBase):
...
@@ -385,136 +345,40 @@ class TestSave(SaveTestBase):
class
TestSaveParams
(
TempDirMixin
,
PytorchTestCase
):
class
TestSaveParams
(
TempDirMixin
,
PytorchTestCase
):
"""Test the correctness of optional parameters of `sox_io_backend.save`"""
"""Test the correctness of optional parameters of `sox_io_backend.save`"""
@
parameterized
.
expand
([(
True
,
),
(
False
,
)],
name_func
=
name_func
)
@
parameterized
.
expand
([(
True
,
),
(
False
,
)],
name_func
=
name_func
)
def
test_channels_first
(
self
,
channels_first
):
def
test_
save_
channels_first
(
self
,
channels_first
):
"""channels_first swaps axes"""
"""channels_first swaps axes"""
path
=
self
.
get_temp_path
(
'data.wav'
)
path
=
self
.
get_temp_path
(
'data.wav'
)
data
=
get_wav_data
(
'int32'
,
2
,
channels_first
=
channels_first
)
data
=
get_wav_data
(
'int16'
,
2
,
channels_first
=
channels_first
,
normalize
=
False
)
sox_io_backend
.
save
(
sox_io_backend
.
save
(
path
,
data
,
8000
,
channels_first
=
channels_first
,
dtype
=
None
)
path
,
data
,
8000
,
channels_first
=
channels_first
)
found
=
load_wav
(
path
)[
0
]
found
=
load_wav
(
path
,
normalize
=
False
)[
0
]
expected
=
data
if
channels_first
else
data
.
transpose
(
1
,
0
)
expected
=
data
if
channels_first
else
data
.
transpose
(
1
,
0
)
self
.
assertEqual
(
found
,
expected
)
self
.
assertEqual
(
found
,
expected
)
@
parameterized
.
expand
([
@
parameterized
.
expand
([
'float32'
,
'int32'
,
'int16'
,
'uint8'
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
name_func
=
name_func
)
],
name_func
=
name_func
)
def
test_noncontiguous
(
self
,
dtype
):
def
test_
save_
noncontiguous
(
self
,
dtype
):
"""Noncontiguous tensors are saved correctly"""
"""Noncontiguous tensors are saved correctly"""
path
=
self
.
get_temp_path
(
'data.wav'
)
path
=
self
.
get_temp_path
(
'data.wav'
)
expected
=
get_wav_data
(
dtype
,
4
)[::
2
,
::
2
]
enc
,
bps
=
get_enc_params
(
dtype
)
expected
=
get_wav_data
(
dtype
,
4
,
normalize
=
False
)[::
2
,
::
2
]
assert
not
expected
.
is_contiguous
()
assert
not
expected
.
is_contiguous
()
sox_io_backend
.
save
(
path
,
expected
,
8000
,
dtype
=
None
)
sox_io_backend
.
save
(
found
=
load_wav
(
path
)[
0
]
path
,
expected
,
8000
,
encoding
=
enc
,
bits_per_sample
=
bps
)
found
=
load_wav
(
path
,
normalize
=
False
)[
0
]
self
.
assertEqual
(
found
,
expected
)
self
.
assertEqual
(
found
,
expected
)
@
parameterized
.
expand
([
@
parameterized
.
expand
([
'float32'
,
'int32'
,
'int16'
,
'uint8'
,
'float32'
,
'int32'
,
'int16'
,
'uint8'
,
])
])
def
test_tensor_preserve
(
self
,
dtype
):
def
test_
save_
tensor_preserve
(
self
,
dtype
):
"""save function should not alter Tensor"""
"""save function should not alter Tensor"""
path
=
self
.
get_temp_path
(
'data.wav'
)
path
=
self
.
get_temp_path
(
'data.wav'
)
expected
=
get_wav_data
(
dtype
,
4
)[::
2
,
::
2
]
expected
=
get_wav_data
(
dtype
,
4
,
normalize
=
False
)[::
2
,
::
2
]
data
=
expected
.
clone
()
data
=
expected
.
clone
()
sox_io_backend
.
save
(
path
,
data
,
8000
,
dtype
=
None
)
sox_io_backend
.
save
(
path
,
data
,
8000
)
self
.
assertEqual
(
data
,
expected
)
self
.
assertEqual
(
data
,
expected
)
@
parameterized
.
expand
([
(
'float32'
,
torch
.
tensor
([
-
1.0
,
-
0.5
,
0
,
0.5
,
1.0
]).
to
(
torch
.
float32
)),
(
'int32'
,
torch
.
tensor
([
-
2147483648
,
-
1073741824
,
0
,
1073741824
,
2147483647
]).
to
(
torch
.
int32
)),
(
'int16'
,
torch
.
tensor
([
-
32768
,
-
16384
,
0
,
16384
,
32767
]).
to
(
torch
.
int16
)),
(
'uint8'
,
torch
.
tensor
([
0
,
64
,
128
,
192
,
255
]).
to
(
torch
.
uint8
)),
])
def
test_dtype_conversion
(
self
,
dtype
,
expected
):
"""`save` performs dtype conversion on float32 src tensors only."""
path
=
self
.
get_temp_path
(
"data.wav"
)
data
=
torch
.
tensor
([
-
1.0
,
-
0.5
,
0
,
0.5
,
1.0
]).
to
(
torch
.
float32
).
view
(
-
1
,
1
)
sox_io_backend
.
save
(
path
,
data
,
8000
,
dtype
=
dtype
)
found
=
load_wav
(
path
,
normalize
=
False
)[
0
]
self
.
assertEqual
(
found
,
expected
.
view
(
-
1
,
1
))
@
skipIfNoExtension
@
skipIfNoExec
(
'sox'
)
class
TestFileObject
(
SaveTestBase
):
"""
We campare the result of file-like object input against file path input because
`save` function is rigrously tested for file path inputs to match libsox's result,
"""
@
parameterized
.
expand
([
(
'wav'
,
None
),
(
'mp3'
,
128
),
(
'mp3'
,
320
),
(
'flac'
,
0
),
(
'flac'
,
5
),
(
'flac'
,
8
),
(
'vorbis'
,
-
1
),
(
'vorbis'
,
10
),
(
'amb'
,
None
),
])
def
test_fileobj
(
self
,
ext
,
compression
):
"""Saving audio to file object returns the same result as via file path."""
sample_rate
=
16000
dtype
=
'float32'
num_channels
=
2
num_frames
=
16000
channels_first
=
True
data
=
get_wav_data
(
dtype
,
num_channels
,
num_frames
=
num_frames
)
ref_path
=
self
.
get_temp_path
(
f
'reference.
{
ext
}
'
)
res_path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
sox_io_backend
.
save
(
ref_path
,
data
,
channels_first
=
channels_first
,
sample_rate
=
sample_rate
,
compression
=
compression
,
dtype
=
None
)
with
open
(
res_path
,
'wb'
)
as
fileobj
:
sox_io_backend
.
save
(
fileobj
,
data
,
channels_first
=
channels_first
,
sample_rate
=
sample_rate
,
compression
=
compression
,
format
=
ext
,
dtype
=
None
)
expected_data
,
_
=
sox_io_backend
.
load
(
ref_path
)
data
,
sr
=
sox_io_backend
.
load
(
res_path
)
assert
sample_rate
==
sr
self
.
assertEqual
(
expected_data
,
data
)
@
parameterized
.
expand
([
(
'wav'
,
None
),
(
'mp3'
,
128
),
(
'mp3'
,
320
),
(
'flac'
,
0
),
(
'flac'
,
5
),
(
'flac'
,
8
),
(
'vorbis'
,
-
1
),
(
'vorbis'
,
10
),
(
'amb'
,
None
),
])
def
test_bytesio
(
self
,
ext
,
compression
):
"""Saving audio to BytesIO object returns the same result as via file path."""
sample_rate
=
16000
dtype
=
'float32'
num_channels
=
2
num_frames
=
16000
channels_first
=
True
data
=
get_wav_data
(
dtype
,
num_channels
,
num_frames
=
num_frames
)
ref_path
=
self
.
get_temp_path
(
f
'reference.
{
ext
}
'
)
res_path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
sox_io_backend
.
save
(
ref_path
,
data
,
channels_first
=
channels_first
,
sample_rate
=
sample_rate
,
compression
=
compression
,
dtype
=
None
)
fileobj
=
io
.
BytesIO
()
sox_io_backend
.
save
(
fileobj
,
data
,
channels_first
=
channels_first
,
sample_rate
=
sample_rate
,
compression
=
compression
,
format
=
ext
,
dtype
=
None
)
fileobj
.
seek
(
0
)
with
open
(
res_path
,
'wb'
)
as
file_
:
file_
.
write
(
fileobj
.
read
())
expected_data
,
_
=
sox_io_backend
.
load
(
ref_path
)
data
,
sr
=
sox_io_backend
.
load
(
res_path
)
assert
sample_rate
==
sr
self
.
assertEqual
(
expected_data
,
data
)
test/torchaudio_unittest/backend/sox_io/torchscript_test.py
View file @
c3cb2015
...
@@ -17,6 +17,7 @@ from torchaudio_unittest.common_utils import (
...
@@ -17,6 +17,7 @@ from torchaudio_unittest.common_utils import (
)
)
from
.common
import
(
from
.common
import
(
name_func
,
name_func
,
get_enc_params
,
)
)
...
@@ -35,8 +36,12 @@ def py_save_func(
...
@@ -35,8 +36,12 @@ def py_save_func(
sample_rate
:
int
,
sample_rate
:
int
,
channels_first
:
bool
=
True
,
channels_first
:
bool
=
True
,
compression
:
Optional
[
float
]
=
None
,
compression
:
Optional
[
float
]
=
None
,
encoding
:
Optional
[
str
]
=
None
,
bits_per_sample
:
Optional
[
int
]
=
None
,
):
):
torchaudio
.
save
(
filepath
,
tensor
,
sample_rate
,
channels_first
,
compression
)
torchaudio
.
save
(
filepath
,
tensor
,
sample_rate
,
channels_first
,
compression
,
None
,
encoding
,
bits_per_sample
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExec
(
'sox'
)
...
@@ -102,15 +107,16 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
...
@@ -102,15 +107,16 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
torch
.
jit
.
script
(
py_save_func
).
save
(
script_path
)
torch
.
jit
.
script
(
py_save_func
).
save
(
script_path
)
ts_save_func
=
torch
.
jit
.
load
(
script_path
)
ts_save_func
=
torch
.
jit
.
load
(
script_path
)
expected
=
get_wav_data
(
dtype
,
num_channels
)
expected
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
)
py_path
=
self
.
get_temp_path
(
f
'test_save_py_
{
dtype
}
_
{
sample_rate
}
_
{
num_channels
}
.wav'
)
py_path
=
self
.
get_temp_path
(
f
'test_save_py_
{
dtype
}
_
{
sample_rate
}
_
{
num_channels
}
.wav'
)
ts_path
=
self
.
get_temp_path
(
f
'test_save_ts_
{
dtype
}
_
{
sample_rate
}
_
{
num_channels
}
.wav'
)
ts_path
=
self
.
get_temp_path
(
f
'test_save_ts_
{
dtype
}
_
{
sample_rate
}
_
{
num_channels
}
.wav'
)
enc
,
bps
=
get_enc_params
(
dtype
)
py_save_func
(
py_path
,
expected
,
sample_rate
,
True
,
None
)
py_save_func
(
py_path
,
expected
,
sample_rate
,
True
,
None
,
enc
,
bps
)
ts_save_func
(
ts_path
,
expected
,
sample_rate
,
True
,
None
)
ts_save_func
(
ts_path
,
expected
,
sample_rate
,
True
,
None
,
enc
,
bps
)
py_data
,
py_sr
=
load_wav
(
py_path
)
py_data
,
py_sr
=
load_wav
(
py_path
,
normalize
=
False
)
ts_data
,
ts_sr
=
load_wav
(
ts_path
)
ts_data
,
ts_sr
=
load_wav
(
ts_path
,
normalize
=
False
)
self
.
assertEqual
(
sample_rate
,
py_sr
)
self
.
assertEqual
(
sample_rate
,
py_sr
)
self
.
assertEqual
(
sample_rate
,
ts_sr
)
self
.
assertEqual
(
sample_rate
,
ts_sr
)
...
@@ -131,8 +137,8 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
...
@@ -131,8 +137,8 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
py_path
=
self
.
get_temp_path
(
f
'test_save_py_
{
sample_rate
}
_
{
num_channels
}
_
{
compression_level
}
.flac'
)
py_path
=
self
.
get_temp_path
(
f
'test_save_py_
{
sample_rate
}
_
{
num_channels
}
_
{
compression_level
}
.flac'
)
ts_path
=
self
.
get_temp_path
(
f
'test_save_ts_
{
sample_rate
}
_
{
num_channels
}
_
{
compression_level
}
.flac'
)
ts_path
=
self
.
get_temp_path
(
f
'test_save_ts_
{
sample_rate
}
_
{
num_channels
}
_
{
compression_level
}
.flac'
)
py_save_func
(
py_path
,
expected
,
sample_rate
,
True
,
compression_level
)
py_save_func
(
py_path
,
expected
,
sample_rate
,
True
,
compression_level
,
None
,
None
)
ts_save_func
(
ts_path
,
expected
,
sample_rate
,
True
,
compression_level
)
ts_save_func
(
ts_path
,
expected
,
sample_rate
,
True
,
compression_level
,
None
,
None
)
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
py_path_wav
=
f
'
{
py_path
}
.wav'
py_path_wav
=
f
'
{
py_path
}
.wav'
...
...
test/torchaudio_unittest/common_utils/sox_utils.py
View file @
c3cb2015
import
sys
import
subprocess
import
subprocess
import
warnings
import
warnings
...
@@ -32,6 +33,7 @@ def gen_audio_file(
...
@@ -32,6 +33,7 @@ def gen_audio_file(
command
=
[
command
=
[
'sox'
,
'sox'
,
'-V3'
,
# verbose
'-V3'
,
# verbose
'--no-dither'
,
# disable automatic dithering
'-R'
,
'-R'
,
# -R is supposed to be repeatable, though the implementation looks suspicious
# -R is supposed to be repeatable, though the implementation looks suspicious
# and not setting the seed to a fixed value.
# and not setting the seed to a fixed value.
...
@@ -61,21 +63,23 @@ def gen_audio_file(
...
@@ -61,21 +63,23 @@ def gen_audio_file(
]
]
if
attenuation
is
not
None
:
if
attenuation
is
not
None
:
command
+=
[
'vol'
,
f
'-
{
attenuation
}
dB'
]
command
+=
[
'vol'
,
f
'-
{
attenuation
}
dB'
]
print
(
' '
.
join
(
command
))
print
(
' '
.
join
(
command
)
,
file
=
sys
.
stderr
)
subprocess
.
run
(
command
,
check
=
True
)
subprocess
.
run
(
command
,
check
=
True
)
def
convert_audio_file
(
def
convert_audio_file
(
src_path
,
dst_path
,
src_path
,
dst_path
,
*
,
bit_depth
=
None
,
compression
=
None
):
*
,
encoding
=
None
,
bit_depth
=
None
,
compression
=
None
):
"""Convert audio file with `sox` command."""
"""Convert audio file with `sox` command."""
command
=
[
'sox'
,
'-V3'
,
'-R'
,
str
(
src_path
)]
command
=
[
'sox'
,
'-V3'
,
'--no-dither'
,
'-R'
,
str
(
src_path
)]
if
encoding
is
not
None
:
command
+=
[
'--encoding'
,
str
(
encoding
)]
if
bit_depth
is
not
None
:
if
bit_depth
is
not
None
:
command
+=
[
'--bits'
,
str
(
bit_depth
)]
command
+=
[
'--bits'
,
str
(
bit_depth
)]
if
compression
is
not
None
:
if
compression
is
not
None
:
command
+=
[
'--compression'
,
str
(
compression
)]
command
+=
[
'--compression'
,
str
(
compression
)]
command
+=
[
dst_path
]
command
+=
[
dst_path
]
print
(
' '
.
join
(
command
))
print
(
' '
.
join
(
command
)
,
file
=
sys
.
stderr
)
subprocess
.
run
(
command
,
check
=
True
)
subprocess
.
run
(
command
,
check
=
True
)
...
...
torchaudio/backend/sox_io_backend.py
View file @
c3cb2015
import
os
import
os
import
warnings
from
typing
import
Tuple
,
Optional
from
typing
import
Tuple
,
Optional
import
torch
import
torch
...
@@ -152,26 +151,6 @@ def load(
...
@@ -152,26 +151,6 @@ def load(
filepath
,
frame_offset
,
num_frames
,
normalize
,
channels_first
,
format
)
filepath
,
frame_offset
,
num_frames
,
normalize
,
channels_first
,
format
)
@
torch
.
jit
.
unused
def
_save
(
filepath
:
str
,
src
:
torch
.
Tensor
,
sample_rate
:
int
,
channels_first
:
bool
=
True
,
compression
:
Optional
[
float
]
=
None
,
format
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
str
]
=
None
,
):
if
hasattr
(
filepath
,
'write'
):
if
format
is
None
:
raise
RuntimeError
(
'`format` is required when saving to file object.'
)
torchaudio
.
_torchaudio
.
save_audio_fileobj
(
filepath
,
src
,
sample_rate
,
channels_first
,
compression
,
format
,
dtype
)
else
:
torch
.
ops
.
torchaudio
.
sox_io_save_audio_file
(
os
.
fspath
(
filepath
),
src
,
sample_rate
,
channels_first
,
compression
,
format
,
dtype
)
@
_mod_utils
.
requires_module
(
'torchaudio._torchaudio'
)
@
_mod_utils
.
requires_module
(
'torchaudio._torchaudio'
)
def
save
(
def
save
(
filepath
:
str
,
filepath
:
str
,
...
@@ -180,30 +159,11 @@ def save(
...
@@ -180,30 +159,11 @@ def save(
channels_first
:
bool
=
True
,
channels_first
:
bool
=
True
,
compression
:
Optional
[
float
]
=
None
,
compression
:
Optional
[
float
]
=
None
,
format
:
Optional
[
str
]
=
None
,
format
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
str
]
=
None
,
encoding
:
Optional
[
str
]
=
None
,
bits_per_sample
:
Optional
[
int
]
=
None
,
):
):
"""Save audio data to file.
"""Save audio data to file.
Note:
Supported formats are;
* WAV, AMB
* 32-bit floating-point
* 32-bit signed integer
* 16-bit signed integer
* 8-bit unsigned integer
* MP3
* FLAC
* OGG/VORBIS
* SPHERE
* AMR-NB
To save ``MP3``, ``FLAC``, ``OGG/VORBIS``, and other codecs ``libsox`` does not
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
and corresponding codec libraries such as ``libmad`` or ``libmp3lame`` etc.
Args:
Args:
filepath (str or pathlib.Path): Path to save file.
filepath (str or pathlib.Path): Path to save file.
This function also handles ``pathlib.Path`` objects, but is annotated
This function also handles ``pathlib.Path`` objects, but is annotated
...
@@ -215,32 +175,137 @@ def save(
...
@@ -215,32 +175,137 @@ def save(
compression (Optional[float]): Used for formats other than WAV.
compression (Optional[float]): Used for formats other than WAV.
This corresponds to ``-C`` option of ``sox`` command.
This corresponds to ``-C`` option of ``sox`` command.
* | ``MP3``: Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or
``"mp3"``
| VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``.
Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or
* | ``FLAC``: compression level. Whole number from ``0`` to ``8``.
VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``.
| ``8`` is default and highest compression.
* | ``OGG/VORBIS``: number from ``-1`` to ``10``; ``-1`` is the highest compression
``"flac"``
| and lowest quality. Default: ``3``.
Whole number from ``0`` to ``8``. ``8`` is default and highest compression.
``"ogg"``, ``"vorbis"``
Number from ``-1`` to ``10``; ``-1`` is the highest compression
and lowest quality. Default: ``3``.
See the detail at http://sox.sourceforge.net/soxformat.html.
See the detail at http://sox.sourceforge.net/soxformat.html.
format (str, optional): Output audio format.
format (str, optional): Override the audio format.
This is required when the output audio format cannot be infered from
When ``filepath`` argument is path-like object, audio format is infered from
``filepath``, (such as file extension or ``name`` attribute of the given file object).
file extension. If file extension is missing or different, you can specify the
dtype (str, optional): Output tensor dtype.
correct format with this argument.
Valid values: ``"uint8", "int16", "int32", "float32", "float64", None``
``dtype=None`` means no conversion is performed.
When ``filepath`` argument is file-like object, this argument is required.
``dtype`` parameter is only effective for ``float32`` Tensor.
Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``,
``"amb"``, ``"flac"`` and ``"sph"``.
encoding (str, optional): Changes the encoding for the supported formats.
This argument is effective only for supported formats, cush as ``"wav"``, ``""amb"``
and ``"sph"``. Valid values are;
- ``"PCM_S"`` (signed integer Linear PCM)
- ``"PCM_U"`` (unsigned integer Linear PCM)
- ``"PCM_F"`` (floating point PCM)
- ``"ULAW"`` (mu-law)
- ``"ALAW"`` (a-law)
Default values
If not provided, the default value is picked based on ``format`` and ``bits_per_sample``.
``"wav"``, ``"amb"``
- | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the
| Tensor is used to determine the default value.
- ``"PCM_U"`` if dtype is ``uint8``
- ``"PCM_S"`` if dtype is ``int16`` or ``int32`
- ``"PCM_F"`` if dtype is ``float32``
- ``"PCM_U"`` if ``bits_per_sample=8``
- ``"PCM_S"`` otherwise
``"sph"`` format;
- the default value is ``"PCM_S"``
bits_per_sample (int, optional): Changes the bit depth for the supported formats.
When ``format`` is one of ``"wav"``, ``"flac"``, ``"sph"``, or ``"amb"``, you can change the
bit depth. Valid values are ``8``, ``16``, ``32`` and ``64``.
Default Value;
If not provided, the default values are picked based on ``format`` and ``"encoding"``;
``"wav"``, ``"amb"``;
- | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the
| Tensor is used.
- ``8`` if dtype is ``uint8``
- ``16`` if dtype is ``int16``
- ``32`` if dtype is ``int32`` or ``float32``
- ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"``
- ``16`` if ``encoding`` is ``"PCM_S"``
- ``32`` if ``encoding`` is ``"PCM_F"``
``"flac"`` format;
- the default value is ``24``
``"sph"`` format;
- ``16`` if ``encoding`` is ``"PCM_U"``, ``"PCM_S"``, ``"PCM_F"`` or not provided.
- ``8`` if ``encoding`` is ``"ULAW"`` or ``"ALAW"``
``"amb"`` format;
- ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"``
- ``16`` if ``encoding`` is ``"PCM_S"`` or not provided.
- ``32`` if ``encoding`` is ``"PCM_F"``
Supported formats/encodings/bit depth/compression are;
``"wav"``, ``"amb"``
- 32-bit floating-point PCM
- 32-bit signed integer PCM
- 24-bit signed integer PCM
- 16-bit signed integer PCM
- 8-bit unsigned integer PCM
- 8-bit mu-law
- 8-bit a-law
Note: Default encoding/bit depth is determined by the dtype of the input Tensor.
``"mp3"``
Fixed bit rate (such as 128kHz) and variable bit rate compression.
Default: VBR with high quality.
``"flac"``
- 8-bit
- 16-bit
- 24-bit (default)
``"ogg"``, ``"vorbis"``
- Different quality level. Default: approx. 112kbps
``"sph"``
- 8-bit signed integer PCM
- 16-bit signed integer PCM
- 24-bit signed integer PCM
- 32-bit signed integer PCM (default)
- 8-bit mu-law
- 8-bit a-law
- 16-bit a-law
- 24-bit a-law
- 32-bit a-law
``"amr-nb"``
Bitrate ranging from 4.75 kbit/s to 12.2 kbit/s. Default: 4.75 kbit/s
Note:
To save into formats that ``libsox`` does not handle natively, (such as ``"mp3"``,
``"flac"``, ``"ogg"`` and ``"vorbis"``), your installation of ``torchaudio`` has
to be linked to ``libsox`` and corresponding codec libraries such as ``libmad``
or ``libmp3lame`` etc.
"""
"""
if
src
.
dtype
==
torch
.
float32
and
dtype
is
None
:
warnings
.
warn
(
'`dtype` default value will be changed to `int16` in 0.9 release.'
'Specify `dtype` to suppress this warning.'
)
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_save
(
filepath
,
src
,
sample_rate
,
channels_first
,
compression
,
format
,
dtype
)
if
hasattr
(
filepath
,
'write'
):
return
torchaudio
.
_torchaudio
.
save_audio_fileobj
(
filepath
,
src
,
sample_rate
,
channels_first
,
compression
,
format
,
encoding
,
bits_per_sample
)
return
filepath
=
os
.
fspath
(
filepath
)
torch
.
ops
.
torchaudio
.
sox_io_save_audio_file
(
torch
.
ops
.
torchaudio
.
sox_io_save_audio_file
(
filepath
,
src
,
sample_rate
,
channels_first
,
compression
,
format
,
dtyp
e
)
filepath
,
src
,
sample_rate
,
channels_first
,
compression
,
format
,
encoding
,
bits_per_sampl
e
)
@
_mod_utils
.
requires_module
(
'torchaudio._torchaudio'
)
@
_mod_utils
.
requires_module
(
'torchaudio._torchaudio'
)
...
...
torchaudio/csrc/CMakeLists.txt
View file @
c3cb2015
...
@@ -9,6 +9,7 @@ set(
...
@@ -9,6 +9,7 @@ set(
sox/utils.cpp
sox/utils.cpp
sox/effects.cpp
sox/effects.cpp
sox/effects_chain.cpp
sox/effects_chain.cpp
sox/types.cpp
)
)
if
(
BUILD_TRANSDUCER
)
if
(
BUILD_TRANSDUCER
)
...
...
torchaudio/csrc/sox/effects_chain.cpp
View file @
c3cb2015
...
@@ -68,21 +68,43 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
...
@@ -68,21 +68,43 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
// Ensure that it's a multiple of the number of channels
// Ensure that it's a multiple of the number of channels
*
osamp
-=
*
osamp
%
num_channels
;
*
osamp
-=
*
osamp
%
num_channels
;
// Slice the input Tensor
and unnormalize the values
// Slice the input Tensor
const
auto
tensor_
=
[
&
]()
{
const
auto
tensor_
=
[
&
]()
{
auto
i_frame
=
index
/
num_channels
;
auto
i_frame
=
index
/
num_channels
;
auto
num_frames
=
*
osamp
/
num_channels
;
auto
num_frames
=
*
osamp
/
num_channels
;
auto
t
=
(
priv
->
channels_first
)
auto
t
=
(
priv
->
channels_first
)
?
tensor
.
index
({
Slice
(),
Slice
(
i_frame
,
i_frame
+
num_frames
)}).
t
()
?
tensor
.
index
({
Slice
(),
Slice
(
i_frame
,
i_frame
+
num_frames
)}).
t
()
:
tensor
.
index
({
Slice
(
i_frame
,
i_frame
+
num_frames
),
Slice
()});
:
tensor
.
index
({
Slice
(
i_frame
,
i_frame
+
num_frames
),
Slice
()});
return
unnormalize_wav
(
t
.
reshape
({
-
1
})
)
.
contiguous
();
return
t
.
reshape
({
-
1
}).
contiguous
();
}();
}();
priv
->
index
+=
*
osamp
;
// Write data to SoxEffectsChain buffer.
auto
ptr
=
tensor_
.
data_ptr
<
int32_t
>
();
std
::
copy
(
ptr
,
ptr
+
*
osamp
,
obuf
);
// Convert to sox_sample_t (int32_t) and write to buffer
SOX_SAMPLE_LOCALS
;
const
auto
dtype
=
tensor_
.
dtype
();
if
(
dtype
==
torch
::
kFloat32
)
{
auto
ptr
=
tensor_
.
data_ptr
<
float_t
>
();
for
(
size_t
i
=
0
;
i
<
*
osamp
;
++
i
)
{
obuf
[
i
]
=
SOX_FLOAT_32BIT_TO_SAMPLE
(
ptr
[
i
],
effp
->
clips
);
}
}
else
if
(
dtype
==
torch
::
kInt32
)
{
auto
ptr
=
tensor_
.
data_ptr
<
int32_t
>
();
for
(
size_t
i
=
0
;
i
<
*
osamp
;
++
i
)
{
obuf
[
i
]
=
SOX_SIGNED_32BIT_TO_SAMPLE
(
ptr
[
i
],
effp
->
clips
);
}
}
else
if
(
dtype
==
torch
::
kInt16
)
{
auto
ptr
=
tensor_
.
data_ptr
<
int16_t
>
();
for
(
size_t
i
=
0
;
i
<
*
osamp
;
++
i
)
{
obuf
[
i
]
=
SOX_SIGNED_16BIT_TO_SAMPLE
(
ptr
[
i
],
effp
->
clips
);
}
}
else
if
(
dtype
==
torch
::
kUInt8
)
{
auto
ptr
=
tensor_
.
data_ptr
<
uint8_t
>
();
for
(
size_t
i
=
0
;
i
<
*
osamp
;
++
i
)
{
obuf
[
i
]
=
SOX_UNSIGNED_8BIT_TO_SAMPLE
(
ptr
[
i
],
effp
->
clips
);
}
}
else
{
throw
std
::
runtime_error
(
"Unexpected dtype."
);
}
priv
->
index
+=
*
osamp
;
return
(
priv
->
index
==
num_samples
)
?
SOX_EOF
:
SOX_SUCCESS
;
return
(
priv
->
index
==
num_samples
)
?
SOX_EOF
:
SOX_SUCCESS
;
}
}
...
@@ -430,7 +452,7 @@ int fileobj_output_flow(
...
@@ -430,7 +452,7 @@ int fileobj_output_flow(
fflush
(
fp
);
fflush
(
fp
);
// Copy the encoded chunk to python object.
// Copy the encoded chunk to python object.
fileobj
->
attr
(
"write"
)(
py
::
bytes
(
*
buffer
,
*
buffer_size
));
fileobj
->
attr
(
"write"
)(
py
::
bytes
(
*
buffer
,
ftell
(
fp
)
));
// Reset FILE*
// Reset FILE*
sf
->
tell_off
=
0
;
sf
->
tell_off
=
0
;
...
...
torchaudio/csrc/sox/io.cpp
View file @
c3cb2015
...
@@ -116,35 +116,27 @@ void save_audio_file(
...
@@ -116,35 +116,27 @@ void save_audio_file(
torch
::
Tensor
tensor
,
torch
::
Tensor
tensor
,
int64_t
sample_rate
,
int64_t
sample_rate
,
bool
channels_first
,
bool
channels_first
,
c10
::
optional
<
double
>
compression
,
c10
::
optional
<
double
>&
compression
,
c10
::
optional
<
std
::
string
>
format
,
c10
::
optional
<
std
::
string
>&
format
,
c10
::
optional
<
std
::
string
>
dtype
)
{
c10
::
optional
<
std
::
string
>&
encoding
,
c10
::
optional
<
int64_t
>&
bits_per_sample
)
{
validate_input_tensor
(
tensor
);
validate_input_tensor
(
tensor
);
if
(
tensor
.
dtype
()
!=
torch
::
kFloat32
&&
dtype
.
has_value
())
{
throw
std
::
runtime_error
(
"dtype conversion only supported for float32 tensors"
);
}
const
auto
tgt_dtype
=
(
tensor
.
dtype
()
==
torch
::
kFloat32
&&
dtype
.
has_value
())
?
get_dtype_from_str
(
dtype
.
value
())
:
tensor
.
dtype
();
const
auto
filetype
=
[
&
]()
{
const
auto
filetype
=
[
&
]()
{
if
(
format
.
has_value
())
if
(
format
.
has_value
())
return
format
.
value
();
return
format
.
value
();
return
get_filetype
(
path
);
return
get_filetype
(
path
);
}();
}();
if
(
filetype
==
"amr-nb"
)
{
if
(
filetype
==
"amr-nb"
)
{
const
auto
num_channels
=
tensor
.
size
(
channels_first
?
0
:
1
);
const
auto
num_channels
=
tensor
.
size
(
channels_first
?
0
:
1
);
TORCH_CHECK
(
TORCH_CHECK
(
num_channels
==
1
,
"amr-nb format only supports single channel audio."
);
num_channels
==
1
,
"amr-nb format only supports single channel audio."
);
tensor
=
(
unnormalize_wav
(
tensor
)
/
65536
).
to
(
torch
::
kInt16
);
}
}
const
auto
signal_info
=
const
auto
signal_info
=
get_signalinfo
(
&
tensor
,
sample_rate
,
filetype
,
channels_first
);
get_signalinfo
(
&
tensor
,
sample_rate
,
filetype
,
channels_first
);
const
auto
encoding_info
=
const
auto
encoding_info
=
get_encodinginfo_for_save
(
get_encodinginfo_for_save
(
filetype
,
t
gt_
dtype
,
compression
);
filetype
,
t
ensor
.
dtype
()
,
compression
,
encoding
,
bits_per_sample
);
SoxFormat
sf
(
sox_open_write
(
SoxFormat
sf
(
sox_open_write
(
path
.
c_str
(),
path
.
c_str
(),
...
@@ -258,19 +250,17 @@ void save_audio_fileobj(
...
@@ -258,19 +250,17 @@ void save_audio_fileobj(
torch
::
Tensor
tensor
,
torch
::
Tensor
tensor
,
int64_t
sample_rate
,
int64_t
sample_rate
,
bool
channels_first
,
bool
channels_first
,
c10
::
optional
<
double
>
compression
,
c10
::
optional
<
double
>&
compression
,
std
::
string
filetype
,
c10
::
optional
<
std
::
string
>&
format
,
c10
::
optional
<
std
::
string
>
dtype
)
{
c10
::
optional
<
std
::
string
>&
encoding
,
c10
::
optional
<
int64_t
>&
bits_per_sample
)
{
validate_input_tensor
(
tensor
);
validate_input_tensor
(
tensor
);
if
(
tensor
.
dtype
()
!=
torch
::
kFloat32
&&
dtype
.
has_value
())
{
if
(
!
format
.
has_value
())
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"
dtype conversion only supported for float32 tensors
"
);
"
`format` is required when saving to file object.
"
);
}
}
const
auto
tgt_dtype
=
const
auto
filetype
=
format
.
value
();
(
tensor
.
dtype
()
==
torch
::
kFloat32
&&
dtype
.
has_value
())
?
get_dtype_from_str
(
dtype
.
value
())
:
tensor
.
dtype
();
if
(
filetype
==
"amr-nb"
)
{
if
(
filetype
==
"amr-nb"
)
{
const
auto
num_channels
=
tensor
.
size
(
channels_first
?
0
:
1
);
const
auto
num_channels
=
tensor
.
size
(
channels_first
?
0
:
1
);
...
@@ -278,12 +268,11 @@ void save_audio_fileobj(
...
@@ -278,12 +268,11 @@ void save_audio_fileobj(
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"amr-nb format only supports single channel audio."
);
"amr-nb format only supports single channel audio."
);
}
}
tensor
=
(
unnormalize_wav
(
tensor
)
/
65536
).
to
(
torch
::
kInt16
);
}
}
const
auto
signal_info
=
const
auto
signal_info
=
get_signalinfo
(
&
tensor
,
sample_rate
,
filetype
,
channels_first
);
get_signalinfo
(
&
tensor
,
sample_rate
,
filetype
,
channels_first
);
const
auto
encoding_info
=
const
auto
encoding_info
=
get_encodinginfo_for_save
(
get_encodinginfo_for_save
(
filetype
,
t
gt_
dtype
,
compression
);
filetype
,
t
ensor
.
dtype
()
,
compression
,
encoding
,
bits_per_sample
);
AutoReleaseBuffer
buffer
;
AutoReleaseBuffer
buffer
;
...
...
torchaudio/csrc/sox/io.h
View file @
c3cb2015
...
@@ -28,9 +28,10 @@ void save_audio_file(
...
@@ -28,9 +28,10 @@ void save_audio_file(
torch
::
Tensor
tensor
,
torch
::
Tensor
tensor
,
int64_t
sample_rate
,
int64_t
sample_rate
,
bool
channels_first
,
bool
channels_first
,
c10
::
optional
<
double
>
compression
,
c10
::
optional
<
double
>&
compression
,
c10
::
optional
<
std
::
string
>
format
,
c10
::
optional
<
std
::
string
>&
format
,
c10
::
optional
<
std
::
string
>
dtype
);
c10
::
optional
<
std
::
string
>&
encoding
,
c10
::
optional
<
int64_t
>&
bits_per_sample
);
#ifdef TORCH_API_INCLUDE_EXTENSION_H
#ifdef TORCH_API_INCLUDE_EXTENSION_H
...
@@ -51,9 +52,10 @@ void save_audio_fileobj(
...
@@ -51,9 +52,10 @@ void save_audio_fileobj(
torch
::
Tensor
tensor
,
torch
::
Tensor
tensor
,
int64_t
sample_rate
,
int64_t
sample_rate
,
bool
channels_first
,
bool
channels_first
,
c10
::
optional
<
double
>
compression
,
c10
::
optional
<
double
>&
compression
,
std
::
string
filetype
,
c10
::
optional
<
std
::
string
>&
format
,
c10
::
optional
<
std
::
string
>
dtype
);
c10
::
optional
<
std
::
string
>&
encoding
,
c10
::
optional
<
int64_t
>&
bits_per_sample
);
#endif // TORCH_API_INCLUDE_EXTENSION_H
#endif // TORCH_API_INCLUDE_EXTENSION_H
...
...
torchaudio/csrc/sox/types.cpp
0 → 100644
View file @
c3cb2015
#include <torchaudio/csrc/sox/types.h>
namespace
torchaudio
{
namespace
sox_utils
{
Format
get_format_from_string
(
const
std
::
string
&
format
)
{
if
(
format
==
"wav"
)
return
Format
::
WAV
;
if
(
format
==
"mp3"
)
return
Format
::
MP3
;
if
(
format
==
"flac"
)
return
Format
::
FLAC
;
if
(
format
==
"ogg"
||
format
==
"vorbis"
)
return
Format
::
VORBIS
;
if
(
format
==
"amr-nb"
)
return
Format
::
AMR_NB
;
if
(
format
==
"amr-wb"
)
return
Format
::
AMR_WB
;
if
(
format
==
"amb"
)
return
Format
::
AMB
;
if
(
format
==
"sph"
)
return
Format
::
SPHERE
;
std
::
ostringstream
stream
;
stream
<<
"Internal Error: unexpected format value: "
<<
format
;
throw
std
::
runtime_error
(
stream
.
str
());
}
std
::
string
to_string
(
Encoding
v
)
{
switch
(
v
)
{
case
Encoding
::
UNKNOWN
:
return
"UNKNOWN"
;
case
Encoding
::
PCM_SIGNED
:
return
"PCM_S"
;
case
Encoding
::
PCM_UNSIGNED
:
return
"PCM_U"
;
case
Encoding
::
PCM_FLOAT
:
return
"PCM_F"
;
case
Encoding
::
FLAC
:
return
"FLAC"
;
case
Encoding
::
ULAW
:
return
"ULAW"
;
case
Encoding
::
ALAW
:
return
"ALAW"
;
case
Encoding
::
MP3
:
return
"MP3"
;
case
Encoding
::
VORBIS
:
return
"VORBIS"
;
case
Encoding
::
AMR_WB
:
return
"AMR_WB"
;
case
Encoding
::
AMR_NB
:
return
"AMR_NB"
;
case
Encoding
::
OPUS
:
return
"OPUS"
;
default:
throw
std
::
runtime_error
(
"Internal Error: unexpected encoding."
);
}
}
Encoding
get_encoding_from_option
(
const
c10
::
optional
<
std
::
string
>&
encoding
)
{
if
(
!
encoding
.
has_value
())
return
Encoding
::
NOT_PROVIDED
;
std
::
string
v
=
encoding
.
value
();
if
(
v
==
"PCM_S"
)
return
Encoding
::
PCM_SIGNED
;
if
(
v
==
"PCM_U"
)
return
Encoding
::
PCM_UNSIGNED
;
if
(
v
==
"PCM_F"
)
return
Encoding
::
PCM_FLOAT
;
if
(
v
==
"ULAW"
)
return
Encoding
::
ULAW
;
if
(
v
==
"ALAW"
)
return
Encoding
::
ALAW
;
std
::
ostringstream
stream
;
stream
<<
"Internal Error: unexpected encoding value: "
<<
v
;
throw
std
::
runtime_error
(
stream
.
str
());
}
BitDepth
get_bit_depth_from_option
(
const
c10
::
optional
<
int64_t
>&
bit_depth
)
{
if
(
!
bit_depth
.
has_value
())
return
BitDepth
::
NOT_PROVIDED
;
int64_t
v
=
bit_depth
.
value
();
switch
(
v
)
{
case
8
:
return
BitDepth
::
B8
;
case
16
:
return
BitDepth
::
B16
;
case
24
:
return
BitDepth
::
B24
;
case
32
:
return
BitDepth
::
B32
;
case
64
:
return
BitDepth
::
B64
;
default:
{
std
::
ostringstream
s
;
s
<<
"Internal Error: unexpected bit depth value: "
<<
v
;
throw
std
::
runtime_error
(
s
.
str
());
}
}
}
}
// namespace sox_utils
}
// namespace torchaudio
torchaudio/csrc/sox/types.h
0 → 100644
View file @
c3cb2015
#ifndef TORCHAUDIO_SOX_TYPES_H
#define TORCHAUDIO_SOX_TYPES_H
#include <torch/script.h>
namespace
torchaudio
{
namespace
sox_utils
{
enum
class
Format
{
WAV
,
MP3
,
FLAC
,
VORBIS
,
AMR_NB
,
AMR_WB
,
AMB
,
SPHERE
,
};
Format
get_format_from_string
(
const
std
::
string
&
format
);
enum
class
Encoding
{
NOT_PROVIDED
,
UNKNOWN
,
PCM_SIGNED
,
PCM_UNSIGNED
,
PCM_FLOAT
,
FLAC
,
ULAW
,
ALAW
,
MP3
,
VORBIS
,
AMR_WB
,
AMR_NB
,
OPUS
,
};
std
::
string
to_string
(
Encoding
v
);
Encoding
get_encoding_from_option
(
const
c10
::
optional
<
std
::
string
>&
encoding
);
enum
class
BitDepth
:
unsigned
{
NOT_PROVIDED
=
0
,
B8
=
8
,
B16
=
16
,
B24
=
24
,
B32
=
32
,
B64
=
64
,
};
BitDepth
get_bit_depth_from_option
(
const
c10
::
optional
<
int64_t
>&
bit_depth
);
}
// namespace sox_utils
}
// namespace torchaudio
#endif
torchaudio/csrc/sox/utils.cpp
View file @
c3cb2015
#include <c10/core/ScalarType.h>
#include <c10/core/ScalarType.h>
#include <sox.h>
#include <sox.h>
#include <torchaudio/csrc/sox/types.h>
#include <torchaudio/csrc/sox/utils.h>
#include <torchaudio/csrc/sox/utils.h>
namespace
torchaudio
{
namespace
torchaudio
{
...
@@ -163,22 +164,32 @@ torch::Tensor convert_to_tensor(
...
@@ -163,22 +164,32 @@ torch::Tensor convert_to_tensor(
const
caffe2
::
TypeMeta
dtype
,
const
caffe2
::
TypeMeta
dtype
,
const
bool
normalize
,
const
bool
normalize
,
const
bool
channels_first
)
{
const
bool
channels_first
)
{
auto
t
=
torch
::
from_blob
(
torch
::
Tensor
t
;
buffer
,
{
num_samples
/
num_channels
,
num_channels
},
torch
::
kInt32
);
uint64_t
dummy
;
// Note: Tensor created from_blob does not own data but borrwos
SOX_SAMPLE_LOCALS
;
// So make sure to create a new copy after processing samples.
if
(
normalize
||
dtype
==
torch
::
kFloat32
)
{
if
(
normalize
||
dtype
==
torch
::
kFloat32
)
{
t
=
t
.
to
(
torch
::
kFloat32
);
t
=
torch
::
empty
(
t
*=
(
t
>
0
)
/
2147483647.
+
(
t
<
0
)
/
2147483648.
;
{
num_samples
/
num_channels
,
num_channels
},
torch
::
kFloat32
);
auto
ptr
=
t
.
data_ptr
<
float_t
>
();
for
(
int32_t
i
=
0
;
i
<
num_samples
;
++
i
)
{
ptr
[
i
]
=
SOX_SAMPLE_TO_FLOAT_32BIT
(
buffer
[
i
],
dummy
);
}
}
else
if
(
dtype
==
torch
::
kInt32
)
{
}
else
if
(
dtype
==
torch
::
kInt32
)
{
t
=
t
.
clone
();
t
=
torch
::
from_blob
(
buffer
,
{
num_samples
/
num_channels
,
num_channels
},
torch
::
kInt32
)
.
clone
();
}
else
if
(
dtype
==
torch
::
kInt16
)
{
}
else
if
(
dtype
==
torch
::
kInt16
)
{
t
.
floor_divide_
(
1
<<
16
);
t
=
torch
::
empty
({
num_samples
/
num_channels
,
num_channels
},
torch
::
kInt16
);
t
=
t
.
to
(
torch
::
kInt16
);
auto
ptr
=
t
.
data_ptr
<
int16_t
>
();
for
(
int32_t
i
=
0
;
i
<
num_samples
;
++
i
)
{
ptr
[
i
]
=
SOX_SAMPLE_TO_SIGNED_16BIT
(
buffer
[
i
],
dummy
);
}
}
else
if
(
dtype
==
torch
::
kUInt8
)
{
}
else
if
(
dtype
==
torch
::
kUInt8
)
{
t
.
floor_divide_
(
1
<<
24
);
t
=
torch
::
empty
({
num_samples
/
num_channels
,
num_channels
},
torch
::
kUInt8
);
t
+=
128
;
auto
ptr
=
t
.
data_ptr
<
uint8_t
>
();
t
=
t
.
to
(
torch
::
kUInt8
);
for
(
int32_t
i
=
0
;
i
<
num_samples
;
++
i
)
{
ptr
[
i
]
=
SOX_SAMPLE_TO_UNSIGNED_8BIT
(
buffer
[
i
],
dummy
);
}
}
else
{
}
else
{
throw
std
::
runtime_error
(
"Unsupported dtype."
);
throw
std
::
runtime_error
(
"Unsupported dtype."
);
}
}
...
@@ -188,63 +199,181 @@ torch::Tensor convert_to_tensor(
...
@@ -188,63 +199,181 @@ torch::Tensor convert_to_tensor(
return
t
.
contiguous
();
return
t
.
contiguous
();
}
}
torch
::
Tensor
unnormalize_wav
(
const
torch
::
Tensor
input_tensor
)
{
const
auto
dtype
=
input_tensor
.
dtype
();
auto
tensor
=
input_tensor
;
if
(
dtype
==
torch
::
kFloat32
)
{
double
multi_pos
=
2147483647.
;
double
multi_neg
=
-
2147483648.
;
auto
mult
=
(
tensor
>
0
)
*
multi_pos
-
(
tensor
<
0
)
*
multi_neg
;
tensor
=
tensor
.
to
(
torch
::
dtype
(
torch
::
kFloat64
));
tensor
*=
mult
;
tensor
.
clamp_
(
multi_neg
,
multi_pos
);
tensor
=
tensor
.
to
(
torch
::
dtype
(
torch
::
kInt32
));
}
else
if
(
dtype
==
torch
::
kInt32
)
{
// already denormalized
}
else
if
(
dtype
==
torch
::
kInt16
)
{
tensor
=
tensor
.
to
(
torch
::
dtype
(
torch
::
kInt32
));
tensor
*=
((
tensor
!=
0
)
*
65536
);
}
else
if
(
dtype
==
torch
::
kUInt8
)
{
tensor
=
tensor
.
to
(
torch
::
dtype
(
torch
::
kInt32
));
tensor
-=
128
;
tensor
*=
16777216
;
}
else
{
throw
std
::
runtime_error
(
"Unexpected dtype."
);
}
return
tensor
;
}
const
std
::
string
get_filetype
(
const
std
::
string
path
)
{
const
std
::
string
get_filetype
(
const
std
::
string
path
)
{
std
::
string
ext
=
path
.
substr
(
path
.
find_last_of
(
"."
)
+
1
);
std
::
string
ext
=
path
.
substr
(
path
.
find_last_of
(
"."
)
+
1
);
std
::
transform
(
ext
.
begin
(),
ext
.
end
(),
ext
.
begin
(),
::
tolower
);
std
::
transform
(
ext
.
begin
(),
ext
.
end
(),
ext
.
begin
(),
::
tolower
);
return
ext
;
return
ext
;
}
}
sox_encoding_t
get_encoding
(
namespace
{
const
std
::
string
filetype
,
const
caffe2
::
TypeMeta
dtype
)
{
std
::
tuple
<
sox_encoding_t
,
unsigned
>
get_save_encoding_for_wav
(
if
(
filetype
==
"mp3"
)
const
std
::
string
format
,
return
SOX_ENCODING_MP3
;
const
caffe2
::
TypeMeta
dtype
,
if
(
filetype
==
"flac"
)
const
Encoding
&
encoding
,
return
SOX_ENCODING_FLAC
;
const
BitDepth
&
bits_per_sample
)
{
if
(
filetype
==
"ogg"
||
filetype
==
"vorbis"
)
switch
(
encoding
)
{
return
SOX_ENCODING_VORBIS
;
case
Encoding
::
NOT_PROVIDED
:
if
(
filetype
==
"wav"
||
filetype
==
"amb"
)
{
switch
(
bits_per_sample
)
{
if
(
dtype
==
torch
::
kUInt8
)
case
BitDepth
::
NOT_PROVIDED
:
return
SOX_ENCODING_UNSIGNED
;
if
(
dtype
==
torch
::
kFloat32
)
if
(
dtype
==
torch
::
kInt16
)
return
std
::
make_tuple
<>
(
SOX_ENCODING_FLOAT
,
32
);
return
SOX_ENCODING_SIGN2
;
if
(
dtype
==
torch
::
kInt32
)
if
(
dtype
==
torch
::
kInt32
)
return
std
::
make_tuple
<>
(
SOX_ENCODING_SIGN2
,
32
);
return
SOX_ENCODING_SIGN2
;
if
(
dtype
==
torch
::
kInt16
)
if
(
dtype
==
torch
::
kFloat32
)
return
std
::
make_tuple
<>
(
SOX_ENCODING_SIGN2
,
16
);
return
SOX_ENCODING_FLOAT
;
if
(
dtype
==
torch
::
kUInt8
)
throw
std
::
runtime_error
(
"Unsupported dtype."
);
return
std
::
make_tuple
<>
(
SOX_ENCODING_UNSIGNED
,
8
);
throw
std
::
runtime_error
(
"Internal Error: Unexpected dtype."
);
case
BitDepth
::
B8
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_UNSIGNED
,
8
);
default:
return
std
::
make_tuple
<>
(
SOX_ENCODING_SIGN2
,
static_cast
<
unsigned
>
(
bits_per_sample
));
}
case
Encoding
::
PCM_SIGNED
:
switch
(
bits_per_sample
)
{
case
BitDepth
::
NOT_PROVIDED
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_SIGN2
,
32
);
case
BitDepth
::
B8
:
throw
std
::
runtime_error
(
format
+
" does not support 8-bit signed PCM encoding."
);
default:
return
std
::
make_tuple
<>
(
SOX_ENCODING_SIGN2
,
static_cast
<
unsigned
>
(
bits_per_sample
));
}
case
Encoding
::
PCM_UNSIGNED
:
switch
(
bits_per_sample
)
{
case
BitDepth
::
NOT_PROVIDED
:
case
BitDepth
::
B8
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_UNSIGNED
,
8
);
default:
throw
std
::
runtime_error
(
format
+
" only supports 8-bit for unsigned PCM encoding."
);
}
case
Encoding
::
PCM_FLOAT
:
switch
(
bits_per_sample
)
{
case
BitDepth
::
NOT_PROVIDED
:
case
BitDepth
::
B32
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_FLOAT
,
32
);
case
BitDepth
::
B64
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_FLOAT
,
64
);
default:
throw
std
::
runtime_error
(
format
+
" only supports 32-bit or 64-bit for floating-point PCM encoding."
);
}
case
Encoding
::
ULAW
:
switch
(
bits_per_sample
)
{
case
BitDepth
::
NOT_PROVIDED
:
case
BitDepth
::
B8
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_ULAW
,
8
);
default:
throw
std
::
runtime_error
(
format
+
" only supports 8-bit for mu-law encoding."
);
}
case
Encoding
::
ALAW
:
switch
(
bits_per_sample
)
{
case
BitDepth
::
NOT_PROVIDED
:
case
BitDepth
::
B8
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_ALAW
,
8
);
default:
throw
std
::
runtime_error
(
format
+
" only supports 8-bit for a-law encoding."
);
}
default:
throw
std
::
runtime_error
(
format
+
" does not support encoding: "
+
to_string
(
encoding
));
}
}
std
::
tuple
<
sox_encoding_t
,
unsigned
>
get_save_encoding
(
const
std
::
string
&
format
,
const
caffe2
::
TypeMeta
dtype
,
const
c10
::
optional
<
std
::
string
>&
encoding
,
const
c10
::
optional
<
int64_t
>&
bits_per_sample
)
{
const
Format
fmt
=
get_format_from_string
(
format
);
const
Encoding
enc
=
get_encoding_from_option
(
encoding
);
const
BitDepth
bps
=
get_bit_depth_from_option
(
bits_per_sample
);
switch
(
fmt
)
{
case
Format
::
WAV
:
case
Format
::
AMB
:
return
get_save_encoding_for_wav
(
format
,
dtype
,
enc
,
bps
);
case
Format
::
MP3
:
if
(
enc
!=
Encoding
::
NOT_PROVIDED
)
throw
std
::
runtime_error
(
"mp3 does not support `encoding` option."
);
if
(
bps
!=
BitDepth
::
NOT_PROVIDED
)
throw
std
::
runtime_error
(
"mp3 does not support `bits_per_sample` option."
);
return
std
::
make_tuple
<>
(
SOX_ENCODING_MP3
,
16
);
case
Format
::
VORBIS
:
if
(
enc
!=
Encoding
::
NOT_PROVIDED
)
throw
std
::
runtime_error
(
"vorbis does not support `encoding` option."
);
if
(
bps
!=
BitDepth
::
NOT_PROVIDED
)
throw
std
::
runtime_error
(
"vorbis does not support `bits_per_sample` option."
);
return
std
::
make_tuple
<>
(
SOX_ENCODING_VORBIS
,
16
);
case
Format
::
AMR_NB
:
if
(
enc
!=
Encoding
::
NOT_PROVIDED
)
throw
std
::
runtime_error
(
"amr-nb does not support `encoding` option."
);
if
(
bps
!=
BitDepth
::
NOT_PROVIDED
)
throw
std
::
runtime_error
(
"amr-nb does not support `bits_per_sample` option."
);
return
std
::
make_tuple
<>
(
SOX_ENCODING_AMR_NB
,
16
);
case
Format
::
FLAC
:
if
(
enc
!=
Encoding
::
NOT_PROVIDED
)
throw
std
::
runtime_error
(
"flac does not support `encoding` option."
);
switch
(
bps
)
{
case
BitDepth
::
B32
:
case
BitDepth
::
B64
:
throw
std
::
runtime_error
(
"flac does not support `bits_per_sample` larger than 24."
);
default:
return
std
::
make_tuple
<>
(
SOX_ENCODING_FLAC
,
static_cast
<
unsigned
>
(
bps
));
}
case
Format
::
SPHERE
:
switch
(
enc
)
{
case
Encoding
::
NOT_PROVIDED
:
case
Encoding
::
PCM_SIGNED
:
switch
(
bps
)
{
case
BitDepth
::
NOT_PROVIDED
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_SIGN2
,
32
);
default:
return
std
::
make_tuple
<>
(
SOX_ENCODING_SIGN2
,
static_cast
<
unsigned
>
(
bps
));
}
case
Encoding
::
PCM_UNSIGNED
:
throw
std
::
runtime_error
(
"sph does not support unsigned integer PCM."
);
case
Encoding
::
PCM_FLOAT
:
throw
std
::
runtime_error
(
"sph does not support floating point PCM."
);
case
Encoding
::
ULAW
:
switch
(
bps
)
{
case
BitDepth
::
NOT_PROVIDED
:
case
BitDepth
::
B8
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_ULAW
,
8
);
default:
throw
std
::
runtime_error
(
"sph only supports 8-bit for mu-law encoding."
);
}
case
Encoding
::
ALAW
:
switch
(
bps
)
{
case
BitDepth
::
NOT_PROVIDED
:
case
BitDepth
::
B8
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_ALAW
,
8
);
default:
return
std
::
make_tuple
<>
(
SOX_ENCODING_ALAW
,
static_cast
<
unsigned
>
(
bps
));
}
default:
throw
std
::
runtime_error
(
"sph does not support encoding: "
+
encoding
.
value
());
}
default:
throw
std
::
runtime_error
(
"Unsupported format: "
+
format
);
}
}
if
(
filetype
==
"sph"
)
return
SOX_ENCODING_SIGN2
;
if
(
filetype
==
"amr-nb"
)
return
SOX_ENCODING_AMR_NB
;
throw
std
::
runtime_error
(
"Unsupported file type: "
+
filetype
);
}
}
unsigned
get_precision
(
unsigned
get_precision
(
...
@@ -270,14 +399,13 @@ unsigned get_precision(
...
@@ -270,14 +399,13 @@ unsigned get_precision(
if
(
filetype
==
"sph"
)
if
(
filetype
==
"sph"
)
return
32
;
return
32
;
if
(
filetype
==
"amr-nb"
)
{
if
(
filetype
==
"amr-nb"
)
{
TORCH_INTERNAL_ASSERT
(
dtype
==
torch
::
kInt16
,
"When saving to AMR-NB format, the input tensor must be int16 type."
);
return
16
;
return
16
;
}
}
throw
std
::
runtime_error
(
"Unsupported file type: "
+
filetype
);
throw
std
::
runtime_error
(
"Unsupported file type: "
+
filetype
);
}
}
}
// namespace
sox_signalinfo_t
get_signalinfo
(
sox_signalinfo_t
get_signalinfo
(
const
torch
::
Tensor
*
waveform
,
const
torch
::
Tensor
*
waveform
,
const
int64_t
sample_rate
,
const
int64_t
sample_rate
,
...
@@ -325,12 +453,15 @@ sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype) {
...
@@ -325,12 +453,15 @@ sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype) {
}
}
sox_encodinginfo_t
get_encodinginfo_for_save
(
sox_encodinginfo_t
get_encodinginfo_for_save
(
const
std
::
string
f
iletype
,
const
std
::
string
&
f
ormat
,
const
caffe2
::
TypeMeta
dtype
,
const
caffe2
::
TypeMeta
dtype
,
c10
::
optional
<
double
>&
compression
)
{
const
c10
::
optional
<
double
>&
compression
,
const
c10
::
optional
<
std
::
string
>&
encoding
,
const
c10
::
optional
<
int64_t
>&
bits_per_sample
)
{
auto
enc
=
get_save_encoding
(
format
,
dtype
,
encoding
,
bits_per_sample
);
return
sox_encodinginfo_t
{
return
sox_encodinginfo_t
{
/*encoding=*/
get_encoding
(
filetype
,
dtype
),
/*encoding=*/
std
::
get
<
0
>
(
enc
),
/*bits_per_sample=*/
get_precision
(
filetype
,
dtype
),
/*bits_per_sample=*/
std
::
get
<
1
>
(
enc
),
/*compression=*/
compression
.
value_or
(
HUGE_VAL
),
/*compression=*/
compression
.
value_or
(
HUGE_VAL
),
/*reverse_bytes=*/
sox_option_default
,
/*reverse_bytes=*/
sox_option_default
,
/*reverse_nibbles=*/
sox_option_default
,
/*reverse_nibbles=*/
sox_option_default
,
...
...
torchaudio/csrc/sox/utils.h
View file @
c3cb2015
...
@@ -93,11 +93,6 @@ torch::Tensor convert_to_tensor(
...
@@ -93,11 +93,6 @@ torch::Tensor convert_to_tensor(
const
bool
normalize
,
const
bool
normalize
,
const
bool
channels_first
);
const
bool
channels_first
);
///
/// Convert float32/int32/int16/uint8 Tensor to int32 for Torch -> Sox
/// conversion.
torch
::
Tensor
unnormalize_wav
(
const
torch
::
Tensor
);
/// Extract extension from file path
/// Extract extension from file path
const
std
::
string
get_filetype
(
const
std
::
string
path
);
const
std
::
string
get_filetype
(
const
std
::
string
path
);
...
@@ -113,9 +108,11 @@ sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype);
...
@@ -113,9 +108,11 @@ sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype);
/// Get sox_encodinginfo_t for saving to file/file object
/// Get sox_encodinginfo_t for saving to file/file object
sox_encodinginfo_t
get_encodinginfo_for_save
(
sox_encodinginfo_t
get_encodinginfo_for_save
(
const
std
::
string
f
iletype
,
const
std
::
string
&
f
ormat
,
const
caffe2
::
TypeMeta
dtype
,
const
caffe2
::
TypeMeta
dtype
,
c10
::
optional
<
double
>&
compression
);
const
c10
::
optional
<
double
>&
compression
,
const
c10
::
optional
<
std
::
string
>&
encoding
,
const
c10
::
optional
<
int64_t
>&
bits_per_sample
);
#ifdef TORCH_API_INCLUDE_EXTENSION_H
#ifdef TORCH_API_INCLUDE_EXTENSION_H
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment