Commit 898db8c7 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Update slicing conversion code (#3129)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3129

- Add step parameter to support audio slicing
- Rename to `SlicingTensorConverter` (`Generator` is too generic.)

Reviewed By: xiaohui-zhang

Differential Revision: D43704926

fbshipit-source-id: c4bf0ff766e0ae1b5d46b159a6367492ef68f9cd
parent b0faecb2
...@@ -2,21 +2,28 @@ ...@@ -2,21 +2,28 @@
namespace torchaudio::io { namespace torchaudio::io {
using Iterator = Generator::Iterator; using Iterator = SlicingTensorConverter::Iterator;
using ConvertFunc = Generator::ConvertFunc; using ConvertFunc = SlicingTensorConverter::ConvertFunc;
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// Generator // SlicingTensorConverter
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
Generator::Generator(torch::Tensor frames_, AVFrame* buff, ConvertFunc& func) SlicingTensorConverter::SlicingTensorConverter(
: frames(std::move(frames_)), buffer(buff), convert_func(func) {} torch::Tensor frames_,
AVFrame* buff,
Iterator Generator::begin() const { ConvertFunc& func,
return Iterator{frames, buffer, convert_func}; int64_t step_)
: frames(std::move(frames_)),
buffer(buff),
convert_func(func),
step(step_) {}
Iterator SlicingTensorConverter::begin() const {
return Iterator{frames, buffer, convert_func, step};
} }
int64_t Generator::end() const { int64_t SlicingTensorConverter::end() const {
return frames.size(0); return frames.size(0);
} }
...@@ -27,21 +34,28 @@ int64_t Generator::end() const { ...@@ -27,21 +34,28 @@ int64_t Generator::end() const {
Iterator::Iterator( Iterator::Iterator(
const torch::Tensor frames_, const torch::Tensor frames_,
AVFrame* buffer_, AVFrame* buffer_,
ConvertFunc& convert_func_) ConvertFunc& convert_func_,
: frames(frames_), buffer(buffer_), convert_func(convert_func_) {} int64_t step_)
: frames(frames_),
buffer(buffer_),
convert_func(convert_func_),
step(step_) {}
Iterator& Iterator::operator++() { Iterator& Iterator::operator++() {
++i; i += step;
return *this; return *this;
} }
AVFrame* Iterator::operator*() const { AVFrame* Iterator::operator*() const {
convert_func(frames.index({i}), buffer); using namespace torch::indexing;
convert_func(frames.index({Slice{i, i + step}}), buffer);
return buffer; return buffer;
} }
bool Iterator::operator!=(const int64_t other) const { bool Iterator::operator!=(const int64_t end) const {
return i != other; // This is used for detecting the end of iteraton.
// For audio, iteration is done by
return i < end;
} }
} // namespace torchaudio::io } // namespace torchaudio::io
...@@ -6,17 +6,15 @@ ...@@ -6,17 +6,15 @@
namespace torchaudio::io { namespace torchaudio::io {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// Generator // SlicingTensorConverter
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// Genrator class is responsible for implementing an interface compatible with // SlicingTensorConverter class is responsible for implementing an interface
// range-based for loop interface (begin and end), and initialization of frame // compatible with range-based for loop interface (begin and end).
// data (channel reordering and ensuring the contiguous-ness). class SlicingTensorConverter {
class Generator {
public: public:
// Convert function writes input frame Tensor to destinatoin AVFrame // Convert function writes input frame Tensor to destinatoin AVFrame
// both tensor input and AVFrame are expected to be valid and properly // both tensor input and AVFrame are expected to be valid and properly
// allocated. (i.e. glorified copy) // allocated. (i.e. glorified copy). It is used in Iterator.
// It is one-to-one conversion. Performed in Iterator.
using ConvertFunc = std::function<void(const torch::Tensor&, AVFrame*)>; using ConvertFunc = std::function<void(const torch::Tensor&, AVFrame*)>;
//////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////
...@@ -26,7 +24,9 @@ class Generator { ...@@ -26,7 +24,9 @@ class Generator {
// increment, comaprison against, and dereference (applying conversion // increment, comaprison against, and dereference (applying conversion
// function in it). // function in it).
class Iterator { class Iterator {
// Input tensor, has to be NCHW or NHWC, uint8, CPU or CUDA // Tensor to be sliced
// - audio: NC, CPU, uint8|int16|float|double
// - video: NCHW or NHWC, CPU or CUDA, uint8
// It will be sliced at dereference time. // It will be sliced at dereference time.
const torch::Tensor frames; const torch::Tensor frames;
// Output buffer (not owned, but modified by Iterator) // Output buffer (not owned, but modified by Iterator)
...@@ -35,13 +35,15 @@ class Generator { ...@@ -35,13 +35,15 @@ class Generator {
ConvertFunc& convert_func; ConvertFunc& convert_func;
// Index // Index
int64_t step;
int64_t i = 0; int64_t i = 0;
public: public:
Iterator( Iterator(
const torch::Tensor tensor, const torch::Tensor tensor,
AVFrame* buffer, AVFrame* buffer,
ConvertFunc& convert_func); ConvertFunc& convert_func,
int64_t step);
Iterator& operator++(); Iterator& operator++();
AVFrame* operator*() const; AVFrame* operator*() const;
...@@ -49,8 +51,9 @@ class Generator { ...@@ -49,8 +51,9 @@ class Generator {
}; };
private: private:
// Tensor representing video frames provided by client code // Input Tensor:
// Expected (and validated) to be NCHW, uint8. // - video: NCHW, CPU|CUDA, uint8,
// - audio: NC, CPU, uin8|int16|int32|in64|float32|double
torch::Tensor frames; torch::Tensor frames;
// Output buffer (not owned, passed to iterator) // Output buffer (not owned, passed to iterator)
...@@ -59,8 +62,14 @@ class Generator { ...@@ -59,8 +62,14 @@ class Generator {
// ops: not owned. // ops: not owned.
ConvertFunc& convert_func; ConvertFunc& convert_func;
int64_t step;
public: public:
Generator(torch::Tensor frames, AVFrame* buffer, ConvertFunc& convert_func); SlicingTensorConverter(
torch::Tensor frames,
AVFrame* buffer,
ConvertFunc& convert_func,
int64_t step = 1);
[[nodiscard]] Iterator begin() const; [[nodiscard]] Iterator begin() const;
[[nodiscard]] int64_t end() const; [[nodiscard]] int64_t end() const;
......
...@@ -11,7 +11,7 @@ namespace torchaudio::io { ...@@ -11,7 +11,7 @@ namespace torchaudio::io {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
using InitFunc = VideoTensorConverter::InitFunc; using InitFunc = VideoTensorConverter::InitFunc;
using ConvertFunc = Generator::ConvertFunc; using ConvertFunc = SlicingTensorConverter::ConvertFunc;
namespace { namespace {
...@@ -27,9 +27,12 @@ namespace { ...@@ -27,9 +27,12 @@ namespace {
// ... // ...
// H: RGB RGB ... RGB PAD ... PAD // H: RGB RGB ... RGB PAD ... PAD
void write_interlaced_video(const torch::Tensor& frame, AVFrame* buffer) { void write_interlaced_video(const torch::Tensor& frame, AVFrame* buffer) {
const auto height = frame.size(0); TORCH_INTERNAL_ASSERT(
const auto width = frame.size(1); frame.size(0) == 1,
const auto num_channels = frame.size(2); "The first dimension of the image dimension must be one.");
const auto height = frame.size(1);
const auto width = frame.size(2);
const auto num_channels = frame.size(3);
size_t stride = width * num_channels; size_t stride = width * num_channels;
// TODO: writable // TODO: writable
...@@ -70,15 +73,18 @@ void write_planar_video( ...@@ -70,15 +73,18 @@ void write_planar_video(
const torch::Tensor& frame, const torch::Tensor& frame,
AVFrame* buffer, AVFrame* buffer,
int num_planes) { int num_planes) {
const auto height = frame.size(1); TORCH_INTERNAL_ASSERT(
const auto width = frame.size(2); frame.size(0) == 1,
"The first dimension of the image dimension must be one.");
const auto height = frame.size(2);
const auto width = frame.size(3);
// TODO: writable // TODO: writable
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00472 // https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00472
TORCH_INTERNAL_ASSERT(av_frame_is_writable(buffer), "frame is not writable."); TORCH_INTERNAL_ASSERT(av_frame_is_writable(buffer), "frame is not writable.");
for (int j = 0; j < num_planes; ++j) { for (int j = 0; j < num_planes; ++j) {
uint8_t* src = frame.index({j}).data_ptr<uint8_t>(); uint8_t* src = frame.index({0, j}).data_ptr<uint8_t>();
uint8_t* dst = buffer->data[j]; uint8_t* dst = buffer->data[j];
for (int h = 0; h < height; ++h) { for (int h = 0; h < height; ++h) {
memcpy(dst, src, width); memcpy(dst, src, width);
...@@ -97,9 +103,12 @@ void write_interlaced_video_cuda( ...@@ -97,9 +103,12 @@ void write_interlaced_video_cuda(
false, false,
"torchaudio is not compiled with CUDA support. Hardware acceleration is not available."); "torchaudio is not compiled with CUDA support. Hardware acceleration is not available.");
#else #else
const auto height = frame.size(0); TORCH_INTERNAL_ASSERT(
const auto width = frame.size(1); frame.size(0) == 1,
const auto num_channels = frame.size(2) + (pad_extra ? 1 : 0); "The first dimension of the image dimension must be one.");
const auto height = frame.size(1);
const auto width = frame.size(2);
const auto num_channels = frame.size(3) + (pad_extra ? 1 : 0);
size_t spitch = width * num_channels; size_t spitch = width * num_channels;
if (cudaSuccess != if (cudaSuccess !=
cudaMemcpy2D( cudaMemcpy2D(
...@@ -124,14 +133,17 @@ void write_planar_video_cuda( ...@@ -124,14 +133,17 @@ void write_planar_video_cuda(
false, false,
"torchaudio is not compiled with CUDA support. Hardware acceleration is not available."); "torchaudio is not compiled with CUDA support. Hardware acceleration is not available.");
#else #else
const auto height = frame.size(1); TORCH_INTERNAL_ASSERT(
const auto width = frame.size(2); frame.size(0) == 1,
"The first dimension of the image dimension must be one.");
const auto height = frame.size(2);
const auto width = frame.size(3);
for (int j = 0; j < num_planes; ++j) { for (int j = 0; j < num_planes; ++j) {
if (cudaSuccess != if (cudaSuccess !=
cudaMemcpy2D( cudaMemcpy2D(
(void*)(buffer->data[j]), (void*)(buffer->data[j]),
buffer->linesize[j], buffer->linesize[j],
(const void*)(frame.index({j}).data_ptr<uint8_t>()), (const void*)(frame.index({0, j}).data_ptr<uint8_t>()),
width, width,
width, width,
height, height,
...@@ -269,9 +281,10 @@ VideoTensorConverter::VideoTensorConverter( ...@@ -269,9 +281,10 @@ VideoTensorConverter::VideoTensorConverter(
std::tie(init_func, convert_func) = get_func(src_fmt, codec_ctx->sw_pix_fmt); std::tie(init_func, convert_func) = get_func(src_fmt, codec_ctx->sw_pix_fmt);
} }
Generator VideoTensorConverter::convert(const torch::Tensor& frames) { SlicingTensorConverter VideoTensorConverter::convert(
const torch::Tensor& frames) {
validate_video_input(src_fmt, codec_ctx, frames); validate_video_input(src_fmt, codec_ctx, frames);
return Generator{init_func(frames), buffer, convert_func}; return SlicingTensorConverter{init_func(frames), buffer, convert_func};
} }
} // namespace torchaudio::io } // namespace torchaudio::io
...@@ -22,10 +22,10 @@ class VideoTensorConverter { ...@@ -22,10 +22,10 @@ class VideoTensorConverter {
AVFramePtr buffer; AVFramePtr buffer;
InitFunc init_func{}; InitFunc init_func{};
Generator::ConvertFunc convert_func{}; SlicingTensorConverter::ConvertFunc convert_func{};
public: public:
VideoTensorConverter(enum AVPixelFormat src_fmt, AVCodecContext* codec_ctx); VideoTensorConverter(enum AVPixelFormat src_fmt, AVCodecContext* codec_ctx);
Generator convert(const torch::Tensor& frames); SlicingTensorConverter convert(const torch::Tensor& frames);
}; };
} // namespace torchaudio::io } // namespace torchaudio::io
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment