Unverified Commit b33c539c authored by moto's avatar moto Committed by GitHub
Browse files

Fix clang-format CI job (#1198)

parent 99ed7183
#!/usr/bin/env python
"""A wrapper script around clang-format, suitable for linting multiple files
and to use for continuous integration.
This is an alternative API for the clang-format command line.
It runs over multiple files and directories in parallel.
A diff output is produced and a sensible exit code is returned.
"""
import argparse
import codecs
import difflib
import fnmatch
import io
import multiprocessing
import os
import signal
import subprocess
import sys
import traceback
from functools import partial
try:
from subprocess import DEVNULL # py3k
except ImportError:
DEVNULL = open(os.devnull, "wb")
DEFAULT_EXTENSIONS = 'c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu'
class ExitStatus:
SUCCESS = 0
DIFF = 1
TROUBLE = 2
def list_files(files, recursive=False, extensions=None, exclude=None):
if extensions is None:
extensions = []
if exclude is None:
exclude = []
out = []
for file in files:
if recursive and os.path.isdir(file):
for dirpath, dnames, fnames in os.walk(file):
fpaths = [os.path.join(dirpath, fname) for fname in fnames]
for pattern in exclude:
# os.walk() supports trimming down the dnames list
# by modifying it in-place,
# to avoid unnecessary directory listings.
dnames[:] = [
x for x in dnames
if
not fnmatch.fnmatch(os.path.join(dirpath, x), pattern)
]
fpaths = [
x for x in fpaths if not fnmatch.fnmatch(x, pattern)
]
for f in fpaths:
ext = os.path.splitext(f)[1][1:]
if ext in extensions:
out.append(f)
else:
out.append(file)
return out
def make_diff(file, original, reformatted):
return list(
difflib.unified_diff(
original,
reformatted,
fromfile='{}\t(original)'.format(file),
tofile='{}\t(reformatted)'.format(file),
n=3))
class DiffError(Exception):
def __init__(self, message, errs=None):
super(DiffError, self).__init__(message)
self.errs = errs or []
class UnexpectedError(Exception):
def __init__(self, message, exc=None):
super(UnexpectedError, self).__init__(message)
self.formatted_traceback = traceback.format_exc()
self.exc = exc
def run_clang_format_diff_wrapper(args, file):
try:
ret = run_clang_format_diff(args, file)
return ret
except DiffError:
raise
except Exception as e:
raise UnexpectedError('{}: {}: {}'.format(file, e.__class__.__name__,
e), e)
def run_clang_format_diff(args, file):
try:
with io.open(file, 'r', encoding='utf-8') as f:
original = f.readlines()
except IOError as exc:
raise DiffError(str(exc))
invocation = [args.clang_format_executable, file]
# Use of utf-8 to decode the process output.
#
# Hopefully, this is the correct thing to do.
#
# It's done due to the following assumptions (which may be incorrect):
# - clang-format will returns the bytes read from the files as-is,
# without conversion, and it is already assumed that the files use utf-8.
# - if the diagnostics were internationalized, they would use utf-8:
# > Adding Translations to Clang
# >
# > Not possible yet!
# > Diagnostic strings should be written in UTF-8,
# > the client can translate to the relevant code page if needed.
# > Each translation completely replaces the format string
# > for the diagnostic.
# > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation
try:
proc = subprocess.Popen(
invocation,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
encoding='utf-8')
except OSError as exc:
raise DiffError(
"Command '{}' failed to start: {}".format(
subprocess.list2cmdline(invocation), exc
)
)
proc_stdout = proc.stdout
proc_stderr = proc.stderr
# hopefully the stderr pipe won't get full and block the process
outs = list(proc_stdout.readlines())
errs = list(proc_stderr.readlines())
proc.wait()
if proc.returncode:
raise DiffError(
"Command '{}' returned non-zero exit status {}".format(
subprocess.list2cmdline(invocation), proc.returncode
),
errs,
)
return make_diff(file, original, outs), errs
def bold_red(s):
return '\x1b[1m\x1b[31m' + s + '\x1b[0m'
def colorize(diff_lines):
def bold(s):
return '\x1b[1m' + s + '\x1b[0m'
def cyan(s):
return '\x1b[36m' + s + '\x1b[0m'
def green(s):
return '\x1b[32m' + s + '\x1b[0m'
def red(s):
return '\x1b[31m' + s + '\x1b[0m'
for line in diff_lines:
if line[:4] in ['--- ', '+++ ']:
yield bold(line)
elif line.startswith('@@ '):
yield cyan(line)
elif line.startswith('+'):
yield green(line)
elif line.startswith('-'):
yield red(line)
else:
yield line
def print_diff(diff_lines, use_color):
if use_color:
diff_lines = colorize(diff_lines)
sys.stdout.writelines(diff_lines)
def print_trouble(prog, message, use_colors):
error_text = 'error:'
if use_colors:
error_text = bold_red(error_text)
print("{}: {} {}".format(prog, error_text, message), file=sys.stderr)
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--clang-format-executable',
metavar='EXECUTABLE',
help='path to the clang-format executable',
default='clang-format')
parser.add_argument(
'--extensions',
help='comma separated list of file extensions (default: {})'.format(
DEFAULT_EXTENSIONS),
default=DEFAULT_EXTENSIONS)
parser.add_argument(
'-r',
'--recursive',
action='store_true',
help='run recursively over directories')
parser.add_argument('files', metavar='file', nargs='+')
parser.add_argument(
'-q',
'--quiet',
action='store_true')
parser.add_argument(
'-j',
metavar='N',
type=int,
default=0,
help='run N clang-format jobs in parallel'
' (default number of cpus + 1)')
parser.add_argument(
'--color',
default='auto',
choices=['auto', 'always', 'never'],
help='show colored diff (default: auto)')
parser.add_argument(
'-e',
'--exclude',
metavar='PATTERN',
action='append',
default=[],
help='exclude paths matching the given glob-like pattern(s)'
' from recursive search')
args = parser.parse_args()
# use default signal handling, like diff return SIGINT value on ^C
# https://bugs.python.org/issue14229#msg156446
signal.signal(signal.SIGINT, signal.SIG_DFL)
try:
signal.SIGPIPE
except AttributeError:
# compatibility, SIGPIPE does not exist on Windows
pass
else:
signal.signal(signal.SIGPIPE, signal.SIG_DFL)
colored_stdout = False
colored_stderr = False
if args.color == 'always':
colored_stdout = True
colored_stderr = True
elif args.color == 'auto':
colored_stdout = sys.stdout.isatty()
colored_stderr = sys.stderr.isatty()
version_invocation = [args.clang_format_executable, str("--version")]
try:
subprocess.check_call(version_invocation, stdout=DEVNULL)
except subprocess.CalledProcessError as e:
print_trouble(parser.prog, str(e), use_colors=colored_stderr)
return ExitStatus.TROUBLE
except OSError as e:
print_trouble(
parser.prog,
"Command '{}' failed to start: {}".format(
subprocess.list2cmdline(version_invocation), e
),
use_colors=colored_stderr,
)
return ExitStatus.TROUBLE
retcode = ExitStatus.SUCCESS
files = list_files(
args.files,
recursive=args.recursive,
exclude=args.exclude,
extensions=args.extensions.split(','))
if not files:
return
njobs = args.j
if njobs == 0:
njobs = multiprocessing.cpu_count() + 1
njobs = min(len(files), njobs)
if njobs == 1:
# execute directly instead of in a pool,
# less overhead, simpler stacktraces
it = (run_clang_format_diff_wrapper(args, file) for file in files)
pool = None
else:
pool = multiprocessing.Pool(njobs)
it = pool.imap_unordered(
partial(run_clang_format_diff_wrapper, args), files)
while True:
try:
outs, errs = next(it)
except StopIteration:
break
except DiffError as e:
print_trouble(parser.prog, str(e), use_colors=colored_stderr)
retcode = ExitStatus.TROUBLE
sys.stderr.writelines(e.errs)
except UnexpectedError as e:
print_trouble(parser.prog, str(e), use_colors=colored_stderr)
sys.stderr.write(e.formatted_traceback)
retcode = ExitStatus.TROUBLE
# stop at the first unexpected error,
# something could be very wrong,
# don't process all files unnecessarily
if pool:
pool.terminate()
break
else:
sys.stderr.writelines(errs)
if outs == []:
continue
if not args.quiet:
print_diff(outs, use_color=colored_stdout)
if retcode == ExitStatus.SUCCESS:
retcode = ExitStatus.DIFF
return retcode
if __name__ == '__main__':
sys.exit(main())
#!/usr/bin/env bash #!/usr/bin/env bash
set -u set -eux
root_dir="$(git rev-parse --show-toplevel)" root_dir="$(git rev-parse --show-toplevel)"
conda_dir="${root_dir}/conda" conda_dir="${root_dir}/conda"
env_dir="${root_dir}/env" env_dir="${root_dir}/env"
this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
eval "$("${conda_dir}/bin/conda" shell.bash hook)" eval "$("${conda_dir}/bin/conda" shell.bash hook)"
conda activate "${env_dir}" conda activate "${env_dir}"
# 1. Install tools # 1. Install tools
conda install flake8 conda install flake8
printf "Installed flake8: "
flake8 --version
clangformat_path="${root_dir}/clang-format" clangformat_path="${root_dir}/clang-format"
curl https://oss-clang-format.s3.us-east-2.amazonaws.com/linux64/clang-format-linux64 -o "${clangformat_path}" curl https://oss-clang-format.s3.us-east-2.amazonaws.com/linux64/clang-format-linux64 -o "${clangformat_path}"
chmod +x "${clangformat_path}" chmod +x "${clangformat_path}"
printf "Installed clang-fortmat"
"${clangformat_path}" --version
# 2. Run style checks # 2. Run style checks
# We want to run all the style checks even if one of them fail. # We want to run all the style checks even if one of them fail.
set +e
exit_status=0 exit_status=0
printf "\x1b[34mRunning flake8: " printf "\x1b[34mRunning flake8:\x1b[0m\n"
flake8 --version
printf "\x1b[0m\n"
flake8 torchaudio test build_tools/setup_helpers flake8 torchaudio test build_tools/setup_helpers
status=$? status=$?
exit_status="$((exit_status+status))" exit_status="$((exit_status+status))"
...@@ -30,14 +36,14 @@ if [ "${status}" -ne 0 ]; then ...@@ -30,14 +36,14 @@ if [ "${status}" -ne 0 ]; then
printf "\x1b[31mflake8 failed. Check the format of Python files.\x1b[0m\n" printf "\x1b[31mflake8 failed. Check the format of Python files.\x1b[0m\n"
fi fi
printf "\x1b[34mRunning clang-format: " printf "\x1b[34mRunning clang-format:\x1b[0m\n"
./clang-format --version "${this_dir}"/run-clang-format.py \
printf "\x1b[0m\n" -r torchaudio/csrc \
git-clang-format --binary ./clang-format origin/master --clang-format-executable "${clangformat_path}" \
git diff --exit-code && git diff --exit-code
status=$? status=$?
exit_status="$((exit_status+status))" exit_status="$((exit_status+status))"
if [ "${status}" -ne 0 ]; then if [ "${status}" -ne 0 ]; then
printf "\x1b[31mC++ files are not formatted. Please use git-clang-format to format CPP files.\x1b[0m\n" printf "\x1b[31mC++ files are not formatted. Please use clang-format to format CPP files.\x1b[0m\n"
fi fi
exit $exit_status exit $exit_status
...@@ -2,88 +2,92 @@ ...@@ -2,88 +2,92 @@
#include <torchaudio/csrc/sox/io.h> #include <torchaudio/csrc/sox/io.h>
#include <torchaudio/csrc/sox/legacy.h> #include <torchaudio/csrc/sox/legacy.h>
PYBIND11_MODULE(_torchaudio, m) { PYBIND11_MODULE(_torchaudio, m) {
py::class_<sox_signalinfo_t>(m, "sox_signalinfo_t") py::class_<sox_signalinfo_t>(m, "sox_signalinfo_t")
.def(py::init<>()) .def(py::init<>())
.def("__repr__", [](const sox_signalinfo_t &self) { .def(
std::stringstream ss; "__repr__",
ss << "sox_signalinfo_t {\n" [](const sox_signalinfo_t& self) {
<< " rate-> " << self.rate << "\n" std::stringstream ss;
<< " channels-> " << self.channels << "\n" ss << "sox_signalinfo_t {\n"
<< " precision-> " << self.precision << "\n" << " rate-> " << self.rate << "\n"
<< " length-> " << self.length << "\n" << " channels-> " << self.channels << "\n"
<< " mult-> " << self.mult << "\n" << " precision-> " << self.precision << "\n"
<< "}\n"; << " length-> " << self.length << "\n"
return ss.str(); << " mult-> " << self.mult << "\n"
}) << "}\n";
.def_readwrite("rate", &sox_signalinfo_t::rate) return ss.str();
.def_readwrite("channels", &sox_signalinfo_t::channels) })
.def_readwrite("precision", &sox_signalinfo_t::precision) .def_readwrite("rate", &sox_signalinfo_t::rate)
.def_readwrite("length", &sox_signalinfo_t::length) .def_readwrite("channels", &sox_signalinfo_t::channels)
.def_readwrite("mult", &sox_signalinfo_t::mult); .def_readwrite("precision", &sox_signalinfo_t::precision)
.def_readwrite("length", &sox_signalinfo_t::length)
.def_readwrite("mult", &sox_signalinfo_t::mult);
py::class_<sox_encodinginfo_t>(m, "sox_encodinginfo_t") py::class_<sox_encodinginfo_t>(m, "sox_encodinginfo_t")
.def(py::init<>()) .def(py::init<>())
.def("__repr__", [](const sox_encodinginfo_t &self) { .def(
std::stringstream ss; "__repr__",
ss << "sox_encodinginfo_t {\n" [](const sox_encodinginfo_t& self) {
<< " encoding-> " << self.encoding << "\n" std::stringstream ss;
<< " bits_per_sample-> " << self.bits_per_sample << "\n" ss << "sox_encodinginfo_t {\n"
<< " compression-> " << self.compression << "\n" << " encoding-> " << self.encoding << "\n"
<< " reverse_bytes-> " << self.reverse_bytes << "\n" << " bits_per_sample-> " << self.bits_per_sample << "\n"
<< " reverse_nibbles-> " << self.reverse_nibbles << "\n" << " compression-> " << self.compression << "\n"
<< " reverse_bits-> " << self.reverse_bits << "\n" << " reverse_bytes-> " << self.reverse_bytes << "\n"
<< " opposite_endian-> " << self.opposite_endian << "\n" << " reverse_nibbles-> " << self.reverse_nibbles << "\n"
<< "}\n"; << " reverse_bits-> " << self.reverse_bits << "\n"
return ss.str(); << " opposite_endian-> " << self.opposite_endian << "\n"
}) << "}\n";
.def_readwrite("encoding", &sox_encodinginfo_t::encoding) return ss.str();
.def_readwrite("bits_per_sample", &sox_encodinginfo_t::bits_per_sample) })
.def_readwrite("compression", &sox_encodinginfo_t::compression) .def_readwrite("encoding", &sox_encodinginfo_t::encoding)
.def_readwrite("reverse_bytes", &sox_encodinginfo_t::reverse_bytes) .def_readwrite("bits_per_sample", &sox_encodinginfo_t::bits_per_sample)
.def_readwrite("reverse_nibbles", &sox_encodinginfo_t::reverse_nibbles) .def_readwrite("compression", &sox_encodinginfo_t::compression)
.def_readwrite("reverse_bits", &sox_encodinginfo_t::reverse_bits) .def_readwrite("reverse_bytes", &sox_encodinginfo_t::reverse_bytes)
.def_readwrite("opposite_endian", &sox_encodinginfo_t::opposite_endian); .def_readwrite("reverse_nibbles", &sox_encodinginfo_t::reverse_nibbles)
.def_readwrite("reverse_bits", &sox_encodinginfo_t::reverse_bits)
.def_readwrite("opposite_endian", &sox_encodinginfo_t::opposite_endian);
py::enum_<sox_encoding_t>(m, "sox_encoding_t") py::enum_<sox_encoding_t>(m, "sox_encoding_t")
.value("SOX_ENCODING_UNKNOWN", sox_encoding_t::SOX_ENCODING_UNKNOWN) .value("SOX_ENCODING_UNKNOWN", sox_encoding_t::SOX_ENCODING_UNKNOWN)
.value("SOX_ENCODING_SIGN2", sox_encoding_t::SOX_ENCODING_SIGN2) .value("SOX_ENCODING_SIGN2", sox_encoding_t::SOX_ENCODING_SIGN2)
.value("SOX_ENCODING_UNSIGNED", sox_encoding_t::SOX_ENCODING_UNSIGNED) .value("SOX_ENCODING_UNSIGNED", sox_encoding_t::SOX_ENCODING_UNSIGNED)
.value("SOX_ENCODING_FLOAT", sox_encoding_t::SOX_ENCODING_FLOAT) .value("SOX_ENCODING_FLOAT", sox_encoding_t::SOX_ENCODING_FLOAT)
.value("SOX_ENCODING_FLOAT_TEXT", sox_encoding_t::SOX_ENCODING_FLOAT_TEXT) .value("SOX_ENCODING_FLOAT_TEXT", sox_encoding_t::SOX_ENCODING_FLOAT_TEXT)
.value("SOX_ENCODING_FLAC", sox_encoding_t::SOX_ENCODING_FLAC) .value("SOX_ENCODING_FLAC", sox_encoding_t::SOX_ENCODING_FLAC)
.value("SOX_ENCODING_HCOM", sox_encoding_t::SOX_ENCODING_HCOM) .value("SOX_ENCODING_HCOM", sox_encoding_t::SOX_ENCODING_HCOM)
.value("SOX_ENCODING_WAVPACK", sox_encoding_t::SOX_ENCODING_WAVPACK) .value("SOX_ENCODING_WAVPACK", sox_encoding_t::SOX_ENCODING_WAVPACK)
.value("SOX_ENCODING_WAVPACKF", sox_encoding_t::SOX_ENCODING_WAVPACKF) .value("SOX_ENCODING_WAVPACKF", sox_encoding_t::SOX_ENCODING_WAVPACKF)
.value("SOX_ENCODING_ULAW", sox_encoding_t::SOX_ENCODING_ULAW) .value("SOX_ENCODING_ULAW", sox_encoding_t::SOX_ENCODING_ULAW)
.value("SOX_ENCODING_ALAW", sox_encoding_t::SOX_ENCODING_ALAW) .value("SOX_ENCODING_ALAW", sox_encoding_t::SOX_ENCODING_ALAW)
.value("SOX_ENCODING_G721", sox_encoding_t::SOX_ENCODING_G721) .value("SOX_ENCODING_G721", sox_encoding_t::SOX_ENCODING_G721)
.value("SOX_ENCODING_G723", sox_encoding_t::SOX_ENCODING_G723) .value("SOX_ENCODING_G723", sox_encoding_t::SOX_ENCODING_G723)
.value("SOX_ENCODING_CL_ADPCM", sox_encoding_t::SOX_ENCODING_CL_ADPCM) .value("SOX_ENCODING_CL_ADPCM", sox_encoding_t::SOX_ENCODING_CL_ADPCM)
.value("SOX_ENCODING_CL_ADPCM16", sox_encoding_t::SOX_ENCODING_CL_ADPCM16) .value("SOX_ENCODING_CL_ADPCM16", sox_encoding_t::SOX_ENCODING_CL_ADPCM16)
.value("SOX_ENCODING_MS_ADPCM", sox_encoding_t::SOX_ENCODING_MS_ADPCM) .value("SOX_ENCODING_MS_ADPCM", sox_encoding_t::SOX_ENCODING_MS_ADPCM)
.value("SOX_ENCODING_IMA_ADPCM", sox_encoding_t::SOX_ENCODING_IMA_ADPCM) .value("SOX_ENCODING_IMA_ADPCM", sox_encoding_t::SOX_ENCODING_IMA_ADPCM)
.value("SOX_ENCODING_OKI_ADPCM", sox_encoding_t::SOX_ENCODING_OKI_ADPCM) .value("SOX_ENCODING_OKI_ADPCM", sox_encoding_t::SOX_ENCODING_OKI_ADPCM)
.value("SOX_ENCODING_DPCM", sox_encoding_t::SOX_ENCODING_DPCM) .value("SOX_ENCODING_DPCM", sox_encoding_t::SOX_ENCODING_DPCM)
.value("SOX_ENCODING_DWVW", sox_encoding_t::SOX_ENCODING_DWVW) .value("SOX_ENCODING_DWVW", sox_encoding_t::SOX_ENCODING_DWVW)
.value("SOX_ENCODING_DWVWN", sox_encoding_t::SOX_ENCODING_DWVWN) .value("SOX_ENCODING_DWVWN", sox_encoding_t::SOX_ENCODING_DWVWN)
.value("SOX_ENCODING_GSM", sox_encoding_t::SOX_ENCODING_GSM) .value("SOX_ENCODING_GSM", sox_encoding_t::SOX_ENCODING_GSM)
.value("SOX_ENCODING_MP3", sox_encoding_t::SOX_ENCODING_MP3) .value("SOX_ENCODING_MP3", sox_encoding_t::SOX_ENCODING_MP3)
.value("SOX_ENCODING_VORBIS", sox_encoding_t::SOX_ENCODING_VORBIS) .value("SOX_ENCODING_VORBIS", sox_encoding_t::SOX_ENCODING_VORBIS)
.value("SOX_ENCODING_AMR_WB", sox_encoding_t::SOX_ENCODING_AMR_WB) .value("SOX_ENCODING_AMR_WB", sox_encoding_t::SOX_ENCODING_AMR_WB)
.value("SOX_ENCODING_AMR_NB", sox_encoding_t::SOX_ENCODING_AMR_NB) .value("SOX_ENCODING_AMR_NB", sox_encoding_t::SOX_ENCODING_AMR_NB)
.value("SOX_ENCODING_LPC10", sox_encoding_t::SOX_ENCODING_LPC10) .value("SOX_ENCODING_LPC10", sox_encoding_t::SOX_ENCODING_LPC10)
//.value("SOX_ENCODING_OPUS", sox_encoding_t::SOX_ENCODING_OPUS) // creates a compile error //.value("SOX_ENCODING_OPUS", sox_encoding_t::SOX_ENCODING_OPUS) //
.value("SOX_ENCODINGS", sox_encoding_t::SOX_ENCODINGS) // creates a compile error
.export_values(); .value("SOX_ENCODINGS", sox_encoding_t::SOX_ENCODINGS)
.export_values();
py::enum_<sox_option_t>(m, "sox_option_t") py::enum_<sox_option_t>(m, "sox_option_t")
.value("sox_option_no", sox_option_t::sox_option_no) .value("sox_option_no", sox_option_t::sox_option_no)
.value("sox_option_yes", sox_option_t::sox_option_yes) .value("sox_option_yes", sox_option_t::sox_option_yes)
.value("sox_option_default", sox_option_t::sox_option_default) .value("sox_option_default", sox_option_t::sox_option_default)
.export_values(); .export_values();
py::enum_<sox_bool>(m, "sox_bool") py::enum_<sox_bool>(m, "sox_bool")
.value("sox_false", sox_bool::sox_false) .value("sox_false", sox_bool::sox_false)
.value("sox_true", sox_bool::sox_true) .value("sox_true", sox_bool::sox_true)
.export_values(); .export_values();
m.def( m.def(
"read_audio_file", "read_audio_file",
&torch::audio::read_audio_file, &torch::audio::read_audio_file,
......
...@@ -143,23 +143,27 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj( ...@@ -143,23 +143,27 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
c10::optional<bool>& normalize, c10::optional<bool>& normalize,
c10::optional<bool>& channels_first, c10::optional<bool>& channels_first,
c10::optional<std::string>& format) { c10::optional<std::string>& format) {
// Streaming decoding over file-like object is tricky because libsox operates
// Streaming decoding over file-like object is tricky because libsox operates on FILE pointer. // on FILE pointer. The folloing is what `sox` and `play` commands do
// The folloing is what `sox` and `play` commands do
// - file input -> FILE pointer // - file input -> FILE pointer
// - URL input -> call wget in suprocess and pipe the data -> FILE pointer // - URL input -> call wget in suprocess and pipe the data -> FILE pointer
// - stdin -> FILE pointer // - stdin -> FILE pointer
// //
// We want to, instead, fetch byte strings chunk by chunk, consume them, and discard. // We want to, instead, fetch byte strings chunk by chunk, consume them, and
// discard.
// //
// Here is the approach // Here is the approach
// 1. Initialize sox_format_t using sox_open_mem_read, providing the initial chunk of byte string // 1. Initialize sox_format_t using sox_open_mem_read, providing the initial
// This will perform header-based format detection, if necessary, then fill the metadata of // chunk of byte string
// sox_format_t. Internally, sox_open_mem_read uses fmemopen, which returns FILE* which points the // This will perform header-based format detection, if necessary, then fill
// buffer of the provided byte string. // the metadata of sox_format_t. Internally, sox_open_mem_read uses
// 2. Each time sox reads a chunk from the FILE*, we update the underlying buffer in a way that it // fmemopen, which returns FILE* which points the buffer of the provided
// starts with unseen data, and append the new data read from the given fileobj. // byte string.
// This will trick libsox as if it keeps reading from the FILE* continuously. // 2. Each time sox reads a chunk from the FILE*, we update the underlying
// buffer in a way that it
// starts with unseen data, and append the new data read from the given
// fileobj. This will trick libsox as if it keeps reading from the FILE*
// continuously.
// Prepare the buffer used throughout the lifecycle of SoxEffectChain. // Prepare the buffer used throughout the lifecycle of SoxEffectChain.
// Using std::string and let it manage memory. // Using std::string and let it manage memory.
...@@ -170,9 +174,12 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj( ...@@ -170,9 +174,12 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
auto* in_buf = const_cast<char*>(in_buffer.data()); auto* in_buf = const_cast<char*>(in_buffer.data());
// Fetch the header, and copy it to the buffer. // Fetch the header, and copy it to the buffer.
auto header = static_cast<std::string>(static_cast<py::bytes>(fileobj.attr("read")(4096))); auto header = static_cast<std::string>(
memcpy(static_cast<void*>(in_buf), static_cast<py::bytes>(fileobj.attr("read")(4096)));
static_cast<void*>(const_cast<char*>(header.data())), header.length()); memcpy(
static_cast<void*>(in_buf),
static_cast<void*>(const_cast<char*>(header.data())),
header.length());
// Open file (this starts reading the header) // Open file (this starts reading the header)
SoxFormat sf(sox_open_mem_read( SoxFormat sf(sox_open_mem_read(
...@@ -212,8 +219,7 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj( ...@@ -212,8 +219,7 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
channels_first_); channels_first_);
return std::make_tuple( return std::make_tuple(
tensor, tensor, static_cast<int64_t>(chain.getOutputSampleRate()));
static_cast<int64_t>(chain.getOutputSampleRate()));
} }
#endif // TORCH_API_INCLUDE_EXTENSION_H #endif // TORCH_API_INCLUDE_EXTENSION_H
......
...@@ -123,44 +123,47 @@ int file_output_flow( ...@@ -123,44 +123,47 @@ int file_output_flow(
} }
sox_effect_handler_t* get_tensor_input_handler() { sox_effect_handler_t* get_tensor_input_handler() {
static sox_effect_handler_t handler{/*name=*/"input_tensor", static sox_effect_handler_t handler{
/*usage=*/NULL, /*name=*/"input_tensor",
/*flags=*/SOX_EFF_MCHAN, /*usage=*/NULL,
/*getopts=*/NULL, /*flags=*/SOX_EFF_MCHAN,
/*start=*/NULL, /*getopts=*/NULL,
/*flow=*/NULL, /*start=*/NULL,
/*drain=*/tensor_input_drain, /*flow=*/NULL,
/*stop=*/NULL, /*drain=*/tensor_input_drain,
/*kill=*/NULL, /*stop=*/NULL,
/*priv_size=*/sizeof(TensorInputPriv)}; /*kill=*/NULL,
/*priv_size=*/sizeof(TensorInputPriv)};
return &handler; return &handler;
} }
sox_effect_handler_t* get_tensor_output_handler() { sox_effect_handler_t* get_tensor_output_handler() {
static sox_effect_handler_t handler{/*name=*/"output_tensor", static sox_effect_handler_t handler{
/*usage=*/NULL, /*name=*/"output_tensor",
/*flags=*/SOX_EFF_MCHAN, /*usage=*/NULL,
/*getopts=*/NULL, /*flags=*/SOX_EFF_MCHAN,
/*start=*/NULL, /*getopts=*/NULL,
/*flow=*/tensor_output_flow, /*start=*/NULL,
/*drain=*/NULL, /*flow=*/tensor_output_flow,
/*stop=*/NULL, /*drain=*/NULL,
/*kill=*/NULL, /*stop=*/NULL,
/*priv_size=*/sizeof(TensorOutputPriv)}; /*kill=*/NULL,
/*priv_size=*/sizeof(TensorOutputPriv)};
return &handler; return &handler;
} }
sox_effect_handler_t* get_file_output_handler() { sox_effect_handler_t* get_file_output_handler() {
static sox_effect_handler_t handler{/*name=*/"output_file", static sox_effect_handler_t handler{
/*usage=*/NULL, /*name=*/"output_file",
/*flags=*/SOX_EFF_MCHAN, /*usage=*/NULL,
/*getopts=*/NULL, /*flags=*/SOX_EFF_MCHAN,
/*start=*/NULL, /*getopts=*/NULL,
/*flow=*/file_output_flow, /*start=*/NULL,
/*drain=*/NULL, /*flow=*/file_output_flow,
/*stop=*/NULL, /*drain=*/NULL,
/*kill=*/NULL, /*stop=*/NULL,
/*priv_size=*/sizeof(FileOutputPriv)}; /*kill=*/NULL,
/*priv_size=*/sizeof(FileOutputPriv)};
return &handler; return &handler;
} }
...@@ -198,7 +201,8 @@ void SoxEffectsChain::addInputTensor(TensorSignal* signal) { ...@@ -198,7 +201,8 @@ void SoxEffectsChain::addInputTensor(TensorSignal* signal) {
priv->signal = signal; priv->signal = signal;
priv->index = 0; priv->index = 0;
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
throw std::runtime_error("Internal Error: Failed to add effect: input_tensor"); throw std::runtime_error(
"Internal Error: Failed to add effect: input_tensor");
} }
} }
...@@ -207,7 +211,8 @@ void SoxEffectsChain::addOutputBuffer( ...@@ -207,7 +211,8 @@ void SoxEffectsChain::addOutputBuffer(
SoxEffect e(sox_create_effect(get_tensor_output_handler())); SoxEffect e(sox_create_effect(get_tensor_output_handler()));
static_cast<TensorOutputPriv*>(e->priv)->buffer = output_buffer; static_cast<TensorOutputPriv*>(e->priv)->buffer = output_buffer;
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
throw std::runtime_error("Internal Error: Failed to add effect: output_tensor"); throw std::runtime_error(
"Internal Error: Failed to add effect: output_tensor");
} }
} }
...@@ -305,7 +310,7 @@ struct FileObjOutputPriv { ...@@ -305,7 +310,7 @@ struct FileObjOutputPriv {
/// Callback function to feed byte string /// Callback function to feed byte string
/// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/sox.h#L1268-L1278 /// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/sox.h#L1268-L1278
int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
auto priv = static_cast<FileObjInputPriv *>(effp->priv); auto priv = static_cast<FileObjInputPriv*>(effp->priv);
auto sf = priv->sf; auto sf = priv->sf;
auto fileobj = priv->fileobj; auto fileobj = priv->fileobj;
auto buffer = priv->buffer; auto buffer = priv->buffer;
...@@ -315,9 +320,9 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { ...@@ -315,9 +320,9 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
// //
// NOTE: // NOTE:
// Since the underlying FILE* was opened with `fmemopen`, the only way // Since the underlying FILE* was opened with `fmemopen`, the only way
// libsox detect EOF is reaching the end of the buffer. (null byte won't help) // libsox detect EOF is reaching the end of the buffer. (null byte won't
// Therefore we need to align the content at the end of buffer, otherwise, // help) Therefore we need to align the content at the end of buffer,
// libsox will keep reading the content beyond intended length. // otherwise, libsox will keep reading the content beyond intended length.
// //
// Before: // Before:
// //
...@@ -339,11 +344,12 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { ...@@ -339,11 +344,12 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
const auto num_refill = py::len(chunk_); const auto num_refill = py::len(chunk_);
const auto offset = buffer_size - (num_remain + num_refill); const auto offset = buffer_size - (num_remain + num_refill);
if(num_refill > num_consumed) { if (num_refill > num_consumed) {
std::ostringstream message; std::ostringstream message;
message << "Tried to read up to " << num_consumed << " bytes but, " message
<< "recieved " << num_refill << " bytes. " << "Tried to read up to " << num_consumed << " bytes but, "
<< "The given object does not confirm to read protocol of file object."; << "recieved " << num_refill << " bytes. "
<< "The given object does not confirm to read protocol of file object.";
throw std::runtime_error(message.str()); throw std::runtime_error(message.str());
} }
...@@ -364,7 +370,7 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { ...@@ -364,7 +370,7 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
// 1.4. Set the file pointer to the new offset // 1.4. Set the file pointer to the new offset
sf->tell_off = offset; sf->tell_off = offset;
fseek ((FILE*)sf->fp, offset, SEEK_SET); fseek((FILE*)sf->fp, offset, SEEK_SET);
// 2. Perform decoding operation // 2. Perform decoding operation
// The following part is practically same as "input" effect // The following part is practically same as "input" effect
...@@ -377,7 +383,7 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { ...@@ -377,7 +383,7 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
// store the actual number read back to *osamp // store the actual number read back to *osamp
*osamp = sox_read(sf, obuf, *osamp); *osamp = sox_read(sf, obuf, *osamp);
return *osamp? SOX_SUCCESS : SOX_EOF; return *osamp ? SOX_SUCCESS : SOX_EOF;
} }
int fileobj_output_flow( int fileobj_output_flow(
...@@ -420,30 +426,32 @@ int fileobj_output_flow( ...@@ -420,30 +426,32 @@ int fileobj_output_flow(
} }
sox_effect_handler_t* get_fileobj_input_handler() { sox_effect_handler_t* get_fileobj_input_handler() {
static sox_effect_handler_t handler{/*name=*/"input_fileobj_object", static sox_effect_handler_t handler{
/*usage=*/NULL, /*name=*/"input_fileobj_object",
/*flags=*/SOX_EFF_MCHAN, /*usage=*/NULL,
/*getopts=*/NULL, /*flags=*/SOX_EFF_MCHAN,
/*start=*/NULL, /*getopts=*/NULL,
/*flow=*/NULL, /*start=*/NULL,
/*drain=*/fileobj_input_drain, /*flow=*/NULL,
/*stop=*/NULL, /*drain=*/fileobj_input_drain,
/*kill=*/NULL, /*stop=*/NULL,
/*priv_size=*/sizeof(FileObjInputPriv)}; /*kill=*/NULL,
/*priv_size=*/sizeof(FileObjInputPriv)};
return &handler; return &handler;
} }
sox_effect_handler_t* get_fileobj_output_handler() { sox_effect_handler_t* get_fileobj_output_handler() {
static sox_effect_handler_t handler{/*name=*/"output_fileobj_object", static sox_effect_handler_t handler{
/*usage=*/NULL, /*name=*/"output_fileobj_object",
/*flags=*/SOX_EFF_MCHAN, /*usage=*/NULL,
/*getopts=*/NULL, /*flags=*/SOX_EFF_MCHAN,
/*start=*/NULL, /*getopts=*/NULL,
/*flow=*/fileobj_output_flow, /*start=*/NULL,
/*drain=*/NULL, /*flow=*/fileobj_output_flow,
/*stop=*/NULL, /*drain=*/NULL,
/*kill=*/NULL, /*stop=*/NULL,
/*priv_size=*/sizeof(FileObjOutputPriv)}; /*kill=*/NULL,
/*priv_size=*/sizeof(FileObjOutputPriv)};
return &handler; return &handler;
} }
...@@ -464,7 +472,8 @@ void SoxEffectsChain::addInputFileObj( ...@@ -464,7 +472,8 @@ void SoxEffectsChain::addInputFileObj(
priv->buffer = buffer; priv->buffer = buffer;
priv->buffer_size = buffer_size; priv->buffer_size = buffer_size;
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
throw std::runtime_error("Internal Error: Failed to add effect: input fileobj"); throw std::runtime_error(
"Internal Error: Failed to add effect: input fileobj");
} }
} }
...@@ -481,7 +490,8 @@ void SoxEffectsChain::addOutputFileObj( ...@@ -481,7 +490,8 @@ void SoxEffectsChain::addOutputFileObj(
priv->buffer = buffer; priv->buffer = buffer;
priv->buffer_size = buffer_size; priv->buffer_size = buffer_size;
if (sox_add_effect(sec_, e, &interm_sig_, &out_sig_) != SOX_SUCCESS) { if (sox_add_effect(sec_, e, &interm_sig_, &out_sig_) != SOX_SUCCESS) {
throw std::runtime_error("Internal Error: Failed to add effect: output fileobj"); throw std::runtime_error(
"Internal Error: Failed to add effect: output fileobj");
} }
} }
......
...@@ -112,8 +112,9 @@ void save_audio_file( ...@@ -112,8 +112,9 @@ void save_audio_file(
auto signal = TensorSignal(tensor, sample_rate, channels_first); auto signal = TensorSignal(tensor, sample_rate, channels_first);
const auto filetype = [&](){ const auto filetype = [&]() {
if (format.has_value()) return format.value(); if (format.has_value())
return format.value();
return get_filetype(path); return get_filetype(path);
}(); }();
if (filetype == "amr-nb") { if (filetype == "amr-nb") {
...@@ -123,7 +124,8 @@ void save_audio_file( ...@@ -123,7 +124,8 @@ void save_audio_file(
tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16); tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16);
} }
const auto signal_info = get_signalinfo(&signal, filetype); const auto signal_info = get_signalinfo(&signal, filetype);
const auto encoding_info = get_encodinginfo(filetype, tensor.dtype(), compression); const auto encoding_info =
get_encodinginfo(filetype, tensor.dtype(), compression);
SoxFormat sf(sox_open_write( SoxFormat sf(sox_open_write(
path.c_str(), path.c_str(),
...@@ -161,7 +163,8 @@ std::tuple<torch::Tensor, int64_t> load_audio_fileobj( ...@@ -161,7 +163,8 @@ std::tuple<torch::Tensor, int64_t> load_audio_fileobj(
namespace { namespace {
// helper class to automatically release buffer, to be used by save_audio_fileobj // helper class to automatically release buffer, to be used by
// save_audio_fileobj
struct AutoReleaseBuffer { struct AutoReleaseBuffer {
char* ptr; char* ptr;
size_t size; size_t size;
...@@ -194,12 +197,14 @@ void save_audio_fileobj( ...@@ -194,12 +197,14 @@ void save_audio_fileobj(
if (filetype == "amr-nb") { if (filetype == "amr-nb") {
const auto num_channels = tensor.size(channels_first ? 0 : 1); const auto num_channels = tensor.size(channels_first ? 0 : 1);
if (num_channels != 1) { if (num_channels != 1) {
throw std::runtime_error("amr-nb format only supports single channel audio."); throw std::runtime_error(
"amr-nb format only supports single channel audio.");
} }
tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16); tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16);
} }
const auto signal_info = get_signalinfo(&signal, filetype); const auto signal_info = get_signalinfo(&signal, filetype);
const auto encoding_info = get_encodinginfo(filetype, tensor.dtype(), compression); const auto encoding_info =
get_encodinginfo(filetype, tensor.dtype(), compression);
AutoReleaseBuffer buffer; AutoReleaseBuffer buffer;
...@@ -212,7 +217,8 @@ void save_audio_fileobj( ...@@ -212,7 +217,8 @@ void save_audio_fileobj(
/*oob=*/nullptr)); /*oob=*/nullptr));
if (static_cast<sox_format_t*>(sf) == nullptr) { if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error("Error saving audio file: failed to open memory stream."); throw std::runtime_error(
"Error saving audio file: failed to open memory stream.");
} }
torchaudio::sox_effects_chain::SoxEffectsChain chain( torchaudio::sox_effects_chain::SoxEffectsChain chain(
...@@ -222,7 +228,8 @@ void save_audio_fileobj( ...@@ -222,7 +228,8 @@ void save_audio_fileobj(
chain.addOutputFileObj(sf, &buffer.ptr, &buffer.size, &fileobj); chain.addOutputFileObj(sf, &buffer.ptr, &buffer.size, &fileobj);
chain.run(); chain.run();
// Closing the sox_format_t is necessary for flushing the last chunk to the buffer // Closing the sox_format_t is necessary for flushing the last chunk to the
// buffer
sf.close(); sf.close();
fileobj.attr("write")(py::bytes(buffer.ptr, buffer.size)); fileobj.attr("write")(py::bytes(buffer.ptr, buffer.size));
......
...@@ -40,10 +40,7 @@ int64_t write_audio(SoxDescriptor& fd, at::Tensor tensor) { ...@@ -40,10 +40,7 @@ int64_t write_audio(SoxDescriptor& fd, at::Tensor tensor) {
return samples_written; return samples_written;
} }
void read_audio( void read_audio(SoxDescriptor& fd, at::Tensor output, int64_t buffer_length) {
SoxDescriptor& fd,
at::Tensor output,
int64_t buffer_length) {
std::vector<sox_sample_t> buffer(buffer_length); std::vector<sox_sample_t> buffer(buffer_length);
int number_of_channels = fd->signal.channels; int number_of_channels = fd->signal.channels;
...@@ -64,8 +61,7 @@ void read_audio( ...@@ -64,8 +61,7 @@ void read_audio(
} // namespace } // namespace
std::tuple<sox_signalinfo_t, sox_encodinginfo_t> get_info( std::tuple<sox_signalinfo_t, sox_encodinginfo_t> get_info(
const std::string& file_name const std::string& file_name) {
) {
SoxDescriptor fd(sox_open_read( SoxDescriptor fd(sox_open_read(
file_name.c_str(), file_name.c_str(),
/*signal=*/nullptr, /*signal=*/nullptr,
...@@ -86,7 +82,6 @@ int read_audio_file( ...@@ -86,7 +82,6 @@ int read_audio_file(
sox_signalinfo_t* si, sox_signalinfo_t* si,
sox_encodinginfo_t* ei, sox_encodinginfo_t* ei,
const char* ft) { const char* ft) {
SoxDescriptor fd(sox_open_read(file_name.c_str(), si, ei, ft)); SoxDescriptor fd(sox_open_read(file_name.c_str(), si, ei, ft));
if (fd.get() == nullptr) { if (fd.get() == nullptr) {
throw std::runtime_error("Error opening audio file"); throw std::runtime_error("Error opening audio file");
...@@ -112,15 +107,16 @@ int read_audio_file( ...@@ -112,15 +107,16 @@ int read_audio_file(
// calculate buffer length // calculate buffer length
int64_t buffer_length = total_length; int64_t buffer_length = total_length;
if (offset > 0) { if (offset > 0) {
buffer_length -= offset; buffer_length -= offset;
} }
if (nframes > 0 && buffer_length > nframes) { if (nframes > 0 && buffer_length > nframes) {
buffer_length = nframes; buffer_length = nframes;
} }
// seek to offset point before reading data // seek to offset point before reading data
if (sox_seek(fd.get(), offset, 0) == SOX_EOF) { if (sox_seek(fd.get(), offset, 0) == SOX_EOF) {
throw std::runtime_error("sox_seek reached EOF, try reducing offset or num_samples"); throw std::runtime_error(
"sox_seek reached EOF, try reducing offset or num_samples");
} }
// read data and fill output tensor // read data and fill output tensor
......
#include <sox.h> #include <sox.h>
#include <torch/torch.h> #include <torch/torch.h>
namespace torch { namespace audio { namespace torch {
namespace audio {
/// Reads an audio file from the given `path` into the `output` `Tensor` and /// Reads an audio file from the given `path` into the `output` `Tensor` and
/// returns the sample rate of the audio file. /// returns the sample rate of the audio file.
...@@ -30,9 +31,10 @@ void write_audio_file( ...@@ -30,9 +31,10 @@ void write_audio_file(
/// Reads an audio file from the given `path` and returns a tuple of /// Reads an audio file from the given `path` and returns a tuple of
/// sox_signalinfo_t and sox_encodinginfo_t, which contain information about /// sox_signalinfo_t and sox_encodinginfo_t, which contain information about
/// the audio file such as sample rate, length, bit precision, encoding and more. /// the audio file such as sample rate, length, bit precision, encoding and
/// Throws `std::runtime_error` if the audio file could not be opened, or an /// more. Throws `std::runtime_error` if the audio file could not be opened, or
/// error occurred during reading of the audio data. /// an error occurred during reading of the audio data.
std::tuple<sox_signalinfo_t, sox_encodinginfo_t> get_info( std::tuple<sox_signalinfo_t, sox_encodinginfo_t> get_info(
const std::string& file_name); const std::string& file_name);
}} // namespace torch::audio } // namespace audio
} // namespace torch
...@@ -43,17 +43,19 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { ...@@ -43,17 +43,19 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
.def("get_sample_rate", &torchaudio::sox_io::SignalInfo::getSampleRate) .def("get_sample_rate", &torchaudio::sox_io::SignalInfo::getSampleRate)
.def("get_num_channels", &torchaudio::sox_io::SignalInfo::getNumChannels) .def("get_num_channels", &torchaudio::sox_io::SignalInfo::getNumChannels)
.def("get_num_frames", &torchaudio::sox_io::SignalInfo::getNumFrames) .def("get_num_frames", &torchaudio::sox_io::SignalInfo::getNumFrames)
.def("get_bits_per_sample", &torchaudio::sox_io::SignalInfo::getBitsPerSample); .def(
"get_bits_per_sample",
&torchaudio::sox_io::SignalInfo::getBitsPerSample);
m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info); m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info);
m.def( m.def(
"torchaudio::sox_io_load_audio_file(" "torchaudio::sox_io_load_audio_file("
"str path," "str path,"
"int? frame_offset=None," "int? frame_offset=None,"
"int? num_frames=None," "int? num_frames=None,"
"bool? normalize=True," "bool? normalize=True,"
"bool? channels_first=False," "bool? channels_first=False,"
"str? format=None" "str? format=None"
") -> __torch__.torch.classes.torchaudio.TensorSignal", ") -> __torch__.torch.classes.torchaudio.TensorSignal",
&torchaudio::sox_io::load_audio_file); &torchaudio::sox_io::load_audio_file);
m.def( m.def(
......
...@@ -80,7 +80,9 @@ bool TensorSignal::getChannelsFirst() const { ...@@ -80,7 +80,9 @@ bool TensorSignal::getChannelsFirst() const {
} }
SoxFormat::SoxFormat(sox_format_t* fd) noexcept : fd_(fd) {} SoxFormat::SoxFormat(sox_format_t* fd) noexcept : fd_(fd) {}
SoxFormat::~SoxFormat() { close(); } SoxFormat::~SoxFormat() {
close();
}
sox_format_t* SoxFormat::operator->() const noexcept { sox_format_t* SoxFormat::operator->() const noexcept {
return fd_; return fd_;
...@@ -291,26 +293,28 @@ sox_signalinfo_t get_signalinfo( ...@@ -291,26 +293,28 @@ sox_signalinfo_t get_signalinfo(
sox_encodinginfo_t get_encodinginfo( sox_encodinginfo_t get_encodinginfo(
const std::string filetype, const std::string filetype,
const caffe2::TypeMeta dtype) { const caffe2::TypeMeta dtype) {
return sox_encodinginfo_t{/*encoding=*/get_encoding(filetype, dtype), return sox_encodinginfo_t{
/*bits_per_sample=*/get_precision(filetype, dtype), /*encoding=*/get_encoding(filetype, dtype),
/*compression=*/HUGE_VAL, /*bits_per_sample=*/get_precision(filetype, dtype),
/*reverse_bytes=*/sox_option_default, /*compression=*/HUGE_VAL,
/*reverse_nibbles=*/sox_option_default, /*reverse_bytes=*/sox_option_default,
/*reverse_bits=*/sox_option_default, /*reverse_nibbles=*/sox_option_default,
/*opposite_endian=*/sox_false}; /*reverse_bits=*/sox_option_default,
/*opposite_endian=*/sox_false};
} }
sox_encodinginfo_t get_encodinginfo( sox_encodinginfo_t get_encodinginfo(
const std::string filetype, const std::string filetype,
const caffe2::TypeMeta dtype, const caffe2::TypeMeta dtype,
c10::optional<double>& compression) { c10::optional<double>& compression) {
return sox_encodinginfo_t{/*encoding=*/get_encoding(filetype, dtype), return sox_encodinginfo_t{
/*bits_per_sample=*/get_precision(filetype, dtype), /*encoding=*/get_encoding(filetype, dtype),
/*compression=*/compression.value_or(HUGE_VAL), /*bits_per_sample=*/get_precision(filetype, dtype),
/*reverse_bytes=*/sox_option_default, /*compression=*/compression.value_or(HUGE_VAL),
/*reverse_nibbles=*/sox_option_default, /*reverse_bytes=*/sox_option_default,
/*reverse_bits=*/sox_option_default, /*reverse_nibbles=*/sox_option_default,
/*opposite_endian=*/sox_false}; /*reverse_bits=*/sox_option_default,
/*opposite_endian=*/sox_false};
} }
} // namespace sox_utils } // namespace sox_utils
......
...@@ -69,7 +69,7 @@ struct SoxFormat { ...@@ -69,7 +69,7 @@ struct SoxFormat {
/// ///
/// Verify that input file is found, has known encoding, and not empty /// Verify that input file is found, has known encoding, and not empty
void validate_input_file(const SoxFormat& sf, bool check_length=true); void validate_input_file(const SoxFormat& sf, bool check_length = true);
/// ///
/// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32 /// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32
......
...@@ -8,71 +8,80 @@ ...@@ -8,71 +8,80 @@
namespace { namespace {
int64_t cpu_rnnt_loss(torch::Tensor acts, int64_t cpu_rnnt_loss(
torch::Tensor labels, torch::Tensor acts,
torch::Tensor input_lengths, torch::Tensor labels,
torch::Tensor label_lengths, torch::Tensor input_lengths,
torch::Tensor costs, torch::Tensor label_lengths,
torch::Tensor grads, torch::Tensor costs,
int64_t blank_label, torch::Tensor grads,
int64_t num_threads) { int64_t blank_label,
int64_t num_threads) {
int maxT = acts.size(1); int maxT = acts.size(1);
int maxU = acts.size(2); int maxU = acts.size(2);
int minibatch_size = acts.size(0); int minibatch_size = acts.size(0);
int alphabet_size = acts.size(3); int alphabet_size = acts.size(3);
rnntOptions options; rnntOptions options;
memset(&options, 0, sizeof(options)); memset(&options, 0, sizeof(options));
options.maxT = maxT; options.maxT = maxT;
options.maxU = maxU; options.maxU = maxU;
options.blank_label = blank_label; options.blank_label = blank_label;
options.batch_first = true; options.batch_first = true;
options.loc = RNNT_CPU; options.loc = RNNT_CPU;
options.num_threads = num_threads; options.num_threads = num_threads;
// have to use at least one // have to use at least one
options.num_threads = std::max(options.num_threads, (unsigned int) 1); options.num_threads = std::max(options.num_threads, (unsigned int)1);
size_t cpu_size_bytes = 0; size_t cpu_size_bytes = 0;
switch (acts.scalar_type()) { switch (acts.scalar_type()) {
case torch::ScalarType::Float: case torch::ScalarType::Float: {
{ get_workspace_size(maxT, maxU, minibatch_size, false, &cpu_size_bytes);
get_workspace_size(maxT, maxU, minibatch_size,
false, &cpu_size_bytes); std::vector<float> cpu_workspace(cpu_size_bytes / sizeof(float), 0);
std::vector<float> cpu_workspace(cpu_size_bytes / sizeof(float), 0); compute_rnnt_loss(
acts.data_ptr<float>(),
compute_rnnt_loss(acts.data_ptr<float>(), grads.data_ptr<float>(), grads.data_ptr<float>(),
labels.data_ptr<int>(), label_lengths.data_ptr<int>(), labels.data_ptr<int>(),
input_lengths.data_ptr<int>(), alphabet_size, label_lengths.data_ptr<int>(),
minibatch_size, costs.data_ptr<float>(), input_lengths.data_ptr<int>(),
cpu_workspace.data(), options); alphabet_size,
minibatch_size,
return 0; costs.data_ptr<float>(),
} cpu_workspace.data(),
case torch::ScalarType::Double: options);
{
get_workspace_size(maxT, maxU, minibatch_size, return 0;
false, &cpu_size_bytes, }
sizeof(double)); case torch::ScalarType::Double: {
get_workspace_size(
std::vector<double> cpu_workspace(cpu_size_bytes / sizeof(double), 0); maxT, maxU, minibatch_size, false, &cpu_size_bytes, sizeof(double));
compute_rnnt_loss_fp64(acts.data_ptr<double>(), grads.data_ptr<double>(), std::vector<double> cpu_workspace(cpu_size_bytes / sizeof(double), 0);
labels.data_ptr<int>(), label_lengths.data_ptr<int>(),
input_lengths.data_ptr<int>(), alphabet_size, compute_rnnt_loss_fp64(
minibatch_size, costs.data_ptr<double>(), acts.data_ptr<double>(),
cpu_workspace.data(), options); grads.data_ptr<double>(),
labels.data_ptr<int>(),
return 0; label_lengths.data_ptr<int>(),
} input_lengths.data_ptr<int>(),
default: alphabet_size,
TORCH_CHECK(false, minibatch_size,
std::string(__func__) + " not implemented for '" + toString(acts.scalar_type()) + "'" costs.data_ptr<double>(),
); cpu_workspace.data(),
options);
return 0;
} }
return -1; default:
TORCH_CHECK(
false,
std::string(__func__) + " not implemented for '" +
toString(acts.scalar_type()) + "'");
}
return -1;
} }
} // namespace } // namespace
...@@ -82,12 +91,13 @@ TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { ...@@ -82,12 +91,13 @@ TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
} }
TORCH_LIBRARY_FRAGMENT(torchaudio, m) { TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("rnnt_loss(Tensor acts," m.def(
"Tensor labels," "rnnt_loss(Tensor acts,"
"Tensor input_lengths," "Tensor labels,"
"Tensor label_lengths," "Tensor input_lengths,"
"Tensor costs," "Tensor label_lengths,"
"Tensor grads," "Tensor costs,"
"int blank_label," "Tensor grads,"
"int num_threads) -> int"); "int blank_label,"
"int num_threads) -> int");
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment