Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
463a8b2c
Unverified
Commit
463a8b2c
authored
Jan 07, 2021
by
moto
Committed by
GitHub
Jan 07, 2021
Browse files
Support file-like object in load function (#1158)
parent
422edb18
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
578 additions
and
27 deletions
+578
-27
.circleci/unittest/linux/scripts/install.sh
.circleci/unittest/linux/scripts/install.sh
+2
-2
test/torchaudio_unittest/common_utils/__init__.py
test/torchaudio_unittest/common_utils/__init__.py
+1
-0
test/torchaudio_unittest/common_utils/case_utils.py
test/torchaudio_unittest/common_utils/case_utils.py
+28
-0
test/torchaudio_unittest/soundfile_backend/load_test.py
test/torchaudio_unittest/soundfile_backend/load_test.py
+56
-0
test/torchaudio_unittest/sox_io_backend/load_test.py
test/torchaudio_unittest/sox_io_backend/load_test.py
+163
-1
torchaudio/backend/_soundfile_backend.py
torchaudio/backend/_soundfile_backend.py
+6
-4
torchaudio/backend/sox_io_backend.py
torchaudio/backend/sox_io_backend.py
+20
-5
torchaudio/csrc/pybind.cpp
torchaudio/csrc/pybind.cpp
+6
-0
torchaudio/csrc/sox/effects.cpp
torchaudio/csrc/sox/effects.cpp
+83
-0
torchaudio/csrc/sox/effects.h
torchaudio/csrc/sox/effects.h
+15
-0
torchaudio/csrc/sox/effects_chain.cpp
torchaudio/csrc/sox/effects_chain.cpp
+132
-5
torchaudio/csrc/sox/effects_chain.h
torchaudio/csrc/sox/effects_chain.h
+14
-0
torchaudio/csrc/sox/io.cpp
torchaudio/csrc/sox/io.cpp
+32
-6
torchaudio/csrc/sox/io.h
torchaudio/csrc/sox/io.h
+16
-0
torchaudio/csrc/sox/utils.cpp
torchaudio/csrc/sox/utils.cpp
+3
-3
torchaudio/csrc/sox/utils.h
torchaudio/csrc/sox/utils.h
+1
-1
No files found.
.circleci/unittest/linux/scripts/install.sh
View file @
463a8b2c
...
...
@@ -46,9 +46,9 @@ if [ "${os}" == Linux ] ; then
# TODO: move this to docker
apt
install
-y
-q
libsndfile1
conda
install
-y
-c
conda-forge codecov pytest pytest-cov
pip
install
kaldi-io
'librosa>=0.8.0'
parameterized SoundFile scipy
pip
install
kaldi-io
'librosa>=0.8.0'
parameterized SoundFile scipy
'requests>=2.20'
else
# Note: installing librosa via pip fail because it will try to compile numba.
conda
install
-y
-c
conda-forge codecov pytest pytest-cov
'librosa>=0.8.0'
parameterized scipy
pip
install
kaldi-io SoundFile
pip
install
kaldi-io SoundFile
'requests>=2.20'
fi
test/torchaudio_unittest/common_utils/__init__.py
View file @
463a8b2c
...
...
@@ -8,6 +8,7 @@ from .backend_utils import (
)
from
.case_utils
import
(
TempDirMixin
,
HttpServerMixin
,
TestBaseMixin
,
PytorchTestCase
,
TorchaudioTestCase
,
...
...
test/torchaudio_unittest/common_utils/case_utils.py
View file @
463a8b2c
import
shutil
import
os.path
import
subprocess
import
tempfile
import
time
import
unittest
import
torch
...
...
@@ -40,6 +42,32 @@ class TempDirMixin:
return
path
class
HttpServerMixin
(
TempDirMixin
):
"""Mixin that serves temporary directory as web server
This class creates temporary directory and serve the directory as HTTP service.
The server is up through the execution of all the test suite defined under the subclass.
"""
_proc
=
None
_port
=
8000
@
classmethod
def
setUpClass
(
cls
):
super
().
setUpClass
()
cls
.
_proc
=
subprocess
.
Popen
(
[
'python'
,
'-m'
,
'http.server'
,
f
'
{
cls
.
_port
}
'
],
cwd
=
cls
.
get_base_temp_dir
())
time
.
sleep
(
1.0
)
@
classmethod
def
tearDownClass
(
cls
):
super
().
tearDownClass
()
cls
.
_proc
.
kill
()
def
get_url
(
self
,
*
route
):
return
f
'http://localhost:
{
self
.
_port
}
/
{
self
.
id
()
}
/
{
"/"
.
join
(
route
)
}
'
class
TestBaseMixin
:
"""Mixin to provide consistent way to define device/dtype/backend aware TestCase"""
dtype
=
None
...
...
test/torchaudio_unittest/soundfile_backend/load_test.py
View file @
463a8b2c
import
os
import
tarfile
from
unittest.mock
import
patch
import
torch
...
...
@@ -299,3 +300,58 @@ class TestLoadFormat(TempDirMixin, PytorchTestCase):
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_flac
(
self
,
format_
):
self
.
_test_format
(
format_
)
@
skipIfNoModule
(
"soundfile"
)
class
TestFileObject
(
TempDirMixin
,
PytorchTestCase
):
def
_test_fileobj
(
self
,
ext
):
"""Loading audio via file-like object works"""
sample_rate
=
16000
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
data
=
get_wav_data
(
'float32'
,
num_channels
=
2
).
numpy
().
T
soundfile
.
write
(
path
,
data
,
sample_rate
)
expected
=
soundfile
.
read
(
path
,
dtype
=
'float32'
)[
0
].
T
with
open
(
path
,
'rb'
)
as
fileobj
:
found
,
sr
=
soundfile_backend
.
load
(
fileobj
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
def
test_fileobj_wav
(
self
):
"""Loading audio via file-like object works"""
self
.
_test_fileobj
(
'wav'
)
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_fileobj_flac
(
self
):
"""Loading audio via file-like object works"""
self
.
_test_fileobj
(
'flac'
)
def
_test_tarfile
(
self
,
ext
):
"""Loading audio via file-like object works"""
sample_rate
=
16000
audio_file
=
f
'test.
{
ext
}
'
audio_path
=
self
.
get_temp_path
(
audio_file
)
archive_path
=
self
.
get_temp_path
(
'archive.tar.gz'
)
data
=
get_wav_data
(
'float32'
,
num_channels
=
2
).
numpy
().
T
soundfile
.
write
(
audio_path
,
data
,
sample_rate
)
expected
=
soundfile
.
read
(
audio_path
,
dtype
=
'float32'
)[
0
].
T
with
tarfile
.
TarFile
(
archive_path
,
'w'
)
as
tarobj
:
tarobj
.
add
(
audio_path
,
arcname
=
audio_file
)
with
tarfile
.
TarFile
(
archive_path
,
'r'
)
as
tarobj
:
fileobj
=
tarobj
.
extractfile
(
audio_file
)
found
,
sr
=
soundfile_backend
.
load
(
fileobj
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
def
test_tarfile_wav
(
self
):
"""Loading audio via file-like object works"""
self
.
_test_tarfile
(
'wav'
)
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_tarfile_flac
(
self
):
"""Loading audio via file-like object works"""
self
.
_test_tarfile
(
'flac'
)
test/torchaudio_unittest/sox_io_backend/load_test.py
View file @
463a8b2c
import
io
import
itertools
import
tarfile
from
torchaudio.backend
import
sox_io_backend
from
parameterized
import
parameterized
from
torchaudio.backend
import
sox_io_backend
from
torchaudio._internal
import
module_utils
as
_mod_utils
from
torchaudio_unittest.common_utils
import
(
TempDirMixin
,
HttpServerMixin
,
PytorchTestCase
,
skipIfNoExec
,
skipIfNoExtension
,
skipIfNoModule
,
get_asset_path
,
get_wav_data
,
load_wav
,
...
...
@@ -19,6 +24,10 @@ from .common import (
)
if
_mod_utils
.
is_module_available
(
"requests"
):
import
requests
class
LoadTestBase
(
TempDirMixin
,
PytorchTestCase
):
def
assert_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
duration
):
"""`sox_io_backend.load` can load wav format correctly.
...
...
@@ -369,3 +378,156 @@ class TestLoadWithoutExtension(PytorchTestCase):
path
=
get_asset_path
(
"mp3_without_ext"
)
_
,
sr
=
sox_io_backend
.
load
(
path
,
format
=
"mp3"
)
assert
sr
==
16000
@
skipIfNoExtension
@
skipIfNoExec
(
'sox'
)
class
TestFileObject
(
TempDirMixin
,
PytorchTestCase
):
"""
In this test suite, the result of file-like object input is compared against file path input,
because `load` function is rigrously tested for file path inputs to match libsox's result,
"""
@
parameterized
.
expand
([
(
'wav'
,
None
),
(
'mp3'
,
128
),
(
'mp3'
,
320
),
(
'flac'
,
0
),
(
'flac'
,
5
),
(
'flac'
,
8
),
(
'vorbis'
,
-
1
),
(
'vorbis'
,
10
),
(
'amb'
,
None
),
])
def
test_fileobj
(
self
,
ext
,
compression
):
"""Loading audio via file object returns the same result as via file path."""
sample_rate
=
16000
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
2
,
compression
=
compression
)
expected
,
_
=
sox_io_backend
.
load
(
path
)
with
open
(
path
,
'rb'
)
as
fileobj
:
found
,
sr
=
sox_io_backend
.
load
(
fileobj
,
format
=
format_
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
@
parameterized
.
expand
([
(
'wav'
,
None
),
(
'mp3'
,
128
),
(
'mp3'
,
320
),
(
'flac'
,
0
),
(
'flac'
,
5
),
(
'flac'
,
8
),
(
'vorbis'
,
-
1
),
(
'vorbis'
,
10
),
(
'amb'
,
None
),
])
def
test_bytesio
(
self
,
ext
,
compression
):
"""Loading audio via BytesIO object returns the same result as via file path."""
sample_rate
=
16000
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
2
,
compression
=
compression
)
expected
,
_
=
sox_io_backend
.
load
(
path
)
with
open
(
path
,
'rb'
)
as
file_
:
fileobj
=
io
.
BytesIO
(
file_
.
read
())
found
,
sr
=
sox_io_backend
.
load
(
fileobj
,
format
=
format_
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
@
parameterized
.
expand
([
(
'wav'
,
None
),
(
'mp3'
,
128
),
(
'mp3'
,
320
),
(
'flac'
,
0
),
(
'flac'
,
5
),
(
'flac'
,
8
),
(
'vorbis'
,
-
1
),
(
'vorbis'
,
10
),
(
'amb'
,
None
),
])
def
test_tarfile
(
self
,
ext
,
compression
):
"""Loading compressed audio via file-like object returns the same result as via file path."""
sample_rate
=
16000
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
audio_file
=
f
'test.
{
ext
}
'
audio_path
=
self
.
get_temp_path
(
audio_file
)
archive_path
=
self
.
get_temp_path
(
'archive.tar.gz'
)
sox_utils
.
gen_audio_file
(
audio_path
,
sample_rate
,
num_channels
=
2
,
compression
=
compression
)
expected
,
_
=
sox_io_backend
.
load
(
audio_path
)
with
tarfile
.
TarFile
(
archive_path
,
'w'
)
as
tarobj
:
tarobj
.
add
(
audio_path
,
arcname
=
audio_file
)
with
tarfile
.
TarFile
(
archive_path
,
'r'
)
as
tarobj
:
fileobj
=
tarobj
.
extractfile
(
audio_file
)
found
,
sr
=
sox_io_backend
.
load
(
fileobj
,
format
=
format_
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
@
skipIfNoExtension
@
skipIfNoExec
(
'sox'
)
@
skipIfNoModule
(
"requests"
)
class
TestFileObjectHttp
(
HttpServerMixin
,
PytorchTestCase
):
@
parameterized
.
expand
([
(
'wav'
,
None
),
(
'mp3'
,
128
),
(
'mp3'
,
320
),
(
'flac'
,
0
),
(
'flac'
,
5
),
(
'flac'
,
8
),
(
'vorbis'
,
-
1
),
(
'vorbis'
,
10
),
(
'amb'
,
None
),
])
def
test_requests
(
self
,
ext
,
compression
):
sample_rate
=
16000
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
audio_file
=
f
'test.
{
ext
}
'
audio_path
=
self
.
get_temp_path
(
audio_file
)
sox_utils
.
gen_audio_file
(
audio_path
,
sample_rate
,
num_channels
=
2
,
compression
=
compression
)
expected
,
_
=
sox_io_backend
.
load
(
audio_path
)
url
=
self
.
get_url
(
audio_file
)
with
requests
.
get
(
url
,
stream
=
True
)
as
resp
:
found
,
sr
=
sox_io_backend
.
load
(
resp
.
raw
,
format
=
format_
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
0
,
1
,
10
,
100
,
1000
],
[
-
1
,
1
,
10
,
100
,
1000
],
)),
name_func
=
name_func
)
def
test_frame
(
self
,
frame_offset
,
num_frames
):
"""num_frames and frame_offset correctly specify the region of data"""
sample_rate
=
8000
audio_file
=
'test.wav'
audio_path
=
self
.
get_temp_path
(
audio_file
)
original
=
get_wav_data
(
'float32'
,
num_channels
=
2
)
save_wav
(
audio_path
,
original
,
sample_rate
)
frame_end
=
None
if
num_frames
==
-
1
else
frame_offset
+
num_frames
expected
=
original
[:,
frame_offset
:
frame_end
]
url
=
self
.
get_url
(
audio_file
)
with
requests
.
get
(
url
,
stream
=
True
)
as
resp
:
found
,
sr
=
sox_io_backend
.
load
(
resp
.
raw
,
frame_offset
,
num_frames
)
assert
sr
==
sample_rate
self
.
assertEqual
(
expected
,
found
)
torchaudio/backend/_soundfile_backend.py
View file @
463a8b2c
...
...
@@ -82,10 +82,12 @@ def load(
``[-1.0, 1.0]``.
Args:
filepath (str or pathlib.Path): Path to audio file.
This functionalso handles ``pathlib.Path`` objects, but is annotated as ``str``
for the consistency with "sox_io" backend, which has a restriction on type annotation
for TorchScript compiler compatiblity.
filepath (path-like object or file-like object):
Source of audio data.
Note:
* This argument is intentionally annotated as ``str`` only,
for the consistency with "sox_io" backend, which has a restriction
on type annotation due to TorchScript compiler compatiblity.
frame_offset (int):
Number of frames to skip before start reading data.
num_frames (int):
...
...
torchaudio/backend/sox_io_backend.py
View file @
463a8b2c
import
os
from
typing
import
Tuple
,
Optional
import
torch
...
...
@@ -5,6 +6,7 @@ from torchaudio._internal import (
module_utils
as
_mod_utils
,
)
import
torchaudio
from
.common
import
AudioMetaData
...
...
@@ -82,9 +84,17 @@ def load(
``[-1.0, 1.0]``.
Args:
filepath (str or pathlib.Path):
Path to audio file. This function also handles ``pathlib.Path`` objects, but is
annotated as ``str`` for TorchScript compiler compatibility.
filepath (path-like object or file-like object):
Source of audio data. When the function is not compiled by TorchScript,
(e.g. ``torch.jit.script``), the following types are accepted;
* ``path-like``: file path
* ``file-like``: Object with ``read(size: int) -> bytes`` method,
which returns byte string of at most ``size`` length.
When the function is compiled by TorchScript, only ``str`` type is allowed.
Note:
* This argument is intentionally annotated as ``str`` only due to
TorchScript compiler compatibility.
frame_offset (int):
Number of frames to skip before start reading data.
num_frames (int):
...
...
@@ -112,8 +122,13 @@ def load(
integer type, else ``float32`` type. If ``channels_first=True``, it has
``[channel, time]`` else ``[time, channel]``.
"""
# Cast to str in case type is `pathlib.Path`
filepath
=
str
(
filepath
)
if
not
torch
.
jit
.
is_scripting
():
if
hasattr
(
filepath
,
'read'
):
return
torchaudio
.
_torchaudio
.
load_audio_fileobj
(
filepath
,
frame_offset
,
num_frames
,
normalize
,
channels_first
,
format
)
signal
=
torch
.
ops
.
torchaudio
.
sox_io_load_audio_file
(
os
.
fspath
(
filepath
),
frame_offset
,
num_frames
,
normalize
,
channels_first
,
format
)
return
signal
.
get_tensor
(),
signal
.
get_sample_rate
()
signal
=
torch
.
ops
.
torchaudio
.
sox_io_load_audio_file
(
filepath
,
frame_offset
,
num_frames
,
normalize
,
channels_first
,
format
)
return
signal
.
get_tensor
(),
signal
.
get_sample_rate
()
...
...
torchaudio/csrc/pybind.cpp
View file @
463a8b2c
#include <torch/extension.h>
#include <torchaudio/csrc/sox/io.h>
#include <torchaudio/csrc/sox/legacy.h>
PYBIND11_MODULE
(
_torchaudio
,
m
)
{
py
::
class_
<
sox_signalinfo_t
>
(
m
,
"sox_signalinfo_t"
)
.
def
(
py
::
init
<>
())
...
...
@@ -94,4 +96,8 @@ PYBIND11_MODULE(_torchaudio, m) {
"get_info"
,
&
torch
::
audio
::
get_info
,
"Gets information about an audio file"
);
m
.
def
(
"load_audio_fileobj"
,
&
torchaudio
::
sox_io
::
load_audio_fileobj
,
"Load audio from file object."
);
}
torchaudio/csrc/sox/effects.cpp
View file @
463a8b2c
...
...
@@ -135,5 +135,88 @@ c10::intrusive_ptr<TensorSignal> apply_effects_file(
tensor
,
chain
.
getOutputSampleRate
(),
channels_first_
);
}
#ifdef TORCH_API_INCLUDE_EXTENSION_H
std
::
tuple
<
torch
::
Tensor
,
int64_t
>
apply_effects_fileobj
(
py
::
object
fileobj
,
std
::
vector
<
std
::
vector
<
std
::
string
>>
effects
,
c10
::
optional
<
bool
>&
normalize
,
c10
::
optional
<
bool
>&
channels_first
,
c10
::
optional
<
std
::
string
>&
format
)
{
// Streaming decoding over file-like object is tricky because libsox operates on FILE pointer.
// The folloing is what `sox` and `play` commands do
// - file input -> FILE pointer
// - URL input -> call wget in suprocess and pipe the data -> FILE pointer
// - stdin -> FILE pointer
//
// We want to, instead, fetch byte strings chunk by chunk, consume them, and discard.
//
// Here is the approach
// 1. Initialize sox_format_t using sox_open_mem_read, providing the initial chunk of byte string
// This will perform header-based format detection, if necessary, then fill the metadata of
// sox_format_t. Internally, sox_open_mem_read uses fmemopen, which returns FILE* which points the
// buffer of the provided byte string.
// 2. Each time sox reads a chunk from the FILE*, we update the underlying buffer in a way that it
// starts with unseen data, and append the new data read from the given fileobj.
// This will trick libsox as if it keeps reading from the FILE* continuously.
// Prepare the buffer used throughout the lifecycle of SoxEffectChain.
// Using std::string and let it manage memory.
// 4096 is minimum size requried by auto_detect_format
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L40-L48
const
size_t
in_buffer_size
=
4096
;
std
::
string
in_buffer
(
in_buffer_size
,
'x'
);
auto
*
in_buf
=
const_cast
<
char
*>
(
in_buffer
.
data
());
// Fetch the header, and copy it to the buffer.
auto
header
=
static_cast
<
std
::
string
>
(
static_cast
<
py
::
bytes
>
(
fileobj
.
attr
(
"read"
)(
4096
)));
memcpy
(
static_cast
<
void
*>
(
in_buf
),
static_cast
<
void
*>
(
const_cast
<
char
*>
(
header
.
data
())),
header
.
length
());
// Open file (this starts reading the header)
SoxFormat
sf
(
sox_open_mem_read
(
in_buf
,
in_buffer_size
,
/*signal=*/
nullptr
,
/*encoding=*/
nullptr
,
/*filetype=*/
format
.
has_value
()
?
format
.
value
().
c_str
()
:
nullptr
));
// In case of streamed data, length can be 0
validate_input_file
(
sf
,
/*check_length=*/
false
);
// Prepare output buffer
std
::
vector
<
sox_sample_t
>
out_buffer
;
out_buffer
.
reserve
(
sf
->
signal
.
length
);
// Create and run SoxEffectsChain
const
auto
dtype
=
get_dtype
(
sf
->
encoding
.
encoding
,
sf
->
signal
.
precision
);
torchaudio
::
sox_effects_chain
::
SoxEffectsChain
chain
(
/*input_encoding=*/
sf
->
encoding
,
/*output_encoding=*/
get_encodinginfo
(
"wav"
,
dtype
,
0.
));
chain
.
addInputFileObj
(
sf
,
in_buf
,
in_buffer_size
,
&
fileobj
);
for
(
const
auto
&
effect
:
effects
)
{
chain
.
addEffect
(
effect
);
}
chain
.
addOutputBuffer
(
&
out_buffer
);
chain
.
run
();
// Create tensor from buffer
bool
channels_first_
=
channels_first
.
value_or
(
true
);
auto
tensor
=
convert_to_tensor
(
/*buffer=*/
out_buffer
.
data
(),
/*num_samples=*/
out_buffer
.
size
(),
/*num_channels=*/
chain
.
getOutputNumChannels
(),
dtype
,
normalize
.
value_or
(
true
),
channels_first_
);
return
std
::
make_tuple
(
tensor
,
static_cast
<
int64_t
>
(
chain
.
getOutputSampleRate
()));
}
#endif // TORCH_API_INCLUDE_EXTENSION_H
}
// namespace sox_effects
}
// namespace torchaudio
torchaudio/csrc/sox/effects.h
View file @
463a8b2c
#ifndef TORCHAUDIO_SOX_EFFECTS_H
#define TORCHAUDIO_SOX_EFFECTS_H
#ifdef TORCH_API_INCLUDE_EXTENSION_H
#include <torch/extension.h>
#endif // TORCH_API_INCLUDE_EXTENSION_H
#include <torch/script.h>
#include <torchaudio/csrc/sox/utils.h>
...
...
@@ -22,6 +26,17 @@ c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> apply_effects_file(
c10
::
optional
<
bool
>&
channels_first
,
c10
::
optional
<
std
::
string
>&
format
);
#ifdef TORCH_API_INCLUDE_EXTENSION_H
std
::
tuple
<
torch
::
Tensor
,
int64_t
>
apply_effects_fileobj
(
py
::
object
fileobj
,
std
::
vector
<
std
::
vector
<
std
::
string
>>
effects
,
c10
::
optional
<
bool
>&
normalize
,
c10
::
optional
<
bool
>&
channels_first
,
c10
::
optional
<
std
::
string
>&
format
);
#endif // TORCH_API_INCLUDE_EXTENSION_H
}
// namespace sox_effects
}
// namespace torchaudio
...
...
torchaudio/csrc/sox/effects_chain.cpp
View file @
463a8b2c
...
...
@@ -198,7 +198,7 @@ void SoxEffectsChain::addInputTensor(TensorSignal* signal) {
priv
->
signal
=
signal
;
priv
->
index
=
0
;
if
(
sox_add_effect
(
sec_
,
e
,
&
interm_sig_
,
&
in_sig_
)
!=
SOX_SUCCESS
)
{
throw
std
::
runtime_error
(
"Failed to add effect: input_tensor"
);
throw
std
::
runtime_error
(
"
Internal Error:
Failed to add effect: input_tensor"
);
}
}
...
...
@@ -207,7 +207,7 @@ void SoxEffectsChain::addOutputBuffer(
SoxEffect
e
(
sox_create_effect
(
get_tensor_output_handler
()));
static_cast
<
TensorOutputPriv
*>
(
e
->
priv
)
->
buffer
=
output_buffer
;
if
(
sox_add_effect
(
sec_
,
e
,
&
interm_sig_
,
&
in_sig_
)
!=
SOX_SUCCESS
)
{
throw
std
::
runtime_error
(
"Failed to add effect: output_tensor"
);
throw
std
::
runtime_error
(
"
Internal Error:
Failed to add effect: output_tensor"
);
}
}
...
...
@@ -219,7 +219,7 @@ void SoxEffectsChain::addInputFile(sox_format_t* sf) {
sox_effect_options
(
e
,
1
,
opts
);
if
(
sox_add_effect
(
sec_
,
e
,
&
interm_sig_
,
&
in_sig_
)
!=
SOX_SUCCESS
)
{
std
::
ostringstream
stream
;
stream
<<
"Failed to add effect: input "
<<
sf
->
filename
;
stream
<<
"
Internal Error:
Failed to add effect: input "
<<
sf
->
filename
;
throw
std
::
runtime_error
(
stream
.
str
());
}
}
...
...
@@ -230,7 +230,7 @@ void SoxEffectsChain::addOutputFile(sox_format_t* sf) {
static_cast
<
FileOutputPriv
*>
(
e
->
priv
)
->
sf
=
sf
;
if
(
sox_add_effect
(
sec_
,
e
,
&
interm_sig_
,
&
out_sig_
)
!=
SOX_SUCCESS
)
{
std
::
ostringstream
stream
;
stream
<<
"Failed to add effect: output "
<<
sf
->
filename
;
stream
<<
"
Internal Error:
Failed to add effect: output "
<<
sf
->
filename
;
throw
std
::
runtime_error
(
stream
.
str
());
}
}
...
...
@@ -266,7 +266,7 @@ void SoxEffectsChain::addEffect(const std::vector<std::string> effect) {
if
(
sox_add_effect
(
sec_
,
e
,
&
interm_sig_
,
&
in_sig_
)
!=
SOX_SUCCESS
)
{
std
::
ostringstream
stream
;
stream
<<
"Failed to add effect:
\"
"
<<
name
;
stream
<<
"
Internal Error:
Failed to add effect:
\"
"
<<
name
;
for
(
size_t
i
=
1
;
i
<
num_args
;
++
i
)
{
stream
<<
" "
<<
effect
[
i
];
}
...
...
@@ -283,5 +283,132 @@ int64_t SoxEffectsChain::getOutputSampleRate() {
return
interm_sig_
.
rate
;
}
#ifdef TORCH_API_INCLUDE_EXTENSION_H
namespace
{
/// helper classes for passing file-like object to SoxEffectChain
struct
FileObjInputPriv
{
sox_format_t
*
sf
;
py
::
object
*
fileobj
;
char
*
buffer
;
uint64_t
buffer_size
;
};
/// Callback function to feed byte string
/// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/sox.h#L1268-L1278
int
fileobj_input_drain
(
sox_effect_t
*
effp
,
sox_sample_t
*
obuf
,
size_t
*
osamp
)
{
auto
priv
=
static_cast
<
FileObjInputPriv
*>
(
effp
->
priv
);
auto
sf
=
priv
->
sf
;
auto
fileobj
=
priv
->
fileobj
;
auto
buffer
=
priv
->
buffer
;
auto
buffer_size
=
priv
->
buffer_size
;
// 1. Refresh the buffer
//
// NOTE:
// Since the underlying FILE* was opened with `fmemopen`, the only way
// libsox detect EOF is reaching the end of the buffer. (null byte won't help)
// Therefore we need to align the content at the end of buffer, otherwise,
// libsox will keep reading the content beyond intended length.
//
// Before:
//
// |<--------consumed------->|<-remaining->|
// |*************************|-------------|
// ^ ftell
//
// After:
//
// |<-offset->|<-remaining->|<--new data-->|
// |**********|-------------|++++++++++++++|
// ^ ftell
const
auto
num_consumed
=
sf
->
tell_off
;
const
auto
num_remain
=
buffer_size
-
num_consumed
;
// 1.1. First, we fetch the data to see if there is data to fill the buffer
py
::
bytes
chunk_
=
fileobj
->
attr
(
"read"
)(
num_consumed
);
const
auto
num_refill
=
py
::
len
(
chunk_
);
const
auto
offset
=
buffer_size
-
(
num_remain
+
num_refill
);
if
(
num_refill
>
num_consumed
)
{
std
::
ostringstream
message
;
message
<<
"Tried to read up to "
<<
num_consumed
<<
" bytes but, "
<<
"recieved "
<<
num_refill
<<
" bytes. "
<<
"The given object does not confirm to read protocol of file object."
;
throw
std
::
runtime_error
(
message
.
str
());
}
// 1.2. Move the unconsumed data towards the beginning of buffer.
if
(
num_remain
)
{
auto
src
=
static_cast
<
void
*>
(
buffer
+
num_consumed
);
auto
dst
=
static_cast
<
void
*>
(
buffer
+
offset
);
memmove
(
dst
,
src
,
num_remain
);
}
// 1.3. Refill the remaining buffer.
if
(
num_refill
)
{
auto
chunk
=
static_cast
<
std
::
string
>
(
chunk_
);
auto
src
=
static_cast
<
void
*>
(
const_cast
<
char
*>
(
chunk
.
c_str
()));
auto
dst
=
buffer
+
offset
+
num_remain
;
memcpy
(
dst
,
src
,
num_refill
);
}
// 1.4. Set the file pointer to the new offset
sf
->
tell_off
=
offset
;
fseek
((
FILE
*
)
sf
->
fp
,
offset
,
SEEK_SET
);
// 2. Perform decoding operation
// The following part is practically same as "input" effect
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/input.c#L30-L48
// Ensure that it's a multiple of the number of channels
*
osamp
-=
*
osamp
%
effp
->
out_signal
.
channels
;
// Read up to *osamp samples into obuf;
// store the actual number read back to *osamp
*
osamp
=
sox_read
(
sf
,
obuf
,
*
osamp
);
return
*
osamp
?
SOX_SUCCESS
:
SOX_EOF
;
}
sox_effect_handler_t
*
get_fileobj_input_handler
()
{
static
sox_effect_handler_t
handler
{
/*name=*/
"input_fileobj_object"
,
/*usage=*/
NULL
,
/*flags=*/
SOX_EFF_MCHAN
,
/*getopts=*/
NULL
,
/*start=*/
NULL
,
/*flow=*/
NULL
,
/*drain=*/
fileobj_input_drain
,
/*stop=*/
NULL
,
/*kill=*/
NULL
,
/*priv_size=*/
sizeof
(
FileObjInputPriv
)};
return
&
handler
;
}
}
// namespace
void
SoxEffectsChain
::
addInputFileObj
(
sox_format_t
*
sf
,
char
*
buffer
,
uint64_t
buffer_size
,
py
::
object
*
fileobj
)
{
in_sig_
=
sf
->
signal
;
interm_sig_
=
in_sig_
;
SoxEffect
e
(
sox_create_effect
(
get_fileobj_input_handler
()));
auto
priv
=
static_cast
<
FileObjInputPriv
*>
(
e
->
priv
);
priv
->
sf
=
sf
;
priv
->
fileobj
=
fileobj
;
priv
->
buffer
=
buffer
;
priv
->
buffer_size
=
buffer_size
;
if
(
sox_add_effect
(
sec_
,
e
,
&
interm_sig_
,
&
in_sig_
)
!=
SOX_SUCCESS
)
{
throw
std
::
runtime_error
(
"Internal Error: Failed to add effect: input fileobj"
);
}
}
#endif // TORCH_API_INCLUDE_EXTENSION_H
}
// namespace sox_effects_chain
}
// namespace torchaudio
torchaudio/csrc/sox/effects_chain.h
View file @
463a8b2c
...
...
@@ -4,6 +4,10 @@
#include <sox.h>
#include <torchaudio/csrc/sox/utils.h>
#ifdef TORCH_API_INCLUDE_EXTENSION_H
#include <torch/extension.h>
#endif // TORCH_API_INCLUDE_EXTENSION_H
namespace
torchaudio
{
namespace
sox_effects_chain
{
...
...
@@ -33,6 +37,16 @@ class SoxEffectsChain {
void
addEffect
(
const
std
::
vector
<
std
::
string
>
effect
);
int64_t
getOutputNumChannels
();
int64_t
getOutputSampleRate
();
#ifdef TORCH_API_INCLUDE_EXTENSION_H
void
addInputFileObj
(
sox_format_t
*
sf
,
char
*
buffer
,
uint64_t
buffer_size
,
py
::
object
*
fileobj
);
#endif // TORCH_API_INCLUDE_EXTENSION_H
};
}
// namespace sox_effects_chain
...
...
torchaudio/csrc/sox/io.cpp
View file @
463a8b2c
...
...
@@ -49,13 +49,11 @@ c10::intrusive_ptr<SignalInfo> get_info(
static_cast
<
int64_t
>
(
sf
->
signal
.
length
/
sf
->
signal
.
channels
));
}
c10
::
intrusive_ptr
<
TensorSignal
>
load_audio_file
(
const
std
::
string
&
path
,
namespace
{
std
::
vector
<
std
::
vector
<
std
::
string
>>
get_effects
(
c10
::
optional
<
int64_t
>&
frame_offset
,
c10
::
optional
<
int64_t
>&
num_frames
,
c10
::
optional
<
bool
>&
normalize
,
c10
::
optional
<
bool
>&
channels_first
,
c10
::
optional
<
std
::
string
>&
format
)
{
c10
::
optional
<
int64_t
>&
num_frames
)
{
const
auto
offset
=
frame_offset
.
value_or
(
0
);
if
(
offset
<
0
)
{
throw
std
::
runtime_error
(
...
...
@@ -79,7 +77,19 @@ c10::intrusive_ptr<TensorSignal> load_audio_file(
os_offset
<<
offset
<<
"s"
;
effects
.
emplace_back
(
std
::
vector
<
std
::
string
>
{
"trim"
,
os_offset
.
str
()});
}
return
effects
;
}
}
// namespace
c10
::
intrusive_ptr
<
TensorSignal
>
load_audio_file
(
const
std
::
string
&
path
,
c10
::
optional
<
int64_t
>&
frame_offset
,
c10
::
optional
<
int64_t
>&
num_frames
,
c10
::
optional
<
bool
>&
normalize
,
c10
::
optional
<
bool
>&
channels_first
,
c10
::
optional
<
std
::
string
>&
format
)
{
auto
effects
=
get_effects
(
frame_offset
,
num_frames
);
return
torchaudio
::
sox_effects
::
apply_effects_file
(
path
,
effects
,
normalize
,
channels_first
,
format
);
}
...
...
@@ -123,5 +133,21 @@ void save_audio_file(
chain
.
run
();
}
#ifdef TORCH_API_INCLUDE_EXTENSION_H
std
::
tuple
<
torch
::
Tensor
,
int64_t
>
load_audio_fileobj
(
py
::
object
fileobj
,
c10
::
optional
<
int64_t
>&
frame_offset
,
c10
::
optional
<
int64_t
>&
num_frames
,
c10
::
optional
<
bool
>&
normalize
,
c10
::
optional
<
bool
>&
channels_first
,
c10
::
optional
<
std
::
string
>&
format
)
{
auto
effects
=
get_effects
(
frame_offset
,
num_frames
);
return
torchaudio
::
sox_effects
::
apply_effects_fileobj
(
fileobj
,
effects
,
normalize
,
channels_first
,
format
);
}
#endif // TORCH_API_INCLUDE_EXTENSION_H
}
// namespace sox_io
}
// namespace torchaudio
torchaudio/csrc/sox/io.h
View file @
463a8b2c
#ifndef TORCHAUDIO_SOX_IO_H
#define TORCHAUDIO_SOX_IO_H
#ifdef TORCH_API_INCLUDE_EXTENSION_H
#include <torch/extension.h>
#endif // TORCH_API_INCLUDE_EXTENSION_H
#include <torch/script.h>
#include <torchaudio/csrc/sox/utils.h>
...
...
@@ -38,6 +42,18 @@ void save_audio_file(
const
c10
::
intrusive_ptr
<
torchaudio
::
sox_utils
::
TensorSignal
>&
signal
,
const
double
compression
=
0.
);
#ifdef TORCH_API_INCLUDE_EXTENSION_H
std
::
tuple
<
torch
::
Tensor
,
int64_t
>
load_audio_fileobj
(
py
::
object
fileobj
,
c10
::
optional
<
int64_t
>&
frame_offset
,
c10
::
optional
<
int64_t
>&
num_frames
,
c10
::
optional
<
bool
>&
normalize
,
c10
::
optional
<
bool
>&
channels_first
,
c10
::
optional
<
std
::
string
>&
format
);
#endif // TORCH_API_INCLUDE_EXTENSION_H
}
// namespace sox_io
}
// namespace torchaudio
...
...
torchaudio/csrc/sox/utils.cpp
View file @
463a8b2c
...
...
@@ -92,15 +92,15 @@ SoxFormat::operator sox_format_t*() const noexcept {
return
fd_
;
}
void
validate_input_file
(
const
SoxFormat
&
sf
)
{
void
validate_input_file
(
const
SoxFormat
&
sf
,
bool
check_length
)
{
if
(
static_cast
<
sox_format_t
*>
(
sf
)
==
nullptr
)
{
throw
std
::
runtime_error
(
"Error loading audio file: failed to open file."
);
}
if
(
sf
->
encoding
.
encoding
==
SOX_ENCODING_UNKNOWN
)
{
throw
std
::
runtime_error
(
"Error loading audio file: unknown encoding."
);
}
if
(
sf
->
signal
.
length
==
0
)
{
throw
std
::
runtime_error
(
"Error reading audio file: unkown length."
);
if
(
check_length
&&
sf
->
signal
.
length
==
0
)
{
throw
std
::
runtime_error
(
"Error reading audio file: unk
n
own length."
);
}
}
...
...
torchaudio/csrc/sox/utils.h
View file @
463a8b2c
...
...
@@ -67,7 +67,7 @@ struct SoxFormat {
///
/// Verify that input file is found, has known encoding, and not empty
void
validate_input_file
(
const
SoxFormat
&
sf
);
void
validate_input_file
(
const
SoxFormat
&
sf
,
bool
check_length
=
true
);
///
/// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment