Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
b33c539c
Unverified
Commit
b33c539c
authored
Jan 26, 2021
by
moto
Committed by
GitHub
Jan 26, 2021
Browse files
Fix clang-format CI job (#1198)
parent
99ed7183
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
671 additions
and
284 deletions
+671
-284
.circleci/unittest/linux/scripts/run-clang-format.py
.circleci/unittest/linux/scripts/run-clang-format.py
+340
-0
.circleci/unittest/linux/scripts/run_style_checks.sh
.circleci/unittest/linux/scripts/run_style_checks.sh
+16
-10
torchaudio/csrc/pybind.cpp
torchaudio/csrc/pybind.cpp
+80
-76
torchaudio/csrc/sox/effects.cpp
torchaudio/csrc/sox/effects.cpp
+22
-16
torchaudio/csrc/sox/effects_chain.cpp
torchaudio/csrc/sox/effects_chain.cpp
+74
-64
torchaudio/csrc/sox/io.cpp
torchaudio/csrc/sox/io.cpp
+15
-8
torchaudio/csrc/sox/legacy.cpp
torchaudio/csrc/sox/legacy.cpp
+6
-10
torchaudio/csrc/sox/legacy.h
torchaudio/csrc/sox/legacy.h
+7
-5
torchaudio/csrc/sox/register.cpp
torchaudio/csrc/sox/register.cpp
+9
-7
torchaudio/csrc/sox/utils.cpp
torchaudio/csrc/sox/utils.cpp
+19
-15
torchaudio/csrc/sox/utils.h
torchaudio/csrc/sox/utils.h
+1
-1
torchaudio/csrc/transducer.cpp
torchaudio/csrc/transducer.cpp
+82
-72
No files found.
.circleci/unittest/linux/scripts/run-clang-format.py
0 → 100755
View file @
b33c539c
#!/usr/bin/env python
"""A wrapper script around clang-format, suitable for linting multiple files
and to use for continuous integration.
This is an alternative API for the clang-format command line.
It runs over multiple files and directories in parallel.
A diff output is produced and a sensible exit code is returned.
"""
import
argparse
import
codecs
import
difflib
import
fnmatch
import
io
import
multiprocessing
import
os
import
signal
import
subprocess
import
sys
import
traceback
from
functools
import
partial
try
:
from
subprocess
import
DEVNULL
# py3k
except
ImportError
:
DEVNULL
=
open
(
os
.
devnull
,
"wb"
)
DEFAULT_EXTENSIONS
=
'c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu'
class
ExitStatus
:
SUCCESS
=
0
DIFF
=
1
TROUBLE
=
2
def
list_files
(
files
,
recursive
=
False
,
extensions
=
None
,
exclude
=
None
):
if
extensions
is
None
:
extensions
=
[]
if
exclude
is
None
:
exclude
=
[]
out
=
[]
for
file
in
files
:
if
recursive
and
os
.
path
.
isdir
(
file
):
for
dirpath
,
dnames
,
fnames
in
os
.
walk
(
file
):
fpaths
=
[
os
.
path
.
join
(
dirpath
,
fname
)
for
fname
in
fnames
]
for
pattern
in
exclude
:
# os.walk() supports trimming down the dnames list
# by modifying it in-place,
# to avoid unnecessary directory listings.
dnames
[:]
=
[
x
for
x
in
dnames
if
not
fnmatch
.
fnmatch
(
os
.
path
.
join
(
dirpath
,
x
),
pattern
)
]
fpaths
=
[
x
for
x
in
fpaths
if
not
fnmatch
.
fnmatch
(
x
,
pattern
)
]
for
f
in
fpaths
:
ext
=
os
.
path
.
splitext
(
f
)[
1
][
1
:]
if
ext
in
extensions
:
out
.
append
(
f
)
else
:
out
.
append
(
file
)
return
out
def
make_diff
(
file
,
original
,
reformatted
):
return
list
(
difflib
.
unified_diff
(
original
,
reformatted
,
fromfile
=
'{}
\t
(original)'
.
format
(
file
),
tofile
=
'{}
\t
(reformatted)'
.
format
(
file
),
n
=
3
))
class
DiffError
(
Exception
):
def
__init__
(
self
,
message
,
errs
=
None
):
super
(
DiffError
,
self
).
__init__
(
message
)
self
.
errs
=
errs
or
[]
class
UnexpectedError
(
Exception
):
def
__init__
(
self
,
message
,
exc
=
None
):
super
(
UnexpectedError
,
self
).
__init__
(
message
)
self
.
formatted_traceback
=
traceback
.
format_exc
()
self
.
exc
=
exc
def
run_clang_format_diff_wrapper
(
args
,
file
):
try
:
ret
=
run_clang_format_diff
(
args
,
file
)
return
ret
except
DiffError
:
raise
except
Exception
as
e
:
raise
UnexpectedError
(
'{}: {}: {}'
.
format
(
file
,
e
.
__class__
.
__name__
,
e
),
e
)
def
run_clang_format_diff
(
args
,
file
):
try
:
with
io
.
open
(
file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
original
=
f
.
readlines
()
except
IOError
as
exc
:
raise
DiffError
(
str
(
exc
))
invocation
=
[
args
.
clang_format_executable
,
file
]
# Use of utf-8 to decode the process output.
#
# Hopefully, this is the correct thing to do.
#
# It's done due to the following assumptions (which may be incorrect):
# - clang-format will returns the bytes read from the files as-is,
# without conversion, and it is already assumed that the files use utf-8.
# - if the diagnostics were internationalized, they would use utf-8:
# > Adding Translations to Clang
# >
# > Not possible yet!
# > Diagnostic strings should be written in UTF-8,
# > the client can translate to the relevant code page if needed.
# > Each translation completely replaces the format string
# > for the diagnostic.
# > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation
try
:
proc
=
subprocess
.
Popen
(
invocation
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
universal_newlines
=
True
,
encoding
=
'utf-8'
)
except
OSError
as
exc
:
raise
DiffError
(
"Command '{}' failed to start: {}"
.
format
(
subprocess
.
list2cmdline
(
invocation
),
exc
)
)
proc_stdout
=
proc
.
stdout
proc_stderr
=
proc
.
stderr
# hopefully the stderr pipe won't get full and block the process
outs
=
list
(
proc_stdout
.
readlines
())
errs
=
list
(
proc_stderr
.
readlines
())
proc
.
wait
()
if
proc
.
returncode
:
raise
DiffError
(
"Command '{}' returned non-zero exit status {}"
.
format
(
subprocess
.
list2cmdline
(
invocation
),
proc
.
returncode
),
errs
,
)
return
make_diff
(
file
,
original
,
outs
),
errs
def
bold_red
(
s
):
return
'
\x1b
[1m
\x1b
[31m'
+
s
+
'
\x1b
[0m'
def
colorize
(
diff_lines
):
def
bold
(
s
):
return
'
\x1b
[1m'
+
s
+
'
\x1b
[0m'
def
cyan
(
s
):
return
'
\x1b
[36m'
+
s
+
'
\x1b
[0m'
def
green
(
s
):
return
'
\x1b
[32m'
+
s
+
'
\x1b
[0m'
def
red
(
s
):
return
'
\x1b
[31m'
+
s
+
'
\x1b
[0m'
for
line
in
diff_lines
:
if
line
[:
4
]
in
[
'--- '
,
'+++ '
]:
yield
bold
(
line
)
elif
line
.
startswith
(
'@@ '
):
yield
cyan
(
line
)
elif
line
.
startswith
(
'+'
):
yield
green
(
line
)
elif
line
.
startswith
(
'-'
):
yield
red
(
line
)
else
:
yield
line
def
print_diff
(
diff_lines
,
use_color
):
if
use_color
:
diff_lines
=
colorize
(
diff_lines
)
sys
.
stdout
.
writelines
(
diff_lines
)
def
print_trouble
(
prog
,
message
,
use_colors
):
error_text
=
'error:'
if
use_colors
:
error_text
=
bold_red
(
error_text
)
print
(
"{}: {} {}"
.
format
(
prog
,
error_text
,
message
),
file
=
sys
.
stderr
)
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
'--clang-format-executable'
,
metavar
=
'EXECUTABLE'
,
help
=
'path to the clang-format executable'
,
default
=
'clang-format'
)
parser
.
add_argument
(
'--extensions'
,
help
=
'comma separated list of file extensions (default: {})'
.
format
(
DEFAULT_EXTENSIONS
),
default
=
DEFAULT_EXTENSIONS
)
parser
.
add_argument
(
'-r'
,
'--recursive'
,
action
=
'store_true'
,
help
=
'run recursively over directories'
)
parser
.
add_argument
(
'files'
,
metavar
=
'file'
,
nargs
=
'+'
)
parser
.
add_argument
(
'-q'
,
'--quiet'
,
action
=
'store_true'
)
parser
.
add_argument
(
'-j'
,
metavar
=
'N'
,
type
=
int
,
default
=
0
,
help
=
'run N clang-format jobs in parallel'
' (default number of cpus + 1)'
)
parser
.
add_argument
(
'--color'
,
default
=
'auto'
,
choices
=
[
'auto'
,
'always'
,
'never'
],
help
=
'show colored diff (default: auto)'
)
parser
.
add_argument
(
'-e'
,
'--exclude'
,
metavar
=
'PATTERN'
,
action
=
'append'
,
default
=
[],
help
=
'exclude paths matching the given glob-like pattern(s)'
' from recursive search'
)
args
=
parser
.
parse_args
()
# use default signal handling, like diff return SIGINT value on ^C
# https://bugs.python.org/issue14229#msg156446
signal
.
signal
(
signal
.
SIGINT
,
signal
.
SIG_DFL
)
try
:
signal
.
SIGPIPE
except
AttributeError
:
# compatibility, SIGPIPE does not exist on Windows
pass
else
:
signal
.
signal
(
signal
.
SIGPIPE
,
signal
.
SIG_DFL
)
colored_stdout
=
False
colored_stderr
=
False
if
args
.
color
==
'always'
:
colored_stdout
=
True
colored_stderr
=
True
elif
args
.
color
==
'auto'
:
colored_stdout
=
sys
.
stdout
.
isatty
()
colored_stderr
=
sys
.
stderr
.
isatty
()
version_invocation
=
[
args
.
clang_format_executable
,
str
(
"--version"
)]
try
:
subprocess
.
check_call
(
version_invocation
,
stdout
=
DEVNULL
)
except
subprocess
.
CalledProcessError
as
e
:
print_trouble
(
parser
.
prog
,
str
(
e
),
use_colors
=
colored_stderr
)
return
ExitStatus
.
TROUBLE
except
OSError
as
e
:
print_trouble
(
parser
.
prog
,
"Command '{}' failed to start: {}"
.
format
(
subprocess
.
list2cmdline
(
version_invocation
),
e
),
use_colors
=
colored_stderr
,
)
return
ExitStatus
.
TROUBLE
retcode
=
ExitStatus
.
SUCCESS
files
=
list_files
(
args
.
files
,
recursive
=
args
.
recursive
,
exclude
=
args
.
exclude
,
extensions
=
args
.
extensions
.
split
(
','
))
if
not
files
:
return
njobs
=
args
.
j
if
njobs
==
0
:
njobs
=
multiprocessing
.
cpu_count
()
+
1
njobs
=
min
(
len
(
files
),
njobs
)
if
njobs
==
1
:
# execute directly instead of in a pool,
# less overhead, simpler stacktraces
it
=
(
run_clang_format_diff_wrapper
(
args
,
file
)
for
file
in
files
)
pool
=
None
else
:
pool
=
multiprocessing
.
Pool
(
njobs
)
it
=
pool
.
imap_unordered
(
partial
(
run_clang_format_diff_wrapper
,
args
),
files
)
while
True
:
try
:
outs
,
errs
=
next
(
it
)
except
StopIteration
:
break
except
DiffError
as
e
:
print_trouble
(
parser
.
prog
,
str
(
e
),
use_colors
=
colored_stderr
)
retcode
=
ExitStatus
.
TROUBLE
sys
.
stderr
.
writelines
(
e
.
errs
)
except
UnexpectedError
as
e
:
print_trouble
(
parser
.
prog
,
str
(
e
),
use_colors
=
colored_stderr
)
sys
.
stderr
.
write
(
e
.
formatted_traceback
)
retcode
=
ExitStatus
.
TROUBLE
# stop at the first unexpected error,
# something could be very wrong,
# don't process all files unnecessarily
if
pool
:
pool
.
terminate
()
break
else
:
sys
.
stderr
.
writelines
(
errs
)
if
outs
==
[]:
continue
if
not
args
.
quiet
:
print_diff
(
outs
,
use_color
=
colored_stdout
)
if
retcode
==
ExitStatus
.
SUCCESS
:
retcode
=
ExitStatus
.
DIFF
return
retcode
if
__name__
==
'__main__'
:
sys
.
exit
(
main
())
.circleci/unittest/linux/scripts/run_style_checks.sh
View file @
b33c539c
#!/usr/bin/env bash
#!/usr/bin/env bash
set
-
u
set
-
eux
root_dir
=
"
$(
git rev-parse
--show-toplevel
)
"
root_dir
=
"
$(
git rev-parse
--show-toplevel
)
"
conda_dir
=
"
${
root_dir
}
/conda"
conda_dir
=
"
${
root_dir
}
/conda"
env_dir
=
"
${
root_dir
}
/env"
env_dir
=
"
${
root_dir
}
/env"
this_dir
=
"
$(
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
"
>
/dev/null 2>&1
&&
pwd
)
"
eval
"
$(
"
${
conda_dir
}
/bin/conda"
shell.bash hook
)
"
eval
"
$(
"
${
conda_dir
}
/bin/conda"
shell.bash hook
)
"
conda activate
"
${
env_dir
}
"
conda activate
"
${
env_dir
}
"
# 1. Install tools
# 1. Install tools
conda
install
flake8
conda
install
flake8
printf
"Installed flake8: "
flake8
--version
clangformat_path
=
"
${
root_dir
}
/clang-format"
clangformat_path
=
"
${
root_dir
}
/clang-format"
curl https://oss-clang-format.s3.us-east-2.amazonaws.com/linux64/clang-format-linux64
-o
"
${
clangformat_path
}
"
curl https://oss-clang-format.s3.us-east-2.amazonaws.com/linux64/clang-format-linux64
-o
"
${
clangformat_path
}
"
chmod
+x
"
${
clangformat_path
}
"
chmod
+x
"
${
clangformat_path
}
"
printf
"Installed clang-fortmat"
"
${
clangformat_path
}
"
--version
# 2. Run style checks
# 2. Run style checks
# We want to run all the style checks even if one of them fail.
# We want to run all the style checks even if one of them fail.
set
+e
exit_status
=
0
exit_status
=
0
printf
"
\x
1b[34mRunning flake8: "
printf
"
\x
1b[34mRunning flake8:
\x
1b[0m
\n
"
flake8
--version
printf
"
\x
1b[0m
\n
"
flake8 torchaudio
test
build_tools/setup_helpers
flake8 torchaudio
test
build_tools/setup_helpers
status
=
$?
status
=
$?
exit_status
=
"
$((
exit_status+status
))
"
exit_status
=
"
$((
exit_status+status
))
"
...
@@ -30,14 +36,14 @@ if [ "${status}" -ne 0 ]; then
...
@@ -30,14 +36,14 @@ if [ "${status}" -ne 0 ]; then
printf
"
\x
1b[31mflake8 failed. Check the format of Python files.
\x
1b[0m
\n
"
printf
"
\x
1b[31mflake8 failed. Check the format of Python files.
\x
1b[0m
\n
"
fi
fi
printf
"
\x
1b[34mRunning clang-format:
"
printf
"
\x
1b[34mRunning clang-format:
\x
1b[0m
\n
"
./
clang-format
--version
"
${
this_dir
}
"
/run-
clang-format
.py
\
printf
"
\x
1b[0m
\n
"
-r
torchaudio/csrc
\
git
-clang-format
--binary
./
clang
-
format
origin/master
-
-clang-format
-executable
"
${
clangformat
_path
}
"
\
git diff
--exit-code
&&
git diff
--exit-code
status
=
$?
status
=
$?
exit_status
=
"
$((
exit_status+status
))
"
exit_status
=
"
$((
exit_status+status
))
"
if
[
"
${
status
}
"
-ne
0
]
;
then
if
[
"
${
status
}
"
-ne
0
]
;
then
printf
"
\x
1b[31mC++ files are not formatted. Please use
git-
clang-format to format CPP files.
\x
1b[0m
\n
"
printf
"
\x
1b[31mC++ files are not formatted. Please use clang-format to format CPP files.
\x
1b[0m
\n
"
fi
fi
exit
$exit_status
exit
$exit_status
torchaudio/csrc/pybind.cpp
View file @
b33c539c
...
@@ -2,88 +2,92 @@
...
@@ -2,88 +2,92 @@
#include <torchaudio/csrc/sox/io.h>
#include <torchaudio/csrc/sox/io.h>
#include <torchaudio/csrc/sox/legacy.h>
#include <torchaudio/csrc/sox/legacy.h>
PYBIND11_MODULE
(
_torchaudio
,
m
)
{
PYBIND11_MODULE
(
_torchaudio
,
m
)
{
py
::
class_
<
sox_signalinfo_t
>
(
m
,
"sox_signalinfo_t"
)
py
::
class_
<
sox_signalinfo_t
>
(
m
,
"sox_signalinfo_t"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<>
())
.
def
(
"__repr__"
,
[](
const
sox_signalinfo_t
&
self
)
{
.
def
(
std
::
stringstream
ss
;
"__repr__"
,
ss
<<
"sox_signalinfo_t {
\n
"
[](
const
sox_signalinfo_t
&
self
)
{
<<
" rate-> "
<<
self
.
rate
<<
"
\n
"
std
::
stringstream
ss
;
<<
" channels-> "
<<
self
.
channels
<<
"
\n
"
ss
<<
"sox_signalinfo_t {
\n
"
<<
" precision-> "
<<
self
.
precision
<<
"
\n
"
<<
" rate-> "
<<
self
.
rate
<<
"
\n
"
<<
" length-> "
<<
self
.
length
<<
"
\n
"
<<
" channels-> "
<<
self
.
channels
<<
"
\n
"
<<
" mult-> "
<<
self
.
mult
<<
"
\n
"
<<
" precision-> "
<<
self
.
precision
<<
"
\n
"
<<
"}
\n
"
;
<<
" length-> "
<<
self
.
length
<<
"
\n
"
return
ss
.
str
();
<<
" mult-> "
<<
self
.
mult
<<
"
\n
"
})
<<
"}
\n
"
;
.
def_readwrite
(
"rate"
,
&
sox_signalinfo_t
::
rate
)
return
ss
.
str
();
.
def_readwrite
(
"channels"
,
&
sox_signalinfo_t
::
channels
)
})
.
def_readwrite
(
"precision"
,
&
sox_signalinfo_t
::
precision
)
.
def_readwrite
(
"rate"
,
&
sox_signalinfo_t
::
rate
)
.
def_readwrite
(
"length"
,
&
sox_signalinfo_t
::
length
)
.
def_readwrite
(
"channels"
,
&
sox_signalinfo_t
::
channels
)
.
def_readwrite
(
"mult"
,
&
sox_signalinfo_t
::
mult
);
.
def_readwrite
(
"precision"
,
&
sox_signalinfo_t
::
precision
)
.
def_readwrite
(
"length"
,
&
sox_signalinfo_t
::
length
)
.
def_readwrite
(
"mult"
,
&
sox_signalinfo_t
::
mult
);
py
::
class_
<
sox_encodinginfo_t
>
(
m
,
"sox_encodinginfo_t"
)
py
::
class_
<
sox_encodinginfo_t
>
(
m
,
"sox_encodinginfo_t"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<>
())
.
def
(
"__repr__"
,
[](
const
sox_encodinginfo_t
&
self
)
{
.
def
(
std
::
stringstream
ss
;
"__repr__"
,
ss
<<
"sox_encodinginfo_t {
\n
"
[](
const
sox_encodinginfo_t
&
self
)
{
<<
" encoding-> "
<<
self
.
encoding
<<
"
\n
"
std
::
stringstream
ss
;
<<
" bits_per_sample-> "
<<
self
.
bits_per_sample
<<
"
\n
"
ss
<<
"sox_encodinginfo_t {
\n
"
<<
" compression-> "
<<
self
.
compression
<<
"
\n
"
<<
" encoding-> "
<<
self
.
encoding
<<
"
\n
"
<<
" reverse_bytes-> "
<<
self
.
reverse_bytes
<<
"
\n
"
<<
" bits_per_sample-> "
<<
self
.
bits_per_sample
<<
"
\n
"
<<
" reverse_nibbles-> "
<<
self
.
reverse_nibbles
<<
"
\n
"
<<
" compression-> "
<<
self
.
compression
<<
"
\n
"
<<
" reverse_bits-> "
<<
self
.
reverse_bits
<<
"
\n
"
<<
" reverse_bytes-> "
<<
self
.
reverse_bytes
<<
"
\n
"
<<
" opposite_endian-> "
<<
self
.
opposite_endian
<<
"
\n
"
<<
" reverse_nibbles-> "
<<
self
.
reverse_nibbles
<<
"
\n
"
<<
"}
\n
"
;
<<
" reverse_bits-> "
<<
self
.
reverse_bits
<<
"
\n
"
return
ss
.
str
();
<<
" opposite_endian-> "
<<
self
.
opposite_endian
<<
"
\n
"
})
<<
"}
\n
"
;
.
def_readwrite
(
"encoding"
,
&
sox_encodinginfo_t
::
encoding
)
return
ss
.
str
();
.
def_readwrite
(
"bits_per_sample"
,
&
sox_encodinginfo_t
::
bits_per_sample
)
})
.
def_readwrite
(
"compression"
,
&
sox_encodinginfo_t
::
compression
)
.
def_readwrite
(
"encoding"
,
&
sox_encodinginfo_t
::
encoding
)
.
def_readwrite
(
"reverse_bytes"
,
&
sox_encodinginfo_t
::
reverse_bytes
)
.
def_readwrite
(
"bits_per_sample"
,
&
sox_encodinginfo_t
::
bits_per_sample
)
.
def_readwrite
(
"reverse_nibbles"
,
&
sox_encodinginfo_t
::
reverse_nibbles
)
.
def_readwrite
(
"compression"
,
&
sox_encodinginfo_t
::
compression
)
.
def_readwrite
(
"reverse_bits"
,
&
sox_encodinginfo_t
::
reverse_bits
)
.
def_readwrite
(
"reverse_bytes"
,
&
sox_encodinginfo_t
::
reverse_bytes
)
.
def_readwrite
(
"opposite_endian"
,
&
sox_encodinginfo_t
::
opposite_endian
);
.
def_readwrite
(
"reverse_nibbles"
,
&
sox_encodinginfo_t
::
reverse_nibbles
)
.
def_readwrite
(
"reverse_bits"
,
&
sox_encodinginfo_t
::
reverse_bits
)
.
def_readwrite
(
"opposite_endian"
,
&
sox_encodinginfo_t
::
opposite_endian
);
py
::
enum_
<
sox_encoding_t
>
(
m
,
"sox_encoding_t"
)
py
::
enum_
<
sox_encoding_t
>
(
m
,
"sox_encoding_t"
)
.
value
(
"SOX_ENCODING_UNKNOWN"
,
sox_encoding_t
::
SOX_ENCODING_UNKNOWN
)
.
value
(
"SOX_ENCODING_UNKNOWN"
,
sox_encoding_t
::
SOX_ENCODING_UNKNOWN
)
.
value
(
"SOX_ENCODING_SIGN2"
,
sox_encoding_t
::
SOX_ENCODING_SIGN2
)
.
value
(
"SOX_ENCODING_SIGN2"
,
sox_encoding_t
::
SOX_ENCODING_SIGN2
)
.
value
(
"SOX_ENCODING_UNSIGNED"
,
sox_encoding_t
::
SOX_ENCODING_UNSIGNED
)
.
value
(
"SOX_ENCODING_UNSIGNED"
,
sox_encoding_t
::
SOX_ENCODING_UNSIGNED
)
.
value
(
"SOX_ENCODING_FLOAT"
,
sox_encoding_t
::
SOX_ENCODING_FLOAT
)
.
value
(
"SOX_ENCODING_FLOAT"
,
sox_encoding_t
::
SOX_ENCODING_FLOAT
)
.
value
(
"SOX_ENCODING_FLOAT_TEXT"
,
sox_encoding_t
::
SOX_ENCODING_FLOAT_TEXT
)
.
value
(
"SOX_ENCODING_FLOAT_TEXT"
,
sox_encoding_t
::
SOX_ENCODING_FLOAT_TEXT
)
.
value
(
"SOX_ENCODING_FLAC"
,
sox_encoding_t
::
SOX_ENCODING_FLAC
)
.
value
(
"SOX_ENCODING_FLAC"
,
sox_encoding_t
::
SOX_ENCODING_FLAC
)
.
value
(
"SOX_ENCODING_HCOM"
,
sox_encoding_t
::
SOX_ENCODING_HCOM
)
.
value
(
"SOX_ENCODING_HCOM"
,
sox_encoding_t
::
SOX_ENCODING_HCOM
)
.
value
(
"SOX_ENCODING_WAVPACK"
,
sox_encoding_t
::
SOX_ENCODING_WAVPACK
)
.
value
(
"SOX_ENCODING_WAVPACK"
,
sox_encoding_t
::
SOX_ENCODING_WAVPACK
)
.
value
(
"SOX_ENCODING_WAVPACKF"
,
sox_encoding_t
::
SOX_ENCODING_WAVPACKF
)
.
value
(
"SOX_ENCODING_WAVPACKF"
,
sox_encoding_t
::
SOX_ENCODING_WAVPACKF
)
.
value
(
"SOX_ENCODING_ULAW"
,
sox_encoding_t
::
SOX_ENCODING_ULAW
)
.
value
(
"SOX_ENCODING_ULAW"
,
sox_encoding_t
::
SOX_ENCODING_ULAW
)
.
value
(
"SOX_ENCODING_ALAW"
,
sox_encoding_t
::
SOX_ENCODING_ALAW
)
.
value
(
"SOX_ENCODING_ALAW"
,
sox_encoding_t
::
SOX_ENCODING_ALAW
)
.
value
(
"SOX_ENCODING_G721"
,
sox_encoding_t
::
SOX_ENCODING_G721
)
.
value
(
"SOX_ENCODING_G721"
,
sox_encoding_t
::
SOX_ENCODING_G721
)
.
value
(
"SOX_ENCODING_G723"
,
sox_encoding_t
::
SOX_ENCODING_G723
)
.
value
(
"SOX_ENCODING_G723"
,
sox_encoding_t
::
SOX_ENCODING_G723
)
.
value
(
"SOX_ENCODING_CL_ADPCM"
,
sox_encoding_t
::
SOX_ENCODING_CL_ADPCM
)
.
value
(
"SOX_ENCODING_CL_ADPCM"
,
sox_encoding_t
::
SOX_ENCODING_CL_ADPCM
)
.
value
(
"SOX_ENCODING_CL_ADPCM16"
,
sox_encoding_t
::
SOX_ENCODING_CL_ADPCM16
)
.
value
(
"SOX_ENCODING_CL_ADPCM16"
,
sox_encoding_t
::
SOX_ENCODING_CL_ADPCM16
)
.
value
(
"SOX_ENCODING_MS_ADPCM"
,
sox_encoding_t
::
SOX_ENCODING_MS_ADPCM
)
.
value
(
"SOX_ENCODING_MS_ADPCM"
,
sox_encoding_t
::
SOX_ENCODING_MS_ADPCM
)
.
value
(
"SOX_ENCODING_IMA_ADPCM"
,
sox_encoding_t
::
SOX_ENCODING_IMA_ADPCM
)
.
value
(
"SOX_ENCODING_IMA_ADPCM"
,
sox_encoding_t
::
SOX_ENCODING_IMA_ADPCM
)
.
value
(
"SOX_ENCODING_OKI_ADPCM"
,
sox_encoding_t
::
SOX_ENCODING_OKI_ADPCM
)
.
value
(
"SOX_ENCODING_OKI_ADPCM"
,
sox_encoding_t
::
SOX_ENCODING_OKI_ADPCM
)
.
value
(
"SOX_ENCODING_DPCM"
,
sox_encoding_t
::
SOX_ENCODING_DPCM
)
.
value
(
"SOX_ENCODING_DPCM"
,
sox_encoding_t
::
SOX_ENCODING_DPCM
)
.
value
(
"SOX_ENCODING_DWVW"
,
sox_encoding_t
::
SOX_ENCODING_DWVW
)
.
value
(
"SOX_ENCODING_DWVW"
,
sox_encoding_t
::
SOX_ENCODING_DWVW
)
.
value
(
"SOX_ENCODING_DWVWN"
,
sox_encoding_t
::
SOX_ENCODING_DWVWN
)
.
value
(
"SOX_ENCODING_DWVWN"
,
sox_encoding_t
::
SOX_ENCODING_DWVWN
)
.
value
(
"SOX_ENCODING_GSM"
,
sox_encoding_t
::
SOX_ENCODING_GSM
)
.
value
(
"SOX_ENCODING_GSM"
,
sox_encoding_t
::
SOX_ENCODING_GSM
)
.
value
(
"SOX_ENCODING_MP3"
,
sox_encoding_t
::
SOX_ENCODING_MP3
)
.
value
(
"SOX_ENCODING_MP3"
,
sox_encoding_t
::
SOX_ENCODING_MP3
)
.
value
(
"SOX_ENCODING_VORBIS"
,
sox_encoding_t
::
SOX_ENCODING_VORBIS
)
.
value
(
"SOX_ENCODING_VORBIS"
,
sox_encoding_t
::
SOX_ENCODING_VORBIS
)
.
value
(
"SOX_ENCODING_AMR_WB"
,
sox_encoding_t
::
SOX_ENCODING_AMR_WB
)
.
value
(
"SOX_ENCODING_AMR_WB"
,
sox_encoding_t
::
SOX_ENCODING_AMR_WB
)
.
value
(
"SOX_ENCODING_AMR_NB"
,
sox_encoding_t
::
SOX_ENCODING_AMR_NB
)
.
value
(
"SOX_ENCODING_AMR_NB"
,
sox_encoding_t
::
SOX_ENCODING_AMR_NB
)
.
value
(
"SOX_ENCODING_LPC10"
,
sox_encoding_t
::
SOX_ENCODING_LPC10
)
.
value
(
"SOX_ENCODING_LPC10"
,
sox_encoding_t
::
SOX_ENCODING_LPC10
)
//.value("SOX_ENCODING_OPUS", sox_encoding_t::SOX_ENCODING_OPUS) // creates a compile error
//.value("SOX_ENCODING_OPUS", sox_encoding_t::SOX_ENCODING_OPUS) //
.
value
(
"SOX_ENCODINGS"
,
sox_encoding_t
::
SOX_ENCODINGS
)
// creates a compile error
.
export_values
();
.
value
(
"SOX_ENCODINGS"
,
sox_encoding_t
::
SOX_ENCODINGS
)
.
export_values
();
py
::
enum_
<
sox_option_t
>
(
m
,
"sox_option_t"
)
py
::
enum_
<
sox_option_t
>
(
m
,
"sox_option_t"
)
.
value
(
"sox_option_no"
,
sox_option_t
::
sox_option_no
)
.
value
(
"sox_option_no"
,
sox_option_t
::
sox_option_no
)
.
value
(
"sox_option_yes"
,
sox_option_t
::
sox_option_yes
)
.
value
(
"sox_option_yes"
,
sox_option_t
::
sox_option_yes
)
.
value
(
"sox_option_default"
,
sox_option_t
::
sox_option_default
)
.
value
(
"sox_option_default"
,
sox_option_t
::
sox_option_default
)
.
export_values
();
.
export_values
();
py
::
enum_
<
sox_bool
>
(
m
,
"sox_bool"
)
py
::
enum_
<
sox_bool
>
(
m
,
"sox_bool"
)
.
value
(
"sox_false"
,
sox_bool
::
sox_false
)
.
value
(
"sox_false"
,
sox_bool
::
sox_false
)
.
value
(
"sox_true"
,
sox_bool
::
sox_true
)
.
value
(
"sox_true"
,
sox_bool
::
sox_true
)
.
export_values
();
.
export_values
();
m
.
def
(
m
.
def
(
"read_audio_file"
,
"read_audio_file"
,
&
torch
::
audio
::
read_audio_file
,
&
torch
::
audio
::
read_audio_file
,
...
...
torchaudio/csrc/sox/effects.cpp
View file @
b33c539c
...
@@ -143,23 +143,27 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
...
@@ -143,23 +143,27 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
c10
::
optional
<
bool
>&
normalize
,
c10
::
optional
<
bool
>&
normalize
,
c10
::
optional
<
bool
>&
channels_first
,
c10
::
optional
<
bool
>&
channels_first
,
c10
::
optional
<
std
::
string
>&
format
)
{
c10
::
optional
<
std
::
string
>&
format
)
{
// Streaming decoding over file-like object is tricky because libsox operates
// Streaming decoding over file-like object is tricky because libsox operates on FILE pointer.
// on FILE pointer. The folloing is what `sox` and `play` commands do
// The folloing is what `sox` and `play` commands do
// - file input -> FILE pointer
// - file input -> FILE pointer
// - URL input -> call wget in suprocess and pipe the data -> FILE pointer
// - URL input -> call wget in suprocess and pipe the data -> FILE pointer
// - stdin -> FILE pointer
// - stdin -> FILE pointer
//
//
// We want to, instead, fetch byte strings chunk by chunk, consume them, and discard.
// We want to, instead, fetch byte strings chunk by chunk, consume them, and
// discard.
//
//
// Here is the approach
// Here is the approach
// 1. Initialize sox_format_t using sox_open_mem_read, providing the initial chunk of byte string
// 1. Initialize sox_format_t using sox_open_mem_read, providing the initial
// This will perform header-based format detection, if necessary, then fill the metadata of
// chunk of byte string
// sox_format_t. Internally, sox_open_mem_read uses fmemopen, which returns FILE* which points the
// This will perform header-based format detection, if necessary, then fill
// buffer of the provided byte string.
// the metadata of sox_format_t. Internally, sox_open_mem_read uses
// 2. Each time sox reads a chunk from the FILE*, we update the underlying buffer in a way that it
// fmemopen, which returns FILE* which points the buffer of the provided
// starts with unseen data, and append the new data read from the given fileobj.
// byte string.
// This will trick libsox as if it keeps reading from the FILE* continuously.
// 2. Each time sox reads a chunk from the FILE*, we update the underlying
// buffer in a way that it
// starts with unseen data, and append the new data read from the given
// fileobj. This will trick libsox as if it keeps reading from the FILE*
// continuously.
// Prepare the buffer used throughout the lifecycle of SoxEffectChain.
// Prepare the buffer used throughout the lifecycle of SoxEffectChain.
// Using std::string and let it manage memory.
// Using std::string and let it manage memory.
...
@@ -170,9 +174,12 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
...
@@ -170,9 +174,12 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
auto
*
in_buf
=
const_cast
<
char
*>
(
in_buffer
.
data
());
auto
*
in_buf
=
const_cast
<
char
*>
(
in_buffer
.
data
());
// Fetch the header, and copy it to the buffer.
// Fetch the header, and copy it to the buffer.
auto
header
=
static_cast
<
std
::
string
>
(
static_cast
<
py
::
bytes
>
(
fileobj
.
attr
(
"read"
)(
4096
)));
auto
header
=
static_cast
<
std
::
string
>
(
memcpy
(
static_cast
<
void
*>
(
in_buf
),
static_cast
<
py
::
bytes
>
(
fileobj
.
attr
(
"read"
)(
4096
)));
static_cast
<
void
*>
(
const_cast
<
char
*>
(
header
.
data
())),
header
.
length
());
memcpy
(
static_cast
<
void
*>
(
in_buf
),
static_cast
<
void
*>
(
const_cast
<
char
*>
(
header
.
data
())),
header
.
length
());
// Open file (this starts reading the header)
// Open file (this starts reading the header)
SoxFormat
sf
(
sox_open_mem_read
(
SoxFormat
sf
(
sox_open_mem_read
(
...
@@ -212,8 +219,7 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
...
@@ -212,8 +219,7 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
channels_first_
);
channels_first_
);
return
std
::
make_tuple
(
return
std
::
make_tuple
(
tensor
,
tensor
,
static_cast
<
int64_t
>
(
chain
.
getOutputSampleRate
()));
static_cast
<
int64_t
>
(
chain
.
getOutputSampleRate
()));
}
}
#endif // TORCH_API_INCLUDE_EXTENSION_H
#endif // TORCH_API_INCLUDE_EXTENSION_H
...
...
torchaudio/csrc/sox/effects_chain.cpp
View file @
b33c539c
...
@@ -123,44 +123,47 @@ int file_output_flow(
...
@@ -123,44 +123,47 @@ int file_output_flow(
}
}
sox_effect_handler_t
*
get_tensor_input_handler
()
{
sox_effect_handler_t
*
get_tensor_input_handler
()
{
static
sox_effect_handler_t
handler
{
/*name=*/
"input_tensor"
,
static
sox_effect_handler_t
handler
{
/*usage=*/
NULL
,
/*name=*/
"input_tensor"
,
/*flags=*/
SOX_EFF_MCHAN
,
/*usage=*/
NULL
,
/*getopts=*/
NULL
,
/*flags=*/
SOX_EFF_MCHAN
,
/*start=*/
NULL
,
/*getopts=*/
NULL
,
/*flow=*/
NULL
,
/*start=*/
NULL
,
/*drain=*/
tensor_input_drain
,
/*flow=*/
NULL
,
/*stop=*/
NULL
,
/*drain=*/
tensor_input_drain
,
/*kill=*/
NULL
,
/*stop=*/
NULL
,
/*priv_size=*/
sizeof
(
TensorInputPriv
)};
/*kill=*/
NULL
,
/*priv_size=*/
sizeof
(
TensorInputPriv
)};
return
&
handler
;
return
&
handler
;
}
}
sox_effect_handler_t
*
get_tensor_output_handler
()
{
sox_effect_handler_t
*
get_tensor_output_handler
()
{
static
sox_effect_handler_t
handler
{
/*name=*/
"output_tensor"
,
static
sox_effect_handler_t
handler
{
/*usage=*/
NULL
,
/*name=*/
"output_tensor"
,
/*flags=*/
SOX_EFF_MCHAN
,
/*usage=*/
NULL
,
/*getopts=*/
NULL
,
/*flags=*/
SOX_EFF_MCHAN
,
/*start=*/
NULL
,
/*getopts=*/
NULL
,
/*flow=*/
tensor_output_flow
,
/*start=*/
NULL
,
/*drain=*/
NULL
,
/*flow=*/
tensor_output_flow
,
/*stop=*/
NULL
,
/*drain=*/
NULL
,
/*kill=*/
NULL
,
/*stop=*/
NULL
,
/*priv_size=*/
sizeof
(
TensorOutputPriv
)};
/*kill=*/
NULL
,
/*priv_size=*/
sizeof
(
TensorOutputPriv
)};
return
&
handler
;
return
&
handler
;
}
}
sox_effect_handler_t
*
get_file_output_handler
()
{
sox_effect_handler_t
*
get_file_output_handler
()
{
static
sox_effect_handler_t
handler
{
/*name=*/
"output_file"
,
static
sox_effect_handler_t
handler
{
/*usage=*/
NULL
,
/*name=*/
"output_file"
,
/*flags=*/
SOX_EFF_MCHAN
,
/*usage=*/
NULL
,
/*getopts=*/
NULL
,
/*flags=*/
SOX_EFF_MCHAN
,
/*start=*/
NULL
,
/*getopts=*/
NULL
,
/*flow=*/
file_output_flow
,
/*start=*/
NULL
,
/*drain=*/
NULL
,
/*flow=*/
file_output_flow
,
/*stop=*/
NULL
,
/*drain=*/
NULL
,
/*kill=*/
NULL
,
/*stop=*/
NULL
,
/*priv_size=*/
sizeof
(
FileOutputPriv
)};
/*kill=*/
NULL
,
/*priv_size=*/
sizeof
(
FileOutputPriv
)};
return
&
handler
;
return
&
handler
;
}
}
...
@@ -198,7 +201,8 @@ void SoxEffectsChain::addInputTensor(TensorSignal* signal) {
...
@@ -198,7 +201,8 @@ void SoxEffectsChain::addInputTensor(TensorSignal* signal) {
priv
->
signal
=
signal
;
priv
->
signal
=
signal
;
priv
->
index
=
0
;
priv
->
index
=
0
;
if
(
sox_add_effect
(
sec_
,
e
,
&
interm_sig_
,
&
in_sig_
)
!=
SOX_SUCCESS
)
{
if
(
sox_add_effect
(
sec_
,
e
,
&
interm_sig_
,
&
in_sig_
)
!=
SOX_SUCCESS
)
{
throw
std
::
runtime_error
(
"Internal Error: Failed to add effect: input_tensor"
);
throw
std
::
runtime_error
(
"Internal Error: Failed to add effect: input_tensor"
);
}
}
}
}
...
@@ -207,7 +211,8 @@ void SoxEffectsChain::addOutputBuffer(
...
@@ -207,7 +211,8 @@ void SoxEffectsChain::addOutputBuffer(
SoxEffect
e
(
sox_create_effect
(
get_tensor_output_handler
()));
SoxEffect
e
(
sox_create_effect
(
get_tensor_output_handler
()));
static_cast
<
TensorOutputPriv
*>
(
e
->
priv
)
->
buffer
=
output_buffer
;
static_cast
<
TensorOutputPriv
*>
(
e
->
priv
)
->
buffer
=
output_buffer
;
if
(
sox_add_effect
(
sec_
,
e
,
&
interm_sig_
,
&
in_sig_
)
!=
SOX_SUCCESS
)
{
if
(
sox_add_effect
(
sec_
,
e
,
&
interm_sig_
,
&
in_sig_
)
!=
SOX_SUCCESS
)
{
throw
std
::
runtime_error
(
"Internal Error: Failed to add effect: output_tensor"
);
throw
std
::
runtime_error
(
"Internal Error: Failed to add effect: output_tensor"
);
}
}
}
}
...
@@ -305,7 +310,7 @@ struct FileObjOutputPriv {
...
@@ -305,7 +310,7 @@ struct FileObjOutputPriv {
/// Callback function to feed byte string
/// Callback function to feed byte string
/// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/sox.h#L1268-L1278
/// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/sox.h#L1268-L1278
int
fileobj_input_drain
(
sox_effect_t
*
effp
,
sox_sample_t
*
obuf
,
size_t
*
osamp
)
{
int
fileobj_input_drain
(
sox_effect_t
*
effp
,
sox_sample_t
*
obuf
,
size_t
*
osamp
)
{
auto
priv
=
static_cast
<
FileObjInputPriv
*>
(
effp
->
priv
);
auto
priv
=
static_cast
<
FileObjInputPriv
*>
(
effp
->
priv
);
auto
sf
=
priv
->
sf
;
auto
sf
=
priv
->
sf
;
auto
fileobj
=
priv
->
fileobj
;
auto
fileobj
=
priv
->
fileobj
;
auto
buffer
=
priv
->
buffer
;
auto
buffer
=
priv
->
buffer
;
...
@@ -315,9 +320,9 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
...
@@ -315,9 +320,9 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
//
//
// NOTE:
// NOTE:
// Since the underlying FILE* was opened with `fmemopen`, the only way
// Since the underlying FILE* was opened with `fmemopen`, the only way
// libsox detect EOF is reaching the end of the buffer. (null byte won't
help)
// libsox detect EOF is reaching the end of the buffer. (null byte won't
// Therefore we need to align the content at the end of buffer,
otherwise,
//
help)
Therefore we need to align the content at the end of buffer,
// libsox will keep reading the content beyond intended length.
//
otherwise,
libsox will keep reading the content beyond intended length.
//
//
// Before:
// Before:
//
//
...
@@ -339,11 +344,12 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
...
@@ -339,11 +344,12 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
const
auto
num_refill
=
py
::
len
(
chunk_
);
const
auto
num_refill
=
py
::
len
(
chunk_
);
const
auto
offset
=
buffer_size
-
(
num_remain
+
num_refill
);
const
auto
offset
=
buffer_size
-
(
num_remain
+
num_refill
);
if
(
num_refill
>
num_consumed
)
{
if
(
num_refill
>
num_consumed
)
{
std
::
ostringstream
message
;
std
::
ostringstream
message
;
message
<<
"Tried to read up to "
<<
num_consumed
<<
" bytes but, "
message
<<
"recieved "
<<
num_refill
<<
" bytes. "
<<
"Tried to read up to "
<<
num_consumed
<<
" bytes but, "
<<
"The given object does not confirm to read protocol of file object."
;
<<
"recieved "
<<
num_refill
<<
" bytes. "
<<
"The given object does not confirm to read protocol of file object."
;
throw
std
::
runtime_error
(
message
.
str
());
throw
std
::
runtime_error
(
message
.
str
());
}
}
...
@@ -364,7 +370,7 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
...
@@ -364,7 +370,7 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
// 1.4. Set the file pointer to the new offset
// 1.4. Set the file pointer to the new offset
sf
->
tell_off
=
offset
;
sf
->
tell_off
=
offset
;
fseek
((
FILE
*
)
sf
->
fp
,
offset
,
SEEK_SET
);
fseek
((
FILE
*
)
sf
->
fp
,
offset
,
SEEK_SET
);
// 2. Perform decoding operation
// 2. Perform decoding operation
// The following part is practically same as "input" effect
// The following part is practically same as "input" effect
...
@@ -377,7 +383,7 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
...
@@ -377,7 +383,7 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
// store the actual number read back to *osamp
// store the actual number read back to *osamp
*
osamp
=
sox_read
(
sf
,
obuf
,
*
osamp
);
*
osamp
=
sox_read
(
sf
,
obuf
,
*
osamp
);
return
*
osamp
?
SOX_SUCCESS
:
SOX_EOF
;
return
*
osamp
?
SOX_SUCCESS
:
SOX_EOF
;
}
}
int
fileobj_output_flow
(
int
fileobj_output_flow
(
...
@@ -420,30 +426,32 @@ int fileobj_output_flow(
...
@@ -420,30 +426,32 @@ int fileobj_output_flow(
}
}
sox_effect_handler_t
*
get_fileobj_input_handler
()
{
sox_effect_handler_t
*
get_fileobj_input_handler
()
{
static
sox_effect_handler_t
handler
{
/*name=*/
"input_fileobj_object"
,
static
sox_effect_handler_t
handler
{
/*usage=*/
NULL
,
/*name=*/
"input_fileobj_object"
,
/*flags=*/
SOX_EFF_MCHAN
,
/*usage=*/
NULL
,
/*getopts=*/
NULL
,
/*flags=*/
SOX_EFF_MCHAN
,
/*start=*/
NULL
,
/*getopts=*/
NULL
,
/*flow=*/
NULL
,
/*start=*/
NULL
,
/*drain=*/
fileobj_input_drain
,
/*flow=*/
NULL
,
/*stop=*/
NULL
,
/*drain=*/
fileobj_input_drain
,
/*kill=*/
NULL
,
/*stop=*/
NULL
,
/*priv_size=*/
sizeof
(
FileObjInputPriv
)};
/*kill=*/
NULL
,
/*priv_size=*/
sizeof
(
FileObjInputPriv
)};
return
&
handler
;
return
&
handler
;
}
}
sox_effect_handler_t
*
get_fileobj_output_handler
()
{
sox_effect_handler_t
*
get_fileobj_output_handler
()
{
static
sox_effect_handler_t
handler
{
/*name=*/
"output_fileobj_object"
,
static
sox_effect_handler_t
handler
{
/*usage=*/
NULL
,
/*name=*/
"output_fileobj_object"
,
/*flags=*/
SOX_EFF_MCHAN
,
/*usage=*/
NULL
,
/*getopts=*/
NULL
,
/*flags=*/
SOX_EFF_MCHAN
,
/*start=*/
NULL
,
/*getopts=*/
NULL
,
/*flow=*/
fileobj_output_flow
,
/*start=*/
NULL
,
/*drain=*/
NULL
,
/*flow=*/
fileobj_output_flow
,
/*stop=*/
NULL
,
/*drain=*/
NULL
,
/*kill=*/
NULL
,
/*stop=*/
NULL
,
/*priv_size=*/
sizeof
(
FileObjOutputPriv
)};
/*kill=*/
NULL
,
/*priv_size=*/
sizeof
(
FileObjOutputPriv
)};
return
&
handler
;
return
&
handler
;
}
}
...
@@ -464,7 +472,8 @@ void SoxEffectsChain::addInputFileObj(
...
@@ -464,7 +472,8 @@ void SoxEffectsChain::addInputFileObj(
priv
->
buffer
=
buffer
;
priv
->
buffer
=
buffer
;
priv
->
buffer_size
=
buffer_size
;
priv
->
buffer_size
=
buffer_size
;
if
(
sox_add_effect
(
sec_
,
e
,
&
interm_sig_
,
&
in_sig_
)
!=
SOX_SUCCESS
)
{
if
(
sox_add_effect
(
sec_
,
e
,
&
interm_sig_
,
&
in_sig_
)
!=
SOX_SUCCESS
)
{
throw
std
::
runtime_error
(
"Internal Error: Failed to add effect: input fileobj"
);
throw
std
::
runtime_error
(
"Internal Error: Failed to add effect: input fileobj"
);
}
}
}
}
...
@@ -481,7 +490,8 @@ void SoxEffectsChain::addOutputFileObj(
...
@@ -481,7 +490,8 @@ void SoxEffectsChain::addOutputFileObj(
priv
->
buffer
=
buffer
;
priv
->
buffer
=
buffer
;
priv
->
buffer_size
=
buffer_size
;
priv
->
buffer_size
=
buffer_size
;
if
(
sox_add_effect
(
sec_
,
e
,
&
interm_sig_
,
&
out_sig_
)
!=
SOX_SUCCESS
)
{
if
(
sox_add_effect
(
sec_
,
e
,
&
interm_sig_
,
&
out_sig_
)
!=
SOX_SUCCESS
)
{
throw
std
::
runtime_error
(
"Internal Error: Failed to add effect: output fileobj"
);
throw
std
::
runtime_error
(
"Internal Error: Failed to add effect: output fileobj"
);
}
}
}
}
...
...
torchaudio/csrc/sox/io.cpp
View file @
b33c539c
...
@@ -112,8 +112,9 @@ void save_audio_file(
...
@@ -112,8 +112,9 @@ void save_audio_file(
auto
signal
=
TensorSignal
(
tensor
,
sample_rate
,
channels_first
);
auto
signal
=
TensorSignal
(
tensor
,
sample_rate
,
channels_first
);
const
auto
filetype
=
[
&
](){
const
auto
filetype
=
[
&
]()
{
if
(
format
.
has_value
())
return
format
.
value
();
if
(
format
.
has_value
())
return
format
.
value
();
return
get_filetype
(
path
);
return
get_filetype
(
path
);
}();
}();
if
(
filetype
==
"amr-nb"
)
{
if
(
filetype
==
"amr-nb"
)
{
...
@@ -123,7 +124,8 @@ void save_audio_file(
...
@@ -123,7 +124,8 @@ void save_audio_file(
tensor
=
(
unnormalize_wav
(
tensor
)
/
65536
).
to
(
torch
::
kInt16
);
tensor
=
(
unnormalize_wav
(
tensor
)
/
65536
).
to
(
torch
::
kInt16
);
}
}
const
auto
signal_info
=
get_signalinfo
(
&
signal
,
filetype
);
const
auto
signal_info
=
get_signalinfo
(
&
signal
,
filetype
);
const
auto
encoding_info
=
get_encodinginfo
(
filetype
,
tensor
.
dtype
(),
compression
);
const
auto
encoding_info
=
get_encodinginfo
(
filetype
,
tensor
.
dtype
(),
compression
);
SoxFormat
sf
(
sox_open_write
(
SoxFormat
sf
(
sox_open_write
(
path
.
c_str
(),
path
.
c_str
(),
...
@@ -161,7 +163,8 @@ std::tuple<torch::Tensor, int64_t> load_audio_fileobj(
...
@@ -161,7 +163,8 @@ std::tuple<torch::Tensor, int64_t> load_audio_fileobj(
namespace
{
namespace
{
// helper class to automatically release buffer, to be used by save_audio_fileobj
// helper class to automatically release buffer, to be used by
// save_audio_fileobj
struct
AutoReleaseBuffer
{
struct
AutoReleaseBuffer
{
char
*
ptr
;
char
*
ptr
;
size_t
size
;
size_t
size
;
...
@@ -194,12 +197,14 @@ void save_audio_fileobj(
...
@@ -194,12 +197,14 @@ void save_audio_fileobj(
if
(
filetype
==
"amr-nb"
)
{
if
(
filetype
==
"amr-nb"
)
{
const
auto
num_channels
=
tensor
.
size
(
channels_first
?
0
:
1
);
const
auto
num_channels
=
tensor
.
size
(
channels_first
?
0
:
1
);
if
(
num_channels
!=
1
)
{
if
(
num_channels
!=
1
)
{
throw
std
::
runtime_error
(
"amr-nb format only supports single channel audio."
);
throw
std
::
runtime_error
(
"amr-nb format only supports single channel audio."
);
}
}
tensor
=
(
unnormalize_wav
(
tensor
)
/
65536
).
to
(
torch
::
kInt16
);
tensor
=
(
unnormalize_wav
(
tensor
)
/
65536
).
to
(
torch
::
kInt16
);
}
}
const
auto
signal_info
=
get_signalinfo
(
&
signal
,
filetype
);
const
auto
signal_info
=
get_signalinfo
(
&
signal
,
filetype
);
const
auto
encoding_info
=
get_encodinginfo
(
filetype
,
tensor
.
dtype
(),
compression
);
const
auto
encoding_info
=
get_encodinginfo
(
filetype
,
tensor
.
dtype
(),
compression
);
AutoReleaseBuffer
buffer
;
AutoReleaseBuffer
buffer
;
...
@@ -212,7 +217,8 @@ void save_audio_fileobj(
...
@@ -212,7 +217,8 @@ void save_audio_fileobj(
/*oob=*/
nullptr
));
/*oob=*/
nullptr
));
if
(
static_cast
<
sox_format_t
*>
(
sf
)
==
nullptr
)
{
if
(
static_cast
<
sox_format_t
*>
(
sf
)
==
nullptr
)
{
throw
std
::
runtime_error
(
"Error saving audio file: failed to open memory stream."
);
throw
std
::
runtime_error
(
"Error saving audio file: failed to open memory stream."
);
}
}
torchaudio
::
sox_effects_chain
::
SoxEffectsChain
chain
(
torchaudio
::
sox_effects_chain
::
SoxEffectsChain
chain
(
...
@@ -222,7 +228,8 @@ void save_audio_fileobj(
...
@@ -222,7 +228,8 @@ void save_audio_fileobj(
chain
.
addOutputFileObj
(
sf
,
&
buffer
.
ptr
,
&
buffer
.
size
,
&
fileobj
);
chain
.
addOutputFileObj
(
sf
,
&
buffer
.
ptr
,
&
buffer
.
size
,
&
fileobj
);
chain
.
run
();
chain
.
run
();
// Closing the sox_format_t is necessary for flushing the last chunk to the buffer
// Closing the sox_format_t is necessary for flushing the last chunk to the
// buffer
sf
.
close
();
sf
.
close
();
fileobj
.
attr
(
"write"
)(
py
::
bytes
(
buffer
.
ptr
,
buffer
.
size
));
fileobj
.
attr
(
"write"
)(
py
::
bytes
(
buffer
.
ptr
,
buffer
.
size
));
...
...
torchaudio/csrc/sox/legacy.cpp
View file @
b33c539c
...
@@ -40,10 +40,7 @@ int64_t write_audio(SoxDescriptor& fd, at::Tensor tensor) {
...
@@ -40,10 +40,7 @@ int64_t write_audio(SoxDescriptor& fd, at::Tensor tensor) {
return
samples_written
;
return
samples_written
;
}
}
void
read_audio
(
void
read_audio
(
SoxDescriptor
&
fd
,
at
::
Tensor
output
,
int64_t
buffer_length
)
{
SoxDescriptor
&
fd
,
at
::
Tensor
output
,
int64_t
buffer_length
)
{
std
::
vector
<
sox_sample_t
>
buffer
(
buffer_length
);
std
::
vector
<
sox_sample_t
>
buffer
(
buffer_length
);
int
number_of_channels
=
fd
->
signal
.
channels
;
int
number_of_channels
=
fd
->
signal
.
channels
;
...
@@ -64,8 +61,7 @@ void read_audio(
...
@@ -64,8 +61,7 @@ void read_audio(
}
// namespace
}
// namespace
std
::
tuple
<
sox_signalinfo_t
,
sox_encodinginfo_t
>
get_info
(
std
::
tuple
<
sox_signalinfo_t
,
sox_encodinginfo_t
>
get_info
(
const
std
::
string
&
file_name
const
std
::
string
&
file_name
)
{
)
{
SoxDescriptor
fd
(
sox_open_read
(
SoxDescriptor
fd
(
sox_open_read
(
file_name
.
c_str
(),
file_name
.
c_str
(),
/*signal=*/
nullptr
,
/*signal=*/
nullptr
,
...
@@ -86,7 +82,6 @@ int read_audio_file(
...
@@ -86,7 +82,6 @@ int read_audio_file(
sox_signalinfo_t
*
si
,
sox_signalinfo_t
*
si
,
sox_encodinginfo_t
*
ei
,
sox_encodinginfo_t
*
ei
,
const
char
*
ft
)
{
const
char
*
ft
)
{
SoxDescriptor
fd
(
sox_open_read
(
file_name
.
c_str
(),
si
,
ei
,
ft
));
SoxDescriptor
fd
(
sox_open_read
(
file_name
.
c_str
(),
si
,
ei
,
ft
));
if
(
fd
.
get
()
==
nullptr
)
{
if
(
fd
.
get
()
==
nullptr
)
{
throw
std
::
runtime_error
(
"Error opening audio file"
);
throw
std
::
runtime_error
(
"Error opening audio file"
);
...
@@ -112,15 +107,16 @@ int read_audio_file(
...
@@ -112,15 +107,16 @@ int read_audio_file(
// calculate buffer length
// calculate buffer length
int64_t
buffer_length
=
total_length
;
int64_t
buffer_length
=
total_length
;
if
(
offset
>
0
)
{
if
(
offset
>
0
)
{
buffer_length
-=
offset
;
buffer_length
-=
offset
;
}
}
if
(
nframes
>
0
&&
buffer_length
>
nframes
)
{
if
(
nframes
>
0
&&
buffer_length
>
nframes
)
{
buffer_length
=
nframes
;
buffer_length
=
nframes
;
}
}
// seek to offset point before reading data
// seek to offset point before reading data
if
(
sox_seek
(
fd
.
get
(),
offset
,
0
)
==
SOX_EOF
)
{
if
(
sox_seek
(
fd
.
get
(),
offset
,
0
)
==
SOX_EOF
)
{
throw
std
::
runtime_error
(
"sox_seek reached EOF, try reducing offset or num_samples"
);
throw
std
::
runtime_error
(
"sox_seek reached EOF, try reducing offset or num_samples"
);
}
}
// read data and fill output tensor
// read data and fill output tensor
...
...
torchaudio/csrc/sox/legacy.h
View file @
b33c539c
#include <sox.h>
#include <sox.h>
#include <torch/torch.h>
#include <torch/torch.h>
namespace
torch
{
namespace
audio
{
namespace
torch
{
namespace
audio
{
/// Reads an audio file from the given `path` into the `output` `Tensor` and
/// Reads an audio file from the given `path` into the `output` `Tensor` and
/// returns the sample rate of the audio file.
/// returns the sample rate of the audio file.
...
@@ -30,9 +31,10 @@ void write_audio_file(
...
@@ -30,9 +31,10 @@ void write_audio_file(
/// Reads an audio file from the given `path` and returns a tuple of
/// Reads an audio file from the given `path` and returns a tuple of
/// sox_signalinfo_t and sox_encodinginfo_t, which contain information about
/// sox_signalinfo_t and sox_encodinginfo_t, which contain information about
/// the audio file such as sample rate, length, bit precision, encoding and
more.
/// the audio file such as sample rate, length, bit precision, encoding and
/// Throws `std::runtime_error` if the audio file could not be opened, or
an
///
more.
Throws `std::runtime_error` if the audio file could not be opened, or
/// error occurred during reading of the audio data.
///
an
error occurred during reading of the audio data.
std
::
tuple
<
sox_signalinfo_t
,
sox_encodinginfo_t
>
get_info
(
std
::
tuple
<
sox_signalinfo_t
,
sox_encodinginfo_t
>
get_info
(
const
std
::
string
&
file_name
);
const
std
::
string
&
file_name
);
}}
// namespace torch::audio
}
// namespace audio
}
// namespace torch
torchaudio/csrc/sox/register.cpp
View file @
b33c539c
...
@@ -43,17 +43,19 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
...
@@ -43,17 +43,19 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
.
def
(
"get_sample_rate"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getSampleRate
)
.
def
(
"get_sample_rate"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getSampleRate
)
.
def
(
"get_num_channels"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getNumChannels
)
.
def
(
"get_num_channels"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getNumChannels
)
.
def
(
"get_num_frames"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getNumFrames
)
.
def
(
"get_num_frames"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getNumFrames
)
.
def
(
"get_bits_per_sample"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getBitsPerSample
);
.
def
(
"get_bits_per_sample"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getBitsPerSample
);
m
.
def
(
"torchaudio::sox_io_get_info"
,
&
torchaudio
::
sox_io
::
get_info
);
m
.
def
(
"torchaudio::sox_io_get_info"
,
&
torchaudio
::
sox_io
::
get_info
);
m
.
def
(
m
.
def
(
"torchaudio::sox_io_load_audio_file("
"torchaudio::sox_io_load_audio_file("
"str path,"
"str path,"
"int? frame_offset=None,"
"int? frame_offset=None,"
"int? num_frames=None,"
"int? num_frames=None,"
"bool? normalize=True,"
"bool? normalize=True,"
"bool? channels_first=False,"
"bool? channels_first=False,"
"str? format=None"
"str? format=None"
") -> __torch__.torch.classes.torchaudio.TensorSignal"
,
") -> __torch__.torch.classes.torchaudio.TensorSignal"
,
&
torchaudio
::
sox_io
::
load_audio_file
);
&
torchaudio
::
sox_io
::
load_audio_file
);
m
.
def
(
m
.
def
(
...
...
torchaudio/csrc/sox/utils.cpp
View file @
b33c539c
...
@@ -80,7 +80,9 @@ bool TensorSignal::getChannelsFirst() const {
...
@@ -80,7 +80,9 @@ bool TensorSignal::getChannelsFirst() const {
}
}
SoxFormat
::
SoxFormat
(
sox_format_t
*
fd
)
noexcept
:
fd_
(
fd
)
{}
SoxFormat
::
SoxFormat
(
sox_format_t
*
fd
)
noexcept
:
fd_
(
fd
)
{}
SoxFormat
::~
SoxFormat
()
{
close
();
}
SoxFormat
::~
SoxFormat
()
{
close
();
}
sox_format_t
*
SoxFormat
::
operator
->
()
const
noexcept
{
sox_format_t
*
SoxFormat
::
operator
->
()
const
noexcept
{
return
fd_
;
return
fd_
;
...
@@ -291,26 +293,28 @@ sox_signalinfo_t get_signalinfo(
...
@@ -291,26 +293,28 @@ sox_signalinfo_t get_signalinfo(
sox_encodinginfo_t
get_encodinginfo
(
sox_encodinginfo_t
get_encodinginfo
(
const
std
::
string
filetype
,
const
std
::
string
filetype
,
const
caffe2
::
TypeMeta
dtype
)
{
const
caffe2
::
TypeMeta
dtype
)
{
return
sox_encodinginfo_t
{
/*encoding=*/
get_encoding
(
filetype
,
dtype
),
return
sox_encodinginfo_t
{
/*bits_per_sample=*/
get_precision
(
filetype
,
dtype
),
/*encoding=*/
get_encoding
(
filetype
,
dtype
),
/*compression=*/
HUGE_VAL
,
/*bits_per_sample=*/
get_precision
(
filetype
,
dtype
),
/*reverse_bytes=*/
sox_option_default
,
/*compression=*/
HUGE_VAL
,
/*reverse_nibbles=*/
sox_option_default
,
/*reverse_bytes=*/
sox_option_default
,
/*reverse_bits=*/
sox_option_default
,
/*reverse_nibbles=*/
sox_option_default
,
/*opposite_endian=*/
sox_false
};
/*reverse_bits=*/
sox_option_default
,
/*opposite_endian=*/
sox_false
};
}
}
sox_encodinginfo_t
get_encodinginfo
(
sox_encodinginfo_t
get_encodinginfo
(
const
std
::
string
filetype
,
const
std
::
string
filetype
,
const
caffe2
::
TypeMeta
dtype
,
const
caffe2
::
TypeMeta
dtype
,
c10
::
optional
<
double
>&
compression
)
{
c10
::
optional
<
double
>&
compression
)
{
return
sox_encodinginfo_t
{
/*encoding=*/
get_encoding
(
filetype
,
dtype
),
return
sox_encodinginfo_t
{
/*bits_per_sample=*/
get_precision
(
filetype
,
dtype
),
/*encoding=*/
get_encoding
(
filetype
,
dtype
),
/*compression=*/
compression
.
value_or
(
HUGE_VAL
),
/*bits_per_sample=*/
get_precision
(
filetype
,
dtype
),
/*reverse_bytes=*/
sox_option_default
,
/*compression=*/
compression
.
value_or
(
HUGE_VAL
),
/*reverse_nibbles=*/
sox_option_default
,
/*reverse_bytes=*/
sox_option_default
,
/*reverse_bits=*/
sox_option_default
,
/*reverse_nibbles=*/
sox_option_default
,
/*opposite_endian=*/
sox_false
};
/*reverse_bits=*/
sox_option_default
,
/*opposite_endian=*/
sox_false
};
}
}
}
// namespace sox_utils
}
// namespace sox_utils
...
...
torchaudio/csrc/sox/utils.h
View file @
b33c539c
...
@@ -69,7 +69,7 @@ struct SoxFormat {
...
@@ -69,7 +69,7 @@ struct SoxFormat {
///
///
/// Verify that input file is found, has known encoding, and not empty
/// Verify that input file is found, has known encoding, and not empty
void
validate_input_file
(
const
SoxFormat
&
sf
,
bool
check_length
=
true
);
void
validate_input_file
(
const
SoxFormat
&
sf
,
bool
check_length
=
true
);
///
///
/// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32
/// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32
...
...
torchaudio/csrc/transducer.cpp
View file @
b33c539c
...
@@ -8,71 +8,80 @@
...
@@ -8,71 +8,80 @@
namespace
{
namespace
{
int64_t
cpu_rnnt_loss
(
torch
::
Tensor
acts
,
int64_t
cpu_rnnt_loss
(
torch
::
Tensor
labels
,
torch
::
Tensor
acts
,
torch
::
Tensor
input_lengths
,
torch
::
Tensor
labels
,
torch
::
Tensor
label_lengths
,
torch
::
Tensor
input_lengths
,
torch
::
Tensor
costs
,
torch
::
Tensor
label_lengths
,
torch
::
Tensor
grads
,
torch
::
Tensor
costs
,
int64_t
blank_label
,
torch
::
Tensor
grads
,
int64_t
num_threads
)
{
int64_t
blank_label
,
int64_t
num_threads
)
{
int
maxT
=
acts
.
size
(
1
);
int
maxT
=
acts
.
size
(
1
);
int
maxU
=
acts
.
size
(
2
);
int
maxU
=
acts
.
size
(
2
);
int
minibatch_size
=
acts
.
size
(
0
);
int
minibatch_size
=
acts
.
size
(
0
);
int
alphabet_size
=
acts
.
size
(
3
);
int
alphabet_size
=
acts
.
size
(
3
);
rnntOptions
options
;
rnntOptions
options
;
memset
(
&
options
,
0
,
sizeof
(
options
));
memset
(
&
options
,
0
,
sizeof
(
options
));
options
.
maxT
=
maxT
;
options
.
maxT
=
maxT
;
options
.
maxU
=
maxU
;
options
.
maxU
=
maxU
;
options
.
blank_label
=
blank_label
;
options
.
blank_label
=
blank_label
;
options
.
batch_first
=
true
;
options
.
batch_first
=
true
;
options
.
loc
=
RNNT_CPU
;
options
.
loc
=
RNNT_CPU
;
options
.
num_threads
=
num_threads
;
options
.
num_threads
=
num_threads
;
// have to use at least one
// have to use at least one
options
.
num_threads
=
std
::
max
(
options
.
num_threads
,
(
unsigned
int
)
1
);
options
.
num_threads
=
std
::
max
(
options
.
num_threads
,
(
unsigned
int
)
1
);
size_t
cpu_size_bytes
=
0
;
size_t
cpu_size_bytes
=
0
;
switch
(
acts
.
scalar_type
())
{
switch
(
acts
.
scalar_type
())
{
case
torch
::
ScalarType
::
Float
:
case
torch
::
ScalarType
::
Float
:
{
{
get_workspace_size
(
maxT
,
maxU
,
minibatch_size
,
false
,
&
cpu_size_bytes
);
get_workspace_size
(
maxT
,
maxU
,
minibatch_size
,
false
,
&
cpu_size_bytes
);
std
::
vector
<
float
>
cpu_workspace
(
cpu_size_bytes
/
sizeof
(
float
),
0
);
std
::
vector
<
float
>
cpu_workspace
(
cpu_size_bytes
/
sizeof
(
float
),
0
);
compute_rnnt_loss
(
acts
.
data_ptr
<
float
>
(),
compute_rnnt_loss
(
acts
.
data_ptr
<
float
>
(),
grads
.
data_ptr
<
float
>
(),
grads
.
data_ptr
<
float
>
(),
labels
.
data_ptr
<
int
>
(),
label_lengths
.
data_ptr
<
int
>
(),
labels
.
data_ptr
<
int
>
(),
input_lengths
.
data_ptr
<
int
>
(),
alphabet_size
,
label_lengths
.
data_ptr
<
int
>
(),
minibatch_size
,
costs
.
data_ptr
<
float
>
(),
input_lengths
.
data_ptr
<
int
>
(),
cpu_workspace
.
data
(),
options
);
alphabet_size
,
minibatch_size
,
return
0
;
costs
.
data_ptr
<
float
>
(),
}
cpu_workspace
.
data
(),
case
torch
::
ScalarType
::
Double
:
options
);
{
get_workspace_size
(
maxT
,
maxU
,
minibatch_size
,
return
0
;
false
,
&
cpu_size_bytes
,
}
sizeof
(
double
));
case
torch
::
ScalarType
::
Double
:
{
get_workspace_size
(
std
::
vector
<
double
>
cpu_workspace
(
cpu_size_bytes
/
sizeof
(
double
),
0
);
maxT
,
maxU
,
minibatch_size
,
false
,
&
cpu_size_bytes
,
sizeof
(
double
));
compute_rnnt_loss_fp64
(
acts
.
data_ptr
<
double
>
(),
grads
.
data_ptr
<
double
>
(),
std
::
vector
<
double
>
cpu_workspace
(
cpu_size_bytes
/
sizeof
(
double
),
0
);
labels
.
data_ptr
<
int
>
(),
label_lengths
.
data_ptr
<
int
>
(),
input_lengths
.
data_ptr
<
int
>
(),
alphabet_size
,
compute_rnnt_loss_fp64
(
minibatch_size
,
costs
.
data_ptr
<
double
>
(),
acts
.
data_ptr
<
double
>
(),
cpu_workspace
.
data
(),
options
);
grads
.
data_ptr
<
double
>
(),
labels
.
data_ptr
<
int
>
(),
return
0
;
label_lengths
.
data_ptr
<
int
>
(),
}
input_lengths
.
data_ptr
<
int
>
(),
default:
alphabet_size
,
TORCH_CHECK
(
false
,
minibatch_size
,
std
::
string
(
__func__
)
+
" not implemented for '"
+
toString
(
acts
.
scalar_type
())
+
"'"
costs
.
data_ptr
<
double
>
(),
);
cpu_workspace
.
data
(),
options
);
return
0
;
}
}
return
-
1
;
default:
TORCH_CHECK
(
false
,
std
::
string
(
__func__
)
+
" not implemented for '"
+
toString
(
acts
.
scalar_type
())
+
"'"
);
}
return
-
1
;
}
}
}
// namespace
}
// namespace
...
@@ -82,12 +91,13 @@ TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
...
@@ -82,12 +91,13 @@ TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
}
}
TORCH_LIBRARY_FRAGMENT
(
torchaudio
,
m
)
{
TORCH_LIBRARY_FRAGMENT
(
torchaudio
,
m
)
{
m
.
def
(
"rnnt_loss(Tensor acts,"
m
.
def
(
"Tensor labels,"
"rnnt_loss(Tensor acts,"
"Tensor input_lengths,"
"Tensor labels,"
"Tensor label_lengths,"
"Tensor input_lengths,"
"Tensor costs,"
"Tensor label_lengths,"
"Tensor grads,"
"Tensor costs,"
"int blank_label,"
"Tensor grads,"
"int num_threads) -> int"
);
"int blank_label,"
"int num_threads) -> int"
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment