Commit 014d7140 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Refactor Tensor conversion in StreamReader (#3170)

Summary:
Currently, when the Buffer converts AVFrame* to torch::Tensor,
it checks the format at each time a frame is passed, and
perform the conversion.

This commit changes it so that the conversion operation is
pre-instantiated at the time outside stream is configured.

It introduces Converter implementations for various formats,
and use template to embed them in Buffer class.
This way, branching like if/switch are eliminated from
decoding path.

Pull Request resolved: https://github.com/pytorch/audio/pull/3170

Reviewed By: xiaohui-zhang

Differential Revision: D44048293

Pulled By: mthrok

fbshipit-source-id: 30d8b240a5695d7513f499ce17853f2f0ffcab9f
parent 92f2ea89
......@@ -832,11 +832,49 @@ class StreamReaderAudioTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase)
if self.test_type == "fileobj":
src.seek(0)
self._test_wav(src, original, fmt=None)
# convert to float32
expected = _to_fltp(original)
if self.test_type == "fileobj":
src.seek(0)
self._test_wav(src, expected, fmt="fltp")
def test_audio_stream_format(self):
num_channels = 2
src, s32 = self.get_src(8000, dtype="int32", num_channels=num_channels)
args = {
"num_channels": num_channels,
"normalize": False,
"channels_first": False,
"num_frames": 1 << 16,
}
u8 = get_wav_data("uint8", **args)
s16 = get_wav_data("int16", **args)
s64 = s32.to(torch.int64) * (1 << 32)
f32 = get_wav_data("float32", **args)
f64 = get_wav_data("float64", **args)
s = StreamReader(src)
s.add_basic_audio_stream(frames_per_chunk=-1, format="u8")
s.add_basic_audio_stream(frames_per_chunk=-1, format="u8p")
s.add_basic_audio_stream(frames_per_chunk=-1, format="s16")
s.add_basic_audio_stream(frames_per_chunk=-1, format="s16p")
s.add_basic_audio_stream(frames_per_chunk=-1, format="s32")
s.add_basic_audio_stream(frames_per_chunk=-1, format="s32p")
s.add_basic_audio_stream(frames_per_chunk=-1, format="s64")
s.add_basic_audio_stream(frames_per_chunk=-1, format="s64p")
s.add_basic_audio_stream(frames_per_chunk=-1, format="flt")
s.add_basic_audio_stream(frames_per_chunk=-1, format="fltp")
s.add_basic_audio_stream(frames_per_chunk=-1, format="dbl")
s.add_basic_audio_stream(frames_per_chunk=-1, format="dblp")
s.process_all_packets()
chunks = s.pop_chunks()
self.assertEqual(chunks[0], u8, atol=1, rtol=0)
self.assertEqual(chunks[1], u8, atol=1, rtol=0)
self.assertEqual(chunks[2], s16)
self.assertEqual(chunks[3], s16)
self.assertEqual(chunks[4], s32)
self.assertEqual(chunks[5], s32)
self.assertEqual(chunks[6], s64)
self.assertEqual(chunks[7], s64)
self.assertEqual(chunks[8], f32)
self.assertEqual(chunks[9], f32)
self.assertEqual(chunks[10], f64)
self.assertEqual(chunks[11], f64)
@nested_params(
["int16", "uint8", "int32"], # "float", "double", "int64"]
......@@ -969,6 +1007,7 @@ class StreamReaderImageTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase)
rgb = torch.empty(1, 3, 256, 256, dtype=torch.uint8)
rgb[0, 0] = torch.arange(256, dtype=torch.uint8).reshape([1, -1])
rgb[0, 1] = torch.arange(256, dtype=torch.uint8).reshape([-1, 1])
alpha = torch.full((1, 1, 256, 256), 255, dtype=torch.uint8)
for i in range(256):
rgb[0, 2] = i
path = self.get_temp_path(f"ref_{i}.png")
......@@ -979,6 +1018,10 @@ class StreamReaderImageTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase)
yuv = rgb_to_yuv_ccir(rgb)
bgr = rgb[:, [2, 1, 0], :, :]
gray = rgb_to_gray(rgb)
argb = torch.cat([alpha, rgb], dim=1)
rgba = torch.cat([rgb, alpha], dim=1)
abgr = torch.cat([alpha, bgr], dim=1)
bgra = torch.cat([bgr, alpha], dim=1)
s = StreamReader(path)
s.add_basic_video_stream(frames_per_chunk=-1, format="yuv444p")
......@@ -988,12 +1031,20 @@ class StreamReaderImageTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase)
s.add_basic_video_stream(frames_per_chunk=-1, format="bgr24")
s.add_basic_video_stream(frames_per_chunk=-1, format="gray8")
s.add_basic_video_stream(frames_per_chunk=-1, format="rgb48le")
s.add_basic_video_stream(frames_per_chunk=-1, format="argb")
s.add_basic_video_stream(frames_per_chunk=-1, format="rgba")
s.add_basic_video_stream(frames_per_chunk=-1, format="abgr")
s.add_basic_video_stream(frames_per_chunk=-1, format="bgra")
s.process_all_packets()
yuv444, yuv420, nv12, rgb24, bgr24, gray8, rgb48le = s.pop_chunks()
self.assertEqual(yuv, yuv444, atol=1, rtol=0)
self.assertEqual(yuv, yuv420, atol=1, rtol=0)
self.assertEqual(yuv, nv12, atol=1, rtol=0)
self.assertEqual(rgb, rgb24, atol=0, rtol=0)
self.assertEqual(bgr, bgr24, atol=0, rtol=0)
self.assertEqual(gray, gray8, atol=1, rtol=0)
self.assertEqual(rgb16, rgb48le, atol=256, rtol=0)
chunks = s.pop_chunks()
self.assertEqual(chunks[0], yuv, atol=1, rtol=0)
self.assertEqual(chunks[1], yuv, atol=1, rtol=0)
self.assertEqual(chunks[2], yuv, atol=1, rtol=0)
self.assertEqual(chunks[3], rgb, atol=0, rtol=0)
self.assertEqual(chunks[4], bgr, atol=0, rtol=0)
self.assertEqual(chunks[5], gray, atol=1, rtol=0)
self.assertEqual(chunks[6], rgb16, atol=256, rtol=0)
self.assertEqual(chunks[7], argb, atol=0, rtol=0)
self.assertEqual(chunks[8], rgba, atol=0, rtol=0)
self.assertEqual(chunks[9], abgr, atol=0, rtol=0)
self.assertEqual(chunks[10], bgra, atol=0, rtol=0)
......@@ -9,9 +9,9 @@ set(
sources
ffmpeg.cpp
filter_graph.cpp
stream_reader/buffer/common.cpp
stream_reader/buffer/chunked_buffer.cpp
stream_reader/buffer/unchunked_buffer.cpp
stream_reader/conversion.cpp
stream_reader/sink.cpp
stream_reader/stream_processor.cpp
stream_reader/stream_reader.cpp
......
#include <torchaudio/csrc/ffmpeg/stream_reader/buffer/chunked_buffer.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/buffer/common.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/conversion.h>
namespace torchaudio {
namespace io {
namespace detail {
namespace torchaudio::io::detail {
ChunkedBuffer::ChunkedBuffer(
int frames_per_chunk,
int num_chunks,
double frame_duration)
: frame_duration(frame_duration),
frames_per_chunk(frames_per_chunk),
num_chunks(num_chunks) {}
ChunkedAudioBuffer::ChunkedAudioBuffer(
int frames_per_chunk,
int num_chunks,
double frame_duration)
: ChunkedBuffer(frames_per_chunk, num_chunks, frame_duration) {}
ChunkedVideoBuffer::ChunkedVideoBuffer(
int frames_per_chunk,
int num_chunks,
double frame_duration,
const torch::Device& device_)
: ChunkedBuffer(frames_per_chunk, num_chunks, frame_duration),
device(device_) {}
template <typename Converter>
ChunkedBuffer<Converter>::ChunkedBuffer(
int frames_per_chunk_,
int num_chunks_,
double frame_duration_,
Converter&& converter_)
: frame_duration(frame_duration_),
frames_per_chunk(frames_per_chunk_),
num_chunks(num_chunks_),
converter(std::move(converter_)){};
bool ChunkedBuffer::is_ready() const {
template <typename Converter>
bool ChunkedBuffer<Converter>::is_ready() const {
return num_buffered_frames >= frames_per_chunk;
}
void ChunkedBuffer::push_tensor(torch::Tensor frame, double pts_) {
template <typename Converter>
void ChunkedBuffer<Converter>::push_frame(AVFrame* frame_, double pts_) {
torch::Tensor frame = converter.convert(frame_);
using namespace torch::indexing;
// Note:
// Audio tensors contain multiple frames while video tensors contain only
......@@ -122,7 +113,8 @@ void ChunkedBuffer::push_tensor(torch::Tensor frame, double pts_) {
}
}
c10::optional<Chunk> ChunkedBuffer::pop_chunk() {
template <typename Converter>
c10::optional<Chunk> ChunkedBuffer<Converter>::pop_chunk() {
using namespace torch::indexing;
if (!num_buffered_frames) {
return {};
......@@ -138,19 +130,171 @@ c10::optional<Chunk> ChunkedBuffer::pop_chunk() {
return {Chunk{chunk, pts_val}};
}
void ChunkedAudioBuffer::push_frame(AVFrame* frame, double pts_) {
push_tensor(convert_audio(frame), pts_);
template <typename Converter>
void ChunkedBuffer<Converter>::flush() {
num_buffered_frames = 0;
chunks.clear();
}
void ChunkedVideoBuffer::push_frame(AVFrame* frame, double pts_) {
push_tensor(convert_image(frame, device), pts_);
std::unique_ptr<Buffer> get_chunked_buffer(
int frames_per_chunk,
int num_chunks,
double frame_duration,
AVSampleFormat fmt,
int channels) {
switch (fmt) {
case AV_SAMPLE_FMT_U8: {
using Converter = AudioConverter<torch::kUInt8, false>;
return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{channels});
}
case AV_SAMPLE_FMT_S16: {
using Converter = AudioConverter<torch::kInt16, false>;
return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{channels});
}
case AV_SAMPLE_FMT_S32: {
using Converter = AudioConverter<torch::kInt32, false>;
return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{channels});
}
case AV_SAMPLE_FMT_S64: {
using Converter = AudioConverter<torch::kInt64, false>;
return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{channels});
}
case AV_SAMPLE_FMT_FLT: {
using Converter = AudioConverter<torch::kFloat32, false>;
return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{channels});
}
case AV_SAMPLE_FMT_DBL: {
using Converter = AudioConverter<torch::kFloat64, false>;
return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{channels});
}
case AV_SAMPLE_FMT_U8P: {
using Converter = AudioConverter<torch::kUInt8, true>;
return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{channels});
}
case AV_SAMPLE_FMT_S16P: {
using Converter = AudioConverter<torch::kInt16, true>;
return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{channels});
}
case AV_SAMPLE_FMT_S32P: {
using Converter = AudioConverter<torch::kInt32, true>;
return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{channels});
}
case AV_SAMPLE_FMT_S64P: {
using Converter = AudioConverter<torch::kInt64, true>;
return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{channels});
}
case AV_SAMPLE_FMT_FLTP: {
using Converter = AudioConverter<torch::kFloat32, true>;
return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{channels});
}
case AV_SAMPLE_FMT_DBLP: {
using Converter = AudioConverter<torch::kFloat64, true>;
return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{channels});
}
default:
TORCH_INTERNAL_ASSERT(
false, "Unexpected audio type:", av_get_sample_fmt_name(fmt));
}
}
void ChunkedBuffer::flush() {
num_buffered_frames = 0;
chunks.clear();
std::unique_ptr<Buffer> get_chunked_buffer(
int frames_per_chunk,
int num_chunks,
double frame_duration,
AVPixelFormat fmt,
int h,
int w,
const torch::Device& device) {
if (device.type() == at::DeviceType::CUDA) {
#ifndef USE_CUDA
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
false,
"USE_CUDA is not defined, and it should be guarded before here.");
#else
switch (fmt) {
case AV_PIX_FMT_NV12: {
using Conv = NV12CudaConverter;
return std::make_unique<ChunkedBuffer<Conv>>(
frames_per_chunk, num_chunks, frame_duration, Conv{h, w, device});
}
case AV_PIX_FMT_P010: {
using Conv = P010CudaConverter;
return std::make_unique<ChunkedBuffer<Conv>>(
frames_per_chunk, num_chunks, frame_duration, Conv{h, w, device});
}
case AV_PIX_FMT_P016: {
TORCH_CHECK(
false,
"Unsupported video format found in CUDA HW: ",
av_get_pix_fmt_name(fmt));
}
default: {
TORCH_CHECK(
false,
"Unexpected video format found in CUDA HW: ",
av_get_pix_fmt_name(fmt));
}
}
#endif
}
switch (fmt) {
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24: {
using Converter = InterlacedImageConverter;
return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{h, w, 3});
}
case AV_PIX_FMT_ARGB:
case AV_PIX_FMT_RGBA:
case AV_PIX_FMT_ABGR:
case AV_PIX_FMT_BGRA: {
using Converter = InterlacedImageConverter;
return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{h, w, 4});
}
case AV_PIX_FMT_GRAY8: {
using Converter = InterlacedImageConverter;
return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{h, w, 1});
}
case AV_PIX_FMT_RGB48LE: {
using Converter = Interlaced16BitImageConverter;
return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{h, w, 3});
}
case AV_PIX_FMT_YUV444P: {
using Converter = PlanarImageConverter;
return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{h, w, 3});
}
case AV_PIX_FMT_YUV420P: {
using Converter = YUV420PConverter;
return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{h, w});
}
case AV_PIX_FMT_NV12: {
using Converter = NV12Converter;
return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{h, w});
}
default: {
TORCH_INTERNAL_ASSERT(
false, "Unexpected video format found: ", av_get_pix_fmt_name(fmt));
}
}
}
} // namespace detail
} // namespace io
} // namespace torchaudio
} // namespace torchaudio::io::detail
......@@ -2,14 +2,13 @@
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/buffer.h>
namespace torchaudio {
namespace io {
namespace detail {
namespace torchaudio::io::detail {
//////////////////////////////////////////////////////////////////////////////
// Chunked Buffer Implementation
//////////////////////////////////////////////////////////////////////////////
// Common to both audio and video
template <typename Converter>
class ChunkedBuffer : public Buffer {
// Each AVFrame is converted to a Tensor and stored here.
std::deque<torch::Tensor> chunks;
......@@ -28,40 +27,35 @@ class ChunkedBuffer : public Buffer {
// one Tensor contains multiple samples, so we track here.
int64_t num_buffered_frames = 0;
protected:
ChunkedBuffer(int frames_per_chunk, int num_chunks, double frame_duration);
void push_tensor(torch::Tensor frame, double pts);
Converter converter;
public:
bool is_ready() const override;
void flush() override;
c10::optional<Chunk> pop_chunk() override;
};
class ChunkedAudioBuffer : public ChunkedBuffer {
public:
ChunkedAudioBuffer(
int frames_per_chunk,
int num_chunks,
double frame_duration);
void push_frame(AVFrame* frame, double pts) override;
};
class ChunkedVideoBuffer : public ChunkedBuffer {
const torch::Device device;
public:
ChunkedVideoBuffer(
ChunkedBuffer(
int frames_per_chunk,
int num_chunks,
double frame_duration,
const torch::Device& device);
Converter&& converter);
void push_frame(AVFrame* frame, double pts) override;
bool is_ready() const override;
void flush() override;
c10::optional<Chunk> pop_chunk() override;
void push_frame(AVFrame* frame_, double pts_) override;
};
} // namespace detail
} // namespace io
} // namespace torchaudio
std::unique_ptr<Buffer> get_chunked_buffer(
int frames_per_chunk,
int num_chunks,
double frame_duration,
AVSampleFormat fmt,
int num_channels);
std::unique_ptr<Buffer> get_chunked_buffer(
int frames_per_chunk,
int num_chunks,
double frame_duration,
AVPixelFormat fmt,
int height,
int width,
const torch::Device& device);
} // namespace torchaudio::io::detail
#include <torchaudio/csrc/ffmpeg/stream_reader/buffer/common.h>
#include <stdexcept>
#include <vector>
#ifdef USE_CUDA
#include <c10/cuda/CUDAStream.h>
#endif
namespace torchaudio {
namespace io {
namespace detail {
torch::Tensor convert_audio(AVFrame* pFrame) {
// ref: https://ffmpeg.org/doxygen/4.1/filter__audio_8c_source.html#l00215
AVSampleFormat format = static_cast<AVSampleFormat>(pFrame->format);
int num_channels = pFrame->channels;
int bps = av_get_bytes_per_sample(format);
// Note
// FFMpeg's `nb_samples` represnts the number of samples par channel.
// This corresponds to `num_frames` in torchaudio's notation.
// Also torchaudio uses `num_samples` as the number of samples
// across channels.
int num_frames = pFrame->nb_samples;
int is_planar = av_sample_fmt_is_planar(format);
int num_planes = is_planar ? num_channels : 1;
int plane_size = bps * num_frames * (is_planar ? 1 : num_channels);
std::vector<int64_t> shape = is_planar
? std::vector<int64_t>{num_channels, num_frames}
: std::vector<int64_t>{num_frames, num_channels};
torch::Tensor t;
uint8_t* ptr = nullptr;
switch (format) {
case AV_SAMPLE_FMT_U8:
case AV_SAMPLE_FMT_U8P: {
t = torch::empty(shape, torch::kUInt8);
ptr = t.data_ptr<uint8_t>();
break;
}
case AV_SAMPLE_FMT_S16:
case AV_SAMPLE_FMT_S16P: {
t = torch::empty(shape, torch::kInt16);
ptr = reinterpret_cast<uint8_t*>(t.data_ptr<int16_t>());
break;
}
case AV_SAMPLE_FMT_S32:
case AV_SAMPLE_FMT_S32P: {
t = torch::empty(shape, torch::kInt32);
ptr = reinterpret_cast<uint8_t*>(t.data_ptr<int32_t>());
break;
}
case AV_SAMPLE_FMT_S64:
case AV_SAMPLE_FMT_S64P: {
t = torch::empty(shape, torch::kInt64);
ptr = reinterpret_cast<uint8_t*>(t.data_ptr<int64_t>());
break;
}
case AV_SAMPLE_FMT_FLT:
case AV_SAMPLE_FMT_FLTP: {
t = torch::empty(shape, torch::kFloat32);
ptr = reinterpret_cast<uint8_t*>(t.data_ptr<float>());
break;
}
case AV_SAMPLE_FMT_DBL:
case AV_SAMPLE_FMT_DBLP: {
t = torch::empty(shape, torch::kFloat64);
ptr = reinterpret_cast<uint8_t*>(t.data_ptr<double>());
break;
}
default:
TORCH_CHECK(
false,
"Unsupported audio format: " +
std::string(av_get_sample_fmt_name(format)));
}
for (int i = 0; i < num_planes; ++i) {
memcpy(ptr, pFrame->extended_data[i], plane_size);
ptr += plane_size;
}
if (is_planar) {
t = t.t();
}
return t;
}
namespace {
torch::Tensor get_buffer(
at::IntArrayRef shape,
const torch::Device& device = torch::Device(torch::kCPU),
const torch::Dtype dtype = torch::kUInt8) {
auto options = torch::TensorOptions()
.dtype(dtype)
.layout(torch::kStrided)
.device(device.type(), device.index());
return torch::empty(shape, options);
}
std::tuple<torch::Tensor, bool> get_image_buffer(
AVFrame* frame,
int num_frames,
const torch::Device& device) {
auto fmt = static_cast<AVPixelFormat>(frame->format);
const AVPixFmtDescriptor* desc = [&]() {
if (fmt == AV_PIX_FMT_CUDA) {
AVHWFramesContext* hwctx = (AVHWFramesContext*)frame->hw_frames_ctx->data;
return av_pix_fmt_desc_get(hwctx->sw_format);
}
return av_pix_fmt_desc_get(fmt);
}();
int channels = desc->nb_components;
// Note
// AVPixFmtDescriptor::nb_components represents the number of
// color components. This is different from the number of planes.
//
// For example, YUV420P has three color components Y, U and V, but
// U and V are squashed into the same plane, so there are only
// two planes.
//
// In our application, we cannot express the bare YUV420P as a
// single tensor, so we convert it to 3 channel tensor.
// For this reason, we use nb_components for the number of channels,
// instead of the number of planes.
//
// The actual number of planes can be retrieved with
// av_pix_fmt_count_planes.
int height = frame->height;
int width = frame->width;
int depth = desc->comp[0].depth;
auto dtype = (depth > 8) ? torch::kInt16 : torch::kUInt8;
if (desc->flags & AV_PIX_FMT_FLAG_PLANAR) {
auto buffer =
get_buffer({num_frames, channels, height, width}, device, dtype);
return std::make_tuple(buffer, true);
}
auto buffer =
get_buffer({num_frames, height, width, channels}, device, dtype);
return std::make_tuple(buffer, false);
}
void write_interlaced_image(AVFrame* pFrame, torch::Tensor& frame) {
auto ptr = frame.data_ptr<uint8_t>();
uint8_t* buf = pFrame->data[0];
size_t height = frame.size(1);
size_t stride = frame.size(2) * frame.size(3);
for (int i = 0; i < height; ++i) {
memcpy(ptr, buf, stride);
buf += pFrame->linesize[0];
ptr += stride;
}
}
void write_interlaced_image16(AVFrame* pFrame, torch::Tensor& frame) {
auto ptr = frame.data_ptr<int16_t>();
uint8_t* buf = pFrame->data[0];
size_t height = frame.size(1);
size_t stride = frame.size(2) * frame.size(3);
for (int i = 0; i < height; ++i) {
memcpy(ptr, buf, stride * 2);
buf += pFrame->linesize[0];
ptr += stride;
}
// correct for int16
frame += 32768;
}
void write_planar_image(AVFrame* pFrame, torch::Tensor& frame) {
int num_planes = static_cast<int>(frame.size(1));
int height = static_cast<int>(frame.size(2));
int width = static_cast<int>(frame.size(3));
for (int i = 0; i < num_planes; ++i) {
torch::Tensor plane = frame.index({0, i});
uint8_t* tgt = plane.data_ptr<uint8_t>();
uint8_t* src = pFrame->data[i];
int linesize = pFrame->linesize[i];
for (int h = 0; h < height; ++h) {
memcpy(tgt, src, width);
tgt += width;
src += linesize;
}
}
}
void write_yuv420p(AVFrame* pFrame, torch::Tensor& yuv) {
int height = static_cast<int>(yuv.size(2));
int width = static_cast<int>(yuv.size(3));
// Write Y plane directly
{
uint8_t* tgt = yuv.data_ptr<uint8_t>();
uint8_t* src = pFrame->data[0];
int linesize = pFrame->linesize[0];
for (int h = 0; h < height; ++h) {
memcpy(tgt, src, width);
tgt += width;
src += linesize;
}
}
// Prepare intermediate UV plane
torch::Tensor uv = get_buffer({1, 2, height / 2, width / 2});
{
uint8_t* tgt = uv.data_ptr<uint8_t>();
uint8_t* src = pFrame->data[1];
int linesize = pFrame->linesize[1];
for (int h = 0; h < height / 2; ++h) {
memcpy(tgt, src, width / 2);
tgt += width / 2;
src += linesize;
}
src = pFrame->data[2];
linesize = pFrame->linesize[2];
for (int h = 0; h < height / 2; ++h) {
memcpy(tgt, src, width / 2);
tgt += width / 2;
src += linesize;
}
}
// Upsample width and height
namespace F = torch::nn::functional;
using namespace torch::indexing;
uv = F::interpolate(
uv,
F::InterpolateFuncOptions()
.mode(torch::kNearest)
.size(std::vector<int64_t>({height, width})));
// Write to the UV plane
// yuv[:, 1:] = uv
yuv.index_put_({Slice(), Slice(1)}, uv);
}
void write_nv12_cpu(AVFrame* pFrame, torch::Tensor& yuv) {
int height = static_cast<int>(yuv.size(2));
int width = static_cast<int>(yuv.size(3));
// Write Y plane directly
{
uint8_t* tgt = yuv.data_ptr<uint8_t>();
uint8_t* src = pFrame->data[0];
int linesize = pFrame->linesize[0];
for (int h = 0; h < height; ++h) {
memcpy(tgt, src, width);
tgt += width;
src += linesize;
}
}
// Prepare intermediate UV plane
torch::Tensor uv = get_buffer({1, height / 2, width / 2, 2});
{
uint8_t* tgt = uv.data_ptr<uint8_t>();
uint8_t* src = pFrame->data[1];
int linesize = pFrame->linesize[1];
for (int h = 0; h < height / 2; ++h) {
memcpy(tgt, src, width);
tgt += width;
src += linesize;
}
}
// Upsample width and height
namespace F = torch::nn::functional;
using namespace torch::indexing;
uv = F::interpolate(
uv.permute({0, 3, 1, 2}),
F::InterpolateFuncOptions()
.mode(torch::kNearest)
.size(std::vector<int64_t>({height, width})));
// Write to the UV plane
// yuv[:, 1:] = uv
yuv.index_put_({Slice(), Slice(1)}, uv);
}
#ifdef USE_CUDA
void write_nv12_cuda(AVFrame* pFrame, torch::Tensor& yuv) {
int height = static_cast<int>(yuv.size(2));
int width = static_cast<int>(yuv.size(3));
// Write Y plane directly
{
uint8_t* tgt = yuv.data_ptr<uint8_t>();
CUdeviceptr src = (CUdeviceptr)pFrame->data[0];
int linesize = pFrame->linesize[0];
TORCH_CHECK(
cudaSuccess ==
cudaMemcpy2D(
(void*)tgt,
width,
(const void*)src,
linesize,
width,
height,
cudaMemcpyDeviceToDevice),
"Failed to copy Y plane to Cuda tensor.");
}
// Preapare intermediate UV planes
torch::Tensor uv = get_buffer({1, height / 2, width / 2, 2}, yuv.device());
{
uint8_t* tgt = uv.data_ptr<uint8_t>();
CUdeviceptr src = (CUdeviceptr)pFrame->data[1];
int linesize = pFrame->linesize[1];
TORCH_CHECK(
cudaSuccess ==
cudaMemcpy2D(
(void*)tgt,
width,
(const void*)src,
linesize,
width,
height / 2,
cudaMemcpyDeviceToDevice),
"Failed to copy UV plane to Cuda tensor.");
}
// Upsample width and height
namespace F = torch::nn::functional;
using namespace torch::indexing;
uv = F::interpolate(
uv.permute({0, 3, 1, 2}),
F::InterpolateFuncOptions()
.mode(torch::kNearest)
.size(std::vector<int64_t>({height, width})));
// Write to the UV plane
// yuv[:, 1:] = uv
yuv.index_put_({Slice(), Slice(1)}, uv);
}
void write_p010_cuda(AVFrame* pFrame, torch::Tensor& yuv) {
int height = static_cast<int>(yuv.size(2));
int width = static_cast<int>(yuv.size(3));
// Write Y plane directly
{
int16_t* tgt = yuv.data_ptr<int16_t>();
CUdeviceptr src = (CUdeviceptr)pFrame->data[0];
int linesize = pFrame->linesize[0];
TORCH_CHECK(
cudaSuccess ==
cudaMemcpy2D(
(void*)tgt,
width * 2,
(const void*)src,
linesize,
width * 2,
height,
cudaMemcpyDeviceToDevice),
"Failed to copy Y plane to Cuda tensor.");
}
// Prepare intermediate UV planes
torch::Tensor uv =
get_buffer({1, height / 2, width / 2, 2}, yuv.device(), torch::kInt16);
{
int16_t* tgt = uv.data_ptr<int16_t>();
CUdeviceptr src = (CUdeviceptr)pFrame->data[1];
int linesize = pFrame->linesize[1];
TORCH_CHECK(
cudaSuccess ==
cudaMemcpy2D(
(void*)tgt,
width * 2,
(const void*)src,
linesize,
width * 2,
height / 2,
cudaMemcpyDeviceToDevice),
"Failed to copy UV plane to Cuda tensor.");
}
uv = uv.permute({0, 3, 1, 2});
using namespace torch::indexing;
// Write to the UV plane
// very simplistic upscale using indexing since interpolate doesn't support
// shorts
yuv.index_put_(
{Slice(), Slice(1, 3), Slice(None, None, 2), Slice(None, None, 2)}, uv);
yuv.index_put_(
{Slice(), Slice(1, 3), Slice(1, None, 2), Slice(None, None, 2)}, uv);
yuv.index_put_(
{Slice(), Slice(1, 3), Slice(None, None, 2), Slice(1, None, 2)}, uv);
yuv.index_put_(
{Slice(), Slice(1, 3), Slice(1, None, 2), Slice(1, None, 2)}, uv);
// correct for int16
yuv += 32768;
}
#endif
void write_image(AVFrame* frame, torch::Tensor& buf) {
// ref:
// https://ffmpeg.org/doxygen/4.1/filtering__video_8c_source.html#l00179
// https://ffmpeg.org/doxygen/4.1/decode__video_8c_source.html#l00038
AVPixelFormat format = static_cast<AVPixelFormat>(frame->format);
switch (format) {
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24:
case AV_PIX_FMT_ARGB:
case AV_PIX_FMT_RGBA:
case AV_PIX_FMT_ABGR:
case AV_PIX_FMT_BGRA:
case AV_PIX_FMT_GRAY8: {
write_interlaced_image(frame, buf);
return;
}
case AV_PIX_FMT_YUV444P: {
write_planar_image(frame, buf);
return;
}
case AV_PIX_FMT_YUV420P: {
write_yuv420p(frame, buf);
return;
}
case AV_PIX_FMT_NV12: {
write_nv12_cpu(frame, buf);
return;
}
case AV_PIX_FMT_RGB48LE: {
write_interlaced_image16(frame, buf);
return;
}
#ifdef USE_CUDA
case AV_PIX_FMT_CUDA: {
AVHWFramesContext* hwctx = (AVHWFramesContext*)frame->hw_frames_ctx->data;
AVPixelFormat sw_format = hwctx->sw_format;
// cuvid decoder (nvdec frontend of ffmpeg) only supports the following
// output formats
// https://github.com/FFmpeg/FFmpeg/blob/072101bd52f7f092ee976f4e6e41c19812ad32fd/libavcodec/cuviddec.c#L1121-L1124
switch (sw_format) {
case AV_PIX_FMT_NV12: {
write_nv12_cuda(frame, buf);
return;
}
case AV_PIX_FMT_P010: {
write_p010_cuda(frame, buf);
return;
}
case AV_PIX_FMT_P016:
TORCH_CHECK(
false,
"Unsupported video format found in CUDA HW: " +
std::string(av_get_pix_fmt_name(sw_format)));
default:
TORCH_CHECK(
false,
"Unexpected video format found in CUDA HW: " +
std::string(av_get_pix_fmt_name(sw_format)));
}
}
#endif
default:
TORCH_CHECK(
false,
"Unexpected video format: " +
std::string(av_get_pix_fmt_name(format)));
}
}
} // namespace
torch::Tensor convert_image(AVFrame* frame, const torch::Device& device) {
auto [buffer, is_planar] = get_image_buffer(frame, 1, device);
write_image(frame, buffer);
return is_planar ? buffer : buffer.permute({0, 3, 1, 2});
}
} // namespace detail
} // namespace io
} // namespace torchaudio
#pragma once
#include <torch/torch.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
namespace torchaudio {
namespace io {
namespace detail {
//////////////////////////////////////////////////////////////////////////////
// Helper functions
//////////////////////////////////////////////////////////////////////////////
torch::Tensor convert_audio(AVFrame* frame);
torch::Tensor convert_image(AVFrame* frame, const torch::Device& device);
} // namespace detail
} // namespace io
} // namespace torchaudio
#include <torchaudio/csrc/ffmpeg/stream_reader/buffer/common.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/buffer/unchunked_buffer.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/conversion.h>
namespace torchaudio {
namespace io {
namespace detail {
UnchunkedVideoBuffer::UnchunkedVideoBuffer(const torch::Device& device)
: device(device) {}
template <typename Converter>
UnchunkedBuffer<Converter>::UnchunkedBuffer(Converter&& converter_)
: converter(std::move(converter_)) {}
bool UnchunkedBuffer::is_ready() const {
template <typename Converter>
bool UnchunkedBuffer<Converter>::is_ready() const {
return chunks.size() > 0;
}
void UnchunkedBuffer::push_tensor(const torch::Tensor& t, double pts_) {
template <typename Converter>
void UnchunkedBuffer<Converter>::push_frame(AVFrame* frame, double pts_) {
if (chunks.size() == 0) {
pts = pts_;
}
chunks.push_back(t);
chunks.push_back(converter.convert(frame));
}
void UnchunkedAudioBuffer::push_frame(AVFrame* frame, double pts_) {
push_tensor(convert_audio(frame), pts_);
}
void UnchunkedVideoBuffer::push_frame(AVFrame* frame, double pts_) {
push_tensor(convert_image(frame, device), pts_);
}
c10::optional<Chunk> UnchunkedBuffer::pop_chunk() {
template <typename Converter>
c10::optional<Chunk> UnchunkedBuffer<Converter>::pop_chunk() {
if (chunks.size() == 0) {
return {};
}
......@@ -38,10 +34,143 @@ c10::optional<Chunk> UnchunkedBuffer::pop_chunk() {
return {Chunk{frames, pts}};
}
void UnchunkedBuffer::flush() {
template <typename Converter>
void UnchunkedBuffer<Converter>::flush() {
chunks.clear();
}
std::unique_ptr<Buffer> get_unchunked_buffer(AVSampleFormat fmt, int channels) {
switch (fmt) {
case AV_SAMPLE_FMT_U8: {
using Converter = AudioConverter<torch::kUInt8, false>;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{channels});
}
case AV_SAMPLE_FMT_S16: {
using Converter = AudioConverter<torch::kInt16, false>;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{channels});
}
case AV_SAMPLE_FMT_S32: {
using Converter = AudioConverter<torch::kInt32, false>;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{channels});
}
case AV_SAMPLE_FMT_S64: {
using Converter = AudioConverter<torch::kInt64, false>;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{channels});
}
case AV_SAMPLE_FMT_FLT: {
using Converter = AudioConverter<torch::kFloat32, false>;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{channels});
}
case AV_SAMPLE_FMT_DBL: {
using Converter = AudioConverter<torch::kFloat64, false>;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{channels});
}
case AV_SAMPLE_FMT_U8P: {
using Converter = AudioConverter<torch::kUInt8, true>;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{channels});
}
case AV_SAMPLE_FMT_S16P: {
using Converter = AudioConverter<torch::kInt16, true>;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{channels});
}
case AV_SAMPLE_FMT_S32P: {
using Converter = AudioConverter<torch::kInt32, true>;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{channels});
}
case AV_SAMPLE_FMT_S64P: {
using Converter = AudioConverter<torch::kInt64, true>;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{channels});
}
case AV_SAMPLE_FMT_FLTP: {
using Converter = AudioConverter<torch::kFloat32, true>;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{channels});
}
case AV_SAMPLE_FMT_DBLP: {
using Converter = AudioConverter<torch::kFloat64, true>;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{channels});
}
default:
TORCH_INTERNAL_ASSERT(
false, "Unexpected audio type:", av_get_sample_fmt_name(fmt));
}
}
std::unique_ptr<Buffer> get_unchunked_buffer(
AVPixelFormat fmt,
int h,
int w,
const torch::Device& device) {
if (device.type() == at::DeviceType::CUDA) {
#ifndef USE_CUDA
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
false,
"USE_CUDA is not defined, and it should be guarded before here.");
#else
switch (fmt) {
case AV_PIX_FMT_NV12: {
using Conv = NV12CudaConverter;
return std::make_unique<UnchunkedBuffer<Conv>>(Conv{h, w, device});
}
case AV_PIX_FMT_P010: {
using Conv = P010CudaConverter;
return std::make_unique<UnchunkedBuffer<Conv>>(Conv{h, w, device});
}
case AV_PIX_FMT_P016: {
TORCH_CHECK(
false,
"Unsupported video format found in CUDA HW: ",
av_get_pix_fmt_name(fmt));
}
default: {
TORCH_CHECK(
false,
"Unexpected video format found in CUDA HW: ",
av_get_pix_fmt_name(fmt));
}
}
#endif
}
switch (fmt) {
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24: {
using Converter = InterlacedImageConverter;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{h, w, 3});
}
case AV_PIX_FMT_ARGB:
case AV_PIX_FMT_RGBA:
case AV_PIX_FMT_ABGR:
case AV_PIX_FMT_BGRA: {
using Converter = InterlacedImageConverter;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{h, w, 4});
}
case AV_PIX_FMT_GRAY8: {
using Converter = InterlacedImageConverter;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{h, w, 1});
}
case AV_PIX_FMT_RGB48LE: {
using Converter = Interlaced16BitImageConverter;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{h, w, 3});
}
case AV_PIX_FMT_YUV444P: {
using Converter = PlanarImageConverter;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{h, w, 3});
}
case AV_PIX_FMT_YUV420P: {
using Converter = YUV420PConverter;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{h, w});
}
case AV_PIX_FMT_NV12: {
using Converter = NV12Converter;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{h, w});
}
default: {
TORCH_INTERNAL_ASSERT(
false, "Unexpected video format found: ", av_get_pix_fmt_name(fmt));
}
}
}
} // namespace detail
} // namespace io
} // namespace torchaudio
......@@ -4,43 +4,36 @@
#include <torchaudio/csrc/ffmpeg/stream_reader/buffer.h>
#include <deque>
namespace torchaudio {
namespace io {
namespace detail {
namespace torchaudio::io::detail {
//////////////////////////////////////////////////////////////////////////////
// Unchunked Buffer Interface
//////////////////////////////////////////////////////////////////////////////
// Partial implementation for unchunked buffer common to both audio and video
// Used for buffering audio/video streams without chunking
template <typename Converter>
class UnchunkedBuffer : public Buffer {
// Each AVFrame is converted to a Tensor and stored here.
std::deque<torch::Tensor> chunks;
double pts = -1.;
protected:
void push_tensor(const torch::Tensor& t, double pts);
Converter converter;
public:
UnchunkedBuffer(Converter&& converter);
bool is_ready() const override;
void push_frame(AVFrame* frame, double pts_) override;
c10::optional<Chunk> pop_chunk() override;
void flush() override;
};
class UnchunkedAudioBuffer : public UnchunkedBuffer {
public:
void push_frame(AVFrame* frame, double pts) override;
};
class UnchunkedVideoBuffer : public UnchunkedBuffer {
const torch::Device device;
std::unique_ptr<Buffer> get_unchunked_buffer(
AVSampleFormat fmt,
int num_channels);
public:
explicit UnchunkedVideoBuffer(const torch::Device& device);
void push_frame(AVFrame* frame, double pts) override;
};
std::unique_ptr<Buffer> get_unchunked_buffer(
AVPixelFormat fmt,
int height,
int width,
const torch::Device& device);
} // namespace detail
} // namespace io
} // namespace torchaudio
} // namespace torchaudio::io::detail
#include <torch/torch.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/conversion.h>
#ifdef USE_CUDA
#include <c10/cuda/CUDAStream.h>
#endif
namespace torchaudio::io {
////////////////////////////////////////////////////////////////////////////////
// Audio
////////////////////////////////////////////////////////////////////////////////
template <c10::ScalarType dtype, bool is_planar>
AudioConverter<dtype, is_planar>::AudioConverter(int c) : num_channels(c) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(num_channels > 0);
}
template <c10::ScalarType dtype, bool is_planar>
torch::Tensor AudioConverter<dtype, is_planar>::convert(const AVFrame* src) {
if constexpr (is_planar) {
torch::Tensor dst = torch::empty({num_channels, src->nb_samples}, dtype);
convert(src, dst);
return dst.permute({1, 0});
} else {
torch::Tensor dst = torch::empty({src->nb_samples, num_channels}, dtype);
convert(src, dst);
return dst;
}
}
// Converts AVFrame* into pre-allocated Tensor.
// The shape must be [C, T] if is_planar otherwise [T, C]
template <c10::ScalarType dtype, bool is_planar>
void AudioConverter<dtype, is_planar>::convert(
const AVFrame* src,
torch::Tensor& dst) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(num_channels == src->channels);
constexpr int bps = []() {
switch (dtype) {
case torch::kUInt8:
return 1;
case torch::kInt16:
return 2;
case torch::kInt32:
case torch::kFloat32:
return 4;
case torch::kInt64:
case torch::kFloat64:
return 8;
}
}();
// Note
// FFMpeg's `nb_samples` represnts the number of samples par channel.
// whereas, in torchaudio, `num_samples` is used to represent the number of
// samples across channels. torchaudio uses `num_frames` for per-channel
// samples.
if constexpr (is_planar) {
int plane_size = bps * src->nb_samples;
uint8_t* p_dst = static_cast<uint8_t*>(dst.data_ptr());
for (int i = 0; i < num_channels; ++i) {
memcpy(p_dst, src->extended_data[i], plane_size);
p_dst += plane_size;
}
} else {
int plane_size = bps * src->nb_samples * num_channels;
memcpy(dst.data_ptr(), src->extended_data[0], plane_size);
}
}
// Explicit instantiation
template class AudioConverter<torch::kUInt8, false>;
template class AudioConverter<torch::kUInt8, true>;
template class AudioConverter<torch::kInt16, false>;
template class AudioConverter<torch::kInt16, true>;
template class AudioConverter<torch::kInt32, false>;
template class AudioConverter<torch::kInt32, true>;
template class AudioConverter<torch::kInt64, false>;
template class AudioConverter<torch::kInt64, true>;
template class AudioConverter<torch::kFloat32, false>;
template class AudioConverter<torch::kFloat32, true>;
template class AudioConverter<torch::kFloat64, false>;
template class AudioConverter<torch::kFloat64, true>;
////////////////////////////////////////////////////////////////////////////////
// Image
////////////////////////////////////////////////////////////////////////////////
namespace {
torch::Tensor get_image_buffer(
at::IntArrayRef shape,
const torch::Dtype dtype = torch::kUInt8) {
return torch::empty(
shape, torch::TensorOptions().dtype(dtype).layout(torch::kStrided));
}
torch::Tensor get_image_buffer(
at::IntArrayRef shape,
torch::Device device,
const torch::Dtype dtype = torch::kUInt8) {
return torch::empty(
shape,
torch::TensorOptions()
.dtype(dtype)
.layout(torch::kStrided)
.device(device));
}
} // namespace
ImageConverterBase::ImageConverterBase(int h, int w, int c)
: height(h), width(w), num_channels(c) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(height > 0);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(width > 0);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(num_channels > 0);
}
////////////////////////////////////////////////////////////////////////////////
// Interlaced Image
////////////////////////////////////////////////////////////////////////////////
void InterlacedImageConverter::convert(const AVFrame* src, torch::Tensor& dst) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->height == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(1) == height);
int stride = width * num_channels;
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(2) * dst.size(3) == stride);
auto p_dst = dst.data_ptr<uint8_t>();
uint8_t* p_src = src->data[0];
for (int i = 0; i < height; ++i) {
memcpy(p_dst, p_src, stride);
p_src += src->linesize[0];
p_dst += stride;
}
}
torch::Tensor InterlacedImageConverter::convert(const AVFrame* src) {
torch::Tensor buffer = get_image_buffer({1, height, width, num_channels});
convert(src, buffer);
return buffer.permute({0, 3, 1, 2});
}
////////////////////////////////////////////////////////////////////////////////
// Interlaced 16 Bit Image
////////////////////////////////////////////////////////////////////////////////
void Interlaced16BitImageConverter::convert(
const AVFrame* src,
torch::Tensor& dst) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->height == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(1) == height);
int stride = width * num_channels;
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(2) * dst.size(3) == stride);
auto p_dst = dst.data_ptr<int16_t>();
uint8_t* p_src = src->data[0];
for (int i = 0; i < height; ++i) {
memcpy(p_dst, p_src, stride * 2);
p_src += src->linesize[0];
p_dst += stride;
}
// correct for int16
dst += 32768;
}
torch::Tensor Interlaced16BitImageConverter::convert(const AVFrame* src) {
torch::Tensor buffer =
get_image_buffer({1, height, width, num_channels}, torch::kInt16);
convert(src, buffer);
return buffer.permute({0, 3, 1, 2});
}
////////////////////////////////////////////////////////////////////////////////
// Planar Image
////////////////////////////////////////////////////////////////////////////////
void PlanarImageConverter::convert(const AVFrame* src, torch::Tensor& dst) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->height == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->width == width);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(1) == num_channels);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(2) == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(3) == width);
for (int i = 0; i < num_channels; ++i) {
torch::Tensor plane = dst.index({0, i});
uint8_t* p_dst = plane.data_ptr<uint8_t>();
uint8_t* p_src = src->data[i];
int linesize = src->linesize[i];
for (int h = 0; h < height; ++h) {
memcpy(p_dst, p_src, width);
p_src += linesize;
p_dst += width;
}
}
}
torch::Tensor PlanarImageConverter::convert(const AVFrame* src) {
torch::Tensor buffer = get_image_buffer({1, num_channels, height, width});
convert(src, buffer);
return buffer;
}
////////////////////////////////////////////////////////////////////////////////
// YUV420P
////////////////////////////////////////////////////////////////////////////////
YUV420PConverter::YUV420PConverter(int h, int w)
: ImageConverterBase(h, w, 3),
tmp_uv(get_image_buffer({1, 2, height / 2, width / 2})) {}
void YUV420PConverter::convert(const AVFrame* src, torch::Tensor& dst) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
(AVPixelFormat)(src->format) == AV_PIX_FMT_YUV420P);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->height == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->width == width);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(1) == 3);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(2) == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(3) == width);
// Write Y plane directly
{
uint8_t* p_dst = dst.data_ptr<uint8_t>();
uint8_t* p_src = src->data[0];
for (int h = 0; h < height; ++h) {
memcpy(p_dst, p_src, width);
p_dst += width;
p_src += src->linesize[0];
}
}
// Write intermediate UV plane
{
uint8_t* p_dst = tmp_uv.data_ptr<uint8_t>();
uint8_t* p_src = src->data[1];
for (int h = 0; h < height / 2; ++h) {
memcpy(p_dst, p_src, width / 2);
p_dst += width / 2;
p_src += src->linesize[1];
}
p_src = src->data[2];
for (int h = 0; h < height / 2; ++h) {
memcpy(p_dst, p_src, width / 2);
p_dst += width / 2;
p_src += src->linesize[2];
}
}
// Upsample width and height
namespace F = torch::nn::functional;
torch::Tensor uv = F::interpolate(
tmp_uv,
F::InterpolateFuncOptions()
.mode(torch::kNearest)
.size(std::vector<int64_t>({height, width})));
// Write to the UV plane
// dst[:, 1:] = uv
using namespace torch::indexing;
dst.index_put_({Slice(), Slice(1)}, uv);
}
torch::Tensor YUV420PConverter::convert(const AVFrame* src) {
torch::Tensor buffer = get_image_buffer({1, num_channels, height, width});
convert(src, buffer);
return buffer;
}
////////////////////////////////////////////////////////////////////////////////
// NV12
////////////////////////////////////////////////////////////////////////////////
NV12Converter::NV12Converter(int h, int w)
: ImageConverterBase(h, w, 3),
tmp_uv(get_image_buffer({1, height / 2, width / 2, 2})) {}
void NV12Converter::convert(const AVFrame* src, torch::Tensor& dst) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
(AVPixelFormat)(src->format) == AV_PIX_FMT_NV12);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->height == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->width == width);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(1) == 3);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(2) == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(3) == width);
// Write Y plane directly
{
uint8_t* p_dst = dst.data_ptr<uint8_t>();
uint8_t* p_src = src->data[0];
for (int h = 0; h < height; ++h) {
memcpy(p_dst, p_src, width);
p_dst += width;
p_src += src->linesize[0];
}
}
// Write intermediate UV plane
{
uint8_t* p_dst = tmp_uv.data_ptr<uint8_t>();
uint8_t* p_src = src->data[1];
for (int h = 0; h < height / 2; ++h) {
memcpy(p_dst, p_src, width);
p_dst += width;
p_src += src->linesize[1];
}
}
// Upsample width and height
namespace F = torch::nn::functional;
torch::Tensor uv = F::interpolate(
tmp_uv.permute({0, 3, 1, 2}),
F::InterpolateFuncOptions()
.mode(torch::kNearest)
.size(std::vector<int64_t>({height, width})));
// Write to the UV plane
// dst[:, 1:] = uv
using namespace torch::indexing;
dst.index_put_({Slice(), Slice(1)}, uv);
}
torch::Tensor NV12Converter::convert(const AVFrame* src) {
torch::Tensor buffer = get_image_buffer({1, num_channels, height, width});
convert(src, buffer);
return buffer;
}
#ifdef USE_CUDA
////////////////////////////////////////////////////////////////////////////////
// NV12 CUDA
////////////////////////////////////////////////////////////////////////////////
NV12CudaConverter::NV12CudaConverter(int h, int w, const torch::Device& device)
: ImageConverterBase(h, w, 3),
tmp_uv(get_image_buffer(
{1, height / 2, width / 2, 2},
device,
torch::kUInt8)) {}
void NV12CudaConverter::convert(const AVFrame* src, torch::Tensor& dst) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->height == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->width == width);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(1) == 3);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(2) == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(3) == width);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.dtype() == torch::kUInt8);
auto fmt = (AVPixelFormat)(src->format);
AVHWFramesContext* hwctx = (AVHWFramesContext*)src->hw_frames_ctx->data;
AVPixelFormat sw_fmt = hwctx->sw_format;
TORCH_INTERNAL_ASSERT(
AV_PIX_FMT_CUDA == fmt,
"Expected CUDA frame. Found: ",
av_get_pix_fmt_name(fmt));
TORCH_INTERNAL_ASSERT(
AV_PIX_FMT_NV12 == sw_fmt,
"Expected NV12 format. Found: ",
av_get_pix_fmt_name(sw_fmt));
// Write Y plane directly
auto status = cudaMemcpy2D(
dst.data_ptr(),
width,
src->data[0],
src->linesize[0],
width,
height,
cudaMemcpyDeviceToDevice);
TORCH_CHECK(cudaSuccess == status, "Failed to copy Y plane to Cuda tensor.");
// Preapare intermediate UV planes
status = cudaMemcpy2D(
tmp_uv.data_ptr(),
width,
src->data[1],
src->linesize[1],
width,
height / 2,
cudaMemcpyDeviceToDevice);
TORCH_CHECK(cudaSuccess == status, "Failed to copy UV plane to Cuda tensor.");
// Upsample width and height
namespace F = torch::nn::functional;
torch::Tensor uv = F::interpolate(
tmp_uv.permute({0, 3, 1, 2}),
F::InterpolateFuncOptions()
.mode(torch::kNearest)
.size(std::vector<int64_t>({height, width})));
// Write to the UV plane
// dst[:, 1:] = uv
using namespace torch::indexing;
dst.index_put_({Slice(), Slice(1)}, uv);
}
torch::Tensor NV12CudaConverter::convert(const AVFrame* src) {
torch::Tensor buffer =
get_image_buffer({1, num_channels, height, width}, tmp_uv.device());
convert(src, buffer);
return buffer;
}
////////////////////////////////////////////////////////////////////////////////
// P010 CUDA
////////////////////////////////////////////////////////////////////////////////
P010CudaConverter::P010CudaConverter(int h, int w, const torch::Device& device)
: ImageConverterBase(h, w, 3),
tmp_uv(get_image_buffer(
{1, height / 2, width / 2, 2},
device,
torch::kInt16)) {}
void P010CudaConverter::convert(const AVFrame* src, torch::Tensor& dst) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->height == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->width == width);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(1) == 3);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(2) == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(3) == width);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.dtype() == torch::kInt16);
auto fmt = (AVPixelFormat)(src->format);
AVHWFramesContext* hwctx = (AVHWFramesContext*)src->hw_frames_ctx->data;
AVPixelFormat sw_fmt = hwctx->sw_format;
TORCH_INTERNAL_ASSERT(
AV_PIX_FMT_CUDA == fmt,
"Expected CUDA frame. Found: ",
av_get_pix_fmt_name(fmt));
TORCH_INTERNAL_ASSERT(
AV_PIX_FMT_P010 == sw_fmt,
"Expected P010 format. Found: ",
av_get_pix_fmt_name(sw_fmt));
// Write Y plane directly
auto status = cudaMemcpy2D(
dst.data_ptr(),
width * 2,
src->data[0],
src->linesize[0],
width * 2,
height,
cudaMemcpyDeviceToDevice);
TORCH_CHECK(cudaSuccess == status, "Failed to copy Y plane to CUDA tensor.");
// Prepare intermediate UV planes
status = cudaMemcpy2D(
tmp_uv.data_ptr(),
width * 2,
src->data[1],
src->linesize[1],
width * 2,
height / 2,
cudaMemcpyDeviceToDevice);
TORCH_CHECK(cudaSuccess == status, "Failed to copy UV plane to CUDA tensor.");
// Write to the UV plane
torch::Tensor uv = tmp_uv.permute({0, 3, 1, 2});
using namespace torch::indexing;
// very simplistic upscale using indexing since interpolate doesn't support
// shorts
dst.index_put_(
{Slice(), Slice(1, 3), Slice(None, None, 2), Slice(None, None, 2)}, uv);
dst.index_put_(
{Slice(), Slice(1, 3), Slice(1, None, 2), Slice(None, None, 2)}, uv);
dst.index_put_(
{Slice(), Slice(1, 3), Slice(None, None, 2), Slice(1, None, 2)}, uv);
dst.index_put_(
{Slice(), Slice(1, 3), Slice(1, None, 2), Slice(1, None, 2)}, uv);
// correct for int16
dst += 32768;
}
torch::Tensor P010CudaConverter::convert(const AVFrame* src) {
torch::Tensor buffer = get_image_buffer(
{1, num_channels, height, width}, tmp_uv.device(), torch::kInt16);
convert(src, buffer);
return buffer;
}
#endif
} // namespace torchaudio::io
#pragma once
#include <torch/types.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
namespace torchaudio::io {
////////////////////////////////////////////////////////////////////////////////
// Audio
////////////////////////////////////////////////////////////////////////////////
template <c10::ScalarType dtype, bool is_planar>
class AudioConverter {
const int num_channels;
public:
AudioConverter(int num_channels);
// Converts AVFrame* into Tensor of [T, C]
torch::Tensor convert(const AVFrame* src);
// Converts AVFrame* into pre-allocated Tensor.
// The shape must be [C, T] if is_planar otherwise [T, C]
void convert(const AVFrame* src, torch::Tensor& dst);
};
////////////////////////////////////////////////////////////////////////////////
// Image
////////////////////////////////////////////////////////////////////////////////
struct ImageConverterBase {
const int height;
const int width;
const int num_channels;
ImageConverterBase(int h, int w, int c);
};
////////////////////////////////////////////////////////////////////////////////
// Interlaced Images - NHWC
////////////////////////////////////////////////////////////////////////////////
struct InterlacedImageConverter : public ImageConverterBase {
using ImageConverterBase::ImageConverterBase;
// convert AVFrame* into Tensor of NCHW format
torch::Tensor convert(const AVFrame* src);
// convert AVFrame* into pre-allocated Tensor of NHWC format
void convert(const AVFrame* src, torch::Tensor& dst);
};
struct Interlaced16BitImageConverter : public ImageConverterBase {
using ImageConverterBase::ImageConverterBase;
// convert AVFrame* into Tensor of NCHW format
torch::Tensor convert(const AVFrame* src);
// convert AVFrame* into pre-allocated Tensor of NHWC format
void convert(const AVFrame* src, torch::Tensor& dst);
};
////////////////////////////////////////////////////////////////////////////////
// Planar Images - NCHW
////////////////////////////////////////////////////////////////////////////////
struct PlanarImageConverter : public ImageConverterBase {
using ImageConverterBase::ImageConverterBase;
void convert(const AVFrame* src, torch::Tensor& dst);
torch::Tensor convert(const AVFrame* src);
};
////////////////////////////////////////////////////////////////////////////////
// Family of YUVs - NCHW
////////////////////////////////////////////////////////////////////////////////
class YUV420PConverter : public ImageConverterBase {
torch::Tensor tmp_uv;
public:
YUV420PConverter(int height, int width);
void convert(const AVFrame* src, torch::Tensor& dst);
torch::Tensor convert(const AVFrame* src);
};
class NV12Converter : public ImageConverterBase {
torch::Tensor tmp_uv;
public:
NV12Converter(int height, int width);
void convert(const AVFrame* src, torch::Tensor& dst);
torch::Tensor convert(const AVFrame* src);
};
#ifdef USE_CUDA
class NV12CudaConverter : ImageConverterBase {
torch::Tensor tmp_uv;
public:
NV12CudaConverter(int height, int width, const torch::Device& device);
void convert(const AVFrame* src, torch::Tensor& dst);
torch::Tensor convert(const AVFrame* src);
};
class P010CudaConverter : ImageConverterBase {
torch::Tensor tmp_uv;
public:
P010CudaConverter(int height, int width, const torch::Device& device);
void convert(const AVFrame* src, torch::Tensor& dst);
torch::Tensor convert(const AVFrame* src);
};
#endif
} // namespace torchaudio::io
......@@ -8,10 +8,10 @@ namespace io {
namespace {
std::unique_ptr<Buffer> get_buffer(
AVMediaType type,
AVCodecContext* codec_ctx,
FilterGraph& filter,
int frames_per_chunk,
int num_chunks,
double frame_duration,
const torch::Device& device) {
TORCH_CHECK(
frames_per_chunk > 0 || frames_per_chunk == -1,
......@@ -23,47 +23,71 @@ std::unique_ptr<Buffer> get_buffer(
"`num_chunks` must be positive or -1. Found: ",
num_chunks);
TORCH_INTERNAL_ASSERT(
type == AVMEDIA_TYPE_AUDIO || type == AVMEDIA_TYPE_VIDEO,
auto info = filter.get_output_info();
TORCH_CHECK(
info.type == AVMEDIA_TYPE_AUDIO || info.type == AVMEDIA_TYPE_VIDEO,
"Unsupported media type: ",
av_get_media_type_string(type),
av_get_media_type_string(info.type),
". Only video or audio is supported ");
// Chunked Mode
if (frames_per_chunk > 0) {
if (type == AVMEDIA_TYPE_AUDIO) {
return std::unique_ptr<Buffer>(new detail::ChunkedAudioBuffer(
frames_per_chunk, num_chunks, frame_duration));
auto time_base = filter.get_output_timebase();
double frame_duration = double(time_base.num) / time_base.den;
if (info.type == AVMEDIA_TYPE_AUDIO) {
AVSampleFormat fmt = (AVSampleFormat)(info.format);
if (frames_per_chunk == -1) {
return detail::get_unchunked_buffer(fmt, codec_ctx->channels);
} else {
return std::unique_ptr<Buffer>(new detail::ChunkedVideoBuffer(
frames_per_chunk, num_chunks, frame_duration, device));
return detail::get_chunked_buffer(
frames_per_chunk,
num_chunks,
frame_duration,
fmt,
codec_ctx->channels);
}
} else {
// Note
// When using HW decoder, the pixel format is CUDA, and FilterGraph does
// not yet support CUDA frames, nor propagating the software pixel format,
// so here, we refer to AVCodecContext* to look at the pixel format.
AVPixelFormat fmt = (AVPixelFormat)(info.format);
if (fmt == AV_PIX_FMT_CUDA) {
fmt = codec_ctx->sw_pix_fmt;
}
} else { // unchunked mode
if (type == AVMEDIA_TYPE_AUDIO) {
return std::unique_ptr<Buffer>(new detail::UnchunkedAudioBuffer());
if (frames_per_chunk == -1) {
return detail::get_unchunked_buffer(fmt, info.height, info.width, device);
} else {
return std::unique_ptr<Buffer>(new detail::UnchunkedVideoBuffer(device));
return detail::get_chunked_buffer(
frames_per_chunk,
num_chunks,
frame_duration,
fmt,
info.height,
info.width,
device);
}
}
}
std::unique_ptr<FilterGraph> get_filter_graph(
FilterGraph get_filter_graph(
AVRational input_time_base,
AVCodecContext* codec_ctx,
AVRational frame_rate,
const std::string& filter_description) {
auto p = std::make_unique<FilterGraph>(codec_ctx->codec_type);
auto p = FilterGraph{codec_ctx->codec_type};
switch (codec_ctx->codec_type) {
case AVMEDIA_TYPE_AUDIO:
p->add_audio_src(
p.add_audio_src(
codec_ctx->sample_fmt,
input_time_base,
codec_ctx->sample_rate,
codec_ctx->channel_layout);
break;
case AVMEDIA_TYPE_VIDEO:
p->add_video_src(
p.add_video_src(
codec_ctx->pix_fmt,
input_time_base,
frame_rate,
......@@ -74,9 +98,9 @@ std::unique_ptr<FilterGraph> get_filter_graph(
default:
TORCH_CHECK(false, "Only audio/video are supported.");
}
p->add_sink();
p->add_process(filter_description);
p->create_filter();
p.add_sink();
p.add_process(filter_description);
p.create_filter();
return p;
}
......@@ -100,20 +124,17 @@ Sink::Sink(
codec_ctx,
frame_rate,
filter_description)),
output_time_base(filter->get_output_timebase()),
buffer(get_buffer(
codec_ctx->codec_type,
frames_per_chunk,
num_chunks,
double(output_time_base.num) / output_time_base.den,
device)) {}
output_time_base(filter.get_output_timebase()),
buffer(
get_buffer(codec_ctx, filter, frames_per_chunk, num_chunks, device)) {
}
// 0: some kind of success
// <0: Some error happened
int Sink::process_frame(AVFrame* pFrame) {
int ret = filter->add_frame(pFrame);
int ret = filter.add_frame(pFrame);
while (ret >= 0) {
ret = filter->get_frame(frame);
ret = filter.get_frame(frame);
// AVERROR(EAGAIN) means that new input data is required to return new
// output.
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
......@@ -134,7 +155,7 @@ std::string Sink::get_filter_description() const {
}
FilterGraphOutputInfo Sink::get_filter_output_info() const {
return filter->get_output_info();
return filter.get_output_info();
}
void Sink::flush() {
......
......@@ -15,7 +15,7 @@ class Sink {
AVCodecContext* codec_ctx;
AVRational frame_rate;
std::string filter_description;
std::unique_ptr<FilterGraph> filter;
FilterGraph filter;
// time_base of filter graph output, used for PTS calc
AVRational output_time_base;
......
......@@ -50,6 +50,29 @@ enum AVPixelFormat get_hw_format(
const AVCodecHWConfig* cfg = static_cast<AVCodecHWConfig*>(codec_ctx->opaque);
for (const enum AVPixelFormat* p = pix_fmts; *p != -1; p++) {
if (*p == cfg->pix_fmt) {
// Note
// The HW decode example uses generic approach
// https://ffmpeg.org/doxygen/4.1/hw__decode_8c_source.html#l00063
// But this approach finalizes the codec configuration when the first
// frame comes in.
// We need to inspect the codec configuration right after the codec is
// opened.
// So we add short cut for known patterns.
// yuv420p (h264) -> nv12
// yuv420p10le (hevc/h265) -> p010le
switch (codec_ctx->pix_fmt) {
case AV_PIX_FMT_YUV420P: {
codec_ctx->pix_fmt = AV_PIX_FMT_CUDA;
codec_ctx->sw_pix_fmt = AV_PIX_FMT_NV12;
break;
}
case AV_PIX_FMT_YUV420P10LE: {
codec_ctx->pix_fmt = AV_PIX_FMT_CUDA;
codec_ctx->sw_pix_fmt = AV_PIX_FMT_P010LE;
break;
}
default:;
}
return *p;
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment