Commit ed5b2dc3 authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Francisco Massa
Browse files

extend video reader to support fast video probing (#1437)

* extend video reader to support fast video probing

* fix c++ lint

* small fix

* allow to accept input video of type torch.Tensor
parent 7ae1b8c9
...@@ -87,6 +87,22 @@ class Tester(unittest.TestCase): ...@@ -87,6 +87,22 @@ class Tester(unittest.TestCase):
self.assertTrue(data.equal(lv)) self.assertTrue(data.equal(lv))
self.assertEqual(info["video_fps"], 5) self.assertEqual(info["video_fps"], 5)
@unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen")
def test_probe_video_from_file(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
video_info = io._probe_video_from_file(f_name)
self.assertAlmostEqual(video_info["video_duration"], 2, delta=0.1)
self.assertAlmostEqual(video_info["video_fps"], 5, delta=0.1)
@unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen")
def test_probe_video_from_memory(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
with open(f_name, "rb") as fp:
filebuffer = fp.read()
video_info = io._probe_video_from_memory(filebuffer)
self.assertAlmostEqual(video_info["video_duration"], 2, delta=0.1)
self.assertAlmostEqual(video_info["video_fps"], 5, delta=0.1)
def test_read_timestamps(self): def test_read_timestamps(self):
with temp_video(10, 300, 300, 5) as (f_name, data): with temp_video(10, 300, 300, 5) as (f_name, data):
if _video_backend == "pyav": if _video_backend == "pyav":
......
...@@ -31,6 +31,7 @@ from torchvision.io._video_opt import _HAS_VIDEO_OPT ...@@ -31,6 +31,7 @@ from torchvision.io._video_opt import _HAS_VIDEO_OPT
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos") VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
CheckerConfig = [ CheckerConfig = [
"duration",
"video_fps", "video_fps",
"audio_sample_rate", "audio_sample_rate",
# We find for some videos (e.g. HMDB51 videos), the decoded audio frames and pts are # We find for some videos (e.g. HMDB51 videos), the decoded audio frames and pts are
...@@ -44,6 +45,7 @@ GroundTruth = collections.namedtuple( ...@@ -44,6 +45,7 @@ GroundTruth = collections.namedtuple(
) )
all_check_config = GroundTruth( all_check_config = GroundTruth(
duration=0,
video_fps=0, video_fps=0,
audio_sample_rate=0, audio_sample_rate=0,
check_aframes=True, check_aframes=True,
...@@ -52,36 +54,42 @@ all_check_config = GroundTruth( ...@@ -52,36 +54,42 @@ all_check_config = GroundTruth(
test_videos = { test_videos = {
"RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth( "RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth(
duration=2.0,
video_fps=30.0, video_fps=30.0,
audio_sample_rate=None, audio_sample_rate=None,
check_aframes=True, check_aframes=True,
check_aframe_pts=True, check_aframe_pts=True,
), ),
"SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth( "SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth(
duration=2.0,
video_fps=30.0, video_fps=30.0,
audio_sample_rate=None, audio_sample_rate=None,
check_aframes=True, check_aframes=True,
check_aframe_pts=True, check_aframe_pts=True,
), ),
"TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth( "TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth(
duration=2.0,
video_fps=30.0, video_fps=30.0,
audio_sample_rate=None, audio_sample_rate=None,
check_aframes=True, check_aframes=True,
check_aframe_pts=True, check_aframe_pts=True,
), ),
"v_SoccerJuggling_g23_c01.avi": GroundTruth( "v_SoccerJuggling_g23_c01.avi": GroundTruth(
duration=8.0,
video_fps=29.97, video_fps=29.97,
audio_sample_rate=None, audio_sample_rate=None,
check_aframes=True, check_aframes=True,
check_aframe_pts=True, check_aframe_pts=True,
), ),
"v_SoccerJuggling_g24_c01.avi": GroundTruth( "v_SoccerJuggling_g24_c01.avi": GroundTruth(
duration=8.0,
video_fps=29.97, video_fps=29.97,
audio_sample_rate=None, audio_sample_rate=None,
check_aframes=True, check_aframes=True,
check_aframe_pts=True, check_aframe_pts=True,
), ),
"R6llTwEh07w.mp4": GroundTruth( "R6llTwEh07w.mp4": GroundTruth(
duration=10.0,
video_fps=30.0, video_fps=30.0,
audio_sample_rate=44100, audio_sample_rate=44100,
# PyAv miss one audio frame at the beginning (pts=0) # PyAv miss one audio frame at the beginning (pts=0)
...@@ -89,6 +97,7 @@ test_videos = { ...@@ -89,6 +97,7 @@ test_videos = {
check_aframe_pts=False, check_aframe_pts=False,
), ),
"SOX5yA1l24A.mp4": GroundTruth( "SOX5yA1l24A.mp4": GroundTruth(
duration=11.0,
video_fps=29.97, video_fps=29.97,
audio_sample_rate=48000, audio_sample_rate=48000,
# PyAv miss one audio frame at the beginning (pts=0) # PyAv miss one audio frame at the beginning (pts=0)
...@@ -96,6 +105,7 @@ test_videos = { ...@@ -96,6 +105,7 @@ test_videos = {
check_aframe_pts=False, check_aframe_pts=False,
), ),
"WUzgd7C1pWA.mp4": GroundTruth( "WUzgd7C1pWA.mp4": GroundTruth(
duration=11.0,
video_fps=29.97, video_fps=29.97,
audio_sample_rate=48000, audio_sample_rate=48000,
# PyAv miss one audio frame at the beginning (pts=0) # PyAv miss one audio frame at the beginning (pts=0)
...@@ -272,13 +282,22 @@ class TestVideoReader(unittest.TestCase): ...@@ -272,13 +282,22 @@ class TestVideoReader(unittest.TestCase):
def check_separate_decoding_result(self, tv_result, config): def check_separate_decoding_result(self, tv_result, config):
"""check the decoding results from TorchVision decoder """check the decoding results from TorchVision decoder
""" """
vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = ( vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
tv_result atimebase, asample_rate, aduration = tv_result
video_duration = vduration.item() * Fraction(
vtimebase[0].item(), vtimebase[1].item()
) )
self.assertAlmostEqual(video_duration, config.duration, delta=0.5)
self.assertAlmostEqual(vfps.item(), config.video_fps, delta=0.5) self.assertAlmostEqual(vfps.item(), config.video_fps, delta=0.5)
if asample_rate.numel() > 0: if asample_rate.numel() > 0:
self.assertEqual(asample_rate.item(), config.audio_sample_rate) self.assertEqual(asample_rate.item(), config.audio_sample_rate)
audio_duration = aduration.item() * Fraction(
atimebase[0].item(), atimebase[1].item()
)
self.assertAlmostEqual(audio_duration, config.duration, delta=0.5)
# check if pts of video frames are sorted in ascending order # check if pts of video frames are sorted in ascending order
for i in range(len(vframe_pts) - 1): for i in range(len(vframe_pts) - 1):
self.assertEqual(vframe_pts[i] < vframe_pts[i + 1], True) self.assertEqual(vframe_pts[i] < vframe_pts[i + 1], True)
...@@ -288,6 +307,20 @@ class TestVideoReader(unittest.TestCase): ...@@ -288,6 +307,20 @@ class TestVideoReader(unittest.TestCase):
for i in range(len(aframe_pts) - 1): for i in range(len(aframe_pts) - 1):
self.assertEqual(aframe_pts[i] < aframe_pts[i + 1], True) self.assertEqual(aframe_pts[i] < aframe_pts[i + 1], True)
def check_probe_result(self, result, config):
vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
video_duration = vduration.item() * Fraction(
vtimebase[0].item(), vtimebase[1].item()
)
self.assertAlmostEqual(video_duration, config.duration, delta=0.5)
self.assertAlmostEqual(vfps.item(), config.video_fps, delta=0.5)
if asample_rate.numel() > 0:
self.assertEqual(asample_rate.item(), config.audio_sample_rate)
audio_duration = aduration.item() * Fraction(
atimebase[0].item(), atimebase[1].item()
)
self.assertAlmostEqual(audio_duration, config.duration, delta=0.5)
def compare_decoding_result(self, tv_result, ref_result, config=all_check_config): def compare_decoding_result(self, tv_result, ref_result, config=all_check_config):
""" """
Compare decoding results from two sources. Compare decoding results from two sources.
...@@ -297,18 +330,17 @@ class TestVideoReader(unittest.TestCase): ...@@ -297,18 +330,17 @@ class TestVideoReader(unittest.TestCase):
decoder or TorchVision decoder with getPtsOnly = 1 decoder or TorchVision decoder with getPtsOnly = 1
config: config of decoding results checker config: config of decoding results checker
""" """
vframes, vframe_pts, vtimebase, _vfps, aframes, aframe_pts, atimebase, _asample_rate = ( vframes, vframe_pts, vtimebase, _vfps, _vduration, aframes, aframe_pts, \
tv_result atimebase, _asample_rate, _aduration = tv_result
)
if isinstance(ref_result, list): if isinstance(ref_result, list):
# the ref_result is from new video_reader decoder # the ref_result is from new video_reader decoder
ref_result = DecoderResult( ref_result = DecoderResult(
vframes=ref_result[0], vframes=ref_result[0],
vframe_pts=ref_result[1], vframe_pts=ref_result[1],
vtimebase=ref_result[2], vtimebase=ref_result[2],
aframes=ref_result[4], aframes=ref_result[5],
aframe_pts=ref_result[5], aframe_pts=ref_result[6],
atimebase=ref_result[6], atimebase=ref_result[7],
) )
if vframes.numel() > 0 and ref_result.vframes.numel() > 0: if vframes.numel() > 0 and ref_result.vframes.numel() > 0:
...@@ -351,12 +383,12 @@ class TestVideoReader(unittest.TestCase): ...@@ -351,12 +383,12 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for i in range(num_iter): for _i in range(num_iter):
for test_video, config in test_videos.items(): for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
# pass 1: decode all frames using new decoder # pass 1: decode all frames using new decoder
_ = torch.ops.video_reader.read_video_from_file( torch.ops.video_reader.read_video_from_file(
full_path, full_path,
seek_frame_margin, seek_frame_margin,
0, # getPtsOnly 0, # getPtsOnly
...@@ -460,9 +492,8 @@ class TestVideoReader(unittest.TestCase): ...@@ -460,9 +492,8 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_den, audio_timebase_den,
) )
vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = ( vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
tv_result atimebase, asample_rate, aduration = tv_result
)
self.assertEqual(vframes.numel() > 0, readVideoStream) self.assertEqual(vframes.numel() > 0, readVideoStream)
self.assertEqual(vframe_pts.numel() > 0, readVideoStream) self.assertEqual(vframe_pts.numel() > 0, readVideoStream)
...@@ -489,7 +520,7 @@ class TestVideoReader(unittest.TestCase): ...@@ -489,7 +520,7 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items(): for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file( tv_result = torch.ops.video_reader.read_video_from_file(
...@@ -528,7 +559,7 @@ class TestVideoReader(unittest.TestCase): ...@@ -528,7 +559,7 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items(): for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file( tv_result = torch.ops.video_reader.read_video_from_file(
...@@ -567,7 +598,7 @@ class TestVideoReader(unittest.TestCase): ...@@ -567,7 +598,7 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items(): for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file( tv_result = torch.ops.video_reader.read_video_from_file(
...@@ -606,7 +637,7 @@ class TestVideoReader(unittest.TestCase): ...@@ -606,7 +637,7 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items(): for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file( tv_result = torch.ops.video_reader.read_video_from_file(
...@@ -651,7 +682,7 @@ class TestVideoReader(unittest.TestCase): ...@@ -651,7 +682,7 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items(): for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file( tv_result = torch.ops.video_reader.read_video_from_file(
...@@ -674,18 +705,17 @@ class TestVideoReader(unittest.TestCase): ...@@ -674,18 +705,17 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_num, audio_timebase_num,
audio_timebase_den, audio_timebase_den,
) )
vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, a_sample_rate = ( vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
tv_result atimebase, asample_rate, aduration = tv_result
)
if aframes.numel() > 0: if aframes.numel() > 0:
self.assertEqual(samples, a_sample_rate.item()) self.assertEqual(samples, asample_rate.item())
self.assertEqual(1, aframes.size(1)) self.assertEqual(1, aframes.size(1))
# when audio stream is found # when audio stream is found
duration = float(aframe_pts[-1]) * float(atimebase[0]) / float(atimebase[1]) duration = float(aframe_pts[-1]) * float(atimebase[0]) / float(atimebase[1])
self.assertAlmostEqual( self.assertAlmostEqual(
aframes.size(0), aframes.size(0),
int(duration * a_sample_rate.item()), int(duration * asample_rate.item()),
delta=0.1 * a_sample_rate.item(), delta=0.1 * asample_rate.item(),
) )
def test_compare_read_video_from_memory_and_file(self): def test_compare_read_video_from_memory_and_file(self):
...@@ -859,7 +889,7 @@ class TestVideoReader(unittest.TestCase): ...@@ -859,7 +889,7 @@ class TestVideoReader(unittest.TestCase):
) )
self.assertEqual(tv_result_pts_only[0].numel(), 0) self.assertEqual(tv_result_pts_only[0].numel(), 0)
self.assertEqual(tv_result_pts_only[4].numel(), 0) self.assertEqual(tv_result_pts_only[5].numel(), 0)
self.compare_decoding_result(tv_result, tv_result_pts_only) self.compare_decoding_result(tv_result, tv_result_pts_only)
def test_read_video_in_range_from_memory(self): def test_read_video_in_range_from_memory(self):
...@@ -899,9 +929,8 @@ class TestVideoReader(unittest.TestCase): ...@@ -899,9 +929,8 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_num, audio_timebase_num,
audio_timebase_den, audio_timebase_den,
) )
vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = ( vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
tv_result atimebase, asample_rate, aduration = tv_result
)
self.assertAlmostEqual(config.video_fps, vfps.item(), delta=0.01) self.assertAlmostEqual(config.video_fps, vfps.item(), delta=0.01)
for num_frames in [4, 8, 16, 32, 64, 128]: for num_frames in [4, 8, 16, 32, 64, 128]:
...@@ -997,6 +1026,24 @@ class TestVideoReader(unittest.TestCase): ...@@ -997,6 +1026,24 @@ class TestVideoReader(unittest.TestCase):
# and PyAv # and PyAv
self.compare_decoding_result(tv_result, pyav_result, config) self.compare_decoding_result(tv_result, pyav_result, config)
def test_probe_video_from_file(self):
"""
Test the case when decoder probes a video file
"""
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
probe_result = torch.ops.video_reader.probe_video_from_file(full_path)
self.check_probe_result(probe_result, config)
def test_probe_video_from_memory(self):
"""
Test the case when decoder probes a video in memory
"""
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
probe_result = torch.ops.video_reader.probe_video_from_memory(video_tensor)
self.check_probe_result(probe_result, config)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -49,6 +49,7 @@ void FfmpegAudioStream::updateStreamDecodeParams() { ...@@ -49,6 +49,7 @@ void FfmpegAudioStream::updateStreamDecodeParams() {
mediaFormat_.format.audio.timeBaseDen = mediaFormat_.format.audio.timeBaseDen =
inputCtx_->streams[index_]->time_base.den; inputCtx_->streams[index_]->time_base.den;
} }
mediaFormat_.format.audio.duration = inputCtx_->streams[index_]->duration;
} }
int FfmpegAudioStream::initFormat() { int FfmpegAudioStream::initFormat() {
......
...@@ -220,6 +220,30 @@ int FfmpegDecoder::decodeMemory( ...@@ -220,6 +220,30 @@ int FfmpegDecoder::decodeMemory(
return ret; return ret;
} }
int FfmpegDecoder::probeFile(
unique_ptr<DecoderParameters> params,
const string& fileName,
DecoderOutput& decoderOutput) {
VLOG(1) << "probe file: " << fileName;
FfmpegAvioContext ioctx;
return probeVideo(std::move(params), fileName, true, ioctx, decoderOutput);
}
int FfmpegDecoder::probeMemory(
unique_ptr<DecoderParameters> params,
const uint8_t* buffer,
int64_t size,
DecoderOutput& decoderOutput) {
VLOG(1) << "probe video data in memory";
FfmpegAvioContext ioctx;
int ret = ioctx.initAVIOContext(buffer, size);
if (ret == 0) {
ret =
probeVideo(std::move(params), string(""), false, ioctx, decoderOutput);
}
return ret;
}
void FfmpegDecoder::cleanUp() { void FfmpegDecoder::cleanUp() {
if (formatCtx_) { if (formatCtx_) {
for (auto& stream : streams_) { for (auto& stream : streams_) {
...@@ -320,6 +344,16 @@ int FfmpegDecoder::decodeLoop( ...@@ -320,6 +344,16 @@ int FfmpegDecoder::decodeLoop(
return ret; return ret;
} }
int FfmpegDecoder::probeVideo(
unique_ptr<DecoderParameters> params,
const std::string& filename,
bool isDecodeFile,
FfmpegAvioContext& ioctx,
DecoderOutput& decoderOutput) {
params_ = std::move(params);
return init(filename, isDecodeFile, ioctx, decoderOutput);
}
bool FfmpegDecoder::initStreams() { bool FfmpegDecoder::initStreams() {
for (auto it = params_->formats.begin(); it != params_->formats.end(); ++it) { for (auto it = params_->formats.begin(); it != params_->formats.end(); ++it) {
AVMediaType mediaType; AVMediaType mediaType;
......
...@@ -75,6 +75,19 @@ class FfmpegDecoder { ...@@ -75,6 +75,19 @@ class FfmpegDecoder {
const uint8_t* buffer, const uint8_t* buffer,
int64_t size, int64_t size,
DecoderOutput& decoderOutput); DecoderOutput& decoderOutput);
// return 0 on success
// return negative number on failure
int probeFile(
std::unique_ptr<DecoderParameters> params,
const std::string& filename,
DecoderOutput& decoderOutput);
// return 0 on success
// return negative number on failure
int probeMemory(
std::unique_ptr<DecoderParameters> params,
const uint8_t* buffer,
int64_t size,
DecoderOutput& decoderOutput);
void cleanUp(); void cleanUp();
...@@ -95,6 +108,13 @@ class FfmpegDecoder { ...@@ -95,6 +108,13 @@ class FfmpegDecoder {
FfmpegAvioContext& ioctx, FfmpegAvioContext& ioctx,
DecoderOutput& decoderOutput); DecoderOutput& decoderOutput);
int probeVideo(
std::unique_ptr<DecoderParameters> params,
const std::string& filename,
bool isDecodeFile,
FfmpegAvioContext& ioctx,
DecoderOutput& decoderOutput);
bool initStreams(); bool initStreams();
void flushStreams(DecoderOutput& decoderOutput); void flushStreams(DecoderOutput& decoderOutput);
......
...@@ -48,6 +48,7 @@ void FfmpegVideoStream::updateStreamDecodeParams() { ...@@ -48,6 +48,7 @@ void FfmpegVideoStream::updateStreamDecodeParams() {
mediaFormat_.format.video.timeBaseDen = mediaFormat_.format.video.timeBaseDen =
inputCtx_->streams[index_]->time_base.den; inputCtx_->streams[index_]->time_base.den;
} }
mediaFormat_.format.video.duration = inputCtx_->streams[index_]->duration;
} }
int FfmpegVideoStream::initFormat() { int FfmpegVideoStream::initFormat() {
......
...@@ -48,6 +48,7 @@ struct VideoFormat { ...@@ -48,6 +48,7 @@ struct VideoFormat {
int timeBaseNum{0}; int timeBaseNum{0};
int timeBaseDen{1}; // numerator and denominator of time base int timeBaseDen{1}; // numerator and denominator of time base
float fps{0.0}; float fps{0.0};
int64_t duration{0}; // duration of the stream, in stream time base
}; };
struct AudioFormat { struct AudioFormat {
...@@ -60,6 +61,7 @@ struct AudioFormat { ...@@ -60,6 +61,7 @@ struct AudioFormat {
int64_t startPts{0}, endPts{0}; // Start and end presentation timestamp int64_t startPts{0}, endPts{0}; // Start and end presentation timestamp
int timeBaseNum{0}; int timeBaseNum{0};
int timeBaseDen{1}; // numerator and denominator of time base int timeBaseDen{1}; // numerator and denominator of time base
int64_t duration{0}; // duration of the stream, in stream time base
}; };
union FormatUnion { union FormatUnion {
......
...@@ -27,8 +27,6 @@ PyMODINIT_FUNC PyInit_video_reader(void) { ...@@ -27,8 +27,6 @@ PyMODINIT_FUNC PyInit_video_reader(void) {
namespace video_reader { namespace video_reader {
bool glog_initialized = false;
class UnknownPixelFormatException : public exception { class UnknownPixelFormatException : public exception {
const char* what() const throw() override { const char* what() const throw() override {
return "Unknown pixel format"; return "Unknown pixel format";
...@@ -167,11 +165,6 @@ torch::List<torch::Tensor> readVideo( ...@@ -167,11 +165,6 @@ torch::List<torch::Tensor> readVideo(
int64_t audioEndPts, int64_t audioEndPts,
int64_t audioTimeBaseNum, int64_t audioTimeBaseNum,
int64_t audioTimeBaseDen) { int64_t audioTimeBaseDen) {
if (!glog_initialized) {
glog_initialized = true;
// google::InitGoogleLogging("VideoReader");
}
unique_ptr<DecoderParameters> params = util::getDecoderParams( unique_ptr<DecoderParameters> params = util::getDecoderParams(
seekFrameMargin, seekFrameMargin,
getPtsOnly, getPtsOnly,
...@@ -209,6 +202,8 @@ torch::List<torch::Tensor> readVideo( ...@@ -209,6 +202,8 @@ torch::List<torch::Tensor> readVideo(
torch::Tensor videoFramePts = torch::zeros({0}, torch::kLong); torch::Tensor videoFramePts = torch::zeros({0}, torch::kLong);
torch::Tensor videoTimeBase = torch::zeros({0}, torch::kInt); torch::Tensor videoTimeBase = torch::zeros({0}, torch::kInt);
torch::Tensor videoFps = torch::zeros({0}, torch::kFloat); torch::Tensor videoFps = torch::zeros({0}, torch::kFloat);
torch::Tensor videoDuration = torch::zeros({0}, torch::kLong);
if (readVideoStream == 1) { if (readVideoStream == 1) {
auto it = decoderOutput.media_data_.find(TYPE_VIDEO); auto it = decoderOutput.media_data_.find(TYPE_VIDEO);
if (it != decoderOutput.media_data_.end()) { if (it != decoderOutput.media_data_.end()) {
...@@ -236,6 +231,10 @@ torch::List<torch::Tensor> readVideo( ...@@ -236,6 +231,10 @@ torch::List<torch::Tensor> readVideo(
videoFps = torch::zeros({1}, torch::kFloat); videoFps = torch::zeros({1}, torch::kFloat);
float* videoFpsData = videoFps.data_ptr<float>(); float* videoFpsData = videoFps.data_ptr<float>();
videoFpsData[0] = it->second.format_.video.fps; videoFpsData[0] = it->second.format_.video.fps;
videoDuration = torch::zeros({1}, torch::kLong);
int64_t* videoDurationData = videoDuration.data_ptr<int64_t>();
videoDurationData[0] = it->second.format_.video.duration;
} else { } else {
VLOG(1) << "Miss video stream"; VLOG(1) << "Miss video stream";
} }
...@@ -246,6 +245,7 @@ torch::List<torch::Tensor> readVideo( ...@@ -246,6 +245,7 @@ torch::List<torch::Tensor> readVideo(
torch::Tensor audioFramePts = torch::zeros({0}, torch::kLong); torch::Tensor audioFramePts = torch::zeros({0}, torch::kLong);
torch::Tensor audioTimeBase = torch::zeros({0}, torch::kInt); torch::Tensor audioTimeBase = torch::zeros({0}, torch::kInt);
torch::Tensor audioSampleRate = torch::zeros({0}, torch::kInt); torch::Tensor audioSampleRate = torch::zeros({0}, torch::kInt);
torch::Tensor audioDuration = torch::zeros({0}, torch::kLong);
if (readAudioStream == 1) { if (readAudioStream == 1) {
auto it = decoderOutput.media_data_.find(TYPE_AUDIO); auto it = decoderOutput.media_data_.find(TYPE_AUDIO);
if (it != decoderOutput.media_data_.end()) { if (it != decoderOutput.media_data_.end()) {
...@@ -275,6 +275,10 @@ torch::List<torch::Tensor> readVideo( ...@@ -275,6 +275,10 @@ torch::List<torch::Tensor> readVideo(
audioSampleRate = torch::zeros({1}, torch::kInt); audioSampleRate = torch::zeros({1}, torch::kInt);
int* audioSampleRateData = audioSampleRate.data_ptr<int>(); int* audioSampleRateData = audioSampleRate.data_ptr<int>();
audioSampleRateData[0] = it->second.format_.audio.samples; audioSampleRateData[0] = it->second.format_.audio.samples;
audioDuration = torch::zeros({1}, torch::kLong);
int64_t* audioDurationData = audioDuration.data_ptr<int64_t>();
audioDurationData[0] = it->second.format_.audio.duration;
} else { } else {
VLOG(1) << "Miss audio stream"; VLOG(1) << "Miss audio stream";
} }
...@@ -285,10 +289,12 @@ torch::List<torch::Tensor> readVideo( ...@@ -285,10 +289,12 @@ torch::List<torch::Tensor> readVideo(
result.push_back(std::move(videoFramePts)); result.push_back(std::move(videoFramePts));
result.push_back(std::move(videoTimeBase)); result.push_back(std::move(videoTimeBase));
result.push_back(std::move(videoFps)); result.push_back(std::move(videoFps));
result.push_back(std::move(videoDuration));
result.push_back(std::move(audioFrame)); result.push_back(std::move(audioFrame));
result.push_back(std::move(audioFramePts)); result.push_back(std::move(audioFramePts));
result.push_back(std::move(audioTimeBase)); result.push_back(std::move(audioTimeBase));
result.push_back(std::move(audioSampleRate)); result.push_back(std::move(audioSampleRate));
result.push_back(std::move(audioDuration));
return result; return result;
} }
...@@ -378,10 +384,117 @@ torch::List<torch::Tensor> readVideoFromFile( ...@@ -378,10 +384,117 @@ torch::List<torch::Tensor> readVideoFromFile(
audioTimeBaseDen); audioTimeBaseDen);
} }
torch::List<torch::Tensor> probeVideo(
bool isReadFile,
const torch::Tensor& input_video,
std::string videoPath) {
unique_ptr<DecoderParameters> params = util::getDecoderParams(
0, // seekFrameMargin
0, // getPtsOnly
1, // readVideoStream
0, // width
0, // height
0, // minDimension
0, // videoStartPts
0, // videoEndPts
0, // videoTimeBaseNum
1, // videoTimeBaseDen
1, // readAudioStream
0, // audioSamples
0, // audioChannels
0, // audioStartPts
0, // audioEndPts
0, // audioTimeBaseNum
1 // audioTimeBaseDen
);
FfmpegDecoder decoder;
DecoderOutput decoderOutput;
if (isReadFile) {
decoder.probeFile(std::move(params), videoPath, decoderOutput);
} else {
decoder.probeMemory(
std::move(params),
input_video.data_ptr<uint8_t>(),
input_video.size(0),
decoderOutput);
}
// video section
torch::Tensor videoTimeBase = torch::zeros({0}, torch::kInt);
torch::Tensor videoFps = torch::zeros({0}, torch::kFloat);
torch::Tensor videoDuration = torch::zeros({0}, torch::kLong);
auto it = decoderOutput.media_data_.find(TYPE_VIDEO);
if (it != decoderOutput.media_data_.end()) {
VLOG(1) << "Find video stream";
videoTimeBase = torch::zeros({2}, torch::kInt);
int* videoTimeBaseData = videoTimeBase.data_ptr<int>();
videoTimeBaseData[0] = it->second.format_.video.timeBaseNum;
videoTimeBaseData[1] = it->second.format_.video.timeBaseDen;
videoFps = torch::zeros({1}, torch::kFloat);
float* videoFpsData = videoFps.data_ptr<float>();
videoFpsData[0] = it->second.format_.video.fps;
videoDuration = torch::zeros({1}, torch::kLong);
int64_t* videoDurationData = videoDuration.data_ptr<int64_t>();
videoDurationData[0] = it->second.format_.video.duration;
} else {
VLOG(1) << "Miss video stream";
}
// audio section
torch::Tensor audioTimeBase = torch::zeros({0}, torch::kInt);
torch::Tensor audioSampleRate = torch::zeros({0}, torch::kInt);
torch::Tensor audioDuration = torch::zeros({0}, torch::kLong);
it = decoderOutput.media_data_.find(TYPE_AUDIO);
if (it != decoderOutput.media_data_.end()) {
VLOG(1) << "Find audio stream";
audioTimeBase = torch::zeros({2}, torch::kInt);
int* audioTimeBaseData = audioTimeBase.data_ptr<int>();
audioTimeBaseData[0] = it->second.format_.audio.timeBaseNum;
audioTimeBaseData[1] = it->second.format_.audio.timeBaseDen;
audioSampleRate = torch::zeros({1}, torch::kInt);
int* audioSampleRateData = audioSampleRate.data_ptr<int>();
audioSampleRateData[0] = it->second.format_.audio.samples;
audioDuration = torch::zeros({1}, torch::kLong);
int64_t* audioDurationData = audioDuration.data_ptr<int64_t>();
audioDurationData[0] = it->second.format_.audio.duration;
} else {
VLOG(1) << "Miss audio stream";
}
torch::List<torch::Tensor> result;
result.push_back(std::move(videoTimeBase));
result.push_back(std::move(videoFps));
result.push_back(std::move(videoDuration));
result.push_back(std::move(audioTimeBase));
result.push_back(std::move(audioSampleRate));
result.push_back(std::move(audioDuration));
return result;
}
torch::List<torch::Tensor> probeVideoFromMemory(torch::Tensor input_video) {
return probeVideo(false, input_video, "");
}
torch::List<torch::Tensor> probeVideoFromFile(std::string videoPath) {
torch::Tensor dummy_input_video = torch::ones({0});
return probeVideo(true, dummy_input_video, videoPath);
}
} // namespace video_reader } // namespace video_reader
static auto registry = torch::RegisterOperators() static auto registry = torch::RegisterOperators()
.op("video_reader::read_video_from_memory", .op("video_reader::read_video_from_memory",
&video_reader::readVideoFromMemory) &video_reader::readVideoFromMemory)
.op("video_reader::read_video_from_file", .op("video_reader::read_video_from_file",
&video_reader::readVideoFromFile); &video_reader::readVideoFromFile)
.op("video_reader::probe_video_from_memory",
&video_reader::probeVideoFromMemory)
.op("video_reader::probe_video_from_file",
&video_reader::probeVideoFromFile);
...@@ -2,15 +2,17 @@ from .video import write_video, read_video, read_video_timestamps ...@@ -2,15 +2,17 @@ from .video import write_video, read_video, read_video_timestamps
from ._video_opt import ( from ._video_opt import (
_read_video_from_file, _read_video_from_file,
_read_video_timestamps_from_file, _read_video_timestamps_from_file,
_probe_video_from_file,
_read_video_from_memory, _read_video_from_memory,
_read_video_timestamps_from_memory, _read_video_timestamps_from_memory,
_probe_video_from_memory,
_HAS_VIDEO_OPT, _HAS_VIDEO_OPT,
) )
__all__ = [ __all__ = [
'write_video', 'read_video', 'read_video_timestamps', 'write_video', 'read_video', 'read_video_timestamps',
'_read_video_from_file', '_read_video_timestamps_from_file', '_read_video_from_file', '_read_video_timestamps_from_file', '_probe_video_from_file',
'_read_video_from_memory', '_read_video_timestamps_from_memory', '_read_video_from_memory', '_read_video_timestamps_from_memory', '_probe_video_from_memory',
'_HAS_VIDEO_OPT', '_HAS_VIDEO_OPT',
] ]
...@@ -26,14 +26,20 @@ def _validate_pts(pts_range): ...@@ -26,14 +26,20 @@ def _validate_pts(pts_range):
start pts: %d and end pts: %d""" % (pts_range[0], pts_range[1]) start pts: %d and end pts: %d""" % (pts_range[0], pts_range[1])
def _fill_info(vtimebase, vfps, atimebase, asample_rate): def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration):
info = {} info = {}
if vtimebase.numel() > 0: if vtimebase.numel() > 0:
info["video_timebase"] = Fraction(vtimebase[0].item(), vtimebase[1].item()) info["video_timebase"] = Fraction(vtimebase[0].item(), vtimebase[1].item())
if vduration.numel() > 0:
video_duration = vduration.item() * info["video_timebase"]
info["video_duration"] = video_duration
if vfps.numel() > 0: if vfps.numel() > 0:
info["video_fps"] = vfps.item() info["video_fps"] = vfps.item()
if atimebase.numel() > 0: if atimebase.numel() > 0:
info["audio_timebase"] = Fraction(atimebase[0].item(), atimebase[1].item()) info["audio_timebase"] = Fraction(atimebase[0].item(), atimebase[1].item())
if aduration.numel() > 0:
audio_duration = aduration.item() * info["audio_timebase"]
info["audio_duration"] = audio_duration
if asample_rate.numel() > 0: if asample_rate.numel() > 0:
info["audio_sample_rate"] = asample_rate.item() info["audio_sample_rate"] = asample_rate.item()
...@@ -141,8 +147,9 @@ def _read_video_from_file( ...@@ -141,8 +147,9 @@ def _read_video_from_file(
audio_timebase.numerator, audio_timebase.numerator,
audio_timebase.denominator, audio_timebase.denominator,
) )
vframes, _vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = result vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, \
info = _fill_info(vtimebase, vfps, atimebase, asample_rate) asample_rate, aduration = result
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
if aframes.numel() > 0: if aframes.numel() > 0:
# when audio stream is found # when audio stream is found
aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range) aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
...@@ -175,16 +182,30 @@ def _read_video_timestamps_from_file(filename): ...@@ -175,16 +182,30 @@ def _read_video_timestamps_from_file(filename):
0, # audio_timebase_num 0, # audio_timebase_num
1, # audio_timebase_den 1, # audio_timebase_den
) )
_vframes, vframe_pts, vtimebase, vfps, _aframes, aframe_pts, atimebase, asample_rate = result _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, \
info = _fill_info(vtimebase, vfps, atimebase, asample_rate) asample_rate, aduration = result
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
vframe_pts = vframe_pts.numpy().tolist() vframe_pts = vframe_pts.numpy().tolist()
aframe_pts = aframe_pts.numpy().tolist() aframe_pts = aframe_pts.numpy().tolist()
return vframe_pts, aframe_pts, info return vframe_pts, aframe_pts, info
def _probe_video_from_file(filename):
"""
Probe a video file.
Return:
info [dict]: contain video meta information, including video_timebase,
video_duration, video_fps, audio_timebase, audio_duration, audio_sample_rate
"""
result = torch.ops.video_reader.probe_video_from_file(filename)
vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
return info
def _read_video_from_memory( def _read_video_from_memory(
file_buffer, video_data,
seek_frame_margin=0.25, seek_frame_margin=0.25,
read_video_stream=1, read_video_stream=1,
video_width=0, video_width=0,
...@@ -204,8 +225,8 @@ def _read_video_from_memory( ...@@ -204,8 +225,8 @@ def _read_video_from_memory(
Args Args
---------- ----------
file_buffer : buffer video_data : data type could be 1) torch.Tensor, dtype=torch.int8 or 2) python bytes
buffer of compressed video content compressed video content stored in either 1) torch.Tensor 2) python bytes
seek_frame_margin: double, optional seek_frame_margin: double, optional
seeking frame in the stream is imprecise. Thus, when video_start_pts is specified, seeking frame in the stream is imprecise. Thus, when video_start_pts is specified,
we seek the pts earlier by seek_frame_margin seconds we seek the pts earlier by seek_frame_margin seconds
...@@ -252,10 +273,11 @@ def _read_video_from_memory( ...@@ -252,10 +273,11 @@ def _read_video_from_memory(
_validate_pts(video_pts_range) _validate_pts(video_pts_range)
_validate_pts(audio_pts_range) _validate_pts(audio_pts_range)
video_tensor = torch.from_numpy(np.frombuffer(file_buffer, dtype=np.uint8)) if not isinstance(video_data, torch.Tensor):
video_data = torch.from_numpy(np.frombuffer(video_data, dtype=np.uint8))
result = torch.ops.video_reader.read_video_from_memory( result = torch.ops.video_reader.read_video_from_memory(
video_tensor, video_data,
seek_frame_margin, seek_frame_margin,
0, # getPtsOnly 0, # getPtsOnly
read_video_stream, read_video_stream,
...@@ -275,24 +297,25 @@ def _read_video_from_memory( ...@@ -275,24 +297,25 @@ def _read_video_from_memory(
audio_timebase.denominator, audio_timebase.denominator,
) )
vframes, _vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = result vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
info = _fill_info(vtimebase, vfps, atimebase, asample_rate) atimebase, asample_rate, aduration = result
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
if aframes.numel() > 0: if aframes.numel() > 0:
# when audio stream is found # when audio stream is found
aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range) aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
return vframes, aframes, info return vframes, aframes, info
def _read_video_timestamps_from_memory(file_buffer): def _read_video_timestamps_from_memory(video_data):
""" """
Decode all frames in the video. Only pts (presentation timestamp) is returned. Decode all frames in the video. Only pts (presentation timestamp) is returned.
The actual frame pixel data is not copied. Thus, read_video_timestamps(...) The actual frame pixel data is not copied. Thus, read_video_timestamps(...)
is much faster than read_video(...) is much faster than read_video(...)
""" """
if not isinstance(video_data, torch.Tensor):
video_tensor = torch.from_numpy(np.frombuffer(file_buffer, dtype=np.uint8)) video_data = torch.from_numpy(np.frombuffer(video_data, dtype=np.uint8))
result = torch.ops.video_reader.read_video_from_memory( result = torch.ops.video_reader.read_video_from_memory(
video_tensor, video_data,
0, # seek_frame_margin 0, # seek_frame_margin
1, # getPtsOnly 1, # getPtsOnly
1, # read_video_stream 1, # read_video_stream
...@@ -311,9 +334,25 @@ def _read_video_timestamps_from_memory(file_buffer): ...@@ -311,9 +334,25 @@ def _read_video_timestamps_from_memory(file_buffer):
0, # audio_timebase_num 0, # audio_timebase_num
1, # audio_timebase_den 1, # audio_timebase_den
) )
_vframes, vframe_pts, vtimebase, vfps, _aframes, aframe_pts, atimebase, asample_rate = result _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, \
info = _fill_info(vtimebase, vfps, atimebase, asample_rate) atimebase, asample_rate, aduration = result
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
vframe_pts = vframe_pts.numpy().tolist() vframe_pts = vframe_pts.numpy().tolist()
aframe_pts = aframe_pts.numpy().tolist() aframe_pts = aframe_pts.numpy().tolist()
return vframe_pts, aframe_pts, info return vframe_pts, aframe_pts, info
def _probe_video_from_memory(video_data):
"""
Probe a video in memory.
Return:
info [dict]: contain video meta information, including video_timebase,
video_duration, video_fps, audio_timebase, audio_duration, audio_sample_rate
"""
if not isinstance(video_data, torch.Tensor):
video_data = torch.from_numpy(np.frombuffer(video_data, dtype=np.uint8))
result = torch.ops.video_reader.probe_video_from_memory(video_data)
vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
return info
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment