Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
f9fbc104
Unverified
Commit
f9fbc104
authored
Mar 01, 2022
by
Prabhat Roy
Committed by
GitHub
Mar 01, 2022
Browse files
Allow cuda device to be passed without the index for GPU decoding (#5505)
parent
d4146ef1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
11 additions
and
11 deletions
+11
-11
test/test_video_gpu_decoder.py
test/test_video_gpu_decoder.py
+3
-3
torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp
torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp
+5
-4
torchvision/csrc/io/decoder/gpu/gpu_decoder.h
torchvision/csrc/io/decoder/gpu/gpu_decoder.h
+1
-1
torchvision/io/video_reader.py
torchvision/io/video_reader.py
+2
-3
No files found.
test/test_video_gpu_decoder.py
View file @
f9fbc104
...
...
@@ -30,7 +30,7 @@ class TestVideoGPUDecoder:
)
def
test_frame_reading
(
self
,
video_file
):
full_path
=
os
.
path
.
join
(
VIDEO_DIR
,
video_file
)
decoder
=
VideoReader
(
full_path
,
device
=
"cuda
:0
"
)
decoder
=
VideoReader
(
full_path
,
device
=
"cuda"
)
with
av
.
open
(
full_path
)
as
container
:
for
av_frame
in
container
.
decode
(
container
.
streams
.
video
[
0
]):
av_frames
=
torch
.
tensor
(
av_frame
.
to_rgb
(
src_colorspace
=
"ITU709"
).
to_ndarray
())
...
...
@@ -54,7 +54,7 @@ class TestVideoGPUDecoder:
],
)
def
test_seek_reading
(
self
,
keyframes
,
full_path
,
duration
):
decoder
=
VideoReader
(
full_path
,
device
=
"cuda
:0
"
)
decoder
=
VideoReader
(
full_path
,
device
=
"cuda"
)
time
=
duration
/
2
decoder
.
seek
(
time
,
keyframes_only
=
keyframes
)
with
av
.
open
(
full_path
)
as
container
:
...
...
@@ -80,7 +80,7 @@ class TestVideoGPUDecoder:
)
def
test_metadata
(
self
,
video_file
):
full_path
=
os
.
path
.
join
(
VIDEO_DIR
,
video_file
)
decoder
=
VideoReader
(
full_path
,
device
=
"cuda
:0
"
)
decoder
=
VideoReader
(
full_path
,
device
=
"cuda"
)
video_metadata
=
decoder
.
get_metadata
()[
"video"
]
with
av
.
open
(
full_path
)
as
container
:
video
=
container
.
streams
.
video
[
0
]
...
...
torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp
View file @
f9fbc104
...
...
@@ -3,9 +3,10 @@
/* Set cuda device, create cuda context and initialise the demuxer and decoder.
*/
GPUDecoder
::
GPUDecoder
(
std
::
string
src_file
,
int64_t
dev
)
:
demuxer
(
src_file
.
c_str
()),
device
(
dev
)
{
at
::
cuda
::
CUDAGuard
device_guard
(
device
);
GPUDecoder
::
GPUDecoder
(
std
::
string
src_file
,
torch
::
Device
dev
)
:
demuxer
(
src_file
.
c_str
())
{
at
::
cuda
::
CUDAGuard
device_guard
(
dev
);
device
=
device_guard
.
current_device
().
index
();
check_for_cuda_errors
(
cuDevicePrimaryCtxRetain
(
&
ctx
,
device
),
__LINE__
,
__FILE__
);
decoder
.
init
(
ctx
,
ffmpeg_to_codec
(
demuxer
.
get_video_codec
()));
...
...
@@ -58,7 +59,7 @@ c10::Dict<std::string, c10::Dict<std::string, double>> GPUDecoder::
TORCH_LIBRARY
(
torchvision
,
m
)
{
m
.
class_
<
GPUDecoder
>
(
"GPUDecoder"
)
.
def
(
torch
::
init
<
std
::
string
,
int64_t
>
())
.
def
(
torch
::
init
<
std
::
string
,
torch
::
Device
>
())
.
def
(
"seek"
,
&
GPUDecoder
::
seek
)
.
def
(
"get_metadata"
,
&
GPUDecoder
::
get_metadata
)
.
def
(
"next"
,
&
GPUDecoder
::
decode
);
...
...
torchvision/csrc/io/decoder/gpu/gpu_decoder.h
View file @
f9fbc104
...
...
@@ -5,7 +5,7 @@
class
GPUDecoder
:
public
torch
::
CustomClassHolder
{
public:
GPUDecoder
(
std
::
string
,
int64_t
);
GPUDecoder
(
std
::
string
,
torch
::
Device
);
~
GPUDecoder
();
torch
::
Tensor
decode
();
void
seek
(
double
,
bool
);
...
...
torchvision/io/video_reader.py
View file @
f9fbc104
...
...
@@ -84,6 +84,7 @@ class VideoReader:
will depend on the version of FFMPEG codecs supported.
device (str, optional): Device to be used for decoding. Defaults to ``"cpu"``.
To use GPU decoding, pass ``device="cuda"``.
"""
...
...
@@ -95,9 +96,7 @@ class VideoReader:
if
not
_HAS_GPU_VIDEO_DECODER
:
raise
RuntimeError
(
"Not compiled with GPU decoder support."
)
self
.
is_cuda
=
True
if
device
.
index
is
None
:
raise
RuntimeError
(
"Invalid cuda device!"
)
self
.
_c
=
torch
.
classes
.
torchvision
.
GPUDecoder
(
path
,
device
.
index
)
self
.
_c
=
torch
.
classes
.
torchvision
.
GPUDecoder
(
path
,
device
)
return
if
not
_has_video_opt
():
raise
RuntimeError
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment