Unverified Commit f9fbc104 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Allow cuda device to be passed without the index for GPU decoding (#5505)

parent d4146ef1
...@@ -30,7 +30,7 @@ class TestVideoGPUDecoder: ...@@ -30,7 +30,7 @@ class TestVideoGPUDecoder:
) )
def test_frame_reading(self, video_file): def test_frame_reading(self, video_file):
full_path = os.path.join(VIDEO_DIR, video_file) full_path = os.path.join(VIDEO_DIR, video_file)
decoder = VideoReader(full_path, device="cuda:0") decoder = VideoReader(full_path, device="cuda")
with av.open(full_path) as container: with av.open(full_path) as container:
for av_frame in container.decode(container.streams.video[0]): for av_frame in container.decode(container.streams.video[0]):
av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray()) av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray())
...@@ -54,7 +54,7 @@ class TestVideoGPUDecoder: ...@@ -54,7 +54,7 @@ class TestVideoGPUDecoder:
], ],
) )
def test_seek_reading(self, keyframes, full_path, duration): def test_seek_reading(self, keyframes, full_path, duration):
decoder = VideoReader(full_path, device="cuda:0") decoder = VideoReader(full_path, device="cuda")
time = duration / 2 time = duration / 2
decoder.seek(time, keyframes_only=keyframes) decoder.seek(time, keyframes_only=keyframes)
with av.open(full_path) as container: with av.open(full_path) as container:
...@@ -80,7 +80,7 @@ class TestVideoGPUDecoder: ...@@ -80,7 +80,7 @@ class TestVideoGPUDecoder:
) )
def test_metadata(self, video_file): def test_metadata(self, video_file):
full_path = os.path.join(VIDEO_DIR, video_file) full_path = os.path.join(VIDEO_DIR, video_file)
decoder = VideoReader(full_path, device="cuda:0") decoder = VideoReader(full_path, device="cuda")
video_metadata = decoder.get_metadata()["video"] video_metadata = decoder.get_metadata()["video"]
with av.open(full_path) as container: with av.open(full_path) as container:
video = container.streams.video[0] video = container.streams.video[0]
......
...@@ -3,9 +3,10 @@ ...@@ -3,9 +3,10 @@
/* Set cuda device, create cuda context and initialise the demuxer and decoder. /* Set cuda device, create cuda context and initialise the demuxer and decoder.
*/ */
GPUDecoder::GPUDecoder(std::string src_file, int64_t dev) GPUDecoder::GPUDecoder(std::string src_file, torch::Device dev)
: demuxer(src_file.c_str()), device(dev) { : demuxer(src_file.c_str()) {
at::cuda::CUDAGuard device_guard(device); at::cuda::CUDAGuard device_guard(dev);
device = device_guard.current_device().index();
check_for_cuda_errors( check_for_cuda_errors(
cuDevicePrimaryCtxRetain(&ctx, device), __LINE__, __FILE__); cuDevicePrimaryCtxRetain(&ctx, device), __LINE__, __FILE__);
decoder.init(ctx, ffmpeg_to_codec(demuxer.get_video_codec())); decoder.init(ctx, ffmpeg_to_codec(demuxer.get_video_codec()));
...@@ -58,7 +59,7 @@ c10::Dict<std::string, c10::Dict<std::string, double>> GPUDecoder:: ...@@ -58,7 +59,7 @@ c10::Dict<std::string, c10::Dict<std::string, double>> GPUDecoder::
TORCH_LIBRARY(torchvision, m) { TORCH_LIBRARY(torchvision, m) {
m.class_<GPUDecoder>("GPUDecoder") m.class_<GPUDecoder>("GPUDecoder")
.def(torch::init<std::string, int64_t>()) .def(torch::init<std::string, torch::Device>())
.def("seek", &GPUDecoder::seek) .def("seek", &GPUDecoder::seek)
.def("get_metadata", &GPUDecoder::get_metadata) .def("get_metadata", &GPUDecoder::get_metadata)
.def("next", &GPUDecoder::decode); .def("next", &GPUDecoder::decode);
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
class GPUDecoder : public torch::CustomClassHolder { class GPUDecoder : public torch::CustomClassHolder {
public: public:
GPUDecoder(std::string, int64_t); GPUDecoder(std::string, torch::Device);
~GPUDecoder(); ~GPUDecoder();
torch::Tensor decode(); torch::Tensor decode();
void seek(double, bool); void seek(double, bool);
......
...@@ -84,6 +84,7 @@ class VideoReader: ...@@ -84,6 +84,7 @@ class VideoReader:
will depend on the version of FFMPEG codecs supported. will depend on the version of FFMPEG codecs supported.
device (str, optional): Device to be used for decoding. Defaults to ``"cpu"``. device (str, optional): Device to be used for decoding. Defaults to ``"cpu"``.
To use GPU decoding, pass ``device="cuda"``.
""" """
...@@ -95,9 +96,7 @@ class VideoReader: ...@@ -95,9 +96,7 @@ class VideoReader:
if not _HAS_GPU_VIDEO_DECODER: if not _HAS_GPU_VIDEO_DECODER:
raise RuntimeError("Not compiled with GPU decoder support.") raise RuntimeError("Not compiled with GPU decoder support.")
self.is_cuda = True self.is_cuda = True
if device.index is None: self._c = torch.classes.torchvision.GPUDecoder(path, device)
raise RuntimeError("Invalid cuda device!")
self._c = torch.classes.torchvision.GPUDecoder(path, device.index)
return return
if not _has_video_opt(): if not _has_video_opt():
raise RuntimeError( raise RuntimeError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment