Unverified Commit b33c539c authored by moto's avatar moto Committed by GitHub
Browse files

Fix clang-format CI job (#1198)

parent 99ed7183
#!/usr/bin/env python
"""A wrapper script around clang-format, suitable for linting multiple files
and to use for continuous integration.
This is an alternative API for the clang-format command line.
It runs over multiple files and directories in parallel.
A diff output is produced and a sensible exit code is returned.
"""
import argparse
import codecs
import difflib
import fnmatch
import io
import multiprocessing
import os
import signal
import subprocess
import sys
import traceback
from functools import partial
try:
from subprocess import DEVNULL # py3k
except ImportError:
DEVNULL = open(os.devnull, "wb")
DEFAULT_EXTENSIONS = 'c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu'
class ExitStatus:
SUCCESS = 0
DIFF = 1
TROUBLE = 2
def list_files(files, recursive=False, extensions=None, exclude=None):
if extensions is None:
extensions = []
if exclude is None:
exclude = []
out = []
for file in files:
if recursive and os.path.isdir(file):
for dirpath, dnames, fnames in os.walk(file):
fpaths = [os.path.join(dirpath, fname) for fname in fnames]
for pattern in exclude:
# os.walk() supports trimming down the dnames list
# by modifying it in-place,
# to avoid unnecessary directory listings.
dnames[:] = [
x for x in dnames
if
not fnmatch.fnmatch(os.path.join(dirpath, x), pattern)
]
fpaths = [
x for x in fpaths if not fnmatch.fnmatch(x, pattern)
]
for f in fpaths:
ext = os.path.splitext(f)[1][1:]
if ext in extensions:
out.append(f)
else:
out.append(file)
return out
def make_diff(file, original, reformatted):
return list(
difflib.unified_diff(
original,
reformatted,
fromfile='{}\t(original)'.format(file),
tofile='{}\t(reformatted)'.format(file),
n=3))
class DiffError(Exception):
def __init__(self, message, errs=None):
super(DiffError, self).__init__(message)
self.errs = errs or []
class UnexpectedError(Exception):
def __init__(self, message, exc=None):
super(UnexpectedError, self).__init__(message)
self.formatted_traceback = traceback.format_exc()
self.exc = exc
def run_clang_format_diff_wrapper(args, file):
try:
ret = run_clang_format_diff(args, file)
return ret
except DiffError:
raise
except Exception as e:
raise UnexpectedError('{}: {}: {}'.format(file, e.__class__.__name__,
e), e)
def run_clang_format_diff(args, file):
try:
with io.open(file, 'r', encoding='utf-8') as f:
original = f.readlines()
except IOError as exc:
raise DiffError(str(exc))
invocation = [args.clang_format_executable, file]
# Use of utf-8 to decode the process output.
#
# Hopefully, this is the correct thing to do.
#
# It's done due to the following assumptions (which may be incorrect):
# - clang-format will returns the bytes read from the files as-is,
# without conversion, and it is already assumed that the files use utf-8.
# - if the diagnostics were internationalized, they would use utf-8:
# > Adding Translations to Clang
# >
# > Not possible yet!
# > Diagnostic strings should be written in UTF-8,
# > the client can translate to the relevant code page if needed.
# > Each translation completely replaces the format string
# > for the diagnostic.
# > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation
try:
proc = subprocess.Popen(
invocation,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
encoding='utf-8')
except OSError as exc:
raise DiffError(
"Command '{}' failed to start: {}".format(
subprocess.list2cmdline(invocation), exc
)
)
proc_stdout = proc.stdout
proc_stderr = proc.stderr
# hopefully the stderr pipe won't get full and block the process
outs = list(proc_stdout.readlines())
errs = list(proc_stderr.readlines())
proc.wait()
if proc.returncode:
raise DiffError(
"Command '{}' returned non-zero exit status {}".format(
subprocess.list2cmdline(invocation), proc.returncode
),
errs,
)
return make_diff(file, original, outs), errs
def bold_red(s):
return '\x1b[1m\x1b[31m' + s + '\x1b[0m'
def colorize(diff_lines):
def bold(s):
return '\x1b[1m' + s + '\x1b[0m'
def cyan(s):
return '\x1b[36m' + s + '\x1b[0m'
def green(s):
return '\x1b[32m' + s + '\x1b[0m'
def red(s):
return '\x1b[31m' + s + '\x1b[0m'
for line in diff_lines:
if line[:4] in ['--- ', '+++ ']:
yield bold(line)
elif line.startswith('@@ '):
yield cyan(line)
elif line.startswith('+'):
yield green(line)
elif line.startswith('-'):
yield red(line)
else:
yield line
def print_diff(diff_lines, use_color):
if use_color:
diff_lines = colorize(diff_lines)
sys.stdout.writelines(diff_lines)
def print_trouble(prog, message, use_colors):
error_text = 'error:'
if use_colors:
error_text = bold_red(error_text)
print("{}: {} {}".format(prog, error_text, message), file=sys.stderr)
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--clang-format-executable',
metavar='EXECUTABLE',
help='path to the clang-format executable',
default='clang-format')
parser.add_argument(
'--extensions',
help='comma separated list of file extensions (default: {})'.format(
DEFAULT_EXTENSIONS),
default=DEFAULT_EXTENSIONS)
parser.add_argument(
'-r',
'--recursive',
action='store_true',
help='run recursively over directories')
parser.add_argument('files', metavar='file', nargs='+')
parser.add_argument(
'-q',
'--quiet',
action='store_true')
parser.add_argument(
'-j',
metavar='N',
type=int,
default=0,
help='run N clang-format jobs in parallel'
' (default number of cpus + 1)')
parser.add_argument(
'--color',
default='auto',
choices=['auto', 'always', 'never'],
help='show colored diff (default: auto)')
parser.add_argument(
'-e',
'--exclude',
metavar='PATTERN',
action='append',
default=[],
help='exclude paths matching the given glob-like pattern(s)'
' from recursive search')
args = parser.parse_args()
# use default signal handling, like diff return SIGINT value on ^C
# https://bugs.python.org/issue14229#msg156446
signal.signal(signal.SIGINT, signal.SIG_DFL)
try:
signal.SIGPIPE
except AttributeError:
# compatibility, SIGPIPE does not exist on Windows
pass
else:
signal.signal(signal.SIGPIPE, signal.SIG_DFL)
colored_stdout = False
colored_stderr = False
if args.color == 'always':
colored_stdout = True
colored_stderr = True
elif args.color == 'auto':
colored_stdout = sys.stdout.isatty()
colored_stderr = sys.stderr.isatty()
version_invocation = [args.clang_format_executable, str("--version")]
try:
subprocess.check_call(version_invocation, stdout=DEVNULL)
except subprocess.CalledProcessError as e:
print_trouble(parser.prog, str(e), use_colors=colored_stderr)
return ExitStatus.TROUBLE
except OSError as e:
print_trouble(
parser.prog,
"Command '{}' failed to start: {}".format(
subprocess.list2cmdline(version_invocation), e
),
use_colors=colored_stderr,
)
return ExitStatus.TROUBLE
retcode = ExitStatus.SUCCESS
files = list_files(
args.files,
recursive=args.recursive,
exclude=args.exclude,
extensions=args.extensions.split(','))
if not files:
return
njobs = args.j
if njobs == 0:
njobs = multiprocessing.cpu_count() + 1
njobs = min(len(files), njobs)
if njobs == 1:
# execute directly instead of in a pool,
# less overhead, simpler stacktraces
it = (run_clang_format_diff_wrapper(args, file) for file in files)
pool = None
else:
pool = multiprocessing.Pool(njobs)
it = pool.imap_unordered(
partial(run_clang_format_diff_wrapper, args), files)
while True:
try:
outs, errs = next(it)
except StopIteration:
break
except DiffError as e:
print_trouble(parser.prog, str(e), use_colors=colored_stderr)
retcode = ExitStatus.TROUBLE
sys.stderr.writelines(e.errs)
except UnexpectedError as e:
print_trouble(parser.prog, str(e), use_colors=colored_stderr)
sys.stderr.write(e.formatted_traceback)
retcode = ExitStatus.TROUBLE
# stop at the first unexpected error,
# something could be very wrong,
# don't process all files unnecessarily
if pool:
pool.terminate()
break
else:
sys.stderr.writelines(errs)
if outs == []:
continue
if not args.quiet:
print_diff(outs, use_color=colored_stdout)
if retcode == ExitStatus.SUCCESS:
retcode = ExitStatus.DIFF
return retcode
if __name__ == '__main__':
sys.exit(main())
#!/usr/bin/env bash
set -u
set -eux
root_dir="$(git rev-parse --show-toplevel)"
conda_dir="${root_dir}/conda"
env_dir="${root_dir}/env"
this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
eval "$("${conda_dir}/bin/conda" shell.bash hook)"
conda activate "${env_dir}"
# 1. Install tools
conda install flake8
printf "Installed flake8: "
flake8 --version
clangformat_path="${root_dir}/clang-format"
curl https://oss-clang-format.s3.us-east-2.amazonaws.com/linux64/clang-format-linux64 -o "${clangformat_path}"
chmod +x "${clangformat_path}"
printf "Installed clang-fortmat"
"${clangformat_path}" --version
# 2. Run style checks
# We want to run all the style checks even if one of them fail.
set +e
exit_status=0
printf "\x1b[34mRunning flake8: "
flake8 --version
printf "\x1b[0m\n"
printf "\x1b[34mRunning flake8:\x1b[0m\n"
flake8 torchaudio test build_tools/setup_helpers
status=$?
exit_status="$((exit_status+status))"
......@@ -30,14 +36,14 @@ if [ "${status}" -ne 0 ]; then
printf "\x1b[31mflake8 failed. Check the format of Python files.\x1b[0m\n"
fi
printf "\x1b[34mRunning clang-format: "
./clang-format --version
printf "\x1b[0m\n"
git-clang-format --binary ./clang-format origin/master
git diff --exit-code
printf "\x1b[34mRunning clang-format:\x1b[0m\n"
"${this_dir}"/run-clang-format.py \
-r torchaudio/csrc \
--clang-format-executable "${clangformat_path}" \
&& git diff --exit-code
status=$?
exit_status="$((exit_status+status))"
if [ "${status}" -ne 0 ]; then
printf "\x1b[31mC++ files are not formatted. Please use git-clang-format to format CPP files.\x1b[0m\n"
printf "\x1b[31mC++ files are not formatted. Please use clang-format to format CPP files.\x1b[0m\n"
fi
exit $exit_status
......@@ -2,11 +2,12 @@
#include <torchaudio/csrc/sox/io.h>
#include <torchaudio/csrc/sox/legacy.h>
PYBIND11_MODULE(_torchaudio, m) {
py::class_<sox_signalinfo_t>(m, "sox_signalinfo_t")
.def(py::init<>())
.def("__repr__", [](const sox_signalinfo_t &self) {
.def(
"__repr__",
[](const sox_signalinfo_t& self) {
std::stringstream ss;
ss << "sox_signalinfo_t {\n"
<< " rate-> " << self.rate << "\n"
......@@ -24,7 +25,9 @@ PYBIND11_MODULE(_torchaudio, m) {
.def_readwrite("mult", &sox_signalinfo_t::mult);
py::class_<sox_encodinginfo_t>(m, "sox_encodinginfo_t")
.def(py::init<>())
.def("__repr__", [](const sox_encodinginfo_t &self) {
.def(
"__repr__",
[](const sox_encodinginfo_t& self) {
std::stringstream ss;
ss << "sox_encodinginfo_t {\n"
<< " encoding-> " << self.encoding << "\n"
......@@ -72,7 +75,8 @@ PYBIND11_MODULE(_torchaudio, m) {
.value("SOX_ENCODING_AMR_WB", sox_encoding_t::SOX_ENCODING_AMR_WB)
.value("SOX_ENCODING_AMR_NB", sox_encoding_t::SOX_ENCODING_AMR_NB)
.value("SOX_ENCODING_LPC10", sox_encoding_t::SOX_ENCODING_LPC10)
//.value("SOX_ENCODING_OPUS", sox_encoding_t::SOX_ENCODING_OPUS) // creates a compile error
//.value("SOX_ENCODING_OPUS", sox_encoding_t::SOX_ENCODING_OPUS) //
// creates a compile error
.value("SOX_ENCODINGS", sox_encoding_t::SOX_ENCODINGS)
.export_values();
py::enum_<sox_option_t>(m, "sox_option_t")
......
......@@ -143,23 +143,27 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first,
c10::optional<std::string>& format) {
// Streaming decoding over file-like object is tricky because libsox operates on FILE pointer.
// The folloing is what `sox` and `play` commands do
// Streaming decoding over file-like object is tricky because libsox operates
// on FILE pointer. The folloing is what `sox` and `play` commands do
// - file input -> FILE pointer
// - URL input -> call wget in suprocess and pipe the data -> FILE pointer
// - stdin -> FILE pointer
//
// We want to, instead, fetch byte strings chunk by chunk, consume them, and discard.
// We want to, instead, fetch byte strings chunk by chunk, consume them, and
// discard.
//
// Here is the approach
// 1. Initialize sox_format_t using sox_open_mem_read, providing the initial chunk of byte string
// This will perform header-based format detection, if necessary, then fill the metadata of
// sox_format_t. Internally, sox_open_mem_read uses fmemopen, which returns FILE* which points the
// buffer of the provided byte string.
// 2. Each time sox reads a chunk from the FILE*, we update the underlying buffer in a way that it
// starts with unseen data, and append the new data read from the given fileobj.
// This will trick libsox as if it keeps reading from the FILE* continuously.
// 1. Initialize sox_format_t using sox_open_mem_read, providing the initial
// chunk of byte string
// This will perform header-based format detection, if necessary, then fill
// the metadata of sox_format_t. Internally, sox_open_mem_read uses
// fmemopen, which returns FILE* which points the buffer of the provided
// byte string.
// 2. Each time sox reads a chunk from the FILE*, we update the underlying
// buffer in a way that it
// starts with unseen data, and append the new data read from the given
// fileobj. This will trick libsox as if it keeps reading from the FILE*
// continuously.
// Prepare the buffer used throughout the lifecycle of SoxEffectChain.
// Using std::string and let it manage memory.
......@@ -170,9 +174,12 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
auto* in_buf = const_cast<char*>(in_buffer.data());
// Fetch the header, and copy it to the buffer.
auto header = static_cast<std::string>(static_cast<py::bytes>(fileobj.attr("read")(4096)));
memcpy(static_cast<void*>(in_buf),
static_cast<void*>(const_cast<char*>(header.data())), header.length());
auto header = static_cast<std::string>(
static_cast<py::bytes>(fileobj.attr("read")(4096)));
memcpy(
static_cast<void*>(in_buf),
static_cast<void*>(const_cast<char*>(header.data())),
header.length());
// Open file (this starts reading the header)
SoxFormat sf(sox_open_mem_read(
......@@ -212,8 +219,7 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
channels_first_);
return std::make_tuple(
tensor,
static_cast<int64_t>(chain.getOutputSampleRate()));
tensor, static_cast<int64_t>(chain.getOutputSampleRate()));
}
#endif // TORCH_API_INCLUDE_EXTENSION_H
......
......@@ -123,7 +123,8 @@ int file_output_flow(
}
sox_effect_handler_t* get_tensor_input_handler() {
static sox_effect_handler_t handler{/*name=*/"input_tensor",
static sox_effect_handler_t handler{
/*name=*/"input_tensor",
/*usage=*/NULL,
/*flags=*/SOX_EFF_MCHAN,
/*getopts=*/NULL,
......@@ -137,7 +138,8 @@ sox_effect_handler_t* get_tensor_input_handler() {
}
sox_effect_handler_t* get_tensor_output_handler() {
static sox_effect_handler_t handler{/*name=*/"output_tensor",
static sox_effect_handler_t handler{
/*name=*/"output_tensor",
/*usage=*/NULL,
/*flags=*/SOX_EFF_MCHAN,
/*getopts=*/NULL,
......@@ -151,7 +153,8 @@ sox_effect_handler_t* get_tensor_output_handler() {
}
sox_effect_handler_t* get_file_output_handler() {
static sox_effect_handler_t handler{/*name=*/"output_file",
static sox_effect_handler_t handler{
/*name=*/"output_file",
/*usage=*/NULL,
/*flags=*/SOX_EFF_MCHAN,
/*getopts=*/NULL,
......@@ -198,7 +201,8 @@ void SoxEffectsChain::addInputTensor(TensorSignal* signal) {
priv->signal = signal;
priv->index = 0;
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
throw std::runtime_error("Internal Error: Failed to add effect: input_tensor");
throw std::runtime_error(
"Internal Error: Failed to add effect: input_tensor");
}
}
......@@ -207,7 +211,8 @@ void SoxEffectsChain::addOutputBuffer(
SoxEffect e(sox_create_effect(get_tensor_output_handler()));
static_cast<TensorOutputPriv*>(e->priv)->buffer = output_buffer;
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
throw std::runtime_error("Internal Error: Failed to add effect: output_tensor");
throw std::runtime_error(
"Internal Error: Failed to add effect: output_tensor");
}
}
......@@ -305,7 +310,7 @@ struct FileObjOutputPriv {
/// Callback function to feed byte string
/// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/sox.h#L1268-L1278
int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
auto priv = static_cast<FileObjInputPriv *>(effp->priv);
auto priv = static_cast<FileObjInputPriv*>(effp->priv);
auto sf = priv->sf;
auto fileobj = priv->fileobj;
auto buffer = priv->buffer;
......@@ -315,9 +320,9 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
//
// NOTE:
// Since the underlying FILE* was opened with `fmemopen`, the only way
// libsox detect EOF is reaching the end of the buffer. (null byte won't help)
// Therefore we need to align the content at the end of buffer, otherwise,
// libsox will keep reading the content beyond intended length.
// libsox detect EOF is reaching the end of the buffer. (null byte won't
// help) Therefore we need to align the content at the end of buffer,
// otherwise, libsox will keep reading the content beyond intended length.
//
// Before:
//
......@@ -339,9 +344,10 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
const auto num_refill = py::len(chunk_);
const auto offset = buffer_size - (num_remain + num_refill);
if(num_refill > num_consumed) {
if (num_refill > num_consumed) {
std::ostringstream message;
message << "Tried to read up to " << num_consumed << " bytes but, "
message
<< "Tried to read up to " << num_consumed << " bytes but, "
<< "recieved " << num_refill << " bytes. "
<< "The given object does not confirm to read protocol of file object.";
throw std::runtime_error(message.str());
......@@ -364,7 +370,7 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
// 1.4. Set the file pointer to the new offset
sf->tell_off = offset;
fseek ((FILE*)sf->fp, offset, SEEK_SET);
fseek((FILE*)sf->fp, offset, SEEK_SET);
// 2. Perform decoding operation
// The following part is practically same as "input" effect
......@@ -377,7 +383,7 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
// store the actual number read back to *osamp
*osamp = sox_read(sf, obuf, *osamp);
return *osamp? SOX_SUCCESS : SOX_EOF;
return *osamp ? SOX_SUCCESS : SOX_EOF;
}
int fileobj_output_flow(
......@@ -420,7 +426,8 @@ int fileobj_output_flow(
}
sox_effect_handler_t* get_fileobj_input_handler() {
static sox_effect_handler_t handler{/*name=*/"input_fileobj_object",
static sox_effect_handler_t handler{
/*name=*/"input_fileobj_object",
/*usage=*/NULL,
/*flags=*/SOX_EFF_MCHAN,
/*getopts=*/NULL,
......@@ -434,7 +441,8 @@ sox_effect_handler_t* get_fileobj_input_handler() {
}
sox_effect_handler_t* get_fileobj_output_handler() {
static sox_effect_handler_t handler{/*name=*/"output_fileobj_object",
static sox_effect_handler_t handler{
/*name=*/"output_fileobj_object",
/*usage=*/NULL,
/*flags=*/SOX_EFF_MCHAN,
/*getopts=*/NULL,
......@@ -464,7 +472,8 @@ void SoxEffectsChain::addInputFileObj(
priv->buffer = buffer;
priv->buffer_size = buffer_size;
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
throw std::runtime_error("Internal Error: Failed to add effect: input fileobj");
throw std::runtime_error(
"Internal Error: Failed to add effect: input fileobj");
}
}
......@@ -481,7 +490,8 @@ void SoxEffectsChain::addOutputFileObj(
priv->buffer = buffer;
priv->buffer_size = buffer_size;
if (sox_add_effect(sec_, e, &interm_sig_, &out_sig_) != SOX_SUCCESS) {
throw std::runtime_error("Internal Error: Failed to add effect: output fileobj");
throw std::runtime_error(
"Internal Error: Failed to add effect: output fileobj");
}
}
......
......@@ -112,8 +112,9 @@ void save_audio_file(
auto signal = TensorSignal(tensor, sample_rate, channels_first);
const auto filetype = [&](){
if (format.has_value()) return format.value();
const auto filetype = [&]() {
if (format.has_value())
return format.value();
return get_filetype(path);
}();
if (filetype == "amr-nb") {
......@@ -123,7 +124,8 @@ void save_audio_file(
tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16);
}
const auto signal_info = get_signalinfo(&signal, filetype);
const auto encoding_info = get_encodinginfo(filetype, tensor.dtype(), compression);
const auto encoding_info =
get_encodinginfo(filetype, tensor.dtype(), compression);
SoxFormat sf(sox_open_write(
path.c_str(),
......@@ -161,7 +163,8 @@ std::tuple<torch::Tensor, int64_t> load_audio_fileobj(
namespace {
// helper class to automatically release buffer, to be used by save_audio_fileobj
// helper class to automatically release buffer, to be used by
// save_audio_fileobj
struct AutoReleaseBuffer {
char* ptr;
size_t size;
......@@ -194,12 +197,14 @@ void save_audio_fileobj(
if (filetype == "amr-nb") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
if (num_channels != 1) {
throw std::runtime_error("amr-nb format only supports single channel audio.");
throw std::runtime_error(
"amr-nb format only supports single channel audio.");
}
tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16);
}
const auto signal_info = get_signalinfo(&signal, filetype);
const auto encoding_info = get_encodinginfo(filetype, tensor.dtype(), compression);
const auto encoding_info =
get_encodinginfo(filetype, tensor.dtype(), compression);
AutoReleaseBuffer buffer;
......@@ -212,7 +217,8 @@ void save_audio_fileobj(
/*oob=*/nullptr));
if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error("Error saving audio file: failed to open memory stream.");
throw std::runtime_error(
"Error saving audio file: failed to open memory stream.");
}
torchaudio::sox_effects_chain::SoxEffectsChain chain(
......@@ -222,7 +228,8 @@ void save_audio_fileobj(
chain.addOutputFileObj(sf, &buffer.ptr, &buffer.size, &fileobj);
chain.run();
// Closing the sox_format_t is necessary for flushing the last chunk to the buffer
// Closing the sox_format_t is necessary for flushing the last chunk to the
// buffer
sf.close();
fileobj.attr("write")(py::bytes(buffer.ptr, buffer.size));
......
......@@ -40,10 +40,7 @@ int64_t write_audio(SoxDescriptor& fd, at::Tensor tensor) {
return samples_written;
}
void read_audio(
SoxDescriptor& fd,
at::Tensor output,
int64_t buffer_length) {
void read_audio(SoxDescriptor& fd, at::Tensor output, int64_t buffer_length) {
std::vector<sox_sample_t> buffer(buffer_length);
int number_of_channels = fd->signal.channels;
......@@ -64,8 +61,7 @@ void read_audio(
} // namespace
std::tuple<sox_signalinfo_t, sox_encodinginfo_t> get_info(
const std::string& file_name
) {
const std::string& file_name) {
SoxDescriptor fd(sox_open_read(
file_name.c_str(),
/*signal=*/nullptr,
......@@ -86,7 +82,6 @@ int read_audio_file(
sox_signalinfo_t* si,
sox_encodinginfo_t* ei,
const char* ft) {
SoxDescriptor fd(sox_open_read(file_name.c_str(), si, ei, ft));
if (fd.get() == nullptr) {
throw std::runtime_error("Error opening audio file");
......@@ -120,7 +115,8 @@ int read_audio_file(
// seek to offset point before reading data
if (sox_seek(fd.get(), offset, 0) == SOX_EOF) {
throw std::runtime_error("sox_seek reached EOF, try reducing offset or num_samples");
throw std::runtime_error(
"sox_seek reached EOF, try reducing offset or num_samples");
}
// read data and fill output tensor
......
#include <sox.h>
#include <torch/torch.h>
namespace torch { namespace audio {
namespace torch {
namespace audio {
/// Reads an audio file from the given `path` into the `output` `Tensor` and
/// returns the sample rate of the audio file.
......@@ -30,9 +31,10 @@ void write_audio_file(
/// Reads an audio file from the given `path` and returns a tuple of
/// sox_signalinfo_t and sox_encodinginfo_t, which contain information about
/// the audio file such as sample rate, length, bit precision, encoding and more.
/// Throws `std::runtime_error` if the audio file could not be opened, or an
/// error occurred during reading of the audio data.
/// the audio file such as sample rate, length, bit precision, encoding and
/// more. Throws `std::runtime_error` if the audio file could not be opened, or
/// an error occurred during reading of the audio data.
std::tuple<sox_signalinfo_t, sox_encodinginfo_t> get_info(
const std::string& file_name);
}} // namespace torch::audio
} // namespace audio
} // namespace torch
......@@ -43,7 +43,9 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
.def("get_sample_rate", &torchaudio::sox_io::SignalInfo::getSampleRate)
.def("get_num_channels", &torchaudio::sox_io::SignalInfo::getNumChannels)
.def("get_num_frames", &torchaudio::sox_io::SignalInfo::getNumFrames)
.def("get_bits_per_sample", &torchaudio::sox_io::SignalInfo::getBitsPerSample);
.def(
"get_bits_per_sample",
&torchaudio::sox_io::SignalInfo::getBitsPerSample);
m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info);
m.def(
......
......@@ -80,7 +80,9 @@ bool TensorSignal::getChannelsFirst() const {
}
SoxFormat::SoxFormat(sox_format_t* fd) noexcept : fd_(fd) {}
SoxFormat::~SoxFormat() { close(); }
SoxFormat::~SoxFormat() {
close();
}
sox_format_t* SoxFormat::operator->() const noexcept {
return fd_;
......@@ -291,7 +293,8 @@ sox_signalinfo_t get_signalinfo(
sox_encodinginfo_t get_encodinginfo(
const std::string filetype,
const caffe2::TypeMeta dtype) {
return sox_encodinginfo_t{/*encoding=*/get_encoding(filetype, dtype),
return sox_encodinginfo_t{
/*encoding=*/get_encoding(filetype, dtype),
/*bits_per_sample=*/get_precision(filetype, dtype),
/*compression=*/HUGE_VAL,
/*reverse_bytes=*/sox_option_default,
......@@ -304,7 +307,8 @@ sox_encodinginfo_t get_encodinginfo(
const std::string filetype,
const caffe2::TypeMeta dtype,
c10::optional<double>& compression) {
return sox_encodinginfo_t{/*encoding=*/get_encoding(filetype, dtype),
return sox_encodinginfo_t{
/*encoding=*/get_encoding(filetype, dtype),
/*bits_per_sample=*/get_precision(filetype, dtype),
/*compression=*/compression.value_or(HUGE_VAL),
/*reverse_bytes=*/sox_option_default,
......
......@@ -69,7 +69,7 @@ struct SoxFormat {
///
/// Verify that input file is found, has known encoding, and not empty
void validate_input_file(const SoxFormat& sf, bool check_length=true);
void validate_input_file(const SoxFormat& sf, bool check_length = true);
///
/// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32
......
......@@ -8,7 +8,8 @@
namespace {
int64_t cpu_rnnt_loss(torch::Tensor acts,
int64_t cpu_rnnt_loss(
torch::Tensor acts,
torch::Tensor labels,
torch::Tensor input_lengths,
torch::Tensor label_lengths,
......@@ -16,7 +17,6 @@ int64_t cpu_rnnt_loss(torch::Tensor acts,
torch::Tensor grads,
int64_t blank_label,
int64_t num_threads) {
int maxT = acts.size(1);
int maxU = acts.size(2);
int minibatch_size = acts.size(0);
......@@ -32,45 +32,54 @@ int64_t cpu_rnnt_loss(torch::Tensor acts,
options.num_threads = num_threads;
// have to use at least one
options.num_threads = std::max(options.num_threads, (unsigned int) 1);
options.num_threads = std::max(options.num_threads, (unsigned int)1);
size_t cpu_size_bytes = 0;
switch (acts.scalar_type()) {
case torch::ScalarType::Float:
{
get_workspace_size(maxT, maxU, minibatch_size,
false, &cpu_size_bytes);
case torch::ScalarType::Float: {
get_workspace_size(maxT, maxU, minibatch_size, false, &cpu_size_bytes);
std::vector<float> cpu_workspace(cpu_size_bytes / sizeof(float), 0);
compute_rnnt_loss(acts.data_ptr<float>(), grads.data_ptr<float>(),
labels.data_ptr<int>(), label_lengths.data_ptr<int>(),
input_lengths.data_ptr<int>(), alphabet_size,
minibatch_size, costs.data_ptr<float>(),
cpu_workspace.data(), options);
compute_rnnt_loss(
acts.data_ptr<float>(),
grads.data_ptr<float>(),
labels.data_ptr<int>(),
label_lengths.data_ptr<int>(),
input_lengths.data_ptr<int>(),
alphabet_size,
minibatch_size,
costs.data_ptr<float>(),
cpu_workspace.data(),
options);
return 0;
}
case torch::ScalarType::Double:
{
get_workspace_size(maxT, maxU, minibatch_size,
false, &cpu_size_bytes,
sizeof(double));
case torch::ScalarType::Double: {
get_workspace_size(
maxT, maxU, minibatch_size, false, &cpu_size_bytes, sizeof(double));
std::vector<double> cpu_workspace(cpu_size_bytes / sizeof(double), 0);
compute_rnnt_loss_fp64(acts.data_ptr<double>(), grads.data_ptr<double>(),
labels.data_ptr<int>(), label_lengths.data_ptr<int>(),
input_lengths.data_ptr<int>(), alphabet_size,
minibatch_size, costs.data_ptr<double>(),
cpu_workspace.data(), options);
compute_rnnt_loss_fp64(
acts.data_ptr<double>(),
grads.data_ptr<double>(),
labels.data_ptr<int>(),
label_lengths.data_ptr<int>(),
input_lengths.data_ptr<int>(),
alphabet_size,
minibatch_size,
costs.data_ptr<double>(),
cpu_workspace.data(),
options);
return 0;
}
default:
TORCH_CHECK(false,
std::string(__func__) + " not implemented for '" + toString(acts.scalar_type()) + "'"
);
TORCH_CHECK(
false,
std::string(__func__) + " not implemented for '" +
toString(acts.scalar_type()) + "'");
}
return -1;
}
......@@ -82,7 +91,8 @@ TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
}
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("rnnt_loss(Tensor acts,"
m.def(
"rnnt_loss(Tensor acts,"
"Tensor labels,"
"Tensor input_lengths,"
"Tensor label_lengths,"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment