Unverified Commit b21e0bfb authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Add seek in GPU decoder (#5215)

* Add seek in GPU decoder

* Merge the two tests

* Refine unit test
parent 4c7a91ef
......@@ -37,6 +37,33 @@ class TestVideoGPUDecoder:
mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float()))
assert mean_delta < 0.75
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
@pytest.mark.parametrize("keyframes", [True, False])
@pytest.mark.parametrize(
"full_path, duration",
[
(os.path.join(VIDEO_DIR, x), y)
for x, y in [
("v_SoccerJuggling_g23_c01.avi", 8.0),
("v_SoccerJuggling_g24_c01.avi", 8.0),
("R6llTwEh07w.mp4", 10.0),
("SOX5yA1l24A.mp4", 11.0),
("WUzgd7C1pWA.mp4", 11.0),
]
],
)
def test_seek_reading(self, keyframes, full_path, duration):
decoder = VideoReader(full_path, device="cuda:0")
time = duration / 2
decoder.seek(time, keyframes_only=keyframes)
with av.open(full_path) as container:
container.seek(int(time * 1000000), any_frame=not keyframes, backward=False)
for av_frame in container.decode(container.streams.video[0]):
av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray())
vision_frames = next(decoder)["data"]
mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float()))
assert mean_delta < 0.75
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
def test_metadata(self):
for test_video in test_videos:
......
......@@ -218,6 +218,15 @@ class Demuxer {
frameCount++;
return true;
}
void seek(double timestamp, int flag) {
int64_t time = timestamp * AV_TIME_BASE;
TORCH_CHECK(
0 <= av_seek_frame(fmtCtx, -1, time, flag),
"avformat_open_input() failed at line ",
__LINE__,
" in demuxer.h\n");
}
};
inline cudaVideoCodec ffmpeg_to_codec(AVCodecID id) {
......
......@@ -38,6 +38,14 @@ torch::Tensor GPUDecoder::decode() {
return frame;
}
/* Seek to a passed timestamp. The second argument controls whether to seek to a
* keyframe.
*/
void GPUDecoder::seek(double timestamp, bool keyframes_only) {
int flag = keyframes_only ? 0 : AVSEEK_FLAG_ANY;
demuxer.seek(timestamp, flag);
}
c10::Dict<std::string, c10::Dict<std::string, double>> GPUDecoder::
get_metadata() const {
c10::Dict<std::string, c10::Dict<std::string, double>> metadata;
......@@ -51,6 +59,7 @@ c10::Dict<std::string, c10::Dict<std::string, double>> GPUDecoder::
TORCH_LIBRARY(torchvision, m) {
m.class_<GPUDecoder>("GPUDecoder")
.def(torch::init<std::string, int64_t>())
.def("seek", &GPUDecoder::seek)
.def("get_metadata", &GPUDecoder::get_metadata)
.def("next", &GPUDecoder::decode);
}
......@@ -8,6 +8,7 @@ class GPUDecoder : public torch::CustomClassHolder {
GPUDecoder(std::string, int64_t);
~GPUDecoder();
torch::Tensor decode();
void seek(double, bool);
c10::Dict<std::string, c10::Dict<std::string, double>> get_metadata() const;
private:
......
......@@ -174,8 +174,6 @@ class VideoReader:
frame with the exact timestamp if it exists or
the first frame with timestamp larger than ``time_s``.
"""
if self.is_cuda:
raise RuntimeError("seek() not yet supported with GPU decoding.")
self._c.seek(time_s, keyframes_only)
return self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment