Commit ffeba11a authored by mayp777's avatar mayp777
Browse files

UPDATE

parent 29deb085
......@@ -2,12 +2,27 @@
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
namespace torchaudio {
namespace ffmpeg {
namespace io {
class FilterGraph {
AVMediaType media_type;
/// Used to report the output formats of filter graph.
struct FilterGraphOutputInfo {
AVMediaType type = AVMEDIA_TYPE_UNKNOWN;
int format = -1;
AVRational time_base = {1, 1};
// Audio
int sample_rate = -1;
int num_channels = -1;
AVFilterGraphPtr pFilterGraph;
// Video
AVRational frame_rate = {0, 1};
int height = -1;
int width = -1;
};
class FilterGraph {
AVFilterGraphPtr graph;
// AVFilterContext is freed as a part of AVFilterGraph
// so we do not manage the resource.
......@@ -15,7 +30,7 @@ class FilterGraph {
AVFilterContext* buffersink_ctx = nullptr;
public:
explicit FilterGraph(AVMediaType media_type);
explicit FilterGraph();
// Custom destructor to release AVFilterGraph*
~FilterGraph() = default;
// Non-copyable
......@@ -37,17 +52,29 @@ class FilterGraph {
void add_video_src(
AVPixelFormat format,
AVRational time_base,
AVRational frame_rate,
int width,
int height,
AVRational sample_aspect_ratio);
void add_src(const std::string& arg);
void add_audio_sink();
void add_sink();
void add_video_sink();
void add_process(const std::string& filter_description);
void create_filter();
void create_filter(AVBufferRef* hw_frames_ctx = nullptr);
private:
void add_src(const AVFilter* buffersrc, const std::string& arg);
void add_sink(const AVFilter* buffersrc);
//////////////////////////////////////////////////////////////////////////////
// Query methods
//////////////////////////////////////////////////////////////////////////////
public:
[[nodiscard]] FilterGraphOutputInfo get_output_info() const;
//////////////////////////////////////////////////////////////////////////////
// Streaming process
......@@ -57,5 +84,5 @@ class FilterGraph {
int get_frame(AVFrame* pOutputFrame);
};
} // namespace ffmpeg
} // namespace io
} // namespace torchaudio
#include <torchaudio/csrc/ffmpeg/hw_context.h>
namespace torchaudio::io {
namespace {
static std::mutex MUTEX;
static std::map<int, AVBufferRefPtr> CUDA_CONTEXT_CACHE;
} // namespace
AVBufferRef* get_cuda_context(int index) {
std::lock_guard<std::mutex> lock(MUTEX);
if (index == -1) {
index = 0;
}
if (CUDA_CONTEXT_CACHE.count(index) == 0) {
AVBufferRef* p = nullptr;
int ret = av_hwdevice_ctx_create(
&p, AV_HWDEVICE_TYPE_CUDA, std::to_string(index).c_str(), nullptr, 0);
TORCH_CHECK(
ret >= 0,
"Failed to create CUDA device context on device ",
index,
"(",
av_err2string(ret),
")");
assert(p);
CUDA_CONTEXT_CACHE.emplace(index, p);
return p;
}
AVBufferRefPtr& buffer = CUDA_CONTEXT_CACHE.at(index);
return buffer;
}
void clear_cuda_context_cache() {
std::lock_guard<std::mutex> lock(MUTEX);
CUDA_CONTEXT_CACHE.clear();
}
} // namespace torchaudio::io
#pragma once
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
namespace torchaudio::io {
AVBufferRef* get_cuda_context(int index);
void clear_cuda_context_cache();
} // namespace torchaudio::io
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <torchaudio/csrc/ffmpeg/pybind/stream_reader.h>
#include <torchaudio/csrc/ffmpeg/pybind/stream_writer.h>
#include <torchaudio/csrc/ffmpeg/hw_context.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/stream_reader.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h>
namespace torchaudio {
namespace ffmpeg {
namespace torchaudio::io {
namespace {
PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
py::class_<StreamWriterFileObj, c10::intrusive_ptr<StreamWriterFileObj>>(
m, "StreamWriterFileObj")
std::map<std::string, std::tuple<int64_t, int64_t, int64_t>> get_versions() {
std::map<std::string, std::tuple<int64_t, int64_t, int64_t>> ret;
#define add_version(NAME) \
{ \
int ver = NAME##_version(); \
ret.emplace( \
"lib" #NAME, \
std::make_tuple<>( \
AV_VERSION_MAJOR(ver), \
AV_VERSION_MINOR(ver), \
AV_VERSION_MICRO(ver))); \
}
add_version(avutil);
add_version(avcodec);
add_version(avformat);
add_version(avfilter);
add_version(avdevice);
return ret;
#undef add_version
}
std::map<std::string, std::string> get_demuxers(bool req_device) {
std::map<std::string, std::string> ret;
const AVInputFormat* fmt = nullptr;
void* i = nullptr;
while ((fmt = av_demuxer_iterate(&i))) {
assert(fmt);
bool is_device = [&]() {
const AVClass* avclass = fmt->priv_class;
return avclass && AV_IS_INPUT_DEVICE(avclass->category);
}();
if (req_device == is_device) {
ret.emplace(fmt->name, fmt->long_name);
}
}
return ret;
}
std::map<std::string, std::string> get_muxers(bool req_device) {
std::map<std::string, std::string> ret;
const AVOutputFormat* fmt = nullptr;
void* i = nullptr;
while ((fmt = av_muxer_iterate(&i))) {
assert(fmt);
bool is_device = [&]() {
const AVClass* avclass = fmt->priv_class;
return avclass && AV_IS_OUTPUT_DEVICE(avclass->category);
}();
if (req_device == is_device) {
ret.emplace(fmt->name, fmt->long_name);
}
}
return ret;
}
std::map<std::string, std::string> get_codecs(
AVMediaType type,
bool req_encoder) {
const AVCodec* c = nullptr;
void* i = nullptr;
std::map<std::string, std::string> ret;
while ((c = av_codec_iterate(&i))) {
assert(c);
if ((req_encoder && av_codec_is_encoder(c)) ||
(!req_encoder && av_codec_is_decoder(c))) {
if (c->type == type && c->name) {
ret.emplace(c->name, c->long_name ? c->long_name : "");
}
}
}
return ret;
}
std::vector<std::string> get_protocols(bool output) {
void* opaque = nullptr;
const char* name = nullptr;
std::vector<std::string> ret;
while ((name = avio_enum_protocols(&opaque, output))) {
assert(name);
ret.emplace_back(name);
}
return ret;
}
std::string get_build_config() {
return avcodec_configuration();
}
//////////////////////////////////////////////////////////////////////////////
// StreamReader/Writer FileObj
//////////////////////////////////////////////////////////////////////////////
struct FileObj {
py::object fileobj;
int buffer_size;
};
namespace {
static int read_func(void* opaque, uint8_t* buf, int buf_size) {
FileObj* fileobj = static_cast<FileObj*>(opaque);
buf_size = FFMIN(buf_size, fileobj->buffer_size);
int num_read = 0;
while (num_read < buf_size) {
int request = buf_size - num_read;
auto chunk = static_cast<std::string>(
static_cast<py::bytes>(fileobj->fileobj.attr("read")(request)));
auto chunk_len = chunk.length();
if (chunk_len == 0) {
break;
}
TORCH_CHECK(
chunk_len <= request,
"Requested up to ",
request,
" bytes but, received ",
chunk_len,
" bytes. The given object does not confirm to read protocol of file object.");
memcpy(buf, chunk.data(), chunk_len);
buf += chunk_len;
num_read += static_cast<int>(chunk_len);
}
return num_read == 0 ? AVERROR_EOF : num_read;
}
static int write_func(void* opaque, uint8_t* buf, int buf_size) {
FileObj* fileobj = static_cast<FileObj*>(opaque);
buf_size = FFMIN(buf_size, fileobj->buffer_size);
py::bytes b(reinterpret_cast<const char*>(buf), buf_size);
// TODO: check the return value
fileobj->fileobj.attr("write")(b);
return buf_size;
}
static int64_t seek_func(void* opaque, int64_t offset, int whence) {
// We do not know the file size.
if (whence == AVSEEK_SIZE) {
return AVERROR(EIO);
}
FileObj* fileobj = static_cast<FileObj*>(opaque);
return py::cast<int64_t>(fileobj->fileobj.attr("seek")(offset, whence));
}
} // namespace
struct StreamReaderFileObj : private FileObj, public StreamReaderCustomIO {
StreamReaderFileObj(
py::object fileobj,
const c10::optional<std::string>& format,
const c10::optional<std::map<std::string, std::string>>& option,
int buffer_size)
: FileObj{fileobj, buffer_size},
StreamReaderCustomIO(
this,
format,
buffer_size,
read_func,
py::hasattr(fileobj, "seek") ? &seek_func : nullptr,
option) {}
};
struct StreamWriterFileObj : private FileObj, public StreamWriterCustomIO {
StreamWriterFileObj(
py::object fileobj,
const c10::optional<std::string>& format,
int buffer_size)
: FileObj{fileobj, buffer_size},
StreamWriterCustomIO(
this,
format,
buffer_size,
write_func,
py::hasattr(fileobj, "seek") ? &seek_func : nullptr) {}
};
#ifndef TORCHAUDIO_FFMPEG_EXT_NAME
#error TORCHAUDIO_FFMPEG_EXT_NAME must be defined.
#endif
PYBIND11_MODULE(TORCHAUDIO_FFMPEG_EXT_NAME, m) {
m.def("init", []() { avdevice_register_all(); });
m.def("get_log_level", []() { return av_log_get_level(); });
m.def("set_log_level", [](int level) { av_log_set_level(level); });
m.def("get_versions", &get_versions);
m.def("get_muxers", []() { return get_muxers(false); });
m.def("get_demuxers", []() { return get_demuxers(false); });
m.def("get_input_devices", []() { return get_demuxers(true); });
m.def("get_build_config", &get_build_config);
m.def("get_output_devices", []() { return get_muxers(true); });
m.def("get_audio_decoders", []() {
return get_codecs(AVMEDIA_TYPE_AUDIO, false);
});
m.def("get_audio_encoders", []() {
return get_codecs(AVMEDIA_TYPE_AUDIO, true);
});
m.def("get_video_decoders", []() {
return get_codecs(AVMEDIA_TYPE_VIDEO, false);
});
m.def("get_video_encoders", []() {
return get_codecs(AVMEDIA_TYPE_VIDEO, true);
});
m.def("get_input_protocols", []() { return get_protocols(false); });
m.def("get_output_protocols", []() { return get_protocols(true); });
m.def("clear_cuda_context_cache", &clear_cuda_context_cache);
py::class_<Chunk>(m, "Chunk", py::module_local())
.def_readwrite("frames", &Chunk::frames)
.def_readwrite("pts", &Chunk::pts);
py::class_<CodecConfig>(m, "CodecConfig", py::module_local())
.def(py::init<int, int, const c10::optional<int>&, int, int>());
py::class_<StreamWriter>(m, "StreamWriter", py::module_local())
.def(py::init<const std::string&, const c10::optional<std::string>&>())
.def("set_metadata", &StreamWriter::set_metadata)
.def("add_audio_stream", &StreamWriter::add_audio_stream)
.def("add_video_stream", &StreamWriter::add_video_stream)
.def("dump_format", &StreamWriter::dump_format)
.def("open", &StreamWriter::open)
.def("write_audio_chunk", &StreamWriter::write_audio_chunk)
.def("write_video_chunk", &StreamWriter::write_video_chunk)
.def("flush", &StreamWriter::flush)
.def("close", &StreamWriter::close);
py::class_<StreamWriterFileObj>(m, "StreamWriterFileObj", py::module_local())
.def(py::init<py::object, const c10::optional<std::string>&, int64_t>())
.def("set_metadata", &StreamWriterFileObj::set_metadata)
.def("add_audio_stream", &StreamWriterFileObj::add_audio_stream)
......@@ -20,12 +243,92 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
.def("write_video_chunk", &StreamWriterFileObj::write_video_chunk)
.def("flush", &StreamWriterFileObj::flush)
.def("close", &StreamWriterFileObj::close);
py::class_<StreamReaderFileObj, c10::intrusive_ptr<StreamReaderFileObj>>(
m, "StreamReaderFileObj")
py::class_<OutputStreamInfo>(m, "OutputStreamInfo", py::module_local())
.def_readonly("source_index", &OutputStreamInfo::source_index)
.def_readonly("filter_description", &OutputStreamInfo::filter_description)
.def_property_readonly(
"media_type",
[](const OutputStreamInfo& o) -> std::string {
return av_get_media_type_string(o.media_type);
})
.def_property_readonly(
"format",
[](const OutputStreamInfo& o) -> std::string {
switch (o.media_type) {
case AVMEDIA_TYPE_AUDIO:
return av_get_sample_fmt_name((AVSampleFormat)(o.format));
case AVMEDIA_TYPE_VIDEO:
return av_get_pix_fmt_name((AVPixelFormat)(o.format));
default:
TORCH_INTERNAL_ASSERT(
false,
"FilterGraph is returning unexpected media type: ",
av_get_media_type_string(o.media_type));
}
})
.def_readonly("sample_rate", &OutputStreamInfo::sample_rate)
.def_readonly("num_channels", &OutputStreamInfo::num_channels)
.def_readonly("width", &OutputStreamInfo::width)
.def_readonly("height", &OutputStreamInfo::height)
.def_property_readonly(
"frame_rate", [](const OutputStreamInfo& o) -> double {
if (o.frame_rate.den == 0) {
TORCH_WARN(
"Invalid frame rate is found: ",
o.frame_rate.num,
"/",
o.frame_rate.den);
return -1;
}
return static_cast<double>(o.frame_rate.num) / o.frame_rate.den;
});
py::class_<SrcStreamInfo>(m, "SourceStreamInfo", py::module_local())
.def_property_readonly(
"media_type",
[](const SrcStreamInfo& s) {
return av_get_media_type_string(s.media_type);
})
.def_readonly("codec_name", &SrcStreamInfo::codec_name)
.def_readonly("codec_long_name", &SrcStreamInfo::codec_long_name)
.def_readonly("format", &SrcStreamInfo::fmt_name)
.def_readonly("bit_rate", &SrcStreamInfo::bit_rate)
.def_readonly("num_frames", &SrcStreamInfo::num_frames)
.def_readonly("bits_per_sample", &SrcStreamInfo::bits_per_sample)
.def_readonly("metadata", &SrcStreamInfo::metadata)
.def_readonly("sample_rate", &SrcStreamInfo::sample_rate)
.def_readonly("num_channels", &SrcStreamInfo::num_channels)
.def_readonly("width", &SrcStreamInfo::width)
.def_readonly("height", &SrcStreamInfo::height)
.def_readonly("frame_rate", &SrcStreamInfo::frame_rate);
py::class_<StreamReader>(m, "StreamReader", py::module_local())
.def(py::init<
const std::string&,
const c10::optional<std::string>&,
const c10::optional<OptionDict>&>())
.def("num_src_streams", &StreamReader::num_src_streams)
.def("num_out_streams", &StreamReader::num_out_streams)
.def("find_best_audio_stream", &StreamReader::find_best_audio_stream)
.def("find_best_video_stream", &StreamReader::find_best_video_stream)
.def("get_metadata", &StreamReader::get_metadata)
.def("get_src_stream_info", &StreamReader::get_src_stream_info)
.def("get_out_stream_info", &StreamReader::get_out_stream_info)
.def("seek", &StreamReader::seek)
.def("add_audio_stream", &StreamReader::add_audio_stream)
.def("add_video_stream", &StreamReader::add_video_stream)
.def("remove_stream", &StreamReader::remove_stream)
.def(
"process_packet",
py::overload_cast<const c10::optional<double>&, const double>(
&StreamReader::process_packet))
.def("process_all_packets", &StreamReader::process_all_packets)
.def("fill_buffer", &StreamReader::fill_buffer)
.def("is_buffer_ready", &StreamReader::is_buffer_ready)
.def("pop_chunks", &StreamReader::pop_chunks);
py::class_<StreamReaderFileObj>(m, "StreamReaderFileObj", py::module_local())
.def(py::init<
py::object,
const c10::optional<std::string>&,
const c10::optional<OptionMap>&,
const c10::optional<OptionDict>&,
int64_t>())
.def("num_src_streams", &StreamReaderFileObj::num_src_streams)
.def("num_out_streams", &StreamReaderFileObj::num_out_streams)
......@@ -42,12 +345,15 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
.def("add_audio_stream", &StreamReaderFileObj::add_audio_stream)
.def("add_video_stream", &StreamReaderFileObj::add_video_stream)
.def("remove_stream", &StreamReaderFileObj::remove_stream)
.def("process_packet", &StreamReaderFileObj::process_packet)
.def(
"process_packet",
py::overload_cast<const c10::optional<double>&, const double>(
&StreamReader::process_packet))
.def("process_all_packets", &StreamReaderFileObj::process_all_packets)
.def("fill_buffer", &StreamReaderFileObj::fill_buffer)
.def("is_buffer_ready", &StreamReaderFileObj::is_buffer_ready)
.def("pop_chunks", &StreamReaderFileObj::pop_chunks);
}
} // namespace
} // namespace ffmpeg
} // namespace torchaudio
} // namespace torchaudio::io
#include <torchaudio/csrc/ffmpeg/stream_reader/buffer/chunked_buffer.h>
namespace torchaudio::io::detail {
ChunkedBuffer::ChunkedBuffer(
AVRational time_base,
int frames_per_chunk_,
int num_chunks_)
: time_base(time_base),
frames_per_chunk(frames_per_chunk_),
num_chunks(num_chunks_){};
bool ChunkedBuffer::is_ready() const {
return num_buffered_frames >= frames_per_chunk;
}
void ChunkedBuffer::push_frame(torch::Tensor frame, int64_t pts_) {
using namespace torch::indexing;
// Note:
// Audio tensors contain multiple frames while video tensors contain only
// one frame. Video tensors can be regarded as special degenerated case of
// audio, so in the following, we only consider audio processing.
//
// The incoming Tensor might contain more frames than the value of
// `frames_per_chunk`.
// If we push the input tensor to dequeu as-is, then, at the trimming stage,
// the entire frames would be trimmed, this is not ideal. We want to keep
// at most `frames_per_chunk * num_chunks` frames.
// So we slice push the incoming Tensor.
//
// 1. Check if the last chunk is fully filled. If not, fill it.
//
// <----- frames per chunk ----->^
// x x x x x x x x x x x x x x x |
// x x x x x x x + + + + + + - - | num_chunks
// - - - - - - - - - - - - - - - |
// <-- filled --><--- remain --->v
// <- append->
//
if (int64_t filled = num_buffered_frames % frames_per_chunk) {
TORCH_INTERNAL_ASSERT(
chunks.size() > 0,
"There is supposed to be left over frames, but the buffer dequeue is empty.");
int64_t num_frames = frame.size(0);
int64_t remain = frames_per_chunk - filled;
int64_t append = remain < num_frames ? remain : num_frames;
torch::Tensor prev = chunks.back();
// prev[filled:filled+append] = frame[:append]
prev.index_put_(
{Slice(filled, filled + append)}, frame.index({Slice(None, append)}));
num_buffered_frames += append;
// frame = frame[append:]
frame = frame.index({Slice(append)});
pts_ += append;
}
// 2. Return if the number of input frames are smaller than the empty buffer.
// i.e. all the frames are pushed.
if (frame.numel() == 0) {
return;
}
// 3. Now the existing buffer chunks are fully filled, start adding new chunks
//
// <----- frames per chunk ----->^
// x x x x x x x x x x x x x x x |
// x x x x x x x x x x x x x x x | num_chunks
// + + + + + + + + + + + + + + + |
// <---------- append ---------->v
//
int64_t num_frames = frame.size(0);
int64_t num_splits =
num_frames / frames_per_chunk + (num_frames % frames_per_chunk ? 1 : 0);
for (int64_t i = 0; i < num_splits; ++i) {
int64_t start = i * frames_per_chunk;
// chunk = frame[i*frames_per_chunk:(i+1) * frames_per_chunk]
auto chunk = frame.index({Slice(start, start + frames_per_chunk)});
int64_t pts_val = pts_ + start;
int64_t chunk_size = chunk.size(0);
TORCH_INTERNAL_ASSERT(
chunk_size <= frames_per_chunk,
"Chunk size is larger than frames per chunk.");
if (chunk_size < frames_per_chunk) {
auto shape = chunk.sizes().vec();
shape[0] = frames_per_chunk;
auto temp = torch::empty(shape, frame.options());
temp.index_put_({Slice(None, chunk_size)}, chunk);
chunk = temp;
}
chunks.push_back(chunk);
pts.push_back(pts_val);
num_buffered_frames += chunk_size;
// Trim if num_chunks > 0
if (num_chunks > 0 && chunks.size() > num_chunks) {
TORCH_WARN_ONCE(
"The number of buffered frames exceeded the buffer size. "
"Dropping the old frames. "
"To avoid this, you can set a higher buffer_chunk_size value.");
chunks.pop_front();
num_buffered_frames -= frames_per_chunk;
}
}
}
c10::optional<Chunk> ChunkedBuffer::pop_chunk() {
using namespace torch::indexing;
if (!num_buffered_frames) {
return {};
}
torch::Tensor chunk = chunks.front();
double pts_val = double(pts.front()) * time_base.num / time_base.den;
chunks.pop_front();
pts.pop_front();
if (num_buffered_frames < frames_per_chunk) {
chunk = chunk.index({Slice(None, num_buffered_frames)});
}
num_buffered_frames -= chunk.size(0);
return {Chunk{chunk, pts_val}};
}
void ChunkedBuffer::flush() {
num_buffered_frames = 0;
chunks.clear();
}
} // namespace torchaudio::io::detail
#pragma once
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/typedefs.h>
namespace torchaudio::io::detail {
class ChunkedBuffer {
// Each AVFrame is converted to a Tensor and stored here.
std::deque<torch::Tensor> chunks;
// Time stamps corresponding the first frame of each chunk
std::deque<int64_t> pts;
AVRational time_base;
// The number of frames to return as a chunk
// If <0, then user wants to receive all the frames
const int64_t frames_per_chunk;
// The numbe of chunks to retain
const int64_t num_chunks;
// The number of currently stored chunks
// For video, one Tensor corresponds to one frame, but for audio,
// one Tensor contains multiple samples, so we track here.
int64_t num_buffered_frames = 0;
public:
ChunkedBuffer(AVRational time_base, int frames_per_chunk, int num_chunks);
bool is_ready() const;
void flush();
c10::optional<Chunk> pop_chunk();
void push_frame(torch::Tensor frame, int64_t pts_);
};
} // namespace torchaudio::io::detail
#include <torchaudio/csrc/ffmpeg/stream_reader/buffer/unchunked_buffer.h>
namespace torchaudio::io::detail {
UnchunkedBuffer::UnchunkedBuffer(AVRational time_base) : time_base(time_base){};
bool UnchunkedBuffer::is_ready() const {
return chunks.size() > 0;
}
void UnchunkedBuffer::push_frame(torch::Tensor frame, int64_t pts_) {
if (chunks.size() == 0) {
pts = double(pts_) * time_base.num / time_base.den;
}
chunks.push_back(frame);
}
c10::optional<Chunk> UnchunkedBuffer::pop_chunk() {
if (chunks.size() == 0) {
return {};
}
auto frames =
torch::cat(std::vector<torch::Tensor>{chunks.begin(), chunks.end()}, 0);
chunks.clear();
return {Chunk{frames, pts}};
}
void UnchunkedBuffer::flush() {
chunks.clear();
}
} // namespace torchaudio::io::detail
#pragma once
#include <torch/types.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/typedefs.h>
#include <deque>
namespace torchaudio::io::detail {
class UnchunkedBuffer {
// Each AVFrame is converted to a Tensor and stored here.
std::deque<torch::Tensor> chunks;
double pts = -1.;
AVRational time_base;
public:
UnchunkedBuffer(AVRational time_base);
bool is_ready() const;
void push_frame(torch::Tensor frame, int64_t pts_);
c10::optional<Chunk> pop_chunk();
void flush();
};
} // namespace torchaudio::io::detail
#include <torch/torch.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/conversion.h>
#ifdef USE_CUDA
#include <c10/cuda/CUDAStream.h>
#endif
namespace torchaudio::io {
////////////////////////////////////////////////////////////////////////////////
// Audio
////////////////////////////////////////////////////////////////////////////////
template <c10::ScalarType dtype, bool is_planar>
AudioConverter<dtype, is_planar>::AudioConverter(int c) : num_channels(c) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(num_channels > 0);
}
template <c10::ScalarType dtype, bool is_planar>
torch::Tensor AudioConverter<dtype, is_planar>::convert(const AVFrame* src) {
if constexpr (is_planar) {
torch::Tensor dst = torch::empty({num_channels, src->nb_samples}, dtype);
convert(src, dst);
return dst.permute({1, 0});
} else {
torch::Tensor dst = torch::empty({src->nb_samples, num_channels}, dtype);
convert(src, dst);
return dst;
}
}
// Converts AVFrame* into pre-allocated Tensor.
// The shape must be [C, T] if is_planar otherwise [T, C]
template <c10::ScalarType dtype, bool is_planar>
void AudioConverter<dtype, is_planar>::convert(
const AVFrame* src,
torch::Tensor& dst) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(num_channels == src->channels);
constexpr int bps = []() {
switch (dtype) {
case torch::kUInt8:
return 1;
case torch::kInt16:
return 2;
case torch::kInt32:
case torch::kFloat32:
return 4;
case torch::kInt64:
case torch::kFloat64:
return 8;
}
}();
// Note
// FFMpeg's `nb_samples` represnts the number of samples par channel.
// whereas, in torchaudio, `num_samples` is used to represent the number of
// samples across channels. torchaudio uses `num_frames` for per-channel
// samples.
if constexpr (is_planar) {
int plane_size = bps * src->nb_samples;
uint8_t* p_dst = static_cast<uint8_t*>(dst.data_ptr());
for (int i = 0; i < num_channels; ++i) {
memcpy(p_dst, src->extended_data[i], plane_size);
p_dst += plane_size;
}
} else {
int plane_size = bps * src->nb_samples * num_channels;
memcpy(dst.data_ptr(), src->extended_data[0], plane_size);
}
}
// Explicit instantiation
template class AudioConverter<torch::kUInt8, false>;
template class AudioConverter<torch::kUInt8, true>;
template class AudioConverter<torch::kInt16, false>;
template class AudioConverter<torch::kInt16, true>;
template class AudioConverter<torch::kInt32, false>;
template class AudioConverter<torch::kInt32, true>;
template class AudioConverter<torch::kInt64, false>;
template class AudioConverter<torch::kInt64, true>;
template class AudioConverter<torch::kFloat32, false>;
template class AudioConverter<torch::kFloat32, true>;
template class AudioConverter<torch::kFloat64, false>;
template class AudioConverter<torch::kFloat64, true>;
////////////////////////////////////////////////////////////////////////////////
// Image
////////////////////////////////////////////////////////////////////////////////
namespace {
torch::Tensor get_image_buffer(
at::IntArrayRef shape,
const torch::Dtype dtype = torch::kUInt8) {
return torch::empty(
shape, torch::TensorOptions().dtype(dtype).layout(torch::kStrided));
}
torch::Tensor get_image_buffer(
at::IntArrayRef shape,
torch::Device device,
const torch::Dtype dtype = torch::kUInt8) {
return torch::empty(
shape,
torch::TensorOptions()
.dtype(dtype)
.layout(torch::kStrided)
.device(device));
}
} // namespace
ImageConverterBase::ImageConverterBase(int h, int w, int c)
: height(h), width(w), num_channels(c) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(height > 0);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(width > 0);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(num_channels > 0);
}
////////////////////////////////////////////////////////////////////////////////
// Interlaced Image
////////////////////////////////////////////////////////////////////////////////
void InterlacedImageConverter::convert(const AVFrame* src, torch::Tensor& dst) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->height == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(1) == height);
int stride = width * num_channels;
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(2) * dst.size(3) == stride);
auto p_dst = dst.data_ptr<uint8_t>();
uint8_t* p_src = src->data[0];
for (int i = 0; i < height; ++i) {
memcpy(p_dst, p_src, stride);
p_src += src->linesize[0];
p_dst += stride;
}
}
torch::Tensor InterlacedImageConverter::convert(const AVFrame* src) {
torch::Tensor buffer = get_image_buffer({1, height, width, num_channels});
convert(src, buffer);
return buffer.permute({0, 3, 1, 2});
}
////////////////////////////////////////////////////////////////////////////////
// Interlaced 16 Bit Image
////////////////////////////////////////////////////////////////////////////////
void Interlaced16BitImageConverter::convert(
const AVFrame* src,
torch::Tensor& dst) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->height == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(1) == height);
int stride = width * num_channels;
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(2) * dst.size(3) == stride);
auto p_dst = dst.data_ptr<int16_t>();
uint8_t* p_src = src->data[0];
for (int i = 0; i < height; ++i) {
memcpy(p_dst, p_src, stride * 2);
p_src += src->linesize[0];
p_dst += stride;
}
// correct for int16
dst += 32768;
}
torch::Tensor Interlaced16BitImageConverter::convert(const AVFrame* src) {
torch::Tensor buffer =
get_image_buffer({1, height, width, num_channels}, torch::kInt16);
convert(src, buffer);
return buffer.permute({0, 3, 1, 2});
}
////////////////////////////////////////////////////////////////////////////////
// Planar Image
////////////////////////////////////////////////////////////////////////////////
void PlanarImageConverter::convert(const AVFrame* src, torch::Tensor& dst) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->height == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->width == width);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(1) == num_channels);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(2) == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(3) == width);
for (int i = 0; i < num_channels; ++i) {
torch::Tensor plane = dst.index({0, i});
uint8_t* p_dst = plane.data_ptr<uint8_t>();
uint8_t* p_src = src->data[i];
int linesize = src->linesize[i];
for (int h = 0; h < height; ++h) {
memcpy(p_dst, p_src, width);
p_src += linesize;
p_dst += width;
}
}
}
torch::Tensor PlanarImageConverter::convert(const AVFrame* src) {
torch::Tensor buffer = get_image_buffer({1, num_channels, height, width});
convert(src, buffer);
return buffer;
}
////////////////////////////////////////////////////////////////////////////////
// YUV420P
////////////////////////////////////////////////////////////////////////////////
YUV420PConverter::YUV420PConverter(int h, int w) : ImageConverterBase(h, w, 3) {
TORCH_WARN_ONCE(
"The output format YUV420P is selected. "
"This will be implicitly converted to YUV444P, "
"in which all the color components Y, U, V have the same dimension.");
}
void YUV420PConverter::convert(const AVFrame* src, torch::Tensor& dst) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
(AVPixelFormat)(src->format) == AV_PIX_FMT_YUV420P);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->height == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->width == width);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(1) == 3);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(2) == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(3) == width);
// Write Y plane directly
{
uint8_t* p_dst = dst.data_ptr<uint8_t>();
uint8_t* p_src = src->data[0];
for (int h = 0; h < height; ++h) {
memcpy(p_dst, p_src, width);
p_dst += width;
p_src += src->linesize[0];
}
}
// Chroma (U and V planes) are subsamapled by 2 in both vertical and
// holizontal directions.
// https://en.wikipedia.org/wiki/Chroma_subsampling
// Since we are returning data in Tensor, which has the same size for all
// color planes, we need to upsample the UV planes. PyTorch has interpolate
// function but it does not work for int16 type. So we manually copy them.
//
// block1 block2 block3 block4
// ab -> aabb = a b * a b * *
// cd aabb a b a b
// ccdd c d c d
// ccdd c d c d
//
auto block00 = dst.slice(2, 0, {}, 2).slice(3, 0, {}, 2);
auto block01 = dst.slice(2, 0, {}, 2).slice(3, 1, {}, 2);
auto block10 = dst.slice(2, 1, {}, 2).slice(3, 0, {}, 2);
auto block11 = dst.slice(2, 1, {}, 2).slice(3, 1, {}, 2);
for (int i = 1; i < 3; ++i) {
// borrow data
auto tmp = torch::from_blob(
src->data[i],
{height / 2, width / 2},
{src->linesize[i], 1},
[](void*) {},
torch::TensorOptions().dtype(torch::kUInt8).layout(torch::kStrided));
// Copy to each block
block00.slice(1, i, i + 1).copy_(tmp);
block01.slice(1, i, i + 1).copy_(tmp);
block10.slice(1, i, i + 1).copy_(tmp);
block11.slice(1, i, i + 1).copy_(tmp);
}
}
torch::Tensor YUV420PConverter::convert(const AVFrame* src) {
torch::Tensor buffer = get_image_buffer({1, num_channels, height, width});
convert(src, buffer);
return buffer;
}
////////////////////////////////////////////////////////////////////////////////
// YUV420P10LE
////////////////////////////////////////////////////////////////////////////////
YUV420P10LEConverter::YUV420P10LEConverter(int h, int w)
: ImageConverterBase(h, w, 3) {
TORCH_WARN_ONCE(
"The output format YUV420PLE is selected. "
"This will be implicitly converted to YUV444P (16-bit), "
"in which all the color components Y, U, V have the same dimension.");
}
void YUV420P10LEConverter::convert(const AVFrame* src, torch::Tensor& dst) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
(AVPixelFormat)(src->format) == AV_PIX_FMT_YUV420P10LE);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->height == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->width == width);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(1) == 3);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(2) == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(3) == width);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.dtype() == torch::kInt16);
// Write Y plane directly
{
int16_t* p_dst = dst.data_ptr<int16_t>();
uint8_t* p_src = src->data[0];
for (int h = 0; h < height; ++h) {
memcpy(p_dst, p_src, (size_t)width * 2);
p_dst += width;
p_src += src->linesize[0];
}
}
// Chroma (U and V planes) are subsamapled by 2 in both vertical and
// holizontal directions.
// https://en.wikipedia.org/wiki/Chroma_subsampling
// Since we are returning data in Tensor, which has the same size for all
// color planes, we need to upsample the UV planes. PyTorch has interpolate
// function but it does not work for int16 type. So we manually copy them.
//
// block1 block2 block3 block4
// ab -> aabb = a b * a b * *
// cd aabb a b a b
// ccdd c d c d
// ccdd c d c d
//
auto block00 = dst.slice(2, 0, {}, 2).slice(3, 0, {}, 2);
auto block01 = dst.slice(2, 0, {}, 2).slice(3, 1, {}, 2);
auto block10 = dst.slice(2, 1, {}, 2).slice(3, 0, {}, 2);
auto block11 = dst.slice(2, 1, {}, 2).slice(3, 1, {}, 2);
for (int i = 1; i < 3; ++i) {
// borrow data
auto tmp = torch::from_blob(
src->data[i],
{height / 2, width / 2},
{src->linesize[i] / 2, 1},
[](void*) {},
torch::TensorOptions().dtype(torch::kInt16).layout(torch::kStrided));
// Copy to each block
block00.slice(1, i, i + 1).copy_(tmp);
block01.slice(1, i, i + 1).copy_(tmp);
block10.slice(1, i, i + 1).copy_(tmp);
block11.slice(1, i, i + 1).copy_(tmp);
}
}
torch::Tensor YUV420P10LEConverter::convert(const AVFrame* src) {
torch::Tensor buffer =
get_image_buffer({1, num_channels, height, width}, torch::kInt16);
convert(src, buffer);
return buffer;
}
////////////////////////////////////////////////////////////////////////////////
// NV12
////////////////////////////////////////////////////////////////////////////////
NV12Converter::NV12Converter(int h, int w) : ImageConverterBase(h, w, 3) {
TORCH_WARN_ONCE(
"The output format NV12 is selected. "
"This will be implicitly converted to YUV444P, "
"in which all the color components Y, U, V have the same dimension.");
}
void NV12Converter::convert(const AVFrame* src, torch::Tensor& dst) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
(AVPixelFormat)(src->format) == AV_PIX_FMT_NV12);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->height == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->width == width);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(1) == 3);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(2) == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(3) == width);
// Write Y plane directly
{
uint8_t* p_dst = dst.data_ptr<uint8_t>();
uint8_t* p_src = src->data[0];
for (int h = 0; h < height; ++h) {
memcpy(p_dst, p_src, width);
p_dst += width;
p_src += src->linesize[0];
}
}
// Write intermediate UV plane
{
auto tmp = torch::from_blob(
src->data[1],
{height / 2, width},
{src->linesize[1], 1},
[](void*) {},
torch::TensorOptions().dtype(torch::kUInt8).layout(torch::kStrided));
tmp = tmp.view({1, height / 2, width / 2, 2}).permute({0, 3, 1, 2});
auto dst_uv = dst.slice(1, 1, 3);
dst_uv.slice(2, 0, {}, 2).slice(3, 0, {}, 2).copy_(tmp);
dst_uv.slice(2, 0, {}, 2).slice(3, 1, {}, 2).copy_(tmp);
dst_uv.slice(2, 1, {}, 2).slice(3, 0, {}, 2).copy_(tmp);
dst_uv.slice(2, 1, {}, 2).slice(3, 1, {}, 2).copy_(tmp);
}
}
torch::Tensor NV12Converter::convert(const AVFrame* src) {
torch::Tensor buffer = get_image_buffer({1, num_channels, height, width});
convert(src, buffer);
return buffer;
}
#ifdef USE_CUDA
CudaImageConverterBase::CudaImageConverterBase(const torch::Device& device)
: device(device) {}
////////////////////////////////////////////////////////////////////////////////
// NV12 CUDA
////////////////////////////////////////////////////////////////////////////////
NV12CudaConverter::NV12CudaConverter(const torch::Device& device)
: CudaImageConverterBase(device) {
TORCH_WARN_ONCE(
"The output format NV12 is selected. "
"This will be implicitly converted to YUV444P, "
"in which all the color components Y, U, V have the same dimension.");
}
void NV12CudaConverter::convert(const AVFrame* src, torch::Tensor& dst) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->height == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->width == width);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(1) == 3);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(2) == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(3) == width);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.dtype() == torch::kUInt8);
auto fmt = (AVPixelFormat)(src->format);
AVHWFramesContext* hwctx = (AVHWFramesContext*)src->hw_frames_ctx->data;
AVPixelFormat sw_fmt = hwctx->sw_format;
TORCH_INTERNAL_ASSERT(
AV_PIX_FMT_CUDA == fmt,
"Expected CUDA frame. Found: ",
av_get_pix_fmt_name(fmt));
TORCH_INTERNAL_ASSERT(
AV_PIX_FMT_NV12 == sw_fmt,
"Expected NV12 format. Found: ",
av_get_pix_fmt_name(sw_fmt));
// Write Y plane directly
auto status = cudaMemcpy2D(
dst.data_ptr(),
width,
src->data[0],
src->linesize[0],
width,
height,
cudaMemcpyDeviceToDevice);
TORCH_CHECK(cudaSuccess == status, "Failed to copy Y plane to Cuda tensor.");
// Preapare intermediate UV planes
status = cudaMemcpy2D(
tmp_uv.data_ptr(),
width,
src->data[1],
src->linesize[1],
width,
height / 2,
cudaMemcpyDeviceToDevice);
TORCH_CHECK(cudaSuccess == status, "Failed to copy UV plane to Cuda tensor.");
// Upsample width and height
namespace F = torch::nn::functional;
torch::Tensor uv = F::interpolate(
tmp_uv.permute({0, 3, 1, 2}),
F::InterpolateFuncOptions()
.mode(torch::kNearest)
.size(std::vector<int64_t>({height, width})));
// Write to the UV plane
// dst[:, 1:] = uv
using namespace torch::indexing;
dst.index_put_({Slice(), Slice(1)}, uv);
}
torch::Tensor NV12CudaConverter::convert(const AVFrame* src) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src);
if (!init) {
height = src->height;
width = src->width;
tmp_uv =
get_image_buffer({1, height / 2, width / 2, 2}, device, torch::kUInt8);
init = true;
}
torch::Tensor buffer = get_image_buffer({1, 3, height, width}, device);
convert(src, buffer);
return buffer;
}
////////////////////////////////////////////////////////////////////////////////
// P010 CUDA
////////////////////////////////////////////////////////////////////////////////
P010CudaConverter::P010CudaConverter(const torch::Device& device)
: CudaImageConverterBase{device} {
TORCH_WARN_ONCE(
"The output format P010 is selected. "
"This will be implicitly converted to YUV444P, "
"in which all the color components Y, U, V have the same dimension.");
}
void P010CudaConverter::convert(const AVFrame* src, torch::Tensor& dst) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->height == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->width == width);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(1) == 3);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(2) == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(3) == width);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.dtype() == torch::kInt16);
auto fmt = (AVPixelFormat)(src->format);
AVHWFramesContext* hwctx = (AVHWFramesContext*)src->hw_frames_ctx->data;
AVPixelFormat sw_fmt = hwctx->sw_format;
TORCH_INTERNAL_ASSERT(
AV_PIX_FMT_CUDA == fmt,
"Expected CUDA frame. Found: ",
av_get_pix_fmt_name(fmt));
TORCH_INTERNAL_ASSERT(
AV_PIX_FMT_P010 == sw_fmt,
"Expected P010 format. Found: ",
av_get_pix_fmt_name(sw_fmt));
// Write Y plane directly
auto status = cudaMemcpy2D(
dst.data_ptr(),
width * 2,
src->data[0],
src->linesize[0],
width * 2,
height,
cudaMemcpyDeviceToDevice);
TORCH_CHECK(cudaSuccess == status, "Failed to copy Y plane to CUDA tensor.");
// Prepare intermediate UV planes
status = cudaMemcpy2D(
tmp_uv.data_ptr(),
width * 2,
src->data[1],
src->linesize[1],
width * 2,
height / 2,
cudaMemcpyDeviceToDevice);
TORCH_CHECK(cudaSuccess == status, "Failed to copy UV plane to CUDA tensor.");
// Write to the UV plane
torch::Tensor uv = tmp_uv.permute({0, 3, 1, 2});
using namespace torch::indexing;
// very simplistic upscale using indexing since interpolate doesn't support
// shorts
dst.index_put_(
{Slice(), Slice(1, 3), Slice(None, None, 2), Slice(None, None, 2)}, uv);
dst.index_put_(
{Slice(), Slice(1, 3), Slice(1, None, 2), Slice(None, None, 2)}, uv);
dst.index_put_(
{Slice(), Slice(1, 3), Slice(None, None, 2), Slice(1, None, 2)}, uv);
dst.index_put_(
{Slice(), Slice(1, 3), Slice(1, None, 2), Slice(1, None, 2)}, uv);
// correct for int16
dst += 32768;
}
torch::Tensor P010CudaConverter::convert(const AVFrame* src) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src);
if (!init) {
height = src->height;
width = src->width;
tmp_uv =
get_image_buffer({1, height / 2, width / 2, 2}, device, torch::kInt16);
init = true;
}
torch::Tensor buffer =
get_image_buffer({1, 3, height, width}, device, torch::kInt16);
convert(src, buffer);
return buffer;
}
////////////////////////////////////////////////////////////////////////////////
// YUV444P CUDA
////////////////////////////////////////////////////////////////////////////////
YUV444PCudaConverter::YUV444PCudaConverter(const torch::Device& device)
: CudaImageConverterBase(device) {}
void YUV444PCudaConverter::convert(const AVFrame* src, torch::Tensor& dst) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->height == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->width == width);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(1) == 3);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(2) == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(3) == width);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.dtype() == torch::kUInt8);
auto fmt = (AVPixelFormat)(src->format);
AVHWFramesContext* hwctx = (AVHWFramesContext*)src->hw_frames_ctx->data;
AVPixelFormat sw_fmt = hwctx->sw_format;
TORCH_INTERNAL_ASSERT(
AV_PIX_FMT_CUDA == fmt,
"Expected CUDA frame. Found: ",
av_get_pix_fmt_name(fmt));
TORCH_INTERNAL_ASSERT(
AV_PIX_FMT_YUV444P == sw_fmt,
"Expected YUV444P format. Found: ",
av_get_pix_fmt_name(sw_fmt));
// Write Y plane directly
for (int i = 0; i < 3; ++i) {
auto status = cudaMemcpy2D(
dst.index({0, i}).data_ptr(),
width,
src->data[i],
src->linesize[i],
width,
height,
cudaMemcpyDeviceToDevice);
TORCH_CHECK(
cudaSuccess == status, "Failed to copy plane ", i, " to CUDA tensor.");
}
}
torch::Tensor YUV444PCudaConverter::convert(const AVFrame* src) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src);
if (!init) {
height = src->height;
width = src->width;
init = true;
}
torch::Tensor buffer = get_image_buffer({1, 3, height, width}, device);
convert(src, buffer);
return buffer;
}
#endif
} // namespace torchaudio::io
#pragma once
#include <torch/types.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
namespace torchaudio::io {
////////////////////////////////////////////////////////////////////////////////
// Audio
////////////////////////////////////////////////////////////////////////////////
template <c10::ScalarType dtype, bool is_planar>
class AudioConverter {
const int num_channels;
public:
AudioConverter(int num_channels);
// Converts AVFrame* into Tensor of [T, C]
torch::Tensor convert(const AVFrame* src);
// Converts AVFrame* into pre-allocated Tensor.
// The shape must be [C, T] if is_planar otherwise [T, C]
void convert(const AVFrame* src, torch::Tensor& dst);
};
////////////////////////////////////////////////////////////////////////////////
// Image
////////////////////////////////////////////////////////////////////////////////
struct ImageConverterBase {
const int height;
const int width;
const int num_channels;
ImageConverterBase(int h, int w, int c);
};
////////////////////////////////////////////////////////////////////////////////
// Interlaced Images - NHWC
////////////////////////////////////////////////////////////////////////////////
struct InterlacedImageConverter : public ImageConverterBase {
using ImageConverterBase::ImageConverterBase;
// convert AVFrame* into Tensor of NCHW format
torch::Tensor convert(const AVFrame* src);
// convert AVFrame* into pre-allocated Tensor of NHWC format
void convert(const AVFrame* src, torch::Tensor& dst);
};
struct Interlaced16BitImageConverter : public ImageConverterBase {
using ImageConverterBase::ImageConverterBase;
// convert AVFrame* into Tensor of NCHW format
torch::Tensor convert(const AVFrame* src);
// convert AVFrame* into pre-allocated Tensor of NHWC format
void convert(const AVFrame* src, torch::Tensor& dst);
};
////////////////////////////////////////////////////////////////////////////////
// Planar Images - NCHW
////////////////////////////////////////////////////////////////////////////////
struct PlanarImageConverter : public ImageConverterBase {
using ImageConverterBase::ImageConverterBase;
void convert(const AVFrame* src, torch::Tensor& dst);
torch::Tensor convert(const AVFrame* src);
};
////////////////////////////////////////////////////////////////////////////////
// Family of YUVs - NCHW
////////////////////////////////////////////////////////////////////////////////
class YUV420PConverter : public ImageConverterBase {
public:
YUV420PConverter(int height, int width);
void convert(const AVFrame* src, torch::Tensor& dst);
torch::Tensor convert(const AVFrame* src);
};
class YUV420P10LEConverter : public ImageConverterBase {
public:
YUV420P10LEConverter(int height, int width);
void convert(const AVFrame* src, torch::Tensor& dst);
torch::Tensor convert(const AVFrame* src);
};
class NV12Converter : public ImageConverterBase {
public:
NV12Converter(int height, int width);
void convert(const AVFrame* src, torch::Tensor& dst);
torch::Tensor convert(const AVFrame* src);
};
#ifdef USE_CUDA
// Note:
// GPU decoders are tricky. They allow to change the resolution as part of
// decoder option, and the resulting resolution is (seemingly) not retrievable.
// Therefore, we adopt delayed frame size initialization.
// For that purpose, we do not inherit from ImageConverterBase.
struct CudaImageConverterBase {
const torch::Device device;
bool init = false;
int height = -1;
int width = -1;
explicit CudaImageConverterBase(const torch::Device& device);
};
class NV12CudaConverter : CudaImageConverterBase {
torch::Tensor tmp_uv{};
public:
explicit NV12CudaConverter(const torch::Device& device);
void convert(const AVFrame* src, torch::Tensor& dst);
torch::Tensor convert(const AVFrame* src);
};
class P010CudaConverter : CudaImageConverterBase {
torch::Tensor tmp_uv{};
public:
explicit P010CudaConverter(const torch::Device& device);
void convert(const AVFrame* src, torch::Tensor& dst);
torch::Tensor convert(const AVFrame* src);
};
class YUV444PCudaConverter : CudaImageConverterBase {
public:
explicit YUV444PCudaConverter(const torch::Device& device);
void convert(const AVFrame* src, torch::Tensor& dst);
torch::Tensor convert(const AVFrame* src);
};
#endif
} // namespace torchaudio::io
#include <torchaudio/csrc/ffmpeg/stream_reader/packet_buffer.h>
namespace torchaudio::io {
void PacketBuffer::push_packet(AVPacket* packet) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(packet, "Packet is null.");
AVPacket* p = av_packet_clone(packet);
TORCH_INTERNAL_ASSERT(p, "Failed to clone packet.");
packets.emplace_back(p);
}
std::vector<AVPacketPtr> PacketBuffer::pop_packets() {
std::vector<AVPacketPtr> ret{
std::make_move_iterator(packets.begin()),
std::make_move_iterator(packets.end())};
packets.clear();
return ret;
}
bool PacketBuffer::has_packets() {
return packets.size() > 0;
}
} // namespace torchaudio::io
#pragma once
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
namespace torchaudio {
namespace io {
class PacketBuffer {
public:
void push_packet(AVPacket* packet);
std::vector<AVPacketPtr> pop_packets();
bool has_packets();
private:
std::deque<AVPacketPtr> packets;
};
} // namespace io
} // namespace torchaudio
#include <torchaudio/csrc/ffmpeg/stream_reader/buffer/chunked_buffer.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/buffer/unchunked_buffer.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/conversion.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/post_process.h>
namespace torchaudio::io {
namespace detail {
namespace {
///////////////////////////////////////////////////////////////////////////////
// FilterGraphWrapper (FilterGraph + reset feature)
///////////////////////////////////////////////////////////////////////////////
using FilterGraphFactory = std::function<FilterGraph(const std::string&)>;
FilterGraphFactory get_audio_factory(
AVRational time_base,
AVCodecContext* codec_ctx) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(codec_ctx->codec_type == AVMEDIA_TYPE_AUDIO);
return [fmt = codec_ctx->sample_fmt,
time_base,
rate = codec_ctx->sample_rate,
channel_layout = codec_ctx->channel_layout](
const std::string& filter_desc) -> FilterGraph {
FilterGraph f;
f.add_audio_src(fmt, time_base, rate, channel_layout);
f.add_audio_sink();
f.add_process(filter_desc);
f.create_filter();
return f;
};
}
FilterGraphFactory get_video_factory(
AVRational time_base,
AVRational frame_rate,
AVCodecContext* codec_ctx) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(codec_ctx->codec_type == AVMEDIA_TYPE_VIDEO);
return [fmt = codec_ctx->pix_fmt,
time_base,
frame_rate,
w = codec_ctx->width,
h = codec_ctx->height,
ratio = codec_ctx->sample_aspect_ratio,
hw_frames_ctx = codec_ctx->hw_frames_ctx](
const std::string& filter_desc) -> FilterGraph {
FilterGraph f;
f.add_video_src(fmt, time_base, frame_rate, w, h, ratio);
f.add_video_sink();
f.add_process(filter_desc);
if (hw_frames_ctx) {
f.create_filter(av_buffer_ref(hw_frames_ctx));
} else {
f.create_filter();
}
return f;
};
}
struct FilterGraphWrapper {
const std::string desc;
private:
FilterGraphFactory factory;
public:
FilterGraph filter;
// Constructor for audio input
FilterGraphWrapper(
AVRational input_time_base,
AVCodecContext* codec_ctx,
const std::string& desc)
: desc(desc),
factory(get_audio_factory(input_time_base, codec_ctx)),
filter(factory(desc)) {}
// Constructor for video input
FilterGraphWrapper(
AVRational input_time_base,
AVRational frame_rate,
AVCodecContext* codec_ctx,
const std::string& desc)
: desc(desc),
factory(get_video_factory(input_time_base, frame_rate, codec_ctx)),
filter(factory(desc)) {}
void reset() {
filter = factory(desc);
}
};
///////////////////////////////////////////////////////////////////////////////
// ProcessImpl
///////////////////////////////////////////////////////////////////////////////
template <typename Converter, typename Buffer>
struct ProcessImpl : public IPostDecodeProcess {
private:
AVFramePtr frame{alloc_avframe()};
FilterGraphWrapper filter_wrapper;
public:
Converter converter;
Buffer buffer;
ProcessImpl(
FilterGraphWrapper&& filter_wrapper,
Converter&& converter,
Buffer&& buffer)
: filter_wrapper(std::move(filter_wrapper)),
converter(std::move(converter)),
buffer(std::move(buffer)) {}
bool is_buffer_ready() const override {
return buffer.is_ready();
}
const std::string& get_filter_desc() const override {
return filter_wrapper.desc;
};
FilterGraphOutputInfo get_filter_output_info() const override {
return filter_wrapper.filter.get_output_info();
};
void flush() override {
filter_wrapper.reset();
buffer.flush();
}
int process_frame(AVFrame* in_frame) override {
int ret = filter_wrapper.filter.add_frame(in_frame);
while (ret >= 0) {
ret = filter_wrapper.filter.get_frame(frame);
// AVERROR(EAGAIN) means that new input data is required to return new
// output.
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
return 0;
}
if (ret >= 0) {
buffer.push_frame(converter.convert(frame), frame->pts);
}
av_frame_unref(frame);
}
return ret;
}
c10::optional<Chunk> pop_chunk() override {
return buffer.pop_chunk();
}
};
///////////////////////////////////////////////////////////////////////////////
// Audio
///////////////////////////////////////////////////////////////////////////////
std::unique_ptr<IPostDecodeProcess> get_unchunked_audio_process(
FilterGraphWrapper&& filter) {
auto i = filter.filter.get_output_info();
TORCH_INTERNAL_ASSERT(
i.type == AVMEDIA_TYPE_AUDIO,
"Unsupported media type found: ",
av_get_media_type_string(i.type));
using B = UnchunkedBuffer;
switch (auto fmt = (AVSampleFormat)i.format; fmt) {
case AV_SAMPLE_FMT_U8: {
using C = AudioConverter<torch::kUInt8, false>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, B{i.time_base});
}
case AV_SAMPLE_FMT_S16: {
using C = AudioConverter<torch::kInt16, false>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, B{i.time_base});
}
case AV_SAMPLE_FMT_S32: {
using C = AudioConverter<torch::kInt32, false>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, B{i.time_base});
}
case AV_SAMPLE_FMT_S64: {
using C = AudioConverter<torch::kInt64, false>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, B{i.time_base});
}
case AV_SAMPLE_FMT_FLT: {
using C = AudioConverter<torch::kFloat32, false>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, B{i.time_base});
}
case AV_SAMPLE_FMT_DBL: {
using C = AudioConverter<torch::kFloat64, false>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, B{i.time_base});
}
case AV_SAMPLE_FMT_U8P: {
using C = AudioConverter<torch::kUInt8, true>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, B{i.time_base});
}
case AV_SAMPLE_FMT_S16P: {
using C = AudioConverter<torch::kInt16, true>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, B{i.time_base});
}
case AV_SAMPLE_FMT_S32P: {
using C = AudioConverter<torch::kInt32, true>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, B{i.time_base});
}
case AV_SAMPLE_FMT_S64P: {
using C = AudioConverter<torch::kInt64, true>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, B{i.time_base});
}
case AV_SAMPLE_FMT_FLTP: {
using C = AudioConverter<torch::kFloat32, true>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, B{i.time_base});
}
case AV_SAMPLE_FMT_DBLP: {
using C = AudioConverter<torch::kFloat64, true>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, B{i.time_base});
}
default:
TORCH_INTERNAL_ASSERT(
false, "Unexpected audio type:", av_get_sample_fmt_name(fmt));
}
}
std::unique_ptr<IPostDecodeProcess> get_chunked_audio_process(
FilterGraphWrapper&& filter,
int frames_per_chunk,
int num_chunks) {
auto i = filter.filter.get_output_info();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
i.type == AVMEDIA_TYPE_AUDIO,
"Unsupported media type found: ",
av_get_media_type_string(i.type));
using B = ChunkedBuffer;
B buffer{i.time_base, frames_per_chunk, num_chunks};
switch (auto fmt = (AVSampleFormat)i.format; fmt) {
case AV_SAMPLE_FMT_U8: {
using C = AudioConverter<torch::kUInt8, false>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, std::move(buffer));
}
case AV_SAMPLE_FMT_S16: {
using C = AudioConverter<torch::kInt16, false>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, std::move(buffer));
}
case AV_SAMPLE_FMT_S32: {
using C = AudioConverter<torch::kInt32, false>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, std::move(buffer));
}
case AV_SAMPLE_FMT_S64: {
using C = AudioConverter<torch::kInt64, false>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, std::move(buffer));
}
case AV_SAMPLE_FMT_FLT: {
using C = AudioConverter<torch::kFloat32, false>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, std::move(buffer));
}
case AV_SAMPLE_FMT_DBL: {
using C = AudioConverter<torch::kFloat64, false>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, std::move(buffer));
}
case AV_SAMPLE_FMT_U8P: {
using C = AudioConverter<torch::kUInt8, true>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, std::move(buffer));
}
case AV_SAMPLE_FMT_S16P: {
using C = AudioConverter<torch::kInt16, true>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, std::move(buffer));
}
case AV_SAMPLE_FMT_S32P: {
using C = AudioConverter<torch::kInt32, true>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, std::move(buffer));
}
case AV_SAMPLE_FMT_S64P: {
using C = AudioConverter<torch::kInt64, true>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, std::move(buffer));
}
case AV_SAMPLE_FMT_FLTP: {
using C = AudioConverter<torch::kFloat32, true>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, std::move(buffer));
}
case AV_SAMPLE_FMT_DBLP: {
using C = AudioConverter<torch::kFloat64, true>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, std::move(buffer));
}
default:
TORCH_INTERNAL_ASSERT(
false, "Unexpected audio type:", av_get_sample_fmt_name(fmt));
}
}
///////////////////////////////////////////////////////////////////////////////
// Video
///////////////////////////////////////////////////////////////////////////////
std::unique_ptr<IPostDecodeProcess> get_unchunked_video_process(
FilterGraphWrapper&& filter) {
auto i = filter.filter.get_output_info();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
i.type == AVMEDIA_TYPE_VIDEO,
"Unsupported media type found: ",
av_get_media_type_string(i.type));
auto h = i.height;
auto w = i.width;
auto tb = i.time_base;
using B = UnchunkedBuffer;
switch (auto fmt = (AVPixelFormat)i.format; fmt) {
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24: {
using C = InterlacedImageConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w, 3}, B{tb});
}
case AV_PIX_FMT_ARGB:
case AV_PIX_FMT_RGBA:
case AV_PIX_FMT_ABGR:
case AV_PIX_FMT_BGRA: {
using C = InterlacedImageConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w, 4}, B{tb});
}
case AV_PIX_FMT_GRAY8: {
using C = InterlacedImageConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w, 1}, B{tb});
}
case AV_PIX_FMT_RGB48LE: {
using C = Interlaced16BitImageConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w, 3}, B{tb});
}
case AV_PIX_FMT_YUV444P: {
using C = PlanarImageConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w, 3}, B{tb});
}
case AV_PIX_FMT_YUV420P: {
using C = YUV420PConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w}, B{tb});
}
case AV_PIX_FMT_YUV420P10LE: {
using C = YUV420P10LEConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w}, B{tb});
}
case AV_PIX_FMT_NV12: {
using C = NV12Converter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w}, B{tb});
}
default: {
TORCH_INTERNAL_ASSERT(
false, "Unexpected video format found: ", av_get_pix_fmt_name(fmt));
}
}
}
std::unique_ptr<IPostDecodeProcess> get_unchunked_cuda_video_process(
FilterGraphWrapper&& filter,
const torch::Device& device) {
#ifndef USE_CUDA
TORCH_INTERNAL_ASSERT(
false,
"USE_CUDA is not defined, but CUDA decoding process was requested.");
#else
auto i = filter.filter.get_output_info();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
i.type == AVMEDIA_TYPE_VIDEO,
"Unsupported media type found: ",
av_get_media_type_string(i.type));
using B = UnchunkedBuffer;
switch (auto fmt = (AVPixelFormat)i.format; fmt) {
case AV_PIX_FMT_NV12: {
using C = NV12CudaConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{device}, B{i.time_base});
}
case AV_PIX_FMT_P010: {
using C = P010CudaConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{device}, B{i.time_base});
}
case AV_PIX_FMT_YUV444P: {
using C = YUV444PCudaConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{device}, B{i.time_base});
}
case AV_PIX_FMT_P016: {
TORCH_CHECK(
false,
"Unsupported video format found in CUDA HW: ",
av_get_pix_fmt_name(fmt));
}
default: {
TORCH_CHECK(
false,
"Unexpected video format found in CUDA HW: ",
av_get_pix_fmt_name(fmt));
}
}
#endif
}
std::unique_ptr<IPostDecodeProcess> get_chunked_video_process(
FilterGraphWrapper&& filter,
int frames_per_chunk,
int num_chunks) {
auto i = filter.filter.get_output_info();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
i.type == AVMEDIA_TYPE_VIDEO,
"Unsupported media type found: ",
av_get_media_type_string(i.type));
auto h = i.height;
auto w = i.width;
auto tb = i.time_base;
using B = ChunkedBuffer;
switch (auto fmt = (AVPixelFormat)i.format; fmt) {
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24: {
using C = InterlacedImageConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w, 3}, B{tb, frames_per_chunk, num_chunks});
}
case AV_PIX_FMT_ARGB:
case AV_PIX_FMT_RGBA:
case AV_PIX_FMT_ABGR:
case AV_PIX_FMT_BGRA: {
using C = InterlacedImageConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w, 4}, B{tb, frames_per_chunk, num_chunks});
}
case AV_PIX_FMT_GRAY8: {
using C = InterlacedImageConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w, 1}, B{tb, frames_per_chunk, num_chunks});
}
case AV_PIX_FMT_RGB48LE: {
using C = Interlaced16BitImageConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w, 3}, B{tb, frames_per_chunk, num_chunks});
}
case AV_PIX_FMT_YUV444P: {
using C = PlanarImageConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w, 3}, B{tb, frames_per_chunk, num_chunks});
}
case AV_PIX_FMT_YUV420P: {
using C = YUV420PConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w}, B{tb, frames_per_chunk, num_chunks});
}
case AV_PIX_FMT_YUV420P10LE: {
using C = YUV420P10LEConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w}, B{tb, frames_per_chunk, num_chunks});
}
case AV_PIX_FMT_NV12: {
using C = NV12Converter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w}, B{tb, frames_per_chunk, num_chunks});
}
default: {
TORCH_INTERNAL_ASSERT(
false, "Unexpected video format found: ", av_get_pix_fmt_name(fmt));
}
}
}
std::unique_ptr<IPostDecodeProcess> get_chunked_cuda_video_process(
FilterGraphWrapper&& filter,
int frames_per_chunk,
int num_chunks,
const torch::Device& device) {
#ifndef USE_CUDA
TORCH_INTERNAL_ASSERT(
false,
"USE_CUDA is not defined, but CUDA decoding process was requested.");
#else
auto i = filter.filter.get_output_info();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
i.type == AVMEDIA_TYPE_VIDEO,
"Unsupported media type found: ",
av_get_media_type_string(i.type));
using B = ChunkedBuffer;
switch (auto fmt = (AVPixelFormat)i.format; fmt) {
case AV_PIX_FMT_NV12: {
using C = NV12CudaConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter),
C{device},
B{i.time_base, frames_per_chunk, num_chunks});
}
case AV_PIX_FMT_P010: {
using C = P010CudaConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter),
C{device},
B{i.time_base, frames_per_chunk, num_chunks});
}
case AV_PIX_FMT_YUV444P: {
using C = YUV444PCudaConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter),
C{device},
B{i.time_base, frames_per_chunk, num_chunks});
}
case AV_PIX_FMT_P016: {
TORCH_CHECK(
false,
"Unsupported video format found in CUDA HW: ",
av_get_pix_fmt_name(fmt));
}
default: {
TORCH_CHECK(
false,
"Unexpected video format found in CUDA HW: ",
av_get_pix_fmt_name(fmt));
}
}
#endif
}
} // namespace
} // namespace detail
std::unique_ptr<IPostDecodeProcess> get_audio_process(
AVRational input_time_base,
AVCodecContext* codec_ctx,
const std::string& desc,
int frames_per_chunk,
int num_chunks) {
TORCH_CHECK(
frames_per_chunk > 0 || frames_per_chunk == -1,
"`frames_per_chunk` must be positive or -1. Found: ",
frames_per_chunk);
TORCH_CHECK(
num_chunks > 0 || num_chunks == -1,
"`num_chunks` must be positive or -1. Found: ",
num_chunks);
detail::FilterGraphWrapper filter{input_time_base, codec_ctx, desc};
if (frames_per_chunk == -1) {
return detail::get_unchunked_audio_process(std::move(filter));
}
return detail::get_chunked_audio_process(
std::move(filter), frames_per_chunk, num_chunks);
}
std::unique_ptr<IPostDecodeProcess> get_video_process(
AVRational input_time_base,
AVRational frame_rate,
AVCodecContext* codec_ctx,
const std::string& desc,
int frames_per_chunk,
int num_chunks,
const torch::Device& device) {
TORCH_CHECK(
frames_per_chunk > 0 || frames_per_chunk == -1,
"`frames_per_chunk` must be positive or -1. Found: ",
frames_per_chunk);
TORCH_CHECK(
num_chunks > 0 || num_chunks == -1,
"`num_chunks` must be positive or -1. Found: ",
num_chunks);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
device.is_cuda() || device.is_cpu(), "Unexpected device type: ", device);
detail::FilterGraphWrapper filter{
input_time_base, frame_rate, codec_ctx, desc};
if (frames_per_chunk == -1) {
if (device.is_cuda()) {
return detail::get_unchunked_cuda_video_process(
std::move(filter), device);
}
return detail::get_unchunked_video_process(std::move(filter));
}
if (device.is_cuda()) {
return detail::get_chunked_cuda_video_process(
std::move(filter), frames_per_chunk, num_chunks, device);
}
return detail::get_chunked_video_process(
std::move(filter), frames_per_chunk, num_chunks);
}
} // namespace torchaudio::io
#pragma once
#include <torchaudio/csrc/ffmpeg/filter_graph.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/typedefs.h>
namespace torchaudio::io {
struct IPostDecodeProcess {
virtual ~IPostDecodeProcess() = default;
virtual int process_frame(AVFrame* frame) = 0;
virtual c10::optional<Chunk> pop_chunk() = 0;
virtual bool is_buffer_ready() const = 0;
virtual const std::string& get_filter_desc() const = 0;
virtual FilterGraphOutputInfo get_filter_output_info() const = 0;
virtual void flush() = 0;
};
std::unique_ptr<IPostDecodeProcess> get_audio_process(
AVRational input_time_base,
AVCodecContext* codec_ctx,
const std::string& desc,
int frames_per_chunk,
int num_chunks);
std::unique_ptr<IPostDecodeProcess> get_video_process(
AVRational input_time_base,
AVRational frame_rate,
AVCodecContext* codec_ctx,
const std::string& desc,
int frames_per_chunk,
int num_chunks,
const torch::Device& device);
} // namespace torchaudio::io
#include <torchaudio/csrc/ffmpeg/hw_context.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/stream_processor.h>
#include <stdexcept>
#include <string_view>
namespace torchaudio {
namespace ffmpeg {
namespace torchaudio::io {
using KeyType = StreamProcessor::KeyType;
namespace {
AVCodecContextPtr alloc_codec_context(
enum AVCodecID codec_id,
const c10::optional<std::string>& decoder_name) {
const AVCodec* codec = [&]() {
if (decoder_name) {
const AVCodec* c =
avcodec_find_decoder_by_name(decoder_name.value().c_str());
TORCH_CHECK(c, "Unsupported codec: ", decoder_name.value());
return c;
} else {
const AVCodec* c = avcodec_find_decoder(codec_id);
TORCH_CHECK(c, "Unsupported codec: ", avcodec_get_name(codec_id));
return c;
}
}();
AVCodecContext* codec_ctx = avcodec_alloc_context3(codec);
TORCH_CHECK(codec_ctx, "Failed to allocate CodecContext.");
return AVCodecContextPtr(codec_ctx);
}
const AVCodecHWConfig* get_cuda_config(const AVCodec* codec) {
for (int i = 0;; ++i) {
const AVCodecHWConfig* config = avcodec_get_hw_config(codec, i);
if (!config) {
break;
}
if (config->device_type == AV_HWDEVICE_TYPE_CUDA &&
config->methods & AV_CODEC_HW_CONFIG_METHOD_HW_DEVICE_CTX) {
return config;
}
}
TORCH_CHECK(
false,
"CUDA device was requested, but the codec \"",
codec->name,
"\" is not supported.");
}
enum AVPixelFormat get_hw_format(
AVCodecContext* codec_ctx,
const enum AVPixelFormat* pix_fmts) {
const AVCodecHWConfig* cfg = static_cast<AVCodecHWConfig*>(codec_ctx->opaque);
for (const enum AVPixelFormat* p = pix_fmts; *p != -1; p++) {
if (*p == cfg->pix_fmt) {
// Note
// The HW decode example uses generic approach
// https://ffmpeg.org/doxygen/4.1/hw__decode_8c_source.html#l00063
// But this approach finalizes the codec configuration when the first
// frame comes in.
// We need to inspect the codec configuration right after the codec is
// opened.
// So we add short cut for known patterns.
// yuv420p (h264) -> nv12
// yuv420p10le (hevc/h265) -> p010le
switch (codec_ctx->pix_fmt) {
case AV_PIX_FMT_YUV420P: {
codec_ctx->pix_fmt = AV_PIX_FMT_CUDA;
codec_ctx->sw_pix_fmt = AV_PIX_FMT_NV12;
break;
}
case AV_PIX_FMT_YUV420P10LE: {
codec_ctx->pix_fmt = AV_PIX_FMT_CUDA;
codec_ctx->sw_pix_fmt = AV_PIX_FMT_P010LE;
break;
}
default:;
}
return *p;
}
}
TORCH_WARN("Failed to get HW surface format.");
return AV_PIX_FMT_NONE;
}
AVBufferRef* get_hw_frames_ctx(AVCodecContext* codec_ctx) {
AVBufferRef* p = av_hwframe_ctx_alloc(codec_ctx->hw_device_ctx);
TORCH_CHECK(
p,
"Failed to allocate CUDA frame context from device context at ",
codec_ctx->hw_device_ctx);
auto frames_ctx = (AVHWFramesContext*)(p->data);
frames_ctx->format = codec_ctx->pix_fmt;
frames_ctx->sw_format = codec_ctx->sw_pix_fmt;
frames_ctx->width = codec_ctx->width;
frames_ctx->height = codec_ctx->height;
frames_ctx->initial_pool_size = 5;
int ret = av_hwframe_ctx_init(p);
if (ret >= 0) {
return p;
}
av_buffer_unref(&p);
TORCH_CHECK(
false, "Failed to initialize CUDA frame context: ", av_err2string(ret));
}
void configure_codec_context(
AVCodecContext* codec_ctx,
const AVCodecParameters* params,
const torch::Device& device) {
int ret = avcodec_parameters_to_context(codec_ctx, params);
TORCH_CHECK(
ret >= 0, "Failed to set CodecContext parameter: ", av_err2string(ret));
if (device.type() == c10::DeviceType::CUDA) {
#ifndef USE_CUDA
TORCH_CHECK(false, "torchaudio is not compiled with CUDA support.");
#else
const AVCodecHWConfig* cfg = get_cuda_config(codec_ctx->codec);
// https://www.ffmpeg.org/doxygen/trunk/hw__decode_8c_source.html#l00221
// 1. Set HW config to opaue pointer.
codec_ctx->opaque = static_cast<void*>(const_cast<AVCodecHWConfig*>(cfg));
// 2. Set pCodecContext->get_format call back function which
// will retrieve the HW pixel format from opaque pointer.
codec_ctx->get_format = get_hw_format;
codec_ctx->hw_device_ctx = av_buffer_ref(get_cuda_context(device.index()));
TORCH_INTERNAL_ASSERT(
codec_ctx->hw_device_ctx, "Failed to reference HW device context.");
#endif
}
}
void open_codec(
AVCodecContext* codec_ctx,
const c10::optional<OptionDict>& decoder_option) {
AVDictionary* opts = get_option_dict(decoder_option);
// Default to single thread execution.
if (!av_dict_get(opts, "threads", nullptr, 0)) {
av_dict_set(&opts, "threads", "1", 0);
}
if (!codec_ctx->channel_layout) {
codec_ctx->channel_layout =
av_get_default_channel_layout(codec_ctx->channels);
}
int ret = avcodec_open2(codec_ctx, codec_ctx->codec, &opts);
clean_up_dict(opts);
TORCH_CHECK(
ret >= 0, "Failed to initialize CodecContext: ", av_err2string(ret));
}
bool ends_with(std::string_view str, std::string_view suffix) {
return str.size() >= suffix.size() &&
0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
}
StreamProcessor::StreamProcessor(
AVCodecParameters* codecpar,
AVCodecContextPtr get_codec_ctx(
const AVCodecParameters* params,
const c10::optional<std::string>& decoder_name,
const c10::optional<OptionDict>& decoder_option,
const torch::Device& device)
: decoder(codecpar, decoder_name, decoder_option, device) {}
const torch::Device& device) {
AVCodecContextPtr codec_ctx =
alloc_codec_context(params->codec_id, decoder_name);
configure_codec_context(codec_ctx, params, device);
open_codec(codec_ctx, decoder_option);
if (codec_ctx->hw_device_ctx) {
codec_ctx->hw_frames_ctx = get_hw_frames_ctx(codec_ctx);
}
if (ends_with(codec_ctx->codec->name, "_cuvid")) {
C10_LOG_API_USAGE_ONCE("torchaudio.io.StreamReaderCUDA");
}
return codec_ctx;
}
} // namespace
using KeyType = StreamProcessor::KeyType;
StreamProcessor::StreamProcessor(const AVRational& time_base)
: stream_time_base(time_base) {}
////////////////////////////////////////////////////////////////////////////////
// Configurations
////////////////////////////////////////////////////////////////////////////////
KeyType StreamProcessor::add_stream(
AVRational input_time_base,
AVCodecParameters* codecpar,
int frames_per_chunk,
int num_chunks,
const c10::optional<std::string>& filter_description,
AVRational frame_rate,
const std::string& filter_description,
const torch::Device& device) {
switch (codecpar->codec_type) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
is_decoder_set(), "Decoder hasn't been set.");
// If device is provided, then check that codec_ctx has hw_device_ctx set.
// In case, defining an output stream with HW accel on an input stream that
// has decoder set without HW accel, it will cause seg fault.
// i.e.
// The following should be rejected here.
// reader = StreamReader(...)
// reader.add_video_stream(..., decoder="h264_cuvid")
// reader.add_video_stream(..., decoder="h264_cuvid", hw_accel="cuda")
// TODO:
// One idea to work around this is to always define HW device context, and
// if HW acceleration is not required, insert `hwdownload` filter.
// This way it will be possible to handle both cases at the same time.
switch (device.type()) {
case torch::kCPU:
TORCH_CHECK(
!codec_ctx->hw_device_ctx,
"Decoding without Hardware acceleration is requested, however, "
"the decoder has been already defined with a HW acceleration. "
"Decoding a stream with and without HW acceleration simultaneously "
"is not supported.");
break;
case torch::kCUDA:
TORCH_CHECK(
codec_ctx->hw_device_ctx,
"CUDA Hardware acceleration is requested, however, the decoder has "
"been already defined without a HW acceleration. "
"Decoding a stream with and without HW acceleration simultaneously "
"is not supported.");
break;
default:;
}
switch (codec_ctx->codec_type) {
case AVMEDIA_TYPE_AUDIO:
post_processes.emplace(
std::piecewise_construct,
std::forward_as_tuple(current_key),
std::forward_as_tuple(get_audio_process(
stream_time_base,
codec_ctx,
filter_description,
frames_per_chunk,
num_chunks)));
return current_key++;
case AVMEDIA_TYPE_VIDEO:
break;
post_processes.emplace(
std::piecewise_construct,
std::forward_as_tuple(current_key),
std::forward_as_tuple(get_video_process(
stream_time_base,
frame_rate,
codec_ctx,
filter_description,
frames_per_chunk,
num_chunks,
device)));
return current_key++;
default:
TORCH_CHECK(false, "Only Audio and Video are supported");
}
KeyType key = current_key++;
sinks.emplace(
std::piecewise_construct,
std::forward_as_tuple(key),
std::forward_as_tuple(
input_time_base,
codecpar,
frames_per_chunk,
num_chunks,
filter_description,
device));
decoder_time_base = av_q2d(input_time_base);
return key;
}
void StreamProcessor::remove_stream(KeyType key) {
sinks.erase(key);
post_processes.erase(key);
}
void StreamProcessor::set_discard_timestamp(int64_t timestamp) {
TORCH_CHECK(timestamp >= 0, "timestamp must be non-negative.");
discard_before_pts =
av_rescale_q(timestamp, av_get_time_base_q(), stream_time_base);
}
void StreamProcessor::set_decoder(
const AVCodecParameters* codecpar,
const c10::optional<std::string>& decoder_name,
const c10::optional<OptionDict>& decoder_option,
const torch::Device& device) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!codec_ctx, "Decoder has already been set.");
codec_ctx = get_codec_ctx(codecpar, decoder_name, decoder_option, device);
}
////////////////////////////////////////////////////////////////////////////////
// Query methods
////////////////////////////////////////////////////////////////////////////////
std::string StreamProcessor::get_filter_description(KeyType key) const {
return sinks.at(key).get_filter_description();
return post_processes.at(key)->get_filter_desc();
}
FilterGraphOutputInfo StreamProcessor::get_filter_output_info(
KeyType key) const {
return post_processes.at(key)->get_filter_output_info();
}
bool StreamProcessor::is_buffer_ready() const {
for (const auto& it : sinks) {
if (!it.second.is_buffer_ready()) {
for (const auto& it : post_processes) {
if (!it.second->is_buffer_ready()) {
return false;
}
}
return true;
}
bool StreamProcessor::is_decoder_set() const {
return codec_ctx;
}
////////////////////////////////////////////////////////////////////////////////
// The streaming process
////////////////////////////////////////////////////////////////////////////////
// 0: some kind of success
// <0: Some error happened
int StreamProcessor::process_packet(AVPacket* packet) {
int ret = decoder.process_packet(packet);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
is_decoder_set(),
"Decoder must have been set prior to calling this function.");
int ret = avcodec_send_packet(codec_ctx, packet);
while (ret >= 0) {
ret = decoder.get_frame(pFrame1);
ret = avcodec_receive_frame(codec_ctx, frame);
// AVERROR(EAGAIN) means that new input data is required to return new
// output.
if (ret == AVERROR(EAGAIN))
return 0;
if (ret == AVERROR_EOF)
return send_frame(NULL);
return send_frame(nullptr);
if (ret < 0)
return ret;
send_frame(pFrame1);
av_frame_unref(pFrame1);
// If pts is undefined then overwrite with best effort estimate.
// In this case, best_effort_timestamp is basically the number of frames
// emit from decoder.
//
// We need valid pts because filter_graph does not fall back to
// best_effort_timestamp.
if (frame->pts == AV_NOPTS_VALUE) {
if (frame->best_effort_timestamp == AV_NOPTS_VALUE) {
// This happens in drain mode.
// When the decoder enters drain mode, it starts flushing the internally
// buffered frames, of which PTS cannot be estimated.
//
// This is because they might be intra-frames not in chronological
// order. In this case, we use received frames as-is in the order they
// are received.
frame->pts = codec_ctx->frame_number + 1;
} else {
frame->pts = frame->best_effort_timestamp;
}
}
// When the value of discard_before_pts is 0, we consider that the seek is
// not performed and all the frames are passed to downstream
// unconditionally.
//
// Two reasons for this behavior;
// 1. When seek mode is not precise, we do not discard any frame.
// In this case discard_before_pts is set to zero.
// 2. When users seek to zero, what they expect is to get to the beginning
// of the data.
//
// Note: discard_before_pts < 0 is UB.
if (discard_before_pts <= 0 || frame->pts >= discard_before_pts) {
send_frame(frame);
}
// else we can just unref the frame and continue
av_frame_unref(frame);
}
return ret;
}
void StreamProcessor::flush() {
decoder.flush_buffer();
for (auto& ite : sinks) {
ite.second.flush();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
is_decoder_set(),
"Decoder must have been set prior to calling this function.");
avcodec_flush_buffers(codec_ctx);
for (auto& ite : post_processes) {
ite.second->flush();
}
}
// 0: some kind of success
// <0: Some error happened
int StreamProcessor::send_frame(AVFrame* pFrame) {
int StreamProcessor::send_frame(AVFrame* frame_) {
int ret = 0;
for (auto& ite : sinks) {
int ret2 = ite.second.process_frame(pFrame);
for (auto& ite : post_processes) {
int ret2 = ite.second->process_frame(frame_);
if (ret2 < 0)
ret = ret2;
}
......@@ -110,9 +384,8 @@ int StreamProcessor::send_frame(AVFrame* pFrame) {
////////////////////////////////////////////////////////////////////////////////
// Retrieval
////////////////////////////////////////////////////////////////////////////////
c10::optional<torch::Tensor> StreamProcessor::pop_chunk(KeyType key) {
return sinks.at(key).buffer->pop_chunk();
c10::optional<Chunk> StreamProcessor::pop_chunk(KeyType key) {
return post_processes.at(key)->pop_chunk();
}
} // namespace ffmpeg
} // namespace torchaudio
} // namespace torchaudio::io
#pragma once
#include <torch/torch.h>
#include <torch/types.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/decoder.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/sink.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/post_process.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/typedefs.h>
#include <map>
namespace torchaudio {
namespace ffmpeg {
namespace io {
class StreamProcessor {
public:
using KeyType = int;
private:
AVFramePtr pFrame1;
AVFramePtr pFrame2;
// Stream time base which is not stored in AVCodecContextPtr
AVRational stream_time_base;
// Components for decoding source media
double decoder_time_base; // for debug
Decoder decoder;
AVCodecContextPtr codec_ctx{nullptr};
AVFramePtr frame{alloc_avframe()};
KeyType current_key = 0;
std::map<KeyType, Sink> sinks;
std::map<KeyType, std::unique_ptr<IPostDecodeProcess>> post_processes;
// Used for precise seek.
// 0: no discard
// Positive Values: decoded frames with PTS values less than this are
// discarded.
// Negative values: UB. Should not happen.
int64_t discard_before_pts = 0;
public:
StreamProcessor(
AVCodecParameters* codecpar,
const c10::optional<std::string>& decoder_name,
const c10::optional<OptionDict>& decoder_option,
const torch::Device& device);
explicit StreamProcessor(const AVRational& time_base);
~StreamProcessor() = default;
// Non-copyable
StreamProcessor(const StreamProcessor&) = delete;
......@@ -48,21 +51,33 @@ class StreamProcessor {
// 3. Configure a buffer.
// 4. Return filter ID.
KeyType add_stream(
AVRational input_time_base,
AVCodecParameters* codecpar,
int frames_per_chunk,
int num_chunks,
const c10::optional<std::string>& filter_description,
AVRational frame_rate,
const std::string& filter_description,
const torch::Device& device);
// 1. Remove the stream
void remove_stream(KeyType key);
// Set discard
// The input timestamp must be expressed in AV_TIME_BASE unit.
void set_discard_timestamp(int64_t timestamp);
void set_decoder(
const AVCodecParameters* codecpar,
const c10::optional<std::string>& decoder_name,
const c10::optional<OptionDict>& decoder_option,
const torch::Device& device);
//////////////////////////////////////////////////////////////////////////////
// Query methods
//////////////////////////////////////////////////////////////////////////////
std::string get_filter_description(KeyType key) const;
[[nodiscard]] std::string get_filter_description(KeyType key) const;
[[nodiscard]] FilterGraphOutputInfo get_filter_output_info(KeyType key) const;
bool is_buffer_ready() const;
[[nodiscard]] bool is_decoder_set() const;
//////////////////////////////////////////////////////////////////////////////
// The streaming process
......@@ -85,8 +100,8 @@ class StreamProcessor {
//////////////////////////////////////////////////////////////////////////////
public:
// Get the chunk from the given filter result
c10::optional<torch::Tensor> pop_chunk(KeyType key);
c10::optional<Chunk> pop_chunk(KeyType key);
};
} // namespace ffmpeg
} // namespace io
} // namespace torchaudio
......@@ -5,69 +5,119 @@
#include <stdexcept>
#include <thread>
namespace torchaudio {
namespace ffmpeg {
namespace torchaudio::io {
using KeyType = StreamProcessor::KeyType;
//////////////////////////////////////////////////////////////////////////////
// Helper methods
// Initialization / resource allocations
//////////////////////////////////////////////////////////////////////////////
void StreamReader::validate_open_stream() const {
TORCH_CHECK(pFormatContext, "Stream is not open.");
}
namespace {
AVFormatContext* get_input_format_context(
const std::string& src,
const c10::optional<std::string>& format,
const c10::optional<OptionDict>& option,
AVIOContext* io_ctx) {
AVFormatContext* p = avformat_alloc_context();
TORCH_CHECK(p, "Failed to allocate AVFormatContext.");
if (io_ctx) {
p->pb = io_ctx;
}
void StreamReader::validate_src_stream_index(int i) const {
validate_open_stream();
TORCH_CHECK(
i >= 0 && i < static_cast<int>(pFormatContext->nb_streams),
"Source stream index out of range");
}
auto* pInputFormat = [&format]() -> AVFORMAT_CONST AVInputFormat* {
if (format.has_value()) {
std::string format_str = format.value();
AVFORMAT_CONST AVInputFormat* pInput =
av_find_input_format(format_str.c_str());
TORCH_CHECK(pInput, "Unsupported device/format: \"", format_str, "\"");
return pInput;
}
return nullptr;
}();
void StreamReader::validate_output_stream_index(int i) const {
TORCH_CHECK(
i >= 0 && i < static_cast<int>(stream_indices.size()),
"Output stream index out of range");
}
AVDictionary* opt = get_option_dict(option);
int ret = avformat_open_input(&p, src.c_str(), pInputFormat, &opt);
clean_up_dict(opt);
void StreamReader::validate_src_stream_type(int i, AVMediaType type) {
validate_src_stream_index(i);
TORCH_CHECK(
pFormatContext->streams[i]->codecpar->codec_type == type,
"Stream ",
i,
" is not ",
av_get_media_type_string(type),
" stream.");
ret >= 0,
"Failed to open the input \"",
src,
"\" (",
av_err2string(ret),
").");
return p;
}
} // namespace
//////////////////////////////////////////////////////////////////////////////
// Initialization / resource allocations
//////////////////////////////////////////////////////////////////////////////
StreamReader::StreamReader(AVFormatInputContextPtr&& p)
: pFormatContext(std::move(p)) {
int ret = avformat_find_stream_info(pFormatContext, nullptr);
StreamReader::StreamReader(AVFormatContext* p) : format_ctx(p) {
C10_LOG_API_USAGE_ONCE("torchaudio.io.StreamReader");
int ret = avformat_find_stream_info(format_ctx, nullptr);
TORCH_CHECK(
ret >= 0, "Failed to find stream information: ", av_err2string(ret));
processors =
std::vector<std::unique_ptr<StreamProcessor>>(pFormatContext->nb_streams);
for (int i = 0; i < pFormatContext->nb_streams; ++i) {
switch (pFormatContext->streams[i]->codecpar->codec_type) {
std::vector<std::unique_ptr<StreamProcessor>>(format_ctx->nb_streams);
for (int i = 0; i < format_ctx->nb_streams; ++i) {
switch (format_ctx->streams[i]->codecpar->codec_type) {
case AVMEDIA_TYPE_AUDIO:
case AVMEDIA_TYPE_VIDEO:
break;
default:
pFormatContext->streams[i]->discard = AVDISCARD_ALL;
format_ctx->streams[i]->discard = AVDISCARD_ALL;
}
}
}
StreamReader::StreamReader(
AVIOContext* io_ctx,
const c10::optional<std::string>& format,
const c10::optional<OptionDict>& option)
: StreamReader(get_input_format_context(
"Custom Input Context",
format,
option,
io_ctx)) {}
StreamReader::StreamReader(
const std::string& src,
const c10::optional<std::string>& format,
const c10::optional<OptionDict>& option)
: StreamReader(get_input_format_context(src, format, option, nullptr)) {}
//////////////////////////////////////////////////////////////////////////////
// Helper methods
//////////////////////////////////////////////////////////////////////////////
void validate_open_stream(AVFormatContext* format_ctx) {
TORCH_CHECK(format_ctx, "Stream is not open.");
}
void validate_src_stream_index(AVFormatContext* format_ctx, int i) {
validate_open_stream(format_ctx);
TORCH_CHECK(
i >= 0 && i < static_cast<int>(format_ctx->nb_streams),
"Source stream index out of range");
}
void validate_src_stream_type(
AVFormatContext* format_ctx,
int i,
AVMediaType type) {
validate_src_stream_index(format_ctx, i);
TORCH_CHECK(
format_ctx->streams[i]->codecpar->codec_type == type,
"Stream ",
i,
" is not ",
av_get_media_type_string(type),
" stream.");
}
////////////////////////////////////////////////////////////////////////////////
// Query methods
////////////////////////////////////////////////////////////////////////////////
int64_t StreamReader::num_src_streams() const {
return pFormatContext->nb_streams;
return format_ctx->nb_streams;
}
namespace {
......@@ -75,19 +125,20 @@ OptionDict parse_metadata(const AVDictionary* metadata) {
AVDictionaryEntry* tag = nullptr;
OptionDict ret;
while ((tag = av_dict_get(metadata, "", tag, AV_DICT_IGNORE_SUFFIX))) {
ret.insert(std::string(tag->key), std::string(tag->value));
ret.emplace(std::string(tag->key), std::string(tag->value));
}
return ret;
}
} // namespace
OptionDict StreamReader::get_metadata() const {
return parse_metadata(pFormatContext->metadata);
return parse_metadata(format_ctx->metadata);
}
SrcStreamInfo StreamReader::get_src_stream_info(int i) const {
validate_src_stream_index(i);
AVStream* stream = pFormatContext->streams[i];
validate_src_stream_index(format_ctx, i);
AVStream* stream = format_ctx->streams[i];
AVCodecParameters* codecpar = stream->codecpar;
SrcStreamInfo ret;
......@@ -127,34 +178,82 @@ SrcStreamInfo StreamReader::get_src_stream_info(int i) const {
return ret;
}
namespace {
AVCodecParameters* get_codecpar() {
AVCodecParameters* ptr = avcodec_parameters_alloc();
TORCH_CHECK(ptr, "Failed to allocate resource.");
return ptr;
}
} // namespace
StreamParams StreamReader::get_src_stream_params(int i) {
validate_src_stream_index(format_ctx, i);
AVStream* stream = format_ctx->streams[i];
AVCodecParametersPtr codec_params(get_codecpar());
int ret = avcodec_parameters_copy(codec_params, stream->codecpar);
TORCH_CHECK(
ret >= 0,
"Failed to copy the stream's codec parameters. (",
av_err2string(ret),
")");
return {std::move(codec_params), stream->time_base, i};
}
int64_t StreamReader::num_out_streams() const {
return static_cast<int64_t>(stream_indices.size());
}
OutputStreamInfo StreamReader::get_out_stream_info(int i) const {
validate_output_stream_index(i);
OutputStreamInfo ret;
TORCH_CHECK(
i >= 0 && static_cast<size_t>(i) < stream_indices.size(),
"Output stream index out of range");
int i_src = stream_indices[i].first;
KeyType key = stream_indices[i].second;
FilterGraphOutputInfo info = processors[i_src]->get_filter_output_info(key);
OutputStreamInfo ret;
ret.source_index = i_src;
ret.filter_description = processors[i_src]->get_filter_description(key);
ret.media_type = info.type;
ret.format = info.format;
switch (info.type) {
case AVMEDIA_TYPE_AUDIO:
ret.sample_rate = info.sample_rate;
ret.num_channels = info.num_channels;
break;
case AVMEDIA_TYPE_VIDEO:
ret.width = info.width;
ret.height = info.height;
ret.frame_rate = info.frame_rate;
break;
default:;
}
return ret;
}
int64_t StreamReader::find_best_audio_stream() const {
return av_find_best_stream(
pFormatContext, AVMEDIA_TYPE_AUDIO, -1, -1, nullptr, 0);
format_ctx, AVMEDIA_TYPE_AUDIO, -1, -1, nullptr, 0);
}
int64_t StreamReader::find_best_video_stream() const {
return av_find_best_stream(
pFormatContext, AVMEDIA_TYPE_VIDEO, -1, -1, nullptr, 0);
format_ctx, AVMEDIA_TYPE_VIDEO, -1, -1, nullptr, 0);
}
bool StreamReader::is_buffer_ready() const {
for (const auto& it : processors) {
if (it && !it->is_buffer_ready()) {
return false;
if (processors.empty()) {
// If no decoding output streams exist, then determine overall readiness
// from the readiness of packet buffer.
return packet_buffer->has_packets();
} else {
// Otherwise, determine readiness solely from the readiness of the decoding
// output streams.
for (const auto& it : processors) {
if (it && !it->is_buffer_ready()) {
return false;
}
}
}
return true;
......@@ -163,15 +262,42 @@ bool StreamReader::is_buffer_ready() const {
////////////////////////////////////////////////////////////////////////////////
// Configure methods
////////////////////////////////////////////////////////////////////////////////
void StreamReader::seek(double timestamp) {
TORCH_CHECK(timestamp >= 0, "timestamp must be non-negative.");
void StreamReader::seek(double timestamp_s, int64_t mode) {
TORCH_CHECK(timestamp_s >= 0, "timestamp must be non-negative.");
TORCH_CHECK(
format_ctx->nb_streams > 0,
"At least one stream must exist in this context");
int64_t timestamp_av_tb = static_cast<int64_t>(timestamp_s * AV_TIME_BASE);
int64_t ts = static_cast<int64_t>(timestamp * AV_TIME_BASE);
int ret = avformat_seek_file(pFormatContext, -1, INT64_MIN, ts, INT64_MAX, 0);
TORCH_CHECK(ret >= 0, "Failed to seek. (" + av_err2string(ret) + ".)");
int flag = AVSEEK_FLAG_BACKWARD;
switch (mode) {
case 0:
// reset seek_timestap as it is only used for precise seek
seek_timestamp = 0;
break;
case 1:
flag |= AVSEEK_FLAG_ANY;
// reset seek_timestap as it is only used for precise seek
seek_timestamp = 0;
break;
case 2:
seek_timestamp = timestamp_av_tb;
break;
default:
TORCH_CHECK(false, "Invalid mode value: ", mode);
}
int ret = av_seek_frame(format_ctx, -1, timestamp_av_tb, flag);
if (ret < 0) {
seek_timestamp = 0;
TORCH_CHECK(false, "Failed to seek. (" + av_err2string(ret) + ".)");
}
for (const auto& it : processors) {
if (it) {
it->flush();
it->set_discard_timestamp(seek_timestamp);
}
}
}
......@@ -188,7 +314,7 @@ void StreamReader::add_audio_stream(
AVMEDIA_TYPE_AUDIO,
static_cast<int>(frames_per_chunk),
static_cast<int>(num_chunks),
filter_desc,
filter_desc.value_or("anull"),
decoder,
decoder_option,
torch::Device(torch::DeviceType::CPU));
......@@ -209,9 +335,7 @@ void StreamReader::add_video_stream(
#ifdef USE_CUDA
torch::Device d{hw_accel.value()};
TORCH_CHECK(
d.type() == c10::DeviceType::CUDA,
"Only CUDA is supported for hardware acceleration. Found: ",
device.str());
d.is_cuda(), "Only CUDA is supported for HW acceleration. Found: ", d);
return d;
#else
TORCH_CHECK(
......@@ -225,47 +349,75 @@ void StreamReader::add_video_stream(
AVMEDIA_TYPE_VIDEO,
static_cast<int>(frames_per_chunk),
static_cast<int>(num_chunks),
filter_desc,
filter_desc.value_or("null"),
decoder,
decoder_option,
device);
}
void StreamReader::add_packet_stream(int i) {
validate_src_stream_index(format_ctx, i);
if (!packet_buffer) {
packet_buffer = std::make_unique<PacketBuffer>();
}
packet_stream_indices.emplace(i);
}
void StreamReader::add_stream(
int i,
AVMediaType media_type,
int frames_per_chunk,
int num_chunks,
const c10::optional<std::string>& filter_desc,
const std::string& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<OptionDict>& decoder_option,
const torch::Device& device) {
validate_src_stream_type(i, media_type);
validate_src_stream_type(format_ctx, i, media_type);
AVStream* stream = pFormatContext->streams[i];
// When media source is file-like object, it is possible that source codec is
// not detected properly.
AVStream* stream = format_ctx->streams[i];
// When media source is file-like object, it is possible that source codec
// is not detected properly.
TORCH_CHECK(
stream->codecpar->format != -1,
"Failed to detect the source stream format.");
if (!processors[i]) {
processors[i] = std::make_unique<StreamProcessor>(
processors[i] = std::make_unique<StreamProcessor>(stream->time_base);
processors[i]->set_discard_timestamp(seek_timestamp);
}
if (!processors[i]->is_decoder_set()) {
processors[i]->set_decoder(
stream->codecpar, decoder, decoder_option, device);
} else {
TORCH_CHECK(
!decoder && (!decoder_option || decoder_option.value().size() == 0),
"Decoder options were provided, but the decoder has already been initialized.")
}
stream->discard = AVDISCARD_DEFAULT;
auto frame_rate = [&]() -> AVRational {
switch (media_type) {
case AVMEDIA_TYPE_AUDIO:
return AVRational{0, 1};
case AVMEDIA_TYPE_VIDEO:
return av_guess_frame_rate(format_ctx, stream, nullptr);
default:
TORCH_INTERNAL_ASSERT(
false,
"Unexpected media type is given: ",
av_get_media_type_string(media_type));
}
}();
int key = processors[i]->add_stream(
stream->time_base,
stream->codecpar,
frames_per_chunk,
num_chunks,
filter_desc,
device);
frames_per_chunk, num_chunks, frame_rate, filter_desc, device);
stream_indices.push_back(std::make_pair<>(i, key));
}
void StreamReader::remove_stream(int64_t i) {
validate_output_stream_index(static_cast<int>(i));
TORCH_CHECK(
i >= 0 && static_cast<size_t>(i) < stream_indices.size(),
"Output stream index out of range");
auto it = stream_indices.begin() + i;
int iP = it->first;
processors[iP]->remove_stream(it->second);
......@@ -293,7 +445,7 @@ void StreamReader::remove_stream(int64_t i) {
// 1: It's done, caller should stop calling
// <0: Some error happened
int StreamReader::process_packet() {
int ret = av_read_frame(pFormatContext, pPacket);
int ret = av_read_frame(format_ctx, packet);
if (ret == AVERROR_EOF) {
ret = drain();
return (ret < 0) ? ret : 1;
......@@ -301,12 +453,21 @@ int StreamReader::process_packet() {
if (ret < 0) {
return ret;
}
AutoPacketUnref packet{pPacket};
auto& processor = processors[pPacket->stream_index];
AutoPacketUnref auto_unref{packet};
int stream_index = packet->stream_index;
if (packet_stream_indices.count(stream_index)) {
packet_buffer->push_packet(packet);
}
auto& processor = processors[stream_index];
if (!processor) {
return 0;
}
ret = processor->process_packet(packet);
return (ret < 0) ? ret : 0;
}
......@@ -344,6 +505,39 @@ int StreamReader::process_packet_block(double timeout, double backoff) {
}
}
void StreamReader::process_all_packets() {
int64_t ret = 0;
do {
ret = process_packet();
} while (!ret);
}
int StreamReader::process_packet(
const c10::optional<double>& timeout,
const double backoff) {
int code = [&]() -> int {
if (timeout.has_value()) {
return process_packet_block(timeout.value(), backoff);
}
return process_packet();
}();
TORCH_CHECK(
code >= 0, "Failed to process a packet. (" + av_err2string(code) + "). ");
return code;
}
int StreamReader::fill_buffer(
const c10::optional<double>& timeout,
const double backoff) {
while (!is_buffer_ready()) {
int code = process_packet(timeout, backoff);
if (code != 0) {
return code;
}
}
return 0;
}
// <0: Some error happened.
int StreamReader::drain() {
int ret = 0, tmp = 0;
......@@ -358,13 +552,58 @@ int StreamReader::drain() {
return ret;
}
std::vector<c10::optional<torch::Tensor>> StreamReader::pop_chunks() {
std::vector<c10::optional<torch::Tensor>> ret;
std::vector<c10::optional<Chunk>> StreamReader::pop_chunks() {
std::vector<c10::optional<Chunk>> ret;
ret.reserve(static_cast<size_t>(num_out_streams()));
for (auto& i : stream_indices) {
ret.push_back(processors[i.first]->pop_chunk(i.second));
ret.emplace_back(processors[i.first]->pop_chunk(i.second));
}
return ret;
}
} // namespace ffmpeg
} // namespace torchaudio
std::vector<AVPacketPtr> StreamReader::pop_packets() {
return packet_buffer->pop_packets();
}
//////////////////////////////////////////////////////////////////////////////
// StreamReaderCustomIO
//////////////////////////////////////////////////////////////////////////////
namespace detail {
namespace {
AVIOContext* get_io_context(
void* opaque,
int buffer_size,
int (*read_packet)(void* opaque, uint8_t* buf, int buf_size),
int64_t (*seek)(void* opaque, int64_t offset, int whence)) {
unsigned char* buffer = static_cast<unsigned char*>(av_malloc(buffer_size));
TORCH_CHECK(buffer, "Failed to allocate buffer.");
AVIOContext* io_ctx = avio_alloc_context(
buffer, buffer_size, 0, opaque, read_packet, nullptr, seek);
if (!io_ctx) {
av_freep(&buffer);
TORCH_CHECK(false, "Failed to allocate AVIOContext.");
}
return io_ctx;
}
} // namespace
CustomInput::CustomInput(
void* opaque,
int buffer_size,
int (*read_packet)(void* opaque, uint8_t* buf, int buf_size),
int64_t (*seek)(void* opaque, int64_t offset, int whence))
: io_ctx(get_io_context(opaque, buffer_size, read_packet, seek)) {}
} // namespace detail
StreamReaderCustomIO::StreamReaderCustomIO(
void* opaque,
const c10::optional<std::string>& format,
int buffer_size,
int (*read_packet)(void* opaque, uint8_t* buf, int buf_size),
int64_t (*seek)(void* opaque, int64_t offset, int whence),
const c10::optional<OptionDict>& option)
: CustomInput(opaque, buffer_size, read_packet, seek),
StreamReader(io_ctx, format, option) {}
} // namespace torchaudio::io
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment