Commit 8a9ab2a4 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Refactor encoding process (#3146)

Summary:
After the series of simplification, audio/video encoding processes
can be merged, and it allows the gets rid of the boilerplate code.

Pull Request resolved: https://github.com/pytorch/audio/pull/3146

(Note: this ignores all push blocking failures!)

Reviewed By: xiaohui-zhang

Differential Revision: D43815640

fbshipit-source-id: 2a14e372b2cc75db7eeabc27d855a24c3f7d5063
parent b96a7ebb
......@@ -16,14 +16,10 @@ set(
stream_reader/sink.cpp
stream_reader/stream_processor.cpp
stream_reader/stream_reader.cpp
stream_writer/encode_process.cpp
stream_writer/encoder.cpp
stream_writer/converter.cpp
stream_writer/output_stream.cpp
stream_writer/audio_converter.cpp
stream_writer/audio_output_stream.cpp
stream_writer/video_converter.cpp
stream_writer/video_output_stream.cpp
stream_writer/stream_writer.cpp
stream_writer/tensor_converter.cpp
compat.cpp
utils.cpp
)
......
// One stop header for all ffmepg needs
#pragma once
#include <torch/torch.h>
#include <torch/types.h>
#include <cstdint>
#include <map>
#include <memory>
......
#include <torchaudio/csrc/ffmpeg/stream_writer/audio_converter.h>
namespace torchaudio::io {
namespace {
void validate_audio_input(AVFrame* buffer, const torch::Tensor& t) {
auto dtype = t.dtype().toScalarType();
switch (static_cast<AVSampleFormat>(buffer->format)) {
case AV_SAMPLE_FMT_U8:
TORCH_CHECK(
dtype == c10::ScalarType::Byte, "Expected Tensor of uint8 type.");
break;
case AV_SAMPLE_FMT_S16:
TORCH_CHECK(
dtype == c10::ScalarType::Short, "Expected Tensor of int16 type.");
break;
case AV_SAMPLE_FMT_S32:
TORCH_CHECK(
dtype == c10::ScalarType::Int, "Expected Tensor of int32 type.");
break;
case AV_SAMPLE_FMT_S64:
TORCH_CHECK(
dtype == c10::ScalarType::Long, "Expected Tensor of int64 type.");
break;
case AV_SAMPLE_FMT_FLT:
TORCH_CHECK(
dtype == c10::ScalarType::Float, "Expected Tensor of float32 type.");
break;
case AV_SAMPLE_FMT_DBL:
TORCH_CHECK(
dtype == c10::ScalarType::Double, "Expected Tensor of float64 type.");
break;
default:
TORCH_CHECK(
false,
"Internal error: Audio encoding stream is not properly configured.");
}
TORCH_CHECK(t.device().is_cpu(), "Input tensor has to be on CPU.");
TORCH_CHECK(t.dim() == 2, "Input Tensor has to be 2D.");
const auto num_channels = t.size(1);
TORCH_CHECK(
num_channels == buffer->channels,
"Expected waveform with ",
buffer->channels,
" channels. Found ",
num_channels);
}
// 2D (time, channel) and contiguous.
void convert_func_(const torch::Tensor& chunk, AVFrame* buffer) {
auto num_frames = chunk.size(0);
auto byte_size = chunk.numel() * chunk.element_size();
// TODO: make writable
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00334
TORCH_CHECK(av_frame_is_writable(buffer), "frame is not writable.");
memcpy(buffer->data[0], chunk.data_ptr(), byte_size);
buffer->nb_samples = static_cast<int>(num_frames);
}
} // namespace
AudioTensorConverter::AudioTensorConverter(
AVFrame* buffer_,
const int64_t buffer_size_)
: buffer(buffer_), buffer_size(buffer_size_), convert_func(convert_func_) {}
SlicingTensorConverter AudioTensorConverter::convert(
const torch::Tensor& frames) {
validate_audio_input(buffer, frames);
return SlicingTensorConverter{
frames.contiguous(),
buffer,
convert_func,
buffer_size,
};
}
} // namespace torchaudio::io
#pragma once
#include <torchaudio/csrc/ffmpeg/stream_writer/converter.h>
namespace torchaudio::io {
////////////////////////////////////////////////////////////////////////////////
// AudioTensorConverter
////////////////////////////////////////////////////////////////////////////////
// AudioTensorConverter is responsible for picking up the right set of
// conversion process (InitFunc and ConvertFunc) based on the input sample
// format information, and own them.
class AudioTensorConverter {
AVFrame* buffer;
const int64_t buffer_size;
SlicingTensorConverter::ConvertFunc convert_func;
public:
AudioTensorConverter(AVFrame* buffer, const int64_t buffer_size);
SlicingTensorConverter convert(const torch::Tensor& frames);
};
} // namespace torchaudio::io
#include <torchaudio/csrc/ffmpeg/stream_writer/audio_output_stream.h>
namespace torchaudio::io {
namespace {
FilterGraph get_audio_filter(
AVSampleFormat src_fmt,
AVCodecContext* codec_ctx) {
auto desc = [&]() -> std::string {
if (src_fmt == codec_ctx->sample_fmt) {
return "anull";
} else {
std::stringstream ss;
ss << "aformat=" << av_get_sample_fmt_name(codec_ctx->sample_fmt);
return ss.str();
}
}();
FilterGraph p{AVMEDIA_TYPE_AUDIO};
p.add_audio_src(
src_fmt,
codec_ctx->time_base,
codec_ctx->sample_rate,
codec_ctx->channel_layout);
p.add_sink();
p.add_process(desc);
p.create_filter();
return p;
}
AVFramePtr get_audio_frame(
AVSampleFormat src_fmt,
AVCodecContext* codec_ctx,
int default_frame_size = 10000) {
AVFramePtr frame{};
frame->pts = 0;
frame->format = src_fmt;
frame->channel_layout = codec_ctx->channel_layout;
frame->sample_rate = codec_ctx->sample_rate;
frame->nb_samples =
codec_ctx->frame_size ? codec_ctx->frame_size : default_frame_size;
if (frame->nb_samples) {
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0,
"Error allocating an audio buffer (",
av_err2string(ret),
").");
}
return frame;
}
} // namespace
AudioOutputStream::AudioOutputStream(
AVFormatContext* format_ctx,
AVSampleFormat src_fmt,
AVCodecContextPtr&& codec_ctx_)
: OutputStream(
format_ctx,
codec_ctx_,
get_audio_filter(src_fmt, codec_ctx_)),
buffer(get_audio_frame(src_fmt, codec_ctx_)),
converter(buffer, buffer->nb_samples),
codec_ctx(std::move(codec_ctx_)) {}
void AudioOutputStream::write_chunk(const torch::Tensor& waveform) {
AVRational time_base{1, codec_ctx->sample_rate};
for (const auto& frame : converter.convert(waveform)) {
process_frame(frame);
frame->pts +=
av_rescale_q(frame->nb_samples, time_base, codec_ctx->time_base);
}
}
} // namespace torchaudio::io
#pragma once
#include <torchaudio/csrc/ffmpeg/stream_writer/audio_converter.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/output_stream.h>
namespace torchaudio::io {
struct AudioOutputStream : OutputStream {
AVFramePtr buffer;
AudioTensorConverter converter;
AVCodecContextPtr codec_ctx;
AudioOutputStream(
AVFormatContext* format_ctx,
AVSampleFormat src_fmt,
AVCodecContextPtr&& codec_ctx);
void write_chunk(const torch::Tensor& waveform) override;
~AudioOutputStream() override = default;
};
} // namespace torchaudio::io
#include <torchaudio/csrc/ffmpeg/stream_writer/converter.h>
namespace torchaudio::io {
using Iterator = SlicingTensorConverter::Iterator;
using ConvertFunc = SlicingTensorConverter::ConvertFunc;
////////////////////////////////////////////////////////////////////////////////
// SlicingTensorConverter
////////////////////////////////////////////////////////////////////////////////
SlicingTensorConverter::SlicingTensorConverter(
torch::Tensor frames_,
AVFrame* buff,
ConvertFunc& func,
int64_t step_)
: frames(std::move(frames_)),
buffer(buff),
convert_func(func),
step(step_) {}
Iterator SlicingTensorConverter::begin() const {
return Iterator{frames, buffer, convert_func, step};
}
int64_t SlicingTensorConverter::end() const {
return frames.size(0);
}
////////////////////////////////////////////////////////////////////////////////
// Iterator
////////////////////////////////////////////////////////////////////////////////
Iterator::Iterator(
const torch::Tensor frames_,
AVFrame* buffer_,
ConvertFunc& convert_func_,
int64_t step_)
: frames(frames_),
buffer(buffer_),
convert_func(convert_func_),
step(step_) {}
Iterator& Iterator::operator++() {
i += step;
return *this;
}
AVFrame* Iterator::operator*() const {
using namespace torch::indexing;
convert_func(frames.index({Slice{i, i + step}}), buffer);
return buffer;
}
bool Iterator::operator!=(const int64_t end) const {
// This is used for detecting the end of iteraton.
// For audio, iteration is done by
return i < end;
}
} // namespace torchaudio::io
#pragma once
#include <torch/types.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
namespace torchaudio::io {
//////////////////////////////////////////////////////////////////////////////
// SlicingTensorConverter
//////////////////////////////////////////////////////////////////////////////
// SlicingTensorConverter class is responsible for implementing an interface
// compatible with range-based for loop interface (begin and end).
class SlicingTensorConverter {
public:
// Convert function writes input frame Tensor to destinatoin AVFrame
// both tensor input and AVFrame are expected to be valid and properly
// allocated. (i.e. glorified copy). It is used in Iterator.
using ConvertFunc = std::function<void(const torch::Tensor&, AVFrame*)>;
////////////////////////////////////////////////////////////////////////////
// Iterator
////////////////////////////////////////////////////////////////////////////
// Iterator class is responsible for implementing iterator protocol, that is
// increment, comaprison against, and dereference (applying conversion
// function in it).
class Iterator {
// Tensor to be sliced
// - audio: NC, CPU, uint8|int16|float|double
// - video: NCHW or NHWC, CPU or CUDA, uint8
// It will be sliced at dereference time.
const torch::Tensor frames;
// Output buffer (not owned, but modified by Iterator)
AVFrame* buffer;
// Function that converts one frame Tensor into AVFrame.
ConvertFunc& convert_func;
// Index
int64_t step;
int64_t i = 0;
public:
Iterator(
const torch::Tensor tensor,
AVFrame* buffer,
ConvertFunc& convert_func,
int64_t step);
Iterator& operator++();
AVFrame* operator*() const;
bool operator!=(const int64_t other) const;
};
private:
// Input Tensor:
// - video: NCHW, CPU|CUDA, uint8,
// - audio: NC, CPU, uin8|int16|int32|in64|float32|double
torch::Tensor frames;
// Output buffer (not owned, passed to iterator)
AVFrame* buffer;
// ops: not owned.
ConvertFunc& convert_func;
int64_t step;
public:
SlicingTensorConverter(
torch::Tensor frames,
AVFrame* buffer,
ConvertFunc& convert_func,
int64_t step = 1);
[[nodiscard]] Iterator begin() const;
[[nodiscard]] int64_t end() const;
};
} // namespace torchaudio::io
#include <torchaudio/csrc/ffmpeg/stream_writer/encode_process.h>
namespace torchaudio::io {
namespace {
AVCodecContextPtr get_codec_ctx(
enum AVMediaType type,
AVFORMAT_CONST AVOutputFormat* oformat,
const c10::optional<std::string>& encoder) {
enum AVCodecID default_codec = [&]() {
switch (type) {
case AVMEDIA_TYPE_AUDIO:
return oformat->audio_codec;
case AVMEDIA_TYPE_VIDEO:
return oformat->video_codec;
default:
TORCH_CHECK(
false, "Unsupported media type: ", av_get_media_type_string(type));
}
}();
TORCH_CHECK(
default_codec != AV_CODEC_ID_NONE,
"Format \"",
oformat->name,
"\" does not support ",
av_get_media_type_string(type),
".");
const AVCodec* codec = [&]() {
if (encoder) {
const AVCodec* c = avcodec_find_encoder_by_name(encoder.value().c_str());
TORCH_CHECK(c, "Unexpected codec: ", encoder.value());
return c;
}
const AVCodec* c = avcodec_find_encoder(default_codec);
TORCH_CHECK(
c, "Encoder not found for codec: ", avcodec_get_name(default_codec));
return c;
}();
AVCodecContext* ctx = avcodec_alloc_context3(codec);
TORCH_CHECK(ctx, "Failed to allocate CodecContext.");
if (oformat->flags & AVFMT_GLOBALHEADER) {
ctx->flags |= AV_CODEC_FLAG_GLOBAL_HEADER;
}
return AVCodecContextPtr(ctx);
}
std::vector<int> get_supported_sample_rates(const AVCodec* codec) {
std::vector<int> ret;
if (codec->supported_samplerates) {
const int* t = codec->supported_samplerates;
while (*t) {
ret.push_back(*t);
++t;
}
}
return ret;
}
std::vector<std::string> get_supported_sample_fmts(const AVCodec* codec) {
std::vector<std::string> ret;
if (codec->sample_fmts) {
const enum AVSampleFormat* t = codec->sample_fmts;
while (*t != AV_SAMPLE_FMT_NONE) {
ret.emplace_back(av_get_sample_fmt_name(*t));
++t;
}
}
return ret;
}
std::vector<uint64_t> get_supported_channel_layouts(const AVCodec* codec) {
std::vector<uint64_t> ret;
if (codec->channel_layouts) {
const uint64_t* t = codec->channel_layouts;
while (*t) {
ret.push_back(*t);
++t;
}
}
return ret;
}
void configure_audio_codec(
AVCodecContextPtr& ctx,
int64_t sample_rate,
int64_t num_channels,
const c10::optional<std::string>& format) {
// TODO: Review options and make them configurable?
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00122
// - bit_rate
// - bit_rate_tolerance
ctx->sample_rate = [&]() -> int {
auto rates = get_supported_sample_rates(ctx->codec);
if (rates.empty()) {
return static_cast<int>(sample_rate);
}
for (const auto& it : rates) {
if (it == sample_rate) {
return static_cast<int>(sample_rate);
}
}
TORCH_CHECK(
false,
ctx->codec->name,
" does not support sample rate ",
sample_rate,
". Supported sample rates are: ",
c10::Join(", ", rates));
}();
ctx->time_base = av_inv_q(av_d2q(sample_rate, 1 << 24));
ctx->sample_fmt = [&]() {
// Use default
if (!format) {
TORCH_CHECK(
ctx->codec->sample_fmts,
ctx->codec->name,
" does not have default sample format. Please specify one.");
return ctx->codec->sample_fmts[0];
}
// Use the given one.
auto fmt = format.value();
auto ret = av_get_sample_fmt(fmt.c_str());
auto fmts = get_supported_sample_fmts(ctx->codec);
if (fmts.empty()) {
TORCH_CHECK(
ret != AV_SAMPLE_FMT_NONE, "Unrecognized format: ", fmt, ". ");
return ret;
}
TORCH_CHECK(
std::count(fmts.begin(), fmts.end(), fmt),
"Unsupported sample format: ",
fmt,
". Supported values are ",
c10::Join(", ", fmts));
return ret;
}();
// validate and set channels
ctx->channels = static_cast<int>(num_channels);
auto layout = av_get_default_channel_layout(ctx->channels);
auto layouts = get_supported_channel_layouts(ctx->codec);
if (!layouts.empty()) {
if (!std::count(layouts.begin(), layouts.end(), layout)) {
std::vector<std::string> tmp;
for (const auto& it : layouts) {
tmp.push_back(std::to_string(av_get_channel_layout_nb_channels(it)));
}
TORCH_CHECK(
false,
"Unsupported channels: ",
num_channels,
". Supported channels are: ",
c10::Join(", ", tmp));
}
}
ctx->channel_layout = static_cast<uint64_t>(layout);
}
void open_codec(
AVCodecContextPtr& codec_ctx,
const c10::optional<OptionDict>& option) {
AVDictionary* opt = get_option_dict(option);
int ret = avcodec_open2(codec_ctx, codec_ctx->codec, &opt);
clean_up_dict(opt);
TORCH_CHECK(ret >= 0, "Failed to open codec: (", av_err2string(ret), ")");
}
AVCodecContextPtr get_audio_codec(
AVFORMAT_CONST AVOutputFormat* oformat,
int64_t sample_rate,
int64_t num_channels,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format) {
AVCodecContextPtr ctx = get_codec_ctx(AVMEDIA_TYPE_AUDIO, oformat, encoder);
configure_audio_codec(ctx, sample_rate, num_channels, encoder_format);
open_codec(ctx, encoder_option);
return ctx;
}
FilterGraph get_audio_filter(
AVSampleFormat src_fmt,
AVCodecContext* codec_ctx) {
auto desc = [&]() -> std::string {
if (src_fmt == codec_ctx->sample_fmt) {
return "anull";
} else {
std::stringstream ss;
ss << "aformat=" << av_get_sample_fmt_name(codec_ctx->sample_fmt);
return ss.str();
}
}();
FilterGraph p{AVMEDIA_TYPE_AUDIO};
p.add_audio_src(
src_fmt,
codec_ctx->time_base,
codec_ctx->sample_rate,
codec_ctx->channel_layout);
p.add_sink();
p.add_process(desc);
p.create_filter();
return p;
}
AVFramePtr get_audio_frame(
AVSampleFormat src_fmt,
int sample_rate,
int num_channels,
AVCodecContext* codec_ctx,
int default_frame_size = 10000) {
AVFramePtr frame{};
frame->pts = 0;
frame->format = src_fmt;
// note: channels attribute is not required for encoding, but TensorConverter
// refers to it
frame->channels = num_channels;
frame->channel_layout = codec_ctx->channel_layout;
frame->sample_rate = sample_rate;
frame->nb_samples =
codec_ctx->frame_size ? codec_ctx->frame_size : default_frame_size;
if (frame->nb_samples) {
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0,
"Error allocating an audio buffer (",
av_err2string(ret),
").");
}
return frame;
}
std::vector<std::string> get_supported_pix_fmts(const AVCodec* codec) {
std::vector<std::string> ret;
if (codec->pix_fmts) {
const enum AVPixelFormat* t = codec->pix_fmts;
while (*t != AV_PIX_FMT_NONE) {
ret.emplace_back(av_get_pix_fmt_name(*t));
++t;
}
}
return ret;
}
std::vector<AVRational> get_supported_frame_rates(const AVCodec* codec) {
std::vector<AVRational> ret;
if (codec->supported_framerates) {
const AVRational* t = codec->supported_framerates;
while (!(t->num == 0 && t->den == 0)) {
ret.push_back(*t);
++t;
}
}
return ret;
}
// used to compare frame rate / sample rate.
// not a general purpose float comparison
bool is_rate_close(double rate, AVRational rational) {
double ref =
static_cast<double>(rational.num) / static_cast<double>(rational.den);
// frame rates / sample rates
static const double threshold = 0.001;
return fabs(rate - ref) < threshold;
}
void configure_video_codec(
AVCodecContextPtr& ctx,
double frame_rate,
int64_t width,
int64_t height,
const c10::optional<std::string>& format) {
// TODO: Review other options and make them configurable?
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00147
// - bit_rate
// - bit_rate_tolerance
// - gop_size
// - max_b_frames
// - mb_decisions
ctx->width = static_cast<int>(width);
ctx->height = static_cast<int>(height);
ctx->time_base = [&]() {
AVRational ret = av_inv_q(av_d2q(frame_rate, 1 << 24));
auto rates = get_supported_frame_rates(ctx->codec);
// Codec does not have constraint on frame rate
if (rates.empty()) {
return ret;
}
// Codec has list of supported frame rate.
for (const auto& t : rates) {
if (is_rate_close(frame_rate, t)) {
return ret;
}
}
// Given one is not supported.
std::vector<std::string> tmp;
for (const auto& t : rates) {
tmp.emplace_back(
t.den == 1 ? std::to_string(t.num)
: std::to_string(t.num) + "/" + std::to_string(t.den));
}
TORCH_CHECK(
false,
"Unsupported frame rate: ",
frame_rate,
". Supported values are ",
c10::Join(", ", tmp));
}();
ctx->pix_fmt = [&]() {
// Use default
if (!format) {
TORCH_CHECK(
ctx->codec->pix_fmts,
ctx->codec->name,
" does not have defaut pixel format. Please specify one.");
return ctx->codec->pix_fmts[0];
}
// Use the given one,
auto fmt = format.value();
auto ret = av_get_pix_fmt(fmt.c_str());
auto fmts = get_supported_pix_fmts(ctx->codec);
if (fmts.empty()) {
TORCH_CHECK(ret != AV_PIX_FMT_NONE, "Unrecognized format: ", fmt, ". ");
return ret;
}
if (!std::count(fmts.begin(), fmts.end(), fmt)) {
TORCH_CHECK(
false,
"Unsupported pixel format: ",
fmt,
". Supported values are ",
c10::Join(", ", fmts));
}
return ret;
}();
}
void configure_hw_accel(AVCodecContext* ctx, const std::string& hw_accel) {
torch::Device device{hw_accel};
TORCH_CHECK(
device.type() == c10::DeviceType::CUDA,
"Only CUDA is supported for hardware acceleration. Found: ",
device.str());
// NOTES:
// 1. Examples like
// https://ffmpeg.org/doxygen/4.1/hw_decode_8c-example.html#a9 wraps the HW
// device context and the HW frames context with av_buffer_ref. This
// increments the reference counting and the resource won't be automatically
// dallocated at the time AVCodecContex is destructed. (We will need to
// decrement once ourselves), so we do not do it. When adding support to share
// context objects, this needs to be reviewed.
//
// 2. When encoding, it is technically not necessary to attach HW device
// context to AVCodecContext. But this way, it will be deallocated
// automatically at the time AVCodecContext is freed, so we do that.
int ret = av_hwdevice_ctx_create(
&ctx->hw_device_ctx,
AV_HWDEVICE_TYPE_CUDA,
std::to_string(device.index()).c_str(),
nullptr,
0);
TORCH_CHECK(
ret >= 0, "Failed to create CUDA device context: ", av_err2string(ret));
assert(ctx->hw_device_ctx);
ctx->sw_pix_fmt = ctx->pix_fmt;
ctx->pix_fmt = AV_PIX_FMT_CUDA;
ctx->hw_frames_ctx = av_hwframe_ctx_alloc(ctx->hw_device_ctx);
TORCH_CHECK(ctx->hw_frames_ctx, "Failed to create CUDA frame context.");
auto frames_ctx = (AVHWFramesContext*)(ctx->hw_frames_ctx->data);
frames_ctx->format = ctx->pix_fmt;
frames_ctx->sw_format = ctx->sw_pix_fmt;
frames_ctx->width = ctx->width;
frames_ctx->height = ctx->height;
frames_ctx->initial_pool_size = 5;
ret = av_hwframe_ctx_init(ctx->hw_frames_ctx);
TORCH_CHECK(
ret >= 0,
"Failed to initialize CUDA frame context: ",
av_err2string(ret));
}
AVCodecContextPtr get_video_codec(
AVFORMAT_CONST AVOutputFormat* oformat,
double frame_rate,
int64_t width,
int64_t height,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel) {
AVCodecContextPtr ctx = get_codec_ctx(AVMEDIA_TYPE_VIDEO, oformat, encoder);
configure_video_codec(ctx, frame_rate, width, height, encoder_format);
if (hw_accel) {
#ifdef USE_CUDA
configure_hw_accel(ctx, hw_accel.value());
#else
TORCH_CHECK(
false,
"torchaudio is not compiled with CUDA support. ",
"Hardware acceleration is not available.");
#endif
}
open_codec(ctx, encoder_option);
return ctx;
}
FilterGraph get_video_filter(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) {
auto desc = [&]() -> std::string {
if (src_fmt == codec_ctx->pix_fmt ||
codec_ctx->pix_fmt == AV_PIX_FMT_CUDA) {
return "null";
} else {
std::stringstream ss;
ss << "format=" << av_get_pix_fmt_name(codec_ctx->pix_fmt);
return ss.str();
}
}();
FilterGraph p{AVMEDIA_TYPE_VIDEO};
p.add_video_src(
src_fmt,
codec_ctx->time_base,
codec_ctx->width,
codec_ctx->height,
codec_ctx->sample_aspect_ratio);
p.add_sink();
p.add_process(desc);
p.create_filter();
return p;
}
AVFramePtr get_video_frame(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) {
AVFramePtr frame{};
if (codec_ctx->hw_frames_ctx) {
int ret = av_hwframe_get_buffer(codec_ctx->hw_frames_ctx, frame, 0);
TORCH_CHECK(ret >= 0, "Failed to fetch CUDA frame: ", av_err2string(ret));
} else {
frame->format = src_fmt;
frame->width = codec_ctx->width;
frame->height = codec_ctx->height;
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0,
"Error allocating a video buffer (",
av_err2string(ret),
").");
}
frame->pts = 0;
return frame;
}
} // namespace
EncodeProcess::EncodeProcess(
AVFormatContext* format_ctx,
int sample_rate,
int num_channels,
const enum AVSampleFormat format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format)
: codec_ctx(get_audio_codec(
format_ctx->oformat,
sample_rate,
num_channels,
encoder,
encoder_option,
encoder_format)),
encoder(format_ctx, codec_ctx),
filter(get_audio_filter(format, codec_ctx)),
src_frame(get_audio_frame(format, sample_rate, num_channels, codec_ctx)),
converter(AVMEDIA_TYPE_AUDIO, src_frame, src_frame->nb_samples) {}
EncodeProcess::EncodeProcess(
AVFormatContext* format_ctx,
double frame_rate,
int width,
int height,
const enum AVPixelFormat format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel)
: codec_ctx(get_video_codec(
format_ctx->oformat,
frame_rate,
width,
height,
encoder,
encoder_option,
encoder_format,
hw_accel)),
encoder(format_ctx, codec_ctx),
filter(get_video_filter(format, codec_ctx)),
src_frame(get_video_frame(format, codec_ctx)),
converter(AVMEDIA_TYPE_VIDEO, src_frame) {}
void EncodeProcess::process(AVMediaType type, const torch::Tensor& tensor) {
TORCH_CHECK(
codec_ctx->codec_type == type,
"Attempted to write ",
av_get_media_type_string(type),
" to ",
av_get_media_type_string(codec_ctx->codec_type),
" stream.");
AVRational codec_tb = codec_ctx->time_base;
for (const auto& frame : converter.convert(tensor)) {
process_frame(frame);
if (type == AVMEDIA_TYPE_VIDEO) {
frame->pts += 1;
} else {
AVRational sr_tb{1, codec_ctx->sample_rate};
frame->pts += av_rescale_q(frame->nb_samples, sr_tb, codec_tb);
}
}
}
void EncodeProcess::process_frame(AVFrame* src) {
int ret = filter.add_frame(src);
while (ret >= 0) {
ret = filter.get_frame(dst_frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
if (ret == AVERROR_EOF) {
encoder.encode(nullptr);
}
break;
}
if (ret >= 0) {
encoder.encode(dst_frame);
}
av_frame_unref(dst_frame);
}
}
void EncodeProcess::flush() {
process_frame(nullptr);
}
} // namespace torchaudio::io
#pragma once
#include <torch/types.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/filter_graph.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/encoder.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/tensor_converter.h>
namespace torchaudio::io {
class EncodeProcess {
// In the reverse order of the process
AVCodecContextPtr codec_ctx;
Encoder encoder;
AVFramePtr dst_frame{};
FilterGraph filter;
AVFramePtr src_frame;
TensorConverter converter;
public:
// Constructor for audio
EncodeProcess(
AVFormatContext* format_ctx,
int sample_rate,
int num_channels,
const enum AVSampleFormat format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format);
// constructor for video
EncodeProcess(
AVFormatContext* format_ctx,
double frame_rate,
int width,
int height,
const enum AVPixelFormat format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel);
void process(AVMediaType type, const torch::Tensor& tensor);
void process_frame(AVFrame* src);
void flush();
};
}; // namespace torchaudio::io
#include <torchaudio/csrc/ffmpeg/stream_writer/output_stream.h>
namespace torchaudio::io {
OutputStream::OutputStream(
AVFormatContext* format_ctx,
AVCodecContext* codec_ctx_,
FilterGraph&& filter_)
: codec_ctx(codec_ctx_),
encoder(format_ctx, codec_ctx),
filter(std::move(filter_)),
dst_frame() {}
void OutputStream::process_frame(AVFrame* src) {
int ret = filter.add_frame(src);
while (ret >= 0) {
ret = filter.get_frame(dst_frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
if (ret == AVERROR_EOF) {
encoder.encode(nullptr);
}
break;
}
if (ret >= 0) {
encoder.encode(dst_frame);
}
av_frame_unref(dst_frame);
}
}
void OutputStream::flush() {
process_frame(nullptr);
}
} // namespace torchaudio::io
#pragma once
#include <torch/types.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/filter_graph.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/encoder.h>
namespace torchaudio::io {
struct OutputStream {
// Reference to codec context
AVCodecContext* codec_ctx;
// Encoder + Muxer
Encoder encoder;
// Filter for additional processing
FilterGraph filter;
// frame that output from FilterGraph is written
AVFramePtr dst_frame;
OutputStream(
AVFormatContext* format_ctx,
AVCodecContext* codec_ctx,
FilterGraph&& filter);
virtual void write_chunk(const torch::Tensor& input) = 0;
void process_frame(AVFrame* src);
void flush();
virtual ~OutputStream() = default;
};
} // namespace torchaudio::io
#include <torchaudio/csrc/ffmpeg/stream_writer/audio_output_stream.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/video_output_stream.h>
#ifdef USE_CUDA
#include <c10/cuda/CUDAStream.h>
......@@ -56,368 +54,6 @@ StreamWriter::StreamWriter(
: StreamWriter(get_output_format_context(dst, format, nullptr)) {}
namespace {
std::vector<std::string> get_supported_pix_fmts(const AVCodec* codec) {
std::vector<std::string> ret;
if (codec->pix_fmts) {
const enum AVPixelFormat* t = codec->pix_fmts;
while (*t != AV_PIX_FMT_NONE) {
ret.emplace_back(av_get_pix_fmt_name(*t));
++t;
}
}
return ret;
}
std::vector<AVRational> get_supported_frame_rates(const AVCodec* codec) {
std::vector<AVRational> ret;
if (codec->supported_framerates) {
const AVRational* t = codec->supported_framerates;
while (!(t->num == 0 && t->den == 0)) {
ret.push_back(*t);
++t;
}
}
return ret;
}
// used to compare frame rate / sample rate.
// not a general purpose float comparison
bool is_rate_close(double rate, AVRational rational) {
double ref =
static_cast<double>(rational.num) / static_cast<double>(rational.den);
// frame rates / sample rates
static const double threshold = 0.001;
return fabs(rate - ref) < threshold;
}
std::vector<std::string> get_supported_sample_fmts(const AVCodec* codec) {
std::vector<std::string> ret;
if (codec->sample_fmts) {
const enum AVSampleFormat* t = codec->sample_fmts;
while (*t != AV_SAMPLE_FMT_NONE) {
ret.emplace_back(av_get_sample_fmt_name(*t));
++t;
}
}
return ret;
}
std::vector<int> get_supported_sample_rates(const AVCodec* codec) {
std::vector<int> ret;
if (codec->supported_samplerates) {
const int* t = codec->supported_samplerates;
while (*t) {
ret.push_back(*t);
++t;
}
}
return ret;
}
std::vector<uint64_t> get_supported_channel_layouts(const AVCodec* codec) {
std::vector<uint64_t> ret;
if (codec->channel_layouts) {
const uint64_t* t = codec->channel_layouts;
while (*t) {
ret.push_back(*t);
++t;
}
}
return ret;
}
void configure_audio_codec(
AVCodecContextPtr& ctx,
int64_t sample_rate,
int64_t num_channels,
const c10::optional<std::string>& format) {
// TODO: Review options and make them configurable?
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00122
// - bit_rate
// - bit_rate_tolerance
ctx->sample_rate = [&]() -> int {
auto rates = get_supported_sample_rates(ctx->codec);
if (rates.empty()) {
return static_cast<int>(sample_rate);
}
for (const auto& it : rates) {
if (it == sample_rate) {
return static_cast<int>(sample_rate);
}
}
TORCH_CHECK(
false,
ctx->codec->name,
" does not support sample rate ",
sample_rate,
". Supported sample rates are: ",
c10::Join(", ", rates));
}();
ctx->time_base = av_inv_q(av_d2q(sample_rate, 1 << 24));
ctx->sample_fmt = [&]() {
// Use default
if (!format) {
TORCH_CHECK(
ctx->codec->sample_fmts,
ctx->codec->name,
" does not have default sample format. Please specify one.");
return ctx->codec->sample_fmts[0];
}
// Use the given one.
auto fmt = format.value();
auto ret = av_get_sample_fmt(fmt.c_str());
auto fmts = get_supported_sample_fmts(ctx->codec);
if (fmts.empty()) {
TORCH_CHECK(
ret != AV_SAMPLE_FMT_NONE, "Unrecognized format: ", fmt, ". ");
return ret;
}
TORCH_CHECK(
std::count(fmts.begin(), fmts.end(), fmt),
"Unsupported sample format: ",
fmt,
". Supported values are ",
c10::Join(", ", fmts));
return ret;
}();
// validate and set channels
ctx->channels = static_cast<int>(num_channels);
auto layout = av_get_default_channel_layout(ctx->channels);
auto layouts = get_supported_channel_layouts(ctx->codec);
if (!layouts.empty()) {
if (!std::count(layouts.begin(), layouts.end(), layout)) {
std::vector<std::string> tmp;
for (const auto& it : layouts) {
tmp.push_back(std::to_string(av_get_channel_layout_nb_channels(it)));
}
TORCH_CHECK(
false,
"Unsupported channels: ",
num_channels,
". Supported channels are: ",
c10::Join(", ", tmp));
}
}
ctx->channel_layout = static_cast<uint64_t>(layout);
}
void configure_video_codec(
AVCodecContextPtr& ctx,
double frame_rate,
int64_t width,
int64_t height,
const c10::optional<std::string>& format) {
// TODO: Review other options and make them configurable?
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00147
// - bit_rate
// - bit_rate_tolerance
// - gop_size
// - max_b_frames
// - mb_decisions
ctx->width = static_cast<int>(width);
ctx->height = static_cast<int>(height);
ctx->time_base = [&]() {
AVRational ret = av_inv_q(av_d2q(frame_rate, 1 << 24));
auto rates = get_supported_frame_rates(ctx->codec);
// Codec does not have constraint on frame rate
if (rates.empty()) {
return ret;
}
// Codec has list of supported frame rate.
for (const auto& t : rates) {
if (is_rate_close(frame_rate, t)) {
return ret;
}
}
// Given one is not supported.
std::vector<std::string> tmp;
for (const auto& t : rates) {
tmp.emplace_back(
t.den == 1 ? std::to_string(t.num)
: std::to_string(t.num) + "/" + std::to_string(t.den));
}
TORCH_CHECK(
false,
"Unsupported frame rate: ",
frame_rate,
". Supported values are ",
c10::Join(", ", tmp));
}();
ctx->pix_fmt = [&]() {
// Use default
if (!format) {
TORCH_CHECK(
ctx->codec->pix_fmts,
ctx->codec->name,
" does not have defaut pixel format. Please specify one.");
return ctx->codec->pix_fmts[0];
}
// Use the given one,
auto fmt = format.value();
auto ret = av_get_pix_fmt(fmt.c_str());
auto fmts = get_supported_pix_fmts(ctx->codec);
if (fmts.empty()) {
TORCH_CHECK(ret != AV_PIX_FMT_NONE, "Unrecognized format: ", fmt, ". ");
return ret;
}
if (!std::count(fmts.begin(), fmts.end(), fmt)) {
TORCH_CHECK(
false,
"Unsupported pixel format: ",
fmt,
". Supported values are ",
c10::Join(", ", fmts));
}
return ret;
}();
}
void open_codec(
AVCodecContextPtr& codec_ctx,
const c10::optional<OptionDict>& option) {
AVDictionary* opt = get_option_dict(option);
int ret = avcodec_open2(codec_ctx, codec_ctx->codec, &opt);
clean_up_dict(opt);
TORCH_CHECK(ret >= 0, "Failed to open codec: (", av_err2string(ret), ")");
}
AVCodecContextPtr get_codec_ctx(
enum AVMediaType type,
AVFORMAT_CONST AVOutputFormat* oformat,
const c10::optional<std::string>& encoder) {
enum AVCodecID default_codec = [&]() {
switch (type) {
case AVMEDIA_TYPE_AUDIO:
return oformat->audio_codec;
case AVMEDIA_TYPE_VIDEO:
return oformat->video_codec;
default:
TORCH_CHECK(
false, "Unsupported media type: ", av_get_media_type_string(type));
}
}();
TORCH_CHECK(
default_codec != AV_CODEC_ID_NONE,
"Format \"",
oformat->name,
"\" does not support ",
av_get_media_type_string(type),
".");
const AVCodec* codec = [&]() {
if (encoder) {
const AVCodec* c = avcodec_find_encoder_by_name(encoder.value().c_str());
TORCH_CHECK(c, "Unexpected codec: ", encoder.value());
return c;
}
const AVCodec* c = avcodec_find_encoder(default_codec);
TORCH_CHECK(
c, "Encoder not found for codec: ", avcodec_get_name(default_codec));
return c;
}();
AVCodecContext* ctx = avcodec_alloc_context3(codec);
TORCH_CHECK(ctx, "Failed to allocate CodecContext.");
if (oformat->flags & AVFMT_GLOBALHEADER) {
ctx->flags |= AV_CODEC_FLAG_GLOBAL_HEADER;
}
return AVCodecContextPtr(ctx);
}
AVCodecContextPtr get_audio_codec(
AVFORMAT_CONST AVOutputFormat* oformat,
int64_t sample_rate,
int64_t num_channels,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format) {
AVCodecContextPtr ctx = get_codec_ctx(AVMEDIA_TYPE_AUDIO, oformat, encoder);
configure_audio_codec(ctx, sample_rate, num_channels, encoder_format);
open_codec(ctx, encoder_option);
return ctx;
}
void configure_hw_accel(AVCodecContext* ctx, const std::string& hw_accel) {
torch::Device device{hw_accel};
TORCH_CHECK(
device.type() == c10::DeviceType::CUDA,
"Only CUDA is supported for hardware acceleration. Found: ",
device.str());
// NOTES:
// 1. Examples like
// https://ffmpeg.org/doxygen/4.1/hw_decode_8c-example.html#a9 wraps the HW
// device context and the HW frames context with av_buffer_ref. This
// increments the reference counting and the resource won't be automatically
// dallocated at the time AVCodecContex is destructed. (We will need to
// decrement once ourselves), so we do not do it. When adding support to share
// context objects, this needs to be reviewed.
//
// 2. When encoding, it is technically not necessary to attach HW device
// context to AVCodecContext. But this way, it will be deallocated
// automatically at the time AVCodecContext is freed, so we do that.
int ret = av_hwdevice_ctx_create(
&ctx->hw_device_ctx,
AV_HWDEVICE_TYPE_CUDA,
std::to_string(device.index()).c_str(),
nullptr,
0);
TORCH_CHECK(
ret >= 0, "Failed to create CUDA device context: ", av_err2string(ret));
assert(ctx->hw_device_ctx);
ctx->sw_pix_fmt = ctx->pix_fmt;
ctx->pix_fmt = AV_PIX_FMT_CUDA;
ctx->hw_frames_ctx = av_hwframe_ctx_alloc(ctx->hw_device_ctx);
TORCH_CHECK(ctx->hw_frames_ctx, "Failed to create CUDA frame context.");
auto frames_ctx = (AVHWFramesContext*)(ctx->hw_frames_ctx->data);
frames_ctx->format = ctx->pix_fmt;
frames_ctx->sw_format = ctx->sw_pix_fmt;
frames_ctx->width = ctx->width;
frames_ctx->height = ctx->height;
frames_ctx->initial_pool_size = 5;
ret = av_hwframe_ctx_init(ctx->hw_frames_ctx);
TORCH_CHECK(
ret >= 0,
"Failed to initialize CUDA frame context: ",
av_err2string(ret));
}
AVCodecContextPtr get_video_codec(
AVFORMAT_CONST AVOutputFormat* oformat,
double frame_rate,
int64_t width,
int64_t height,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel) {
AVCodecContextPtr ctx = get_codec_ctx(AVMEDIA_TYPE_VIDEO, oformat, encoder);
configure_video_codec(ctx, frame_rate, width, height, encoder_format);
if (hw_accel) {
#ifdef USE_CUDA
configure_hw_accel(ctx, hw_accel.value());
#else
TORCH_CHECK(
false,
"torchaudio is not compiled with CUDA support. ",
"Hardware acceleration is not available.");
#endif
}
open_codec(ctx, encoder_option);
return ctx;
}
enum AVSampleFormat get_src_sample_fmt(const std::string& src) {
auto fmt = av_get_sample_fmt(src.c_str());
......@@ -466,16 +102,14 @@ void StreamWriter::add_audio_stream(
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format) {
streams.emplace_back(std::make_unique<AudioOutputStream>(
processes.emplace_back(
pFormatContext,
get_src_sample_fmt(format),
get_audio_codec(
pFormatContext->oformat,
sample_rate,
num_channels,
get_src_sample_fmt(format),
encoder,
encoder_option,
encoder_format)));
encoder_format);
}
void StreamWriter::add_video_stream(
......@@ -487,18 +121,16 @@ void StreamWriter::add_video_stream(
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel) {
streams.emplace_back(std::make_unique<VideoOutputStream>(
processes.emplace_back(
pFormatContext,
get_src_pixel_fmt(format),
get_video_codec(
pFormatContext->oformat,
frame_rate,
width,
height,
get_src_pixel_fmt(format),
encoder,
encoder_option,
encoder_format,
hw_accel)));
hw_accel);
}
void StreamWriter::set_metadata(const OptionDict& metadata) {
......@@ -566,35 +198,29 @@ void StreamWriter::close() {
}
}
void StreamWriter::validate_stream(int i, enum AVMediaType type) {
void StreamWriter::write_audio_chunk(int i, const torch::Tensor& waveform) {
TORCH_CHECK(
0 <= i && i < static_cast<int>(streams.size()),
0 <= i && i < static_cast<int>(processes.size()),
"Invalid stream index. Index must be in range of [0, ",
streams.size(),
processes.size(),
"). Found: ",
i);
TORCH_CHECK(
streams[i]->codec_ctx->codec_type == type,
"Stream ",
i,
" is not ",
av_get_media_type_string(type));
}
void StreamWriter::write_audio_chunk(int i, const torch::Tensor& waveform) {
validate_stream(i, AVMEDIA_TYPE_AUDIO);
streams[i]->write_chunk(waveform);
processes[i].process(AVMEDIA_TYPE_AUDIO, waveform);
}
void StreamWriter::write_video_chunk(int i, const torch::Tensor& frames) {
validate_stream(i, AVMEDIA_TYPE_VIDEO);
streams[i]->write_chunk(frames);
TORCH_CHECK(
0 <= i && i < static_cast<int>(processes.size()),
"Invalid stream index. Index must be in range of [0, ",
processes.size(),
"). Found: ",
i);
processes[i].process(AVMEDIA_TYPE_VIDEO, frames);
}
void StreamWriter::flush() {
for (auto& os : streams) {
os->flush();
for (auto& p : processes) {
p.flush();
}
}
......
......@@ -3,7 +3,7 @@
#include <torch/torch.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/filter_graph.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/output_stream.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/encode_process.h>
namespace torchaudio {
namespace io {
......@@ -14,7 +14,7 @@ namespace io {
class StreamWriter {
AVFormatOutputContextPtr pFormatContext;
AVBufferRefPtr pHWBufferRef;
std::vector<std::unique_ptr<OutputStream>> streams;
std::vector<EncodeProcess> processes;
AVPacketPtr pkt;
protected:
......@@ -171,9 +171,6 @@ class StreamWriter {
void write_video_chunk(int i, const torch::Tensor& chunk);
/// Flush the frames from encoders and write the frames to the destination.
void flush();
private:
void validate_stream(int i, enum AVMediaType);
};
} // namespace io
......
#include <torchaudio/csrc/ffmpeg/stream_writer/video_converter.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/tensor_converter.h>
#ifdef USE_CUDA
#include <c10/cuda/CUDAStream.h>
......@@ -6,44 +6,151 @@
namespace torchaudio::io {
namespace {
using InitFunc = TensorConverter::InitFunc;
using ConvertFunc = TensorConverter::ConvertFunc;
////////////////////////////////////////////////////////////////////////////////
// VideoTensorConverter
// Audio
////////////////////////////////////////////////////////////////////////////////
using InitFunc = VideoTensorConverter::InitFunc;
using ConvertFunc = SlicingTensorConverter::ConvertFunc;
void validate_audio_input(
const torch::Tensor& t,
AVFrame* buffer,
c10::ScalarType dtype) {
TORCH_CHECK(
t.dtype().toScalarType() == dtype,
"Expected ",
dtype,
" type. Found: ",
t.dtype().toScalarType());
TORCH_CHECK(t.device().is_cpu(), "Input tensor has to be on CPU.");
TORCH_CHECK(t.dim() == 2, "Input Tensor has to be 2D.");
TORCH_CHECK(
t.size(1) == buffer->channels,
"Expected waveform with ",
buffer->channels,
" channels. Found ",
t.size(1));
}
// 2D (time, channel) and contiguous.
void convert_func_(const torch::Tensor& chunk, AVFrame* buffer) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(chunk.dim() == 2);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(chunk.size(1) == buffer->channels);
namespace {
// TODO: make writable
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00334
TORCH_CHECK(av_frame_is_writable(buffer), "frame is not writable.");
auto byte_size = chunk.numel() * chunk.element_size();
memcpy(buffer->data[0], chunk.data_ptr(), byte_size);
buffer->nb_samples = static_cast<int>(chunk.size(0));
}
std::pair<InitFunc, ConvertFunc> get_audio_func(AVFrame* buffer) {
auto dtype = [&]() -> c10::ScalarType {
switch (static_cast<AVSampleFormat>(buffer->format)) {
case AV_SAMPLE_FMT_U8:
return c10::ScalarType::Byte;
case AV_SAMPLE_FMT_S16:
return c10::ScalarType::Short;
case AV_SAMPLE_FMT_S32:
return c10::ScalarType::Int;
case AV_SAMPLE_FMT_S64:
return c10::ScalarType::Long;
case AV_SAMPLE_FMT_FLT:
return c10::ScalarType::Float;
case AV_SAMPLE_FMT_DBL:
return c10::ScalarType::Double;
default:
TORCH_INTERNAL_ASSERT(
false, "Audio encoding process is not properly configured.");
}
}();
InitFunc init_func = [=](const torch::Tensor& tensor, AVFrame* buffer) {
validate_audio_input(tensor, buffer, dtype);
return tensor.contiguous();
};
return {init_func, convert_func_};
}
////////////////////////////////////////////////////////////////////////////////
// Video
////////////////////////////////////////////////////////////////////////////////
void validate_video_input(
const torch::Tensor& t,
AVFrame* buffer,
int num_channels) {
if (buffer->hw_frames_ctx) {
TORCH_CHECK(t.device().is_cuda(), "Input tensor has to be on CUDA.");
} else {
TORCH_CHECK(t.device().is_cpu(), "Input tensor has to be on CPU.");
}
TORCH_CHECK(
t.dtype().toScalarType() == c10::ScalarType::Byte,
"Expected Tensor of uint8 type.");
TORCH_CHECK(t.dim() == 4, "Input Tensor has to be 4D.");
TORCH_CHECK(
t.size(1) == num_channels && t.size(2) == buffer->height &&
t.size(3) == buffer->width,
"Expected tensor with shape (N, ",
num_channels,
", ",
buffer->height,
", ",
buffer->width,
") (NCHW format). Found ",
t.sizes());
}
// NCHW ->NHWC, ensure contiguous
torch::Tensor init_interlaced(const torch::Tensor& tensor) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor.dim() == 4);
return tensor.permute({0, 2, 3, 1}).contiguous();
}
// Keep NCHW, ensure contiguous
torch::Tensor init_planar(const torch::Tensor& tensor) {
return tensor.contiguous();
}
// Interlaced video
// Each frame is composed of one plane, and color components for each pixel are
// collocated.
// The memory layout is 1D linear, interpretated as following.
//
// |<----- linesize[0] ----->|
// |<----- linesize[0] ------>|
// |<-- stride -->|
// 0 1 ... W
// 0: RGB RGB ... RGB PAD ... PAD
// 1: RGB RGB ... RGB PAD ... PAD
// ...
// H: RGB RGB ... RGB PAD ... PAD
void write_interlaced_video(const torch::Tensor& frame, AVFrame* buffer) {
TORCH_INTERNAL_ASSERT(
frame.size(0) == 1,
"The first dimension of the image dimension must be one.");
const auto height = frame.size(1);
const auto width = frame.size(2);
const auto num_channels = frame.size(3);
void write_interlaced_video(
const torch::Tensor& frame,
AVFrame* buffer,
int num_channels) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.dim() == 4);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.size(0) == 1);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.size(1) == buffer->height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.size(2) == buffer->width);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.size(3) == num_channels);
size_t stride = width * num_channels;
// TODO: writable
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00472
TORCH_INTERNAL_ASSERT(av_frame_is_writable(buffer), "frame is not writable.");
size_t stride = buffer->width * num_channels;
uint8_t* src = frame.data_ptr<uint8_t>();
uint8_t* dst = buffer->data[0];
for (int h = 0; h < height; ++h) {
for (int h = 0; h < buffer->height; ++h) {
std::memcpy(dst, src, stride);
src += width * num_channels;
src += stride;
dst += buffer->linesize[0];
}
}
......@@ -73,22 +180,24 @@ void write_planar_video(
const torch::Tensor& frame,
AVFrame* buffer,
int num_planes) {
TORCH_INTERNAL_ASSERT(
frame.size(0) == 1,
"The first dimension of the image dimension must be one.");
const auto height = frame.size(2);
const auto width = frame.size(3);
const auto num_colors =
av_pix_fmt_desc_get((AVPixelFormat)buffer->format)->nb_components;
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.dim() == 4);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.size(0) == 1);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.size(1) == num_colors);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.size(2), buffer->height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.size(3), buffer->width);
// TODO: writable
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00472
TORCH_INTERNAL_ASSERT(av_frame_is_writable(buffer), "frame is not writable.");
for (int j = 0; j < num_planes; ++j) {
for (int j = 0; j < num_colors; ++j) {
uint8_t* src = frame.index({0, j}).data_ptr<uint8_t>();
uint8_t* dst = buffer->data[j];
for (int h = 0; h < height; ++h) {
memcpy(dst, src, width);
src += width;
for (int h = 0; h < buffer->height; ++h) {
memcpy(dst, src, buffer->width);
src += buffer->width;
dst += buffer->linesize[j];
}
}
......@@ -97,19 +206,18 @@ void write_planar_video(
void write_interlaced_video_cuda(
const torch::Tensor& frame,
AVFrame* buffer,
bool pad_extra) {
int num_channels) {
#ifndef USE_CUDA
TORCH_CHECK(
false,
"torchaudio is not compiled with CUDA support. Hardware acceleration is not available.");
#else
TORCH_INTERNAL_ASSERT(
frame.size(0) == 1,
"The first dimension of the image dimension must be one.");
const auto height = frame.size(1);
const auto width = frame.size(2);
const auto num_channels = frame.size(3) + (pad_extra ? 1 : 0);
size_t spitch = width * num_channels;
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.dim() == 4);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.size(0) == 1);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.size(1) == buffer->height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.size(2) == buffer->width);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.size(3) == num_channels);
size_t spitch = buffer->width * num_channels;
if (cudaSuccess !=
cudaMemcpy2D(
(void*)(buffer->data[0]),
......@@ -117,7 +225,7 @@ void write_interlaced_video_cuda(
(const void*)(frame.data_ptr<uint8_t>()),
spitch,
spitch,
height,
buffer->height,
cudaMemcpyDeviceToDevice)) {
TORCH_CHECK(false, "Failed to copy pixel data from CUDA tensor.");
}
......@@ -133,20 +241,20 @@ void write_planar_video_cuda(
false,
"torchaudio is not compiled with CUDA support. Hardware acceleration is not available.");
#else
TORCH_INTERNAL_ASSERT(
frame.size(0) == 1,
"The first dimension of the image dimension must be one.");
const auto height = frame.size(2);
const auto width = frame.size(3);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.dim() == 4);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.size(0) == 1);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.size(1) == num_planes);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.size(2) == buffer->height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.size(3) == buffer->width);
for (int j = 0; j < num_planes; ++j) {
if (cudaSuccess !=
cudaMemcpy2D(
(void*)(buffer->data[j]),
buffer->linesize[j],
(const void*)(frame.index({0, j}).data_ptr<uint8_t>()),
width,
width,
height,
buffer->width,
buffer->width,
buffer->height,
cudaMemcpyDeviceToDevice)) {
TORCH_CHECK(false, "Failed to copy pixel data from CUDA tensor.");
}
......@@ -154,37 +262,38 @@ void write_planar_video_cuda(
#endif
}
// NCHW ->NHWC, ensure contiguous
torch::Tensor init_interlaced(const torch::Tensor& tensor) {
return tensor.permute({0, 2, 3, 1}).contiguous();
}
// Keep NCHW, ensure contiguous
torch::Tensor init_planar(const torch::Tensor& tensor) {
return tensor.contiguous();
}
std::pair<InitFunc, ConvertFunc> get_func(AVFrame* buffer) {
std::pair<InitFunc, ConvertFunc> get_video_func(AVFrame* buffer) {
if (buffer->hw_frames_ctx) {
auto frames_ctx = (AVHWFramesContext*)(buffer->hw_frames_ctx->data);
auto sw_pix_fmt = frames_ctx->sw_format;
switch (sw_pix_fmt) {
// Note:
// RGB0 / BGR0 expects 4 channel, but neither
// av_pix_fmt_desc_get(pix_fmt)->nb_components
// or av_pix_fmt_count_planes(pix_fmt) returns 4.
case AV_PIX_FMT_RGB0:
case AV_PIX_FMT_BGR0: {
ConvertFunc convert_func = [](const torch::Tensor& t, AVFrame* f) {
write_interlaced_video_cuda(t, f, true);
write_interlaced_video_cuda(t, f, 4);
};
return {init_interlaced, convert_func};
InitFunc init_func = [](const torch::Tensor& t, AVFrame* f) {
validate_video_input(t, f, 4);
return init_interlaced(t);
};
return {init_func, convert_func};
}
case AV_PIX_FMT_GBRP:
case AV_PIX_FMT_GBRP16LE:
case AV_PIX_FMT_YUV444P:
case AV_PIX_FMT_YUV444P16LE: {
auto num_planes = av_pix_fmt_count_planes(sw_pix_fmt);
ConvertFunc convert_func = [=](const torch::Tensor& t, AVFrame* f) {
write_planar_video_cuda(t, f, num_planes);
ConvertFunc convert_func = [](const torch::Tensor& t, AVFrame* f) {
write_planar_video_cuda(t, f, 3);
};
InitFunc init_func = [](const torch::Tensor& t, AVFrame* f) {
validate_video_input(t, f, 3);
return init_planar(t);
};
return {init_planar, convert_func};
return {init_func, convert_func};
}
default:
TORCH_CHECK(
......@@ -199,14 +308,25 @@ std::pair<InitFunc, ConvertFunc> get_func(AVFrame* buffer) {
case AV_PIX_FMT_GRAY8:
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24: {
return {init_interlaced, write_interlaced_video};
int channels = av_pix_fmt_desc_get(pix_fmt)->nb_components;
InitFunc init_func = [=](const torch::Tensor& t, AVFrame* f) {
validate_video_input(t, f, channels);
return init_interlaced(t);
};
ConvertFunc convert_func = [=](const torch::Tensor& t, AVFrame* f) {
write_interlaced_video(t, f, channels);
};
return {init_func, convert_func};
}
case AV_PIX_FMT_YUV444P: {
auto num_planes = av_pix_fmt_count_planes(pix_fmt);
ConvertFunc convert_func = [=](const torch::Tensor& t, AVFrame* f) {
write_planar_video(t, f, num_planes);
InitFunc init_func = [](const torch::Tensor& t, AVFrame* f) {
validate_video_input(t, f, 3);
return init_planar(t);
};
ConvertFunc convert_func = [](const torch::Tensor& t, AVFrame* f) {
write_planar_video(t, f, 3);
};
return {init_planar, convert_func};
return {init_func, convert_func};
}
default:
TORCH_CHECK(
......@@ -214,48 +334,86 @@ std::pair<InitFunc, ConvertFunc> get_func(AVFrame* buffer) {
}
}
void validate_video_input(AVFrame* buffer, const torch::Tensor& t) {
auto fmt = [&]() -> AVPixelFormat {
if (buffer->hw_frames_ctx) {
TORCH_CHECK(t.device().is_cuda(), "Input tensor has to be on CUDA.");
auto frames_ctx = (AVHWFramesContext*)(buffer->hw_frames_ctx->data);
return frames_ctx->sw_format;
} else {
TORCH_CHECK(t.device().is_cpu(), "Input tensor has to be on CPU.");
return static_cast<AVPixelFormat>(buffer->format);
} // namespace
////////////////////////////////////////////////////////////////////////////////
// TensorConverter
////////////////////////////////////////////////////////////////////////////////
TensorConverter::TensorConverter(AVMediaType type, AVFrame* buf, int buf_size)
: buffer(buf), buffer_size(buf_size) {
switch (type) {
case AVMEDIA_TYPE_AUDIO:
std::tie(init_func, convert_func) = get_audio_func(buffer);
break;
case AVMEDIA_TYPE_VIDEO:
std::tie(init_func, convert_func) = get_video_func(buffer);
break;
default:
TORCH_INTERNAL_ASSERT(
false, "Unsupported media type: ", av_get_media_type_string(type));
}
}();
}
auto dtype = t.dtype().toScalarType();
TORCH_CHECK(dtype == c10::ScalarType::Byte, "Expected Tensor of uint8 type.");
TORCH_CHECK(t.dim() == 4, "Input Tensor has to be 4D.");
using Generator = TensorConverter::Generator;
// Note: the number of color components is not same as the number of planes.
// For example, YUV420P has only two planes. U and V are in the second plane.
int num_color_components = av_pix_fmt_desc_get(fmt)->nb_components;
Generator TensorConverter::convert(const torch::Tensor& t) {
return Generator{init_func(t, buffer), buffer, convert_func, buffer_size};
}
const auto c = t.size(1), h = t.size(2), w = t.size(3);
TORCH_CHECK(
c == num_color_components && h == buffer->height && w == buffer->width,
"Expected tensor with shape (N, ",
num_color_components,
", ",
buffer->height,
", ",
buffer->width,
") (NCHW format). Found ",
t.sizes());
////////////////////////////////////////////////////////////////////////////////
// Generator
////////////////////////////////////////////////////////////////////////////////
using Iterator = Generator::Iterator;
Generator::Generator(
torch::Tensor frames_,
AVFrame* buff,
ConvertFunc& func,
int64_t step_)
: frames(std::move(frames_)),
buffer(buff),
convert_func(func),
step(step_) {}
Iterator Generator::begin() const {
return Iterator{frames, buffer, convert_func, step};
}
} // namespace
int64_t Generator::end() const {
return frames.size(0);
}
////////////////////////////////////////////////////////////////////////////////
// Iterator
////////////////////////////////////////////////////////////////////////////////
Iterator::Iterator(
const torch::Tensor frames_,
AVFrame* buffer_,
ConvertFunc& convert_func_,
int64_t step_)
: frames(frames_),
buffer(buffer_),
convert_func(convert_func_),
step(step_) {}
Iterator& Iterator::operator++() {
i += step;
return *this;
}
VideoTensorConverter::VideoTensorConverter(AVFrame* buf) : buffer(buf) {
std::tie(init_func, convert_func) = get_func(buffer);
AVFrame* Iterator::operator*() const {
using namespace torch::indexing;
convert_func(frames.index({Slice{i, i + step}}), buffer);
return buffer;
}
SlicingTensorConverter VideoTensorConverter::convert(const torch::Tensor& t) {
validate_video_input(buffer, t);
return SlicingTensorConverter{init_func(t), buffer, convert_func};
bool Iterator::operator!=(const int64_t end) const {
// This is used for detecting the end of iteraton.
// For audio, iteration is done by
return i < end;
}
} // namespace torchaudio::io
#pragma once
#include <torch/types.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
namespace torchaudio::io {
class TensorConverter {
public:
// Initialization is one-time process applied to frames before the iteration
// starts. i.e. either convert to NHWC.
using InitFunc = std::function<torch::Tensor(const torch::Tensor&, AVFrame*)>;
// Convert function writes input frame Tensor to destinatoin AVFrame
// both tensor input and AVFrame are expected to be valid and properly
// allocated. (i.e. glorified copy). It is used in Iterator.
using ConvertFunc = std::function<void(const torch::Tensor&, AVFrame*)>;
//////////////////////////////////////////////////////////////////////////////
// Generator
//////////////////////////////////////////////////////////////////////////////
// Generator class is responsible for implementing an interface
// compatible with range-based for loop interface (begin and end).
class Generator {
public:
////////////////////////////////////////////////////////////////////////////
// Iterator
////////////////////////////////////////////////////////////////////////////
// Iterator class is responsible for implementing iterator protocol, that is
// increment, comaprison against, and dereference (applying conversion
// function in it).
class Iterator {
// Tensor to be sliced
// - audio: NC, CPU, uint8|int16|float|double
// - video: NCHW or NHWC, CPU or CUDA, uint8
// It will be sliced at dereference time.
const torch::Tensor frames;
// Output buffer (not owned, but modified by Iterator)
AVFrame* buffer;
// Function that converts one frame Tensor into AVFrame.
ConvertFunc& convert_func;
// Index
int64_t step;
int64_t i = 0;
public:
Iterator(
const torch::Tensor tensor,
AVFrame* buffer,
ConvertFunc& convert_func,
int64_t step);
Iterator& operator++();
AVFrame* operator*() const;
bool operator!=(const int64_t other) const;
};
private:
// Input Tensor:
// - video: NCHW, CPU|CUDA, uint8,
// - audio: NC, CPU, uin8|int16|int32|in64|float32|double
torch::Tensor frames;
// Output buffer (not owned, passed to iterator)
AVFrame* buffer;
// ops: not owned.
ConvertFunc& convert_func;
int64_t step;
public:
Generator(
torch::Tensor frames,
AVFrame* buffer,
ConvertFunc& convert_func,
int64_t step = 1);
[[nodiscard]] Iterator begin() const;
[[nodiscard]] int64_t end() const;
};
private:
AVFrame* buffer;
const int buffer_size = 1;
InitFunc init_func{};
ConvertFunc convert_func{};
public:
TensorConverter(AVMediaType type, AVFrame* buffer, int buffer_size = 1);
Generator convert(const torch::Tensor& t);
};
} // namespace torchaudio::io
#pragma once
#include <torchaudio/csrc/ffmpeg/stream_writer/converter.h>
namespace torchaudio::io {
////////////////////////////////////////////////////////////////////////////////
// VideoTensorConverter
////////////////////////////////////////////////////////////////////////////////
// VideoTensorConverter is responsible for picking up the right set of
// conversion process (InitFunc and ConvertFunc) based on the input pixel format
// information, and own them.
class VideoTensorConverter {
public:
// Initialization is one-time process applied to frames before the iteration
// starts. i.e. either convert to NHWC.
using InitFunc = std::function<torch::Tensor(const torch::Tensor&)>;
private:
AVFrame* buffer;
InitFunc init_func{};
SlicingTensorConverter::ConvertFunc convert_func{};
public:
explicit VideoTensorConverter(AVFrame* buffer);
SlicingTensorConverter convert(const torch::Tensor& frames);
};
} // namespace torchaudio::io
#include <torchaudio/csrc/ffmpeg/stream_writer/video_output_stream.h>
#ifdef USE_CUDA
#include <c10/cuda/CUDAStream.h>
#endif
namespace torchaudio::io {
namespace {
FilterGraph get_video_filter(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) {
auto desc = [&]() -> std::string {
if (src_fmt == codec_ctx->pix_fmt ||
codec_ctx->pix_fmt == AV_PIX_FMT_CUDA) {
return "null";
} else {
std::stringstream ss;
ss << "format=" << av_get_pix_fmt_name(codec_ctx->pix_fmt);
return ss.str();
}
}();
FilterGraph p{AVMEDIA_TYPE_VIDEO};
p.add_video_src(
src_fmt,
codec_ctx->time_base,
codec_ctx->width,
codec_ctx->height,
codec_ctx->sample_aspect_ratio);
p.add_sink();
p.add_process(desc);
p.create_filter();
return p;
}
AVFramePtr get_video_frame(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) {
AVFramePtr frame{};
if (codec_ctx->hw_frames_ctx) {
int ret = av_hwframe_get_buffer(codec_ctx->hw_frames_ctx, frame, 0);
TORCH_CHECK(ret >= 0, "Failed to fetch CUDA frame: ", av_err2string(ret));
} else {
frame->format = src_fmt;
frame->width = codec_ctx->width;
frame->height = codec_ctx->height;
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0,
"Error allocating a video buffer (",
av_err2string(ret),
").");
}
frame->pts = 0;
return frame;
}
} // namespace
VideoOutputStream::VideoOutputStream(
AVFormatContext* format_ctx,
AVPixelFormat src_fmt,
AVCodecContextPtr&& codec_ctx_)
: OutputStream(
format_ctx,
codec_ctx_,
get_video_filter(src_fmt, codec_ctx_)),
buffer(get_video_frame(src_fmt, codec_ctx_)),
converter(buffer),
codec_ctx(std::move(codec_ctx_)) {}
void VideoOutputStream::write_chunk(const torch::Tensor& frames) {
for (const auto& frame : converter.convert(frames)) {
process_frame(frame);
frame->pts += 1;
}
}
} // namespace torchaudio::io
#pragma once
#include <torchaudio/csrc/ffmpeg/stream_writer/output_stream.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/video_converter.h>
namespace torchaudio::io {
struct VideoOutputStream : OutputStream {
AVFramePtr buffer;
VideoTensorConverter converter;
AVCodecContextPtr codec_ctx;
VideoOutputStream(
AVFormatContext* format_ctx,
AVPixelFormat src_fmt,
AVCodecContextPtr&& codec_ctx);
void write_chunk(const torch::Tensor& frames) override;
~VideoOutputStream() override = default;
};
} // namespace torchaudio::io
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment