"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "5d1372c0251d9b961e1d550ced8a07260426ff30"
Unverified Commit 4c7a91ef authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Add support for get_metadata() in GPU decoder (#5256)

parent 44ae1e51
import math
import os
import pytest
......@@ -36,6 +37,18 @@ class TestVideoGPUDecoder:
mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float()))
assert mean_delta < 0.75
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
def test_metadata(self):
for test_video in test_videos:
full_path = os.path.join(VIDEO_DIR, test_video)
decoder = VideoReader(full_path, device="cuda:0")
video_metadata = decoder.get_metadata()["video"]
with av.open(full_path) as container:
video = container.streams.video[0]
av_duration = float(video.duration * video.time_base)
assert math.isclose(video_metadata["duration"], av_duration, rel_tol=1e-2)
assert math.isclose(video_metadata["fps"], video.base_rate, rel_tol=1e-2)
if __name__ == "__main__":
pytest.main([__file__])
......@@ -142,6 +142,14 @@ class Demuxer {
return eVideoCodec;
}
double get_duration() const {
return (double)fmtCtx->duration / AV_TIME_BASE;
}
double get_fps() const {
return av_q2d(fmtCtx->streams[iVideoStream]->r_frame_rate);
}
bool demux(uint8_t** video, unsigned long* videoBytes) {
if (!fmtCtx) {
return false;
......
......@@ -38,8 +38,19 @@ torch::Tensor GPUDecoder::decode() {
return frame;
}
c10::Dict<std::string, c10::Dict<std::string, double>> GPUDecoder::
get_metadata() const {
c10::Dict<std::string, c10::Dict<std::string, double>> metadata;
c10::Dict<std::string, double> video_metadata;
video_metadata.insert("duration", demuxer.get_duration());
video_metadata.insert("fps", demuxer.get_fps());
metadata.insert("video", video_metadata);
return metadata;
}
TORCH_LIBRARY(torchvision, m) {
m.class_<GPUDecoder>("GPUDecoder")
.def(torch::init<std::string, int64_t>())
.def("get_metadata", &GPUDecoder::get_metadata)
.def("next", &GPUDecoder::decode);
}
......@@ -8,6 +8,7 @@ class GPUDecoder : public torch::CustomClassHolder {
GPUDecoder(std::string, int64_t);
~GPUDecoder();
torch::Tensor decode();
c10::Dict<std::string, c10::Dict<std::string, double>> get_metadata() const;
private:
Demuxer demuxer;
......
......@@ -185,8 +185,6 @@ class VideoReader:
Returns:
(dict): dictionary containing duration and frame rate for every stream
"""
if self.is_cuda:
raise RuntimeError("get_metadata() not yet supported with GPU decoding.")
return self._c.get_metadata()
def set_current_stream(self, stream: str) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment