Commit 9ef6c23d authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Refactor Streamer to StreamReader in C++ codebase (#2403)

Summary:
* `Streamer` has been renamed to `StreamReader` when it was moved from prototype to beta.
This commit applies the same name change to the C++ source code.

* Fix miscellaneous lint issues

* Make the code compilable on FFmpeg 5

Pull Request resolved: https://github.com/pytorch/audio/pull/2403

Reviewed By: carolineechen

Differential Revision: D36613053

Pulled By: mthrok

fbshipit-source-id: 69fedd6720d488dadf4dfe7d375ee76d216b215d
parent 752de3e4
......@@ -179,15 +179,15 @@ endif()
if(USE_FFMPEG)
set(
LIBTORCHAUDIO_FFMPEG_SOURCES
ffmpeg/prototype.cpp
ffmpeg/decoder.cpp
ffmpeg/ffmpeg.cpp
ffmpeg/filter_graph.cpp
ffmpeg/buffer.cpp
ffmpeg/sink.cpp
ffmpeg/stream_processor.cpp
ffmpeg/streamer.cpp
ffmpeg/stream_reader.cpp
ffmpeg/stream_reader_wrapper.cpp
ffmpeg/stream_reader_binding.cpp
)
message(STATUS "FFMPEG_ROOT=$ENV{FFMPEG_ROOT}")
find_package(FFMPEG 4.1 REQUIRED COMPONENTS avdevice avfilter avformat avcodec avutil)
......
......@@ -13,14 +13,14 @@ Practically all the code is re-organization of examples;
https://ffmpeg.org/doxygen/4.1/examples.html
## Streamer Architecture
## StreamReader Architecture
The top level class is `Streamer` class. This class handles the input (via `AVFormatContext*`), and manages `StreamProcessor`s for each stream in the input.
The top level class is `StreamReader` class. This class handles the input (via `AVFormatContext*`), and manages `StreamProcessor`s for each stream in the input.
The `Streamer` object slices the input data into a series of `AVPacket` objects and it feeds the objects to corresponding `StreamProcessor`s.
The `StreamReader` object slices the input data into a series of `AVPacket` objects and it feeds the objects to corresponding `StreamProcessor`s.
```
Streamer
StreamReader
┌─────────────────────────────────────────────────┐
│ │
│ AVFormatContext* ┌──► StreamProcessor[0] │
......
......@@ -15,6 +15,7 @@ extern "C" {
#include <libavformat/avformat.h>
#include <libavformat/avio.h>
#include <libavutil/avutil.h>
#include <libavutil/channel_layout.h>
#include <libavutil/frame.h>
#include <libavutil/imgutils.h>
#include <libavutil/log.h>
......
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/streamer.h>
#include <torchaudio/csrc/ffmpeg/stream_reader.h>
#include <chrono>
#include <sstream>
#include <stdexcept>
......@@ -13,23 +13,26 @@ using KeyType = StreamProcessor::KeyType;
//////////////////////////////////////////////////////////////////////////////
// Helper methods
//////////////////////////////////////////////////////////////////////////////
void Streamer::validate_open_stream() const {
if (!pFormatContext)
void StreamReader::validate_open_stream() const {
if (!pFormatContext) {
throw std::runtime_error("Stream is not open.");
}
}
void Streamer::validate_src_stream_index(int i) const {
void StreamReader::validate_src_stream_index(int i) const {
validate_open_stream();
if (i < 0 || i >= static_cast<int>(pFormatContext->nb_streams))
if (i < 0 || i >= static_cast<int>(pFormatContext->nb_streams)) {
throw std::runtime_error("Source stream index out of range");
}
}
void Streamer::validate_output_stream_index(int i) const {
if (i < 0 || i >= static_cast<int>(stream_indices.size()))
void StreamReader::validate_output_stream_index(int i) const {
if (i < 0 || i >= static_cast<int>(stream_indices.size())) {
throw std::runtime_error("Output stream index out of range");
}
}
void Streamer::validate_src_stream_type(int i, AVMediaType type) {
void StreamReader::validate_src_stream_type(int i, AVMediaType type) {
validate_src_stream_index(i);
if (pFormatContext->streams[i]->codecpar->codec_type != type) {
std::ostringstream oss;
......@@ -42,7 +45,8 @@ void Streamer::validate_src_stream_type(int i, AVMediaType type) {
//////////////////////////////////////////////////////////////////////////////
// Initialization / resource allocations
//////////////////////////////////////////////////////////////////////////////
Streamer::Streamer(AVFormatContextPtr&& p) : pFormatContext(std::move(p)) {
StreamReader::StreamReader(AVFormatContextPtr&& p)
: pFormatContext(std::move(p)) {
if (avformat_find_stream_info(pFormatContext, nullptr) < 0) {
throw std::runtime_error("Failed to find stream information.");
}
......@@ -63,11 +67,11 @@ Streamer::Streamer(AVFormatContextPtr&& p) : pFormatContext(std::move(p)) {
////////////////////////////////////////////////////////////////////////////////
// Query methods
////////////////////////////////////////////////////////////////////////////////
int64_t Streamer::num_src_streams() const {
int64_t StreamReader::num_src_streams() const {
return pFormatContext->nb_streams;
}
SrcStreamInfo Streamer::get_src_stream_info(int i) const {
SrcStreamInfo StreamReader::get_src_stream_info(int i) const {
validate_src_stream_index(i);
AVStream* stream = pFormatContext->streams[i];
AVCodecParameters* codecpar = stream->codecpar;
......@@ -105,11 +109,11 @@ SrcStreamInfo Streamer::get_src_stream_info(int i) const {
return ret;
}
int64_t Streamer::num_out_streams() const {
return stream_indices.size();
int64_t StreamReader::num_out_streams() const {
return static_cast<int64_t>(stream_indices.size());
}
OutputStreamInfo Streamer::get_out_stream_info(int i) const {
OutputStreamInfo StreamReader::get_out_stream_info(int i) const {
validate_output_stream_index(i);
OutputStreamInfo ret;
int i_src = stream_indices[i].first;
......@@ -119,17 +123,17 @@ OutputStreamInfo Streamer::get_out_stream_info(int i) const {
return ret;
}
int64_t Streamer::find_best_audio_stream() const {
int64_t StreamReader::find_best_audio_stream() const {
return av_find_best_stream(
pFormatContext, AVMEDIA_TYPE_AUDIO, -1, -1, NULL, 0);
pFormatContext, AVMEDIA_TYPE_AUDIO, -1, -1, nullptr, 0);
}
int64_t Streamer::find_best_video_stream() const {
int64_t StreamReader::find_best_video_stream() const {
return av_find_best_stream(
pFormatContext, AVMEDIA_TYPE_VIDEO, -1, -1, NULL, 0);
pFormatContext, AVMEDIA_TYPE_VIDEO, -1, -1, nullptr, 0);
}
bool Streamer::is_buffer_ready() const {
bool StreamReader::is_buffer_ready() const {
for (const auto& it : processors) {
if (it && !it->is_buffer_ready()) {
return false;
......@@ -141,7 +145,7 @@ bool Streamer::is_buffer_ready() const {
////////////////////////////////////////////////////////////////////////////////
// Configure methods
////////////////////////////////////////////////////////////////////////////////
void Streamer::seek(double timestamp) {
void StreamReader::seek(double timestamp) {
if (timestamp < 0) {
throw std::runtime_error("timestamp must be non-negative.");
}
......@@ -158,7 +162,7 @@ void Streamer::seek(double timestamp) {
}
}
void Streamer::add_audio_stream(
void StreamReader::add_audio_stream(
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
......@@ -166,17 +170,17 @@ void Streamer::add_audio_stream(
const c10::optional<std::string>& decoder,
const OptionDict& decoder_option) {
add_stream(
i,
static_cast<int>(i),
AVMEDIA_TYPE_AUDIO,
frames_per_chunk,
num_chunks,
static_cast<int>(frames_per_chunk),
static_cast<int>(num_chunks),
filter_desc,
decoder,
decoder_option,
torch::Device(torch::DeviceType::CPU));
}
void Streamer::add_video_stream(
void StreamReader::add_video_stream(
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
......@@ -204,17 +208,17 @@ void Streamer::add_video_stream(
}();
add_stream(
i,
static_cast<int>(i),
AVMEDIA_TYPE_VIDEO,
frames_per_chunk,
num_chunks,
static_cast<int>(frames_per_chunk),
static_cast<int>(num_chunks),
filter_desc,
decoder,
decoder_option,
device);
}
void Streamer::add_stream(
void StreamReader::add_stream(
int i,
AVMediaType media_type,
int frames_per_chunk,
......@@ -233,10 +237,11 @@ void Streamer::add_stream(
"Failed to detect the source stream format. Please provide the decoder to use.");
}
stream->discard = AVDISCARD_DEFAULT;
if (!processors[i])
if (!processors[i]) {
processors[i] = std::make_unique<StreamProcessor>(
stream->codecpar, decoder, decoder_option, device);
}
stream->discard = AVDISCARD_DEFAULT;
int key = processors[i]->add_stream(
stream->time_base,
stream->codecpar,
......@@ -247,8 +252,8 @@ void Streamer::add_stream(
stream_indices.push_back(std::make_pair<>(i, key));
}
void Streamer::remove_stream(int64_t i) {
validate_output_stream_index(i);
void StreamReader::remove_stream(int64_t i) {
validate_output_stream_index(static_cast<int>(i));
auto it = stream_indices.begin() + i;
int iP = it->first;
processors[iP]->remove_stream(it->second);
......@@ -258,11 +263,13 @@ void Streamer::remove_stream(int64_t i) {
bool still_used = false;
for (auto& p : stream_indices) {
still_used |= (iP == p.first);
if (still_used)
if (still_used) {
break;
}
}
if (!still_used) {
processors[iP].reset(nullptr);
}
if (!still_used)
processors[iP].reset(NULL);
}
////////////////////////////////////////////////////////////////////////////////
......@@ -273,18 +280,20 @@ void Streamer::remove_stream(int64_t i) {
// 0: caller should keep calling this function
// 1: It's done, caller should stop calling
// <0: Some error happened
int Streamer::process_packet() {
int StreamReader::process_packet() {
int ret = av_read_frame(pFormatContext, pPacket);
if (ret == AVERROR_EOF) {
ret = drain();
return (ret < 0) ? ret : 1;
}
if (ret < 0)
if (ret < 0) {
return ret;
}
AutoPacketUnref packet{pPacket};
auto& processor = processors[pPacket->stream_index];
if (!processor)
if (!processor) {
return 0;
}
ret = processor->process_packet(packet);
return (ret < 0) ? ret : 0;
}
......@@ -293,7 +302,7 @@ int Streamer::process_packet() {
// it keeps retrying until timeout happens,
//
// timeout and backoff is given in millisecond
int Streamer::process_packet_block(double timeout, double backoff) {
int StreamReader::process_packet_block(double timeout, double backoff) {
auto dead_line = [&]() {
// If timeout < 0, then it repeats forever
if (timeout < 0) {
......@@ -324,19 +333,20 @@ int Streamer::process_packet_block(double timeout, double backoff) {
}
// <0: Some error happened.
int Streamer::drain() {
int StreamReader::drain() {
int ret = 0, tmp = 0;
for (auto& p : processors) {
if (p) {
tmp = p->process_packet(NULL);
if (tmp < 0)
tmp = p->process_packet(nullptr);
if (tmp < 0) {
ret = tmp;
}
}
}
return ret;
}
std::vector<c10::optional<torch::Tensor>> Streamer::pop_chunks() {
std::vector<c10::optional<torch::Tensor>> StreamReader::pop_chunks() {
std::vector<c10::optional<torch::Tensor>> ret;
for (auto& i : stream_indices) {
ret.push_back(processors[i.first]->pop_chunk(i.second));
......
......@@ -8,7 +8,7 @@
namespace torchaudio {
namespace ffmpeg {
class Streamer {
class StreamReader {
AVFormatContextPtr pFormatContext;
AVPacketPtr pPacket;
......@@ -19,14 +19,14 @@ class Streamer {
std::vector<std::pair<int, int>> stream_indices;
public:
explicit Streamer(AVFormatContextPtr&& p);
~Streamer() = default;
explicit StreamReader(AVFormatContextPtr&& p);
~StreamReader() = default;
// Non-copyable
Streamer(const Streamer&) = delete;
Streamer& operator=(const Streamer&) = delete;
StreamReader(const StreamReader&) = delete;
StreamReader& operator=(const StreamReader&) = delete;
// Movable
Streamer(Streamer&&) = default;
Streamer& operator=(Streamer&&) = default;
StreamReader(StreamReader&&) = default;
StreamReader& operator=(StreamReader&&) = default;
//////////////////////////////////////////////////////////////////////////////
// Helper methods
......
......@@ -28,12 +28,13 @@ c10::intrusive_ptr<StreamReaderBinding> init(
std::tuple<c10::optional<torch::Tensor>, int64_t> load(const std::string& src) {
StreamReaderBinding s{get_input_format_context(src, {}, {})};
int i = s.find_best_audio_stream();
auto sinfo = s.Streamer::get_src_stream_info(i);
int i = static_cast<int>(s.find_best_audio_stream());
auto sinfo = s.StreamReader::get_src_stream_info(i);
int64_t sample_rate = static_cast<int64_t>(sinfo.sample_rate);
s.add_audio_stream(i, -1, -1, {}, {}, {});
s.process_all_packets();
auto tensors = s.pop_chunks();
assert(tensors.size() > 0);
return std::make_tuple<>(tensors[0], sample_rate);
}
......@@ -42,11 +43,12 @@ using S = const c10::intrusive_ptr<StreamReaderBinding>&;
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::ffmpeg_init", []() {
avdevice_register_all();
if (av_log_get_level() == AV_LOG_INFO)
if (av_log_get_level() == AV_LOG_INFO) {
av_log_set_level(AV_LOG_ERROR);
}
});
m.def("torchaudio::ffmpeg_load", load);
m.class_<StreamReaderBinding>("ffmpeg_Streamer")
m.class_<StreamReaderBinding>("ffmpeg_StreamReader")
.def(torch::init<>(init))
.def("num_src_streams", [](S self) { return self->num_src_streams(); })
.def("num_out_streams", [](S self) { return self->num_out_streams(); })
......
......@@ -25,14 +25,14 @@ OutInfo convert(OutputStreamInfo osi) {
} // namespace
StreamReaderBinding::StreamReaderBinding(AVFormatContextPtr&& p)
: Streamer(std::move(p)) {}
: StreamReader(std::move(p)) {}
SrcInfo StreamReaderBinding::get_src_stream_info(int64_t i) {
return convert(Streamer::get_src_stream_info(i));
return convert(StreamReader::get_src_stream_info(static_cast<int>(i)));
}
OutInfo StreamReaderBinding::get_out_stream_info(int64_t i) {
return convert(Streamer::get_out_stream_info(i));
return convert(StreamReader::get_out_stream_info(static_cast<int>(i)));
}
int64_t StreamReaderBinding::process_packet(
......@@ -40,9 +40,9 @@ int64_t StreamReaderBinding::process_packet(
const double backoff) {
int64_t code = [&]() {
if (timeout.has_value()) {
return Streamer::process_packet_block(timeout.value(), backoff);
return StreamReader::process_packet_block(timeout.value(), backoff);
}
return Streamer::process_packet();
return StreamReader::process_packet();
}();
if (code < 0) {
throw std::runtime_error(
......
#pragma once
#include <torch/script.h>
#include <torchaudio/csrc/ffmpeg/streamer.h>
#include <torchaudio/csrc/ffmpeg/stream_reader.h>
namespace torchaudio {
namespace ffmpeg {
......@@ -25,9 +25,10 @@ using OutInfo = std::tuple<
std::string // filter description
>;
// Structure to implement wrapper API around Streamer, which is more suitable
// for Binding the code (i.e. it receives/returns pritimitves)
struct StreamReaderBinding : public Streamer, public torch::CustomClassHolder {
// Structure to implement wrapper API around StreamReader, which is more
// suitable for Binding the code (i.e. it receives/returns pritimitves)
struct StreamReaderBinding : public StreamReader,
public torch::CustomClassHolder {
explicit StreamReaderBinding(AVFormatContextPtr&& p);
SrcInfo get_src_stream_info(int64_t i);
OutInfo get_out_stream_info(int64_t i);
......
......@@ -11,7 +11,7 @@ struct SrcStreamInfo {
const char* codec_name = "N/A";
const char* codec_long_name = "N/A";
const char* fmt_name = "N/A";
int bit_rate = 0;
int64_t bit_rate = 0;
// Audio
double sample_rate = 0;
int num_channels = 0;
......
......@@ -322,7 +322,7 @@ class StreamReader:
buffer_size: int = 4096,
):
if isinstance(src, str):
self._be = torch.classes.torchaudio.ffmpeg_Streamer(src, format, option)
self._be = torch.classes.torchaudio.ffmpeg_StreamReader(src, format, option)
elif hasattr(src, "read"):
self._be = torchaudio._torchaudio_ffmpeg.StreamReaderFileObj(src, format, option, buffer_size)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment