Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
b21e0bfb
Unverified
Commit
b21e0bfb
authored
Jan 24, 2022
by
Prabhat Roy
Committed by
GitHub
Jan 24, 2022
Browse files
Add seek in GPU decoder (#5215)
* Add seek in GPU decoder * Merge the two tests * Refine unit test
parent
4c7a91ef
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
46 additions
and
2 deletions
+46
-2
test/test_video_gpu_decoder.py
test/test_video_gpu_decoder.py
+27
-0
torchvision/csrc/io/decoder/gpu/demuxer.h
torchvision/csrc/io/decoder/gpu/demuxer.h
+9
-0
torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp
torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp
+9
-0
torchvision/csrc/io/decoder/gpu/gpu_decoder.h
torchvision/csrc/io/decoder/gpu/gpu_decoder.h
+1
-0
torchvision/io/__init__.py
torchvision/io/__init__.py
+0
-2
No files found.
test/test_video_gpu_decoder.py
View file @
b21e0bfb
...
...
@@ -37,6 +37,33 @@ class TestVideoGPUDecoder:
mean_delta
=
torch
.
mean
(
torch
.
abs
(
av_frames
.
float
()
-
vision_frames
.
cpu
().
float
()))
assert
mean_delta
<
0.75
@
pytest
.
mark
.
skipif
(
av
is
None
,
reason
=
"PyAV unavailable"
)
@
pytest
.
mark
.
parametrize
(
"keyframes"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"full_path, duration"
,
[
(
os
.
path
.
join
(
VIDEO_DIR
,
x
),
y
)
for
x
,
y
in
[
(
"v_SoccerJuggling_g23_c01.avi"
,
8.0
),
(
"v_SoccerJuggling_g24_c01.avi"
,
8.0
),
(
"R6llTwEh07w.mp4"
,
10.0
),
(
"SOX5yA1l24A.mp4"
,
11.0
),
(
"WUzgd7C1pWA.mp4"
,
11.0
),
]
],
)
def
test_seek_reading
(
self
,
keyframes
,
full_path
,
duration
):
decoder
=
VideoReader
(
full_path
,
device
=
"cuda:0"
)
time
=
duration
/
2
decoder
.
seek
(
time
,
keyframes_only
=
keyframes
)
with
av
.
open
(
full_path
)
as
container
:
container
.
seek
(
int
(
time
*
1000000
),
any_frame
=
not
keyframes
,
backward
=
False
)
for
av_frame
in
container
.
decode
(
container
.
streams
.
video
[
0
]):
av_frames
=
torch
.
tensor
(
av_frame
.
to_rgb
(
src_colorspace
=
"ITU709"
).
to_ndarray
())
vision_frames
=
next
(
decoder
)[
"data"
]
mean_delta
=
torch
.
mean
(
torch
.
abs
(
av_frames
.
float
()
-
vision_frames
.
cpu
().
float
()))
assert
mean_delta
<
0.75
@
pytest
.
mark
.
skipif
(
av
is
None
,
reason
=
"PyAV unavailable"
)
def
test_metadata
(
self
):
for
test_video
in
test_videos
:
...
...
torchvision/csrc/io/decoder/gpu/demuxer.h
View file @
b21e0bfb
...
...
@@ -218,6 +218,15 @@ class Demuxer {
frameCount
++
;
return
true
;
}
void
seek
(
double
timestamp
,
int
flag
)
{
int64_t
time
=
timestamp
*
AV_TIME_BASE
;
TORCH_CHECK
(
0
<=
av_seek_frame
(
fmtCtx
,
-
1
,
time
,
flag
),
"avformat_open_input() failed at line "
,
__LINE__
,
" in demuxer.h
\n
"
);
}
};
inline
cudaVideoCodec
ffmpeg_to_codec
(
AVCodecID
id
)
{
...
...
torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp
View file @
b21e0bfb
...
...
@@ -38,6 +38,14 @@ torch::Tensor GPUDecoder::decode() {
return
frame
;
}
/* Seek to a passed timestamp. The second argument controls whether to seek to a
* keyframe.
*/
void
GPUDecoder
::
seek
(
double
timestamp
,
bool
keyframes_only
)
{
int
flag
=
keyframes_only
?
0
:
AVSEEK_FLAG_ANY
;
demuxer
.
seek
(
timestamp
,
flag
);
}
c10
::
Dict
<
std
::
string
,
c10
::
Dict
<
std
::
string
,
double
>>
GPUDecoder
::
get_metadata
()
const
{
c10
::
Dict
<
std
::
string
,
c10
::
Dict
<
std
::
string
,
double
>>
metadata
;
...
...
@@ -51,6 +59,7 @@ c10::Dict<std::string, c10::Dict<std::string, double>> GPUDecoder::
TORCH_LIBRARY
(
torchvision
,
m
)
{
m
.
class_
<
GPUDecoder
>
(
"GPUDecoder"
)
.
def
(
torch
::
init
<
std
::
string
,
int64_t
>
())
.
def
(
"seek"
,
&
GPUDecoder
::
seek
)
.
def
(
"get_metadata"
,
&
GPUDecoder
::
get_metadata
)
.
def
(
"next"
,
&
GPUDecoder
::
decode
);
}
torchvision/csrc/io/decoder/gpu/gpu_decoder.h
View file @
b21e0bfb
...
...
@@ -8,6 +8,7 @@ class GPUDecoder : public torch::CustomClassHolder {
GPUDecoder
(
std
::
string
,
int64_t
);
~
GPUDecoder
();
torch
::
Tensor
decode
();
void
seek
(
double
,
bool
);
c10
::
Dict
<
std
::
string
,
c10
::
Dict
<
std
::
string
,
double
>>
get_metadata
()
const
;
private:
...
...
torchvision/io/__init__.py
View file @
b21e0bfb
...
...
@@ -174,8 +174,6 @@ class VideoReader:
frame with the exact timestamp if it exists or
the first frame with timestamp larger than ``time_s``.
"""
if
self
.
is_cuda
:
raise
RuntimeError
(
"seek() not yet supported with GPU decoding."
)
self
.
_c
.
seek
(
time_s
,
keyframes_only
)
return
self
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment