Unverified Commit f9fbc104 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Allow cuda device to be passed without the index for GPU decoding (#5505)

parent d4146ef1
......@@ -30,7 +30,7 @@ class TestVideoGPUDecoder:
)
def test_frame_reading(self, video_file):
full_path = os.path.join(VIDEO_DIR, video_file)
decoder = VideoReader(full_path, device="cuda:0")
decoder = VideoReader(full_path, device="cuda")
with av.open(full_path) as container:
for av_frame in container.decode(container.streams.video[0]):
av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray())
......@@ -54,7 +54,7 @@ class TestVideoGPUDecoder:
],
)
def test_seek_reading(self, keyframes, full_path, duration):
decoder = VideoReader(full_path, device="cuda:0")
decoder = VideoReader(full_path, device="cuda")
time = duration / 2
decoder.seek(time, keyframes_only=keyframes)
with av.open(full_path) as container:
......@@ -80,7 +80,7 @@ class TestVideoGPUDecoder:
)
def test_metadata(self, video_file):
full_path = os.path.join(VIDEO_DIR, video_file)
decoder = VideoReader(full_path, device="cuda:0")
decoder = VideoReader(full_path, device="cuda")
video_metadata = decoder.get_metadata()["video"]
with av.open(full_path) as container:
video = container.streams.video[0]
......
......@@ -3,9 +3,10 @@
/* Set cuda device, create cuda context and initialise the demuxer and decoder.
*/
GPUDecoder::GPUDecoder(std::string src_file, int64_t dev)
: demuxer(src_file.c_str()), device(dev) {
at::cuda::CUDAGuard device_guard(device);
GPUDecoder::GPUDecoder(std::string src_file, torch::Device dev)
: demuxer(src_file.c_str()) {
at::cuda::CUDAGuard device_guard(dev);
device = device_guard.current_device().index();
check_for_cuda_errors(
cuDevicePrimaryCtxRetain(&ctx, device), __LINE__, __FILE__);
decoder.init(ctx, ffmpeg_to_codec(demuxer.get_video_codec()));
......@@ -58,7 +59,7 @@ c10::Dict<std::string, c10::Dict<std::string, double>> GPUDecoder::
TORCH_LIBRARY(torchvision, m) {
m.class_<GPUDecoder>("GPUDecoder")
.def(torch::init<std::string, int64_t>())
.def(torch::init<std::string, torch::Device>())
.def("seek", &GPUDecoder::seek)
.def("get_metadata", &GPUDecoder::get_metadata)
.def("next", &GPUDecoder::decode);
......
......@@ -5,7 +5,7 @@
class GPUDecoder : public torch::CustomClassHolder {
public:
GPUDecoder(std::string, int64_t);
GPUDecoder(std::string, torch::Device);
~GPUDecoder();
torch::Tensor decode();
void seek(double, bool);
......
......@@ -84,6 +84,7 @@ class VideoReader:
will depend on the version of FFMPEG codecs supported.
device (str, optional): Device to be used for decoding. Defaults to ``"cpu"``.
To use GPU decoding, pass ``device="cuda"``.
"""
......@@ -95,9 +96,7 @@ class VideoReader:
if not _HAS_GPU_VIDEO_DECODER:
raise RuntimeError("Not compiled with GPU decoder support.")
self.is_cuda = True
if device.index is None:
raise RuntimeError("Invalid cuda device!")
self._c = torch.classes.torchvision.GPUDecoder(path, device.index)
self._c = torch.classes.torchvision.GPUDecoder(path, device)
return
if not _has_video_opt():
raise RuntimeError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment