Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
e7b43dde
Unverified
Commit
e7b43dde
authored
Jul 20, 2021
by
hwangjeff
Committed by
GitHub
Jul 20, 2021
Browse files
Make buffer size for function info configurable (#1634)
parent
8ec6b873
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
78 additions
and
6 deletions
+78
-6
test/torchaudio_unittest/backend/sox_io/info_test.py
test/torchaudio_unittest/backend/sox_io/info_test.py
+52
-4
test/torchaudio_unittest/common_utils/sox_utils.py
test/torchaudio_unittest/common_utils/sox_utils.py
+3
-1
torchaudio/csrc/sox/io.cpp
torchaudio/csrc/sox/io.cpp
+4
-1
torchaudio/csrc/sox/utils.cpp
torchaudio/csrc/sox/utils.cpp
+7
-0
torchaudio/csrc/sox/utils.h
torchaudio/csrc/sox/utils.h
+2
-0
torchaudio/utils/sox_utils.py
torchaudio/utils/sox_utils.py
+10
-0
No files found.
test/torchaudio_unittest/backend/sox_io/info_test.py
View file @
e7b43dde
from
contextlib
import
contextmanager
import
io
import
io
import
os
import
os
import
itertools
import
itertools
...
@@ -5,6 +6,7 @@ import tarfile
...
@@ -5,6 +6,7 @@ import tarfile
from
parameterized
import
parameterized
from
parameterized
import
parameterized
from
torchaudio.backend
import
sox_io_backend
from
torchaudio.backend
import
sox_io_backend
from
torchaudio.utils.sox_utils
import
get_buffer_size
,
set_buffer_size
from
torchaudio._internal
import
module_utils
as
_mod_utils
from
torchaudio._internal
import
module_utils
as
_mod_utils
from
torchaudio_unittest.backend.common
import
(
from
torchaudio_unittest.backend.common
import
(
...
@@ -293,24 +295,33 @@ class TestLoadWithoutExtension(PytorchTestCase):
...
@@ -293,24 +295,33 @@ class TestLoadWithoutExtension(PytorchTestCase):
class
FileObjTestBase
(
TempDirMixin
):
class
FileObjTestBase
(
TempDirMixin
):
def
_gen_file
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
def
_gen_file
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
*
,
comments
=
None
):
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
bit_depth
=
sox_utils
.
get_bit_depth
(
dtype
)
bit_depth
=
sox_utils
.
get_bit_depth
(
dtype
)
duration
=
num_frames
/
sample_rate
duration
=
num_frames
/
sample_rate
comment_file
=
self
.
_gen_comment_file
(
comments
)
if
comments
else
None
sox_utils
.
gen_audio_file
(
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
num_channels
,
path
,
sample_rate
,
num_channels
=
num_channels
,
encoding
=
sox_utils
.
get_encoding
(
dtype
),
encoding
=
sox_utils
.
get_encoding
(
dtype
),
bit_depth
=
bit_depth
,
bit_depth
=
bit_depth
,
duration
=
duration
)
duration
=
duration
,
comment_file
=
comment_file
,
)
return
path
return
path
def
_gen_comment_file
(
self
,
comments
):
comment_path
=
self
.
get_temp_path
(
"comment.txt"
)
with
open
(
comment_path
,
"w"
)
as
file_
:
file_
.
writelines
(
comments
)
return
comment_path
@
skipIfNoSox
@
skipIfNoSox
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExec
(
'sox'
)
class
TestFileObject
(
FileObjTestBase
,
PytorchTestCase
):
class
TestFileObject
(
FileObjTestBase
,
PytorchTestCase
):
def
_query_fileobj
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
def
_query_fileobj
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
*
,
comments
=
None
):
path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
comments
=
comments
)
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
with
open
(
path
,
'rb'
)
as
fileobj
:
with
open
(
path
,
'rb'
)
as
fileobj
:
return
sox_io_backend
.
info
(
fileobj
,
format_
)
return
sox_io_backend
.
info
(
fileobj
,
format_
)
...
@@ -333,6 +344,15 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
...
@@ -333,6 +344,15 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
fileobj
=
tarobj
.
extractfile
(
audio_file
)
fileobj
=
tarobj
.
extractfile
(
audio_file
)
return
sox_io_backend
.
info
(
fileobj
,
format_
)
return
sox_io_backend
.
info
(
fileobj
,
format_
)
@
contextmanager
def
_set_buffer_size
(
self
,
buffer_size
):
try
:
original_buffer_size
=
get_buffer_size
()
set_buffer_size
(
buffer_size
)
yield
finally
:
set_buffer_size
(
original_buffer_size
)
@
parameterized
.
expand
([
@
parameterized
.
expand
([
(
'wav'
,
"float32"
),
(
'wav'
,
"float32"
),
(
'wav'
,
"int32"
),
(
'wav'
,
"int32"
),
...
@@ -359,6 +379,34 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
...
@@ -359,6 +379,34 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
parameterized
.
expand
([
(
'vorbis'
,
"float32"
),
])
def
test_fileobj_large_header
(
self
,
ext
,
dtype
):
"""
For audio file with header size exceeding default buffer size:
- Querying audio via file object without enlarging buffer size fails.
- Querying audio via file object after enlarging buffer size succeeds.
"""
sample_rate
=
16000
num_frames
=
3
*
sample_rate
num_channels
=
2
comments
=
"metadata="
+
" "
.
join
([
"value"
for
_
in
range
(
1000
)])
with
self
.
assertRaisesRegex
(
RuntimeError
,
"^Error loading audio file:"
):
sinfo
=
self
.
_query_fileobj
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
comments
=
comments
)
with
self
.
_set_buffer_size
(
16384
):
sinfo
=
self
.
_query_fileobj
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
comments
=
comments
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
0
if
ext
in
[
'mp3'
,
'vorbis'
]
else
num_frames
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
parameterized
.
expand
([
@
parameterized
.
expand
([
(
'wav'
,
"float32"
),
(
'wav'
,
"float32"
),
(
'wav'
,
"int32"
),
(
'wav'
,
"int32"
),
...
...
test/torchaudio_unittest/common_utils/sox_utils.py
View file @
e7b43dde
...
@@ -25,7 +25,7 @@ def get_bit_depth(dtype):
...
@@ -25,7 +25,7 @@ def get_bit_depth(dtype):
def
gen_audio_file
(
def
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
path
,
sample_rate
,
num_channels
,
*
,
encoding
=
None
,
bit_depth
=
None
,
compression
=
None
,
attenuation
=
None
,
duration
=
1
,
*
,
encoding
=
None
,
bit_depth
=
None
,
compression
=
None
,
attenuation
=
None
,
duration
=
1
,
comment_file
=
None
,
):
):
"""Generate synthetic audio file with `sox` command."""
"""Generate synthetic audio file with `sox` command."""
if
path
.
endswith
(
'.wav'
):
if
path
.
endswith
(
'.wav'
):
...
@@ -53,6 +53,8 @@ def gen_audio_file(
...
@@ -53,6 +53,8 @@ def gen_audio_file(
command
+=
[
'--bits'
,
str
(
bit_depth
)]
command
+=
[
'--bits'
,
str
(
bit_depth
)]
if
encoding
is
not
None
:
if
encoding
is
not
None
:
command
+=
[
'--encoding'
,
str
(
encoding
)]
command
+=
[
'--encoding'
,
str
(
encoding
)]
if
comment_file
is
not
None
:
command
+=
[
'--comment-file'
,
str
(
comment_file
)]
command
+=
[
command
+=
[
str
(
path
),
str
(
path
),
'synth'
,
str
(
duration
),
# synthesizes for the given duration [sec]
'synth'
,
str
(
duration
),
# synthesizes for the given duration [sec]
...
...
torchaudio/csrc/sox/io.cpp
View file @
e7b43dde
...
@@ -161,7 +161,10 @@ std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_fileobj(
...
@@ -161,7 +161,10 @@ std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_fileobj(
//
//
// See:
// See:
// https://xiph.org/vorbis/doc/Vorbis_I_spec.html
// https://xiph.org/vorbis/doc/Vorbis_I_spec.html
auto
capacity
=
4096
;
const
int
kDefaultCapacityInBytes
=
4096
;
auto
capacity
=
(
sox_get_globals
()
->
bufsiz
>
kDefaultCapacityInBytes
)
?
sox_get_globals
()
->
bufsiz
:
kDefaultCapacityInBytes
;
std
::
string
buffer
(
capacity
,
'\0'
);
std
::
string
buffer
(
capacity
,
'\0'
);
auto
*
buf
=
const_cast
<
char
*>
(
buffer
.
data
());
auto
*
buf
=
const_cast
<
char
*>
(
buffer
.
data
());
auto
num_read
=
read_fileobj
(
&
fileobj
,
capacity
,
buf
);
auto
num_read
=
read_fileobj
(
&
fileobj
,
capacity
,
buf
);
...
...
torchaudio/csrc/sox/utils.cpp
View file @
e7b43dde
...
@@ -22,6 +22,10 @@ void set_buffer_size(const int64_t buffer_size) {
...
@@ -22,6 +22,10 @@ void set_buffer_size(const int64_t buffer_size) {
sox_get_globals
()
->
bufsiz
=
static_cast
<
size_t
>
(
buffer_size
);
sox_get_globals
()
->
bufsiz
=
static_cast
<
size_t
>
(
buffer_size
);
}
}
int64_t
get_buffer_size
()
{
return
sox_get_globals
()
->
bufsiz
;
}
std
::
vector
<
std
::
vector
<
std
::
string
>>
list_effects
()
{
std
::
vector
<
std
::
vector
<
std
::
string
>>
list_effects
()
{
std
::
vector
<
std
::
vector
<
std
::
string
>>
effects
;
std
::
vector
<
std
::
vector
<
std
::
string
>>
effects
;
for
(
const
sox_effect_fn_t
*
fns
=
sox_get_effect_fns
();
*
fns
;
++
fns
)
{
for
(
const
sox_effect_fn_t
*
fns
=
sox_get_effect_fns
();
*
fns
;
++
fns
)
{
...
@@ -538,6 +542,9 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
...
@@ -538,6 +542,9 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m
.
def
(
m
.
def
(
"torchaudio::sox_utils_list_write_formats"
,
"torchaudio::sox_utils_list_write_formats"
,
&
torchaudio
::
sox_utils
::
list_write_formats
);
&
torchaudio
::
sox_utils
::
list_write_formats
);
m
.
def
(
"torchaudio::sox_utils_get_buffer_size"
,
&
torchaudio
::
sox_utils
::
get_buffer_size
);
}
}
}
// namespace sox_utils
}
// namespace sox_utils
...
...
torchaudio/csrc/sox/utils.h
View file @
e7b43dde
...
@@ -24,6 +24,8 @@ void set_use_threads(const bool use_threads);
...
@@ -24,6 +24,8 @@ void set_use_threads(const bool use_threads);
void
set_buffer_size
(
const
int64_t
buffer_size
);
void
set_buffer_size
(
const
int64_t
buffer_size
);
int64_t
get_buffer_size
();
std
::
vector
<
std
::
vector
<
std
::
string
>>
list_effects
();
std
::
vector
<
std
::
vector
<
std
::
string
>>
list_effects
();
std
::
vector
<
std
::
string
>
list_read_formats
();
std
::
vector
<
std
::
string
>
list_read_formats
();
...
...
torchaudio/utils/sox_utils.py
View file @
e7b43dde
...
@@ -90,3 +90,13 @@ def list_write_formats() -> List[str]:
...
@@ -90,3 +90,13 @@ def list_write_formats() -> List[str]:
List[str]: List of supported audio formats
List[str]: List of supported audio formats
"""
"""
return
torch
.
ops
.
torchaudio
.
sox_utils_list_write_formats
()
return
torch
.
ops
.
torchaudio
.
sox_utils_list_write_formats
()
@
_mod_utils
.
requires_sox
()
def
get_buffer_size
()
->
int
:
"""Get buffer size for sox effect chain
Returns:
int: size in bytes of buffers used for processing audio.
"""
return
torch
.
ops
.
torchaudio
.
sox_utils_get_buffer_size
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment