Commit 9133f2a0 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Extract audio conversion into separate class (#3130)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3130

Similar to https://github.com/pytorch/audio/pull/3120
Adopt the generator style slicing conversion to audio encoding
process.

Reviewed By: nateanl

Differential Revision: D43685380

fbshipit-source-id: 3e95655783e5c5d768486f8af6e6b47b0072999b
parent fbf05f28
#include <torchaudio/csrc/ffmpeg/stream_writer/audio_converter.h>
namespace torchaudio::io {
namespace {
AVFramePtr get_audio_frame(
AVSampleFormat src_fmt,
AVCodecContext* codec_ctx,
int default_frame_size) {
AVFramePtr frame{};
frame->pts = 0;
frame->format = src_fmt;
frame->channel_layout = codec_ctx->channel_layout;
frame->sample_rate = codec_ctx->sample_rate;
frame->nb_samples =
codec_ctx->frame_size ? codec_ctx->frame_size : default_frame_size;
if (frame->nb_samples) {
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0,
"Error allocating an audio buffer (",
av_err2string(ret),
").");
}
return frame;
}
void validate_audio_input(
enum AVSampleFormat fmt,
AVCodecContext* ctx,
const torch::Tensor& t) {
auto dtype = t.dtype().toScalarType();
switch (fmt) {
case AV_SAMPLE_FMT_U8:
TORCH_CHECK(
dtype == c10::ScalarType::Byte, "Expected Tensor of uint8 type.");
break;
case AV_SAMPLE_FMT_S16:
TORCH_CHECK(
dtype == c10::ScalarType::Short, "Expected Tensor of int16 type.");
break;
case AV_SAMPLE_FMT_S32:
TORCH_CHECK(
dtype == c10::ScalarType::Int, "Expected Tensor of int32 type.");
break;
case AV_SAMPLE_FMT_S64:
TORCH_CHECK(
dtype == c10::ScalarType::Long, "Expected Tensor of int64 type.");
break;
case AV_SAMPLE_FMT_FLT:
TORCH_CHECK(
dtype == c10::ScalarType::Float, "Expected Tensor of float32 type.");
break;
case AV_SAMPLE_FMT_DBL:
TORCH_CHECK(
dtype == c10::ScalarType::Double, "Expected Tensor of float64 type.");
break;
default:
TORCH_CHECK(
false,
"Internal error: Audio encoding stream is not properly configured.");
}
TORCH_CHECK(t.device().is_cpu(), "Input tensor has to be on CPU.");
TORCH_CHECK(t.dim() == 2, "Input Tensor has to be 2D.");
const auto num_channels = t.size(1);
TORCH_CHECK(
num_channels == ctx->channels,
"Expected waveform with ",
ctx->channels,
" channels. Found ",
num_channels);
}
// 2D (time, channel) and contiguous.
void convert_func_(const torch::Tensor& chunk, AVFrame* buffer) {
auto num_frames = chunk.size(0);
auto byte_size = chunk.numel() * chunk.element_size();
// TODO: make writable
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00334
TORCH_CHECK(av_frame_is_writable(buffer), "frame is not writable.");
memcpy(buffer->data[0], chunk.data_ptr(), byte_size);
buffer->nb_samples = static_cast<int>(num_frames);
}
} // namespace
AudioTensorConverter::AudioTensorConverter(
enum AVSampleFormat src_fmt_,
AVCodecContext* codec_ctx_,
int default_frame_size)
: src_fmt(src_fmt_),
codec_ctx(codec_ctx_),
buffer(get_audio_frame(src_fmt_, codec_ctx_, default_frame_size)),
buffer_size(buffer->nb_samples),
convert_func(convert_func_) {}
SlicingTensorConverter AudioTensorConverter::convert(
const torch::Tensor& frames) {
validate_audio_input(src_fmt, codec_ctx, frames);
return SlicingTensorConverter{
frames.contiguous(),
buffer,
convert_func,
buffer_size,
};
}
} // namespace torchaudio::io
#pragma once
#include <torchaudio/csrc/ffmpeg/stream_writer/converter.h>
namespace torchaudio::io {
////////////////////////////////////////////////////////////////////////////////
// AudioTensorConverter
////////////////////////////////////////////////////////////////////////////////
// AudioTensorConverter is responsible for picking up the right set of
// conversion process (InitFunc and ConvertFunc) based on the input sample
// format information, and own them.
class AudioTensorConverter {
enum AVSampleFormat src_fmt;
AVCodecContext* codec_ctx;
AVFramePtr buffer;
const int64_t buffer_size;
SlicingTensorConverter::ConvertFunc convert_func;
public:
AudioTensorConverter(
enum AVSampleFormat src_fmt,
AVCodecContext* codec_ctx,
int default_frame_size = 10000);
SlicingTensorConverter convert(const torch::Tensor& frames);
};
} // namespace torchaudio::io
...@@ -29,28 +29,6 @@ FilterGraph get_audio_filter( ...@@ -29,28 +29,6 @@ FilterGraph get_audio_filter(
return p; return p;
} }
AVFramePtr get_audio_frame(
AVSampleFormat src_fmt,
AVCodecContext* codec_ctx,
int default_frame_size = 10000) {
AVFramePtr frame{};
frame->pts = 0;
frame->format = src_fmt;
frame->channel_layout = codec_ctx->channel_layout;
frame->sample_rate = codec_ctx->sample_rate;
frame->nb_samples =
codec_ctx->frame_size ? codec_ctx->frame_size : default_frame_size;
if (frame->nb_samples) {
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0,
"Error allocating an audio buffer (",
av_err2string(ret),
").");
}
return frame;
}
} // namespace } // namespace
AudioOutputStream::AudioOutputStream( AudioOutputStream::AudioOutputStream(
...@@ -61,83 +39,15 @@ AudioOutputStream::AudioOutputStream( ...@@ -61,83 +39,15 @@ AudioOutputStream::AudioOutputStream(
format_ctx, format_ctx,
codec_ctx_, codec_ctx_,
get_audio_filter(src_fmt, codec_ctx_)), get_audio_filter(src_fmt, codec_ctx_)),
src_frame(get_audio_frame(src_fmt, codec_ctx_)), converter(src_fmt, codec_ctx_),
frame_capacity(src_frame->nb_samples),
codec_ctx(std::move(codec_ctx_)) {} codec_ctx(std::move(codec_ctx_)) {}
namespace {
void validate_audio_input(
enum AVSampleFormat fmt,
AVCodecContext* ctx,
const torch::Tensor& t) {
auto dtype = t.dtype().toScalarType();
switch (fmt) {
case AV_SAMPLE_FMT_U8:
TORCH_CHECK(
dtype == c10::ScalarType::Byte, "Expected Tensor of uint8 type.");
break;
case AV_SAMPLE_FMT_S16:
TORCH_CHECK(
dtype == c10::ScalarType::Short, "Expected Tensor of int16 type.");
break;
case AV_SAMPLE_FMT_S32:
TORCH_CHECK(
dtype == c10::ScalarType::Int, "Expected Tensor of int32 type.");
break;
case AV_SAMPLE_FMT_S64:
TORCH_CHECK(
dtype == c10::ScalarType::Long, "Expected Tensor of int64 type.");
break;
case AV_SAMPLE_FMT_FLT:
TORCH_CHECK(
dtype == c10::ScalarType::Float, "Expected Tensor of float32 type.");
break;
case AV_SAMPLE_FMT_DBL:
TORCH_CHECK(
dtype == c10::ScalarType::Double, "Expected Tensor of float64 type.");
break;
default:
TORCH_CHECK(
false,
"Internal error: Audio encoding stream is not properly configured.");
}
TORCH_CHECK(t.device().is_cpu(), "Input tensor has to be on CPU.");
TORCH_CHECK(t.dim() == 2, "Input Tensor has to be 2D.");
const auto num_channels = t.size(1);
TORCH_CHECK(
num_channels == ctx->channels,
"Expected waveform with ",
ctx->channels,
" channels. Found ",
num_channels);
}
} // namespace
void AudioOutputStream::write_chunk(const torch::Tensor& waveform) { void AudioOutputStream::write_chunk(const torch::Tensor& waveform) {
validate_audio_input(
static_cast<AVSampleFormat>(src_frame->format), codec_ctx, waveform);
AVRational time_base{1, codec_ctx->sample_rate}; AVRational time_base{1, codec_ctx->sample_rate};
for (const auto& frame : converter.convert(waveform)) {
using namespace torch::indexing; process_frame(frame);
for (int64_t i = 0; i < waveform.size(0); i += frame_capacity) { frame->pts +=
auto chunk = waveform.index({Slice(i, i + frame_capacity), Slice()}); av_rescale_q(frame->nb_samples, time_base, codec_ctx->time_base);
auto num_frames = chunk.size(0);
auto byte_size = chunk.numel() * chunk.element_size();
chunk = chunk.reshape({-1}).contiguous();
// TODO: make writable
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00334
TORCH_CHECK(
av_frame_is_writable(src_frame),
"Internal Error: frame is not writable.");
memcpy(src_frame->data[0], chunk.data_ptr(), byte_size);
src_frame->nb_samples = num_frames;
process_frame(src_frame);
src_frame->pts += av_rescale_q(num_frames, time_base, codec_ctx->time_base);
} }
} }
......
#pragma once #pragma once
#include <torchaudio/csrc/ffmpeg/stream_writer/audio_converter.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/output_stream.h> #include <torchaudio/csrc/ffmpeg/stream_writer/output_stream.h>
namespace torchaudio::io { namespace torchaudio::io {
struct AudioOutputStream : OutputStream { struct AudioOutputStream : OutputStream {
AVFramePtr src_frame; AudioTensorConverter converter;
int64_t frame_capacity; int64_t frame_capacity;
AVCodecContextPtr codec_ctx; AVCodecContextPtr codec_ctx;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment