Commit f234e51f authored by Javier Cardenete Morales's avatar Javier Cardenete Morales Committed by Facebook GitHub Bot
Browse files

Replace 'runtime_error' exception with 'TORCH_CHECK' in TorchAudio sox (#2592)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2592

std::runtime_error does not preserve the C++ stack trace, so it is unclear to users what went wrong internally.

PyTorch's TORCH_CHECK macro allows to print C++ stack trace when TORCH_SHOW_CPP_STACKTRACES environment variable is set to 1.

Reviewed By: mthrok

Differential Revision: D38219331

fbshipit-source-id: f51c27111077e927f97127f73f83a31b8e74f61f
parent d6267031
......@@ -20,16 +20,15 @@ void initialize_sox_effects() {
switch (SOX_RESOURCE_STATE) {
case NotInitialized:
if (sox_init() != SOX_SUCCESS) {
throw std::runtime_error("Failed to initialize sox effects.");
};
TORCH_CHECK(
sox_init() == SOX_SUCCESS, "Failed to initialize sox effects.");
SOX_RESOURCE_STATE = Initialized;
break;
case Initialized:
break;
case ShutDown:
throw std::runtime_error(
"SoX Effects has been shut down. Cannot initialize again.");
TORCH_CHECK(
false, "SoX Effects has been shut down. Cannot initialize again.");
}
};
......@@ -38,12 +37,10 @@ void shutdown_sox_effects() {
switch (SOX_RESOURCE_STATE) {
case NotInitialized:
throw std::runtime_error(
"SoX Effects is not initialized. Cannot shutdown.");
TORCH_CHECK(false, "SoX Effects is not initialized. Cannot shutdown.");
case Initialized:
if (sox_quit() != SOX_SUCCESS) {
throw std::runtime_error("Failed to initialize sox effects.");
};
TORCH_CHECK(
sox_quit() == SOX_SUCCESS, "Failed to initialize sox effects.");
SOX_RESOURCE_STATE = ShutDown;
break;
case ShutDown:
......
#include <torchaudio/csrc/sox/effects_chain.h>
#include <torchaudio/csrc/sox/utils.h>
#include "c10/util/Exception.h"
using namespace torch::indexing;
using namespace torchaudio::sox_utils;
......@@ -80,7 +81,7 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
break;
}
default:
throw std::runtime_error("Unexpected dtype.");
TORCH_CHECK(false, "Unexpected dtype: ", chunk.dtype());
}
// Write to buffer
chunk = chunk.contiguous();
......@@ -114,12 +115,13 @@ int file_output_flow(
if (*isamp) {
auto sf = static_cast<FileOutputPriv*>(effp->priv)->sf;
if (sox_write(sf, ibuf, *isamp) != *isamp) {
if (sf->sox_errno) {
std::ostringstream stream;
stream << sf->sox_errstr << " " << sox_strerror(sf->sox_errno) << " "
<< sf->filename;
throw std::runtime_error(stream.str());
}
TORCH_CHECK(
!sf->sox_errno,
sf->sox_errstr,
" ",
sox_strerror(sf->sox_errno),
" ",
sf->filename);
return SOX_EOF;
}
}
......@@ -198,9 +200,7 @@ SoxEffectsChain::SoxEffectsChain(
interm_sig_(),
out_sig_(),
sec_(sox_create_effects_chain(&in_enc_, &out_enc_)) {
if (!sec_) {
throw std::runtime_error("Failed to create effect chain.");
}
TORCH_CHECK(sec_, "Failed to create effect chain.");
}
SoxEffectsChain::~SoxEffectsChain() {
......@@ -225,20 +225,18 @@ void SoxEffectsChain::addInputTensor(
priv->waveform = waveform;
priv->sample_rate = sample_rate;
priv->channels_first = channels_first;
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
throw std::runtime_error(
TORCH_CHECK(
sox_add_effect(sec_, e, &interm_sig_, &in_sig_) == SOX_SUCCESS,
"Internal Error: Failed to add effect: input_tensor");
}
}
void SoxEffectsChain::addOutputBuffer(
std::vector<sox_sample_t>* output_buffer) {
SoxEffect e(sox_create_effect(get_tensor_output_handler()));
static_cast<TensorOutputPriv*>(e->priv)->buffer = output_buffer;
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
throw std::runtime_error(
TORCH_CHECK(
sox_add_effect(sec_, e, &interm_sig_, &in_sig_) == SOX_SUCCESS,
"Internal Error: Failed to add effect: output_tensor");
}
}
void SoxEffectsChain::addInputFile(sox_format_t* sf) {
......@@ -247,42 +245,34 @@ void SoxEffectsChain::addInputFile(sox_format_t* sf) {
SoxEffect e(sox_create_effect(sox_find_effect("input")));
char* opts[] = {(char*)sf};
sox_effect_options(e, 1, opts);
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
std::ostringstream stream;
stream << "Internal Error: Failed to add effect: input " << sf->filename;
throw std::runtime_error(stream.str());
}
TORCH_CHECK(
sox_add_effect(sec_, e, &interm_sig_, &in_sig_) == SOX_SUCCESS,
"Internal Error: Failed to add effect: input ",
sf->filename);
}
void SoxEffectsChain::addOutputFile(sox_format_t* sf) {
out_sig_ = sf->signal;
SoxEffect e(sox_create_effect(get_file_output_handler()));
static_cast<FileOutputPriv*>(e->priv)->sf = sf;
if (sox_add_effect(sec_, e, &interm_sig_, &out_sig_) != SOX_SUCCESS) {
std::ostringstream stream;
stream << "Internal Error: Failed to add effect: output " << sf->filename;
throw std::runtime_error(stream.str());
}
TORCH_CHECK(
sox_add_effect(sec_, e, &interm_sig_, &out_sig_) == SOX_SUCCESS,
"Internal Error: Failed to add effect: output ",
sf->filename);
}
void SoxEffectsChain::addEffect(const std::vector<std::string> effect) {
const auto num_args = effect.size();
if (num_args == 0) {
throw std::runtime_error("Invalid argument: empty effect.");
}
TORCH_CHECK(num_args != 0, "Invalid argument: empty effect.");
const auto name = effect[0];
if (UNSUPPORTED_EFFECTS.find(name) != UNSUPPORTED_EFFECTS.end()) {
std::ostringstream stream;
stream << "Unsupported effect: " << name;
throw std::runtime_error(stream.str());
}
TORCH_CHECK(
UNSUPPORTED_EFFECTS.find(name) == UNSUPPORTED_EFFECTS.end(),
"Unsupported effect: ",
name)
auto returned_effect = sox_find_effect(name.c_str());
if (!returned_effect) {
std::ostringstream stream;
stream << "Unsupported effect: " << name;
throw std::runtime_error(stream.str());
}
TORCH_CHECK(returned_effect, "Unsupported effect: ", name)
SoxEffect e(sox_create_effect(returned_effect));
const auto num_options = num_args - 1;
......@@ -290,25 +280,16 @@ void SoxEffectsChain::addEffect(const std::vector<std::string> effect) {
for (size_t i = 1; i < num_args; ++i) {
opts.push_back((char*)effect[i].c_str());
}
if (sox_effect_options(e, num_options, num_options ? opts.data() : nullptr) !=
SOX_SUCCESS) {
std::ostringstream stream;
stream << "Invalid effect option:";
for (const auto& v : effect) {
stream << " " << v;
}
throw std::runtime_error(stream.str());
}
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
std::ostringstream stream;
stream << "Internal Error: Failed to add effect: \"" << name;
for (size_t i = 1; i < num_args; ++i) {
stream << " " << effect[i];
}
stream << "\"";
throw std::runtime_error(stream.str());
}
TORCH_CHECK(
sox_effect_options(e, num_options, num_options ? opts.data() : nullptr) ==
SOX_SUCCESS,
"Invalid effect option: ",
c10::Join(" ", effect))
TORCH_CHECK(
sox_add_effect(sec_, e, &interm_sig_, &in_sig_) == SOX_SUCCESS,
"Internal Error: Failed to add effect: \"",
c10::Join(" ", effect),
"\"");
}
int64_t SoxEffectsChain::getOutputNumChannels() {
......
......@@ -36,15 +36,14 @@ std::vector<std::vector<std::string>> get_effects(
const c10::optional<int64_t>& frame_offset,
const c10::optional<int64_t>& num_frames) {
const auto offset = frame_offset.value_or(0);
if (offset < 0) {
throw std::runtime_error(
"Invalid argument: frame_offset must be non-negative.");
}
TORCH_CHECK(
offset >= 0,
"Invalid argument: frame_offset must be non-negative. Found: ",
offset);
const auto frames = num_frames.value_or(-1);
if (frames == 0 || frames < -1) {
throw std::runtime_error(
TORCH_CHECK(
frames > 0 || frames == -1,
"Invalid argument: num_frames must be -1 or greater than 0.");
}
std::vector<std::vector<std::string>> effects;
if (frames != -1) {
......@@ -119,10 +118,10 @@ void save_audio_file(
/*oob=*/nullptr,
/*overwrite_permitted=*/nullptr));
if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error(
"Error saving audio file: failed to open file " + path);
}
TORCH_CHECK(
static_cast<sox_format_t*>(sf) != nullptr,
"Error saving audio file: failed to open file ",
path);
torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()),
......
......@@ -24,9 +24,7 @@ Format get_format_from_string(const std::string& format) {
return Format::HTK;
if (format == "gsm")
return Format::GSM;
std::ostringstream stream;
stream << "Internal Error: unexpected format value: " << format;
throw std::runtime_error(stream.str());
TORCH_CHECK(false, "Internal Error: unexpected format value: ", format);
}
std::string to_string(Encoding v) {
......@@ -56,7 +54,7 @@ std::string to_string(Encoding v) {
case Encoding::OPUS:
return "OPUS";
default:
throw std::runtime_error("Internal Error: unexpected encoding.");
TORCH_CHECK(false, "Internal Error: unexpected encoding.");
}
}
......@@ -74,9 +72,7 @@ Encoding get_encoding_from_option(const c10::optional<std::string> encoding) {
return Encoding::ULAW;
if (v == "ALAW")
return Encoding::ALAW;
std::ostringstream stream;
stream << "Internal Error: unexpected encoding value: " << v;
throw std::runtime_error(stream.str());
TORCH_CHECK(false, "Internal Error: unexpected encoding value: ", v);
}
BitDepth get_bit_depth_from_option(const c10::optional<int64_t> bit_depth) {
......@@ -95,9 +91,7 @@ BitDepth get_bit_depth_from_option(const c10::optional<int64_t> bit_depth) {
case 64:
return BitDepth::B64;
default: {
std::ostringstream s;
s << "Internal Error: unexpected bit depth value: " << v;
throw std::runtime_error(s.str());
TORCH_CHECK(false, "Internal Error: unexpected bit depth value: ", v);
}
}
}
......
......@@ -86,23 +86,18 @@ void SoxFormat::close() {
}
void validate_input_file(const SoxFormat& sf, const std::string& path) {
if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error(
TORCH_CHECK(
static_cast<sox_format_t*>(sf) != nullptr,
"Error loading audio file: failed to open file " + path);
}
if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
throw std::runtime_error("Error loading audio file: unknown encoding.");
}
TORCH_CHECK(
sf->encoding.encoding != SOX_ENCODING_UNKNOWN,
"Error loading audio file: unknown encoding.");
}
void validate_input_tensor(const torch::Tensor tensor) {
if (!tensor.device().is_cpu()) {
throw std::runtime_error("Input tensor has to be on CPU.");
}
TORCH_CHECK(tensor.device().is_cpu(), "Input tensor has to be on CPU.");
if (tensor.ndimension() != 2) {
throw std::runtime_error("Input tensor has to be 2D.");
}
TORCH_CHECK(tensor.ndimension() == 2, "Input tensor has to be 2D.");
switch (tensor.dtype().toScalarType()) {
case c10::ScalarType::Byte:
......@@ -111,7 +106,8 @@ void validate_input_tensor(const torch::Tensor tensor) {
case c10::ScalarType::Float:
break;
default:
throw std::runtime_error(
TORCH_CHECK(
false,
"Input tensor has to be one of float32, int32, int16 or uint8 type.");
}
}
......@@ -131,7 +127,8 @@ caffe2::TypeMeta get_dtype(
case 32:
return torch::kInt32;
default:
throw std::runtime_error(
TORCH_CHECK(
false,
"Only 16, 24, and 32 bits are supported for signed PCM.");
}
default:
......@@ -180,7 +177,7 @@ torch::Tensor convert_to_tensor(
ptr[i] = SOX_SAMPLE_TO_UNSIGNED_8BIT(buffer[i], dummy);
}
} else {
throw std::runtime_error("Unsupported dtype.");
TORCH_CHECK(false, "Unsupported dtype: ", dtype);
}
if (channels_first) {
t = t.transpose(1, 0);
......@@ -215,7 +212,7 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
case c10::ScalarType::Byte:
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
default:
throw std::runtime_error("Internal Error: Unexpected dtype.");
TORCH_CHECK(false, "Internal Error: Unexpected dtype: ", dtype);
}
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
......@@ -228,8 +225,8 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
case BitDepth::NOT_PROVIDED:
return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
case BitDepth::B8:
throw std::runtime_error(
format + " does not support 8-bit signed PCM encoding.");
TORCH_CHECK(
false, format, " does not support 8-bit signed PCM encoding.");
default:
return std::make_tuple<>(
SOX_ENCODING_SIGN2, static_cast<unsigned>(bits_per_sample));
......@@ -240,8 +237,8 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
default:
throw std::runtime_error(
format + " only supports 8-bit for unsigned PCM encoding.");
TORCH_CHECK(
false, format, " only supports 8-bit for unsigned PCM encoding.");
}
case Encoding::PCM_FLOAT:
switch (bits_per_sample) {
......@@ -251,8 +248,9 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
case BitDepth::B64:
return std::make_tuple<>(SOX_ENCODING_FLOAT, 64);
default:
throw std::runtime_error(
format +
TORCH_CHECK(
false,
format,
" only supports 32-bit or 64-bit for floating-point PCM encoding.");
}
case Encoding::ULAW:
......@@ -261,8 +259,8 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ULAW, 8);
default:
throw std::runtime_error(
format + " only supports 8-bit for mu-law encoding.");
TORCH_CHECK(
false, format, " only supports 8-bit for mu-law encoding.");
}
case Encoding::ALAW:
switch (bits_per_sample) {
......@@ -270,12 +268,12 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ALAW, 8);
default:
throw std::runtime_error(
format + " only supports 8-bit for a-law encoding.");
TORCH_CHECK(
false, format, " only supports 8-bit for a-law encoding.");
}
default:
throw std::runtime_error(
format + " does not support encoding: " + to_string(encoding));
TORCH_CHECK(
false, format, " does not support encoding: " + to_string(encoding));
}
}
......@@ -293,41 +291,46 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding(
case Format::AMB:
return get_save_encoding_for_wav(format, dtype, enc, bps);
case Format::MP3:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("mp3 does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
TORCH_CHECK(
enc == Encoding::NOT_PROVIDED,
"mp3 does not support `encoding` option.");
TORCH_CHECK(
bps == BitDepth::NOT_PROVIDED,
"mp3 does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_MP3, 16);
case Format::HTK:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("htk does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
TORCH_CHECK(
enc == Encoding::NOT_PROVIDED,
"htk does not support `encoding` option.");
TORCH_CHECK(
bps == BitDepth::NOT_PROVIDED,
"htk does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
case Format::VORBIS:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("vorbis does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
TORCH_CHECK(
enc == Encoding::NOT_PROVIDED,
"vorbis does not support `encoding` option.");
TORCH_CHECK(
bps == BitDepth::NOT_PROVIDED,
"vorbis does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_VORBIS, 16);
case Format::AMR_NB:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("amr-nb does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
TORCH_CHECK(
enc == Encoding::NOT_PROVIDED,
"amr-nb does not support `encoding` option.");
TORCH_CHECK(
bps == BitDepth::NOT_PROVIDED,
"amr-nb does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_AMR_NB, 16);
case Format::FLAC:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("flac does not support `encoding` option.");
TORCH_CHECK(
enc == Encoding::NOT_PROVIDED,
"flac does not support `encoding` option.");
switch (bps) {
case BitDepth::B32:
case BitDepth::B64:
throw std::runtime_error(
"flac does not support `bits_per_sample` larger than 24.");
TORCH_CHECK(
false, "flac does not support `bits_per_sample` larger than 24.");
default:
return std::make_tuple<>(
SOX_ENCODING_FLAC, static_cast<unsigned>(bps));
......@@ -344,18 +347,17 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding(
SOX_ENCODING_SIGN2, static_cast<unsigned>(bps));
}
case Encoding::PCM_UNSIGNED:
throw std::runtime_error(
"sph does not support unsigned integer PCM.");
TORCH_CHECK(false, "sph does not support unsigned integer PCM.");
case Encoding::PCM_FLOAT:
throw std::runtime_error("sph does not support floating point PCM.");
TORCH_CHECK(false, "sph does not support floating point PCM.");
case Encoding::ULAW:
switch (bps) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ULAW, 8);
default:
throw std::runtime_error(
"sph only supports 8-bit for mu-law encoding.");
TORCH_CHECK(
false, "sph only supports 8-bit for mu-law encoding.");
}
case Encoding::ALAW:
switch (bps) {
......@@ -367,19 +369,20 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding(
SOX_ENCODING_ALAW, static_cast<unsigned>(bps));
}
default:
throw std::runtime_error(
"sph does not support encoding: " + encoding.value());
TORCH_CHECK(
false, "sph does not support encoding: ", encoding.value());
}
case Format::GSM:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("gsm does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
TORCH_CHECK(
enc == Encoding::NOT_PROVIDED,
"gsm does not support `encoding` option.");
TORCH_CHECK(
bps == BitDepth::NOT_PROVIDED,
"gsm does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_GSM, 16);
default:
throw std::runtime_error("Unsupported format: " + format);
TORCH_CHECK(false, "Unsupported format: " + format);
}
}
......@@ -401,7 +404,7 @@ unsigned get_precision(const std::string filetype, caffe2::TypeMeta dtype) {
case c10::ScalarType::Float:
return 32;
default:
throw std::runtime_error("Unsupported dtype.");
TORCH_CHECK(false, "Unsupported dtype: ", dtype);
}
}
if (filetype == "sph")
......@@ -415,7 +418,7 @@ unsigned get_precision(const std::string filetype, caffe2::TypeMeta dtype) {
if (filetype == "htk") {
return 16;
}
throw std::runtime_error("Unsupported file type: " + filetype);
TORCH_CHECK(false, "Unsupported file type: ", filetype);
}
} // namespace
......@@ -445,7 +448,7 @@ sox_encodinginfo_t get_tensor_encodinginfo(caffe2::TypeMeta dtype) {
case c10::ScalarType::Float:
return SOX_ENCODING_FLOAT;
default:
throw std::runtime_error("Unsupported dtype.");
TORCH_CHECK(false, "Unsupported dtype: ", dtype);
}
}();
unsigned bits_per_sample = [&]() {
......@@ -459,7 +462,7 @@ sox_encodinginfo_t get_tensor_encodinginfo(caffe2::TypeMeta dtype) {
case c10::ScalarType::Float:
return 32;
default:
throw std::runtime_error("Unsupported dtype.");
TORCH_CHECK(false, "Unsupported dtype: ", dtype);
}
}();
return sox_encodinginfo_t{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment