Commit db4898f3 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Refactor audio conversion (#3143)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3143

Similar to https://github.com/pytorch/audio/pull/3140,
only provide objects which are semantically related to the
operation performed by AudioConverter.

Reviewed By: xiaohui-zhang

Differential Revision: D43781012

fbshipit-source-id: 4795e20f56272af5cfda8a5f46083e60d1890c3e
parent 26acdbff
......@@ -4,34 +4,9 @@ namespace torchaudio::io {
namespace {
AVFramePtr get_audio_frame(
AVSampleFormat src_fmt,
AVCodecContext* codec_ctx,
int default_frame_size) {
AVFramePtr frame{};
frame->pts = 0;
frame->format = src_fmt;
frame->channel_layout = codec_ctx->channel_layout;
frame->sample_rate = codec_ctx->sample_rate;
frame->nb_samples =
codec_ctx->frame_size ? codec_ctx->frame_size : default_frame_size;
if (frame->nb_samples) {
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0,
"Error allocating an audio buffer (",
av_err2string(ret),
").");
}
return frame;
}
void validate_audio_input(
enum AVSampleFormat fmt,
AVCodecContext* ctx,
const torch::Tensor& t) {
void validate_audio_input(AVFrame* buffer, const torch::Tensor& t) {
auto dtype = t.dtype().toScalarType();
switch (fmt) {
switch (static_cast<AVSampleFormat>(buffer->format)) {
case AV_SAMPLE_FMT_U8:
TORCH_CHECK(
dtype == c10::ScalarType::Byte, "Expected Tensor of uint8 type.");
......@@ -65,9 +40,9 @@ void validate_audio_input(
TORCH_CHECK(t.dim() == 2, "Input Tensor has to be 2D.");
const auto num_channels = t.size(1);
TORCH_CHECK(
num_channels == ctx->channels,
num_channels == buffer->channels,
"Expected waveform with ",
ctx->channels,
buffer->channels,
" channels. Found ",
num_channels);
}
......@@ -87,18 +62,13 @@ void convert_func_(const torch::Tensor& chunk, AVFrame* buffer) {
} // namespace
AudioTensorConverter::AudioTensorConverter(
enum AVSampleFormat src_fmt_,
AVCodecContext* codec_ctx_,
int default_frame_size)
: src_fmt(src_fmt_),
codec_ctx(codec_ctx_),
buffer(get_audio_frame(src_fmt_, codec_ctx_, default_frame_size)),
buffer_size(buffer->nb_samples),
convert_func(convert_func_) {}
AVFrame* buffer_,
const int64_t buffer_size_)
: buffer(buffer_), buffer_size(buffer_size_), convert_func(convert_func_) {}
SlicingTensorConverter AudioTensorConverter::convert(
const torch::Tensor& frames) {
validate_audio_input(src_fmt, codec_ctx, frames);
validate_audio_input(buffer, frames);
return SlicingTensorConverter{
frames.contiguous(),
buffer,
......
......@@ -11,17 +11,12 @@ namespace torchaudio::io {
// conversion process (InitFunc and ConvertFunc) based on the input sample
// format information, and own them.
class AudioTensorConverter {
enum AVSampleFormat src_fmt;
AVCodecContext* codec_ctx;
AVFramePtr buffer;
AVFrame* buffer;
const int64_t buffer_size;
SlicingTensorConverter::ConvertFunc convert_func;
public:
AudioTensorConverter(
enum AVSampleFormat src_fmt,
AVCodecContext* codec_ctx,
int default_frame_size = 10000);
AudioTensorConverter(AVFrame* buffer, const int64_t buffer_size);
SlicingTensorConverter convert(const torch::Tensor& frames);
};
} // namespace torchaudio::io
......@@ -29,6 +29,28 @@ FilterGraph get_audio_filter(
return p;
}
AVFramePtr get_audio_frame(
AVSampleFormat src_fmt,
AVCodecContext* codec_ctx,
int default_frame_size = 10000) {
AVFramePtr frame{};
frame->pts = 0;
frame->format = src_fmt;
frame->channel_layout = codec_ctx->channel_layout;
frame->sample_rate = codec_ctx->sample_rate;
frame->nb_samples =
codec_ctx->frame_size ? codec_ctx->frame_size : default_frame_size;
if (frame->nb_samples) {
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0,
"Error allocating an audio buffer (",
av_err2string(ret),
").");
}
return frame;
}
} // namespace
AudioOutputStream::AudioOutputStream(
......@@ -39,7 +61,8 @@ AudioOutputStream::AudioOutputStream(
format_ctx,
codec_ctx_,
get_audio_filter(src_fmt, codec_ctx_)),
converter(src_fmt, codec_ctx_),
buffer(get_audio_frame(src_fmt, codec_ctx_)),
converter(buffer, buffer->nb_samples),
codec_ctx(std::move(codec_ctx_)) {}
void AudioOutputStream::write_chunk(const torch::Tensor& waveform) {
......
......@@ -5,8 +5,8 @@
namespace torchaudio::io {
struct AudioOutputStream : OutputStream {
AVFramePtr buffer;
AudioTensorConverter converter;
int64_t frame_capacity;
AVCodecContextPtr codec_ctx;
AudioOutputStream(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment