Commit 898db8c7 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Update slicing conversion code (#3129)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3129

- Add step parameter to support audio slicing
- Rename to `SlicingTensorConverter` (`Generator` is too generic.)

Reviewed By: xiaohui-zhang

Differential Revision: D43704926

fbshipit-source-id: c4bf0ff766e0ae1b5d46b159a6367492ef68f9cd
parent b0faecb2
......@@ -2,21 +2,28 @@
namespace torchaudio::io {
using Iterator = Generator::Iterator;
using ConvertFunc = Generator::ConvertFunc;
using Iterator = SlicingTensorConverter::Iterator;
using ConvertFunc = SlicingTensorConverter::ConvertFunc;
////////////////////////////////////////////////////////////////////////////////
// Generator
// SlicingTensorConverter
////////////////////////////////////////////////////////////////////////////////
Generator::Generator(torch::Tensor frames_, AVFrame* buff, ConvertFunc& func)
: frames(std::move(frames_)), buffer(buff), convert_func(func) {}
Iterator Generator::begin() const {
return Iterator{frames, buffer, convert_func};
SlicingTensorConverter::SlicingTensorConverter(
torch::Tensor frames_,
AVFrame* buff,
ConvertFunc& func,
int64_t step_)
: frames(std::move(frames_)),
buffer(buff),
convert_func(func),
step(step_) {}
Iterator SlicingTensorConverter::begin() const {
return Iterator{frames, buffer, convert_func, step};
}
int64_t Generator::end() const {
int64_t SlicingTensorConverter::end() const {
return frames.size(0);
}
......@@ -27,21 +34,28 @@ int64_t Generator::end() const {
Iterator::Iterator(
const torch::Tensor frames_,
AVFrame* buffer_,
ConvertFunc& convert_func_)
: frames(frames_), buffer(buffer_), convert_func(convert_func_) {}
ConvertFunc& convert_func_,
int64_t step_)
: frames(frames_),
buffer(buffer_),
convert_func(convert_func_),
step(step_) {}
Iterator& Iterator::operator++() {
++i;
i += step;
return *this;
}
AVFrame* Iterator::operator*() const {
convert_func(frames.index({i}), buffer);
using namespace torch::indexing;
convert_func(frames.index({Slice{i, i + step}}), buffer);
return buffer;
}
bool Iterator::operator!=(const int64_t other) const {
return i != other;
bool Iterator::operator!=(const int64_t end) const {
// This is used for detecting the end of iteraton.
// For audio, iteration is done by
return i < end;
}
} // namespace torchaudio::io
......@@ -6,17 +6,15 @@
namespace torchaudio::io {
//////////////////////////////////////////////////////////////////////////////
// Generator
// SlicingTensorConverter
//////////////////////////////////////////////////////////////////////////////
// Genrator class is responsible for implementing an interface compatible with
// range-based for loop interface (begin and end), and initialization of frame
// data (channel reordering and ensuring the contiguous-ness).
class Generator {
// SlicingTensorConverter class is responsible for implementing an interface
// compatible with range-based for loop interface (begin and end).
class SlicingTensorConverter {
public:
// Convert function writes input frame Tensor to destinatoin AVFrame
// both tensor input and AVFrame are expected to be valid and properly
// allocated. (i.e. glorified copy)
// It is one-to-one conversion. Performed in Iterator.
// allocated. (i.e. glorified copy). It is used in Iterator.
using ConvertFunc = std::function<void(const torch::Tensor&, AVFrame*)>;
////////////////////////////////////////////////////////////////////////////
......@@ -26,7 +24,9 @@ class Generator {
// increment, comaprison against, and dereference (applying conversion
// function in it).
class Iterator {
// Input tensor, has to be NCHW or NHWC, uint8, CPU or CUDA
// Tensor to be sliced
// - audio: NC, CPU, uint8|int16|float|double
// - video: NCHW or NHWC, CPU or CUDA, uint8
// It will be sliced at dereference time.
const torch::Tensor frames;
// Output buffer (not owned, but modified by Iterator)
......@@ -35,13 +35,15 @@ class Generator {
ConvertFunc& convert_func;
// Index
int64_t step;
int64_t i = 0;
public:
Iterator(
const torch::Tensor tensor,
AVFrame* buffer,
ConvertFunc& convert_func);
ConvertFunc& convert_func,
int64_t step);
Iterator& operator++();
AVFrame* operator*() const;
......@@ -49,8 +51,9 @@ class Generator {
};
private:
// Tensor representing video frames provided by client code
// Expected (and validated) to be NCHW, uint8.
// Input Tensor:
// - video: NCHW, CPU|CUDA, uint8,
// - audio: NC, CPU, uin8|int16|int32|in64|float32|double
torch::Tensor frames;
// Output buffer (not owned, passed to iterator)
......@@ -59,8 +62,14 @@ class Generator {
// ops: not owned.
ConvertFunc& convert_func;
int64_t step;
public:
Generator(torch::Tensor frames, AVFrame* buffer, ConvertFunc& convert_func);
SlicingTensorConverter(
torch::Tensor frames,
AVFrame* buffer,
ConvertFunc& convert_func,
int64_t step = 1);
[[nodiscard]] Iterator begin() const;
[[nodiscard]] int64_t end() const;
......
......@@ -11,7 +11,7 @@ namespace torchaudio::io {
////////////////////////////////////////////////////////////////////////////////
using InitFunc = VideoTensorConverter::InitFunc;
using ConvertFunc = Generator::ConvertFunc;
using ConvertFunc = SlicingTensorConverter::ConvertFunc;
namespace {
......@@ -27,9 +27,12 @@ namespace {
// ...
// H: RGB RGB ... RGB PAD ... PAD
void write_interlaced_video(const torch::Tensor& frame, AVFrame* buffer) {
const auto height = frame.size(0);
const auto width = frame.size(1);
const auto num_channels = frame.size(2);
TORCH_INTERNAL_ASSERT(
frame.size(0) == 1,
"The first dimension of the image dimension must be one.");
const auto height = frame.size(1);
const auto width = frame.size(2);
const auto num_channels = frame.size(3);
size_t stride = width * num_channels;
// TODO: writable
......@@ -70,15 +73,18 @@ void write_planar_video(
const torch::Tensor& frame,
AVFrame* buffer,
int num_planes) {
const auto height = frame.size(1);
const auto width = frame.size(2);
TORCH_INTERNAL_ASSERT(
frame.size(0) == 1,
"The first dimension of the image dimension must be one.");
const auto height = frame.size(2);
const auto width = frame.size(3);
// TODO: writable
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00472
TORCH_INTERNAL_ASSERT(av_frame_is_writable(buffer), "frame is not writable.");
for (int j = 0; j < num_planes; ++j) {
uint8_t* src = frame.index({j}).data_ptr<uint8_t>();
uint8_t* src = frame.index({0, j}).data_ptr<uint8_t>();
uint8_t* dst = buffer->data[j];
for (int h = 0; h < height; ++h) {
memcpy(dst, src, width);
......@@ -97,9 +103,12 @@ void write_interlaced_video_cuda(
false,
"torchaudio is not compiled with CUDA support. Hardware acceleration is not available.");
#else
const auto height = frame.size(0);
const auto width = frame.size(1);
const auto num_channels = frame.size(2) + (pad_extra ? 1 : 0);
TORCH_INTERNAL_ASSERT(
frame.size(0) == 1,
"The first dimension of the image dimension must be one.");
const auto height = frame.size(1);
const auto width = frame.size(2);
const auto num_channels = frame.size(3) + (pad_extra ? 1 : 0);
size_t spitch = width * num_channels;
if (cudaSuccess !=
cudaMemcpy2D(
......@@ -124,14 +133,17 @@ void write_planar_video_cuda(
false,
"torchaudio is not compiled with CUDA support. Hardware acceleration is not available.");
#else
const auto height = frame.size(1);
const auto width = frame.size(2);
TORCH_INTERNAL_ASSERT(
frame.size(0) == 1,
"The first dimension of the image dimension must be one.");
const auto height = frame.size(2);
const auto width = frame.size(3);
for (int j = 0; j < num_planes; ++j) {
if (cudaSuccess !=
cudaMemcpy2D(
(void*)(buffer->data[j]),
buffer->linesize[j],
(const void*)(frame.index({j}).data_ptr<uint8_t>()),
(const void*)(frame.index({0, j}).data_ptr<uint8_t>()),
width,
width,
height,
......@@ -269,9 +281,10 @@ VideoTensorConverter::VideoTensorConverter(
std::tie(init_func, convert_func) = get_func(src_fmt, codec_ctx->sw_pix_fmt);
}
Generator VideoTensorConverter::convert(const torch::Tensor& frames) {
SlicingTensorConverter VideoTensorConverter::convert(
const torch::Tensor& frames) {
validate_video_input(src_fmt, codec_ctx, frames);
return Generator{init_func(frames), buffer, convert_func};
return SlicingTensorConverter{init_func(frames), buffer, convert_func};
}
} // namespace torchaudio::io
......@@ -22,10 +22,10 @@ class VideoTensorConverter {
AVFramePtr buffer;
InitFunc init_func{};
Generator::ConvertFunc convert_func{};
SlicingTensorConverter::ConvertFunc convert_func{};
public:
VideoTensorConverter(enum AVPixelFormat src_fmt, AVCodecContext* codec_ctx);
Generator convert(const torch::Tensor& frames);
SlicingTensorConverter convert(const torch::Tensor& frames);
};
} // namespace torchaudio::io
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment