Commit 45651245 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Refactor a part of convert_image (#2940)

Summary:
Refactor the two helper functions that convert AVFrame to torch::Tensor into separate buffer allocation and data copy.

Pull Request resolved: https://github.com/pytorch/audio/pull/2940

Reviewed By: carolineechen

Differential Revision: D42247915

Pulled By: mthrok

fbshipit-source-id: 2f504d48674088205e6039e8aadd8856b3fe5eee
parent 4699ef21
......@@ -90,11 +90,9 @@ torch::Tensor convert_audio_tensor(AVFrame* pFrame) {
// Helper functions - video
//////////////////////////////////////////////////////////////////////////////
namespace {
torch::Tensor convert_interlaced_video(AVFrame* pFrame) {
torch::Tensor get_interlaced_image_buffer(AVFrame* pFrame) {
int width = pFrame->width;
int height = pFrame->height;
uint8_t* buf = pFrame->data[0];
int linesize = pFrame->linesize[0];
int channel = av_pix_fmt_desc_get(static_cast<AVPixelFormat>(pFrame->format))
->nb_components;
......@@ -103,18 +101,28 @@ torch::Tensor convert_interlaced_video(AVFrame* pFrame) {
.layout(torch::kStrided)
.device(torch::kCPU);
torch::Tensor frame = torch::empty({1, height, width, channel}, options);
return torch::empty({1, height, width, channel}, options);
}
void write_interlaced_image(AVFrame* pFrame, torch::Tensor& frame) {
auto ptr = frame.data_ptr<uint8_t>();
int stride = width * channel;
uint8_t* buf = pFrame->data[0];
size_t height = frame.size(1);
size_t stride = frame.size(2) * frame.size(3);
for (int i = 0; i < height; ++i) {
memcpy(ptr, buf, stride);
buf += linesize;
buf += pFrame->linesize[0];
ptr += stride;
}
}
torch::Tensor convert_interlaced_video(AVFrame* pFrame) {
torch::Tensor frame = get_interlaced_image_buffer(pFrame);
write_interlaced_image(pFrame, frame);
return frame.permute({0, 3, 1, 2});
}
torch::Tensor convert_planar_video(AVFrame* pFrame) {
torch::Tensor get_planar_image_buffer(AVFrame* pFrame) {
int width = pFrame->width;
int height = pFrame->height;
int num_planes =
......@@ -124,8 +132,13 @@ torch::Tensor convert_planar_video(AVFrame* pFrame) {
.dtype(torch::kUInt8)
.layout(torch::kStrided)
.device(torch::kCPU);
return torch::empty({1, num_planes, height, width}, options);
}
torch::Tensor frame = torch::empty({1, num_planes, height, width}, options);
void write_planar_image(AVFrame* pFrame, torch::Tensor& frame) {
int num_planes = static_cast<int>(frame.size(1));
int height = static_cast<int>(frame.size(2));
int width = static_cast<int>(frame.size(3));
for (int i = 0; i < num_planes; ++i) {
torch::Tensor plane = frame.index({0, i});
uint8_t* tgt = plane.data_ptr<uint8_t>();
......@@ -137,6 +150,11 @@ torch::Tensor convert_planar_video(AVFrame* pFrame) {
src += linesize;
}
}
}
torch::Tensor convert_planar_video(AVFrame* pFrame) {
torch::Tensor frame = get_planar_image_buffer(pFrame);
write_planar_image(pFrame, frame);
return frame;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment