Commit 000878e0 authored by Jeff Hwang's avatar Jeff Hwang Committed by Facebook GitHub Bot
Browse files

Introduce packet passthrough feature to streaming api (#3220)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3220

Introduces methods to `StreamReader` and `StreamWriter` that allow for reading and writing `AVPacket` instances rather than tensors. Useful for efficiently remuxing data pulled as is from source.

Reviewed By: mthrok

Differential Revision: D44271536

fbshipit-source-id: 9b9d743c0119a5eb564fa628fd6a67806d120985
parent 9da92cdb
......@@ -13,11 +13,13 @@ set(
stream_reader/buffer/chunked_buffer.cpp
stream_reader/buffer/unchunked_buffer.cpp
stream_reader/conversion.cpp
stream_reader/packet_buffer.cpp
stream_reader/post_process.cpp
stream_reader/stream_processor.cpp
stream_reader/stream_reader.cpp
stream_writer/encode_process.cpp
stream_writer/encoder.cpp
stream_writer/packet_writer.cpp
stream_writer/stream_writer.cpp
stream_writer/tensor_converter.cpp
compat.cpp
......
......@@ -158,5 +158,23 @@ AVFilterGraphPtr::AVFilterGraphPtr()
void AVFilterGraphPtr::reset() {
ptr.reset(get_filter_graph());
}
////////////////////////////////////////////////////////////////////////////////
// AVCodecParameters
////////////////////////////////////////////////////////////////////////////////
void AVCodecParametersDeleter::operator()(AVCodecParameters* codecpar) {
avcodec_parameters_free(&codecpar);
}
namespace {
AVCodecParameters* get_codecpar() {
AVCodecParameters* ptr = avcodec_parameters_alloc();
TORCH_CHECK(ptr, "Failed to allocate resource.");
return ptr;
}
} // namespace
AVCodecParametersPtr::AVCodecParametersPtr()
: Wrapper<AVCodecParameters, AVCodecParametersDeleter>(get_codecpar()) {}
} // namespace io
} // namespace torchaudio
......@@ -189,6 +189,24 @@ struct AVFilterGraphPtr : public Wrapper<AVFilterGraph, AVFilterGraphDeleter> {
AVFilterGraphPtr();
void reset();
};
////////////////////////////////////////////////////////////////////////////////
// AVCodecParameters
////////////////////////////////////////////////////////////////////////////////
struct AVCodecParametersDeleter {
void operator()(AVCodecParameters* p);
};
struct AVCodecParametersPtr
: public Wrapper<AVCodecParameters, AVCodecParametersDeleter> {
AVCodecParametersPtr();
};
struct StreamParams {
AVCodecParametersPtr codec_params;
AVRational time_base{};
int stream_index{};
};
} // namespace io
} // namespace torchaudio
......
#include <torchaudio/csrc/ffmpeg/stream_reader/packet_buffer.h>
namespace torchaudio {
namespace io {
void PacketBuffer::push_packet(AVPacket* packet) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(packet, "Packet is null.");
AVPacketPtr pPacket;
av_packet_ref(pPacket, packet);
packets.push_back(std::move(pPacket));
}
std::vector<AVPacketPtr> PacketBuffer::pop_packets() {
std::vector<AVPacketPtr> ret{
std::make_move_iterator(packets.begin()),
std::make_move_iterator(packets.end())};
packets.clear();
return ret;
}
bool PacketBuffer::has_packets() {
return packets.size() > 0;
}
} // namespace io
} // namespace torchaudio
#pragma once
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
namespace torchaudio {
namespace io {
class PacketBuffer {
public:
void push_packet(AVPacket* packet);
std::vector<AVPacketPtr> pop_packets();
bool has_packets();
private:
std::deque<AVPacketPtr> packets;
};
} // namespace io
} // namespace torchaudio
......@@ -179,6 +179,21 @@ SrcStreamInfo StreamReader::get_src_stream_info(int i) const {
return ret;
}
StreamParams StreamReader::get_src_stream_params(int i) {
StreamParams params;
validate_src_stream_index(pFormatContext, i);
AVStream* stream = pFormatContext->streams[i];
int ret = avcodec_parameters_copy(params.codec_params, stream->codecpar);
TORCH_CHECK(
ret >= 0,
"Failed to copy the stream's codec parameters. (",
av_err2string(ret),
")");
params.time_base = stream->time_base;
params.stream_index = i;
return params;
}
int64_t StreamReader::num_out_streams() const {
return static_cast<int64_t>(stream_indices.size());
}
......@@ -222,11 +237,19 @@ int64_t StreamReader::find_best_video_stream() const {
}
bool StreamReader::is_buffer_ready() const {
if (processors.empty()) {
// If no decoding output streams exist, then determine overall readiness
// from the readiness of packet buffer.
return packet_buffer->has_packets();
} else {
// Otherwise, determine readiness solely from the readiness of the decoding
// output streams.
for (const auto& it : processors) {
if (it && !it->is_buffer_ready()) {
return false;
}
}
}
return true;
}
......@@ -326,6 +349,14 @@ void StreamReader::add_video_stream(
device);
}
void StreamReader::add_packet_stream(int i) {
validate_src_stream_index(pFormatContext, i);
if (!packet_buffer) {
packet_buffer = std::make_unique<PacketBuffer>();
}
packet_stream_indices.emplace(i);
}
void StreamReader::add_stream(
int i,
AVMediaType media_type,
......@@ -338,8 +369,8 @@ void StreamReader::add_stream(
validate_src_stream_type(pFormatContext, i, media_type);
AVStream* stream = pFormatContext->streams[i];
// When media source is file-like object, it is possible that source codec is
// not detected properly.
// When media source is file-like object, it is possible that source codec
// is not detected properly.
TORCH_CHECK(
stream->codecpar->format != -1,
"Failed to detect the source stream format.");
......@@ -417,7 +448,14 @@ int StreamReader::process_packet() {
return ret;
}
AutoPacketUnref packet{pPacket};
auto& processor = processors[pPacket->stream_index];
int stream_index = pPacket->stream_index;
if (packet_stream_indices.count(stream_index)) {
packet_buffer->push_packet(packet);
}
auto& processor = processors[stream_index];
if (!processor) {
return 0;
}
......@@ -517,5 +555,8 @@ std::vector<c10::optional<Chunk>> StreamReader::pop_chunks() {
return ret;
}
std::vector<AVPacketPtr> StreamReader::pop_packets() {
return packet_buffer->pop_packets();
}
} // namespace io
} // namespace torchaudio
#pragma once
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/packet_buffer.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/stream_processor.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/typedefs.h>
#include <vector>
......@@ -20,6 +21,11 @@ class StreamReader {
// the second is the map key inside of processor.
std::vector<std::pair<int, int>> stream_indices;
// For supporting reading raw packets.
std::unique_ptr<PacketBuffer> packet_buffer;
// Set of source stream indices to read packets for.
std::unordered_set<int> packet_stream_indices;
// timestamp to seek to expressed in AV_TIME_BASE
//
// 0 : No seek
......@@ -128,6 +134,14 @@ class StreamReader {
/// Check if all the buffers of the output streams have enough decoded frames.
bool is_buffer_ready() const;
/// @cond
/// Get source stream parameters. Necessary on the write side for packet
/// passthrough.
///
/// @param i Source stream index.
StreamParams get_src_stream_params(int i);
/// @endcond
///@}
//////////////////////////////////////////////////////////////////////////////
......@@ -215,6 +229,16 @@ class StreamReader {
const c10::optional<std::string>& decoder = c10::nullopt,
const c10::optional<OptionDict>& decoder_option = c10::nullopt,
const c10::optional<std::string>& hw_accel = c10::nullopt);
/// @cond
/// Add a output packet stream.
/// Allows for passing packets directly from the source stream, bypassing
/// the decode path, to ``StreamWriter`` for remuxing.
///
/// @param i The index of the source stream.
void add_packet_stream(int i);
/// @endcond
/// Remove an output stream.
///
/// @param i The index of the output stream to be removed.
......@@ -306,6 +330,10 @@ class StreamReader {
/// Pop one chunk from each output stream if it is available.
std::vector<c10::optional<Chunk>> pop_chunks();
/// @cond
/// Pop packets from buffer, if available.
std::vector<AVPacketPtr> pop_packets();
/// @endcond
///@}
};
......
#include <torchaudio/csrc/ffmpeg/stream_writer/packet_writer.h>
namespace torchaudio::io {
namespace {
AVStream* add_stream(
AVFormatContext* format_ctx,
const StreamParams& stream_params) {
AVStream* stream = avformat_new_stream(format_ctx, nullptr);
int ret =
avcodec_parameters_copy(stream->codecpar, stream_params.codec_params);
TORCH_CHECK(
ret >= 0,
"Failed to copy the stream's codec parameters. (",
av_err2string(ret),
")");
stream->time_base = stream_params.time_base;
return stream;
}
} // namespace
PacketWriter::PacketWriter(
AVFormatContext* format_ctx_,
const StreamParams& stream_params_)
: format_ctx(format_ctx_),
stream(add_stream(format_ctx_, stream_params_)),
original_time_base(stream_params_.time_base) {}
void PacketWriter::write_packet(const AVPacketPtr& packet) {
AVPacket dst_packet;
int ret = av_packet_ref(&dst_packet, packet);
TORCH_CHECK(ret >= 0, "Failed to copy packet.");
av_packet_rescale_ts(&dst_packet, original_time_base, stream->time_base);
dst_packet.stream_index = stream->index;
ret = av_interleaved_write_frame(format_ctx, &dst_packet);
TORCH_CHECK(ret >= 0, "Failed to write packet to destination.");
}
} // namespace torchaudio::io
#pragma once
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
namespace torchaudio::io {
class PacketWriter {
AVFormatContext* format_ctx;
AVStream* stream;
AVRational original_time_base;
public:
PacketWriter(
AVFormatContext* format_ctx_,
const StreamParams& stream_params_);
void write_packet(const AVPacketPtr& packet);
};
} // namespace torchaudio::io
......@@ -66,9 +66,12 @@ void StreamWriter::add_audio_stream(
const c10::optional<std::string>& filter_desc) {
TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
pFormatContext->nb_streams == num_output_streams(),
"The number of encode process and the number of output streams do not match.");
processes.emplace_back(get_audio_encode_process(
processes.emplace(
std::piecewise_construct,
std::forward_as_tuple(current_key),
std::forward_as_tuple(get_audio_encode_process(
pFormatContext,
sample_rate,
num_channels,
......@@ -79,7 +82,8 @@ void StreamWriter::add_audio_stream(
encoder_sample_rate,
encoder_num_channels,
codec_config,
filter_desc));
filter_desc)));
current_key++;
}
void StreamWriter::add_video_stream(
......@@ -98,9 +102,12 @@ void StreamWriter::add_video_stream(
const c10::optional<std::string>& filter_desc) {
TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
pFormatContext->nb_streams == num_output_streams(),
"The number of encode process and the number of output streams do not match.");
processes.emplace_back(get_video_encode_process(
processes.emplace(
std::piecewise_construct,
std::forward_as_tuple(current_key),
std::forward_as_tuple(get_video_encode_process(
pFormatContext,
frame_rate,
width,
......@@ -114,7 +121,16 @@ void StreamWriter::add_video_stream(
encoder_height,
hw_accel,
codec_config,
filter_desc));
filter_desc)));
current_key++;
}
void StreamWriter::add_packet_stream(const StreamParams& stream_params) {
packet_writers.emplace(
std::piecewise_construct,
std::forward_as_tuple(stream_params.stream_index),
std::forward_as_tuple(pFormatContext, stream_params));
current_key++;
}
void StreamWriter::add_audio_frame_stream(
......@@ -130,9 +146,12 @@ void StreamWriter::add_audio_frame_stream(
const c10::optional<std::string>& filter_desc) {
TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
pFormatContext->nb_streams == num_output_streams(),
"The number of encode process and the number of output streams do not match.");
processes.emplace_back(get_audio_encode_process(
processes.emplace(
std::piecewise_construct,
std::forward_as_tuple(current_key),
std::forward_as_tuple(get_audio_encode_process(
pFormatContext,
sample_rate,
num_channels,
......@@ -144,7 +163,8 @@ void StreamWriter::add_audio_frame_stream(
encoder_num_channels,
codec_config,
filter_desc,
true));
true)));
current_key++;
}
void StreamWriter::add_video_frame_stream(
......@@ -163,9 +183,12 @@ void StreamWriter::add_video_frame_stream(
const c10::optional<std::string>& filter_desc) {
TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
pFormatContext->nb_streams == num_output_streams(),
"The number of encode process and the number of output streams do not match.");
processes.emplace_back(get_video_encode_process(
processes.emplace(
std::piecewise_construct,
std::forward_as_tuple(current_key),
std::forward_as_tuple(get_video_encode_process(
pFormatContext,
frame_rate,
width,
......@@ -180,7 +203,8 @@ void StreamWriter::add_video_frame_stream(
hw_accel,
codec_config,
filter_desc,
true));
true)));
current_key++;
}
void StreamWriter::set_metadata(const OptionDict& metadata) {
......@@ -196,7 +220,7 @@ void StreamWriter::dump_format(int64_t i) {
void StreamWriter::open(const c10::optional<OptionDict>& option) {
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
pFormatContext->nb_streams == num_output_streams(),
"The number of encode process and the number of output streams do not match.");
int ret = 0;
......@@ -270,7 +294,7 @@ void StreamWriter::write_audio_chunk(
"Stream ",
i,
" is not audio type.");
processes[i].process(waveform, pts);
processes.at(i).process(waveform, pts);
}
void StreamWriter::write_video_chunk(
......@@ -289,7 +313,17 @@ void StreamWriter::write_video_chunk(
"Stream ",
i,
" is not video type.");
processes[i].process(frames, pts);
processes.at(i).process(frames, pts);
}
void StreamWriter::write_packet(const AVPacketPtr& packet) {
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
int src_stream_index = packet->stream_index;
TORCH_CHECK(
packet_writers.count(src_stream_index),
"Invalid packet stream source index ",
src_stream_index);
packet_writers.at(src_stream_index).write_packet(packet);
}
void StreamWriter::write_frame(int i, AVFrame* frame) {
......@@ -300,15 +334,18 @@ void StreamWriter::write_frame(int i, AVFrame* frame) {
pFormatContext->nb_streams,
"). Found: ",
i);
processes[i].process_frame(frame);
processes.at(i).process_frame(frame);
}
void StreamWriter::flush() {
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
for (auto& p : processes) {
p.flush();
p.second.flush();
}
}
int StreamWriter::num_output_streams() {
return static_cast<int>(processes.size() + packet_writers.size());
}
} // namespace io
} // namespace torchaudio
......@@ -4,6 +4,7 @@
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/filter_graph.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/encode_process.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/packet_writer.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/types.h>
namespace torchaudio {
......@@ -15,9 +16,12 @@ namespace io {
class StreamWriter {
AVFormatOutputContextPtr pFormatContext;
AVBufferRefPtr pHWBufferRef;
std::vector<EncodeProcess> processes;
std::map<int, EncodeProcess> processes;
std::map<int, PacketWriter> packet_writers;
AVPacketPtr pkt;
bool is_open = false;
int current_key = 0;
protected:
/// @cond
......@@ -195,7 +199,16 @@ class StreamWriter {
const c10::optional<std::string>& hw_accel = c10::nullopt,
const c10::optional<CodecConfig>& codec_config = c10::nullopt,
const c10::optional<std::string>& filter_desc = c10::nullopt);
/// Add packet stream. Intended to be used in conjunction with
/// ``StreamReader`` to perform packet passthrough.
/// @param stream_params Stream parameters returned by
/// ``StreamReader::get_src_stream_params()`` for the packet stream to pass
/// through.
void add_packet_stream(const StreamParams& stream_params);
/// @endcond
/// Set file-level metadata
/// @param metadata metadata.
void set_metadata(const OptionDict& metadata);
......@@ -256,9 +269,16 @@ class StreamWriter {
/// @param i Stream index.
/// @param frame Frame to write.
void write_frame(int i, AVFrame* frame);
/// Write packet.
/// @param packet Packet to write, passed from ``StreamReader``.
void write_packet(const AVPacketPtr& packet);
/// @endcond
/// Flush the frames from encoders and write the frames to the destination.
void flush();
private:
int num_output_streams();
};
} // namespace io
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment