"docs/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "9b93e7df8c53631ef964cfff118772f4f9fa17bd"
Unverified Commit f9fbc104 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Allow cuda device to be passed without the index for GPU decoding (#5505)

parent d4146ef1
...@@ -30,7 +30,7 @@ class TestVideoGPUDecoder: ...@@ -30,7 +30,7 @@ class TestVideoGPUDecoder:
) )
def test_frame_reading(self, video_file): def test_frame_reading(self, video_file):
full_path = os.path.join(VIDEO_DIR, video_file) full_path = os.path.join(VIDEO_DIR, video_file)
decoder = VideoReader(full_path, device="cuda:0") decoder = VideoReader(full_path, device="cuda")
with av.open(full_path) as container: with av.open(full_path) as container:
for av_frame in container.decode(container.streams.video[0]): for av_frame in container.decode(container.streams.video[0]):
av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray()) av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray())
...@@ -54,7 +54,7 @@ class TestVideoGPUDecoder: ...@@ -54,7 +54,7 @@ class TestVideoGPUDecoder:
], ],
) )
def test_seek_reading(self, keyframes, full_path, duration): def test_seek_reading(self, keyframes, full_path, duration):
decoder = VideoReader(full_path, device="cuda:0") decoder = VideoReader(full_path, device="cuda")
time = duration / 2 time = duration / 2
decoder.seek(time, keyframes_only=keyframes) decoder.seek(time, keyframes_only=keyframes)
with av.open(full_path) as container: with av.open(full_path) as container:
...@@ -80,7 +80,7 @@ class TestVideoGPUDecoder: ...@@ -80,7 +80,7 @@ class TestVideoGPUDecoder:
) )
def test_metadata(self, video_file): def test_metadata(self, video_file):
full_path = os.path.join(VIDEO_DIR, video_file) full_path = os.path.join(VIDEO_DIR, video_file)
decoder = VideoReader(full_path, device="cuda:0") decoder = VideoReader(full_path, device="cuda")
video_metadata = decoder.get_metadata()["video"] video_metadata = decoder.get_metadata()["video"]
with av.open(full_path) as container: with av.open(full_path) as container:
video = container.streams.video[0] video = container.streams.video[0]
......
...@@ -3,9 +3,10 @@ ...@@ -3,9 +3,10 @@
/* Set cuda device, create cuda context and initialise the demuxer and decoder. /* Set cuda device, create cuda context and initialise the demuxer and decoder.
*/ */
GPUDecoder::GPUDecoder(std::string src_file, int64_t dev) GPUDecoder::GPUDecoder(std::string src_file, torch::Device dev)
: demuxer(src_file.c_str()), device(dev) { : demuxer(src_file.c_str()) {
at::cuda::CUDAGuard device_guard(device); at::cuda::CUDAGuard device_guard(dev);
device = device_guard.current_device().index();
check_for_cuda_errors( check_for_cuda_errors(
cuDevicePrimaryCtxRetain(&ctx, device), __LINE__, __FILE__); cuDevicePrimaryCtxRetain(&ctx, device), __LINE__, __FILE__);
decoder.init(ctx, ffmpeg_to_codec(demuxer.get_video_codec())); decoder.init(ctx, ffmpeg_to_codec(demuxer.get_video_codec()));
...@@ -58,7 +59,7 @@ c10::Dict<std::string, c10::Dict<std::string, double>> GPUDecoder:: ...@@ -58,7 +59,7 @@ c10::Dict<std::string, c10::Dict<std::string, double>> GPUDecoder::
TORCH_LIBRARY(torchvision, m) { TORCH_LIBRARY(torchvision, m) {
m.class_<GPUDecoder>("GPUDecoder") m.class_<GPUDecoder>("GPUDecoder")
.def(torch::init<std::string, int64_t>()) .def(torch::init<std::string, torch::Device>())
.def("seek", &GPUDecoder::seek) .def("seek", &GPUDecoder::seek)
.def("get_metadata", &GPUDecoder::get_metadata) .def("get_metadata", &GPUDecoder::get_metadata)
.def("next", &GPUDecoder::decode); .def("next", &GPUDecoder::decode);
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
class GPUDecoder : public torch::CustomClassHolder { class GPUDecoder : public torch::CustomClassHolder {
public: public:
GPUDecoder(std::string, int64_t); GPUDecoder(std::string, torch::Device);
~GPUDecoder(); ~GPUDecoder();
torch::Tensor decode(); torch::Tensor decode();
void seek(double, bool); void seek(double, bool);
......
...@@ -84,6 +84,7 @@ class VideoReader: ...@@ -84,6 +84,7 @@ class VideoReader:
will depend on the version of FFMPEG codecs supported. will depend on the version of FFMPEG codecs supported.
device (str, optional): Device to be used for decoding. Defaults to ``"cpu"``. device (str, optional): Device to be used for decoding. Defaults to ``"cpu"``.
To use GPU decoding, pass ``device="cuda"``.
""" """
...@@ -95,9 +96,7 @@ class VideoReader: ...@@ -95,9 +96,7 @@ class VideoReader:
if not _HAS_GPU_VIDEO_DECODER: if not _HAS_GPU_VIDEO_DECODER:
raise RuntimeError("Not compiled with GPU decoder support.") raise RuntimeError("Not compiled with GPU decoder support.")
self.is_cuda = True self.is_cuda = True
if device.index is None: self._c = torch.classes.torchvision.GPUDecoder(path, device)
raise RuntimeError("Invalid cuda device!")
self._c = torch.classes.torchvision.GPUDecoder(path, device.index)
return return
if not _has_video_opt(): if not _has_video_opt():
raise RuntimeError( raise RuntimeError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment