Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
e7b43dde
"docs/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "8b4dcf1dfb0368bf856b361dd14ebe81b22971ac"
Unverified
Commit
e7b43dde
authored
Jul 20, 2021
by
hwangjeff
Committed by
GitHub
Jul 20, 2021
Browse files
Make buffer size for function info configurable (#1634)
parent
8ec6b873
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
78 additions
and
6 deletions
+78
-6
test/torchaudio_unittest/backend/sox_io/info_test.py
test/torchaudio_unittest/backend/sox_io/info_test.py
+52
-4
test/torchaudio_unittest/common_utils/sox_utils.py
test/torchaudio_unittest/common_utils/sox_utils.py
+3
-1
torchaudio/csrc/sox/io.cpp
torchaudio/csrc/sox/io.cpp
+4
-1
torchaudio/csrc/sox/utils.cpp
torchaudio/csrc/sox/utils.cpp
+7
-0
torchaudio/csrc/sox/utils.h
torchaudio/csrc/sox/utils.h
+2
-0
torchaudio/utils/sox_utils.py
torchaudio/utils/sox_utils.py
+10
-0
No files found.
test/torchaudio_unittest/backend/sox_io/info_test.py
View file @
e7b43dde
from
contextlib
import
contextmanager
import
io
import
io
import
os
import
os
import
itertools
import
itertools
...
@@ -5,6 +6,7 @@ import tarfile
...
@@ -5,6 +6,7 @@ import tarfile
from
parameterized
import
parameterized
from
parameterized
import
parameterized
from
torchaudio.backend
import
sox_io_backend
from
torchaudio.backend
import
sox_io_backend
from
torchaudio.utils.sox_utils
import
get_buffer_size
,
set_buffer_size
from
torchaudio._internal
import
module_utils
as
_mod_utils
from
torchaudio._internal
import
module_utils
as
_mod_utils
from
torchaudio_unittest.backend.common
import
(
from
torchaudio_unittest.backend.common
import
(
...
@@ -293,24 +295,33 @@ class TestLoadWithoutExtension(PytorchTestCase):
...
@@ -293,24 +295,33 @@ class TestLoadWithoutExtension(PytorchTestCase):
class
FileObjTestBase
(
TempDirMixin
):
class
FileObjTestBase
(
TempDirMixin
):
def
_gen_file
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
def
_gen_file
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
*
,
comments
=
None
):
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
bit_depth
=
sox_utils
.
get_bit_depth
(
dtype
)
bit_depth
=
sox_utils
.
get_bit_depth
(
dtype
)
duration
=
num_frames
/
sample_rate
duration
=
num_frames
/
sample_rate
comment_file
=
self
.
_gen_comment_file
(
comments
)
if
comments
else
None
sox_utils
.
gen_audio_file
(
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
num_channels
,
path
,
sample_rate
,
num_channels
=
num_channels
,
encoding
=
sox_utils
.
get_encoding
(
dtype
),
encoding
=
sox_utils
.
get_encoding
(
dtype
),
bit_depth
=
bit_depth
,
bit_depth
=
bit_depth
,
duration
=
duration
)
duration
=
duration
,
comment_file
=
comment_file
,
)
return
path
return
path
def
_gen_comment_file
(
self
,
comments
):
comment_path
=
self
.
get_temp_path
(
"comment.txt"
)
with
open
(
comment_path
,
"w"
)
as
file_
:
file_
.
writelines
(
comments
)
return
comment_path
@
skipIfNoSox
@
skipIfNoSox
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExec
(
'sox'
)
class
TestFileObject
(
FileObjTestBase
,
PytorchTestCase
):
class
TestFileObject
(
FileObjTestBase
,
PytorchTestCase
):
def
_query_fileobj
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
def
_query_fileobj
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
*
,
comments
=
None
):
path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
comments
=
comments
)
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
with
open
(
path
,
'rb'
)
as
fileobj
:
with
open
(
path
,
'rb'
)
as
fileobj
:
return
sox_io_backend
.
info
(
fileobj
,
format_
)
return
sox_io_backend
.
info
(
fileobj
,
format_
)
...
@@ -333,6 +344,15 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
...
@@ -333,6 +344,15 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
fileobj
=
tarobj
.
extractfile
(
audio_file
)
fileobj
=
tarobj
.
extractfile
(
audio_file
)
return
sox_io_backend
.
info
(
fileobj
,
format_
)
return
sox_io_backend
.
info
(
fileobj
,
format_
)
@
contextmanager
def
_set_buffer_size
(
self
,
buffer_size
):
try
:
original_buffer_size
=
get_buffer_size
()
set_buffer_size
(
buffer_size
)
yield
finally
:
set_buffer_size
(
original_buffer_size
)
@
parameterized
.
expand
([
@
parameterized
.
expand
([
(
'wav'
,
"float32"
),
(
'wav'
,
"float32"
),
(
'wav'
,
"int32"
),
(
'wav'
,
"int32"
),
...
@@ -359,6 +379,34 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
...
@@ -359,6 +379,34 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
parameterized
.
expand
([
(
'vorbis'
,
"float32"
),
])
def
test_fileobj_large_header
(
self
,
ext
,
dtype
):
"""
For audio file with header size exceeding default buffer size:
- Querying audio via file object without enlarging buffer size fails.
- Querying audio via file object after enlarging buffer size succeeds.
"""
sample_rate
=
16000
num_frames
=
3
*
sample_rate
num_channels
=
2
comments
=
"metadata="
+
" "
.
join
([
"value"
for
_
in
range
(
1000
)])
with
self
.
assertRaisesRegex
(
RuntimeError
,
"^Error loading audio file:"
):
sinfo
=
self
.
_query_fileobj
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
comments
=
comments
)
with
self
.
_set_buffer_size
(
16384
):
sinfo
=
self
.
_query_fileobj
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
,
comments
=
comments
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
0
if
ext
in
[
'mp3'
,
'vorbis'
]
else
num_frames
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
parameterized
.
expand
([
@
parameterized
.
expand
([
(
'wav'
,
"float32"
),
(
'wav'
,
"float32"
),
(
'wav'
,
"int32"
),
(
'wav'
,
"int32"
),
...
...
test/torchaudio_unittest/common_utils/sox_utils.py
View file @
e7b43dde
...
@@ -25,7 +25,7 @@ def get_bit_depth(dtype):
...
@@ -25,7 +25,7 @@ def get_bit_depth(dtype):
def
gen_audio_file
(
def
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
path
,
sample_rate
,
num_channels
,
*
,
encoding
=
None
,
bit_depth
=
None
,
compression
=
None
,
attenuation
=
None
,
duration
=
1
,
*
,
encoding
=
None
,
bit_depth
=
None
,
compression
=
None
,
attenuation
=
None
,
duration
=
1
,
comment_file
=
None
,
):
):
"""Generate synthetic audio file with `sox` command."""
"""Generate synthetic audio file with `sox` command."""
if
path
.
endswith
(
'.wav'
):
if
path
.
endswith
(
'.wav'
):
...
@@ -53,6 +53,8 @@ def gen_audio_file(
...
@@ -53,6 +53,8 @@ def gen_audio_file(
command
+=
[
'--bits'
,
str
(
bit_depth
)]
command
+=
[
'--bits'
,
str
(
bit_depth
)]
if
encoding
is
not
None
:
if
encoding
is
not
None
:
command
+=
[
'--encoding'
,
str
(
encoding
)]
command
+=
[
'--encoding'
,
str
(
encoding
)]
if
comment_file
is
not
None
:
command
+=
[
'--comment-file'
,
str
(
comment_file
)]
command
+=
[
command
+=
[
str
(
path
),
str
(
path
),
'synth'
,
str
(
duration
),
# synthesizes for the given duration [sec]
'synth'
,
str
(
duration
),
# synthesizes for the given duration [sec]
...
...
torchaudio/csrc/sox/io.cpp
View file @
e7b43dde
...
@@ -161,7 +161,10 @@ std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_fileobj(
...
@@ -161,7 +161,10 @@ std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_fileobj(
//
//
// See:
// See:
// https://xiph.org/vorbis/doc/Vorbis_I_spec.html
// https://xiph.org/vorbis/doc/Vorbis_I_spec.html
auto
capacity
=
4096
;
const
int
kDefaultCapacityInBytes
=
4096
;
auto
capacity
=
(
sox_get_globals
()
->
bufsiz
>
kDefaultCapacityInBytes
)
?
sox_get_globals
()
->
bufsiz
:
kDefaultCapacityInBytes
;
std
::
string
buffer
(
capacity
,
'\0'
);
std
::
string
buffer
(
capacity
,
'\0'
);
auto
*
buf
=
const_cast
<
char
*>
(
buffer
.
data
());
auto
*
buf
=
const_cast
<
char
*>
(
buffer
.
data
());
auto
num_read
=
read_fileobj
(
&
fileobj
,
capacity
,
buf
);
auto
num_read
=
read_fileobj
(
&
fileobj
,
capacity
,
buf
);
...
...
torchaudio/csrc/sox/utils.cpp
View file @
e7b43dde
...
@@ -22,6 +22,10 @@ void set_buffer_size(const int64_t buffer_size) {
...
@@ -22,6 +22,10 @@ void set_buffer_size(const int64_t buffer_size) {
sox_get_globals
()
->
bufsiz
=
static_cast
<
size_t
>
(
buffer_size
);
sox_get_globals
()
->
bufsiz
=
static_cast
<
size_t
>
(
buffer_size
);
}
}
int64_t
get_buffer_size
()
{
return
sox_get_globals
()
->
bufsiz
;
}
std
::
vector
<
std
::
vector
<
std
::
string
>>
list_effects
()
{
std
::
vector
<
std
::
vector
<
std
::
string
>>
list_effects
()
{
std
::
vector
<
std
::
vector
<
std
::
string
>>
effects
;
std
::
vector
<
std
::
vector
<
std
::
string
>>
effects
;
for
(
const
sox_effect_fn_t
*
fns
=
sox_get_effect_fns
();
*
fns
;
++
fns
)
{
for
(
const
sox_effect_fn_t
*
fns
=
sox_get_effect_fns
();
*
fns
;
++
fns
)
{
...
@@ -538,6 +542,9 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
...
@@ -538,6 +542,9 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m
.
def
(
m
.
def
(
"torchaudio::sox_utils_list_write_formats"
,
"torchaudio::sox_utils_list_write_formats"
,
&
torchaudio
::
sox_utils
::
list_write_formats
);
&
torchaudio
::
sox_utils
::
list_write_formats
);
m
.
def
(
"torchaudio::sox_utils_get_buffer_size"
,
&
torchaudio
::
sox_utils
::
get_buffer_size
);
}
}
}
// namespace sox_utils
}
// namespace sox_utils
...
...
torchaudio/csrc/sox/utils.h
View file @
e7b43dde
...
@@ -24,6 +24,8 @@ void set_use_threads(const bool use_threads);
...
@@ -24,6 +24,8 @@ void set_use_threads(const bool use_threads);
void
set_buffer_size
(
const
int64_t
buffer_size
);
void
set_buffer_size
(
const
int64_t
buffer_size
);
int64_t
get_buffer_size
();
std
::
vector
<
std
::
vector
<
std
::
string
>>
list_effects
();
std
::
vector
<
std
::
vector
<
std
::
string
>>
list_effects
();
std
::
vector
<
std
::
string
>
list_read_formats
();
std
::
vector
<
std
::
string
>
list_read_formats
();
...
...
torchaudio/utils/sox_utils.py
View file @
e7b43dde
...
@@ -90,3 +90,13 @@ def list_write_formats() -> List[str]:
...
@@ -90,3 +90,13 @@ def list_write_formats() -> List[str]:
List[str]: List of supported audio formats
List[str]: List of supported audio formats
"""
"""
return
torch
.
ops
.
torchaudio
.
sox_utils_list_write_formats
()
return
torch
.
ops
.
torchaudio
.
sox_utils_list_write_formats
()
@
_mod_utils
.
requires_sox
()
def
get_buffer_size
()
->
int
:
"""Get buffer size for sox effect chain
Returns:
int: size in bytes of buffers used for processing audio.
"""
return
torch
.
ops
.
torchaudio
.
sox_utils_get_buffer_size
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment