Commit 54d2d04f authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add HW acceleration support on Streamer (#2331)

Summary:
This commits add `hw_accel` option to `Streamer::add_video_stream` method.
Specifying `hw_accel="cuda"` allows to create the chunk Tensor directly from CUDA,
when the following conditions are met.
1. the video format is H264,
2. underlying ffmpeg is compiled with NVENC, and
3. the client code specifies `decoder="h264_cuvid"`.

A simple benchmark yields x7 improvement in the decoding speed.

<details>

```python
import time

from torchaudio.prototype.io import Streamer

srcs = [
    "https://download.pytorch.org/torchaudio/tutorial-assets/stream-api/NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4.mp4",
    "./NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4.mp4",  # offline version
]

patterns = [
    ("h264_cuvid", None, "cuda:0"),  # NVDEC on CUDA:0 -> CUDA:0
    ("h264_cuvid", None, "cuda:1"),  # NVDEC on CUDA:1 -> CUDA:1
    ("h264_cuvid", None, None),  # NVDEC -> CPU
    (None, None, None),  # CPU
]

for src in srcs:
    print(src, flush=True)
    for (decoder, decoder_options, hw_accel) in patterns:
        s = Streamer(src)
        s.add_video_stream(5, decoder=decoder, decoder_options=decoder_options, hw_accel=hw_accel)

        t0 = time.monotonic()
        num_frames = 0
	for i, (chunk, ) in enumerate(s.stream()):
	    num_frames += chunk.shape[0]
        t1 = time.monotonic()
        print(chunk.dtype, chunk.shape, chunk.device)
        print(time.monotonic() - t0, num_frames, flush=True)
```
</details>

```
https://download.pytorch.org/torchaudio/tutorial-assets/stream-api/NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4.mp4
torch.uint8 torch.Size([5, 3, 1080, 1920]) cuda:0
10.781158386962488 6175
torch.uint8 torch.Size([5, 3, 1080, 1920]) cuda:1
10.771313901990652 6175
torch.uint8 torch.Size([5, 3, 1080, 1920]) cpu
27.88662809302332 6175
torch.uint8 torch.Size([5, 3, 1080, 1920]) cpu
83.22728440898936 6175
./NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4.mp4
torch.uint8 torch.Size([5, 3, 1080, 1920]) cuda:0
12.945253834011964 6175
torch.uint8 torch.Size([5, 3, 1080, 1920]) cuda:1
12.870224556012545 6175
torch.uint8 torch.Size([5, 3, 1080, 1920]) cpu
28.03406483103754 6175
torch.uint8 torch.Size([5, 3, 1080, 1920]) cpu
82.6120332319988 6175
```

With HW resizing

<details>

```python
import time

from torchaudio.prototype.io import Streamer

srcs = [
    "./NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4.mp4",
    "https://download.pytorch.org/torchaudio/tutorial-assets/stream-api/NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4.mp4",
]

patterns = [
    # Decode with NVDEC, CUDA HW scaling -> CUDA:0
    ("h264_cuvid", {"resize": "960x540"}, "", "cuda:0"),
    # Decoded with NVDEC, CUDA HW scaling -> CPU
    ("h264_cuvid", {"resize": "960x540"}, "", None),
    # CPU decoding, CPU scaling
    (None, None, "scale=width=960:height=540", None),
]

for src in srcs:
    print(src, flush=True)
    for (decoder, decoder_options, filter_desc, hw_accel) in patterns:
        s = Streamer(src)
        s.add_video_stream(
            5,
            decoder=decoder,
            decoder_options=decoder_options,
            filter_desc=filter_desc,
            hw_accel=hw_accel,
        )

        t0 = time.monotonic()
        num_frames = 0
        for i, (chunk, ) in enumerate(s.stream()):
            num_frames += chunk.shape[0]
        t1 = time.monotonic()
        print(chunk.dtype, chunk.shape, chunk.device)
        print(time.monotonic() - t0, num_frames, flush=True)
```

</details>

```
./NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4.mp4
torch.uint8 torch.Size([5, 3, 540, 960]) cuda:0
12.890056837990414 6175
torch.uint8 torch.Size([5, 3, 540, 960]) cpu
10.697489063022658 6175
torch.uint8 torch.Size([5, 3, 540, 960]) cpu
85.19899423001334 6175

https://download.pytorch.org/torchaudio/tutorial-assets/stream-api/NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4.mp4
torch.uint8 torch.Size([5, 3, 540, 960]) cuda:0
10.712715593050234 6175
torch.uint8 torch.Size([5, 3, 540, 960]) cpu
11.030170071986504 6175
torch.uint8 torch.Size([5, 3, 540, 960]) cpu
84.8515750519582 6175
```

Pull Request resolved: https://github.com/pytorch/audio/pull/2331

Reviewed By: hwangjeff

Differential Revision: D36217169

Pulled By: mthrok

fbshipit-source-id: 7979570b083cfc238ad4735b44305d8649f0607b
parent 638120ca
...@@ -15,8 +15,11 @@ Buffer::Buffer(int frames_per_chunk, int num_chunks) ...@@ -15,8 +15,11 @@ Buffer::Buffer(int frames_per_chunk, int num_chunks)
AudioBuffer::AudioBuffer(int frames_per_chunk, int num_chunks) AudioBuffer::AudioBuffer(int frames_per_chunk, int num_chunks)
: Buffer(frames_per_chunk, num_chunks) {} : Buffer(frames_per_chunk, num_chunks) {}
VideoBuffer::VideoBuffer(int frames_per_chunk, int num_chunks) VideoBuffer::VideoBuffer(
: Buffer(frames_per_chunk, num_chunks) {} int frames_per_chunk,
int num_chunks,
const torch::Device& device_)
: Buffer(frames_per_chunk, num_chunks), device(device_) {}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// Query // Query
...@@ -255,14 +258,15 @@ torch::Tensor convert_nv12_cpu(AVFrame* pFrame) { ...@@ -255,14 +258,15 @@ torch::Tensor convert_nv12_cpu(AVFrame* pFrame) {
} }
#ifdef USE_CUDA #ifdef USE_CUDA
torch::Tensor convert_nv12_cuda(AVFrame* pFrame) { torch::Tensor convert_nv12_cuda(AVFrame* pFrame, const torch::Device& device) {
int width = pFrame->width; int width = pFrame->width;
int height = pFrame->height; int height = pFrame->height;
auto options = torch::TensorOptions() auto options = torch::TensorOptions()
.dtype(torch::kUInt8) .dtype(torch::kUInt8)
.layout(torch::kStrided) .layout(torch::kStrided)
.device(torch::kCUDA); .device(torch::kCUDA)
.device_index(device.index());
torch::Tensor y = torch::empty({1, height, width, 1}, options); torch::Tensor y = torch::empty({1, height, width, 1}, options);
{ {
...@@ -305,7 +309,9 @@ torch::Tensor convert_nv12_cuda(AVFrame* pFrame) { ...@@ -305,7 +309,9 @@ torch::Tensor convert_nv12_cuda(AVFrame* pFrame) {
} }
#endif #endif
torch::Tensor convert_image_tensor(AVFrame* pFrame) { torch::Tensor convert_image_tensor(
AVFrame* pFrame,
const torch::Device& device) {
// ref: // ref:
// https://ffmpeg.org/doxygen/4.1/filtering__video_8c_source.html#l00179 // https://ffmpeg.org/doxygen/4.1/filtering__video_8c_source.html#l00179
// https://ffmpeg.org/doxygen/4.1/decode__video_8c_source.html#l00038 // https://ffmpeg.org/doxygen/4.1/decode__video_8c_source.html#l00038
...@@ -344,7 +350,7 @@ torch::Tensor convert_image_tensor(AVFrame* pFrame) { ...@@ -344,7 +350,7 @@ torch::Tensor convert_image_tensor(AVFrame* pFrame) {
// https://github.com/FFmpeg/FFmpeg/blob/072101bd52f7f092ee976f4e6e41c19812ad32fd/libavcodec/cuviddec.c#L1121-L1124 // https://github.com/FFmpeg/FFmpeg/blob/072101bd52f7f092ee976f4e6e41c19812ad32fd/libavcodec/cuviddec.c#L1121-L1124
switch (sw_format) { switch (sw_format) {
case AV_PIX_FMT_NV12: case AV_PIX_FMT_NV12:
return convert_nv12_cuda(pFrame); return convert_nv12_cuda(pFrame, device);
case AV_PIX_FMT_P010: case AV_PIX_FMT_P010:
case AV_PIX_FMT_P016: case AV_PIX_FMT_P016:
throw std::runtime_error( throw std::runtime_error(
...@@ -399,7 +405,7 @@ void VideoBuffer::push_tensor(torch::Tensor t) { ...@@ -399,7 +405,7 @@ void VideoBuffer::push_tensor(torch::Tensor t) {
} }
void VideoBuffer::push_frame(AVFrame* frame) { void VideoBuffer::push_frame(AVFrame* frame) {
push_tensor(convert_image_tensor(frame)); push_tensor(convert_image_tensor(frame, device));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
......
...@@ -76,8 +76,13 @@ class AudioBuffer : public Buffer { ...@@ -76,8 +76,13 @@ class AudioBuffer : public Buffer {
// But this mean that chunks consisting of multiple frames have to be created // But this mean that chunks consisting of multiple frames have to be created
// at popping time. // at popping time.
class VideoBuffer : public Buffer { class VideoBuffer : public Buffer {
const torch::Device device;
public: public:
VideoBuffer(int frames_per_chunk, int num_chunks); VideoBuffer(
int frames_per_chunk,
int num_chunks,
const torch::Device& device);
void push_frame(AVFrame* frame); void push_frame(AVFrame* frame);
......
...@@ -9,8 +9,9 @@ namespace ffmpeg { ...@@ -9,8 +9,9 @@ namespace ffmpeg {
Decoder::Decoder( Decoder::Decoder(
AVCodecParameters* pParam, AVCodecParameters* pParam,
const std::string& decoder_name, const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option) const std::map<std::string, std::string>& decoder_option,
: pCodecContext(pParam, decoder_name, decoder_option) {} const torch::Device& device)
: pCodecContext(pParam, decoder_name, decoder_option, device) {}
int Decoder::process_packet(AVPacket* pPacket) { int Decoder::process_packet(AVPacket* pPacket) {
return avcodec_send_packet(pCodecContext, pPacket); return avcodec_send_packet(pCodecContext, pPacket);
......
...@@ -13,7 +13,8 @@ class Decoder { ...@@ -13,7 +13,8 @@ class Decoder {
Decoder( Decoder(
AVCodecParameters* pParam, AVCodecParameters* pParam,
const std::string& decoder_name, const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option); const std::map<std::string, std::string>& decoder_option,
const torch::Device& device);
// Custom destructor to clean up the resources // Custom destructor to clean up the resources
~Decoder() = default; ~Decoder() = default;
// Non-copyable // Non-copyable
......
#include <c10/util/Exception.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h> #include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
...@@ -176,11 +177,46 @@ AVCodecContext* get_codec_context( ...@@ -176,11 +177,46 @@ AVCodecContext* get_codec_context(
return pCodecContext; return pCodecContext;
} }
#ifdef USE_CUDA
enum AVPixelFormat get_hw_format(
AVCodecContext* ctx,
const enum AVPixelFormat* pix_fmts) {
const enum AVPixelFormat* p = nullptr;
AVPixelFormat pix_fmt = *static_cast<AVPixelFormat*>(ctx->opaque);
for (p = pix_fmts; *p != -1; p++) {
if (*p == pix_fmt) {
return *p;
}
}
TORCH_WARN("Failed to get HW surface format.");
return AV_PIX_FMT_NONE;
}
const AVCodecHWConfig* get_cuda_config(const AVCodec* pCodec) {
for (int i = 0;; ++i) {
const AVCodecHWConfig* config = avcodec_get_hw_config(pCodec, i);
if (!config) {
break;
}
if (config->device_type == AV_HWDEVICE_TYPE_CUDA &&
config->methods & AV_CODEC_HW_CONFIG_METHOD_HW_DEVICE_CTX) {
return config;
}
}
std::stringstream ss;
ss << "CUDA device was requested, but the codec \"" << pCodec->name
<< "\" is not supported.";
throw std::runtime_error(ss.str());
}
#endif
void init_codec_context( void init_codec_context(
AVCodecContext* pCodecContext, AVCodecContext* pCodecContext,
AVCodecParameters* pParams, AVCodecParameters* pParams,
const std::string& decoder_name, const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option) { const std::map<std::string, std::string>& decoder_option,
const torch::Device& device,
AVBufferRefPtr& pHWBufferRef) {
const AVCodec* pCodec = decoder_name.empty() const AVCodec* pCodec = decoder_name.empty()
? avcodec_find_decoder(pParams->codec_id) ? avcodec_find_decoder(pParams->codec_id)
: avcodec_find_decoder_by_name(decoder_name.c_str()); : avcodec_find_decoder_by_name(decoder_name.c_str());
...@@ -192,6 +228,40 @@ void init_codec_context( ...@@ -192,6 +228,40 @@ void init_codec_context(
throw std::runtime_error("Failed to set CodecContext parameter."); throw std::runtime_error("Failed to set CodecContext parameter.");
} }
#ifdef USE_CUDA
// Enable HW Acceleration
if (device.type() == c10::DeviceType::CUDA) {
const AVCodecHWConfig* config = get_cuda_config(pCodec);
// TODO: check how to log
// C10_LOG << "Decoder " << pCodec->name << " supports device " <<
// av_hwdevice_get_type_name(config->device_type);
// https://www.ffmpeg.org/doxygen/trunk/hw__decode_8c_source.html#l00221
// 1. Set HW pixel format (config->pix_fmt) to opaue pointer.
static thread_local AVPixelFormat pix_fmt = config->pix_fmt;
pCodecContext->opaque = static_cast<void*>(&pix_fmt);
// 2. Set pCodecContext->get_format call back function which
// will retrieve the HW pixel format from opaque pointer.
pCodecContext->get_format = get_hw_format;
// 3. Create HW device context and set to pCodecContext.
AVBufferRef* hw_device_ctx = nullptr;
// TODO: check how to deallocate the context
int err = av_hwdevice_ctx_create(
&hw_device_ctx,
AV_HWDEVICE_TYPE_CUDA,
std::to_string(device.index()).c_str(),
nullptr,
0);
if (err < 0) {
throw std::runtime_error(
"Failed to create CUDA device context: " + av_err2string(err));
}
assert(hw_device_ctx);
pCodecContext->hw_device_ctx = av_buffer_ref(hw_device_ctx);
pHWBufferRef.reset(hw_device_ctx);
}
#endif
AVDictionary* opts = get_option_dict(decoder_option); AVDictionary* opts = get_option_dict(decoder_option);
if (avcodec_open2(pCodecContext, pCodec, &opts) < 0) { if (avcodec_open2(pCodecContext, pCodec, &opts) < 0) {
throw std::runtime_error("Failed to initialize CodecContext."); throw std::runtime_error("Failed to initialize CodecContext.");
...@@ -211,11 +281,32 @@ void init_codec_context( ...@@ -211,11 +281,32 @@ void init_codec_context(
AVCodecContextPtr::AVCodecContextPtr( AVCodecContextPtr::AVCodecContextPtr(
AVCodecParameters* pParam, AVCodecParameters* pParam,
const std::string& decoder_name, const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option) const std::map<std::string, std::string>& decoder_option,
const torch::Device& device)
: Wrapper<AVCodecContext, AVCodecContextDeleter>( : Wrapper<AVCodecContext, AVCodecContextDeleter>(
get_codec_context(pParam->codec_id, decoder_name)) { get_codec_context(pParam->codec_id, decoder_name)),
init_codec_context(ptr.get(), pParam, decoder_name, decoder_option); pHWBufferRef() {
init_codec_context(
ptr.get(), pParam, decoder_name, decoder_option, device, pHWBufferRef);
}
////////////////////////////////////////////////////////////////////////////////
// AVBufferRefPtr
////////////////////////////////////////////////////////////////////////////////
void AutoBufferUnref::operator()(AVBufferRef* p) {
av_buffer_unref(&p);
}
AVBufferRefPtr::AVBufferRefPtr()
: Wrapper<AVBufferRef, AutoBufferUnref>(nullptr) {}
void AVBufferRefPtr::reset(AVBufferRef* p) {
TORCH_CHECK(
!ptr,
"InternalError: A valid AVBufferRefPtr is being reset. Please file an issue.");
ptr.reset(p);
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// AVFilterGraph // AVFilterGraph
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
......
// One stop header for all ffmepg needs // One stop header for all ffmepg needs
#pragma once #pragma once
#include <torch/torch.h>
#include <cstdint> #include <cstdint>
#include <map> #include <map>
#include <memory> #include <memory>
...@@ -110,6 +111,19 @@ struct AVFramePtr : public Wrapper<AVFrame, AVFrameDeleter> { ...@@ -110,6 +111,19 @@ struct AVFramePtr : public Wrapper<AVFrame, AVFrameDeleter> {
AVFramePtr(); AVFramePtr();
}; };
////////////////////////////////////////////////////////////////////////////////
// AutoBufferUnrer is responsible for performing unref at the end of lifetime
// of AVBufferRefPtr.
////////////////////////////////////////////////////////////////////////////////
struct AutoBufferUnref {
void operator()(AVBufferRef* p);
};
struct AVBufferRefPtr : public Wrapper<AVBufferRef, AutoBufferUnref> {
AVBufferRefPtr();
void reset(AVBufferRef* p);
};
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// AVCodecContext // AVCodecContext
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
...@@ -118,10 +132,13 @@ struct AVCodecContextDeleter { ...@@ -118,10 +132,13 @@ struct AVCodecContextDeleter {
}; };
struct AVCodecContextPtr struct AVCodecContextPtr
: public Wrapper<AVCodecContext, AVCodecContextDeleter> { : public Wrapper<AVCodecContext, AVCodecContextDeleter> {
AVBufferRefPtr pHWBufferRef;
AVCodecContextPtr( AVCodecContextPtr(
AVCodecParameters* pParam, AVCodecParameters* pParam,
const std::string& decoder, const std::string& decoder,
const std::map<std::string, std::string>& decoder_option); const std::map<std::string, std::string>& decoder_option,
const torch::Device& device);
}; };
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
......
...@@ -180,8 +180,11 @@ void FilterGraph::add_process() { ...@@ -180,8 +180,11 @@ void FilterGraph::add_process() {
} }
void FilterGraph::create_filter() { void FilterGraph::create_filter() {
if (avfilter_graph_config(pFilterGraph, nullptr) < 0) int ret = avfilter_graph_config(pFilterGraph, nullptr);
throw std::runtime_error("Failed to configure the graph."); if (ret < 0) {
throw std::runtime_error(
"Failed to configure the graph: " + av_err2string(ret));
}
// char* desc = avfilter_graph_dump(pFilterGraph.get(), NULL); // char* desc = avfilter_graph_dump(pFilterGraph.get(), NULL);
// std::cerr << "Filter created:\n" << desc << std::endl; // std::cerr << "Filter created:\n" << desc << std::endl;
// av_free(static_cast<void*>(desc)); // av_free(static_cast<void*>(desc));
......
...@@ -246,7 +246,14 @@ void add_basic_video_stream( ...@@ -246,7 +246,14 @@ void add_basic_video_stream(
const c10::optional<int64_t>& height, const c10::optional<int64_t>& height,
const c10::optional<std::string>& format) { const c10::optional<std::string>& format) {
std::string filter_desc = get_vfilter_desc(frame_rate, width, height, format); std::string filter_desc = get_vfilter_desc(frame_rate, width, height, format);
s->s.add_video_stream(i, frames_per_chunk, num_chunks, filter_desc, "", {}); s->s.add_video_stream(
static_cast<int>(i),
static_cast<int>(frames_per_chunk),
static_cast<int>(num_chunks),
std::move(filter_desc),
"",
{},
torch::Device(c10::DeviceType::CPU));
} }
void add_audio_stream( void add_audio_stream(
...@@ -273,14 +280,35 @@ void add_video_stream( ...@@ -273,14 +280,35 @@ void add_video_stream(
int64_t num_chunks, int64_t num_chunks,
const c10::optional<std::string>& filter_desc, const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder, const c10::optional<std::string>& decoder,
const c10::optional<OptionDict>& decoder_options) { const c10::optional<OptionDict>& decoder_options,
const c10::optional<std::string>& hw_accel) {
const torch::Device device = [&]() {
if (!hw_accel) {
return torch::Device{c10::DeviceType::CPU};
}
#ifdef USE_CUDA
torch::Device d{hw_accel.value()};
if (d.type() != c10::DeviceType::CUDA) {
std::stringstream ss;
ss << "Only CUDA is supported for hardware acceleration. Found: "
<< device.str();
throw std::runtime_error(ss.str());
}
return d;
#else
throw std::runtime_error(
"torchaudio is not compiled with CUDA support. Hardware acceleration is not available.");
#endif
}();
s->s.add_video_stream( s->s.add_video_stream(
i, i,
frames_per_chunk, frames_per_chunk,
num_chunks, num_chunks,
filter_desc.value_or(""), filter_desc.value_or(""),
decoder.value_or(""), decoder.value_or(""),
convert_dict(decoder_options)); convert_dict(decoder_options),
device);
} }
void remove_stream(S s, int64_t i) { void remove_stream(S s, int64_t i) {
......
...@@ -8,14 +8,15 @@ namespace { ...@@ -8,14 +8,15 @@ namespace {
std::unique_ptr<Buffer> get_buffer( std::unique_ptr<Buffer> get_buffer(
AVMediaType type, AVMediaType type,
int frames_per_chunk, int frames_per_chunk,
int num_chunks) { int num_chunks,
const torch::Device& device) {
switch (type) { switch (type) {
case AVMEDIA_TYPE_AUDIO: case AVMEDIA_TYPE_AUDIO:
return std::unique_ptr<Buffer>( return std::unique_ptr<Buffer>(
new AudioBuffer(frames_per_chunk, num_chunks)); new AudioBuffer(frames_per_chunk, num_chunks));
case AVMEDIA_TYPE_VIDEO: case AVMEDIA_TYPE_VIDEO:
return std::unique_ptr<Buffer>( return std::unique_ptr<Buffer>(
new VideoBuffer(frames_per_chunk, num_chunks)); new VideoBuffer(frames_per_chunk, num_chunks, device));
default: default:
throw std::runtime_error( throw std::runtime_error(
std::string("Unsupported media type: ") + std::string("Unsupported media type: ") +
...@@ -29,9 +30,14 @@ Sink::Sink( ...@@ -29,9 +30,14 @@ Sink::Sink(
AVCodecParameters* codecpar, AVCodecParameters* codecpar,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
std::string filter_description) std::string filter_description,
: filter(input_time_base, codecpar, filter_description), const torch::Device& device)
buffer(get_buffer(codecpar->codec_type, frames_per_chunk, num_chunks)) {} : filter(input_time_base, codecpar, std::move(filter_description)),
buffer(get_buffer(
codecpar->codec_type,
frames_per_chunk,
num_chunks,
device)) {}
// 0: some kind of success // 0: some kind of success
// <0: Some error happened // <0: Some error happened
......
...@@ -18,7 +18,8 @@ class Sink { ...@@ -18,7 +18,8 @@ class Sink {
AVCodecParameters* codecpar, AVCodecParameters* codecpar,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
std::string filter_description); std::string filter_description,
const torch::Device& device);
int process_frame(AVFrame* frame); int process_frame(AVFrame* frame);
bool is_buffer_ready() const; bool is_buffer_ready() const;
......
...@@ -9,8 +9,9 @@ using KeyType = StreamProcessor::KeyType; ...@@ -9,8 +9,9 @@ using KeyType = StreamProcessor::KeyType;
StreamProcessor::StreamProcessor( StreamProcessor::StreamProcessor(
AVCodecParameters* codecpar, AVCodecParameters* codecpar,
const std::string& decoder_name, const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option) const std::map<std::string, std::string>& decoder_option,
: decoder(codecpar, decoder_name, decoder_option) {} const torch::Device& device)
: decoder(codecpar, decoder_name, decoder_option, device) {}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// Configurations // Configurations
...@@ -20,7 +21,8 @@ KeyType StreamProcessor::add_stream( ...@@ -20,7 +21,8 @@ KeyType StreamProcessor::add_stream(
AVCodecParameters* codecpar, AVCodecParameters* codecpar,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
std::string filter_description) { std::string filter_description,
const torch::Device& device) {
switch (codecpar->codec_type) { switch (codecpar->codec_type) {
case AVMEDIA_TYPE_AUDIO: case AVMEDIA_TYPE_AUDIO:
case AVMEDIA_TYPE_VIDEO: case AVMEDIA_TYPE_VIDEO:
...@@ -37,7 +39,8 @@ KeyType StreamProcessor::add_stream( ...@@ -37,7 +39,8 @@ KeyType StreamProcessor::add_stream(
codecpar, codecpar,
frames_per_chunk, frames_per_chunk,
num_chunks, num_chunks,
std::move(filter_description))); std::move(filter_description),
device));
decoder_time_base = av_q2d(input_time_base); decoder_time_base = av_q2d(input_time_base);
return key; return key;
} }
......
...@@ -28,7 +28,8 @@ class StreamProcessor { ...@@ -28,7 +28,8 @@ class StreamProcessor {
StreamProcessor( StreamProcessor(
AVCodecParameters* codecpar, AVCodecParameters* codecpar,
const std::string& decoder_name, const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option); const std::map<std::string, std::string>& decoder_option,
const torch::Device& device);
~StreamProcessor() = default; ~StreamProcessor() = default;
// Non-copyable // Non-copyable
StreamProcessor(const StreamProcessor&) = delete; StreamProcessor(const StreamProcessor&) = delete;
...@@ -51,7 +52,8 @@ class StreamProcessor { ...@@ -51,7 +52,8 @@ class StreamProcessor {
AVCodecParameters* codecpar, AVCodecParameters* codecpar,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
std::string filter_description); std::string filter_description,
const torch::Device& device);
// 1. Remove the stream // 1. Remove the stream
void remove_stream(KeyType key); void remove_stream(KeyType key);
......
...@@ -166,7 +166,8 @@ void Streamer::add_audio_stream( ...@@ -166,7 +166,8 @@ void Streamer::add_audio_stream(
num_chunks, num_chunks,
std::move(filter_desc), std::move(filter_desc),
decoder, decoder,
decoder_option); decoder_option,
torch::Device(torch::DeviceType::CPU));
} }
void Streamer::add_video_stream( void Streamer::add_video_stream(
...@@ -175,7 +176,8 @@ void Streamer::add_video_stream( ...@@ -175,7 +176,8 @@ void Streamer::add_video_stream(
int num_chunks, int num_chunks,
std::string filter_desc, std::string filter_desc,
const std::string& decoder, const std::string& decoder,
const std::map<std::string, std::string>& decoder_option) { const std::map<std::string, std::string>& decoder_option,
const torch::Device& device) {
add_stream( add_stream(
i, i,
AVMEDIA_TYPE_VIDEO, AVMEDIA_TYPE_VIDEO,
...@@ -183,7 +185,8 @@ void Streamer::add_video_stream( ...@@ -183,7 +185,8 @@ void Streamer::add_video_stream(
num_chunks, num_chunks,
std::move(filter_desc), std::move(filter_desc),
decoder, decoder,
decoder_option); decoder_option,
device);
} }
void Streamer::add_stream( void Streamer::add_stream(
...@@ -193,19 +196,22 @@ void Streamer::add_stream( ...@@ -193,19 +196,22 @@ void Streamer::add_stream(
int num_chunks, int num_chunks,
std::string filter_desc, std::string filter_desc,
const std::string& decoder, const std::string& decoder,
const std::map<std::string, std::string>& decoder_option) { const std::map<std::string, std::string>& decoder_option,
const torch::Device& device) {
validate_src_stream_type(i, media_type); validate_src_stream_type(i, media_type);
AVStream* stream = pFormatContext->streams[i]; AVStream* stream = pFormatContext->streams[i];
stream->discard = AVDISCARD_DEFAULT; stream->discard = AVDISCARD_DEFAULT;
if (!processors[i]) if (!processors[i])
processors[i] = std::make_unique<StreamProcessor>( processors[i] = std::make_unique<StreamProcessor>(
stream->codecpar, decoder, decoder_option); stream->codecpar, decoder, decoder_option, device);
int key = processors[i]->add_stream( int key = processors[i]->add_stream(
stream->time_base, stream->time_base,
stream->codecpar, stream->codecpar,
frames_per_chunk, frames_per_chunk,
num_chunks, num_chunks,
std::move(filter_desc)); std::move(filter_desc),
device);
stream_indices.push_back(std::make_pair<>(i, key)); stream_indices.push_back(std::make_pair<>(i, key));
} }
......
...@@ -75,7 +75,8 @@ class Streamer { ...@@ -75,7 +75,8 @@ class Streamer {
int num_chunks, int num_chunks,
std::string filter_desc, std::string filter_desc,
const std::string& decoder, const std::string& decoder,
const std::map<std::string, std::string>& decoder_option); const std::map<std::string, std::string>& decoder_option,
const torch::Device& device);
void remove_stream(int i); void remove_stream(int i);
private: private:
...@@ -86,7 +87,8 @@ class Streamer { ...@@ -86,7 +87,8 @@ class Streamer {
int num_chunks, int num_chunks,
std::string filter_desc, std::string filter_desc,
const std::string& decoder, const std::string& decoder,
const std::map<std::string, std::string>& decoder_option); const std::map<std::string, std::string>& decoder_option,
const torch::Device& device);
public: public:
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
......
...@@ -403,6 +403,7 @@ class Streamer: ...@@ -403,6 +403,7 @@ class Streamer:
filter_desc: Optional[str] = None, filter_desc: Optional[str] = None,
decoder: Optional[str] = None, decoder: Optional[str] = None,
decoder_options: Optional[Dict[str, str]] = None, decoder_options: Optional[Dict[str, str]] = None,
hw_accel: Optional[str] = None,
): ):
"""Add output video stream """Add output video stream
...@@ -429,6 +430,46 @@ class Streamer: ...@@ -429,6 +430,46 @@ class Streamer:
decoder_options (dict or None, optional): Options passed to decoder. decoder_options (dict or None, optional): Options passed to decoder.
Mapping from str to str. Mapping from str to str.
hw_accel (str or None, optional): Enable hardware acceleration.
The valid choice is "cuda" or ``None``.
Default: ``None``. (No hardware acceleration.)
When the following conditions are met, providing `hw_accel="cuda"`
will create Tensor directly from CUDA HW decoder.
1. TorchAudio is compiled with CUDA support.
2. FFmpeg libraries linked dynamically are compiled with NVDEC support.
3. The codec is supported NVDEC by. (Currently, `"h264_cuvid"` is supported)
Example - HW decoding::
>>> # Decode video with NVDEC, create Tensor on CPU.
>>> streamer = Streamer(src="input.mp4")
>>> streamer.add_video_stream(10, decoder="h264_cuvid", hw_accel=None)
>>>
>>> chunk, = next(streamer.stream())
>>> print(chunk.dtype)
... cpu
>>> # Decode video with NVDEC, create Tensor directly on CUDA
>>> streamer = Streamer(src="input.mp4")
>>> streamer.add_video_stream(10, decoder="h264_cuvid", hw_accel="cuda:1")
>>>
>>> chunk, = next(streamer.stream())
>>> print(chunk.dtype)
... cuda:1
>>> # Decode and resize video with NVDEC, create Tensor directly on CUDA
>>> streamer = Streamer(src="input.mp4")
>>> streamer.add_video_stream(
>>> 10, decoder="h264_cuvid",
>>> decoder_options={"resize": "240x360"}, hw_accel="cuda:1")
>>>
>>> chunk, = next(streamer.stream())
>>> print(chunk.dtype)
... cuda:1
""" """
i = self.default_video_stream if stream_index is None else stream_index i = self.default_video_stream if stream_index is None else stream_index
torch.ops.torchaudio.ffmpeg_streamer_add_video_stream( torch.ops.torchaudio.ffmpeg_streamer_add_video_stream(
...@@ -439,6 +480,7 @@ class Streamer: ...@@ -439,6 +480,7 @@ class Streamer:
filter_desc, filter_desc,
decoder, decoder,
decoder_options, decoder_options,
hw_accel,
) )
def remove_stream(self, i: int): def remove_stream(self, i: int):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment