Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
b33c539c
Unverified
Commit
b33c539c
authored
Jan 26, 2021
by
moto
Committed by
GitHub
Jan 26, 2021
Browse files
Fix clang-format CI job (#1198)
parent
99ed7183
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
671 additions
and
284 deletions
+671
-284
.circleci/unittest/linux/scripts/run-clang-format.py
.circleci/unittest/linux/scripts/run-clang-format.py
+340
-0
.circleci/unittest/linux/scripts/run_style_checks.sh
.circleci/unittest/linux/scripts/run_style_checks.sh
+16
-10
torchaudio/csrc/pybind.cpp
torchaudio/csrc/pybind.cpp
+80
-76
torchaudio/csrc/sox/effects.cpp
torchaudio/csrc/sox/effects.cpp
+22
-16
torchaudio/csrc/sox/effects_chain.cpp
torchaudio/csrc/sox/effects_chain.cpp
+74
-64
torchaudio/csrc/sox/io.cpp
torchaudio/csrc/sox/io.cpp
+15
-8
torchaudio/csrc/sox/legacy.cpp
torchaudio/csrc/sox/legacy.cpp
+6
-10
torchaudio/csrc/sox/legacy.h
torchaudio/csrc/sox/legacy.h
+7
-5
torchaudio/csrc/sox/register.cpp
torchaudio/csrc/sox/register.cpp
+9
-7
torchaudio/csrc/sox/utils.cpp
torchaudio/csrc/sox/utils.cpp
+19
-15
torchaudio/csrc/sox/utils.h
torchaudio/csrc/sox/utils.h
+1
-1
torchaudio/csrc/transducer.cpp
torchaudio/csrc/transducer.cpp
+82
-72
No files found.
.circleci/unittest/linux/scripts/run-clang-format.py
0 → 100755
View file @
b33c539c
#!/usr/bin/env python
"""A wrapper script around clang-format, suitable for linting multiple files
and to use for continuous integration.
This is an alternative API for the clang-format command line.
It runs over multiple files and directories in parallel.
A diff output is produced and a sensible exit code is returned.
"""
import
argparse
import
codecs
import
difflib
import
fnmatch
import
io
import
multiprocessing
import
os
import
signal
import
subprocess
import
sys
import
traceback
from
functools
import
partial
try
:
from
subprocess
import
DEVNULL
# py3k
except
ImportError
:
DEVNULL
=
open
(
os
.
devnull
,
"wb"
)
DEFAULT_EXTENSIONS
=
'c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu'
class
ExitStatus
:
SUCCESS
=
0
DIFF
=
1
TROUBLE
=
2
def
list_files
(
files
,
recursive
=
False
,
extensions
=
None
,
exclude
=
None
):
if
extensions
is
None
:
extensions
=
[]
if
exclude
is
None
:
exclude
=
[]
out
=
[]
for
file
in
files
:
if
recursive
and
os
.
path
.
isdir
(
file
):
for
dirpath
,
dnames
,
fnames
in
os
.
walk
(
file
):
fpaths
=
[
os
.
path
.
join
(
dirpath
,
fname
)
for
fname
in
fnames
]
for
pattern
in
exclude
:
# os.walk() supports trimming down the dnames list
# by modifying it in-place,
# to avoid unnecessary directory listings.
dnames
[:]
=
[
x
for
x
in
dnames
if
not
fnmatch
.
fnmatch
(
os
.
path
.
join
(
dirpath
,
x
),
pattern
)
]
fpaths
=
[
x
for
x
in
fpaths
if
not
fnmatch
.
fnmatch
(
x
,
pattern
)
]
for
f
in
fpaths
:
ext
=
os
.
path
.
splitext
(
f
)[
1
][
1
:]
if
ext
in
extensions
:
out
.
append
(
f
)
else
:
out
.
append
(
file
)
return
out
def
make_diff
(
file
,
original
,
reformatted
):
return
list
(
difflib
.
unified_diff
(
original
,
reformatted
,
fromfile
=
'{}
\t
(original)'
.
format
(
file
),
tofile
=
'{}
\t
(reformatted)'
.
format
(
file
),
n
=
3
))
class
DiffError
(
Exception
):
def
__init__
(
self
,
message
,
errs
=
None
):
super
(
DiffError
,
self
).
__init__
(
message
)
self
.
errs
=
errs
or
[]
class
UnexpectedError
(
Exception
):
def
__init__
(
self
,
message
,
exc
=
None
):
super
(
UnexpectedError
,
self
).
__init__
(
message
)
self
.
formatted_traceback
=
traceback
.
format_exc
()
self
.
exc
=
exc
def
run_clang_format_diff_wrapper
(
args
,
file
):
try
:
ret
=
run_clang_format_diff
(
args
,
file
)
return
ret
except
DiffError
:
raise
except
Exception
as
e
:
raise
UnexpectedError
(
'{}: {}: {}'
.
format
(
file
,
e
.
__class__
.
__name__
,
e
),
e
)
def
run_clang_format_diff
(
args
,
file
):
try
:
with
io
.
open
(
file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
original
=
f
.
readlines
()
except
IOError
as
exc
:
raise
DiffError
(
str
(
exc
))
invocation
=
[
args
.
clang_format_executable
,
file
]
# Use of utf-8 to decode the process output.
#
# Hopefully, this is the correct thing to do.
#
# It's done due to the following assumptions (which may be incorrect):
# - clang-format will returns the bytes read from the files as-is,
# without conversion, and it is already assumed that the files use utf-8.
# - if the diagnostics were internationalized, they would use utf-8:
# > Adding Translations to Clang
# >
# > Not possible yet!
# > Diagnostic strings should be written in UTF-8,
# > the client can translate to the relevant code page if needed.
# > Each translation completely replaces the format string
# > for the diagnostic.
# > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation
try
:
proc
=
subprocess
.
Popen
(
invocation
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
universal_newlines
=
True
,
encoding
=
'utf-8'
)
except
OSError
as
exc
:
raise
DiffError
(
"Command '{}' failed to start: {}"
.
format
(
subprocess
.
list2cmdline
(
invocation
),
exc
)
)
proc_stdout
=
proc
.
stdout
proc_stderr
=
proc
.
stderr
# hopefully the stderr pipe won't get full and block the process
outs
=
list
(
proc_stdout
.
readlines
())
errs
=
list
(
proc_stderr
.
readlines
())
proc
.
wait
()
if
proc
.
returncode
:
raise
DiffError
(
"Command '{}' returned non-zero exit status {}"
.
format
(
subprocess
.
list2cmdline
(
invocation
),
proc
.
returncode
),
errs
,
)
return
make_diff
(
file
,
original
,
outs
),
errs
def
bold_red
(
s
):
return
'
\x1b
[1m
\x1b
[31m'
+
s
+
'
\x1b
[0m'
def
colorize
(
diff_lines
):
def
bold
(
s
):
return
'
\x1b
[1m'
+
s
+
'
\x1b
[0m'
def
cyan
(
s
):
return
'
\x1b
[36m'
+
s
+
'
\x1b
[0m'
def
green
(
s
):
return
'
\x1b
[32m'
+
s
+
'
\x1b
[0m'
def
red
(
s
):
return
'
\x1b
[31m'
+
s
+
'
\x1b
[0m'
for
line
in
diff_lines
:
if
line
[:
4
]
in
[
'--- '
,
'+++ '
]:
yield
bold
(
line
)
elif
line
.
startswith
(
'@@ '
):
yield
cyan
(
line
)
elif
line
.
startswith
(
'+'
):
yield
green
(
line
)
elif
line
.
startswith
(
'-'
):
yield
red
(
line
)
else
:
yield
line
def
print_diff
(
diff_lines
,
use_color
):
if
use_color
:
diff_lines
=
colorize
(
diff_lines
)
sys
.
stdout
.
writelines
(
diff_lines
)
def
print_trouble
(
prog
,
message
,
use_colors
):
error_text
=
'error:'
if
use_colors
:
error_text
=
bold_red
(
error_text
)
print
(
"{}: {} {}"
.
format
(
prog
,
error_text
,
message
),
file
=
sys
.
stderr
)
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
'--clang-format-executable'
,
metavar
=
'EXECUTABLE'
,
help
=
'path to the clang-format executable'
,
default
=
'clang-format'
)
parser
.
add_argument
(
'--extensions'
,
help
=
'comma separated list of file extensions (default: {})'
.
format
(
DEFAULT_EXTENSIONS
),
default
=
DEFAULT_EXTENSIONS
)
parser
.
add_argument
(
'-r'
,
'--recursive'
,
action
=
'store_true'
,
help
=
'run recursively over directories'
)
parser
.
add_argument
(
'files'
,
metavar
=
'file'
,
nargs
=
'+'
)
parser
.
add_argument
(
'-q'
,
'--quiet'
,
action
=
'store_true'
)
parser
.
add_argument
(
'-j'
,
metavar
=
'N'
,
type
=
int
,
default
=
0
,
help
=
'run N clang-format jobs in parallel'
' (default number of cpus + 1)'
)
parser
.
add_argument
(
'--color'
,
default
=
'auto'
,
choices
=
[
'auto'
,
'always'
,
'never'
],
help
=
'show colored diff (default: auto)'
)
parser
.
add_argument
(
'-e'
,
'--exclude'
,
metavar
=
'PATTERN'
,
action
=
'append'
,
default
=
[],
help
=
'exclude paths matching the given glob-like pattern(s)'
' from recursive search'
)
args
=
parser
.
parse_args
()
# use default signal handling, like diff return SIGINT value on ^C
# https://bugs.python.org/issue14229#msg156446
signal
.
signal
(
signal
.
SIGINT
,
signal
.
SIG_DFL
)
try
:
signal
.
SIGPIPE
except
AttributeError
:
# compatibility, SIGPIPE does not exist on Windows
pass
else
:
signal
.
signal
(
signal
.
SIGPIPE
,
signal
.
SIG_DFL
)
colored_stdout
=
False
colored_stderr
=
False
if
args
.
color
==
'always'
:
colored_stdout
=
True
colored_stderr
=
True
elif
args
.
color
==
'auto'
:
colored_stdout
=
sys
.
stdout
.
isatty
()
colored_stderr
=
sys
.
stderr
.
isatty
()
version_invocation
=
[
args
.
clang_format_executable
,
str
(
"--version"
)]
try
:
subprocess
.
check_call
(
version_invocation
,
stdout
=
DEVNULL
)
except
subprocess
.
CalledProcessError
as
e
:
print_trouble
(
parser
.
prog
,
str
(
e
),
use_colors
=
colored_stderr
)
return
ExitStatus
.
TROUBLE
except
OSError
as
e
:
print_trouble
(
parser
.
prog
,
"Command '{}' failed to start: {}"
.
format
(
subprocess
.
list2cmdline
(
version_invocation
),
e
),
use_colors
=
colored_stderr
,
)
return
ExitStatus
.
TROUBLE
retcode
=
ExitStatus
.
SUCCESS
files
=
list_files
(
args
.
files
,
recursive
=
args
.
recursive
,
exclude
=
args
.
exclude
,
extensions
=
args
.
extensions
.
split
(
','
))
if
not
files
:
return
njobs
=
args
.
j
if
njobs
==
0
:
njobs
=
multiprocessing
.
cpu_count
()
+
1
njobs
=
min
(
len
(
files
),
njobs
)
if
njobs
==
1
:
# execute directly instead of in a pool,
# less overhead, simpler stacktraces
it
=
(
run_clang_format_diff_wrapper
(
args
,
file
)
for
file
in
files
)
pool
=
None
else
:
pool
=
multiprocessing
.
Pool
(
njobs
)
it
=
pool
.
imap_unordered
(
partial
(
run_clang_format_diff_wrapper
,
args
),
files
)
while
True
:
try
:
outs
,
errs
=
next
(
it
)
except
StopIteration
:
break
except
DiffError
as
e
:
print_trouble
(
parser
.
prog
,
str
(
e
),
use_colors
=
colored_stderr
)
retcode
=
ExitStatus
.
TROUBLE
sys
.
stderr
.
writelines
(
e
.
errs
)
except
UnexpectedError
as
e
:
print_trouble
(
parser
.
prog
,
str
(
e
),
use_colors
=
colored_stderr
)
sys
.
stderr
.
write
(
e
.
formatted_traceback
)
retcode
=
ExitStatus
.
TROUBLE
# stop at the first unexpected error,
# something could be very wrong,
# don't process all files unnecessarily
if
pool
:
pool
.
terminate
()
break
else
:
sys
.
stderr
.
writelines
(
errs
)
if
outs
==
[]:
continue
if
not
args
.
quiet
:
print_diff
(
outs
,
use_color
=
colored_stdout
)
if
retcode
==
ExitStatus
.
SUCCESS
:
retcode
=
ExitStatus
.
DIFF
return
retcode
if
__name__
==
'__main__'
:
sys
.
exit
(
main
())
.circleci/unittest/linux/scripts/run_style_checks.sh
View file @
b33c539c
#!/usr/bin/env bash
set
-
u
set
-
eux
root_dir
=
"
$(
git rev-parse
--show-toplevel
)
"
conda_dir
=
"
${
root_dir
}
/conda"
env_dir
=
"
${
root_dir
}
/env"
this_dir
=
"
$(
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
"
>
/dev/null 2>&1
&&
pwd
)
"
eval
"
$(
"
${
conda_dir
}
/bin/conda"
shell.bash hook
)
"
conda activate
"
${
env_dir
}
"
# 1. Install tools
conda
install
flake8
printf
"Installed flake8: "
flake8
--version
clangformat_path
=
"
${
root_dir
}
/clang-format"
curl https://oss-clang-format.s3.us-east-2.amazonaws.com/linux64/clang-format-linux64
-o
"
${
clangformat_path
}
"
chmod
+x
"
${
clangformat_path
}
"
printf
"Installed clang-fortmat"
"
${
clangformat_path
}
"
--version
# 2. Run style checks
# We want to run all the style checks even if one of them fail.
set
+e
exit_status
=
0
printf
"
\x
1b[34mRunning flake8: "
flake8
--version
printf
"
\x
1b[0m
\n
"
printf
"
\x
1b[34mRunning flake8:
\x
1b[0m
\n
"
flake8 torchaudio
test
build_tools/setup_helpers
status
=
$?
exit_status
=
"
$((
exit_status+status
))
"
...
...
@@ -30,14 +36,14 @@ if [ "${status}" -ne 0 ]; then
printf
"
\x
1b[31mflake8 failed. Check the format of Python files.
\x
1b[0m
\n
"
fi
printf
"
\x
1b[34mRunning clang-format:
"
./
clang-format
--version
printf
"
\x
1b[0m
\n
"
git
-clang-format
--binary
./
clang
-
format
origin/master
git diff
--exit-code
printf
"
\x
1b[34mRunning clang-format:
\x
1b[0m
\n
"
"
${
this_dir
}
"
/run-
clang-format
.py
\
-r
torchaudio/csrc
\
-
-clang-format
-executable
"
${
clangformat
_path
}
"
\
&&
git diff
--exit-code
status
=
$?
exit_status
=
"
$((
exit_status+status
))
"
if
[
"
${
status
}
"
-ne
0
]
;
then
printf
"
\x
1b[31mC++ files are not formatted. Please use
git-
clang-format to format CPP files.
\x
1b[0m
\n
"
printf
"
\x
1b[31mC++ files are not formatted. Please use clang-format to format CPP files.
\x
1b[0m
\n
"
fi
exit
$exit_status
torchaudio/csrc/pybind.cpp
View file @
b33c539c
...
...
@@ -2,11 +2,12 @@
#include <torchaudio/csrc/sox/io.h>
#include <torchaudio/csrc/sox/legacy.h>
PYBIND11_MODULE
(
_torchaudio
,
m
)
{
py
::
class_
<
sox_signalinfo_t
>
(
m
,
"sox_signalinfo_t"
)
.
def
(
py
::
init
<>
())
.
def
(
"__repr__"
,
[](
const
sox_signalinfo_t
&
self
)
{
.
def
(
"__repr__"
,
[](
const
sox_signalinfo_t
&
self
)
{
std
::
stringstream
ss
;
ss
<<
"sox_signalinfo_t {
\n
"
<<
" rate-> "
<<
self
.
rate
<<
"
\n
"
...
...
@@ -24,7 +25,9 @@ PYBIND11_MODULE(_torchaudio, m) {
.
def_readwrite
(
"mult"
,
&
sox_signalinfo_t
::
mult
);
py
::
class_
<
sox_encodinginfo_t
>
(
m
,
"sox_encodinginfo_t"
)
.
def
(
py
::
init
<>
())
.
def
(
"__repr__"
,
[](
const
sox_encodinginfo_t
&
self
)
{
.
def
(
"__repr__"
,
[](
const
sox_encodinginfo_t
&
self
)
{
std
::
stringstream
ss
;
ss
<<
"sox_encodinginfo_t {
\n
"
<<
" encoding-> "
<<
self
.
encoding
<<
"
\n
"
...
...
@@ -72,7 +75,8 @@ PYBIND11_MODULE(_torchaudio, m) {
.
value
(
"SOX_ENCODING_AMR_WB"
,
sox_encoding_t
::
SOX_ENCODING_AMR_WB
)
.
value
(
"SOX_ENCODING_AMR_NB"
,
sox_encoding_t
::
SOX_ENCODING_AMR_NB
)
.
value
(
"SOX_ENCODING_LPC10"
,
sox_encoding_t
::
SOX_ENCODING_LPC10
)
//.value("SOX_ENCODING_OPUS", sox_encoding_t::SOX_ENCODING_OPUS) // creates a compile error
//.value("SOX_ENCODING_OPUS", sox_encoding_t::SOX_ENCODING_OPUS) //
// creates a compile error
.
value
(
"SOX_ENCODINGS"
,
sox_encoding_t
::
SOX_ENCODINGS
)
.
export_values
();
py
::
enum_
<
sox_option_t
>
(
m
,
"sox_option_t"
)
...
...
torchaudio/csrc/sox/effects.cpp
View file @
b33c539c
...
...
@@ -143,23 +143,27 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
c10
::
optional
<
bool
>&
normalize
,
c10
::
optional
<
bool
>&
channels_first
,
c10
::
optional
<
std
::
string
>&
format
)
{
// Streaming decoding over file-like object is tricky because libsox operates on FILE pointer.
// The folloing is what `sox` and `play` commands do
// Streaming decoding over file-like object is tricky because libsox operates
// on FILE pointer. The folloing is what `sox` and `play` commands do
// - file input -> FILE pointer
// - URL input -> call wget in suprocess and pipe the data -> FILE pointer
// - stdin -> FILE pointer
//
// We want to, instead, fetch byte strings chunk by chunk, consume them, and discard.
// We want to, instead, fetch byte strings chunk by chunk, consume them, and
// discard.
//
// Here is the approach
// 1. Initialize sox_format_t using sox_open_mem_read, providing the initial chunk of byte string
// This will perform header-based format detection, if necessary, then fill the metadata of
// sox_format_t. Internally, sox_open_mem_read uses fmemopen, which returns FILE* which points the
// buffer of the provided byte string.
// 2. Each time sox reads a chunk from the FILE*, we update the underlying buffer in a way that it
// starts with unseen data, and append the new data read from the given fileobj.
// This will trick libsox as if it keeps reading from the FILE* continuously.
// 1. Initialize sox_format_t using sox_open_mem_read, providing the initial
// chunk of byte string
// This will perform header-based format detection, if necessary, then fill
// the metadata of sox_format_t. Internally, sox_open_mem_read uses
// fmemopen, which returns FILE* which points the buffer of the provided
// byte string.
// 2. Each time sox reads a chunk from the FILE*, we update the underlying
// buffer in a way that it
// starts with unseen data, and append the new data read from the given
// fileobj. This will trick libsox as if it keeps reading from the FILE*
// continuously.
// Prepare the buffer used throughout the lifecycle of SoxEffectChain.
// Using std::string and let it manage memory.
...
...
@@ -170,9 +174,12 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
auto
*
in_buf
=
const_cast
<
char
*>
(
in_buffer
.
data
());
// Fetch the header, and copy it to the buffer.
auto
header
=
static_cast
<
std
::
string
>
(
static_cast
<
py
::
bytes
>
(
fileobj
.
attr
(
"read"
)(
4096
)));
memcpy
(
static_cast
<
void
*>
(
in_buf
),
static_cast
<
void
*>
(
const_cast
<
char
*>
(
header
.
data
())),
header
.
length
());
auto
header
=
static_cast
<
std
::
string
>
(
static_cast
<
py
::
bytes
>
(
fileobj
.
attr
(
"read"
)(
4096
)));
memcpy
(
static_cast
<
void
*>
(
in_buf
),
static_cast
<
void
*>
(
const_cast
<
char
*>
(
header
.
data
())),
header
.
length
());
// Open file (this starts reading the header)
SoxFormat
sf
(
sox_open_mem_read
(
...
...
@@ -212,8 +219,7 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
channels_first_
);
return
std
::
make_tuple
(
tensor
,
static_cast
<
int64_t
>
(
chain
.
getOutputSampleRate
()));
tensor
,
static_cast
<
int64_t
>
(
chain
.
getOutputSampleRate
()));
}
#endif // TORCH_API_INCLUDE_EXTENSION_H
...
...
torchaudio/csrc/sox/effects_chain.cpp
View file @
b33c539c
...
...
@@ -123,7 +123,8 @@ int file_output_flow(
}
sox_effect_handler_t
*
get_tensor_input_handler
()
{
static
sox_effect_handler_t
handler
{
/*name=*/
"input_tensor"
,
static
sox_effect_handler_t
handler
{
/*name=*/
"input_tensor"
,
/*usage=*/
NULL
,
/*flags=*/
SOX_EFF_MCHAN
,
/*getopts=*/
NULL
,
...
...
@@ -137,7 +138,8 @@ sox_effect_handler_t* get_tensor_input_handler() {
}
sox_effect_handler_t
*
get_tensor_output_handler
()
{
static
sox_effect_handler_t
handler
{
/*name=*/
"output_tensor"
,
static
sox_effect_handler_t
handler
{
/*name=*/
"output_tensor"
,
/*usage=*/
NULL
,
/*flags=*/
SOX_EFF_MCHAN
,
/*getopts=*/
NULL
,
...
...
@@ -151,7 +153,8 @@ sox_effect_handler_t* get_tensor_output_handler() {
}
sox_effect_handler_t
*
get_file_output_handler
()
{
static
sox_effect_handler_t
handler
{
/*name=*/
"output_file"
,
static
sox_effect_handler_t
handler
{
/*name=*/
"output_file"
,
/*usage=*/
NULL
,
/*flags=*/
SOX_EFF_MCHAN
,
/*getopts=*/
NULL
,
...
...
@@ -198,7 +201,8 @@ void SoxEffectsChain::addInputTensor(TensorSignal* signal) {
priv
->
signal
=
signal
;
priv
->
index
=
0
;
if
(
sox_add_effect
(
sec_
,
e
,
&
interm_sig_
,
&
in_sig_
)
!=
SOX_SUCCESS
)
{
throw
std
::
runtime_error
(
"Internal Error: Failed to add effect: input_tensor"
);
throw
std
::
runtime_error
(
"Internal Error: Failed to add effect: input_tensor"
);
}
}
...
...
@@ -207,7 +211,8 @@ void SoxEffectsChain::addOutputBuffer(
SoxEffect
e
(
sox_create_effect
(
get_tensor_output_handler
()));
static_cast
<
TensorOutputPriv
*>
(
e
->
priv
)
->
buffer
=
output_buffer
;
if
(
sox_add_effect
(
sec_
,
e
,
&
interm_sig_
,
&
in_sig_
)
!=
SOX_SUCCESS
)
{
throw
std
::
runtime_error
(
"Internal Error: Failed to add effect: output_tensor"
);
throw
std
::
runtime_error
(
"Internal Error: Failed to add effect: output_tensor"
);
}
}
...
...
@@ -305,7 +310,7 @@ struct FileObjOutputPriv {
/// Callback function to feed byte string
/// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/sox.h#L1268-L1278
int
fileobj_input_drain
(
sox_effect_t
*
effp
,
sox_sample_t
*
obuf
,
size_t
*
osamp
)
{
auto
priv
=
static_cast
<
FileObjInputPriv
*>
(
effp
->
priv
);
auto
priv
=
static_cast
<
FileObjInputPriv
*>
(
effp
->
priv
);
auto
sf
=
priv
->
sf
;
auto
fileobj
=
priv
->
fileobj
;
auto
buffer
=
priv
->
buffer
;
...
...
@@ -315,9 +320,9 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
//
// NOTE:
// Since the underlying FILE* was opened with `fmemopen`, the only way
// libsox detect EOF is reaching the end of the buffer. (null byte won't
help)
// Therefore we need to align the content at the end of buffer,
otherwise,
// libsox will keep reading the content beyond intended length.
// libsox detect EOF is reaching the end of the buffer. (null byte won't
//
help)
Therefore we need to align the content at the end of buffer,
//
otherwise,
libsox will keep reading the content beyond intended length.
//
// Before:
//
...
...
@@ -339,9 +344,10 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
const
auto
num_refill
=
py
::
len
(
chunk_
);
const
auto
offset
=
buffer_size
-
(
num_remain
+
num_refill
);
if
(
num_refill
>
num_consumed
)
{
if
(
num_refill
>
num_consumed
)
{
std
::
ostringstream
message
;
message
<<
"Tried to read up to "
<<
num_consumed
<<
" bytes but, "
message
<<
"Tried to read up to "
<<
num_consumed
<<
" bytes but, "
<<
"recieved "
<<
num_refill
<<
" bytes. "
<<
"The given object does not confirm to read protocol of file object."
;
throw
std
::
runtime_error
(
message
.
str
());
...
...
@@ -364,7 +370,7 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
// 1.4. Set the file pointer to the new offset
sf
->
tell_off
=
offset
;
fseek
((
FILE
*
)
sf
->
fp
,
offset
,
SEEK_SET
);
fseek
((
FILE
*
)
sf
->
fp
,
offset
,
SEEK_SET
);
// 2. Perform decoding operation
// The following part is practically same as "input" effect
...
...
@@ -377,7 +383,7 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
// store the actual number read back to *osamp
*
osamp
=
sox_read
(
sf
,
obuf
,
*
osamp
);
return
*
osamp
?
SOX_SUCCESS
:
SOX_EOF
;
return
*
osamp
?
SOX_SUCCESS
:
SOX_EOF
;
}
int
fileobj_output_flow
(
...
...
@@ -420,7 +426,8 @@ int fileobj_output_flow(
}
sox_effect_handler_t
*
get_fileobj_input_handler
()
{
static
sox_effect_handler_t
handler
{
/*name=*/
"input_fileobj_object"
,
static
sox_effect_handler_t
handler
{
/*name=*/
"input_fileobj_object"
,
/*usage=*/
NULL
,
/*flags=*/
SOX_EFF_MCHAN
,
/*getopts=*/
NULL
,
...
...
@@ -434,7 +441,8 @@ sox_effect_handler_t* get_fileobj_input_handler() {
}
sox_effect_handler_t
*
get_fileobj_output_handler
()
{
static
sox_effect_handler_t
handler
{
/*name=*/
"output_fileobj_object"
,
static
sox_effect_handler_t
handler
{
/*name=*/
"output_fileobj_object"
,
/*usage=*/
NULL
,
/*flags=*/
SOX_EFF_MCHAN
,
/*getopts=*/
NULL
,
...
...
@@ -464,7 +472,8 @@ void SoxEffectsChain::addInputFileObj(
priv
->
buffer
=
buffer
;
priv
->
buffer_size
=
buffer_size
;
if
(
sox_add_effect
(
sec_
,
e
,
&
interm_sig_
,
&
in_sig_
)
!=
SOX_SUCCESS
)
{
throw
std
::
runtime_error
(
"Internal Error: Failed to add effect: input fileobj"
);
throw
std
::
runtime_error
(
"Internal Error: Failed to add effect: input fileobj"
);
}
}
...
...
@@ -481,7 +490,8 @@ void SoxEffectsChain::addOutputFileObj(
priv
->
buffer
=
buffer
;
priv
->
buffer_size
=
buffer_size
;
if
(
sox_add_effect
(
sec_
,
e
,
&
interm_sig_
,
&
out_sig_
)
!=
SOX_SUCCESS
)
{
throw
std
::
runtime_error
(
"Internal Error: Failed to add effect: output fileobj"
);
throw
std
::
runtime_error
(
"Internal Error: Failed to add effect: output fileobj"
);
}
}
...
...
torchaudio/csrc/sox/io.cpp
View file @
b33c539c
...
...
@@ -112,8 +112,9 @@ void save_audio_file(
auto
signal
=
TensorSignal
(
tensor
,
sample_rate
,
channels_first
);
const
auto
filetype
=
[
&
](){
if
(
format
.
has_value
())
return
format
.
value
();
const
auto
filetype
=
[
&
]()
{
if
(
format
.
has_value
())
return
format
.
value
();
return
get_filetype
(
path
);
}();
if
(
filetype
==
"amr-nb"
)
{
...
...
@@ -123,7 +124,8 @@ void save_audio_file(
tensor
=
(
unnormalize_wav
(
tensor
)
/
65536
).
to
(
torch
::
kInt16
);
}
const
auto
signal_info
=
get_signalinfo
(
&
signal
,
filetype
);
const
auto
encoding_info
=
get_encodinginfo
(
filetype
,
tensor
.
dtype
(),
compression
);
const
auto
encoding_info
=
get_encodinginfo
(
filetype
,
tensor
.
dtype
(),
compression
);
SoxFormat
sf
(
sox_open_write
(
path
.
c_str
(),
...
...
@@ -161,7 +163,8 @@ std::tuple<torch::Tensor, int64_t> load_audio_fileobj(
namespace
{
// helper class to automatically release buffer, to be used by save_audio_fileobj
// helper class to automatically release buffer, to be used by
// save_audio_fileobj
struct
AutoReleaseBuffer
{
char
*
ptr
;
size_t
size
;
...
...
@@ -194,12 +197,14 @@ void save_audio_fileobj(
if
(
filetype
==
"amr-nb"
)
{
const
auto
num_channels
=
tensor
.
size
(
channels_first
?
0
:
1
);
if
(
num_channels
!=
1
)
{
throw
std
::
runtime_error
(
"amr-nb format only supports single channel audio."
);
throw
std
::
runtime_error
(
"amr-nb format only supports single channel audio."
);
}
tensor
=
(
unnormalize_wav
(
tensor
)
/
65536
).
to
(
torch
::
kInt16
);
}
const
auto
signal_info
=
get_signalinfo
(
&
signal
,
filetype
);
const
auto
encoding_info
=
get_encodinginfo
(
filetype
,
tensor
.
dtype
(),
compression
);
const
auto
encoding_info
=
get_encodinginfo
(
filetype
,
tensor
.
dtype
(),
compression
);
AutoReleaseBuffer
buffer
;
...
...
@@ -212,7 +217,8 @@ void save_audio_fileobj(
/*oob=*/
nullptr
));
if
(
static_cast
<
sox_format_t
*>
(
sf
)
==
nullptr
)
{
throw
std
::
runtime_error
(
"Error saving audio file: failed to open memory stream."
);
throw
std
::
runtime_error
(
"Error saving audio file: failed to open memory stream."
);
}
torchaudio
::
sox_effects_chain
::
SoxEffectsChain
chain
(
...
...
@@ -222,7 +228,8 @@ void save_audio_fileobj(
chain
.
addOutputFileObj
(
sf
,
&
buffer
.
ptr
,
&
buffer
.
size
,
&
fileobj
);
chain
.
run
();
// Closing the sox_format_t is necessary for flushing the last chunk to the buffer
// Closing the sox_format_t is necessary for flushing the last chunk to the
// buffer
sf
.
close
();
fileobj
.
attr
(
"write"
)(
py
::
bytes
(
buffer
.
ptr
,
buffer
.
size
));
...
...
torchaudio/csrc/sox/legacy.cpp
View file @
b33c539c
...
...
@@ -40,10 +40,7 @@ int64_t write_audio(SoxDescriptor& fd, at::Tensor tensor) {
return
samples_written
;
}
void
read_audio
(
SoxDescriptor
&
fd
,
at
::
Tensor
output
,
int64_t
buffer_length
)
{
void
read_audio
(
SoxDescriptor
&
fd
,
at
::
Tensor
output
,
int64_t
buffer_length
)
{
std
::
vector
<
sox_sample_t
>
buffer
(
buffer_length
);
int
number_of_channels
=
fd
->
signal
.
channels
;
...
...
@@ -64,8 +61,7 @@ void read_audio(
}
// namespace
std
::
tuple
<
sox_signalinfo_t
,
sox_encodinginfo_t
>
get_info
(
const
std
::
string
&
file_name
)
{
const
std
::
string
&
file_name
)
{
SoxDescriptor
fd
(
sox_open_read
(
file_name
.
c_str
(),
/*signal=*/
nullptr
,
...
...
@@ -86,7 +82,6 @@ int read_audio_file(
sox_signalinfo_t
*
si
,
sox_encodinginfo_t
*
ei
,
const
char
*
ft
)
{
SoxDescriptor
fd
(
sox_open_read
(
file_name
.
c_str
(),
si
,
ei
,
ft
));
if
(
fd
.
get
()
==
nullptr
)
{
throw
std
::
runtime_error
(
"Error opening audio file"
);
...
...
@@ -120,7 +115,8 @@ int read_audio_file(
// seek to offset point before reading data
if
(
sox_seek
(
fd
.
get
(),
offset
,
0
)
==
SOX_EOF
)
{
throw
std
::
runtime_error
(
"sox_seek reached EOF, try reducing offset or num_samples"
);
throw
std
::
runtime_error
(
"sox_seek reached EOF, try reducing offset or num_samples"
);
}
// read data and fill output tensor
...
...
torchaudio/csrc/sox/legacy.h
View file @
b33c539c
#include <sox.h>
#include <torch/torch.h>
namespace
torch
{
namespace
audio
{
namespace
torch
{
namespace
audio
{
/// Reads an audio file from the given `path` into the `output` `Tensor` and
/// returns the sample rate of the audio file.
...
...
@@ -30,9 +31,10 @@ void write_audio_file(
/// Reads an audio file from the given `path` and returns a tuple of
/// sox_signalinfo_t and sox_encodinginfo_t, which contain information about
/// the audio file such as sample rate, length, bit precision, encoding and
more.
/// Throws `std::runtime_error` if the audio file could not be opened, or
an
/// error occurred during reading of the audio data.
/// the audio file such as sample rate, length, bit precision, encoding and
///
more.
Throws `std::runtime_error` if the audio file could not be opened, or
///
an
error occurred during reading of the audio data.
std
::
tuple
<
sox_signalinfo_t
,
sox_encodinginfo_t
>
get_info
(
const
std
::
string
&
file_name
);
}}
// namespace torch::audio
}
// namespace audio
}
// namespace torch
torchaudio/csrc/sox/register.cpp
View file @
b33c539c
...
...
@@ -43,7 +43,9 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
.
def
(
"get_sample_rate"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getSampleRate
)
.
def
(
"get_num_channels"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getNumChannels
)
.
def
(
"get_num_frames"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getNumFrames
)
.
def
(
"get_bits_per_sample"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getBitsPerSample
);
.
def
(
"get_bits_per_sample"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getBitsPerSample
);
m
.
def
(
"torchaudio::sox_io_get_info"
,
&
torchaudio
::
sox_io
::
get_info
);
m
.
def
(
...
...
torchaudio/csrc/sox/utils.cpp
View file @
b33c539c
...
...
@@ -80,7 +80,9 @@ bool TensorSignal::getChannelsFirst() const {
}
SoxFormat
::
SoxFormat
(
sox_format_t
*
fd
)
noexcept
:
fd_
(
fd
)
{}
SoxFormat
::~
SoxFormat
()
{
close
();
}
SoxFormat
::~
SoxFormat
()
{
close
();
}
sox_format_t
*
SoxFormat
::
operator
->
()
const
noexcept
{
return
fd_
;
...
...
@@ -291,7 +293,8 @@ sox_signalinfo_t get_signalinfo(
sox_encodinginfo_t
get_encodinginfo
(
const
std
::
string
filetype
,
const
caffe2
::
TypeMeta
dtype
)
{
return
sox_encodinginfo_t
{
/*encoding=*/
get_encoding
(
filetype
,
dtype
),
return
sox_encodinginfo_t
{
/*encoding=*/
get_encoding
(
filetype
,
dtype
),
/*bits_per_sample=*/
get_precision
(
filetype
,
dtype
),
/*compression=*/
HUGE_VAL
,
/*reverse_bytes=*/
sox_option_default
,
...
...
@@ -304,7 +307,8 @@ sox_encodinginfo_t get_encodinginfo(
const
std
::
string
filetype
,
const
caffe2
::
TypeMeta
dtype
,
c10
::
optional
<
double
>&
compression
)
{
return
sox_encodinginfo_t
{
/*encoding=*/
get_encoding
(
filetype
,
dtype
),
return
sox_encodinginfo_t
{
/*encoding=*/
get_encoding
(
filetype
,
dtype
),
/*bits_per_sample=*/
get_precision
(
filetype
,
dtype
),
/*compression=*/
compression
.
value_or
(
HUGE_VAL
),
/*reverse_bytes=*/
sox_option_default
,
...
...
torchaudio/csrc/sox/utils.h
View file @
b33c539c
...
...
@@ -69,7 +69,7 @@ struct SoxFormat {
///
/// Verify that input file is found, has known encoding, and not empty
void
validate_input_file
(
const
SoxFormat
&
sf
,
bool
check_length
=
true
);
void
validate_input_file
(
const
SoxFormat
&
sf
,
bool
check_length
=
true
);
///
/// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32
...
...
torchaudio/csrc/transducer.cpp
View file @
b33c539c
...
...
@@ -8,7 +8,8 @@
namespace
{
int64_t
cpu_rnnt_loss
(
torch
::
Tensor
acts
,
int64_t
cpu_rnnt_loss
(
torch
::
Tensor
acts
,
torch
::
Tensor
labels
,
torch
::
Tensor
input_lengths
,
torch
::
Tensor
label_lengths
,
...
...
@@ -16,7 +17,6 @@ int64_t cpu_rnnt_loss(torch::Tensor acts,
torch
::
Tensor
grads
,
int64_t
blank_label
,
int64_t
num_threads
)
{
int
maxT
=
acts
.
size
(
1
);
int
maxU
=
acts
.
size
(
2
);
int
minibatch_size
=
acts
.
size
(
0
);
...
...
@@ -32,45 +32,54 @@ int64_t cpu_rnnt_loss(torch::Tensor acts,
options
.
num_threads
=
num_threads
;
// have to use at least one
options
.
num_threads
=
std
::
max
(
options
.
num_threads
,
(
unsigned
int
)
1
);
options
.
num_threads
=
std
::
max
(
options
.
num_threads
,
(
unsigned
int
)
1
);
size_t
cpu_size_bytes
=
0
;
switch
(
acts
.
scalar_type
())
{
case
torch
::
ScalarType
::
Float
:
{
get_workspace_size
(
maxT
,
maxU
,
minibatch_size
,
false
,
&
cpu_size_bytes
);
case
torch
::
ScalarType
::
Float
:
{
get_workspace_size
(
maxT
,
maxU
,
minibatch_size
,
false
,
&
cpu_size_bytes
);
std
::
vector
<
float
>
cpu_workspace
(
cpu_size_bytes
/
sizeof
(
float
),
0
);
compute_rnnt_loss
(
acts
.
data_ptr
<
float
>
(),
grads
.
data_ptr
<
float
>
(),
labels
.
data_ptr
<
int
>
(),
label_lengths
.
data_ptr
<
int
>
(),
input_lengths
.
data_ptr
<
int
>
(),
alphabet_size
,
minibatch_size
,
costs
.
data_ptr
<
float
>
(),
cpu_workspace
.
data
(),
options
);
compute_rnnt_loss
(
acts
.
data_ptr
<
float
>
(),
grads
.
data_ptr
<
float
>
(),
labels
.
data_ptr
<
int
>
(),
label_lengths
.
data_ptr
<
int
>
(),
input_lengths
.
data_ptr
<
int
>
(),
alphabet_size
,
minibatch_size
,
costs
.
data_ptr
<
float
>
(),
cpu_workspace
.
data
(),
options
);
return
0
;
}
case
torch
::
ScalarType
::
Double
:
{
get_workspace_size
(
maxT
,
maxU
,
minibatch_size
,
false
,
&
cpu_size_bytes
,
sizeof
(
double
));
case
torch
::
ScalarType
::
Double
:
{
get_workspace_size
(
maxT
,
maxU
,
minibatch_size
,
false
,
&
cpu_size_bytes
,
sizeof
(
double
));
std
::
vector
<
double
>
cpu_workspace
(
cpu_size_bytes
/
sizeof
(
double
),
0
);
compute_rnnt_loss_fp64
(
acts
.
data_ptr
<
double
>
(),
grads
.
data_ptr
<
double
>
(),
labels
.
data_ptr
<
int
>
(),
label_lengths
.
data_ptr
<
int
>
(),
input_lengths
.
data_ptr
<
int
>
(),
alphabet_size
,
minibatch_size
,
costs
.
data_ptr
<
double
>
(),
cpu_workspace
.
data
(),
options
);
compute_rnnt_loss_fp64
(
acts
.
data_ptr
<
double
>
(),
grads
.
data_ptr
<
double
>
(),
labels
.
data_ptr
<
int
>
(),
label_lengths
.
data_ptr
<
int
>
(),
input_lengths
.
data_ptr
<
int
>
(),
alphabet_size
,
minibatch_size
,
costs
.
data_ptr
<
double
>
(),
cpu_workspace
.
data
(),
options
);
return
0
;
}
default:
TORCH_CHECK
(
false
,
std
::
string
(
__func__
)
+
" not implemented for '"
+
toString
(
acts
.
scalar_type
())
+
"'"
);
TORCH_CHECK
(
false
,
std
::
string
(
__func__
)
+
" not implemented for '"
+
toString
(
acts
.
scalar_type
())
+
"'"
);
}
return
-
1
;
}
...
...
@@ -82,7 +91,8 @@ TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
}
TORCH_LIBRARY_FRAGMENT
(
torchaudio
,
m
)
{
m
.
def
(
"rnnt_loss(Tensor acts,"
m
.
def
(
"rnnt_loss(Tensor acts,"
"Tensor labels,"
"Tensor input_lengths,"
"Tensor label_lengths,"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment