Commit 72404de9 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add StreamWriter (#2628)

Summary:
This commit adds FFmpeg-based encoder StreamWriter class.
StreamWriter is pretty much the opposite of StreamReader class, and
it supports;

* Encoding audio / still image / video
* Exporting to local file / streaming protocol / devices etc...
* File-like object support (in later commit)
* HW video encoding (in later commit)

See also: https://fburl.com/gslide/z85kn5a9 (Meta internal)

Pull Request resolved: https://github.com/pytorch/audio/pull/2628

Reviewed By: nateanl

Differential Revision: D38816650

Pulled By: mthrok

fbshipit-source-id: a9343b0d55755e186971dc96fb86eb52daa003c8
parent 068fc29c
...@@ -80,7 +80,7 @@ fi ...@@ -80,7 +80,7 @@ fi
( (
set -x set -x
conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} 'librosa>=0.8.0' parameterized 'requests>=2.20' conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} 'librosa>=0.8.0' parameterized 'requests>=2.20'
pip install kaldi-io SoundFile coverage pytest pytest-cov 'scipy==1.7.3' transformers expecttest unidecode inflect Pillow sentencepiece pytorch-lightning 'protobuf<4.21.0' demucs pip install kaldi-io SoundFile coverage pytest pytest-cov 'scipy==1.7.3' transformers expecttest unidecode inflect Pillow sentencepiece pytorch-lightning 'protobuf<4.21.0' demucs tinytag
) )
# Install fairseq # Install fairseq
git clone https://github.com/pytorch/fairseq git clone https://github.com/pytorch/fairseq
......
...@@ -88,7 +88,8 @@ esac ...@@ -88,7 +88,8 @@ esac
transformers \ transformers \
unidecode \ unidecode \
'protobuf<4.21.0' \ 'protobuf<4.21.0' \
demucs demucs \
tinytag
) )
# Install fairseq # Install fairseq
git clone https://github.com/pytorch/fairseq git clone https://github.com/pytorch/fairseq
......
...@@ -33,3 +33,10 @@ StreamReaderOutputStream ...@@ -33,3 +33,10 @@ StreamReaderOutputStream
.. autoclass:: StreamReaderOutputStream .. autoclass:: StreamReaderOutputStream
:members: :members:
StreamWriter
------------
.. autoclass:: StreamWriter
:members:
import torch
import torchaudio
from parameterized import parameterized
from torchaudio_unittest.common_utils import (
get_asset_path,
is_ffmpeg_available,
nested_params,
rgb_to_yuv_ccir,
skipIfNoFFmpeg,
skipIfNoModule,
TempDirMixin,
TorchaudioTestCase,
)
if is_ffmpeg_available():
from torchaudio.io import StreamReader, StreamWriter
# TODO:
# Get rid of StreamReader and use synthetic data.
def get_audio_chunk(fmt, sample_rate, num_channels):
path = get_asset_path("nasa_13013.mp4")
s = StreamReader(path)
for _ in range(num_channels):
s.add_basic_audio_stream(-1, -1, format=fmt, sample_rate=sample_rate)
s.stream()
s.process_all_packets()
chunks = [chunk[:, :1] for chunk in s.pop_chunks()]
return torch.cat(chunks, 1)
def get_video_chunk(fmt, frame_rate, *, width, height):
path = get_asset_path("nasa_13013_no_audio.mp4")
s = StreamReader(path)
s.add_basic_video_stream(-1, -1, format=fmt, frame_rate=frame_rate, width=width, height=height)
s.stream()
s.process_all_packets()
(chunk,) = s.pop_chunks()
return chunk
@skipIfNoFFmpeg
class StreamWriterInterfaceTest(TempDirMixin, TorchaudioTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
torchaudio.utils.ffmpeg_utils.set_log_level(32)
@classmethod
def tearDownClass(cls):
torchaudio.utils.ffmpeg_utils.set_log_level(8)
super().tearDownClass()
def get_dst(self, path):
return self.get_temp_path(path)
def get_buf(self, path):
with open(self.get_temp_path(path), "rb") as fileobj:
return fileobj.read()
@skipIfNoModule("tinytag")
def test_metadata_overwrite(self):
"""When set_metadata is called multiple times, only entries from the last call are saved"""
from tinytag import TinyTag
src_fmt = "s16"
sample_rate = 8000
num_channels = 1
path = self.get_dst("test.mp3")
s = StreamWriter(path, format="mp3")
s.set_metadata(metadata={"artist": "torchaudio", "title": "foo"})
s.set_metadata(metadata={"title": self.id()})
s.add_audio_stream(sample_rate, num_channels, format=src_fmt)
chunk = get_audio_chunk(src_fmt, sample_rate, num_channels)
with s.open():
s.write_audio_chunk(0, chunk)
tag = TinyTag.get(path)
assert tag.artist is None
assert tag.title == self.id()
@nested_params(
# Note: "s64" causes UB (left shift of 1 by 63 places cannot be represented in type 'long')
# thus it's omitted.
["u8", "s16", "s32", "flt", "dbl"],
[8000, 16000, 44100],
[1, 2, 4],
)
def test_valid_audio_muxer_and_codecs_wav(self, src_fmt, sample_rate, num_channels):
"""Tensor of various dtypes can be saved as wav format."""
path = self.get_dst("test.wav")
s = StreamWriter(path, format="wav")
s.set_metadata(metadata={"artist": "torchaudio", "title": self.id()})
s.add_audio_stream(sample_rate, num_channels, format=src_fmt)
chunk = get_audio_chunk(src_fmt, sample_rate, num_channels)
with s.open():
s.write_audio_chunk(0, chunk)
@parameterized.expand(
[
("mp3", 8000, 1, "s32p", None),
("mp3", 16000, 2, "fltp", None),
("mp3", 44100, 1, "s16p", {"abr": "true"}),
("flac", 8000, 1, "s16", None),
("flac", 16000, 2, "s32", None),
("opus", 48000, 2, None, {"strict": "experimental"}),
("adts", 8000, 1, "fltp", None), # AAC format
]
)
def test_valid_audio_muxer_and_codecs(self, ext, sample_rate, num_channels, encoder_format, encoder_option):
"""Tensor of various dtypes can be saved as given format."""
path = self.get_dst(f"test.{ext}")
s = StreamWriter(path, format=ext)
s.set_metadata(metadata={"artist": "torchaudio", "title": self.id()})
s.add_audio_stream(sample_rate, num_channels, encoder_option=encoder_option, encoder_format=encoder_format)
chunk = get_audio_chunk("flt", sample_rate, num_channels)
with s.open():
s.write_audio_chunk(0, chunk)
@nested_params(
[
"gray8",
"rgb24",
"bgr24",
"yuv444p",
],
[(128, 64), (720, 576)],
)
def test_valid_video_muxer_and_codecs(self, src_format, size):
"""Image tensors of various formats can be saved as mp4"""
ext = "mp4"
frame_rate = 10
width, height = size
path = self.get_dst(f"test.{ext}")
s = StreamWriter(path, format=ext)
s.add_video_stream(frame_rate, width, height, format=src_format)
chunk = get_video_chunk(src_format, frame_rate, width=width, height=height)
with s.open():
s.write_video_chunk(0, chunk)
def test_valid_audio_video_muxer(self):
"""Audio/image tensors are saved as single video"""
ext = "mp4"
sample_rate = 16000
num_channels = 3
frame_rate = 30000 / 1001
width, height = 720, 576
video_fmt = "yuv444p"
path = self.get_dst(f"test.{ext}")
s = StreamWriter(path, format=ext)
s.set_metadata({"artist": "torchaudio", "title": self.id()})
s.add_audio_stream(sample_rate, num_channels)
s.add_video_stream(frame_rate, width, height, format=video_fmt)
audio = get_audio_chunk("flt", sample_rate, num_channels)
video = get_video_chunk(video_fmt, frame_rate, height=height, width=width)
with s.open():
s.write_audio_chunk(0, audio)
s.write_video_chunk(1, video)
@nested_params(
[
("gray8", "gray8"),
("rgb24", "rgb24"),
("bgr24", "bgr24"),
("yuv444p", "yuv444p"),
("rgb24", "yuv444p"),
("bgr24", "yuv444p"),
],
)
def test_video_raw_out(self, formats):
"""Verify that viedo out is correct with/without color space conversion"""
filename = "test.rawvideo"
frame_rate = 30000 / 1001
width, height = 720, 576
src_fmt, encoder_fmt = formats
frames = int(frame_rate * 2)
channels = 1 if src_fmt == "gray8" else 3
# Generate data
src_size = (frames, channels, height, width)
chunk = torch.randint(low=0, high=255, size=src_size, dtype=torch.uint8)
# Write data
dst = self.get_dst(filename)
s = StreamWriter(dst, format="rawvideo")
s.add_video_stream(frame_rate, width, height, format=src_fmt, encoder_format=encoder_fmt)
with s.open():
s.write_video_chunk(0, chunk)
# Fetch the written data
buf = self.get_buf(filename)
result = torch.frombuffer(buf, dtype=torch.uint8)
if encoder_fmt.endswith("p"):
result = result.reshape(src_size)
else:
result = result.reshape(frames, height, width, channels).permute(0, 3, 1, 2)
# check that they are same
if src_fmt == encoder_fmt:
expected = chunk
else:
if src_fmt == "bgr24":
chunk = chunk[:, [2, 1, 0], :, :]
expected = rgb_to_yuv_ccir(chunk)
self.assertEqual(expected, result, atol=1, rtol=0)
...@@ -149,6 +149,9 @@ if(USE_FFMPEG) ...@@ -149,6 +149,9 @@ if(USE_FFMPEG)
ffmpeg/stream_reader/stream_reader.cpp ffmpeg/stream_reader/stream_reader.cpp
ffmpeg/stream_reader/stream_reader_wrapper.cpp ffmpeg/stream_reader/stream_reader_wrapper.cpp
ffmpeg/stream_reader/stream_reader_binding.cpp ffmpeg/stream_reader/stream_reader_binding.cpp
ffmpeg/stream_writer/stream_writer.cpp
ffmpeg/stream_writer/stream_writer_wrapper.cpp
ffmpeg/stream_writer/stream_writer_binding.cpp
ffmpeg/utils.cpp ffmpeg/utils.cpp
) )
message(STATUS "FFMPEG_ROOT=$ENV{FFMPEG_ROOT}") message(STATUS "FFMPEG_ROOT=$ENV{FFMPEG_ROOT}")
......
...@@ -47,10 +47,18 @@ void AVFormatInputContextDeleter::operator()(AVFormatContext* p) { ...@@ -47,10 +47,18 @@ void AVFormatInputContextDeleter::operator()(AVFormatContext* p) {
AVFormatInputContextPtr::AVFormatInputContextPtr(AVFormatContext* p) AVFormatInputContextPtr::AVFormatInputContextPtr(AVFormatContext* p)
: Wrapper<AVFormatContext, AVFormatInputContextDeleter>(p) {} : Wrapper<AVFormatContext, AVFormatInputContextDeleter>(p) {}
void AVFormatOutputContextDeleter::operator()(AVFormatContext* p) {
avformat_free_context(p);
};
AVFormatOutputContextPtr::AVFormatOutputContextPtr(AVFormatContext* p)
: Wrapper<AVFormatContext, AVFormatOutputContextDeleter>(p) {}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// AVIO // AVIO
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
void AVIOContextDeleter::operator()(AVIOContext* p) { void AVIOContextDeleter::operator()(AVIOContext* p) {
avio_flush(p);
av_freep(&p->buffer); av_freep(&p->buffer);
av_freep(&p); av_freep(&p);
}; };
......
...@@ -93,6 +93,15 @@ struct AVFormatInputContextPtr ...@@ -93,6 +93,15 @@ struct AVFormatInputContextPtr
explicit AVFormatInputContextPtr(AVFormatContext* p); explicit AVFormatInputContextPtr(AVFormatContext* p);
}; };
struct AVFormatOutputContextDeleter {
void operator()(AVFormatContext* p);
};
struct AVFormatOutputContextPtr
: public Wrapper<AVFormatContext, AVFormatOutputContextDeleter> {
explicit AVFormatOutputContextPtr(AVFormatContext* p);
};
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// AVIO // AVIO
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
......
#include <torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h>
namespace torchaudio {
namespace ffmpeg {
namespace {
std::vector<std::string> get_supported_pix_fmts(const AVCodec* codec) {
std::vector<std::string> ret;
if (codec->pix_fmts) {
const enum AVPixelFormat* t = codec->pix_fmts;
while (*t != AV_PIX_FMT_NONE) {
ret.emplace_back(av_get_pix_fmt_name(*t));
++t;
}
}
return ret;
}
std::vector<AVRational> get_supported_frame_rates(const AVCodec* codec) {
std::vector<AVRational> ret;
if (codec->supported_framerates) {
const AVRational* t = codec->supported_framerates;
while (!(t->num == 0 && t->den == 0)) {
ret.push_back(*t);
++t;
}
}
return ret;
}
// used to compare frame rate / sample rate.
// not a general purpose float comparison
bool is_rate_close(double rate, AVRational rational) {
double ref =
static_cast<double>(rational.num) / static_cast<double>(rational.den);
// frame rates / sample rates
static const double threshold = 0.001;
return fabs(rate - ref) < threshold;
}
std::vector<std::string> get_supported_sample_fmts(const AVCodec* codec) {
std::vector<std::string> ret;
if (codec->sample_fmts) {
const enum AVSampleFormat* t = codec->sample_fmts;
while (*t != AV_SAMPLE_FMT_NONE) {
ret.emplace_back(av_get_sample_fmt_name(*t));
++t;
}
}
return ret;
}
std::vector<int> get_supported_sample_rates(const AVCodec* codec) {
std::vector<int> ret;
if (codec->supported_samplerates) {
const int* t = codec->supported_samplerates;
while (*t) {
ret.push_back(*t);
++t;
}
}
return ret;
}
std::vector<uint64_t> get_supported_channel_layouts(const AVCodec* codec) {
std::vector<uint64_t> ret;
if (codec->channel_layouts) {
const uint64_t* t = codec->channel_layouts;
while (*t) {
ret.push_back(*t);
++t;
}
}
return ret;
}
} // namespace
StreamWriter::StreamWriter(AVFormatOutputContextPtr&& p)
: pFormatContext(std::move(p)), streams(), pkt() {}
namespace {
void configure_audio_codec(
AVCodecContextPtr& ctx,
int64_t sample_rate,
int64_t num_channels,
const c10::optional<std::string>& format) {
// TODO: Review options and make them configurable?
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00122
// - bit_rate
// - bit_rate_tolerance
ctx->sample_rate = [&]() -> int {
auto rates = get_supported_sample_rates(ctx->codec);
if (rates.empty()) {
return static_cast<int>(sample_rate);
}
for (const auto& it : rates) {
if (it == sample_rate) {
return static_cast<int>(sample_rate);
}
}
TORCH_CHECK(
false,
ctx->codec->name,
" does not support sample rate ",
sample_rate,
". Supported sample rates are: ",
c10::Join(", ", rates));
}();
ctx->time_base = AVRational{1, static_cast<int>(sample_rate)};
ctx->sample_fmt = [&]() {
// Use default
if (!format) {
TORCH_CHECK(
ctx->codec->sample_fmts,
ctx->codec->name,
" does not have default sample format. Please specify one.");
return ctx->codec->sample_fmts[0];
}
// Use the given one.
auto fmt = format.value();
auto ret = av_get_sample_fmt(fmt.c_str());
auto fmts = get_supported_sample_fmts(ctx->codec);
if (fmts.empty()) {
TORCH_CHECK(
ret != AV_SAMPLE_FMT_NONE, "Unrecognized format: ", fmt, ". ");
return ret;
}
TORCH_CHECK(
std::count(fmts.begin(), fmts.end(), fmt),
"Unsupported sample format: ",
fmt,
". Supported values are ",
c10::Join(", ", fmts));
return ret;
}();
// validate and set channels
ctx->channels = static_cast<int>(num_channels);
auto layout = av_get_default_channel_layout(ctx->channels);
auto layouts = get_supported_channel_layouts(ctx->codec);
if (!layouts.empty()) {
if (!std::count(layouts.begin(), layouts.end(), layout)) {
std::vector<std::string> tmp;
for (const auto& it : layouts) {
tmp.push_back(std::to_string(av_get_channel_layout_nb_channels(it)));
}
TORCH_CHECK(
false,
"Unsupported channels: ",
num_channels,
". Supported channels are: ",
c10::Join(", ", tmp));
}
}
ctx->channel_layout = static_cast<uint64_t>(layout);
}
void configure_video_codec(
AVCodecContextPtr& ctx,
double frame_rate,
int64_t width,
int64_t height,
const c10::optional<std::string>& format) {
// TODO: Review other options and make them configurable?
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00147
// - bit_rate
// - bit_rate_tolerance
// - gop_size
// - max_b_frames
// - mb_decisions
ctx->width = static_cast<int>(width);
ctx->height = static_cast<int>(height);
ctx->time_base = [&]() {
AVRational ret = AVRational{1, static_cast<int>(frame_rate)};
auto rates = get_supported_frame_rates(ctx->codec);
// Codec does not have constraint on frame rate
if (rates.empty()) {
return ret;
}
// Codec has list of supported frame rate.
for (const auto& t : rates) {
if (is_rate_close(frame_rate, t)) {
return ret;
}
}
// Given one is not supported.
std::vector<std::string> tmp;
for (const auto& t : rates) {
tmp.emplace_back(
t.den == 1 ? std::to_string(t.num)
: std::to_string(t.num) + "/" + std::to_string(t.den));
}
TORCH_CHECK(
false,
"Unsupported frame rate: ",
frame_rate,
". Supported values are ",
c10::Join(", ", tmp));
}();
ctx->pix_fmt = [&]() {
// Use default
if (!format) {
TORCH_CHECK(
ctx->codec->pix_fmts,
ctx->codec->name,
" does not have defaut pixel format. Please specify one.");
return ctx->codec->pix_fmts[0];
}
// Use the given one,
auto fmt = format.value();
auto ret = av_get_pix_fmt(fmt.c_str());
auto fmts = get_supported_pix_fmts(ctx->codec);
if (fmts.empty()) {
TORCH_CHECK(ret != AV_PIX_FMT_NONE, "Unrecognized format: ", fmt, ". ");
return ret;
}
if (!std::count(fmts.begin(), fmts.end(), fmt)) {
TORCH_CHECK(
false,
"Unsupported pixel format: ",
fmt,
". Supported values are ",
c10::Join(", ", fmts));
}
return ret;
}();
}
void open_codec(
AVCodecContextPtr& codec_ctx,
const c10::optional<OptionDict>& option) {
AVDictionary* opt = get_option_dict(option);
int ret = avcodec_open2(codec_ctx, codec_ctx->codec, &opt);
clean_up_dict(opt);
TORCH_CHECK(
ret >= 0, "Failed to open audio codec: (", av_err2string(ret), ")");
}
AVFramePtr get_audio_frame(
enum AVSampleFormat fmt,
AVCodecContextPtr& codec_ctx,
int frame_size) {
AVFramePtr frame{};
frame->format = fmt;
frame->channel_layout = codec_ctx->channel_layout;
frame->sample_rate = codec_ctx->sample_rate;
frame->nb_samples = frame_size;
if (frame->nb_samples) {
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0,
"Error allocating an audio buffer (",
av_err2string(ret),
").");
}
return frame;
}
AVFramePtr get_video_frame(
enum AVPixelFormat fmt,
AVCodecContextPtr& codec_ctx) {
AVFramePtr frame{};
frame->format = fmt;
frame->width = codec_ctx->width;
frame->height = codec_ctx->height;
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0, "Error allocating a video buffer (", av_err2string(ret), ").");
return frame;
}
AVCodecContextPtr get_codec_ctx(
enum AVMediaType type,
AVFORMAT_CONST AVOutputFormat* oformat,
const c10::optional<std::string>& encoder) {
enum AVCodecID default_codec = [&]() {
switch (type) {
case AVMEDIA_TYPE_AUDIO:
return oformat->audio_codec;
case AVMEDIA_TYPE_VIDEO:
return oformat->video_codec;
default:
TORCH_CHECK(
false, "Unsupported media type: ", av_get_media_type_string(type));
}
}();
TORCH_CHECK(
default_codec != AV_CODEC_ID_NONE,
"Format \"",
oformat->name,
"\" does not support ",
av_get_media_type_string(type),
".");
const AVCodec* codec = [&]() {
if (encoder) {
const AVCodec* c = avcodec_find_encoder_by_name(encoder.value().c_str());
TORCH_CHECK(c, "Unexpected codec: ", encoder.value());
return c;
}
const AVCodec* c = avcodec_find_encoder(default_codec);
TORCH_CHECK(
c, "Encoder not found for codec: ", avcodec_get_name(default_codec));
return c;
}();
AVCodecContext* ctx = avcodec_alloc_context3(codec);
TORCH_CHECK(ctx, "Failed to allocate CodecContext.");
if (oformat->flags & AVFMT_GLOBALHEADER) {
ctx->flags |= AV_CODEC_FLAG_GLOBAL_HEADER;
}
return AVCodecContextPtr(ctx);
}
enum AVSampleFormat _get_src_sample_fmt(const std::string& src) {
auto fmt = av_get_sample_fmt(src.c_str());
TORCH_CHECK(fmt != AV_SAMPLE_FMT_NONE, "Unknown sample format: ", src);
TORCH_CHECK(
!av_sample_fmt_is_planar(fmt),
"Unexpected sample fotmat value. Valid values are ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_U8),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_S16),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_S32),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_S64),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_FLT),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_DBL),
". ",
"Found: ",
src);
return fmt;
}
enum AVPixelFormat _get_src_pixel_fmt(const std::string& src) {
auto fmt = av_get_pix_fmt(src.c_str());
switch (fmt) {
case AV_PIX_FMT_GRAY8:
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24:
case AV_PIX_FMT_YUV444P:
return fmt;
case AV_PIX_FMT_NONE:
TORCH_CHECK(false, "Unknown pixel format: ", src);
default:
TORCH_CHECK(false, "Unsupported pixel format: ", src);
}
}
std::unique_ptr<FilterGraph> _get_audio_filter(
enum AVSampleFormat fmt,
AVCodecContextPtr& ctx) {
std::stringstream desc;
desc << "aformat=" << av_get_sample_fmt_name(ctx->sample_fmt);
auto p = std::make_unique<FilterGraph>(AVMEDIA_TYPE_AUDIO);
p->add_audio_src(fmt, ctx->time_base, ctx->sample_rate, ctx->channel_layout);
p->add_sink();
p->add_process(desc.str());
p->create_filter();
return p;
}
std::unique_ptr<FilterGraph> _get_video_filter(
enum AVPixelFormat fmt,
AVCodecContextPtr& ctx) {
std::stringstream desc;
desc << "format=" << av_get_pix_fmt_name(ctx->pix_fmt);
auto p = std::make_unique<FilterGraph>(AVMEDIA_TYPE_VIDEO);
p->add_video_src(
fmt, ctx->time_base, ctx->width, ctx->height, ctx->sample_aspect_ratio);
p->add_sink();
p->add_process(desc.str());
p->create_filter();
return p;
}
} // namespace
void StreamWriter::add_audio_stream(
int64_t sample_rate,
int64_t num_channels,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format) {
enum AVSampleFormat src_fmt = _get_src_sample_fmt(format);
AVCodecContextPtr ctx =
get_codec_ctx(AVMEDIA_TYPE_AUDIO, pFormatContext->oformat, encoder);
configure_audio_codec(ctx, sample_rate, num_channels, encoder_format);
open_codec(ctx, encoder_option);
AVStream* stream = add_stream(ctx);
std::unique_ptr<FilterGraph> filter = src_fmt == ctx->sample_fmt
? std::unique_ptr<FilterGraph>(nullptr)
: _get_audio_filter(src_fmt, ctx);
static const int default_capacity = 10000;
int frame_capacity = ctx->frame_size ? ctx->frame_size : default_capacity;
AVFramePtr src_frame = get_audio_frame(src_fmt, ctx, frame_capacity);
AVFramePtr dst_frame = filter
? AVFramePtr{}
: get_audio_frame(ctx->sample_fmt, ctx, frame_capacity);
streams.emplace_back(OutputStream{
stream,
std::move(ctx),
std::move(filter),
std::move(src_frame),
std::move(dst_frame),
0,
frame_capacity});
}
void StreamWriter::add_video_stream(
double frame_rate,
int64_t width,
int64_t height,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format) {
enum AVPixelFormat src_fmt = _get_src_pixel_fmt(format);
AVCodecContextPtr ctx =
get_codec_ctx(AVMEDIA_TYPE_VIDEO, pFormatContext->oformat, encoder);
configure_video_codec(ctx, frame_rate, width, height, encoder_format);
open_codec(ctx, encoder_option);
AVStream* stream = add_stream(ctx);
std::unique_ptr<FilterGraph> filter = src_fmt == ctx->pix_fmt
? std::unique_ptr<FilterGraph>(nullptr)
: _get_video_filter(src_fmt, ctx);
AVFramePtr src_frame = get_video_frame(src_fmt, ctx);
AVFramePtr dst_frame =
filter ? AVFramePtr{} : get_video_frame(ctx->pix_fmt, ctx);
streams.emplace_back(OutputStream{
stream,
std::move(ctx),
std::move(filter),
std::move(src_frame),
std::move(dst_frame),
0,
-1});
}
AVStream* StreamWriter::add_stream(AVCodecContextPtr& codec_ctx) {
AVStream* stream = avformat_new_stream(pFormatContext, nullptr);
TORCH_CHECK(stream, "Failed to allocate stream.");
stream->time_base = codec_ctx->time_base;
int ret = avcodec_parameters_from_context(stream->codecpar, codec_ctx);
TORCH_CHECK(
ret >= 0,
"Failed to copy the stream parameter. (",
av_err2string(ret),
")");
return stream;
}
void StreamWriter::set_metadata(const OptionDict& metadata) {
av_dict_free(&pFormatContext->metadata);
for (const auto& it : metadata) {
av_dict_set(
&pFormatContext->metadata, it.key().c_str(), it.value().c_str(), 0);
}
}
void StreamWriter::dump_format(int64_t i) {
av_dump_format(pFormatContext, (int)i, pFormatContext->url, 1);
}
void StreamWriter::open(const c10::optional<OptionDict>& option) {
int ret = 0;
// Open the file if it was not provided by client code (i.e. when not
// file-like object)
AVFORMAT_CONST AVOutputFormat* fmt = pFormatContext->oformat;
AVDictionary* opt = get_option_dict(option);
if (!(fmt->flags & AVFMT_NOFILE)) {
ret = avio_open2(
&pFormatContext->pb,
pFormatContext->url,
AVIO_FLAG_WRITE,
nullptr,
&opt);
if (ret < 0) {
av_dict_free(&opt);
TORCH_CHECK(
false,
"Failed to open dst: ",
pFormatContext->url,
" (",
av_err2string(ret),
")");
}
}
ret = avformat_write_header(pFormatContext, &opt);
clean_up_dict(opt);
TORCH_CHECK(
ret >= 0,
"Failed to write header: ",
pFormatContext->url,
" (",
av_err2string(ret),
")");
}
void StreamWriter::close() {
int ret = av_write_trailer(pFormatContext);
if (ret < 0) {
LOG(WARNING) << "Failed to write trailer. (" << av_err2string(ret) << ").";
}
// Close the file if it was not provided by client code (i.e. when not
// file-like object)
AVFORMAT_CONST AVOutputFormat* fmt = pFormatContext->oformat;
if (!(fmt->flags & AVFMT_NOFILE)) {
// avio_closep can be only applied to AVIOContext opened by avio_open
avio_closep(&(pFormatContext->pb));
}
}
void StreamWriter::validate_stream(int i, enum AVMediaType type) {
TORCH_CHECK(
0 <= i && i < static_cast<int>(streams.size()),
"Invalid stream index. Index must be in range of [0, ",
streams.size(),
"). Found: ",
i);
TORCH_CHECK(
streams[i].stream->codecpar->codec_type == type,
"Stream ",
i,
" is not ",
av_get_media_type_string(type));
}
void StreamWriter::process_frame(
AVFrame* src_frame,
std::unique_ptr<FilterGraph>& filter,
AVFrame* dst_frame,
AVCodecContextPtr& c,
AVStream* st) {
int ret = filter->add_frame(src_frame);
while (ret >= 0) {
ret = filter->get_frame(dst_frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
break;
}
if (ret >= 0) {
encode_frame(dst_frame, c, st);
}
av_frame_unref(dst_frame);
}
}
void StreamWriter::encode_frame(
AVFrame* frame,
AVCodecContextPtr& c,
AVStream* st) {
int ret = avcodec_send_frame(c, frame);
TORCH_CHECK(ret >= 0, "Failed to encode frame (", av_err2string(ret), ").");
while (ret >= 0) {
ret = avcodec_receive_packet(c, pkt);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
break;
} else {
TORCH_CHECK(
ret >= 0,
"Failed to fetch encoded packet (",
av_err2string(ret),
").");
}
av_packet_rescale_ts(pkt, c->time_base, st->time_base);
pkt->stream_index = st->index;
ret = av_interleaved_write_frame(pFormatContext, pkt);
TORCH_CHECK(ret >= 0, "Failed to write packet (", av_err2string(ret), ").");
}
}
namespace {
void validate_audio_input(
enum AVSampleFormat fmt,
AVCodecContext* ctx,
const torch::Tensor& t) {
auto dtype = t.dtype().toScalarType();
switch (fmt) {
case AV_SAMPLE_FMT_U8:
TORCH_CHECK(
dtype == c10::ScalarType::Byte, "Expected Tensor of uint8 type.");
break;
case AV_SAMPLE_FMT_S16:
TORCH_CHECK(
dtype == c10::ScalarType::Short, "Expected Tensor of int16 type.");
break;
case AV_SAMPLE_FMT_S32:
TORCH_CHECK(
dtype == c10::ScalarType::Int, "Expected Tensor of int32 type.");
break;
case AV_SAMPLE_FMT_S64:
TORCH_CHECK(
dtype == c10::ScalarType::Long, "Expected Tensor of int64 type.");
break;
case AV_SAMPLE_FMT_FLT:
TORCH_CHECK(
dtype == c10::ScalarType::Float, "Expected Tensor of float32 type.");
break;
case AV_SAMPLE_FMT_DBL:
TORCH_CHECK(
dtype == c10::ScalarType::Double, "Expected Tensor of float64 type.");
break;
default:
TORCH_CHECK(
false,
"Internal error: Audio encoding stream is not properly configured.");
}
TORCH_CHECK(t.device().is_cpu(), "Input tensor has to be on CPU.");
TORCH_CHECK(t.dim() == 2, "Input Tensor has to be 2D.");
const auto num_channels = t.size(1);
TORCH_CHECK(
num_channels == ctx->channels,
"Expected waveform with ",
ctx->channels,
" channels. Found ",
num_channels);
}
void validate_video_input(
enum AVPixelFormat fmt,
AVCodecContext* ctx,
const torch::Tensor& t) {
auto dtype = t.dtype().toScalarType();
TORCH_CHECK(dtype == c10::ScalarType::Byte, "Expected Tensor of uint8 type.");
TORCH_CHECK(t.dim() == 4, "Input Tensor has to be 4D.");
// Note: the number of color components is not same as the number of planes.
// For example, YUV420P has only two planes. U and V are in the second plane.
int num_color_components = av_pix_fmt_desc_get(fmt)->nb_components;
const auto channels = t.size(1);
const auto height = t.size(2);
const auto width = t.size(3);
TORCH_CHECK(
channels == num_color_components && height == ctx->height &&
width == ctx->width,
"Expected tensor with shape (N, ",
num_color_components,
", ",
ctx->height,
", ",
ctx->width,
") (NCHW format). Found ",
t.sizes());
}
} // namespace
void StreamWriter::write_audio_chunk(int i, const torch::Tensor& waveform) {
validate_stream(i, AVMEDIA_TYPE_AUDIO);
OutputStream& os = streams[i];
validate_audio_input(
static_cast<AVSampleFormat>(os.src_frame->format),
os.codec_ctx,
waveform);
const auto num_frames = waveform.size(0);
int64_t num_unit_frames = os.frame_capacity;
AVRational time_base{1, os.codec_ctx->sample_rate};
using namespace torch::indexing;
AT_DISPATCH_ALL_TYPES(waveform.scalar_type(), "write_audio_frames", [&] {
for (int64_t i = 0; i < num_frames; i += num_unit_frames) {
auto chunk = waveform.index({Slice(i, i + num_unit_frames), Slice()});
auto num_valid_frames = chunk.size(0);
auto byte_size = chunk.numel() * chunk.element_size();
chunk = chunk.reshape({-1}).contiguous();
// TODO: make writable
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00334
TORCH_CHECK(
av_frame_is_writable(os.src_frame),
"Internal Error: frame is not writable.");
memcpy(
os.src_frame->data[0],
static_cast<void*>(chunk.data_ptr<scalar_t>()),
byte_size);
os.src_frame->pts =
av_rescale_q(os.num_frames, time_base, os.codec_ctx->time_base);
os.src_frame->nb_samples = num_valid_frames;
os.num_frames += num_valid_frames;
if (os.filter) {
process_frame(
os.src_frame, os.filter, os.dst_frame, os.codec_ctx, os.stream);
} else {
encode_frame(os.src_frame, os.codec_ctx, os.stream);
}
}
});
}
void StreamWriter::write_video_chunk(int i, const torch::Tensor& frames) {
validate_stream(i, AVMEDIA_TYPE_VIDEO);
OutputStream& os = streams[i];
enum AVPixelFormat fmt = static_cast<AVPixelFormat>(os.src_frame->format);
TORCH_CHECK(frames.device().is_cpu(), "Input tensor has to be on CPU.");
validate_video_input(fmt, os.codec_ctx, frames);
switch (fmt) {
case AV_PIX_FMT_GRAY8:
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24:
write_interlaced_video(os, frames);
return;
case AV_PIX_FMT_YUV444P:
write_planar_video(os, frames, av_pix_fmt_count_planes(fmt));
return;
default:
TORCH_CHECK(false, "Unexpected pixel format: ", av_get_pix_fmt_name(fmt));
}
}
// Interlaced video
// Each frame is composed of one plane, and color components for each pixel are
// collocated.
// The memory layout is 1D linear, interpretated as following.
//
// |<----- linesize[0] ----->|
// 0 1 ... W
// 0: RGB RGB ... RGB PAD ... PAD
// 1: RGB RGB ... RGB PAD ... PAD
// ...
// H: RGB RGB ... RGB PAD ... PAD
void StreamWriter::write_interlaced_video(
OutputStream& os,
const torch::Tensor& frames) {
const auto num_frames = frames.size(0);
const auto num_channels = frames.size(1);
const auto height = frames.size(2);
const auto width = frames.size(3);
using namespace torch::indexing;
size_t stride = width * num_channels;
for (int i = 0; i < num_frames; ++i) {
// TODO: writable
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00472
TORCH_CHECK(
av_frame_is_writable(os.src_frame),
"Internal Error: frame is not writable.");
// CHW -> HWC
auto chunk =
frames.index({i}).permute({1, 2, 0}).reshape({-1}).contiguous();
uint8_t* src = chunk.data_ptr<uint8_t>();
uint8_t* dst = os.src_frame->data[0];
for (int h = 0; h < height; ++h) {
std::memcpy(dst, src, stride);
src += width * num_channels;
dst += os.src_frame->linesize[0];
}
os.src_frame->pts = os.num_frames;
os.num_frames += 1;
if (os.filter) {
process_frame(
os.src_frame, os.filter, os.dst_frame, os.codec_ctx, os.stream);
} else {
encode_frame(os.src_frame, os.codec_ctx, os.stream);
}
}
}
// Planar video
// Each frame is composed of multiple planes.
// One plane can contain one of more color components.
// (but at the moment only accept formats without subsampled color components)
//
// The memory layout is interpreted as follow
//
// |<----- linesize[0] ----->|
// 0 1 ... W1
// 0: Y Y ... Y PAD ... PAD
// 1: Y Y ... Y PAD ... PAD
// ...
// H1: Y Y ... Y PAD ... PAD
//
// |<--- linesize[1] ---->|
// 0 ... W2
// 0: UV ... UV PAD ... PAD
// 1: UV ... UV PAD ... PAD
// ...
// H2: UV ... UV PAD ... PAD
//
void StreamWriter::write_planar_video(
OutputStream& os,
const torch::Tensor& frames,
int num_planes) {
const auto num_frames = frames.size(0);
const auto height = frames.size(2);
const auto width = frames.size(3);
using namespace torch::indexing;
for (int i = 0; i < num_frames; ++i) {
// TODO: writable
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00472
TORCH_CHECK(
av_frame_is_writable(os.src_frame),
"Internal Error: frame is not writable.");
for (int j = 0; j < num_planes; ++j) {
auto chunk = frames.index({i, j}).contiguous();
uint8_t* src = chunk.data_ptr<uint8_t>();
uint8_t* dst = os.src_frame->data[j];
for (int h = 0; h < height; ++h) {
memcpy(dst, src, width);
src += width;
dst += os.src_frame->linesize[j];
}
}
os.src_frame->pts = os.num_frames;
os.num_frames += 1;
if (os.filter) {
process_frame(
os.src_frame, os.filter, os.dst_frame, os.codec_ctx, os.stream);
} else {
encode_frame(os.src_frame, os.codec_ctx, os.stream);
}
}
}
// TODO: probably better to flush output streams in interweaving manner.
void StreamWriter::flush() {
for (auto& os : streams) {
flush_stream(os);
}
}
void StreamWriter::flush_stream(OutputStream& os) {
if (os.filter) {
process_frame(nullptr, os.filter, os.dst_frame, os.codec_ctx, os.stream);
}
encode_frame(nullptr, os.codec_ctx, os.stream);
}
} // namespace ffmpeg
} // namespace torchaudio
#pragma once
#include <torch/torch.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/filter_graph.h>
namespace torchaudio {
namespace ffmpeg {
struct OutputStream {
AVStream* stream;
AVCodecContextPtr codec_ctx;
std::unique_ptr<FilterGraph> filter;
AVFramePtr src_frame;
AVFramePtr dst_frame;
// The number of samples written so far
int64_t num_frames;
// Audio-only: The maximum frames that frame can hold
int64_t frame_capacity;
};
class StreamWriter {
AVFormatOutputContextPtr pFormatContext;
std::vector<OutputStream> streams;
AVPacketPtr pkt;
public:
explicit StreamWriter(AVFormatOutputContextPtr&& p);
// Non-copyable
StreamWriter(const StreamWriter&) = delete;
StreamWriter& operator=(const StreamWriter&) = delete;
//////////////////////////////////////////////////////////////////////////////
// Query methods
//////////////////////////////////////////////////////////////////////////////
public:
// Print the configured outputs
void dump_format(int64_t i);
//////////////////////////////////////////////////////////////////////////////
// Configure methods
//////////////////////////////////////////////////////////////////////////////
public:
void add_audio_stream(
int64_t sample_rate,
int64_t num_channels,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format);
void add_video_stream(
double frame_rate,
int64_t width,
int64_t height,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format);
void set_metadata(const OptionDict& metadata);
private:
AVStream* add_stream(AVCodecContextPtr& ctx);
//////////////////////////////////////////////////////////////////////////////
// Write methods
//////////////////////////////////////////////////////////////////////////////
public:
void open(const c10::optional<OptionDict>& opt);
void close();
void write_audio_chunk(int i, const torch::Tensor& chunk);
void write_video_chunk(int i, const torch::Tensor& chunk);
void flush();
private:
void validate_stream(int i, enum AVMediaType);
void write_planar_video(
OutputStream& os,
const torch::Tensor& chunk,
int num_planes);
void write_interlaced_video(OutputStream& os, const torch::Tensor& chunk);
void process_frame(
AVFrame* src_frame,
std::unique_ptr<FilterGraph>& filter,
AVFrame* dst_frame,
AVCodecContextPtr& c,
AVStream* st);
void encode_frame(AVFrame* dst_frame, AVCodecContextPtr& c, AVStream* st);
void flush_stream(OutputStream& os);
};
} // namespace ffmpeg
} // namespace torchaudio
#include <torch/script.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/stream_writer_wrapper.h>
namespace torchaudio {
namespace ffmpeg {
namespace {
c10::intrusive_ptr<StreamWriterBinding> init(
const std::string& dst,
const c10::optional<std::string>& format) {
return c10::make_intrusive<StreamWriterBinding>(
get_output_format_context(dst, format));
}
using S = const c10::intrusive_ptr<StreamWriterBinding>&;
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.class_<StreamWriterBinding>("ffmpeg_StreamWriter")
.def(torch::init<>(init))
.def(
"add_audio_stream",
[](S s,
int64_t sample_rate,
int64_t num_channels,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format) {
s->add_audio_stream(
sample_rate,
num_channels,
format,
encoder,
encoder_option,
encoder_format);
})
.def(
"add_video_stream",
[](S s,
double frame_rate,
int64_t width,
int64_t height,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format) {
s->add_video_stream(
frame_rate,
width,
height,
format,
encoder,
encoder_option,
encoder_format);
})
.def(
"set_metadata",
[](S s, const OptionDict& metadata) { s->set_metadata(metadata); })
.def("dump_format", [](S s, int64_t i) { s->dump_format(i); })
.def(
"open",
[](S s, const c10::optional<OptionDict>& option) { s->open(option); })
.def("close", [](S s) { s->close(); })
.def(
"write_audio_chunk",
[](S s, int64_t i, const torch::Tensor& chunk) {
s->write_audio_chunk(static_cast<int>(i), chunk);
})
.def(
"write_video_chunk",
[](S s, int64_t i, const torch::Tensor& chunk) {
s->write_video_chunk(static_cast<int>(i), chunk);
})
.def("flush", [](S s) { s->flush(); });
}
} // namespace
} // namespace ffmpeg
} // namespace torchaudio
#include <torchaudio/csrc/ffmpeg/stream_writer/stream_writer_wrapper.h>
namespace torchaudio {
namespace ffmpeg {
AVFormatOutputContextPtr get_output_format_context(
const std::string& dst,
const c10::optional<std::string>& format) {
AVFormatContext* p = avformat_alloc_context();
TORCH_CHECK(p, "Failed to allocate AVFormatContext.");
int ret = avformat_alloc_output_context2(
&p, nullptr, format ? format.value().c_str() : nullptr, dst.c_str());
TORCH_CHECK(
ret >= 0,
"Failed to open output \"",
dst,
"\" (",
av_err2string(ret),
").");
return AVFormatOutputContextPtr(p);
}
StreamWriterBinding::StreamWriterBinding(AVFormatOutputContextPtr&& p)
: StreamWriter(std::move(p)) {}
} // namespace ffmpeg
} // namespace torchaudio
#pragma once
#include <torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h>
namespace torchaudio {
namespace ffmpeg {
// create format context for writing media
AVFormatOutputContextPtr get_output_format_context(
const std::string& dst,
const c10::optional<std::string>& format);
class StreamWriterBinding : public StreamWriter,
public torch::CustomClassHolder {
public:
explicit StreamWriterBinding(AVFormatOutputContextPtr&& p);
};
} // namespace ffmpeg
} // namespace torchaudio
import torchaudio import torchaudio
_LAZILY_IMPORTED = [ _STREAM_READER = [
"StreamReader", "StreamReader",
"StreamReaderSourceStream", "StreamReaderSourceStream",
"StreamReaderSourceAudioStream", "StreamReaderSourceAudioStream",
...@@ -8,15 +8,28 @@ _LAZILY_IMPORTED = [ ...@@ -8,15 +8,28 @@ _LAZILY_IMPORTED = [
"StreamReaderOutputStream", "StreamReaderOutputStream",
] ]
_STREAM_WRITER = [
"StreamWriter",
]
_LAZILY_IMPORTED = _STREAM_READER + _STREAM_WRITER
def __getattr__(name: str): def __getattr__(name: str):
if name in _LAZILY_IMPORTED: if name in _LAZILY_IMPORTED:
torchaudio._extension._init_ffmpeg() torchaudio._extension._init_ffmpeg()
from . import _stream_reader if name in _STREAM_READER:
from . import _stream_reader
item = getattr(_stream_reader, name)
else:
from . import _stream_writer
item = getattr(_stream_writer, name)
item = getattr(_stream_reader, name)
globals()[name] = item globals()[name] = item
return item return item
raise AttributeError(f"module {__name__} has no attribute {name}") raise AttributeError(f"module {__name__} has no attribute {name}")
......
from typing import Dict, Optional
import torch
def _format_doc(**kwargs):
def decorator(obj):
obj.__doc__ = obj.__doc__.format(**kwargs)
return obj
return decorator
_encoder = """The name of the encoder to be used.
When provided, use the specified encoder instead of the default one.
To list the available encoders, you can use ``ffmpeg -h encoders`` command.
Default: ``None``."""
_encoder_option = """Options passed to encoder.
Mapping from str to str.
To list encoder options for a encoder, you can use
``ffmpeg -h encoder=<ENCODER>`` command.
Default: ``None``."""
_encoder_format = """Format used to encode media.
When encoder supports multiple formats, passing this argument will override
the format used for encoding.
To list supported formats for the encoder, you can use
``ffmpeg -h encoder=<ENCODER>`` command.
Default: ``None``."""
_format_common_args = _format_doc(
encoder=_encoder,
encoder_option=_encoder_option,
encoder_format=_encoder_format,
)
class StreamWriter:
"""Encode and write audio/video streams chunk by chunk
Args:
dst (str): The destination where the encoded data are written.
The supported value depends on the FFmpeg found in the system.
format (str or None, optional):
Override the output format, or specify the output media device.
Default: ``None`` (no override nor device output).
This argument serves two different use cases.
1) Override the output format.
This is useful when writing raw data or in a format different from the extension.
2) Specify the output device.
This allows to output media streams to hardware devices,
such as speaker and video screen.
.. note::
This option roughly corresponds to ``-f`` option of ``ffmpeg`` command.
Please refer to the ffmpeg documentations for possible values.
https://ffmpeg.org/ffmpeg-formats.html#Muxers
Use `ffmpeg -muxers` to list the values available in the current environment.
For device access, the available values vary based on hardware (AV device) and
software configuration (ffmpeg build).
Please refer to the ffmpeg documentations for possible values.
https://ffmpeg.org/ffmpeg-devices.html#Output-Devices
Use `ffmpeg -devices` to list the values available in the current environment.
"""
def __init__(
self,
dst: str,
format: Optional[str] = None,
):
self._s = torch.classes.torchaudio.ffmpeg_StreamWriter(dst, format)
self._is_open = False
@_format_common_args
def add_audio_stream(
self,
sample_rate: int,
num_channels: int,
format: str = "flt",
encoder: Optional[str] = None,
encoder_option: Optional[Dict[str, str]] = None,
encoder_format: Optional[str] = None,
):
"""Add an output audio stream.
Args:
sample_rate (int): The sample rate.
num_channels (int): The number of channels.
format (str, optional): Input sample format, which determines the dtype
of the input tensor.
- ``"u8"``: The input tensor must be ``torch.uint8`` type.
- ``"s16"``: The input tensor must be ``torch.int16`` type.
- ``"s32"``: The input tensor must be ``torch.int32`` type.
- ``"s64"``: The input tensor must be ``torch.int64`` type.
- ``"flt"``: The input tensor must be ``torch.float32`` type.
- ``"dbl"``: The input tensor must be ``torch.float64`` type.
Default: ``"flt"``.
encoder (str or None, optional): {encoder}
encoder_option (dict or None, optional): {encoder_option}
encoder_format (str or None, optional): {encoder_format}
"""
self._s.add_audio_stream(sample_rate, num_channels, format, encoder, encoder_option, encoder_format)
@_format_common_args
def add_video_stream(
self,
frame_rate: float,
width: int,
height: int,
format: str = "rgb24",
encoder: Optional[str] = None,
encoder_option: Optional[Dict[str, str]] = None,
encoder_format: Optional[str] = None,
):
"""Add an output video stream.
This method has to be called before `open` is called.
Args:
frame_rate (float): Frame rate of the video.
width (int): Width of the video frame.
height (int): Height of the video frame.
format (str, optional): Input pixel format, which determines the
color channel order of the input tensor.
- ``"gray8"``: One channel, grayscale.
- ``"rgb24"``: Three channels in the order of RGB.
- ``"bgr24"``: Three channels in the order of BGR.
- ``"yuv444p"``: Three channels in the order of YUV.
Default: ``"rgb24"``.
In either case, the input tensor has to be ``torch.uint8`` type and
the shape must be (frame, channel, height, width).
encoder (str or None, optional): {encoder}
encoder_option (dict or None, optional): {encoder_option}
encoder_format (str or None, optional): {encoder_format}
"""
self._s.add_video_stream(frame_rate, width, height, format, encoder, encoder_option, encoder_format)
def set_metadata(self, metadata: Dict[str, str]):
"""Set file-level metadata
Args:
metadata (dict or None, optional): File-level metadata.
"""
self._s.set_metadata(metadata)
def _print_output_stream(self, i: int):
"""[debug] Print the registered stream information to stdout."""
self._s.dump_format(i)
def open(self, option: Optional[Dict[str, str]] = None):
"""Open the output file / device and write the header.
Args:
option (dict or None, optional): Private options for protocol, device and muxer. See example.
Example - Protocol option
>>> s = StreamWriter(dst="rtmp://localhost:1234/live/app", format="flv")
>>> s.add_video_stream(...)
>>> # Passing protocol option `listen=1` makes StreamWriter act as RTMP server.
>>> with s.open(option={"listen": "1"}) as f:
>>> f.write_video_chunk(...)
Example - Device option
>>> s = StreamWriter("-", format="sdl")
>>> s.add_video_stream(..., encoder_format="rgb24")
>>> # Open SDL video player with fullscreen
>>> with s.open(option={"window_fullscreen": "1"}):
>>> f.write_video_chunk(...)
Example - Muxer option
>>> s = StreamWriter("foo.flac")
>>> s.add_audio_stream(...)
>>> s.set_metadata({"artist": "torchaudio contributors"})
>>> # FLAC muxer has a private option to not write the header.
>>> # The resulting file does not contain the above metadata.
>>> with s.open(option={"write_header": "false"}) as f:
>>> f.write_audio_chunk(...)
"""
if not self._is_open:
self._s.open(option)
self._is_open = True
return self
def close(self):
"""Close the output"""
if self._is_open:
self._s.close()
self._is_open = False
def write_audio_chunk(self, i: int, chunk: torch.Tensor):
"""Write the audio data
Args:
i (int): Stream index.
chunk (Tensor): Waveform tensor. Shape: `(frame, channel)`.
The ``dtype`` must match what was passed to :py:func:`add_audio_stream` method.
"""
self._s.write_audio_chunk(i, chunk)
def write_video_chunk(self, i: int, chunk: torch.Tensor):
"""Write the audio data
Args:
i (int): Stream index.
chunk (Tensor): Waveform tensor. Shape: `(frame, channel, height, width)`.
``dtype``: ``torch.uint8``.
"""
self._s.write_video_chunk(i, chunk)
def flush(self):
"""Flush the frames from encoders and write the frames to the destination."""
self._s.flush()
def __enter__(self):
"""Context manager so that the destination is closed and data are flushed automatically."""
return self
def __exit__(self, exception_type, exception_value, traceback):
"""Context manager so that the destination is closed and data are flushed automatically."""
self.flush()
self.close()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment