Commit 0c8c138c authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Cache HW device context (#3178)

Summary:
TODO: add cache release

Pull Request resolved: https://github.com/pytorch/audio/pull/3178

Reviewed By: hwangjeff

Differential Revision: D44136275

Pulled By: mthrok

fbshipit-source-id: 4eaf646fe17a469e8bbbdf43441d5532f9f8461d
parent 59f067b7
...@@ -9,6 +9,7 @@ set( ...@@ -9,6 +9,7 @@ set(
sources sources
ffmpeg.cpp ffmpeg.cpp
filter_graph.cpp filter_graph.cpp
hw_context.cpp
stream_reader/buffer/chunked_buffer.cpp stream_reader/buffer/chunked_buffer.cpp
stream_reader/buffer/unchunked_buffer.cpp stream_reader/buffer/unchunked_buffer.cpp
stream_reader/conversion.cpp stream_reader/conversion.cpp
......
...@@ -128,8 +128,8 @@ void AutoBufferUnref::operator()(AVBufferRef* p) { ...@@ -128,8 +128,8 @@ void AutoBufferUnref::operator()(AVBufferRef* p) {
av_buffer_unref(&p); av_buffer_unref(&p);
} }
AVBufferRefPtr::AVBufferRefPtr() AVBufferRefPtr::AVBufferRefPtr(AVBufferRef* p)
: Wrapper<AVBufferRef, AutoBufferUnref>(nullptr) {} : Wrapper<AVBufferRef, AutoBufferUnref>(p) {}
void AVBufferRefPtr::reset(AVBufferRef* p) { void AVBufferRefPtr::reset(AVBufferRef* p) {
TORCH_CHECK( TORCH_CHECK(
......
...@@ -164,7 +164,7 @@ struct AutoBufferUnref { ...@@ -164,7 +164,7 @@ struct AutoBufferUnref {
}; };
struct AVBufferRefPtr : public Wrapper<AVBufferRef, AutoBufferUnref> { struct AVBufferRefPtr : public Wrapper<AVBufferRef, AutoBufferUnref> {
AVBufferRefPtr(); AVBufferRefPtr(AVBufferRef* p = nullptr);
void reset(AVBufferRef* p); void reset(AVBufferRef* p);
}; };
......
#include <torchaudio/csrc/ffmpeg/hw_context.h>
namespace torchaudio::io {
namespace {
static std::mutex MUTEX;
static std::map<int, AVBufferRefPtr> CUDA_CONTEXT_CACHE;
} // namespace
AVBufferRef* get_cuda_context(int index) {
std::lock_guard<std::mutex> lock(MUTEX);
if (index == -1) {
index = 0;
}
if (CUDA_CONTEXT_CACHE.count(index) == 0) {
AVBufferRef* p = nullptr;
int ret = av_hwdevice_ctx_create(
&p, AV_HWDEVICE_TYPE_CUDA, std::to_string(index).c_str(), nullptr, 0);
TORCH_CHECK(
ret >= 0,
"Failed to create CUDA device context on device ",
index,
"(",
av_err2string(ret),
")");
assert(p);
CUDA_CONTEXT_CACHE.emplace(index, p);
return p;
}
AVBufferRefPtr& buffer = CUDA_CONTEXT_CACHE.at(index);
return buffer;
}
void clear_cuda_context_cache() {
std::lock_guard<std::mutex> lock(MUTEX);
CUDA_CONTEXT_CACHE.clear();
}
} // namespace torchaudio::io
#pragma once
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
namespace torchaudio::io {
AVBufferRef* get_cuda_context(int index);
void clear_cuda_context_cache();
} // namespace torchaudio::io
#include <torch/extension.h> #include <torch/extension.h>
#include <torchaudio/csrc/ffmpeg/hw_context.h>
#include <torchaudio/csrc/ffmpeg/pybind/fileobj.h> #include <torchaudio/csrc/ffmpeg/pybind/fileobj.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/stream_reader.h> #include <torchaudio/csrc/ffmpeg/stream_reader/stream_reader.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h> #include <torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h>
...@@ -30,6 +31,7 @@ struct StreamWriterFileObj : private FileObj, public StreamWriter { ...@@ -30,6 +31,7 @@ struct StreamWriterFileObj : private FileObj, public StreamWriter {
}; };
PYBIND11_MODULE(_torchaudio_ffmpeg, m) { PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
m.def("clear_cuda_context_cache", &clear_cuda_context_cache);
py::class_<Chunk>(m, "Chunk", py::module_local()) py::class_<Chunk>(m, "Chunk", py::module_local())
.def_readwrite("frames", &Chunk::frames) .def_readwrite("frames", &Chunk::frames)
.def_readwrite("pts", &Chunk::pts); .def_readwrite("pts", &Chunk::pts);
......
#include <torchaudio/csrc/ffmpeg/hw_context.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/stream_processor.h> #include <torchaudio/csrc/ffmpeg/stream_reader/stream_processor.h>
#include <stdexcept> #include <stdexcept>
...@@ -99,6 +100,7 @@ void configure_codec_context( ...@@ -99,6 +100,7 @@ void configure_codec_context(
// 2. Set pCodecContext->get_format call back function which // 2. Set pCodecContext->get_format call back function which
// will retrieve the HW pixel format from opaque pointer. // will retrieve the HW pixel format from opaque pointer.
codec_ctx->get_format = get_hw_format; codec_ctx->get_format = get_hw_format;
codec_ctx->hw_device_ctx = av_buffer_ref(get_cuda_context(device.index()));
#endif #endif
} }
} }
......
...@@ -253,3 +253,9 @@ def get_build_config() -> str: ...@@ -253,3 +253,9 @@ def get_build_config() -> str:
--prefix=/Users/runner/miniforge3 --cc=arm64-apple-darwin20.0.0-clang --enable-gpl --enable-hardcoded-tables --enable-libfreetype --enable-libopenh264 --enable-neon --enable-libx264 --enable-libx265 --enable-libaom --enable-libsvtav1 --enable-libxml2 --enable-libvpx --enable-pic --enable-pthreads --enable-shared --disable-static --enable-version3 --enable-zlib --enable-libmp3lame --pkg-config=/Users/runner/miniforge3/conda-bld/ffmpeg_1646229390493/_build_env/bin/pkg-config --enable-cross-compile --arch=arm64 --target-os=darwin --cross-prefix=arm64-apple-darwin20.0.0- --host-cc=/Users/runner/miniforge3/conda-bld/ffmpeg_1646229390493/_build_env/bin/x86_64-apple-darwin13.4.0-clang # noqa --prefix=/Users/runner/miniforge3 --cc=arm64-apple-darwin20.0.0-clang --enable-gpl --enable-hardcoded-tables --enable-libfreetype --enable-libopenh264 --enable-neon --enable-libx264 --enable-libx265 --enable-libaom --enable-libsvtav1 --enable-libxml2 --enable-libvpx --enable-pic --enable-pthreads --enable-shared --disable-static --enable-version3 --enable-zlib --enable-libmp3lame --pkg-config=/Users/runner/miniforge3/conda-bld/ffmpeg_1646229390493/_build_env/bin/pkg-config --enable-cross-compile --arch=arm64 --target-os=darwin --cross-prefix=arm64-apple-darwin20.0.0- --host-cc=/Users/runner/miniforge3/conda-bld/ffmpeg_1646229390493/_build_env/bin/x86_64-apple-darwin13.4.0-clang # noqa
""" """
return torch.ops.torchaudio.ffmpeg_get_build_config() return torch.ops.torchaudio.ffmpeg_get_build_config()
@torchaudio._extension.fail_if_no_ffmpeg
def clear_cuda_context_cache():
"""Clear the CUDA context used by CUDA Hardware accelerated video decoding"""
torchaudio.lib._torchaudio_ffmpeg.clear_cuda_context_cache()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment