Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
c3cb2015
Unverified
Commit
c3cb2015
authored
Feb 12, 2021
by
moto
Committed by
GitHub
Feb 12, 2021
Browse files
Add encoding and bits_per_sample option to save function (#1226)
parent
4f9b5520
Changes
14
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
883 additions
and
631 deletions
+883
-631
test/torchaudio_unittest/backend/sox_io/common.py
test/torchaudio_unittest/backend/sox_io/common.py
+12
-0
test/torchaudio_unittest/backend/sox_io/roundtrip_test.py
test/torchaudio_unittest/backend/sox_io/roundtrip_test.py
+3
-1
test/torchaudio_unittest/backend/sox_io/save_test.py
test/torchaudio_unittest/backend/sox_io/save_test.py
+302
-438
test/torchaudio_unittest/backend/sox_io/torchscript_test.py
test/torchaudio_unittest/backend/sox_io/torchscript_test.py
+14
-8
test/torchaudio_unittest/common_utils/sox_utils.py
test/torchaudio_unittest/common_utils/sox_utils.py
+8
-4
torchaudio/backend/sox_io_backend.py
torchaudio/backend/sox_io_backend.py
+128
-63
torchaudio/csrc/CMakeLists.txt
torchaudio/csrc/CMakeLists.txt
+1
-0
torchaudio/csrc/sox/effects_chain.cpp
torchaudio/csrc/sox/effects_chain.cpp
+30
-8
torchaudio/csrc/sox/io.cpp
torchaudio/csrc/sox/io.cpp
+16
-27
torchaudio/csrc/sox/io.h
torchaudio/csrc/sox/io.h
+8
-6
torchaudio/csrc/sox/types.cpp
torchaudio/csrc/sox/types.cpp
+102
-0
torchaudio/csrc/sox/types.h
torchaudio/csrc/sox/types.h
+55
-0
torchaudio/csrc/sox/utils.cpp
torchaudio/csrc/sox/utils.cpp
+200
-69
torchaudio/csrc/sox/utils.h
torchaudio/csrc/sox/utils.h
+4
-7
No files found.
test/torchaudio_unittest/backend/sox_io/common.py
View file @
c3cb2015
def
name_func
(
func
,
_
,
params
):
return
f
'
{
func
.
__name__
}
_
{
"_"
.
join
(
str
(
arg
)
for
arg
in
params
.
args
)
}
'
def
get_enc_params
(
dtype
):
if
dtype
==
'float32'
:
return
'PCM_F'
,
32
if
dtype
==
'int32'
:
return
'PCM_S'
,
32
if
dtype
==
'int16'
:
return
'PCM_S'
,
16
if
dtype
==
'uint8'
:
return
'PCM_U'
,
8
raise
ValueError
(
f
'Unexpected dtype:
{
dtype
}
'
)
test/torchaudio_unittest/backend/sox_io/roundtrip_test.py
View file @
c3cb2015
...
...
@@ -12,6 +12,7 @@ from torchaudio_unittest.common_utils import (
)
from
.common
import
(
name_func
,
get_enc_params
,
)
...
...
@@ -27,10 +28,11 @@ class TestRoundTripIO(TempDirMixin, PytorchTestCase):
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""save/load round trip should not degrade data for wav formats"""
original
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
)
enc
,
bps
=
get_enc_params
(
dtype
)
data
=
original
for
i
in
range
(
10
):
path
=
self
.
get_temp_path
(
f
'
{
i
}
.wav'
)
sox_io_backend
.
save
(
path
,
data
,
sample_rate
)
sox_io_backend
.
save
(
path
,
data
,
sample_rate
,
encoding
=
enc
,
bits_per_sample
=
bps
)
data
,
sr
=
sox_io_backend
.
load
(
path
,
normalize
=
False
)
assert
sr
==
sample_rate
self
.
assertEqual
(
original
,
data
)
...
...
test/torchaudio_unittest/backend/sox_io/save_test.py
View file @
c3cb2015
This diff is collapsed.
Click to expand it.
test/torchaudio_unittest/backend/sox_io/torchscript_test.py
View file @
c3cb2015
...
...
@@ -17,6 +17,7 @@ from torchaudio_unittest.common_utils import (
)
from
.common
import
(
name_func
,
get_enc_params
,
)
...
...
@@ -35,8 +36,12 @@ def py_save_func(
sample_rate
:
int
,
channels_first
:
bool
=
True
,
compression
:
Optional
[
float
]
=
None
,
encoding
:
Optional
[
str
]
=
None
,
bits_per_sample
:
Optional
[
int
]
=
None
,
):
torchaudio
.
save
(
filepath
,
tensor
,
sample_rate
,
channels_first
,
compression
)
torchaudio
.
save
(
filepath
,
tensor
,
sample_rate
,
channels_first
,
compression
,
None
,
encoding
,
bits_per_sample
)
@
skipIfNoExec
(
'sox'
)
...
...
@@ -102,15 +107,16 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
torch
.
jit
.
script
(
py_save_func
).
save
(
script_path
)
ts_save_func
=
torch
.
jit
.
load
(
script_path
)
expected
=
get_wav_data
(
dtype
,
num_channels
)
expected
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
)
py_path
=
self
.
get_temp_path
(
f
'test_save_py_
{
dtype
}
_
{
sample_rate
}
_
{
num_channels
}
.wav'
)
ts_path
=
self
.
get_temp_path
(
f
'test_save_ts_
{
dtype
}
_
{
sample_rate
}
_
{
num_channels
}
.wav'
)
enc
,
bps
=
get_enc_params
(
dtype
)
py_save_func
(
py_path
,
expected
,
sample_rate
,
True
,
None
)
ts_save_func
(
ts_path
,
expected
,
sample_rate
,
True
,
None
)
py_save_func
(
py_path
,
expected
,
sample_rate
,
True
,
None
,
enc
,
bps
)
ts_save_func
(
ts_path
,
expected
,
sample_rate
,
True
,
None
,
enc
,
bps
)
py_data
,
py_sr
=
load_wav
(
py_path
)
ts_data
,
ts_sr
=
load_wav
(
ts_path
)
py_data
,
py_sr
=
load_wav
(
py_path
,
normalize
=
False
)
ts_data
,
ts_sr
=
load_wav
(
ts_path
,
normalize
=
False
)
self
.
assertEqual
(
sample_rate
,
py_sr
)
self
.
assertEqual
(
sample_rate
,
ts_sr
)
...
...
@@ -131,8 +137,8 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
py_path
=
self
.
get_temp_path
(
f
'test_save_py_
{
sample_rate
}
_
{
num_channels
}
_
{
compression_level
}
.flac'
)
ts_path
=
self
.
get_temp_path
(
f
'test_save_ts_
{
sample_rate
}
_
{
num_channels
}
_
{
compression_level
}
.flac'
)
py_save_func
(
py_path
,
expected
,
sample_rate
,
True
,
compression_level
)
ts_save_func
(
ts_path
,
expected
,
sample_rate
,
True
,
compression_level
)
py_save_func
(
py_path
,
expected
,
sample_rate
,
True
,
compression_level
,
None
,
None
)
ts_save_func
(
ts_path
,
expected
,
sample_rate
,
True
,
compression_level
,
None
,
None
)
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
py_path_wav
=
f
'
{
py_path
}
.wav'
...
...
test/torchaudio_unittest/common_utils/sox_utils.py
View file @
c3cb2015
import
sys
import
subprocess
import
warnings
...
...
@@ -32,6 +33,7 @@ def gen_audio_file(
command
=
[
'sox'
,
'-V3'
,
# verbose
'--no-dither'
,
# disable automatic dithering
'-R'
,
# -R is supposed to be repeatable, though the implementation looks suspicious
# and not setting the seed to a fixed value.
...
...
@@ -61,21 +63,23 @@ def gen_audio_file(
]
if
attenuation
is
not
None
:
command
+=
[
'vol'
,
f
'-
{
attenuation
}
dB'
]
print
(
' '
.
join
(
command
))
print
(
' '
.
join
(
command
)
,
file
=
sys
.
stderr
)
subprocess
.
run
(
command
,
check
=
True
)
def
convert_audio_file
(
src_path
,
dst_path
,
*
,
bit_depth
=
None
,
compression
=
None
):
*
,
encoding
=
None
,
bit_depth
=
None
,
compression
=
None
):
"""Convert audio file with `sox` command."""
command
=
[
'sox'
,
'-V3'
,
'-R'
,
str
(
src_path
)]
command
=
[
'sox'
,
'-V3'
,
'--no-dither'
,
'-R'
,
str
(
src_path
)]
if
encoding
is
not
None
:
command
+=
[
'--encoding'
,
str
(
encoding
)]
if
bit_depth
is
not
None
:
command
+=
[
'--bits'
,
str
(
bit_depth
)]
if
compression
is
not
None
:
command
+=
[
'--compression'
,
str
(
compression
)]
command
+=
[
dst_path
]
print
(
' '
.
join
(
command
))
print
(
' '
.
join
(
command
)
,
file
=
sys
.
stderr
)
subprocess
.
run
(
command
,
check
=
True
)
...
...
torchaudio/backend/sox_io_backend.py
View file @
c3cb2015
import
os
import
warnings
from
typing
import
Tuple
,
Optional
import
torch
...
...
@@ -152,26 +151,6 @@ def load(
filepath
,
frame_offset
,
num_frames
,
normalize
,
channels_first
,
format
)
@
torch
.
jit
.
unused
def
_save
(
filepath
:
str
,
src
:
torch
.
Tensor
,
sample_rate
:
int
,
channels_first
:
bool
=
True
,
compression
:
Optional
[
float
]
=
None
,
format
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
str
]
=
None
,
):
if
hasattr
(
filepath
,
'write'
):
if
format
is
None
:
raise
RuntimeError
(
'`format` is required when saving to file object.'
)
torchaudio
.
_torchaudio
.
save_audio_fileobj
(
filepath
,
src
,
sample_rate
,
channels_first
,
compression
,
format
,
dtype
)
else
:
torch
.
ops
.
torchaudio
.
sox_io_save_audio_file
(
os
.
fspath
(
filepath
),
src
,
sample_rate
,
channels_first
,
compression
,
format
,
dtype
)
@
_mod_utils
.
requires_module
(
'torchaudio._torchaudio'
)
def
save
(
filepath
:
str
,
...
...
@@ -180,30 +159,11 @@ def save(
channels_first
:
bool
=
True
,
compression
:
Optional
[
float
]
=
None
,
format
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
str
]
=
None
,
encoding
:
Optional
[
str
]
=
None
,
bits_per_sample
:
Optional
[
int
]
=
None
,
):
"""Save audio data to file.
Note:
Supported formats are;
* WAV, AMB
* 32-bit floating-point
* 32-bit signed integer
* 16-bit signed integer
* 8-bit unsigned integer
* MP3
* FLAC
* OGG/VORBIS
* SPHERE
* AMR-NB
To save ``MP3``, ``FLAC``, ``OGG/VORBIS``, and other codecs ``libsox`` does not
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
and corresponding codec libraries such as ``libmad`` or ``libmp3lame`` etc.
Args:
filepath (str or pathlib.Path): Path to save file.
This function also handles ``pathlib.Path`` objects, but is annotated
...
...
@@ -215,32 +175,137 @@ def save(
compression (Optional[float]): Used for formats other than WAV.
This corresponds to ``-C`` option of ``sox`` command.
* | ``MP3``: Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or
| VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``.
* | ``FLAC``: compression level. Whole number from ``0`` to ``8``.
| ``8`` is default and highest compression.
* | ``OGG/VORBIS``: number from ``-1`` to ``10``; ``-1`` is the highest compression
| and lowest quality. Default: ``3``.
``"mp3"``
Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or
VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``.
``"flac"``
Whole number from ``0`` to ``8``. ``8`` is default and highest compression.
``"ogg"``, ``"vorbis"``
Number from ``-1`` to ``10``; ``-1`` is the highest compression
and lowest quality. Default: ``3``.
See the detail at http://sox.sourceforge.net/soxformat.html.
format (str, optional): Output audio format.
This is required when the output audio format cannot be infered from
``filepath``, (such as file extension or ``name`` attribute of the given file object).
dtype (str, optional): Output tensor dtype.
Valid values: ``"uint8", "int16", "int32", "float32", "float64", None``
``dtype=None`` means no conversion is performed.
``dtype`` parameter is only effective for ``float32`` Tensor.
format (str, optional): Override the audio format.
When ``filepath`` argument is path-like object, audio format is infered from
file extension. If file extension is missing or different, you can specify the
correct format with this argument.
When ``filepath`` argument is file-like object, this argument is required.
Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``,
``"amb"``, ``"flac"`` and ``"sph"``.
encoding (str, optional): Changes the encoding for the supported formats.
This argument is effective only for supported formats, cush as ``"wav"``, ``""amb"``
and ``"sph"``. Valid values are;
- ``"PCM_S"`` (signed integer Linear PCM)
- ``"PCM_U"`` (unsigned integer Linear PCM)
- ``"PCM_F"`` (floating point PCM)
- ``"ULAW"`` (mu-law)
- ``"ALAW"`` (a-law)
Default values
If not provided, the default value is picked based on ``format`` and ``bits_per_sample``.
``"wav"``, ``"amb"``
- | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the
| Tensor is used to determine the default value.
- ``"PCM_U"`` if dtype is ``uint8``
- ``"PCM_S"`` if dtype is ``int16`` or ``int32`
- ``"PCM_F"`` if dtype is ``float32``
- ``"PCM_U"`` if ``bits_per_sample=8``
- ``"PCM_S"`` otherwise
``"sph"`` format;
- the default value is ``"PCM_S"``
bits_per_sample (int, optional): Changes the bit depth for the supported formats.
When ``format`` is one of ``"wav"``, ``"flac"``, ``"sph"``, or ``"amb"``, you can change the
bit depth. Valid values are ``8``, ``16``, ``32`` and ``64``.
Default Value;
If not provided, the default values are picked based on ``format`` and ``"encoding"``;
``"wav"``, ``"amb"``;
- | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the
| Tensor is used.
- ``8`` if dtype is ``uint8``
- ``16`` if dtype is ``int16``
- ``32`` if dtype is ``int32`` or ``float32``
- ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"``
- ``16`` if ``encoding`` is ``"PCM_S"``
- ``32`` if ``encoding`` is ``"PCM_F"``
``"flac"`` format;
- the default value is ``24``
``"sph"`` format;
- ``16`` if ``encoding`` is ``"PCM_U"``, ``"PCM_S"``, ``"PCM_F"`` or not provided.
- ``8`` if ``encoding`` is ``"ULAW"`` or ``"ALAW"``
``"amb"`` format;
- ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"``
- ``16`` if ``encoding`` is ``"PCM_S"`` or not provided.
- ``32`` if ``encoding`` is ``"PCM_F"``
Supported formats/encodings/bit depth/compression are;
``"wav"``, ``"amb"``
- 32-bit floating-point PCM
- 32-bit signed integer PCM
- 24-bit signed integer PCM
- 16-bit signed integer PCM
- 8-bit unsigned integer PCM
- 8-bit mu-law
- 8-bit a-law
Note: Default encoding/bit depth is determined by the dtype of the input Tensor.
``"mp3"``
Fixed bit rate (such as 128kHz) and variable bit rate compression.
Default: VBR with high quality.
``"flac"``
- 8-bit
- 16-bit
- 24-bit (default)
``"ogg"``, ``"vorbis"``
- Different quality level. Default: approx. 112kbps
``"sph"``
- 8-bit signed integer PCM
- 16-bit signed integer PCM
- 24-bit signed integer PCM
- 32-bit signed integer PCM (default)
- 8-bit mu-law
- 8-bit a-law
- 16-bit a-law
- 24-bit a-law
- 32-bit a-law
``"amr-nb"``
Bitrate ranging from 4.75 kbit/s to 12.2 kbit/s. Default: 4.75 kbit/s
Note:
To save into formats that ``libsox`` does not handle natively, (such as ``"mp3"``,
``"flac"``, ``"ogg"`` and ``"vorbis"``), your installation of ``torchaudio`` has
to be linked to ``libsox`` and corresponding codec libraries such as ``libmad``
or ``libmp3lame`` etc.
"""
if
src
.
dtype
==
torch
.
float32
and
dtype
is
None
:
warnings
.
warn
(
'`dtype` default value will be changed to `int16` in 0.9 release.'
'Specify `dtype` to suppress this warning.'
)
if
not
torch
.
jit
.
is_scripting
():
_save
(
filepath
,
src
,
sample_rate
,
channels_first
,
compression
,
format
,
dtype
)
return
if
hasattr
(
filepath
,
'write'
):
torchaudio
.
_torchaudio
.
save_audio_fileobj
(
filepath
,
src
,
sample_rate
,
channels_first
,
compression
,
format
,
encoding
,
bits_per_sample
)
return
filepath
=
os
.
fspath
(
filepath
)
torch
.
ops
.
torchaudio
.
sox_io_save_audio_file
(
filepath
,
src
,
sample_rate
,
channels_first
,
compression
,
format
,
dtyp
e
)
filepath
,
src
,
sample_rate
,
channels_first
,
compression
,
format
,
encoding
,
bits_per_sampl
e
)
@
_mod_utils
.
requires_module
(
'torchaudio._torchaudio'
)
...
...
torchaudio/csrc/CMakeLists.txt
View file @
c3cb2015
...
...
@@ -9,6 +9,7 @@ set(
sox/utils.cpp
sox/effects.cpp
sox/effects_chain.cpp
sox/types.cpp
)
if
(
BUILD_TRANSDUCER
)
...
...
torchaudio/csrc/sox/effects_chain.cpp
View file @
c3cb2015
...
...
@@ -68,21 +68,43 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
// Ensure that it's a multiple of the number of channels
*
osamp
-=
*
osamp
%
num_channels
;
// Slice the input Tensor
and unnormalize the values
// Slice the input Tensor
const
auto
tensor_
=
[
&
]()
{
auto
i_frame
=
index
/
num_channels
;
auto
num_frames
=
*
osamp
/
num_channels
;
auto
t
=
(
priv
->
channels_first
)
?
tensor
.
index
({
Slice
(),
Slice
(
i_frame
,
i_frame
+
num_frames
)}).
t
()
:
tensor
.
index
({
Slice
(
i_frame
,
i_frame
+
num_frames
),
Slice
()});
return
unnormalize_wav
(
t
.
reshape
({
-
1
})
)
.
contiguous
();
return
t
.
reshape
({
-
1
}).
contiguous
();
}();
priv
->
index
+=
*
osamp
;
// Write data to SoxEffectsChain buffer.
auto
ptr
=
tensor_
.
data_ptr
<
int32_t
>
();
std
::
copy
(
ptr
,
ptr
+
*
osamp
,
obuf
);
// Convert to sox_sample_t (int32_t) and write to buffer
SOX_SAMPLE_LOCALS
;
const
auto
dtype
=
tensor_
.
dtype
();
if
(
dtype
==
torch
::
kFloat32
)
{
auto
ptr
=
tensor_
.
data_ptr
<
float_t
>
();
for
(
size_t
i
=
0
;
i
<
*
osamp
;
++
i
)
{
obuf
[
i
]
=
SOX_FLOAT_32BIT_TO_SAMPLE
(
ptr
[
i
],
effp
->
clips
);
}
}
else
if
(
dtype
==
torch
::
kInt32
)
{
auto
ptr
=
tensor_
.
data_ptr
<
int32_t
>
();
for
(
size_t
i
=
0
;
i
<
*
osamp
;
++
i
)
{
obuf
[
i
]
=
SOX_SIGNED_32BIT_TO_SAMPLE
(
ptr
[
i
],
effp
->
clips
);
}
}
else
if
(
dtype
==
torch
::
kInt16
)
{
auto
ptr
=
tensor_
.
data_ptr
<
int16_t
>
();
for
(
size_t
i
=
0
;
i
<
*
osamp
;
++
i
)
{
obuf
[
i
]
=
SOX_SIGNED_16BIT_TO_SAMPLE
(
ptr
[
i
],
effp
->
clips
);
}
}
else
if
(
dtype
==
torch
::
kUInt8
)
{
auto
ptr
=
tensor_
.
data_ptr
<
uint8_t
>
();
for
(
size_t
i
=
0
;
i
<
*
osamp
;
++
i
)
{
obuf
[
i
]
=
SOX_UNSIGNED_8BIT_TO_SAMPLE
(
ptr
[
i
],
effp
->
clips
);
}
}
else
{
throw
std
::
runtime_error
(
"Unexpected dtype."
);
}
priv
->
index
+=
*
osamp
;
return
(
priv
->
index
==
num_samples
)
?
SOX_EOF
:
SOX_SUCCESS
;
}
...
...
@@ -430,7 +452,7 @@ int fileobj_output_flow(
fflush
(
fp
);
// Copy the encoded chunk to python object.
fileobj
->
attr
(
"write"
)(
py
::
bytes
(
*
buffer
,
*
buffer_size
));
fileobj
->
attr
(
"write"
)(
py
::
bytes
(
*
buffer
,
ftell
(
fp
)
));
// Reset FILE*
sf
->
tell_off
=
0
;
...
...
torchaudio/csrc/sox/io.cpp
View file @
c3cb2015
...
...
@@ -116,35 +116,27 @@ void save_audio_file(
torch
::
Tensor
tensor
,
int64_t
sample_rate
,
bool
channels_first
,
c10
::
optional
<
double
>
compression
,
c10
::
optional
<
std
::
string
>
format
,
c10
::
optional
<
std
::
string
>
dtype
)
{
c10
::
optional
<
double
>&
compression
,
c10
::
optional
<
std
::
string
>&
format
,
c10
::
optional
<
std
::
string
>&
encoding
,
c10
::
optional
<
int64_t
>&
bits_per_sample
)
{
validate_input_tensor
(
tensor
);
if
(
tensor
.
dtype
()
!=
torch
::
kFloat32
&&
dtype
.
has_value
())
{
throw
std
::
runtime_error
(
"dtype conversion only supported for float32 tensors"
);
}
const
auto
tgt_dtype
=
(
tensor
.
dtype
()
==
torch
::
kFloat32
&&
dtype
.
has_value
())
?
get_dtype_from_str
(
dtype
.
value
())
:
tensor
.
dtype
();
const
auto
filetype
=
[
&
]()
{
if
(
format
.
has_value
())
return
format
.
value
();
return
get_filetype
(
path
);
}();
if
(
filetype
==
"amr-nb"
)
{
const
auto
num_channels
=
tensor
.
size
(
channels_first
?
0
:
1
);
TORCH_CHECK
(
num_channels
==
1
,
"amr-nb format only supports single channel audio."
);
tensor
=
(
unnormalize_wav
(
tensor
)
/
65536
).
to
(
torch
::
kInt16
);
}
const
auto
signal_info
=
get_signalinfo
(
&
tensor
,
sample_rate
,
filetype
,
channels_first
);
const
auto
encoding_info
=
get_encodinginfo_for_save
(
filetype
,
t
gt_
dtype
,
compression
);
const
auto
encoding_info
=
get_encodinginfo_for_save
(
filetype
,
t
ensor
.
dtype
()
,
compression
,
encoding
,
bits_per_sample
);
SoxFormat
sf
(
sox_open_write
(
path
.
c_str
(),
...
...
@@ -258,19 +250,17 @@ void save_audio_fileobj(
torch
::
Tensor
tensor
,
int64_t
sample_rate
,
bool
channels_first
,
c10
::
optional
<
double
>
compression
,
std
::
string
filetype
,
c10
::
optional
<
std
::
string
>
dtype
)
{
c10
::
optional
<
double
>&
compression
,
c10
::
optional
<
std
::
string
>&
format
,
c10
::
optional
<
std
::
string
>&
encoding
,
c10
::
optional
<
int64_t
>&
bits_per_sample
)
{
validate_input_tensor
(
tensor
);
if
(
tensor
.
dtype
()
!=
torch
::
kFloat32
&&
dtype
.
has_value
())
{
if
(
!
format
.
has_value
())
{
throw
std
::
runtime_error
(
"
dtype conversion only supported for float32 tensors
"
);
"
`format` is required when saving to file object.
"
);
}
const
auto
tgt_dtype
=
(
tensor
.
dtype
()
==
torch
::
kFloat32
&&
dtype
.
has_value
())
?
get_dtype_from_str
(
dtype
.
value
())
:
tensor
.
dtype
();
const
auto
filetype
=
format
.
value
();
if
(
filetype
==
"amr-nb"
)
{
const
auto
num_channels
=
tensor
.
size
(
channels_first
?
0
:
1
);
...
...
@@ -278,12 +268,11 @@ void save_audio_fileobj(
throw
std
::
runtime_error
(
"amr-nb format only supports single channel audio."
);
}
tensor
=
(
unnormalize_wav
(
tensor
)
/
65536
).
to
(
torch
::
kInt16
);
}
const
auto
signal_info
=
get_signalinfo
(
&
tensor
,
sample_rate
,
filetype
,
channels_first
);
const
auto
encoding_info
=
get_encodinginfo_for_save
(
filetype
,
t
gt_
dtype
,
compression
);
const
auto
encoding_info
=
get_encodinginfo_for_save
(
filetype
,
t
ensor
.
dtype
()
,
compression
,
encoding
,
bits_per_sample
);
AutoReleaseBuffer
buffer
;
...
...
torchaudio/csrc/sox/io.h
View file @
c3cb2015
...
...
@@ -28,9 +28,10 @@ void save_audio_file(
torch
::
Tensor
tensor
,
int64_t
sample_rate
,
bool
channels_first
,
c10
::
optional
<
double
>
compression
,
c10
::
optional
<
std
::
string
>
format
,
c10
::
optional
<
std
::
string
>
dtype
);
c10
::
optional
<
double
>&
compression
,
c10
::
optional
<
std
::
string
>&
format
,
c10
::
optional
<
std
::
string
>&
encoding
,
c10
::
optional
<
int64_t
>&
bits_per_sample
);
#ifdef TORCH_API_INCLUDE_EXTENSION_H
...
...
@@ -51,9 +52,10 @@ void save_audio_fileobj(
torch
::
Tensor
tensor
,
int64_t
sample_rate
,
bool
channels_first
,
c10
::
optional
<
double
>
compression
,
std
::
string
filetype
,
c10
::
optional
<
std
::
string
>
dtype
);
c10
::
optional
<
double
>&
compression
,
c10
::
optional
<
std
::
string
>&
format
,
c10
::
optional
<
std
::
string
>&
encoding
,
c10
::
optional
<
int64_t
>&
bits_per_sample
);
#endif // TORCH_API_INCLUDE_EXTENSION_H
...
...
torchaudio/csrc/sox/types.cpp
0 → 100644
View file @
c3cb2015
#include <torchaudio/csrc/sox/types.h>
namespace
torchaudio
{
namespace
sox_utils
{
Format
get_format_from_string
(
const
std
::
string
&
format
)
{
if
(
format
==
"wav"
)
return
Format
::
WAV
;
if
(
format
==
"mp3"
)
return
Format
::
MP3
;
if
(
format
==
"flac"
)
return
Format
::
FLAC
;
if
(
format
==
"ogg"
||
format
==
"vorbis"
)
return
Format
::
VORBIS
;
if
(
format
==
"amr-nb"
)
return
Format
::
AMR_NB
;
if
(
format
==
"amr-wb"
)
return
Format
::
AMR_WB
;
if
(
format
==
"amb"
)
return
Format
::
AMB
;
if
(
format
==
"sph"
)
return
Format
::
SPHERE
;
std
::
ostringstream
stream
;
stream
<<
"Internal Error: unexpected format value: "
<<
format
;
throw
std
::
runtime_error
(
stream
.
str
());
}
std
::
string
to_string
(
Encoding
v
)
{
switch
(
v
)
{
case
Encoding
::
UNKNOWN
:
return
"UNKNOWN"
;
case
Encoding
::
PCM_SIGNED
:
return
"PCM_S"
;
case
Encoding
::
PCM_UNSIGNED
:
return
"PCM_U"
;
case
Encoding
::
PCM_FLOAT
:
return
"PCM_F"
;
case
Encoding
::
FLAC
:
return
"FLAC"
;
case
Encoding
::
ULAW
:
return
"ULAW"
;
case
Encoding
::
ALAW
:
return
"ALAW"
;
case
Encoding
::
MP3
:
return
"MP3"
;
case
Encoding
::
VORBIS
:
return
"VORBIS"
;
case
Encoding
::
AMR_WB
:
return
"AMR_WB"
;
case
Encoding
::
AMR_NB
:
return
"AMR_NB"
;
case
Encoding
::
OPUS
:
return
"OPUS"
;
default:
throw
std
::
runtime_error
(
"Internal Error: unexpected encoding."
);
}
}
Encoding
get_encoding_from_option
(
const
c10
::
optional
<
std
::
string
>&
encoding
)
{
if
(
!
encoding
.
has_value
())
return
Encoding
::
NOT_PROVIDED
;
std
::
string
v
=
encoding
.
value
();
if
(
v
==
"PCM_S"
)
return
Encoding
::
PCM_SIGNED
;
if
(
v
==
"PCM_U"
)
return
Encoding
::
PCM_UNSIGNED
;
if
(
v
==
"PCM_F"
)
return
Encoding
::
PCM_FLOAT
;
if
(
v
==
"ULAW"
)
return
Encoding
::
ULAW
;
if
(
v
==
"ALAW"
)
return
Encoding
::
ALAW
;
std
::
ostringstream
stream
;
stream
<<
"Internal Error: unexpected encoding value: "
<<
v
;
throw
std
::
runtime_error
(
stream
.
str
());
}
BitDepth
get_bit_depth_from_option
(
const
c10
::
optional
<
int64_t
>&
bit_depth
)
{
if
(
!
bit_depth
.
has_value
())
return
BitDepth
::
NOT_PROVIDED
;
int64_t
v
=
bit_depth
.
value
();
switch
(
v
)
{
case
8
:
return
BitDepth
::
B8
;
case
16
:
return
BitDepth
::
B16
;
case
24
:
return
BitDepth
::
B24
;
case
32
:
return
BitDepth
::
B32
;
case
64
:
return
BitDepth
::
B64
;
default:
{
std
::
ostringstream
s
;
s
<<
"Internal Error: unexpected bit depth value: "
<<
v
;
throw
std
::
runtime_error
(
s
.
str
());
}
}
}
}
// namespace sox_utils
}
// namespace torchaudio
torchaudio/csrc/sox/types.h
0 → 100644
View file @
c3cb2015
#ifndef TORCHAUDIO_SOX_TYPES_H
#define TORCHAUDIO_SOX_TYPES_H
#include <torch/script.h>
namespace
torchaudio
{
namespace
sox_utils
{
enum
class
Format
{
WAV
,
MP3
,
FLAC
,
VORBIS
,
AMR_NB
,
AMR_WB
,
AMB
,
SPHERE
,
};
Format
get_format_from_string
(
const
std
::
string
&
format
);
enum
class
Encoding
{
NOT_PROVIDED
,
UNKNOWN
,
PCM_SIGNED
,
PCM_UNSIGNED
,
PCM_FLOAT
,
FLAC
,
ULAW
,
ALAW
,
MP3
,
VORBIS
,
AMR_WB
,
AMR_NB
,
OPUS
,
};
std
::
string
to_string
(
Encoding
v
);
Encoding
get_encoding_from_option
(
const
c10
::
optional
<
std
::
string
>&
encoding
);
enum
class
BitDepth
:
unsigned
{
NOT_PROVIDED
=
0
,
B8
=
8
,
B16
=
16
,
B24
=
24
,
B32
=
32
,
B64
=
64
,
};
BitDepth
get_bit_depth_from_option
(
const
c10
::
optional
<
int64_t
>&
bit_depth
);
}
// namespace sox_utils
}
// namespace torchaudio
#endif
torchaudio/csrc/sox/utils.cpp
View file @
c3cb2015
#include <c10/core/ScalarType.h>
#include <sox.h>
#include <torchaudio/csrc/sox/types.h>
#include <torchaudio/csrc/sox/utils.h>
namespace
torchaudio
{
...
...
@@ -163,22 +164,32 @@ torch::Tensor convert_to_tensor(
const
caffe2
::
TypeMeta
dtype
,
const
bool
normalize
,
const
bool
channels_first
)
{
auto
t
=
torch
::
from_blob
(
buffer
,
{
num_samples
/
num_channels
,
num_channels
},
torch
::
kInt32
);
// Note: Tensor created from_blob does not own data but borrwos
// So make sure to create a new copy after processing samples.
torch
::
Tensor
t
;
uint64_t
dummy
;
SOX_SAMPLE_LOCALS
;
if
(
normalize
||
dtype
==
torch
::
kFloat32
)
{
t
=
t
.
to
(
torch
::
kFloat32
);
t
*=
(
t
>
0
)
/
2147483647.
+
(
t
<
0
)
/
2147483648.
;
t
=
torch
::
empty
(
{
num_samples
/
num_channels
,
num_channels
},
torch
::
kFloat32
);
auto
ptr
=
t
.
data_ptr
<
float_t
>
();
for
(
int32_t
i
=
0
;
i
<
num_samples
;
++
i
)
{
ptr
[
i
]
=
SOX_SAMPLE_TO_FLOAT_32BIT
(
buffer
[
i
],
dummy
);
}
}
else
if
(
dtype
==
torch
::
kInt32
)
{
t
=
t
.
clone
();
t
=
torch
::
from_blob
(
buffer
,
{
num_samples
/
num_channels
,
num_channels
},
torch
::
kInt32
)
.
clone
();
}
else
if
(
dtype
==
torch
::
kInt16
)
{
t
.
floor_divide_
(
1
<<
16
);
t
=
t
.
to
(
torch
::
kInt16
);
t
=
torch
::
empty
({
num_samples
/
num_channels
,
num_channels
},
torch
::
kInt16
);
auto
ptr
=
t
.
data_ptr
<
int16_t
>
();
for
(
int32_t
i
=
0
;
i
<
num_samples
;
++
i
)
{
ptr
[
i
]
=
SOX_SAMPLE_TO_SIGNED_16BIT
(
buffer
[
i
],
dummy
);
}
}
else
if
(
dtype
==
torch
::
kUInt8
)
{
t
.
floor_divide_
(
1
<<
24
);
t
+=
128
;
t
=
t
.
to
(
torch
::
kUInt8
);
t
=
torch
::
empty
({
num_samples
/
num_channels
,
num_channels
},
torch
::
kUInt8
);
auto
ptr
=
t
.
data_ptr
<
uint8_t
>
();
for
(
int32_t
i
=
0
;
i
<
num_samples
;
++
i
)
{
ptr
[
i
]
=
SOX_SAMPLE_TO_UNSIGNED_8BIT
(
buffer
[
i
],
dummy
);
}
}
else
{
throw
std
::
runtime_error
(
"Unsupported dtype."
);
}
...
...
@@ -188,63 +199,181 @@ torch::Tensor convert_to_tensor(
return
t
.
contiguous
();
}
torch
::
Tensor
unnormalize_wav
(
const
torch
::
Tensor
input_tensor
)
{
const
auto
dtype
=
input_tensor
.
dtype
();
auto
tensor
=
input_tensor
;
if
(
dtype
==
torch
::
kFloat32
)
{
double
multi_pos
=
2147483647.
;
double
multi_neg
=
-
2147483648.
;
auto
mult
=
(
tensor
>
0
)
*
multi_pos
-
(
tensor
<
0
)
*
multi_neg
;
tensor
=
tensor
.
to
(
torch
::
dtype
(
torch
::
kFloat64
));
tensor
*=
mult
;
tensor
.
clamp_
(
multi_neg
,
multi_pos
);
tensor
=
tensor
.
to
(
torch
::
dtype
(
torch
::
kInt32
));
}
else
if
(
dtype
==
torch
::
kInt32
)
{
// already denormalized
}
else
if
(
dtype
==
torch
::
kInt16
)
{
tensor
=
tensor
.
to
(
torch
::
dtype
(
torch
::
kInt32
));
tensor
*=
((
tensor
!=
0
)
*
65536
);
}
else
if
(
dtype
==
torch
::
kUInt8
)
{
tensor
=
tensor
.
to
(
torch
::
dtype
(
torch
::
kInt32
));
tensor
-=
128
;
tensor
*=
16777216
;
}
else
{
throw
std
::
runtime_error
(
"Unexpected dtype."
);
}
return
tensor
;
}
const
std
::
string
get_filetype
(
const
std
::
string
path
)
{
std
::
string
ext
=
path
.
substr
(
path
.
find_last_of
(
"."
)
+
1
);
std
::
transform
(
ext
.
begin
(),
ext
.
end
(),
ext
.
begin
(),
::
tolower
);
return
ext
;
}
sox_encoding_t
get_encoding
(
const
std
::
string
filetype
,
const
caffe2
::
TypeMeta
dtype
)
{
if
(
filetype
==
"mp3"
)
return
SOX_ENCODING_MP3
;
if
(
filetype
==
"flac"
)
return
SOX_ENCODING_FLAC
;
if
(
filetype
==
"ogg"
||
filetype
==
"vorbis"
)
return
SOX_ENCODING_VORBIS
;
if
(
filetype
==
"wav"
||
filetype
==
"amb"
)
{
if
(
dtype
==
torch
::
kUInt8
)
return
SOX_ENCODING_UNSIGNED
;
if
(
dtype
==
torch
::
kInt16
)
return
SOX_ENCODING_SIGN2
;
if
(
dtype
==
torch
::
kInt32
)
return
SOX_ENCODING_SIGN2
;
if
(
dtype
==
torch
::
kFloat32
)
return
SOX_ENCODING_FLOAT
;
throw
std
::
runtime_error
(
"Unsupported dtype."
);
namespace
{
std
::
tuple
<
sox_encoding_t
,
unsigned
>
get_save_encoding_for_wav
(
const
std
::
string
format
,
const
caffe2
::
TypeMeta
dtype
,
const
Encoding
&
encoding
,
const
BitDepth
&
bits_per_sample
)
{
switch
(
encoding
)
{
case
Encoding
::
NOT_PROVIDED
:
switch
(
bits_per_sample
)
{
case
BitDepth
::
NOT_PROVIDED
:
if
(
dtype
==
torch
::
kFloat32
)
return
std
::
make_tuple
<>
(
SOX_ENCODING_FLOAT
,
32
);
if
(
dtype
==
torch
::
kInt32
)
return
std
::
make_tuple
<>
(
SOX_ENCODING_SIGN2
,
32
);
if
(
dtype
==
torch
::
kInt16
)
return
std
::
make_tuple
<>
(
SOX_ENCODING_SIGN2
,
16
);
if
(
dtype
==
torch
::
kUInt8
)
return
std
::
make_tuple
<>
(
SOX_ENCODING_UNSIGNED
,
8
);
throw
std
::
runtime_error
(
"Internal Error: Unexpected dtype."
);
case
BitDepth
::
B8
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_UNSIGNED
,
8
);
default:
return
std
::
make_tuple
<>
(
SOX_ENCODING_SIGN2
,
static_cast
<
unsigned
>
(
bits_per_sample
));
}
case
Encoding
::
PCM_SIGNED
:
switch
(
bits_per_sample
)
{
case
BitDepth
::
NOT_PROVIDED
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_SIGN2
,
32
);
case
BitDepth
::
B8
:
throw
std
::
runtime_error
(
format
+
" does not support 8-bit signed PCM encoding."
);
default:
return
std
::
make_tuple
<>
(
SOX_ENCODING_SIGN2
,
static_cast
<
unsigned
>
(
bits_per_sample
));
}
case
Encoding
::
PCM_UNSIGNED
:
switch
(
bits_per_sample
)
{
case
BitDepth
::
NOT_PROVIDED
:
case
BitDepth
::
B8
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_UNSIGNED
,
8
);
default:
throw
std
::
runtime_error
(
format
+
" only supports 8-bit for unsigned PCM encoding."
);
}
case
Encoding
::
PCM_FLOAT
:
switch
(
bits_per_sample
)
{
case
BitDepth
::
NOT_PROVIDED
:
case
BitDepth
::
B32
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_FLOAT
,
32
);
case
BitDepth
::
B64
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_FLOAT
,
64
);
default:
throw
std
::
runtime_error
(
format
+
" only supports 32-bit or 64-bit for floating-point PCM encoding."
);
}
case
Encoding
::
ULAW
:
switch
(
bits_per_sample
)
{
case
BitDepth
::
NOT_PROVIDED
:
case
BitDepth
::
B8
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_ULAW
,
8
);
default:
throw
std
::
runtime_error
(
format
+
" only supports 8-bit for mu-law encoding."
);
}
case
Encoding
::
ALAW
:
switch
(
bits_per_sample
)
{
case
BitDepth
::
NOT_PROVIDED
:
case
BitDepth
::
B8
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_ALAW
,
8
);
default:
throw
std
::
runtime_error
(
format
+
" only supports 8-bit for a-law encoding."
);
}
default:
throw
std
::
runtime_error
(
format
+
" does not support encoding: "
+
to_string
(
encoding
));
}
}
std
::
tuple
<
sox_encoding_t
,
unsigned
>
get_save_encoding
(
const
std
::
string
&
format
,
const
caffe2
::
TypeMeta
dtype
,
const
c10
::
optional
<
std
::
string
>&
encoding
,
const
c10
::
optional
<
int64_t
>&
bits_per_sample
)
{
const
Format
fmt
=
get_format_from_string
(
format
);
const
Encoding
enc
=
get_encoding_from_option
(
encoding
);
const
BitDepth
bps
=
get_bit_depth_from_option
(
bits_per_sample
);
switch
(
fmt
)
{
case
Format
::
WAV
:
case
Format
::
AMB
:
return
get_save_encoding_for_wav
(
format
,
dtype
,
enc
,
bps
);
case
Format
::
MP3
:
if
(
enc
!=
Encoding
::
NOT_PROVIDED
)
throw
std
::
runtime_error
(
"mp3 does not support `encoding` option."
);
if
(
bps
!=
BitDepth
::
NOT_PROVIDED
)
throw
std
::
runtime_error
(
"mp3 does not support `bits_per_sample` option."
);
return
std
::
make_tuple
<>
(
SOX_ENCODING_MP3
,
16
);
case
Format
::
VORBIS
:
if
(
enc
!=
Encoding
::
NOT_PROVIDED
)
throw
std
::
runtime_error
(
"vorbis does not support `encoding` option."
);
if
(
bps
!=
BitDepth
::
NOT_PROVIDED
)
throw
std
::
runtime_error
(
"vorbis does not support `bits_per_sample` option."
);
return
std
::
make_tuple
<>
(
SOX_ENCODING_VORBIS
,
16
);
case
Format
::
AMR_NB
:
if
(
enc
!=
Encoding
::
NOT_PROVIDED
)
throw
std
::
runtime_error
(
"amr-nb does not support `encoding` option."
);
if
(
bps
!=
BitDepth
::
NOT_PROVIDED
)
throw
std
::
runtime_error
(
"amr-nb does not support `bits_per_sample` option."
);
return
std
::
make_tuple
<>
(
SOX_ENCODING_AMR_NB
,
16
);
case
Format
::
FLAC
:
if
(
enc
!=
Encoding
::
NOT_PROVIDED
)
throw
std
::
runtime_error
(
"flac does not support `encoding` option."
);
switch
(
bps
)
{
case
BitDepth
::
B32
:
case
BitDepth
::
B64
:
throw
std
::
runtime_error
(
"flac does not support `bits_per_sample` larger than 24."
);
default:
return
std
::
make_tuple
<>
(
SOX_ENCODING_FLAC
,
static_cast
<
unsigned
>
(
bps
));
}
case
Format
::
SPHERE
:
switch
(
enc
)
{
case
Encoding
::
NOT_PROVIDED
:
case
Encoding
::
PCM_SIGNED
:
switch
(
bps
)
{
case
BitDepth
::
NOT_PROVIDED
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_SIGN2
,
32
);
default:
return
std
::
make_tuple
<>
(
SOX_ENCODING_SIGN2
,
static_cast
<
unsigned
>
(
bps
));
}
case
Encoding
::
PCM_UNSIGNED
:
throw
std
::
runtime_error
(
"sph does not support unsigned integer PCM."
);
case
Encoding
::
PCM_FLOAT
:
throw
std
::
runtime_error
(
"sph does not support floating point PCM."
);
case
Encoding
::
ULAW
:
switch
(
bps
)
{
case
BitDepth
::
NOT_PROVIDED
:
case
BitDepth
::
B8
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_ULAW
,
8
);
default:
throw
std
::
runtime_error
(
"sph only supports 8-bit for mu-law encoding."
);
}
case
Encoding
::
ALAW
:
switch
(
bps
)
{
case
BitDepth
::
NOT_PROVIDED
:
case
BitDepth
::
B8
:
return
std
::
make_tuple
<>
(
SOX_ENCODING_ALAW
,
8
);
default:
return
std
::
make_tuple
<>
(
SOX_ENCODING_ALAW
,
static_cast
<
unsigned
>
(
bps
));
}
default:
throw
std
::
runtime_error
(
"sph does not support encoding: "
+
encoding
.
value
());
}
default:
throw
std
::
runtime_error
(
"Unsupported format: "
+
format
);
}
if
(
filetype
==
"sph"
)
return
SOX_ENCODING_SIGN2
;
if
(
filetype
==
"amr-nb"
)
return
SOX_ENCODING_AMR_NB
;
throw
std
::
runtime_error
(
"Unsupported file type: "
+
filetype
);
}
unsigned
get_precision
(
...
...
@@ -270,14 +399,13 @@ unsigned get_precision(
if
(
filetype
==
"sph"
)
return
32
;
if
(
filetype
==
"amr-nb"
)
{
TORCH_INTERNAL_ASSERT
(
dtype
==
torch
::
kInt16
,
"When saving to AMR-NB format, the input tensor must be int16 type."
);
return
16
;
}
throw
std
::
runtime_error
(
"Unsupported file type: "
+
filetype
);
}
}
// namespace
sox_signalinfo_t
get_signalinfo
(
const
torch
::
Tensor
*
waveform
,
const
int64_t
sample_rate
,
...
...
@@ -325,12 +453,15 @@ sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype) {
}
sox_encodinginfo_t
get_encodinginfo_for_save
(
const
std
::
string
f
iletype
,
const
std
::
string
&
f
ormat
,
const
caffe2
::
TypeMeta
dtype
,
c10
::
optional
<
double
>&
compression
)
{
const
c10
::
optional
<
double
>&
compression
,
const
c10
::
optional
<
std
::
string
>&
encoding
,
const
c10
::
optional
<
int64_t
>&
bits_per_sample
)
{
auto
enc
=
get_save_encoding
(
format
,
dtype
,
encoding
,
bits_per_sample
);
return
sox_encodinginfo_t
{
/*encoding=*/
get_encoding
(
filetype
,
dtype
),
/*bits_per_sample=*/
get_precision
(
filetype
,
dtype
),
/*encoding=*/
std
::
get
<
0
>
(
enc
),
/*bits_per_sample=*/
std
::
get
<
1
>
(
enc
),
/*compression=*/
compression
.
value_or
(
HUGE_VAL
),
/*reverse_bytes=*/
sox_option_default
,
/*reverse_nibbles=*/
sox_option_default
,
...
...
torchaudio/csrc/sox/utils.h
View file @
c3cb2015
...
...
@@ -93,11 +93,6 @@ torch::Tensor convert_to_tensor(
const
bool
normalize
,
const
bool
channels_first
);
///
/// Convert float32/int32/int16/uint8 Tensor to int32 for Torch -> Sox
/// conversion.
torch
::
Tensor
unnormalize_wav
(
const
torch
::
Tensor
);
/// Extract extension from file path
const
std
::
string
get_filetype
(
const
std
::
string
path
);
...
...
@@ -113,9 +108,11 @@ sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype);
/// Get sox_encodinginfo_t for saving to file/file object
sox_encodinginfo_t
get_encodinginfo_for_save
(
const
std
::
string
f
iletype
,
const
std
::
string
&
f
ormat
,
const
caffe2
::
TypeMeta
dtype
,
c10
::
optional
<
double
>&
compression
);
const
c10
::
optional
<
double
>&
compression
,
const
c10
::
optional
<
std
::
string
>&
encoding
,
const
c10
::
optional
<
int64_t
>&
bits_per_sample
);
#ifdef TORCH_API_INCLUDE_EXTENSION_H
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment