Commit f234e51f authored by Javier Cardenete Morales's avatar Javier Cardenete Morales Committed by Facebook GitHub Bot
Browse files

Replace 'runtime_error' exception with 'TORCH_CHECK' in TorchAudio sox (#2592)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2592

std::runtime_error does not preserve the C++ stack trace, so it is unclear to users what went wrong internally.

PyTorch's TORCH_CHECK macro allows to print C++ stack trace when TORCH_SHOW_CPP_STACKTRACES environment variable is set to 1.

Reviewed By: mthrok

Differential Revision: D38219331

fbshipit-source-id: f51c27111077e927f97127f73f83a31b8e74f61f
parent d6267031
...@@ -20,16 +20,15 @@ void initialize_sox_effects() { ...@@ -20,16 +20,15 @@ void initialize_sox_effects() {
switch (SOX_RESOURCE_STATE) { switch (SOX_RESOURCE_STATE) {
case NotInitialized: case NotInitialized:
if (sox_init() != SOX_SUCCESS) { TORCH_CHECK(
throw std::runtime_error("Failed to initialize sox effects."); sox_init() == SOX_SUCCESS, "Failed to initialize sox effects.");
};
SOX_RESOURCE_STATE = Initialized; SOX_RESOURCE_STATE = Initialized;
break; break;
case Initialized: case Initialized:
break; break;
case ShutDown: case ShutDown:
throw std::runtime_error( TORCH_CHECK(
"SoX Effects has been shut down. Cannot initialize again."); false, "SoX Effects has been shut down. Cannot initialize again.");
} }
}; };
...@@ -38,12 +37,10 @@ void shutdown_sox_effects() { ...@@ -38,12 +37,10 @@ void shutdown_sox_effects() {
switch (SOX_RESOURCE_STATE) { switch (SOX_RESOURCE_STATE) {
case NotInitialized: case NotInitialized:
throw std::runtime_error( TORCH_CHECK(false, "SoX Effects is not initialized. Cannot shutdown.");
"SoX Effects is not initialized. Cannot shutdown.");
case Initialized: case Initialized:
if (sox_quit() != SOX_SUCCESS) { TORCH_CHECK(
throw std::runtime_error("Failed to initialize sox effects."); sox_quit() == SOX_SUCCESS, "Failed to initialize sox effects.");
};
SOX_RESOURCE_STATE = ShutDown; SOX_RESOURCE_STATE = ShutDown;
break; break;
case ShutDown: case ShutDown:
......
#include <torchaudio/csrc/sox/effects_chain.h> #include <torchaudio/csrc/sox/effects_chain.h>
#include <torchaudio/csrc/sox/utils.h> #include <torchaudio/csrc/sox/utils.h>
#include "c10/util/Exception.h"
using namespace torch::indexing; using namespace torch::indexing;
using namespace torchaudio::sox_utils; using namespace torchaudio::sox_utils;
...@@ -80,7 +81,7 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { ...@@ -80,7 +81,7 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
break; break;
} }
default: default:
throw std::runtime_error("Unexpected dtype."); TORCH_CHECK(false, "Unexpected dtype: ", chunk.dtype());
} }
// Write to buffer // Write to buffer
chunk = chunk.contiguous(); chunk = chunk.contiguous();
...@@ -114,12 +115,13 @@ int file_output_flow( ...@@ -114,12 +115,13 @@ int file_output_flow(
if (*isamp) { if (*isamp) {
auto sf = static_cast<FileOutputPriv*>(effp->priv)->sf; auto sf = static_cast<FileOutputPriv*>(effp->priv)->sf;
if (sox_write(sf, ibuf, *isamp) != *isamp) { if (sox_write(sf, ibuf, *isamp) != *isamp) {
if (sf->sox_errno) { TORCH_CHECK(
std::ostringstream stream; !sf->sox_errno,
stream << sf->sox_errstr << " " << sox_strerror(sf->sox_errno) << " " sf->sox_errstr,
<< sf->filename; " ",
throw std::runtime_error(stream.str()); sox_strerror(sf->sox_errno),
} " ",
sf->filename);
return SOX_EOF; return SOX_EOF;
} }
} }
...@@ -198,9 +200,7 @@ SoxEffectsChain::SoxEffectsChain( ...@@ -198,9 +200,7 @@ SoxEffectsChain::SoxEffectsChain(
interm_sig_(), interm_sig_(),
out_sig_(), out_sig_(),
sec_(sox_create_effects_chain(&in_enc_, &out_enc_)) { sec_(sox_create_effects_chain(&in_enc_, &out_enc_)) {
if (!sec_) { TORCH_CHECK(sec_, "Failed to create effect chain.");
throw std::runtime_error("Failed to create effect chain.");
}
} }
SoxEffectsChain::~SoxEffectsChain() { SoxEffectsChain::~SoxEffectsChain() {
...@@ -225,20 +225,18 @@ void SoxEffectsChain::addInputTensor( ...@@ -225,20 +225,18 @@ void SoxEffectsChain::addInputTensor(
priv->waveform = waveform; priv->waveform = waveform;
priv->sample_rate = sample_rate; priv->sample_rate = sample_rate;
priv->channels_first = channels_first; priv->channels_first = channels_first;
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { TORCH_CHECK(
throw std::runtime_error( sox_add_effect(sec_, e, &interm_sig_, &in_sig_) == SOX_SUCCESS,
"Internal Error: Failed to add effect: input_tensor"); "Internal Error: Failed to add effect: input_tensor");
}
} }
void SoxEffectsChain::addOutputBuffer( void SoxEffectsChain::addOutputBuffer(
std::vector<sox_sample_t>* output_buffer) { std::vector<sox_sample_t>* output_buffer) {
SoxEffect e(sox_create_effect(get_tensor_output_handler())); SoxEffect e(sox_create_effect(get_tensor_output_handler()));
static_cast<TensorOutputPriv*>(e->priv)->buffer = output_buffer; static_cast<TensorOutputPriv*>(e->priv)->buffer = output_buffer;
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { TORCH_CHECK(
throw std::runtime_error( sox_add_effect(sec_, e, &interm_sig_, &in_sig_) == SOX_SUCCESS,
"Internal Error: Failed to add effect: output_tensor"); "Internal Error: Failed to add effect: output_tensor");
}
} }
void SoxEffectsChain::addInputFile(sox_format_t* sf) { void SoxEffectsChain::addInputFile(sox_format_t* sf) {
...@@ -247,42 +245,34 @@ void SoxEffectsChain::addInputFile(sox_format_t* sf) { ...@@ -247,42 +245,34 @@ void SoxEffectsChain::addInputFile(sox_format_t* sf) {
SoxEffect e(sox_create_effect(sox_find_effect("input"))); SoxEffect e(sox_create_effect(sox_find_effect("input")));
char* opts[] = {(char*)sf}; char* opts[] = {(char*)sf};
sox_effect_options(e, 1, opts); sox_effect_options(e, 1, opts);
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { TORCH_CHECK(
std::ostringstream stream; sox_add_effect(sec_, e, &interm_sig_, &in_sig_) == SOX_SUCCESS,
stream << "Internal Error: Failed to add effect: input " << sf->filename; "Internal Error: Failed to add effect: input ",
throw std::runtime_error(stream.str()); sf->filename);
}
} }
void SoxEffectsChain::addOutputFile(sox_format_t* sf) { void SoxEffectsChain::addOutputFile(sox_format_t* sf) {
out_sig_ = sf->signal; out_sig_ = sf->signal;
SoxEffect e(sox_create_effect(get_file_output_handler())); SoxEffect e(sox_create_effect(get_file_output_handler()));
static_cast<FileOutputPriv*>(e->priv)->sf = sf; static_cast<FileOutputPriv*>(e->priv)->sf = sf;
if (sox_add_effect(sec_, e, &interm_sig_, &out_sig_) != SOX_SUCCESS) { TORCH_CHECK(
std::ostringstream stream; sox_add_effect(sec_, e, &interm_sig_, &out_sig_) == SOX_SUCCESS,
stream << "Internal Error: Failed to add effect: output " << sf->filename; "Internal Error: Failed to add effect: output ",
throw std::runtime_error(stream.str()); sf->filename);
}
} }
void SoxEffectsChain::addEffect(const std::vector<std::string> effect) { void SoxEffectsChain::addEffect(const std::vector<std::string> effect) {
const auto num_args = effect.size(); const auto num_args = effect.size();
if (num_args == 0) { TORCH_CHECK(num_args != 0, "Invalid argument: empty effect.");
throw std::runtime_error("Invalid argument: empty effect.");
}
const auto name = effect[0]; const auto name = effect[0];
if (UNSUPPORTED_EFFECTS.find(name) != UNSUPPORTED_EFFECTS.end()) { TORCH_CHECK(
std::ostringstream stream; UNSUPPORTED_EFFECTS.find(name) == UNSUPPORTED_EFFECTS.end(),
stream << "Unsupported effect: " << name; "Unsupported effect: ",
throw std::runtime_error(stream.str()); name)
}
auto returned_effect = sox_find_effect(name.c_str()); auto returned_effect = sox_find_effect(name.c_str());
if (!returned_effect) { TORCH_CHECK(returned_effect, "Unsupported effect: ", name)
std::ostringstream stream;
stream << "Unsupported effect: " << name;
throw std::runtime_error(stream.str());
}
SoxEffect e(sox_create_effect(returned_effect)); SoxEffect e(sox_create_effect(returned_effect));
const auto num_options = num_args - 1; const auto num_options = num_args - 1;
...@@ -290,25 +280,16 @@ void SoxEffectsChain::addEffect(const std::vector<std::string> effect) { ...@@ -290,25 +280,16 @@ void SoxEffectsChain::addEffect(const std::vector<std::string> effect) {
for (size_t i = 1; i < num_args; ++i) { for (size_t i = 1; i < num_args; ++i) {
opts.push_back((char*)effect[i].c_str()); opts.push_back((char*)effect[i].c_str());
} }
if (sox_effect_options(e, num_options, num_options ? opts.data() : nullptr) != TORCH_CHECK(
SOX_SUCCESS) { sox_effect_options(e, num_options, num_options ? opts.data() : nullptr) ==
std::ostringstream stream; SOX_SUCCESS,
stream << "Invalid effect option:"; "Invalid effect option: ",
for (const auto& v : effect) { c10::Join(" ", effect))
stream << " " << v; TORCH_CHECK(
} sox_add_effect(sec_, e, &interm_sig_, &in_sig_) == SOX_SUCCESS,
throw std::runtime_error(stream.str()); "Internal Error: Failed to add effect: \"",
} c10::Join(" ", effect),
"\"");
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
std::ostringstream stream;
stream << "Internal Error: Failed to add effect: \"" << name;
for (size_t i = 1; i < num_args; ++i) {
stream << " " << effect[i];
}
stream << "\"";
throw std::runtime_error(stream.str());
}
} }
int64_t SoxEffectsChain::getOutputNumChannels() { int64_t SoxEffectsChain::getOutputNumChannels() {
......
...@@ -36,15 +36,14 @@ std::vector<std::vector<std::string>> get_effects( ...@@ -36,15 +36,14 @@ std::vector<std::vector<std::string>> get_effects(
const c10::optional<int64_t>& frame_offset, const c10::optional<int64_t>& frame_offset,
const c10::optional<int64_t>& num_frames) { const c10::optional<int64_t>& num_frames) {
const auto offset = frame_offset.value_or(0); const auto offset = frame_offset.value_or(0);
if (offset < 0) { TORCH_CHECK(
throw std::runtime_error( offset >= 0,
"Invalid argument: frame_offset must be non-negative."); "Invalid argument: frame_offset must be non-negative. Found: ",
} offset);
const auto frames = num_frames.value_or(-1); const auto frames = num_frames.value_or(-1);
if (frames == 0 || frames < -1) { TORCH_CHECK(
throw std::runtime_error( frames > 0 || frames == -1,
"Invalid argument: num_frames must be -1 or greater than 0."); "Invalid argument: num_frames must be -1 or greater than 0.");
}
std::vector<std::vector<std::string>> effects; std::vector<std::vector<std::string>> effects;
if (frames != -1) { if (frames != -1) {
...@@ -119,10 +118,10 @@ void save_audio_file( ...@@ -119,10 +118,10 @@ void save_audio_file(
/*oob=*/nullptr, /*oob=*/nullptr,
/*overwrite_permitted=*/nullptr)); /*overwrite_permitted=*/nullptr));
if (static_cast<sox_format_t*>(sf) == nullptr) { TORCH_CHECK(
throw std::runtime_error( static_cast<sox_format_t*>(sf) != nullptr,
"Error saving audio file: failed to open file " + path); "Error saving audio file: failed to open file ",
} path);
torchaudio::sox_effects_chain::SoxEffectsChain chain( torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()), /*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()),
......
...@@ -24,9 +24,7 @@ Format get_format_from_string(const std::string& format) { ...@@ -24,9 +24,7 @@ Format get_format_from_string(const std::string& format) {
return Format::HTK; return Format::HTK;
if (format == "gsm") if (format == "gsm")
return Format::GSM; return Format::GSM;
std::ostringstream stream; TORCH_CHECK(false, "Internal Error: unexpected format value: ", format);
stream << "Internal Error: unexpected format value: " << format;
throw std::runtime_error(stream.str());
} }
std::string to_string(Encoding v) { std::string to_string(Encoding v) {
...@@ -56,7 +54,7 @@ std::string to_string(Encoding v) { ...@@ -56,7 +54,7 @@ std::string to_string(Encoding v) {
case Encoding::OPUS: case Encoding::OPUS:
return "OPUS"; return "OPUS";
default: default:
throw std::runtime_error("Internal Error: unexpected encoding."); TORCH_CHECK(false, "Internal Error: unexpected encoding.");
} }
} }
...@@ -74,9 +72,7 @@ Encoding get_encoding_from_option(const c10::optional<std::string> encoding) { ...@@ -74,9 +72,7 @@ Encoding get_encoding_from_option(const c10::optional<std::string> encoding) {
return Encoding::ULAW; return Encoding::ULAW;
if (v == "ALAW") if (v == "ALAW")
return Encoding::ALAW; return Encoding::ALAW;
std::ostringstream stream; TORCH_CHECK(false, "Internal Error: unexpected encoding value: ", v);
stream << "Internal Error: unexpected encoding value: " << v;
throw std::runtime_error(stream.str());
} }
BitDepth get_bit_depth_from_option(const c10::optional<int64_t> bit_depth) { BitDepth get_bit_depth_from_option(const c10::optional<int64_t> bit_depth) {
...@@ -95,9 +91,7 @@ BitDepth get_bit_depth_from_option(const c10::optional<int64_t> bit_depth) { ...@@ -95,9 +91,7 @@ BitDepth get_bit_depth_from_option(const c10::optional<int64_t> bit_depth) {
case 64: case 64:
return BitDepth::B64; return BitDepth::B64;
default: { default: {
std::ostringstream s; TORCH_CHECK(false, "Internal Error: unexpected bit depth value: ", v);
s << "Internal Error: unexpected bit depth value: " << v;
throw std::runtime_error(s.str());
} }
} }
} }
......
...@@ -86,23 +86,18 @@ void SoxFormat::close() { ...@@ -86,23 +86,18 @@ void SoxFormat::close() {
} }
void validate_input_file(const SoxFormat& sf, const std::string& path) { void validate_input_file(const SoxFormat& sf, const std::string& path) {
if (static_cast<sox_format_t*>(sf) == nullptr) { TORCH_CHECK(
throw std::runtime_error( static_cast<sox_format_t*>(sf) != nullptr,
"Error loading audio file: failed to open file " + path); "Error loading audio file: failed to open file " + path);
} TORCH_CHECK(
if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) { sf->encoding.encoding != SOX_ENCODING_UNKNOWN,
throw std::runtime_error("Error loading audio file: unknown encoding."); "Error loading audio file: unknown encoding.");
}
} }
void validate_input_tensor(const torch::Tensor tensor) { void validate_input_tensor(const torch::Tensor tensor) {
if (!tensor.device().is_cpu()) { TORCH_CHECK(tensor.device().is_cpu(), "Input tensor has to be on CPU.");
throw std::runtime_error("Input tensor has to be on CPU.");
}
if (tensor.ndimension() != 2) { TORCH_CHECK(tensor.ndimension() == 2, "Input tensor has to be 2D.");
throw std::runtime_error("Input tensor has to be 2D.");
}
switch (tensor.dtype().toScalarType()) { switch (tensor.dtype().toScalarType()) {
case c10::ScalarType::Byte: case c10::ScalarType::Byte:
...@@ -111,7 +106,8 @@ void validate_input_tensor(const torch::Tensor tensor) { ...@@ -111,7 +106,8 @@ void validate_input_tensor(const torch::Tensor tensor) {
case c10::ScalarType::Float: case c10::ScalarType::Float:
break; break;
default: default:
throw std::runtime_error( TORCH_CHECK(
false,
"Input tensor has to be one of float32, int32, int16 or uint8 type."); "Input tensor has to be one of float32, int32, int16 or uint8 type.");
} }
} }
...@@ -131,7 +127,8 @@ caffe2::TypeMeta get_dtype( ...@@ -131,7 +127,8 @@ caffe2::TypeMeta get_dtype(
case 32: case 32:
return torch::kInt32; return torch::kInt32;
default: default:
throw std::runtime_error( TORCH_CHECK(
false,
"Only 16, 24, and 32 bits are supported for signed PCM."); "Only 16, 24, and 32 bits are supported for signed PCM.");
} }
default: default:
...@@ -180,7 +177,7 @@ torch::Tensor convert_to_tensor( ...@@ -180,7 +177,7 @@ torch::Tensor convert_to_tensor(
ptr[i] = SOX_SAMPLE_TO_UNSIGNED_8BIT(buffer[i], dummy); ptr[i] = SOX_SAMPLE_TO_UNSIGNED_8BIT(buffer[i], dummy);
} }
} else { } else {
throw std::runtime_error("Unsupported dtype."); TORCH_CHECK(false, "Unsupported dtype: ", dtype);
} }
if (channels_first) { if (channels_first) {
t = t.transpose(1, 0); t = t.transpose(1, 0);
...@@ -215,7 +212,7 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav( ...@@ -215,7 +212,7 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
case c10::ScalarType::Byte: case c10::ScalarType::Byte:
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
default: default:
throw std::runtime_error("Internal Error: Unexpected dtype."); TORCH_CHECK(false, "Internal Error: Unexpected dtype: ", dtype);
} }
case BitDepth::B8: case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
...@@ -228,8 +225,8 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav( ...@@ -228,8 +225,8 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
case BitDepth::NOT_PROVIDED: case BitDepth::NOT_PROVIDED:
return std::make_tuple<>(SOX_ENCODING_SIGN2, 32); return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
case BitDepth::B8: case BitDepth::B8:
throw std::runtime_error( TORCH_CHECK(
format + " does not support 8-bit signed PCM encoding."); false, format, " does not support 8-bit signed PCM encoding.");
default: default:
return std::make_tuple<>( return std::make_tuple<>(
SOX_ENCODING_SIGN2, static_cast<unsigned>(bits_per_sample)); SOX_ENCODING_SIGN2, static_cast<unsigned>(bits_per_sample));
...@@ -240,8 +237,8 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav( ...@@ -240,8 +237,8 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
case BitDepth::B8: case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
default: default:
throw std::runtime_error( TORCH_CHECK(
format + " only supports 8-bit for unsigned PCM encoding."); false, format, " only supports 8-bit for unsigned PCM encoding.");
} }
case Encoding::PCM_FLOAT: case Encoding::PCM_FLOAT:
switch (bits_per_sample) { switch (bits_per_sample) {
...@@ -251,8 +248,9 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav( ...@@ -251,8 +248,9 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
case BitDepth::B64: case BitDepth::B64:
return std::make_tuple<>(SOX_ENCODING_FLOAT, 64); return std::make_tuple<>(SOX_ENCODING_FLOAT, 64);
default: default:
throw std::runtime_error( TORCH_CHECK(
format + false,
format,
" only supports 32-bit or 64-bit for floating-point PCM encoding."); " only supports 32-bit or 64-bit for floating-point PCM encoding.");
} }
case Encoding::ULAW: case Encoding::ULAW:
...@@ -261,8 +259,8 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav( ...@@ -261,8 +259,8 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
case BitDepth::B8: case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ULAW, 8); return std::make_tuple<>(SOX_ENCODING_ULAW, 8);
default: default:
throw std::runtime_error( TORCH_CHECK(
format + " only supports 8-bit for mu-law encoding."); false, format, " only supports 8-bit for mu-law encoding.");
} }
case Encoding::ALAW: case Encoding::ALAW:
switch (bits_per_sample) { switch (bits_per_sample) {
...@@ -270,12 +268,12 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav( ...@@ -270,12 +268,12 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
case BitDepth::B8: case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ALAW, 8); return std::make_tuple<>(SOX_ENCODING_ALAW, 8);
default: default:
throw std::runtime_error( TORCH_CHECK(
format + " only supports 8-bit for a-law encoding."); false, format, " only supports 8-bit for a-law encoding.");
} }
default: default:
throw std::runtime_error( TORCH_CHECK(
format + " does not support encoding: " + to_string(encoding)); false, format, " does not support encoding: " + to_string(encoding));
} }
} }
...@@ -293,41 +291,46 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding( ...@@ -293,41 +291,46 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding(
case Format::AMB: case Format::AMB:
return get_save_encoding_for_wav(format, dtype, enc, bps); return get_save_encoding_for_wav(format, dtype, enc, bps);
case Format::MP3: case Format::MP3:
if (enc != Encoding::NOT_PROVIDED) TORCH_CHECK(
throw std::runtime_error("mp3 does not support `encoding` option."); enc == Encoding::NOT_PROVIDED,
if (bps != BitDepth::NOT_PROVIDED) "mp3 does not support `encoding` option.");
throw std::runtime_error( TORCH_CHECK(
bps == BitDepth::NOT_PROVIDED,
"mp3 does not support `bits_per_sample` option."); "mp3 does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_MP3, 16); return std::make_tuple<>(SOX_ENCODING_MP3, 16);
case Format::HTK: case Format::HTK:
if (enc != Encoding::NOT_PROVIDED) TORCH_CHECK(
throw std::runtime_error("htk does not support `encoding` option."); enc == Encoding::NOT_PROVIDED,
if (bps != BitDepth::NOT_PROVIDED) "htk does not support `encoding` option.");
throw std::runtime_error( TORCH_CHECK(
bps == BitDepth::NOT_PROVIDED,
"htk does not support `bits_per_sample` option."); "htk does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
case Format::VORBIS: case Format::VORBIS:
if (enc != Encoding::NOT_PROVIDED) TORCH_CHECK(
throw std::runtime_error("vorbis does not support `encoding` option."); enc == Encoding::NOT_PROVIDED,
if (bps != BitDepth::NOT_PROVIDED) "vorbis does not support `encoding` option.");
throw std::runtime_error( TORCH_CHECK(
bps == BitDepth::NOT_PROVIDED,
"vorbis does not support `bits_per_sample` option."); "vorbis does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_VORBIS, 16); return std::make_tuple<>(SOX_ENCODING_VORBIS, 16);
case Format::AMR_NB: case Format::AMR_NB:
if (enc != Encoding::NOT_PROVIDED) TORCH_CHECK(
throw std::runtime_error("amr-nb does not support `encoding` option."); enc == Encoding::NOT_PROVIDED,
if (bps != BitDepth::NOT_PROVIDED) "amr-nb does not support `encoding` option.");
throw std::runtime_error( TORCH_CHECK(
bps == BitDepth::NOT_PROVIDED,
"amr-nb does not support `bits_per_sample` option."); "amr-nb does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_AMR_NB, 16); return std::make_tuple<>(SOX_ENCODING_AMR_NB, 16);
case Format::FLAC: case Format::FLAC:
if (enc != Encoding::NOT_PROVIDED) TORCH_CHECK(
throw std::runtime_error("flac does not support `encoding` option."); enc == Encoding::NOT_PROVIDED,
"flac does not support `encoding` option.");
switch (bps) { switch (bps) {
case BitDepth::B32: case BitDepth::B32:
case BitDepth::B64: case BitDepth::B64:
throw std::runtime_error( TORCH_CHECK(
"flac does not support `bits_per_sample` larger than 24."); false, "flac does not support `bits_per_sample` larger than 24.");
default: default:
return std::make_tuple<>( return std::make_tuple<>(
SOX_ENCODING_FLAC, static_cast<unsigned>(bps)); SOX_ENCODING_FLAC, static_cast<unsigned>(bps));
...@@ -344,18 +347,17 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding( ...@@ -344,18 +347,17 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding(
SOX_ENCODING_SIGN2, static_cast<unsigned>(bps)); SOX_ENCODING_SIGN2, static_cast<unsigned>(bps));
} }
case Encoding::PCM_UNSIGNED: case Encoding::PCM_UNSIGNED:
throw std::runtime_error( TORCH_CHECK(false, "sph does not support unsigned integer PCM.");
"sph does not support unsigned integer PCM.");
case Encoding::PCM_FLOAT: case Encoding::PCM_FLOAT:
throw std::runtime_error("sph does not support floating point PCM."); TORCH_CHECK(false, "sph does not support floating point PCM.");
case Encoding::ULAW: case Encoding::ULAW:
switch (bps) { switch (bps) {
case BitDepth::NOT_PROVIDED: case BitDepth::NOT_PROVIDED:
case BitDepth::B8: case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ULAW, 8); return std::make_tuple<>(SOX_ENCODING_ULAW, 8);
default: default:
throw std::runtime_error( TORCH_CHECK(
"sph only supports 8-bit for mu-law encoding."); false, "sph only supports 8-bit for mu-law encoding.");
} }
case Encoding::ALAW: case Encoding::ALAW:
switch (bps) { switch (bps) {
...@@ -367,19 +369,20 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding( ...@@ -367,19 +369,20 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding(
SOX_ENCODING_ALAW, static_cast<unsigned>(bps)); SOX_ENCODING_ALAW, static_cast<unsigned>(bps));
} }
default: default:
throw std::runtime_error( TORCH_CHECK(
"sph does not support encoding: " + encoding.value()); false, "sph does not support encoding: ", encoding.value());
} }
case Format::GSM: case Format::GSM:
if (enc != Encoding::NOT_PROVIDED) TORCH_CHECK(
throw std::runtime_error("gsm does not support `encoding` option."); enc == Encoding::NOT_PROVIDED,
if (bps != BitDepth::NOT_PROVIDED) "gsm does not support `encoding` option.");
throw std::runtime_error( TORCH_CHECK(
bps == BitDepth::NOT_PROVIDED,
"gsm does not support `bits_per_sample` option."); "gsm does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_GSM, 16); return std::make_tuple<>(SOX_ENCODING_GSM, 16);
default: default:
throw std::runtime_error("Unsupported format: " + format); TORCH_CHECK(false, "Unsupported format: " + format);
} }
} }
...@@ -401,7 +404,7 @@ unsigned get_precision(const std::string filetype, caffe2::TypeMeta dtype) { ...@@ -401,7 +404,7 @@ unsigned get_precision(const std::string filetype, caffe2::TypeMeta dtype) {
case c10::ScalarType::Float: case c10::ScalarType::Float:
return 32; return 32;
default: default:
throw std::runtime_error("Unsupported dtype."); TORCH_CHECK(false, "Unsupported dtype: ", dtype);
} }
} }
if (filetype == "sph") if (filetype == "sph")
...@@ -415,7 +418,7 @@ unsigned get_precision(const std::string filetype, caffe2::TypeMeta dtype) { ...@@ -415,7 +418,7 @@ unsigned get_precision(const std::string filetype, caffe2::TypeMeta dtype) {
if (filetype == "htk") { if (filetype == "htk") {
return 16; return 16;
} }
throw std::runtime_error("Unsupported file type: " + filetype); TORCH_CHECK(false, "Unsupported file type: ", filetype);
} }
} // namespace } // namespace
...@@ -445,7 +448,7 @@ sox_encodinginfo_t get_tensor_encodinginfo(caffe2::TypeMeta dtype) { ...@@ -445,7 +448,7 @@ sox_encodinginfo_t get_tensor_encodinginfo(caffe2::TypeMeta dtype) {
case c10::ScalarType::Float: case c10::ScalarType::Float:
return SOX_ENCODING_FLOAT; return SOX_ENCODING_FLOAT;
default: default:
throw std::runtime_error("Unsupported dtype."); TORCH_CHECK(false, "Unsupported dtype: ", dtype);
} }
}(); }();
unsigned bits_per_sample = [&]() { unsigned bits_per_sample = [&]() {
...@@ -459,7 +462,7 @@ sox_encodinginfo_t get_tensor_encodinginfo(caffe2::TypeMeta dtype) { ...@@ -459,7 +462,7 @@ sox_encodinginfo_t get_tensor_encodinginfo(caffe2::TypeMeta dtype) {
case c10::ScalarType::Float: case c10::ScalarType::Float:
return 32; return 32;
default: default:
throw std::runtime_error("Unsupported dtype."); TORCH_CHECK(false, "Unsupported dtype: ", dtype);
} }
}(); }();
return sox_encodinginfo_t{ return sox_encodinginfo_t{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment