#include "video_reader.h" #ifdef USE_PYTHON #include #endif #include "../decoder/memory_buffer.h" #include "../decoder/sync_decoder.h" #ifdef USE_PYTHON // If we are in a Windows environment, we need to define // initialization functions for the _custom_ops extension #ifdef _WIN32 PyMODINIT_FUNC PyInit_video_reader(void) { // No need to do anything. return NULL; } #endif #endif // USE_PYTHONs using namespace ffmpeg; namespace vision { namespace video_reader { namespace { const AVPixelFormat defaultVideoPixelFormat = AV_PIX_FMT_RGB24; const AVSampleFormat defaultAudioSampleFormat = AV_SAMPLE_FMT_FLT; const AVRational timeBaseQ = AVRational{1, AV_TIME_BASE}; const size_t decoderTimeoutMs = 600000; // A jitter can be added to the end of the range to avoid conversion/rounding // error, small value 100us won't be enough to select the next frame, but enough // to compensate rounding error due to the multiple conversions. const size_t timeBaseJitterUs = 100; DecoderParameters getDecoderParams( int64_t videoStartUs, int64_t videoEndUs, double seekFrameMarginUs, int64_t getPtsOnly, int64_t readVideoStream, int videoWidth, int videoHeight, int videoMinDimension, int videoMaxDimension, int64_t readAudioStream, int audioSamples, int audioChannels) { DecoderParameters params; params.headerOnly = getPtsOnly != 0; params.seekAccuracy = seekFrameMarginUs; params.startOffset = videoStartUs; params.endOffset = videoEndUs; params.timeoutMs = decoderTimeoutMs; params.preventStaleness = false; if (readVideoStream == 1) { MediaFormat videoFormat(0); videoFormat.type = TYPE_VIDEO; videoFormat.format.video.format = defaultVideoPixelFormat; videoFormat.format.video.width = videoWidth; videoFormat.format.video.height = videoHeight; videoFormat.format.video.minDimension = videoMinDimension; videoFormat.format.video.maxDimension = videoMaxDimension; params.formats.insert(videoFormat); } if (readAudioStream == 1) { MediaFormat audioFormat; audioFormat.type = TYPE_AUDIO; audioFormat.format.audio.format = defaultAudioSampleFormat; audioFormat.format.audio.samples = audioSamples; audioFormat.format.audio.channels = audioChannels; params.formats.insert(audioFormat); } return params; } // returns number of written bytes template size_t fillTensor( std::vector& msgs, torch::Tensor& frame, torch::Tensor& framePts, int64_t num, int64_t den) { if (msgs.empty()) { return 0; } T* frameData = frame.numel() > 0 ? frame.data_ptr() : nullptr; int64_t* framePtsData = framePts.data_ptr(); CHECK_EQ(framePts.size(0), (int64_t)msgs.size()); size_t avgElementsInFrame = frame.numel() / msgs.size(); size_t offset = 0; for (size_t i = 0; i < msgs.size(); ++i) { const auto& msg = msgs[i]; // convert pts into original time_base AVRational avr = AVRational{(int)num, (int)den}; framePtsData[i] = av_rescale_q(msg.header.pts, timeBaseQ, avr); VLOG(2) << "PTS type: " << sizeof(T) << ", us: " << msg.header.pts << ", original: " << framePtsData[i]; if (frameData) { auto sizeInBytes = msg.payload->length(); memcpy(frameData + offset, msg.payload->data(), sizeInBytes); if (sizeof(T) == sizeof(uint8_t)) { // Video - move by allocated frame size offset += avgElementsInFrame / sizeof(T); } else { // Audio - move by number of samples offset += sizeInBytes / sizeof(T); } } } return offset * sizeof(T); } size_t fillVideoTensor( std::vector& msgs, torch::Tensor& videoFrame, torch::Tensor& videoFramePts, int64_t num, int64_t den) { return fillTensor(msgs, videoFrame, videoFramePts, num, den); } size_t fillAudioTensor( std::vector& msgs, torch::Tensor& audioFrame, torch::Tensor& audioFramePts, int64_t num, int64_t den) { return fillTensor(msgs, audioFrame, audioFramePts, num, den); } void offsetsToUs( double& seekFrameMargin, int64_t readVideoStream, int64_t videoStartPts, int64_t videoEndPts, int64_t videoTimeBaseNum, int64_t videoTimeBaseDen, int64_t readAudioStream, int64_t audioStartPts, int64_t audioEndPts, int64_t audioTimeBaseNum, int64_t audioTimeBaseDen, int64_t& videoStartUs, int64_t& videoEndUs) { seekFrameMargin *= AV_TIME_BASE; videoStartUs = 0; videoEndUs = -1; if (readVideoStream) { AVRational vr = AVRational{(int)videoTimeBaseNum, (int)videoTimeBaseDen}; if (videoStartPts > 0) { videoStartUs = av_rescale_q(videoStartPts, vr, timeBaseQ); } if (videoEndPts > 0) { // Add jitter to the end of the range to avoid conversion/rounding error. // Small value 100us won't be enough to select the next frame, but enough // to compensate rounding error due to the multiple conversions. videoEndUs = timeBaseJitterUs + av_rescale_q(videoEndPts, vr, timeBaseQ); } } else if (readAudioStream) { AVRational ar = AVRational{(int)audioTimeBaseNum, (int)audioTimeBaseDen}; if (audioStartPts > 0) { videoStartUs = av_rescale_q(audioStartPts, ar, timeBaseQ); } if (audioEndPts > 0) { // Add jitter to the end of the range to avoid conversion/rounding error. // Small value 100us won't be enough to select the next frame, but enough // to compensate rounding error due to the multiple conversions. videoEndUs = timeBaseJitterUs + av_rescale_q(audioEndPts, ar, timeBaseQ); } } } torch::List readVideo( bool isReadFile, const torch::Tensor& input_video, std::string videoPath, double seekFrameMargin, int64_t getPtsOnly, int64_t readVideoStream, int64_t width, int64_t height, int64_t minDimension, int64_t maxDimension, int64_t videoStartPts, int64_t videoEndPts, int64_t videoTimeBaseNum, int64_t videoTimeBaseDen, int64_t readAudioStream, int64_t audioSamples, int64_t audioChannels, int64_t audioStartPts, int64_t audioEndPts, int64_t audioTimeBaseNum, int64_t audioTimeBaseDen) { int64_t videoStartUs, videoEndUs; offsetsToUs( seekFrameMargin, readVideoStream, videoStartPts, videoEndPts, videoTimeBaseNum, videoTimeBaseDen, readAudioStream, audioStartPts, audioEndPts, audioTimeBaseNum, audioTimeBaseDen, videoStartUs, videoEndUs); DecoderParameters params = getDecoderParams( videoStartUs, // videoStartPts videoEndUs, // videoEndPts seekFrameMargin, // seekFrameMargin getPtsOnly, // getPtsOnly readVideoStream, // readVideoStream width, // width height, // height minDimension, // minDimension maxDimension, // maxDimension readAudioStream, // readAudioStream audioSamples, // audioSamples audioChannels // audioChannels ); SyncDecoder decoder; std::vector audioMessages, videoMessages; DecoderInCallback callback = nullptr; std::string logMessage, logType; if (isReadFile) { params.uri = videoPath; logType = "file"; logMessage = videoPath; } else { callback = MemoryBuffer::getCallback( input_video.data_ptr(), input_video.size(0)); logType = "memory"; logMessage = std::to_string(input_video.size(0)); } VLOG(1) << "Video decoding from " << logType << " [" << logMessage << "] has started"; const auto now = std::chrono::system_clock::now(); bool succeeded; DecoderMetadata audioMetadata, videoMetadata; std::vector metadata; if ((succeeded = decoder.init(params, std::move(callback), &metadata))) { for (const auto& header : metadata) { if (header.format.type == TYPE_VIDEO) { videoMetadata = header; } else if (header.format.type == TYPE_AUDIO) { audioMetadata = header; } } int res; DecoderOutputMessage msg; while (0 == (res = decoder.decode(&msg, decoderTimeoutMs))) { if (msg.header.format.type == TYPE_VIDEO) { videoMessages.push_back(std::move(msg)); } if (msg.header.format.type == TYPE_AUDIO) { audioMessages.push_back(std::move(msg)); } msg.payload.reset(); } } else { LOG(ERROR) << "Decoder initialization has failed"; } const auto then = std::chrono::system_clock::now(); VLOG(1) << "Video decoding from " << logType << " [" << logMessage << "] has finished, " << std::chrono::duration_cast(then - now) .count() << " us"; decoder.shutdown(); // video section torch::Tensor videoFrame = torch::zeros({0}, torch::kByte); torch::Tensor videoFramePts = torch::zeros({0}, torch::kLong); torch::Tensor videoTimeBase = torch::zeros({0}, torch::kInt); torch::Tensor videoFps = torch::zeros({0}, torch::kFloat); torch::Tensor videoDuration = torch::zeros({0}, torch::kLong); if (succeeded && readVideoStream == 1) { if (!videoMessages.empty()) { const auto& header = videoMetadata; const auto& format = header.format.format.video; int numVideoFrames = videoMessages.size(); int outHeight = format.height; int outWidth = format.width; int numChannels = 3; // decoder guarantees the default AV_PIX_FMT_RGB24 size_t expectedWrittenBytes = 0; if (getPtsOnly == 0) { videoFrame = torch::zeros( {numVideoFrames, outHeight, outWidth, numChannels}, torch::kByte); expectedWrittenBytes = (size_t)numVideoFrames * outHeight * outWidth * numChannels; } videoFramePts = torch::zeros({numVideoFrames}, torch::kLong); VLOG(2) << "video duration: " << header.duration << ", fps: " << header.fps << ", num: " << header.num << ", den: " << header.den << ", num frames: " << numVideoFrames; auto numberWrittenBytes = fillVideoTensor( videoMessages, videoFrame, videoFramePts, header.num, header.den); CHECK_EQ(numberWrittenBytes, expectedWrittenBytes); videoTimeBase = torch::zeros({2}, torch::kInt); int* videoTimeBaseData = videoTimeBase.data_ptr(); videoTimeBaseData[0] = header.num; videoTimeBaseData[1] = header.den; videoFps = torch::zeros({1}, torch::kFloat); float* videoFpsData = videoFps.data_ptr(); videoFpsData[0] = header.fps; videoDuration = torch::zeros({1}, torch::kLong); int64_t* videoDurationData = videoDuration.data_ptr(); AVRational vr = AVRational{(int)header.num, (int)header.den}; videoDurationData[0] = av_rescale_q(header.duration, timeBaseQ, vr); VLOG(1) << "Video decoding from " << logType << " [" << logMessage << "] filled video tensors"; } else { VLOG(1) << "Miss video stream"; } } // audio section torch::Tensor audioFrame = torch::zeros({0}, torch::kFloat); torch::Tensor audioFramePts = torch::zeros({0}, torch::kLong); torch::Tensor audioTimeBase = torch::zeros({0}, torch::kInt); torch::Tensor audioSampleRate = torch::zeros({0}, torch::kInt); torch::Tensor audioDuration = torch::zeros({0}, torch::kLong); if (succeeded && readAudioStream == 1) { if (!audioMessages.empty()) { const auto& header = audioMetadata; const auto& format = header.format.format.audio; int64_t outAudioChannels = format.channels; int bytesPerSample = av_get_bytes_per_sample(static_cast(format.format)); int numAudioFrames = audioMessages.size(); int64_t numAudioSamples = 0; if (getPtsOnly == 0) { int64_t frameSizeTotal = 0; for (auto const& audioMessage : audioMessages) { frameSizeTotal += audioMessage.payload->length(); } CHECK_EQ(frameSizeTotal % (outAudioChannels * bytesPerSample), 0); numAudioSamples = frameSizeTotal / (outAudioChannels * bytesPerSample); audioFrame = torch::zeros({numAudioSamples, outAudioChannels}, torch::kFloat); } audioFramePts = torch::zeros({numAudioFrames}, torch::kLong); VLOG(2) << "audio duration: " << header.duration << ", channels: " << format.channels << ", sample rate: " << format.samples << ", num: " << header.num << ", den: " << header.den; auto numberWrittenBytes = fillAudioTensor( audioMessages, audioFrame, audioFramePts, header.num, header.den); CHECK_EQ( numberWrittenBytes, numAudioSamples * outAudioChannels * sizeof(float)); audioTimeBase = torch::zeros({2}, torch::kInt); int* audioTimeBaseData = audioTimeBase.data_ptr(); audioTimeBaseData[0] = header.num; audioTimeBaseData[1] = header.den; audioSampleRate = torch::zeros({1}, torch::kInt); int* audioSampleRateData = audioSampleRate.data_ptr(); audioSampleRateData[0] = format.samples; audioDuration = torch::zeros({1}, torch::kLong); int64_t* audioDurationData = audioDuration.data_ptr(); AVRational ar = AVRational{(int)header.num, (int)header.den}; audioDurationData[0] = av_rescale_q(header.duration, timeBaseQ, ar); VLOG(1) << "Video decoding from " << logType << " [" << logMessage << "] filled audio tensors"; } else { VLOG(1) << "Miss audio stream"; } } torch::List result; result.push_back(std::move(videoFrame)); result.push_back(std::move(videoFramePts)); result.push_back(std::move(videoTimeBase)); result.push_back(std::move(videoFps)); result.push_back(std::move(videoDuration)); result.push_back(std::move(audioFrame)); result.push_back(std::move(audioFramePts)); result.push_back(std::move(audioTimeBase)); result.push_back(std::move(audioSampleRate)); result.push_back(std::move(audioDuration)); VLOG(1) << "Video decoding from " << logType << " [" << logMessage << "] about to return"; return result; } torch::List probeVideo( bool isReadFile, const torch::Tensor& input_video, std::string videoPath) { DecoderParameters params = getDecoderParams( 0, // videoStartUs -1, // videoEndUs 0, // seekFrameMargin 1, // getPtsOnly 1, // readVideoStream 0, // width 0, // height 0, // minDimension 0, // maxDimension 1, // readAudioStream 0, // audioSamples 0 // audioChannels ); SyncDecoder decoder; DecoderInCallback callback = nullptr; std::string logMessage, logType; if (isReadFile) { params.uri = videoPath; logType = "file"; logMessage = videoPath; } else { callback = MemoryBuffer::getCallback( input_video.data_ptr(), input_video.size(0)); logType = "memory"; logMessage = std::to_string(input_video.size(0)); } VLOG(1) << "Video probing from " << logType << " [" << logMessage << "] has started"; const auto now = std::chrono::system_clock::now(); bool succeeded; bool gotAudio = false, gotVideo = false; DecoderMetadata audioMetadata, videoMetadata; std::vector metadata; if ((succeeded = decoder.init(params, std::move(callback), &metadata))) { for (const auto& header : metadata) { if (header.format.type == TYPE_VIDEO) { gotVideo = true; videoMetadata = header; } else if (header.format.type == TYPE_AUDIO) { gotAudio = true; audioMetadata = header; } } const auto then = std::chrono::system_clock::now(); VLOG(1) << "Video probing from " << logType << " [" << logMessage << "] has finished, " << std::chrono::duration_cast(then - now) .count() << " us"; } else { LOG(ERROR) << "Decoder initialization has failed"; } decoder.shutdown(); // video section torch::Tensor videoTimeBase = torch::zeros({0}, torch::kInt); torch::Tensor videoFps = torch::zeros({0}, torch::kFloat); torch::Tensor videoDuration = torch::zeros({0}, torch::kLong); if (succeeded && gotVideo) { videoTimeBase = torch::zeros({2}, torch::kInt); int* videoTimeBaseData = videoTimeBase.data_ptr(); const auto& header = videoMetadata; videoTimeBaseData[0] = header.num; videoTimeBaseData[1] = header.den; videoFps = torch::zeros({1}, torch::kFloat); float* videoFpsData = videoFps.data_ptr(); videoFpsData[0] = header.fps; videoDuration = torch::zeros({1}, torch::kLong); int64_t* videoDurationData = videoDuration.data_ptr(); AVRational avr = AVRational{(int)header.num, (int)header.den}; videoDurationData[0] = av_rescale_q(header.duration, timeBaseQ, avr); VLOG(2) << "Prob fps: " << header.fps << ", duration: " << header.duration << ", num: " << header.num << ", den: " << header.den; VLOG(1) << "Video probing from " << logType << " [" << logMessage << "] filled video tensors"; } else { LOG(ERROR) << "Miss video stream"; } // audio section torch::Tensor audioTimeBase = torch::zeros({0}, torch::kInt); torch::Tensor audioSampleRate = torch::zeros({0}, torch::kInt); torch::Tensor audioDuration = torch::zeros({0}, torch::kLong); if (succeeded && gotAudio) { audioTimeBase = torch::zeros({2}, torch::kInt); int* audioTimeBaseData = audioTimeBase.data_ptr(); const auto& header = audioMetadata; const auto& media = header.format; const auto& format = media.format.audio; audioTimeBaseData[0] = header.num; audioTimeBaseData[1] = header.den; audioSampleRate = torch::zeros({1}, torch::kInt); int* audioSampleRateData = audioSampleRate.data_ptr(); audioSampleRateData[0] = format.samples; audioDuration = torch::zeros({1}, torch::kLong); int64_t* audioDurationData = audioDuration.data_ptr(); AVRational avr = AVRational{(int)header.num, (int)header.den}; audioDurationData[0] = av_rescale_q(header.duration, timeBaseQ, avr); VLOG(2) << "Prob sample rate: " << format.samples << ", duration: " << header.duration << ", num: " << header.num << ", den: " << header.den; VLOG(1) << "Video probing from " << logType << " [" << logMessage << "] filled audio tensors"; } else { VLOG(1) << "Miss audio stream"; } torch::List result; result.push_back(std::move(videoTimeBase)); result.push_back(std::move(videoFps)); result.push_back(std::move(videoDuration)); result.push_back(std::move(audioTimeBase)); result.push_back(std::move(audioSampleRate)); result.push_back(std::move(audioDuration)); VLOG(1) << "Video probing from " << logType << " [" << logMessage << "] is about to return"; return result; } } // namespace torch::List read_video_from_memory( torch::Tensor input_video, double seekFrameMargin, int64_t getPtsOnly, int64_t readVideoStream, int64_t width, int64_t height, int64_t minDimension, int64_t maxDimension, int64_t videoStartPts, int64_t videoEndPts, int64_t videoTimeBaseNum, int64_t videoTimeBaseDen, int64_t readAudioStream, int64_t audioSamples, int64_t audioChannels, int64_t audioStartPts, int64_t audioEndPts, int64_t audioTimeBaseNum, int64_t audioTimeBaseDen) { C10_LOG_API_USAGE_ONCE( "torchvision.csrc.io.video_reader.video_reader.read_video_from_memory"); return readVideo( false, input_video, "", // videoPath seekFrameMargin, getPtsOnly, readVideoStream, width, height, minDimension, maxDimension, videoStartPts, videoEndPts, videoTimeBaseNum, videoTimeBaseDen, readAudioStream, audioSamples, audioChannels, audioStartPts, audioEndPts, audioTimeBaseNum, audioTimeBaseDen); } torch::List read_video_from_file( std::string videoPath, double seekFrameMargin, int64_t getPtsOnly, int64_t readVideoStream, int64_t width, int64_t height, int64_t minDimension, int64_t maxDimension, int64_t videoStartPts, int64_t videoEndPts, int64_t videoTimeBaseNum, int64_t videoTimeBaseDen, int64_t readAudioStream, int64_t audioSamples, int64_t audioChannels, int64_t audioStartPts, int64_t audioEndPts, int64_t audioTimeBaseNum, int64_t audioTimeBaseDen) { C10_LOG_API_USAGE_ONCE( "torchvision.csrc.io.video_reader.video_reader.read_video_from_file"); torch::Tensor dummy_input_video = torch::ones({0}); return readVideo( true, dummy_input_video, videoPath, seekFrameMargin, getPtsOnly, readVideoStream, width, height, minDimension, maxDimension, videoStartPts, videoEndPts, videoTimeBaseNum, videoTimeBaseDen, readAudioStream, audioSamples, audioChannels, audioStartPts, audioEndPts, audioTimeBaseNum, audioTimeBaseDen); } torch::List probe_video_from_memory(torch::Tensor input_video) { C10_LOG_API_USAGE_ONCE( "torchvision.csrc.io.video_reader.video_reader.probe_video_from_memory"); return probeVideo(false, input_video, ""); } torch::List probe_video_from_file(std::string videoPath) { C10_LOG_API_USAGE_ONCE( "torchvision.csrc.io.video_reader.video_reader.probe_video_from_file"); torch::Tensor dummy_input_video = torch::ones({0}); return probeVideo(true, dummy_input_video, videoPath); } TORCH_LIBRARY_FRAGMENT(video_reader, m) { m.def("read_video_from_memory", read_video_from_memory); m.def("read_video_from_file", read_video_from_file); m.def("probe_video_from_memory", probe_video_from_memory); m.def("probe_video_from_file", probe_video_from_file); } } // namespace video_reader } // namespace vision