Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
f20177b7
Unverified
Commit
f20177b7
authored
Nov 04, 2022
by
Nicolas Hug
Committed by
GitHub
Nov 04, 2022
Browse files
[FBcode->GH] Revert "Pyav backend for VideoReader API (#6598)" (#6908)
This reverts commit
2e833520
.
parent
dc11b1f6
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
107 additions
and
210 deletions
+107
-210
test/test_video_gpu_decoder.py
test/test_video_gpu_decoder.py
+4
-9
test/test_videoapi.py
test/test_videoapi.py
+40
-58
torchvision/__init__.py
torchvision/__init__.py
+4
-17
torchvision/io/__init__.py
torchvision/io/__init__.py
+5
-0
torchvision/io/_load_gpu_decoder.py
torchvision/io/_load_gpu_decoder.py
+8
-0
torchvision/io/video_reader.py
torchvision/io/video_reader.py
+46
-126
No files found.
test/test_video_gpu_decoder.py
View file @
f20177b7
...
@@ -3,9 +3,7 @@ import os
...
@@ -3,9 +3,7 @@ import os
import
pytest
import
pytest
import
torch
import
torch
import
torchvision
from
torchvision.io
import
_HAS_GPU_VIDEO_DECODER
,
VideoReader
from
torchvision
import
_HAS_GPU_VIDEO_DECODER
from
torchvision.io
import
VideoReader
try
:
try
:
import
av
import
av
...
@@ -31,9 +29,8 @@ class TestVideoGPUDecoder:
...
@@ -31,9 +29,8 @@ class TestVideoGPUDecoder:
],
],
)
)
def
test_frame_reading
(
self
,
video_file
):
def
test_frame_reading
(
self
,
video_file
):
torchvision
.
set_video_backend
(
"cuda"
)
full_path
=
os
.
path
.
join
(
VIDEO_DIR
,
video_file
)
full_path
=
os
.
path
.
join
(
VIDEO_DIR
,
video_file
)
decoder
=
VideoReader
(
full_path
)
decoder
=
VideoReader
(
full_path
,
device
=
"cuda"
)
with
av
.
open
(
full_path
)
as
container
:
with
av
.
open
(
full_path
)
as
container
:
for
av_frame
in
container
.
decode
(
container
.
streams
.
video
[
0
]):
for
av_frame
in
container
.
decode
(
container
.
streams
.
video
[
0
]):
av_frames
=
torch
.
tensor
(
av_frame
.
to_rgb
(
src_colorspace
=
"ITU709"
).
to_ndarray
())
av_frames
=
torch
.
tensor
(
av_frame
.
to_rgb
(
src_colorspace
=
"ITU709"
).
to_ndarray
())
...
@@ -57,8 +54,7 @@ class TestVideoGPUDecoder:
...
@@ -57,8 +54,7 @@ class TestVideoGPUDecoder:
],
],
)
)
def
test_seek_reading
(
self
,
keyframes
,
full_path
,
duration
):
def
test_seek_reading
(
self
,
keyframes
,
full_path
,
duration
):
torchvision
.
set_video_backend
(
"cuda"
)
decoder
=
VideoReader
(
full_path
,
device
=
"cuda"
)
decoder
=
VideoReader
(
full_path
)
time
=
duration
/
2
time
=
duration
/
2
decoder
.
seek
(
time
,
keyframes_only
=
keyframes
)
decoder
.
seek
(
time
,
keyframes_only
=
keyframes
)
with
av
.
open
(
full_path
)
as
container
:
with
av
.
open
(
full_path
)
as
container
:
...
@@ -83,9 +79,8 @@ class TestVideoGPUDecoder:
...
@@ -83,9 +79,8 @@ class TestVideoGPUDecoder:
],
],
)
)
def
test_metadata
(
self
,
video_file
):
def
test_metadata
(
self
,
video_file
):
torchvision
.
set_video_backend
(
"cuda"
)
full_path
=
os
.
path
.
join
(
VIDEO_DIR
,
video_file
)
full_path
=
os
.
path
.
join
(
VIDEO_DIR
,
video_file
)
decoder
=
VideoReader
(
full_path
)
decoder
=
VideoReader
(
full_path
,
device
=
"cuda"
)
video_metadata
=
decoder
.
get_metadata
()[
"video"
]
video_metadata
=
decoder
.
get_metadata
()[
"video"
]
with
av
.
open
(
full_path
)
as
container
:
with
av
.
open
(
full_path
)
as
container
:
video
=
container
.
streams
.
video
[
0
]
video
=
container
.
streams
.
video
[
0
]
...
...
test/test_videoapi.py
View file @
f20177b7
...
@@ -53,9 +53,7 @@ test_videos = {
...
@@ -53,9 +53,7 @@ test_videos = {
class
TestVideoApi
:
class
TestVideoApi
:
@
pytest
.
mark
.
skipif
(
av
is
None
,
reason
=
"PyAV unavailable"
)
@
pytest
.
mark
.
skipif
(
av
is
None
,
reason
=
"PyAV unavailable"
)
@
pytest
.
mark
.
parametrize
(
"test_video"
,
test_videos
.
keys
())
@
pytest
.
mark
.
parametrize
(
"test_video"
,
test_videos
.
keys
())
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"video_reader"
,
"pyav"
])
def
test_frame_reading
(
self
,
test_video
):
def
test_frame_reading
(
self
,
test_video
,
backend
):
torchvision
.
set_video_backend
(
backend
)
full_path
=
os
.
path
.
join
(
VIDEO_DIR
,
test_video
)
full_path
=
os
.
path
.
join
(
VIDEO_DIR
,
test_video
)
with
av
.
open
(
full_path
)
as
av_reader
:
with
av
.
open
(
full_path
)
as
av_reader
:
if
av_reader
.
streams
.
video
:
if
av_reader
.
streams
.
video
:
...
@@ -119,60 +117,50 @@ class TestVideoApi:
...
@@ -119,60 +117,50 @@ class TestVideoApi:
@
pytest
.
mark
.
parametrize
(
"stream"
,
[
"video"
,
"audio"
])
@
pytest
.
mark
.
parametrize
(
"stream"
,
[
"video"
,
"audio"
])
@
pytest
.
mark
.
parametrize
(
"test_video"
,
test_videos
.
keys
())
@
pytest
.
mark
.
parametrize
(
"test_video"
,
test_videos
.
keys
())
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"video_reader"
,
"pyav"
])
def
test_frame_reading_mem_vs_file
(
self
,
test_video
,
stream
):
def
test_frame_reading_mem_vs_file
(
self
,
test_video
,
stream
,
backend
):
torchvision
.
set_video_backend
(
backend
)
full_path
=
os
.
path
.
join
(
VIDEO_DIR
,
test_video
)
full_path
=
os
.
path
.
join
(
VIDEO_DIR
,
test_video
)
reader
=
VideoReader
(
full_path
)
# Test video reading from file vs from memory
reader_md
=
reader
.
get_metadata
()
vr_frames
,
vr_frames_mem
=
[],
[]
vr_pts
,
vr_pts_mem
=
[],
[]
if
stream
in
reader_md
:
# get vr frames
# Test video reading from file vs from memory
video_reader
=
VideoReader
(
full_path
,
stream
)
vr_frames
,
vr_frames_mem
=
[],
[]
for
vr_frame
in
video_reader
:
vr_pts
,
vr_pts_mem
=
[],
[]
vr_frames
.
append
(
vr_frame
[
"data"
])
# get vr frames
vr_pts
.
append
(
vr_frame
[
"pts"
])
video_reader
=
VideoReader
(
full_path
,
stream
)
for
vr_frame
in
video_reader
:
# get vr frames = read from memory
vr_frames
.
append
(
vr_frame
[
"data"
])
f
=
open
(
full_path
,
"rb"
)
vr_pts
.
append
(
vr_frame
[
"pts"
])
fbytes
=
f
.
read
()
f
.
close
()
# get vr frames = read from memory
video_reader_from_mem
=
VideoReader
(
fbytes
,
stream
)
f
=
open
(
full_path
,
"rb"
)
fbytes
=
f
.
read
()
for
vr_frame_from_mem
in
video_reader_from_mem
:
f
.
close
()
vr_frames_mem
.
append
(
vr_frame_from_mem
[
"data"
])
video_reader_from_mem
=
VideoReader
(
fbytes
,
stream
)
vr_pts_mem
.
append
(
vr_frame_from_mem
[
"pts"
])
for
vr_frame_from_mem
in
video_reader_from_mem
:
# same number of frames
vr_frames_mem
.
append
(
vr_frame_from_mem
[
"data"
])
assert
len
(
vr_frames
)
==
len
(
vr_frames_mem
)
vr_pts_mem
.
append
(
vr_frame_from_mem
[
"pts"
])
assert
len
(
vr_pts
)
==
len
(
vr_pts_mem
)
# same number of frames
# compare the frames and ptss
assert
len
(
vr_frames
)
==
len
(
vr_frames_mem
)
for
i
in
range
(
len
(
vr_frames
)):
assert
len
(
vr_pts
)
==
len
(
vr_pts_mem
)
assert
vr_pts
[
i
]
==
vr_pts_mem
[
i
]
mean_delta
=
torch
.
mean
(
torch
.
abs
(
vr_frames
[
i
].
float
()
-
vr_frames_mem
[
i
].
float
()))
# compare the frames and ptss
# on average the difference is very small and caused
for
i
in
range
(
len
(
vr_frames
)):
# by decoding (around 1%)
assert
vr_pts
[
i
]
==
vr_pts_mem
[
i
]
# TODO: asses empirically how to set this? atm it's 1%
mean_delta
=
torch
.
mean
(
torch
.
abs
(
vr_frames
[
i
].
float
()
-
vr_frames_mem
[
i
].
float
()))
# averaged over all frames
# on average the difference is very small and caused
assert
mean_delta
.
item
()
<
2.55
# by decoding (around 1%)
# TODO: asses empirically how to set this? atm it's 1%
del
vr_frames
,
vr_pts
,
vr_frames_mem
,
vr_pts_mem
# averaged over all frames
assert
mean_delta
.
item
()
<
2.55
del
vr_frames
,
vr_pts
,
vr_frames_mem
,
vr_pts_mem
else
:
del
reader
,
reader_md
@
pytest
.
mark
.
parametrize
(
"test_video,config"
,
test_videos
.
items
())
@
pytest
.
mark
.
parametrize
(
"test_video,config"
,
test_videos
.
items
())
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"video_reader"
,
"pyav"
])
def
test_metadata
(
self
,
test_video
,
config
):
def
test_metadata
(
self
,
test_video
,
config
,
backend
):
"""
"""
Test that the metadata returned via pyav corresponds to the one returned
Test that the metadata returned via pyav corresponds to the one returned
by the new video decoder API
by the new video decoder API
"""
"""
torchvision
.
set_video_backend
(
backend
)
full_path
=
os
.
path
.
join
(
VIDEO_DIR
,
test_video
)
full_path
=
os
.
path
.
join
(
VIDEO_DIR
,
test_video
)
reader
=
VideoReader
(
full_path
,
"video"
)
reader
=
VideoReader
(
full_path
,
"video"
)
reader_md
=
reader
.
get_metadata
()
reader_md
=
reader
.
get_metadata
()
...
@@ -180,9 +168,7 @@ class TestVideoApi:
...
@@ -180,9 +168,7 @@ class TestVideoApi:
assert
config
.
duration
==
approx
(
reader_md
[
"video"
][
"duration"
][
0
],
abs
=
0.5
)
assert
config
.
duration
==
approx
(
reader_md
[
"video"
][
"duration"
][
0
],
abs
=
0.5
)
@
pytest
.
mark
.
parametrize
(
"test_video"
,
test_videos
.
keys
())
@
pytest
.
mark
.
parametrize
(
"test_video"
,
test_videos
.
keys
())
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"video_reader"
,
"pyav"
])
def
test_seek_start
(
self
,
test_video
):
def
test_seek_start
(
self
,
test_video
,
backend
):
torchvision
.
set_video_backend
(
backend
)
full_path
=
os
.
path
.
join
(
VIDEO_DIR
,
test_video
)
full_path
=
os
.
path
.
join
(
VIDEO_DIR
,
test_video
)
video_reader
=
VideoReader
(
full_path
,
"video"
)
video_reader
=
VideoReader
(
full_path
,
"video"
)
num_frames
=
0
num_frames
=
0
...
@@ -208,9 +194,7 @@ class TestVideoApi:
...
@@ -208,9 +194,7 @@ class TestVideoApi:
assert
start_num_frames
==
num_frames
assert
start_num_frames
==
num_frames
@
pytest
.
mark
.
parametrize
(
"test_video"
,
test_videos
.
keys
())
@
pytest
.
mark
.
parametrize
(
"test_video"
,
test_videos
.
keys
())
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"video_reader"
])
def
test_accurateseek_middle
(
self
,
test_video
):
def
test_accurateseek_middle
(
self
,
test_video
,
backend
):
torchvision
.
set_video_backend
(
backend
)
full_path
=
os
.
path
.
join
(
VIDEO_DIR
,
test_video
)
full_path
=
os
.
path
.
join
(
VIDEO_DIR
,
test_video
)
stream
=
"video"
stream
=
"video"
video_reader
=
VideoReader
(
full_path
,
stream
)
video_reader
=
VideoReader
(
full_path
,
stream
)
...
@@ -249,9 +233,7 @@ class TestVideoApi:
...
@@ -249,9 +233,7 @@ class TestVideoApi:
@
pytest
.
mark
.
skipif
(
av
is
None
,
reason
=
"PyAV unavailable"
)
@
pytest
.
mark
.
skipif
(
av
is
None
,
reason
=
"PyAV unavailable"
)
@
pytest
.
mark
.
parametrize
(
"test_video,config"
,
test_videos
.
items
())
@
pytest
.
mark
.
parametrize
(
"test_video,config"
,
test_videos
.
items
())
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"pyav"
,
"video_reader"
])
def
test_keyframe_reading
(
self
,
test_video
,
config
):
def
test_keyframe_reading
(
self
,
test_video
,
config
,
backend
):
torchvision
.
set_video_backend
(
backend
)
full_path
=
os
.
path
.
join
(
VIDEO_DIR
,
test_video
)
full_path
=
os
.
path
.
join
(
VIDEO_DIR
,
test_video
)
av_reader
=
av
.
open
(
full_path
)
av_reader
=
av
.
open
(
full_path
)
...
...
torchvision/__init__.py
View file @
f20177b7
import
os
import
os
import
warnings
import
warnings
from
modulefinder
import
Module
import
torch
import
torch
from
torchvision
import
datasets
,
io
,
models
,
ops
,
transforms
,
utils
from
torchvision
import
datasets
,
io
,
models
,
ops
,
transforms
,
utils
from
.extension
import
_HAS_OPS
,
_load_library
from
.extension
import
_HAS_OPS
try
:
try
:
from
.version
import
__version__
# noqa: F401
from
.version
import
__version__
# noqa: F401
except
ImportError
:
except
ImportError
:
pass
pass
try
:
_load_library
(
"Decoder"
)
_HAS_GPU_VIDEO_DECODER
=
True
except
(
ImportError
,
OSError
,
ModuleNotFoundError
):
_HAS_GPU_VIDEO_DECODER
=
False
# Check if torchvision is being imported within the root folder
# Check if torchvision is being imported within the root folder
if
not
_HAS_OPS
and
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
==
os
.
path
.
join
(
if
not
_HAS_OPS
and
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
==
os
.
path
.
join
(
os
.
path
.
realpath
(
os
.
getcwd
()),
"torchvision"
os
.
path
.
realpath
(
os
.
getcwd
()),
"torchvision"
...
@@ -74,16 +66,11 @@ def set_video_backend(backend):
...
@@ -74,16 +66,11 @@ def set_video_backend(backend):
backend, please compile torchvision from source.
backend, please compile torchvision from source.
"""
"""
global
_video_backend
global
_video_backend
if
backend
not
in
[
"pyav"
,
"video_reader"
,
"cuda"
]:
if
backend
not
in
[
"pyav"
,
"video_reader"
]:
raise
ValueError
(
"Invalid video backend '%s'. Options are 'pyav'
,
'video_reader'
and 'cuda'
"
%
backend
)
raise
ValueError
(
"Invalid video backend '%s'. Options are 'pyav'
and
'video_reader'"
%
backend
)
if
backend
==
"video_reader"
and
not
io
.
_HAS_VIDEO_OPT
:
if
backend
==
"video_reader"
and
not
io
.
_HAS_VIDEO_OPT
:
# TODO: better messages
message
=
"video_reader video backend is not available. Please compile torchvision from source and try again"
message
=
"video_reader video backend is not available. Please compile torchvision from source and try again"
raise
RuntimeError
(
message
)
warnings
.
warn
(
message
)
elif
backend
==
"cuda"
and
not
_HAS_GPU_VIDEO_DECODER
:
# TODO: better messages
message
=
"cuda video backend is not available."
raise
RuntimeError
(
message
)
else
:
else
:
_video_backend
=
backend
_video_backend
=
backend
...
...
torchvision/io/__init__.py
View file @
f20177b7
...
@@ -4,6 +4,10 @@ import torch
...
@@ -4,6 +4,10 @@ import torch
from
..utils
import
_log_api_usage_once
from
..utils
import
_log_api_usage_once
try
:
from
._load_gpu_decoder
import
_HAS_GPU_VIDEO_DECODER
except
ModuleNotFoundError
:
_HAS_GPU_VIDEO_DECODER
=
False
from
._video_opt
import
(
from
._video_opt
import
(
_HAS_VIDEO_OPT
,
_HAS_VIDEO_OPT
,
_probe_video_from_file
,
_probe_video_from_file
,
...
@@ -43,6 +47,7 @@ __all__ = [
...
@@ -43,6 +47,7 @@ __all__ = [
"_read_video_timestamps_from_memory"
,
"_read_video_timestamps_from_memory"
,
"_probe_video_from_memory"
,
"_probe_video_from_memory"
,
"_HAS_VIDEO_OPT"
,
"_HAS_VIDEO_OPT"
,
"_HAS_GPU_VIDEO_DECODER"
,
"_read_video_clip_from_memory"
,
"_read_video_clip_from_memory"
,
"_read_video_meta_data"
,
"_read_video_meta_data"
,
"VideoMetaData"
,
"VideoMetaData"
,
...
...
torchvision/io/_load_gpu_decoder.py
0 → 100644
View file @
f20177b7
from
..extension
import
_load_library
try
:
_load_library
(
"Decoder"
)
_HAS_GPU_VIDEO_DECODER
=
True
except
(
ImportError
,
OSError
):
_HAS_GPU_VIDEO_DECODER
=
False
torchvision/io/video_reader.py
View file @
f20177b7
import
io
import
warnings
import
warnings
from
typing
import
Any
,
Dict
,
Iterator
,
Optional
from
typing
import
Any
,
Dict
,
Iterator
,
Optional
import
torch
import
torch
from
..utils
import
_log_api_usage_once
from
..utils
import
_log_api_usage_once
try
:
from
._load_gpu_decoder
import
_HAS_GPU_VIDEO_DECODER
except
ModuleNotFoundError
:
_HAS_GPU_VIDEO_DECODER
=
False
from
._video_opt
import
_HAS_VIDEO_OPT
from
._video_opt
import
_HAS_VIDEO_OPT
if
_HAS_VIDEO_OPT
:
if
_HAS_VIDEO_OPT
:
...
@@ -20,37 +22,11 @@ else:
...
@@ -20,37 +22,11 @@ else:
return
False
return
False
try
:
import
av
av
.
logging
.
set_level
(
av
.
logging
.
ERROR
)
if
not
hasattr
(
av
.
video
.
frame
.
VideoFrame
,
"pict_type"
):
av
=
ImportError
(
"""
\
Your version of PyAV is too old for the necessary video operations in torchvision.
If you are on Python 3.5, you will have to build from source (the conda-forge
packages are not up-to-date). See
https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
"""
)
except
ImportError
:
av
=
ImportError
(
"""
\
PyAV is not installed, and is necessary for the video operations in torchvision.
See https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
"""
)
class
VideoReader
:
class
VideoReader
:
"""
"""
Fine-grained video-reading API.
Fine-grained video-reading API.
Supports frame-by-frame reading of various streams from a single video
Supports frame-by-frame reading of various streams from a single video
container. Much like previous video_reader API it supports the following
container.
backends: video_reader, pyav, and cuda.
Backends can be set via `torchvision.set_video_backend` function.
.. betastatus:: VideoReader class
.. betastatus:: VideoReader class
...
@@ -112,11 +88,16 @@ class VideoReader:
...
@@ -112,11 +88,16 @@ class VideoReader:
Default value (0) enables multithreading with codec-dependent heuristic. The performance
Default value (0) enables multithreading with codec-dependent heuristic. The performance
will depend on the version of FFMPEG codecs supported.
will depend on the version of FFMPEG codecs supported.
device (str, optional): Device to be used for decoding. Defaults to ``"cpu"``.
To use GPU decoding, pass ``device="cuda"``.
path (str, optional):
path (str, optional):
.. warning:
.. warning:
This parameter was deprecated in ``0.15`` and will be removed in ``0.17``.
This parameter was deprecated in ``0.15`` and will be removed in ``0.17``.
Please use ``src`` instead.
Please use ``src`` instead.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -124,59 +105,45 @@ class VideoReader:
...
@@ -124,59 +105,45 @@ class VideoReader:
src
:
str
=
""
,
src
:
str
=
""
,
stream
:
str
=
"video"
,
stream
:
str
=
"video"
,
num_threads
:
int
=
0
,
num_threads
:
int
=
0
,
device
:
str
=
"cpu"
,
path
:
Optional
[
str
]
=
None
,
path
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
_log_api_usage_once
(
self
)
_log_api_usage_once
(
self
)
from
..
import
get_video_backend
self
.
is_cuda
=
False
device
=
torch
.
device
(
device
)
if
device
.
type
==
"cuda"
:
if
not
_HAS_GPU_VIDEO_DECODER
:
raise
RuntimeError
(
"Not compiled with GPU decoder support."
)
self
.
is_cuda
=
True
self
.
_c
=
torch
.
classes
.
torchvision
.
GPUDecoder
(
src
,
device
)
return
if
not
_has_video_opt
():
raise
RuntimeError
(
"Not compiled with video_reader support, "
+
"to enable video_reader support, please install "
+
"ffmpeg (version 4.2 is currently supported) and "
+
"build torchvision from source."
)
if
src
==
""
:
if
path
is
None
:
raise
TypeError
(
"src cannot be empty"
)
src
=
path
warnings
.
warn
(
"path is deprecated and will be removed in 0.17. Please use src instead"
)
self
.
backend
=
get_video_backend
()
if
isinstance
(
src
,
str
):
if
src
==
""
:
if
path
is
None
:
raise
TypeError
(
"src cannot be empty"
)
src
=
path
warnings
.
warn
(
"path is deprecated and will be removed in 0.17. Please use src instead"
)
elif
isinstance
(
src
,
bytes
):
elif
isinstance
(
src
,
bytes
):
if
self
.
backend
in
[
"cuda"
]:
src
=
torch
.
frombuffer
(
src
,
dtype
=
torch
.
uint8
)
raise
RuntimeError
(
"VideoReader cannot be initialized from bytes object when using cuda or pyav backend."
if
isinstance
(
src
,
str
):
)
self
.
_c
=
torch
.
classes
.
torchvision
.
Video
(
src
,
stream
,
num_threads
)
elif
self
.
backend
==
"pyav"
:
src
=
io
.
BytesIO
(
src
)
else
:
src
=
torch
.
frombuffer
(
src
,
dtype
=
torch
.
uint8
)
elif
isinstance
(
src
,
torch
.
Tensor
):
elif
isinstance
(
src
,
torch
.
Tensor
):
if
self
.
backend
in
[
"cuda"
,
"pyav"
]
:
if
self
.
is_cuda
:
raise
RuntimeError
(
raise
RuntimeError
(
"GPU VideoReader cannot be initialized from Tensor or bytes object."
)
"VideoReader cannot be initialized from Tensor object when using cuda or pyav backend."
self
.
_c
=
torch
.
classes
.
torchvision
.
Video
(
""
,
""
,
0
)
)
self
.
_c
.
init_from_memory
(
src
,
stream
,
num_threads
)
else
:
else
:
raise
TypeError
(
"`src` must be either string, Tensor or bytes object."
)
raise
TypeError
(
"`src` must be either string, Tensor or bytes object."
)
if
self
.
backend
==
"cuda"
:
device
=
torch
.
device
(
"cuda"
)
self
.
_c
=
torch
.
classes
.
torchvision
.
GPUDecoder
(
src
,
device
)
elif
self
.
backend
==
"video_reader"
:
if
isinstance
(
src
,
str
):
self
.
_c
=
torch
.
classes
.
torchvision
.
Video
(
src
,
stream
,
num_threads
)
elif
isinstance
(
src
,
torch
.
Tensor
):
self
.
_c
=
torch
.
classes
.
torchvision
.
Video
(
""
,
""
,
0
)
self
.
_c
.
init_from_memory
(
src
,
stream
,
num_threads
)
elif
self
.
backend
==
"pyav"
:
self
.
container
=
av
.
open
(
src
,
metadata_errors
=
"ignore"
)
# TODO: load metadata
stream_type
=
stream
.
split
(
":"
)[
0
]
stream_id
=
0
if
len
(
stream
.
split
(
":"
))
==
1
else
int
(
stream
.
split
(
":"
)[
1
])
self
.
pyav_stream
=
{
stream_type
:
stream_id
}
self
.
_c
=
self
.
container
.
decode
(
**
self
.
pyav_stream
)
# TODO: add extradata exception
else
:
raise
RuntimeError
(
"Unknown video backend: {}"
.
format
(
self
.
backend
))
def
__next__
(
self
)
->
Dict
[
str
,
Any
]:
def
__next__
(
self
)
->
Dict
[
str
,
Any
]:
"""Decodes and returns the next frame of the current stream.
"""Decodes and returns the next frame of the current stream.
Frames are encoded as a dict with mandatory
Frames are encoded as a dict with mandatory
...
@@ -189,29 +156,14 @@ class VideoReader:
...
@@ -189,29 +156,14 @@ class VideoReader:
and corresponding timestamp (``pts``) in seconds
and corresponding timestamp (``pts``) in seconds
"""
"""
if
self
.
backend
==
"
cuda
"
:
if
self
.
is_
cuda
:
frame
=
self
.
_c
.
next
()
frame
=
self
.
_c
.
next
()
if
frame
.
numel
()
==
0
:
if
frame
.
numel
()
==
0
:
raise
StopIteration
raise
StopIteration
return
{
"data"
:
frame
,
"pts"
:
None
}
return
{
"data"
:
frame
}
elif
self
.
backend
==
"video_reader"
:
frame
,
pts
=
self
.
_c
.
next
()
frame
,
pts
=
self
.
_c
.
next
()
else
:
try
:
frame
=
next
(
self
.
_c
)
pts
=
float
(
frame
.
pts
*
frame
.
time_base
)
if
"video"
in
self
.
pyav_stream
:
frame
=
torch
.
tensor
(
frame
.
to_rgb
().
to_ndarray
()).
permute
(
2
,
0
,
1
)
elif
"audio"
in
self
.
pyav_stream
:
frame
=
torch
.
tensor
(
frame
.
to_ndarray
()).
permute
(
1
,
0
)
else
:
frame
=
None
except
av
.
error
.
EOFError
:
raise
StopIteration
if
frame
.
numel
()
==
0
:
if
frame
.
numel
()
==
0
:
raise
StopIteration
raise
StopIteration
return
{
"data"
:
frame
,
"pts"
:
pts
}
return
{
"data"
:
frame
,
"pts"
:
pts
}
def
__iter__
(
self
)
->
Iterator
[
Dict
[
str
,
Any
]]:
def
__iter__
(
self
)
->
Iterator
[
Dict
[
str
,
Any
]]:
...
@@ -230,18 +182,7 @@ class VideoReader:
...
@@ -230,18 +182,7 @@ class VideoReader:
frame with the exact timestamp if it exists or
frame with the exact timestamp if it exists or
the first frame with timestamp larger than ``time_s``.
the first frame with timestamp larger than ``time_s``.
"""
"""
if
self
.
backend
in
[
"cuda"
,
"video_reader"
]:
self
.
_c
.
seek
(
time_s
,
keyframes_only
)
self
.
_c
.
seek
(
time_s
,
keyframes_only
)
else
:
# handle special case as pyav doesn't catch it
if
time_s
<
0
:
time_s
=
0
temp_str
=
self
.
container
.
streams
.
get
(
**
self
.
pyav_stream
)[
0
]
offset
=
int
(
round
(
time_s
/
temp_str
.
time_base
))
if
not
keyframes_only
:
warnings
.
warn
(
"Accurate seek is not implemented for pyav backend"
)
self
.
container
.
seek
(
offset
,
backward
=
True
,
any_frame
=
False
,
stream
=
temp_str
)
self
.
_c
=
self
.
container
.
decode
(
**
self
.
pyav_stream
)
return
self
return
self
def
get_metadata
(
self
)
->
Dict
[
str
,
Any
]:
def
get_metadata
(
self
)
->
Dict
[
str
,
Any
]:
...
@@ -250,21 +191,6 @@ class VideoReader:
...
@@ -250,21 +191,6 @@ class VideoReader:
Returns:
Returns:
(dict): dictionary containing duration and frame rate for every stream
(dict): dictionary containing duration and frame rate for every stream
"""
"""
if
self
.
backend
==
"pyav"
:
metadata
=
{}
# type: Dict[str, Any]
for
stream
in
self
.
container
.
streams
:
if
stream
.
type
not
in
metadata
:
if
stream
.
type
==
"video"
:
rate_n
=
"fps"
else
:
rate_n
=
"framerate"
metadata
[
stream
.
type
]
=
{
rate_n
:
[],
"duration"
:
[]}
rate
=
stream
.
average_rate
if
stream
.
average_rate
is
not
None
else
stream
.
sample_rate
metadata
[
stream
.
type
][
"duration"
].
append
(
float
(
stream
.
duration
*
stream
.
time_base
))
metadata
[
stream
.
type
][
rate_n
].
append
(
float
(
rate
))
return
metadata
return
self
.
_c
.
get_metadata
()
return
self
.
_c
.
get_metadata
()
def
set_current_stream
(
self
,
stream
:
str
)
->
bool
:
def
set_current_stream
(
self
,
stream
:
str
)
->
bool
:
...
@@ -284,12 +210,6 @@ class VideoReader:
...
@@ -284,12 +210,6 @@ class VideoReader:
Returns:
Returns:
(bool): True on succes, False otherwise
(bool): True on succes, False otherwise
"""
"""
if
self
.
backend
==
"cuda"
:
if
self
.
is_cuda
:
warnings
.
warn
(
"GPU decoding only works with video stream."
)
print
(
"GPU decoding only works with video stream."
)
if
self
.
backend
==
"pyav"
:
stream_type
=
stream
.
split
(
":"
)[
0
]
stream_id
=
0
if
len
(
stream
.
split
(
":"
))
==
1
else
int
(
stream
.
split
(
":"
)[
1
])
self
.
pyav_stream
=
{
stream_type
:
stream_id
}
self
.
_c
=
self
.
container
.
decode
(
**
self
.
pyav_stream
)
return
True
return
self
.
_c
.
set_current_stream
(
stream
)
return
self
.
_c
.
set_current_stream
(
stream
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment