Commit a2d6fee2 authored by John Lu's avatar John Lu Committed by Facebook GitHub Bot
Browse files

Replace `runtime_error` exception with `TORCH_CHECK` in TorchAudio ffmpeg dir (2/2) (#2551)

Summary:
`std::runtime_error` does not preserve the C++ stack trace, so it is unclear to users what went wrong internally.

PyTorch's `TORCH_CHECK` macro allows to print C++ stack trace when `TORCH_SHOW_CPP_STACKTRACES` environment variable is set to 1.

Pull Request resolved: https://github.com/pytorch/audio/pull/2551

Improve assertion for TorchAudio ffmpeg directory

Reviewed By: mthrok

Differential Revision: D37915732

fbshipit-source-id: 9f597eb00cadd0dc6a1bbf8f7d5c8092804ef685
parent ee631d6b
...@@ -14,32 +14,31 @@ using KeyType = StreamProcessor::KeyType; ...@@ -14,32 +14,31 @@ using KeyType = StreamProcessor::KeyType;
// Helper methods // Helper methods
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
void StreamReader::validate_open_stream() const { void StreamReader::validate_open_stream() const {
if (!pFormatContext) { TORCH_CHECK(pFormatContext, "Stream is not open.");
throw std::runtime_error("Stream is not open.");
}
} }
void StreamReader::validate_src_stream_index(int i) const { void StreamReader::validate_src_stream_index(int i) const {
validate_open_stream(); validate_open_stream();
if (i < 0 || i >= static_cast<int>(pFormatContext->nb_streams)) { TORCH_CHECK(
throw std::runtime_error("Source stream index out of range"); i >= 0 && i < static_cast<int>(pFormatContext->nb_streams),
} "Source stream index out of range");
} }
void StreamReader::validate_output_stream_index(int i) const { void StreamReader::validate_output_stream_index(int i) const {
if (i < 0 || i >= static_cast<int>(stream_indices.size())) { TORCH_CHECK(
throw std::runtime_error("Output stream index out of range"); i >= 0 && i < static_cast<int>(stream_indices.size()),
} "Output stream index out of range");
} }
void StreamReader::validate_src_stream_type(int i, AVMediaType type) { void StreamReader::validate_src_stream_type(int i, AVMediaType type) {
validate_src_stream_index(i); validate_src_stream_index(i);
if (pFormatContext->streams[i]->codecpar->codec_type != type) { TORCH_CHECK(
std::ostringstream oss; pFormatContext->streams[i]->codecpar->codec_type == type,
oss << "Stream " << i << " is not " << av_get_media_type_string(type) "Stream ",
<< " stream."; i,
throw std::runtime_error(oss.str()); " is not ",
} av_get_media_type_string(type),
" stream.");
} }
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
...@@ -47,9 +46,9 @@ void StreamReader::validate_src_stream_type(int i, AVMediaType type) { ...@@ -47,9 +46,9 @@ void StreamReader::validate_src_stream_type(int i, AVMediaType type) {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
StreamReader::StreamReader(AVFormatInputContextPtr&& p) StreamReader::StreamReader(AVFormatInputContextPtr&& p)
: pFormatContext(std::move(p)) { : pFormatContext(std::move(p)) {
if (avformat_find_stream_info(pFormatContext, nullptr) < 0) { int ret = avformat_find_stream_info(pFormatContext, nullptr);
throw std::runtime_error("Failed to find stream information."); TORCH_CHECK(
} ret >= 0, "Failed to find stream information: ", av_err2string(ret));
processors = processors =
std::vector<std::unique_ptr<StreamProcessor>>(pFormatContext->nb_streams); std::vector<std::unique_ptr<StreamProcessor>>(pFormatContext->nb_streams);
...@@ -165,15 +164,11 @@ bool StreamReader::is_buffer_ready() const { ...@@ -165,15 +164,11 @@ bool StreamReader::is_buffer_ready() const {
// Configure methods // Configure methods
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
void StreamReader::seek(double timestamp) { void StreamReader::seek(double timestamp) {
if (timestamp < 0) { TORCH_CHECK(timestamp >= 0, "timestamp must be non-negative.");
throw std::runtime_error("timestamp must be non-negative.");
}
int64_t ts = static_cast<int64_t>(timestamp * AV_TIME_BASE); int64_t ts = static_cast<int64_t>(timestamp * AV_TIME_BASE);
int ret = avformat_seek_file(pFormatContext, -1, INT64_MIN, ts, INT64_MAX, 0); int ret = avformat_seek_file(pFormatContext, -1, INT64_MIN, ts, INT64_MAX, 0);
if (ret < 0) { TORCH_CHECK(ret >= 0, "Failed to seek. (" + av_err2string(ret) + ".)");
throw std::runtime_error("Failed to seek. (" + av_err2string(ret) + ".)");
}
for (const auto& it : processors) { for (const auto& it : processors) {
if (it) { if (it) {
it->flush(); it->flush();
...@@ -213,15 +208,14 @@ void StreamReader::add_video_stream( ...@@ -213,15 +208,14 @@ void StreamReader::add_video_stream(
} }
#ifdef USE_CUDA #ifdef USE_CUDA
torch::Device d{hw_accel.value()}; torch::Device d{hw_accel.value()};
if (d.type() != c10::DeviceType::CUDA) { TORCH_CHECK(
std::stringstream ss; d.type() == c10::DeviceType::CUDA,
ss << "Only CUDA is supported for hardware acceleration. Found: " "Only CUDA is supported for hardware acceleration. Found: ",
<< device.str(); device.str());
throw std::runtime_error(ss.str());
}
return d; return d;
#else #else
throw std::runtime_error( TORCH_CHECK(
false,
"torchaudio is not compiled with CUDA support. Hardware acceleration is not available."); "torchaudio is not compiled with CUDA support. Hardware acceleration is not available.");
#endif #endif
}(); }();
...@@ -251,9 +245,9 @@ void StreamReader::add_stream( ...@@ -251,9 +245,9 @@ void StreamReader::add_stream(
AVStream* stream = pFormatContext->streams[i]; AVStream* stream = pFormatContext->streams[i];
// When media source is file-like object, it is possible that source codec is // When media source is file-like object, it is possible that source codec is
// not detected properly. // not detected properly.
if (stream->codecpar->format == -1) { TORCH_CHECK(
throw std::runtime_error("Failed to detect the source stream format."); stream->codecpar->format != -1,
} "Failed to detect the source stream format.");
if (!processors[i]) { if (!processors[i]) {
processors[i] = std::make_unique<StreamProcessor>( processors[i] = std::make_unique<StreamProcessor>(
......
...@@ -33,9 +33,7 @@ AVFormatInputContextPtr get_input_format_context( ...@@ -33,9 +33,7 @@ AVFormatInputContextPtr get_input_format_context(
const c10::optional<OptionDict>& option, const c10::optional<OptionDict>& option,
AVIOContext* io_ctx) { AVIOContext* io_ctx) {
AVFormatContext* pFormat = avformat_alloc_context(); AVFormatContext* pFormat = avformat_alloc_context();
if (!pFormat) { TORCH_CHECK(pFormat, "Failed to allocate AVFormatContext.");
throw std::runtime_error("Failed to allocate AVFormatContext.");
}
if (io_ctx) { if (io_ctx) {
pFormat->pb = io_ctx; pFormat->pb = io_ctx;
} }
...@@ -45,11 +43,7 @@ AVFormatInputContextPtr get_input_format_context( ...@@ -45,11 +43,7 @@ AVFormatInputContextPtr get_input_format_context(
std::string device_str = device.value(); std::string device_str = device.value();
AVFORMAT_CONST AVInputFormat* p = AVFORMAT_CONST AVInputFormat* p =
av_find_input_format(device_str.c_str()); av_find_input_format(device_str.c_str());
if (!p) { TORCH_CHECK(p, "Unsupported device/format: \"", device_str, "\"");
std::ostringstream msg;
msg << "Unsupported device/format: \"" << device_str << "\"";
throw std::runtime_error(msg.str());
}
return p; return p;
} }
return nullptr; return nullptr;
...@@ -59,10 +53,9 @@ AVFormatInputContextPtr get_input_format_context( ...@@ -59,10 +53,9 @@ AVFormatInputContextPtr get_input_format_context(
int ret = avformat_open_input(&pFormat, src.c_str(), pInput, &opt); int ret = avformat_open_input(&pFormat, src.c_str(), pInput, &opt);
clean_up_dict(opt); clean_up_dict(opt);
if (ret < 0) TORCH_CHECK(
throw std::runtime_error( ret >= 0,
"Failed to open the input \"" + src + "\" (" + av_err2string(ret) + "Failed to open the input \"" + src + "\" (" + av_err2string(ret) + ").");
").");
return AVFormatInputContextPtr(pFormat); return AVFormatInputContextPtr(pFormat);
} }
...@@ -86,10 +79,8 @@ int64_t StreamReaderBinding::process_packet( ...@@ -86,10 +79,8 @@ int64_t StreamReaderBinding::process_packet(
} }
return StreamReader::process_packet(); return StreamReader::process_packet();
}(); }();
if (code < 0) { TORCH_CHECK(
throw std::runtime_error( code >= 0, "Failed to process a packet. (" + av_err2string(code) + "). ");
"Failed to process a packet. (" + av_err2string(code) + "). ");
}
return code; return code;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment