Unverified Commit b21e0bfb authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Add seek in GPU decoder (#5215)

* Add seek in GPU decoder

* Merge the two tests

* Refine unit test
parent 4c7a91ef
...@@ -37,6 +37,33 @@ class TestVideoGPUDecoder: ...@@ -37,6 +37,33 @@ class TestVideoGPUDecoder:
mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float())) mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float()))
assert mean_delta < 0.75 assert mean_delta < 0.75
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
@pytest.mark.parametrize("keyframes", [True, False])
@pytest.mark.parametrize(
"full_path, duration",
[
(os.path.join(VIDEO_DIR, x), y)
for x, y in [
("v_SoccerJuggling_g23_c01.avi", 8.0),
("v_SoccerJuggling_g24_c01.avi", 8.0),
("R6llTwEh07w.mp4", 10.0),
("SOX5yA1l24A.mp4", 11.0),
("WUzgd7C1pWA.mp4", 11.0),
]
],
)
def test_seek_reading(self, keyframes, full_path, duration):
decoder = VideoReader(full_path, device="cuda:0")
time = duration / 2
decoder.seek(time, keyframes_only=keyframes)
with av.open(full_path) as container:
container.seek(int(time * 1000000), any_frame=not keyframes, backward=False)
for av_frame in container.decode(container.streams.video[0]):
av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray())
vision_frames = next(decoder)["data"]
mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float()))
assert mean_delta < 0.75
@pytest.mark.skipif(av is None, reason="PyAV unavailable") @pytest.mark.skipif(av is None, reason="PyAV unavailable")
def test_metadata(self): def test_metadata(self):
for test_video in test_videos: for test_video in test_videos:
......
...@@ -218,6 +218,15 @@ class Demuxer { ...@@ -218,6 +218,15 @@ class Demuxer {
frameCount++; frameCount++;
return true; return true;
} }
void seek(double timestamp, int flag) {
int64_t time = timestamp * AV_TIME_BASE;
TORCH_CHECK(
0 <= av_seek_frame(fmtCtx, -1, time, flag),
"avformat_open_input() failed at line ",
__LINE__,
" in demuxer.h\n");
}
}; };
inline cudaVideoCodec ffmpeg_to_codec(AVCodecID id) { inline cudaVideoCodec ffmpeg_to_codec(AVCodecID id) {
......
...@@ -38,6 +38,14 @@ torch::Tensor GPUDecoder::decode() { ...@@ -38,6 +38,14 @@ torch::Tensor GPUDecoder::decode() {
return frame; return frame;
} }
/* Seek to a passed timestamp. The second argument controls whether to seek to a
* keyframe.
*/
void GPUDecoder::seek(double timestamp, bool keyframes_only) {
int flag = keyframes_only ? 0 : AVSEEK_FLAG_ANY;
demuxer.seek(timestamp, flag);
}
c10::Dict<std::string, c10::Dict<std::string, double>> GPUDecoder:: c10::Dict<std::string, c10::Dict<std::string, double>> GPUDecoder::
get_metadata() const { get_metadata() const {
c10::Dict<std::string, c10::Dict<std::string, double>> metadata; c10::Dict<std::string, c10::Dict<std::string, double>> metadata;
...@@ -51,6 +59,7 @@ c10::Dict<std::string, c10::Dict<std::string, double>> GPUDecoder:: ...@@ -51,6 +59,7 @@ c10::Dict<std::string, c10::Dict<std::string, double>> GPUDecoder::
TORCH_LIBRARY(torchvision, m) { TORCH_LIBRARY(torchvision, m) {
m.class_<GPUDecoder>("GPUDecoder") m.class_<GPUDecoder>("GPUDecoder")
.def(torch::init<std::string, int64_t>()) .def(torch::init<std::string, int64_t>())
.def("seek", &GPUDecoder::seek)
.def("get_metadata", &GPUDecoder::get_metadata) .def("get_metadata", &GPUDecoder::get_metadata)
.def("next", &GPUDecoder::decode); .def("next", &GPUDecoder::decode);
} }
...@@ -8,6 +8,7 @@ class GPUDecoder : public torch::CustomClassHolder { ...@@ -8,6 +8,7 @@ class GPUDecoder : public torch::CustomClassHolder {
GPUDecoder(std::string, int64_t); GPUDecoder(std::string, int64_t);
~GPUDecoder(); ~GPUDecoder();
torch::Tensor decode(); torch::Tensor decode();
void seek(double, bool);
c10::Dict<std::string, c10::Dict<std::string, double>> get_metadata() const; c10::Dict<std::string, c10::Dict<std::string, double>> get_metadata() const;
private: private:
......
...@@ -174,8 +174,6 @@ class VideoReader: ...@@ -174,8 +174,6 @@ class VideoReader:
frame with the exact timestamp if it exists or frame with the exact timestamp if it exists or
the first frame with timestamp larger than ``time_s``. the first frame with timestamp larger than ``time_s``.
""" """
if self.is_cuda:
raise RuntimeError("seek() not yet supported with GPU decoding.")
self._c.seek(time_s, keyframes_only) self._c.seek(time_s, keyframes_only)
return self return self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment