Commit 01f29d73 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Refactor common buffer utility header (#2962)

Summary:
Put the helper functions in unnamed namespace.

Pull Request resolved: https://github.com/pytorch/audio/pull/2962

Reviewed By: carolineechen

Differential Revision: D42378781

Pulled By: mthrok

fbshipit-source-id: 74daf613f8b78f95141ae4e7c4682d8d0e97f72e
parent b4d55fa1
...@@ -85,13 +85,17 @@ torch::Tensor convert_audio(AVFrame* pFrame) { ...@@ -85,13 +85,17 @@ torch::Tensor convert_audio(AVFrame* pFrame) {
return t; return t;
} }
torch::Tensor get_buffer(at::IntArrayRef shape, const torch::Device& device) { namespace {
torch::Tensor get_buffer(
at::IntArrayRef shape,
const torch::Device& device = torch::Device(torch::kCPU)) {
auto options = torch::TensorOptions() auto options = torch::TensorOptions()
.dtype(torch::kUInt8) .dtype(torch::kUInt8)
.layout(torch::kStrided) .layout(torch::kStrided)
.device(device.type(), device.index()); .device(device.type(), device.index());
return torch::empty(shape, options); return torch::empty(shape, options);
} }
} // namespace
torch::Tensor get_image_buffer(AVFrame* frame, const torch::Device& device) { torch::Tensor get_image_buffer(AVFrame* frame, const torch::Device& device) {
auto fmt = static_cast<AVPixelFormat>(frame->format); auto fmt = static_cast<AVPixelFormat>(frame->format);
...@@ -126,6 +130,7 @@ torch::Tensor get_image_buffer(AVFrame* frame, const torch::Device& device) { ...@@ -126,6 +130,7 @@ torch::Tensor get_image_buffer(AVFrame* frame, const torch::Device& device) {
return get_buffer({1, frame->height, frame->width, channels}, device); return get_buffer({1, frame->height, frame->width, channels}, device);
} }
namespace {
void write_interlaced_image(AVFrame* pFrame, torch::Tensor& frame) { void write_interlaced_image(AVFrame* pFrame, torch::Tensor& frame) {
auto ptr = frame.data_ptr<uint8_t>(); auto ptr = frame.data_ptr<uint8_t>();
uint8_t* buf = pFrame->data[0]; uint8_t* buf = pFrame->data[0];
...@@ -299,6 +304,8 @@ void write_nv12_cuda(AVFrame* pFrame, torch::Tensor& yuv) { ...@@ -299,6 +304,8 @@ void write_nv12_cuda(AVFrame* pFrame, torch::Tensor& yuv) {
} }
#endif #endif
} // namespace
torch::Tensor convert_image(AVFrame* frame, const torch::Device& device) { torch::Tensor convert_image(AVFrame* frame, const torch::Device& device) {
// ref: // ref:
// https://ffmpeg.org/doxygen/4.1/filtering__video_8c_source.html#l00179 // https://ffmpeg.org/doxygen/4.1/filtering__video_8c_source.html#l00179
......
...@@ -11,18 +11,7 @@ namespace detail { ...@@ -11,18 +11,7 @@ namespace detail {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
torch::Tensor convert_audio(AVFrame* frame); torch::Tensor convert_audio(AVFrame* frame);
torch::Tensor get_buffer(
at::IntArrayRef shape,
const torch::Device& device = torch::Device(torch::kCPU));
torch::Tensor get_image_buffer(AVFrame* pFrame, const torch::Device& device); torch::Tensor get_image_buffer(AVFrame* pFrame, const torch::Device& device);
void write_yuv420p(AVFrame* pFrame, torch::Tensor& yuv);
void write_nv12_cpu(AVFrame* pFrame, torch::Tensor& yuv);
#ifdef USE_CUDA
void write_nv12_cuda(AVFrame* pFrame, torch::Tensor& yuv);
#endif
void write_interlaced_image(AVFrame* pFrame, torch::Tensor& frame);
void write_planar_image(AVFrame* pFrame, torch::Tensor& frame);
torch::Tensor convert_image(AVFrame* frame, const torch::Device& device); torch::Tensor convert_image(AVFrame* frame, const torch::Device& device);
} // namespace detail } // namespace detail
......
...@@ -16,11 +16,12 @@ class UnchunkedBuffer : public Buffer { ...@@ -16,11 +16,12 @@ class UnchunkedBuffer : public Buffer {
// Each AVFrame is converted to a Tensor and stored here. // Each AVFrame is converted to a Tensor and stored here.
std::deque<torch::Tensor> chunks; std::deque<torch::Tensor> chunks;
protected:
// The number of currently stored chunks // The number of currently stored chunks
// For video, one Tensor corresponds to one frame, but for audio, // For video, one Tensor corresponds to one frame, but for audio,
// one Tensor contains multiple samples, so we track here. // one Tensor contains multiple samples, so we track here.
int64_t num_buffered_frames = 0; int64_t num_buffered_frames = 0;
protected:
void push_tensor(const torch::Tensor& t); void push_tensor(const torch::Tensor& t);
public: public:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment