Commit ed5b2dc3 authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Francisco Massa
Browse files

extend video reader to support fast video probing (#1437)

* extend video reader to support fast video probing

* fix c++ lint

* small fix

* allow to accept input video of type torch.Tensor
parent 7ae1b8c9
......@@ -87,6 +87,22 @@ class Tester(unittest.TestCase):
self.assertTrue(data.equal(lv))
self.assertEqual(info["video_fps"], 5)
@unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen")
def test_probe_video_from_file(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
video_info = io._probe_video_from_file(f_name)
self.assertAlmostEqual(video_info["video_duration"], 2, delta=0.1)
self.assertAlmostEqual(video_info["video_fps"], 5, delta=0.1)
@unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen")
def test_probe_video_from_memory(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
with open(f_name, "rb") as fp:
filebuffer = fp.read()
video_info = io._probe_video_from_memory(filebuffer)
self.assertAlmostEqual(video_info["video_duration"], 2, delta=0.1)
self.assertAlmostEqual(video_info["video_fps"], 5, delta=0.1)
def test_read_timestamps(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
if _video_backend == "pyav":
......
......@@ -31,6 +31,7 @@ from torchvision.io._video_opt import _HAS_VIDEO_OPT
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
CheckerConfig = [
"duration",
"video_fps",
"audio_sample_rate",
# We find for some videos (e.g. HMDB51 videos), the decoded audio frames and pts are
......@@ -44,6 +45,7 @@ GroundTruth = collections.namedtuple(
)
all_check_config = GroundTruth(
duration=0,
video_fps=0,
audio_sample_rate=0,
check_aframes=True,
......@@ -52,36 +54,42 @@ all_check_config = GroundTruth(
test_videos = {
"RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth(
duration=2.0,
video_fps=30.0,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth(
duration=2.0,
video_fps=30.0,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth(
duration=2.0,
video_fps=30.0,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"v_SoccerJuggling_g23_c01.avi": GroundTruth(
duration=8.0,
video_fps=29.97,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"v_SoccerJuggling_g24_c01.avi": GroundTruth(
duration=8.0,
video_fps=29.97,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"R6llTwEh07w.mp4": GroundTruth(
duration=10.0,
video_fps=30.0,
audio_sample_rate=44100,
# PyAv miss one audio frame at the beginning (pts=0)
......@@ -89,6 +97,7 @@ test_videos = {
check_aframe_pts=False,
),
"SOX5yA1l24A.mp4": GroundTruth(
duration=11.0,
video_fps=29.97,
audio_sample_rate=48000,
# PyAv miss one audio frame at the beginning (pts=0)
......@@ -96,6 +105,7 @@ test_videos = {
check_aframe_pts=False,
),
"WUzgd7C1pWA.mp4": GroundTruth(
duration=11.0,
video_fps=29.97,
audio_sample_rate=48000,
# PyAv miss one audio frame at the beginning (pts=0)
......@@ -272,13 +282,22 @@ class TestVideoReader(unittest.TestCase):
def check_separate_decoding_result(self, tv_result, config):
"""check the decoding results from TorchVision decoder
"""
vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = (
tv_result
vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
atimebase, asample_rate, aduration = tv_result
video_duration = vduration.item() * Fraction(
vtimebase[0].item(), vtimebase[1].item()
)
self.assertAlmostEqual(video_duration, config.duration, delta=0.5)
self.assertAlmostEqual(vfps.item(), config.video_fps, delta=0.5)
if asample_rate.numel() > 0:
self.assertEqual(asample_rate.item(), config.audio_sample_rate)
audio_duration = aduration.item() * Fraction(
atimebase[0].item(), atimebase[1].item()
)
self.assertAlmostEqual(audio_duration, config.duration, delta=0.5)
# check if pts of video frames are sorted in ascending order
for i in range(len(vframe_pts) - 1):
self.assertEqual(vframe_pts[i] < vframe_pts[i + 1], True)
......@@ -288,6 +307,20 @@ class TestVideoReader(unittest.TestCase):
for i in range(len(aframe_pts) - 1):
self.assertEqual(aframe_pts[i] < aframe_pts[i + 1], True)
def check_probe_result(self, result, config):
vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
video_duration = vduration.item() * Fraction(
vtimebase[0].item(), vtimebase[1].item()
)
self.assertAlmostEqual(video_duration, config.duration, delta=0.5)
self.assertAlmostEqual(vfps.item(), config.video_fps, delta=0.5)
if asample_rate.numel() > 0:
self.assertEqual(asample_rate.item(), config.audio_sample_rate)
audio_duration = aduration.item() * Fraction(
atimebase[0].item(), atimebase[1].item()
)
self.assertAlmostEqual(audio_duration, config.duration, delta=0.5)
def compare_decoding_result(self, tv_result, ref_result, config=all_check_config):
"""
Compare decoding results from two sources.
......@@ -297,18 +330,17 @@ class TestVideoReader(unittest.TestCase):
decoder or TorchVision decoder with getPtsOnly = 1
config: config of decoding results checker
"""
vframes, vframe_pts, vtimebase, _vfps, aframes, aframe_pts, atimebase, _asample_rate = (
tv_result
)
vframes, vframe_pts, vtimebase, _vfps, _vduration, aframes, aframe_pts, \
atimebase, _asample_rate, _aduration = tv_result
if isinstance(ref_result, list):
# the ref_result is from new video_reader decoder
ref_result = DecoderResult(
vframes=ref_result[0],
vframe_pts=ref_result[1],
vtimebase=ref_result[2],
aframes=ref_result[4],
aframe_pts=ref_result[5],
atimebase=ref_result[6],
aframes=ref_result[5],
aframe_pts=ref_result[6],
atimebase=ref_result[7],
)
if vframes.numel() > 0 and ref_result.vframes.numel() > 0:
......@@ -351,12 +383,12 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for i in range(num_iter):
for test_video, config in test_videos.items():
for _i in range(num_iter):
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
# pass 1: decode all frames using new decoder
_ = torch.ops.video_reader.read_video_from_file(
torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
......@@ -460,9 +492,8 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_den,
)
vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = (
tv_result
)
vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
atimebase, asample_rate, aduration = tv_result
self.assertEqual(vframes.numel() > 0, readVideoStream)
self.assertEqual(vframe_pts.numel() > 0, readVideoStream)
......@@ -489,7 +520,7 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
......@@ -528,7 +559,7 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
......@@ -567,7 +598,7 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
......@@ -606,7 +637,7 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
......@@ -651,7 +682,7 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
......@@ -674,18 +705,17 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_num,
audio_timebase_den,
)
vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, a_sample_rate = (
tv_result
)
vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
atimebase, asample_rate, aduration = tv_result
if aframes.numel() > 0:
self.assertEqual(samples, a_sample_rate.item())
self.assertEqual(samples, asample_rate.item())
self.assertEqual(1, aframes.size(1))
# when audio stream is found
duration = float(aframe_pts[-1]) * float(atimebase[0]) / float(atimebase[1])
self.assertAlmostEqual(
aframes.size(0),
int(duration * a_sample_rate.item()),
delta=0.1 * a_sample_rate.item(),
int(duration * asample_rate.item()),
delta=0.1 * asample_rate.item(),
)
def test_compare_read_video_from_memory_and_file(self):
......@@ -859,7 +889,7 @@ class TestVideoReader(unittest.TestCase):
)
self.assertEqual(tv_result_pts_only[0].numel(), 0)
self.assertEqual(tv_result_pts_only[4].numel(), 0)
self.assertEqual(tv_result_pts_only[5].numel(), 0)
self.compare_decoding_result(tv_result, tv_result_pts_only)
def test_read_video_in_range_from_memory(self):
......@@ -899,9 +929,8 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_num,
audio_timebase_den,
)
vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = (
tv_result
)
vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
atimebase, asample_rate, aduration = tv_result
self.assertAlmostEqual(config.video_fps, vfps.item(), delta=0.01)
for num_frames in [4, 8, 16, 32, 64, 128]:
......@@ -997,6 +1026,24 @@ class TestVideoReader(unittest.TestCase):
# and PyAv
self.compare_decoding_result(tv_result, pyav_result, config)
def test_probe_video_from_file(self):
"""
Test the case when decoder probes a video file
"""
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
probe_result = torch.ops.video_reader.probe_video_from_file(full_path)
self.check_probe_result(probe_result, config)
def test_probe_video_from_memory(self):
"""
Test the case when decoder probes a video in memory
"""
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
probe_result = torch.ops.video_reader.probe_video_from_memory(video_tensor)
self.check_probe_result(probe_result, config)
if __name__ == '__main__':
unittest.main()
......@@ -49,6 +49,7 @@ void FfmpegAudioStream::updateStreamDecodeParams() {
mediaFormat_.format.audio.timeBaseDen =
inputCtx_->streams[index_]->time_base.den;
}
mediaFormat_.format.audio.duration = inputCtx_->streams[index_]->duration;
}
int FfmpegAudioStream::initFormat() {
......
......@@ -220,6 +220,30 @@ int FfmpegDecoder::decodeMemory(
return ret;
}
int FfmpegDecoder::probeFile(
unique_ptr<DecoderParameters> params,
const string& fileName,
DecoderOutput& decoderOutput) {
VLOG(1) << "probe file: " << fileName;
FfmpegAvioContext ioctx;
return probeVideo(std::move(params), fileName, true, ioctx, decoderOutput);
}
int FfmpegDecoder::probeMemory(
unique_ptr<DecoderParameters> params,
const uint8_t* buffer,
int64_t size,
DecoderOutput& decoderOutput) {
VLOG(1) << "probe video data in memory";
FfmpegAvioContext ioctx;
int ret = ioctx.initAVIOContext(buffer, size);
if (ret == 0) {
ret =
probeVideo(std::move(params), string(""), false, ioctx, decoderOutput);
}
return ret;
}
void FfmpegDecoder::cleanUp() {
if (formatCtx_) {
for (auto& stream : streams_) {
......@@ -320,6 +344,16 @@ int FfmpegDecoder::decodeLoop(
return ret;
}
int FfmpegDecoder::probeVideo(
unique_ptr<DecoderParameters> params,
const std::string& filename,
bool isDecodeFile,
FfmpegAvioContext& ioctx,
DecoderOutput& decoderOutput) {
params_ = std::move(params);
return init(filename, isDecodeFile, ioctx, decoderOutput);
}
bool FfmpegDecoder::initStreams() {
for (auto it = params_->formats.begin(); it != params_->formats.end(); ++it) {
AVMediaType mediaType;
......
......@@ -75,6 +75,19 @@ class FfmpegDecoder {
const uint8_t* buffer,
int64_t size,
DecoderOutput& decoderOutput);
// return 0 on success
// return negative number on failure
int probeFile(
std::unique_ptr<DecoderParameters> params,
const std::string& filename,
DecoderOutput& decoderOutput);
// return 0 on success
// return negative number on failure
int probeMemory(
std::unique_ptr<DecoderParameters> params,
const uint8_t* buffer,
int64_t size,
DecoderOutput& decoderOutput);
void cleanUp();
......@@ -95,6 +108,13 @@ class FfmpegDecoder {
FfmpegAvioContext& ioctx,
DecoderOutput& decoderOutput);
int probeVideo(
std::unique_ptr<DecoderParameters> params,
const std::string& filename,
bool isDecodeFile,
FfmpegAvioContext& ioctx,
DecoderOutput& decoderOutput);
bool initStreams();
void flushStreams(DecoderOutput& decoderOutput);
......
......@@ -48,6 +48,7 @@ void FfmpegVideoStream::updateStreamDecodeParams() {
mediaFormat_.format.video.timeBaseDen =
inputCtx_->streams[index_]->time_base.den;
}
mediaFormat_.format.video.duration = inputCtx_->streams[index_]->duration;
}
int FfmpegVideoStream::initFormat() {
......
......@@ -48,6 +48,7 @@ struct VideoFormat {
int timeBaseNum{0};
int timeBaseDen{1}; // numerator and denominator of time base
float fps{0.0};
int64_t duration{0}; // duration of the stream, in stream time base
};
struct AudioFormat {
......@@ -60,6 +61,7 @@ struct AudioFormat {
int64_t startPts{0}, endPts{0}; // Start and end presentation timestamp
int timeBaseNum{0};
int timeBaseDen{1}; // numerator and denominator of time base
int64_t duration{0}; // duration of the stream, in stream time base
};
union FormatUnion {
......
......@@ -27,8 +27,6 @@ PyMODINIT_FUNC PyInit_video_reader(void) {
namespace video_reader {
bool glog_initialized = false;
class UnknownPixelFormatException : public exception {
const char* what() const throw() override {
return "Unknown pixel format";
......@@ -167,11 +165,6 @@ torch::List<torch::Tensor> readVideo(
int64_t audioEndPts,
int64_t audioTimeBaseNum,
int64_t audioTimeBaseDen) {
if (!glog_initialized) {
glog_initialized = true;
// google::InitGoogleLogging("VideoReader");
}
unique_ptr<DecoderParameters> params = util::getDecoderParams(
seekFrameMargin,
getPtsOnly,
......@@ -209,6 +202,8 @@ torch::List<torch::Tensor> readVideo(
torch::Tensor videoFramePts = torch::zeros({0}, torch::kLong);
torch::Tensor videoTimeBase = torch::zeros({0}, torch::kInt);
torch::Tensor videoFps = torch::zeros({0}, torch::kFloat);
torch::Tensor videoDuration = torch::zeros({0}, torch::kLong);
if (readVideoStream == 1) {
auto it = decoderOutput.media_data_.find(TYPE_VIDEO);
if (it != decoderOutput.media_data_.end()) {
......@@ -236,6 +231,10 @@ torch::List<torch::Tensor> readVideo(
videoFps = torch::zeros({1}, torch::kFloat);
float* videoFpsData = videoFps.data_ptr<float>();
videoFpsData[0] = it->second.format_.video.fps;
videoDuration = torch::zeros({1}, torch::kLong);
int64_t* videoDurationData = videoDuration.data_ptr<int64_t>();
videoDurationData[0] = it->second.format_.video.duration;
} else {
VLOG(1) << "Miss video stream";
}
......@@ -246,6 +245,7 @@ torch::List<torch::Tensor> readVideo(
torch::Tensor audioFramePts = torch::zeros({0}, torch::kLong);
torch::Tensor audioTimeBase = torch::zeros({0}, torch::kInt);
torch::Tensor audioSampleRate = torch::zeros({0}, torch::kInt);
torch::Tensor audioDuration = torch::zeros({0}, torch::kLong);
if (readAudioStream == 1) {
auto it = decoderOutput.media_data_.find(TYPE_AUDIO);
if (it != decoderOutput.media_data_.end()) {
......@@ -275,6 +275,10 @@ torch::List<torch::Tensor> readVideo(
audioSampleRate = torch::zeros({1}, torch::kInt);
int* audioSampleRateData = audioSampleRate.data_ptr<int>();
audioSampleRateData[0] = it->second.format_.audio.samples;
audioDuration = torch::zeros({1}, torch::kLong);
int64_t* audioDurationData = audioDuration.data_ptr<int64_t>();
audioDurationData[0] = it->second.format_.audio.duration;
} else {
VLOG(1) << "Miss audio stream";
}
......@@ -285,10 +289,12 @@ torch::List<torch::Tensor> readVideo(
result.push_back(std::move(videoFramePts));
result.push_back(std::move(videoTimeBase));
result.push_back(std::move(videoFps));
result.push_back(std::move(videoDuration));
result.push_back(std::move(audioFrame));
result.push_back(std::move(audioFramePts));
result.push_back(std::move(audioTimeBase));
result.push_back(std::move(audioSampleRate));
result.push_back(std::move(audioDuration));
return result;
}
......@@ -378,10 +384,117 @@ torch::List<torch::Tensor> readVideoFromFile(
audioTimeBaseDen);
}
torch::List<torch::Tensor> probeVideo(
bool isReadFile,
const torch::Tensor& input_video,
std::string videoPath) {
unique_ptr<DecoderParameters> params = util::getDecoderParams(
0, // seekFrameMargin
0, // getPtsOnly
1, // readVideoStream
0, // width
0, // height
0, // minDimension
0, // videoStartPts
0, // videoEndPts
0, // videoTimeBaseNum
1, // videoTimeBaseDen
1, // readAudioStream
0, // audioSamples
0, // audioChannels
0, // audioStartPts
0, // audioEndPts
0, // audioTimeBaseNum
1 // audioTimeBaseDen
);
FfmpegDecoder decoder;
DecoderOutput decoderOutput;
if (isReadFile) {
decoder.probeFile(std::move(params), videoPath, decoderOutput);
} else {
decoder.probeMemory(
std::move(params),
input_video.data_ptr<uint8_t>(),
input_video.size(0),
decoderOutput);
}
// video section
torch::Tensor videoTimeBase = torch::zeros({0}, torch::kInt);
torch::Tensor videoFps = torch::zeros({0}, torch::kFloat);
torch::Tensor videoDuration = torch::zeros({0}, torch::kLong);
auto it = decoderOutput.media_data_.find(TYPE_VIDEO);
if (it != decoderOutput.media_data_.end()) {
VLOG(1) << "Find video stream";
videoTimeBase = torch::zeros({2}, torch::kInt);
int* videoTimeBaseData = videoTimeBase.data_ptr<int>();
videoTimeBaseData[0] = it->second.format_.video.timeBaseNum;
videoTimeBaseData[1] = it->second.format_.video.timeBaseDen;
videoFps = torch::zeros({1}, torch::kFloat);
float* videoFpsData = videoFps.data_ptr<float>();
videoFpsData[0] = it->second.format_.video.fps;
videoDuration = torch::zeros({1}, torch::kLong);
int64_t* videoDurationData = videoDuration.data_ptr<int64_t>();
videoDurationData[0] = it->second.format_.video.duration;
} else {
VLOG(1) << "Miss video stream";
}
// audio section
torch::Tensor audioTimeBase = torch::zeros({0}, torch::kInt);
torch::Tensor audioSampleRate = torch::zeros({0}, torch::kInt);
torch::Tensor audioDuration = torch::zeros({0}, torch::kLong);
it = decoderOutput.media_data_.find(TYPE_AUDIO);
if (it != decoderOutput.media_data_.end()) {
VLOG(1) << "Find audio stream";
audioTimeBase = torch::zeros({2}, torch::kInt);
int* audioTimeBaseData = audioTimeBase.data_ptr<int>();
audioTimeBaseData[0] = it->second.format_.audio.timeBaseNum;
audioTimeBaseData[1] = it->second.format_.audio.timeBaseDen;
audioSampleRate = torch::zeros({1}, torch::kInt);
int* audioSampleRateData = audioSampleRate.data_ptr<int>();
audioSampleRateData[0] = it->second.format_.audio.samples;
audioDuration = torch::zeros({1}, torch::kLong);
int64_t* audioDurationData = audioDuration.data_ptr<int64_t>();
audioDurationData[0] = it->second.format_.audio.duration;
} else {
VLOG(1) << "Miss audio stream";
}
torch::List<torch::Tensor> result;
result.push_back(std::move(videoTimeBase));
result.push_back(std::move(videoFps));
result.push_back(std::move(videoDuration));
result.push_back(std::move(audioTimeBase));
result.push_back(std::move(audioSampleRate));
result.push_back(std::move(audioDuration));
return result;
}
torch::List<torch::Tensor> probeVideoFromMemory(torch::Tensor input_video) {
return probeVideo(false, input_video, "");
}
torch::List<torch::Tensor> probeVideoFromFile(std::string videoPath) {
torch::Tensor dummy_input_video = torch::ones({0});
return probeVideo(true, dummy_input_video, videoPath);
}
} // namespace video_reader
static auto registry = torch::RegisterOperators()
.op("video_reader::read_video_from_memory",
&video_reader::readVideoFromMemory)
.op("video_reader::read_video_from_file",
&video_reader::readVideoFromFile);
&video_reader::readVideoFromFile)
.op("video_reader::probe_video_from_memory",
&video_reader::probeVideoFromMemory)
.op("video_reader::probe_video_from_file",
&video_reader::probeVideoFromFile);
......@@ -2,15 +2,17 @@ from .video import write_video, read_video, read_video_timestamps
from ._video_opt import (
_read_video_from_file,
_read_video_timestamps_from_file,
_probe_video_from_file,
_read_video_from_memory,
_read_video_timestamps_from_memory,
_probe_video_from_memory,
_HAS_VIDEO_OPT,
)
__all__ = [
'write_video', 'read_video', 'read_video_timestamps',
'_read_video_from_file', '_read_video_timestamps_from_file',
'_read_video_from_memory', '_read_video_timestamps_from_memory',
'_read_video_from_file', '_read_video_timestamps_from_file', '_probe_video_from_file',
'_read_video_from_memory', '_read_video_timestamps_from_memory', '_probe_video_from_memory',
'_HAS_VIDEO_OPT',
]
......@@ -26,14 +26,20 @@ def _validate_pts(pts_range):
start pts: %d and end pts: %d""" % (pts_range[0], pts_range[1])
def _fill_info(vtimebase, vfps, atimebase, asample_rate):
def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration):
info = {}
if vtimebase.numel() > 0:
info["video_timebase"] = Fraction(vtimebase[0].item(), vtimebase[1].item())
if vduration.numel() > 0:
video_duration = vduration.item() * info["video_timebase"]
info["video_duration"] = video_duration
if vfps.numel() > 0:
info["video_fps"] = vfps.item()
if atimebase.numel() > 0:
info["audio_timebase"] = Fraction(atimebase[0].item(), atimebase[1].item())
if aduration.numel() > 0:
audio_duration = aduration.item() * info["audio_timebase"]
info["audio_duration"] = audio_duration
if asample_rate.numel() > 0:
info["audio_sample_rate"] = asample_rate.item()
......@@ -141,8 +147,9 @@ def _read_video_from_file(
audio_timebase.numerator,
audio_timebase.denominator,
)
vframes, _vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = result
info = _fill_info(vtimebase, vfps, atimebase, asample_rate)
vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, \
asample_rate, aduration = result
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
if aframes.numel() > 0:
# when audio stream is found
aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
......@@ -175,16 +182,30 @@ def _read_video_timestamps_from_file(filename):
0, # audio_timebase_num
1, # audio_timebase_den
)
_vframes, vframe_pts, vtimebase, vfps, _aframes, aframe_pts, atimebase, asample_rate = result
info = _fill_info(vtimebase, vfps, atimebase, asample_rate)
_vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, \
asample_rate, aduration = result
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
vframe_pts = vframe_pts.numpy().tolist()
aframe_pts = aframe_pts.numpy().tolist()
return vframe_pts, aframe_pts, info
def _probe_video_from_file(filename):
"""
Probe a video file.
Return:
info [dict]: contain video meta information, including video_timebase,
video_duration, video_fps, audio_timebase, audio_duration, audio_sample_rate
"""
result = torch.ops.video_reader.probe_video_from_file(filename)
vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
return info
def _read_video_from_memory(
file_buffer,
video_data,
seek_frame_margin=0.25,
read_video_stream=1,
video_width=0,
......@@ -204,8 +225,8 @@ def _read_video_from_memory(
Args
----------
file_buffer : buffer
buffer of compressed video content
video_data : data type could be 1) torch.Tensor, dtype=torch.int8 or 2) python bytes
compressed video content stored in either 1) torch.Tensor 2) python bytes
seek_frame_margin: double, optional
seeking frame in the stream is imprecise. Thus, when video_start_pts is specified,
we seek the pts earlier by seek_frame_margin seconds
......@@ -252,10 +273,11 @@ def _read_video_from_memory(
_validate_pts(video_pts_range)
_validate_pts(audio_pts_range)
video_tensor = torch.from_numpy(np.frombuffer(file_buffer, dtype=np.uint8))
if not isinstance(video_data, torch.Tensor):
video_data = torch.from_numpy(np.frombuffer(video_data, dtype=np.uint8))
result = torch.ops.video_reader.read_video_from_memory(
video_tensor,
video_data,
seek_frame_margin,
0, # getPtsOnly
read_video_stream,
......@@ -275,24 +297,25 @@ def _read_video_from_memory(
audio_timebase.denominator,
)
vframes, _vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = result
info = _fill_info(vtimebase, vfps, atimebase, asample_rate)
vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
atimebase, asample_rate, aduration = result
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
if aframes.numel() > 0:
# when audio stream is found
aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
return vframes, aframes, info
def _read_video_timestamps_from_memory(file_buffer):
def _read_video_timestamps_from_memory(video_data):
"""
Decode all frames in the video. Only pts (presentation timestamp) is returned.
The actual frame pixel data is not copied. Thus, read_video_timestamps(...)
is much faster than read_video(...)
"""
video_tensor = torch.from_numpy(np.frombuffer(file_buffer, dtype=np.uint8))
if not isinstance(video_data, torch.Tensor):
video_data = torch.from_numpy(np.frombuffer(video_data, dtype=np.uint8))
result = torch.ops.video_reader.read_video_from_memory(
video_tensor,
video_data,
0, # seek_frame_margin
1, # getPtsOnly
1, # read_video_stream
......@@ -311,9 +334,25 @@ def _read_video_timestamps_from_memory(file_buffer):
0, # audio_timebase_num
1, # audio_timebase_den
)
_vframes, vframe_pts, vtimebase, vfps, _aframes, aframe_pts, atimebase, asample_rate = result
info = _fill_info(vtimebase, vfps, atimebase, asample_rate)
_vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, \
atimebase, asample_rate, aduration = result
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
vframe_pts = vframe_pts.numpy().tolist()
aframe_pts = aframe_pts.numpy().tolist()
return vframe_pts, aframe_pts, info
def _probe_video_from_memory(video_data):
"""
Probe a video in memory.
Return:
info [dict]: contain video meta information, including video_timebase,
video_duration, video_fps, audio_timebase, audio_duration, audio_sample_rate
"""
if not isinstance(video_data, torch.Tensor):
video_data = torch.from_numpy(np.frombuffer(video_data, dtype=np.uint8))
result = torch.ops.video_reader.probe_video_from_memory(video_data)
vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
return info
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment