Commit d8a37a21 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Properly set #samples passed to encoder (#3204)

Summary:
Some audio encoders expect specific, exact number of samples described as in `AVCodecContext.frame_size`.

The `AVFrame.nb_samples` is set for the frames passed to `AVFilterGraph`,
but frames coming out of the graph do not necessarily have the same numbr of frames.

This causes issues with encoding OPUS (among others).

This commit fixes it by inserting `asetnsamples` to filter graph if a fixed number of samples is requested.

Note:
It turned out that FFmpeg 4.1 has issue with OPUS encoding. It does not properly discard some sample.
We should probably move the minimum required FFmpeg to 4.2, but I am not sure if we can enforce it via ABI.
Work around will be to issue an warning if encoding OPUS with 4.1. (follow-up)

Pull Request resolved: https://github.com/pytorch/audio/pull/3204

Reviewed By: nateanl

Differential Revision: D44374668

Pulled By: mthrok

fbshipit-source-id: 10ef5333dc0677dfb83c8e40b78edd8ded1b21dc
parent 583174ac
...@@ -318,28 +318,58 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -318,28 +318,58 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
pass pass
@nested_params( @nested_params(
["wav", "mp3", "flac"], ["wav", "flac"],
[8000, 16000, 44100], [8000, 16000, 44100],
[1, 2], [1, 2],
) )
def test_audio_num_frames(self, ext, sample_rate, num_channels): def test_audio_num_frames_lossless(self, ext, sample_rate, num_channels):
"""""" """Lossless format preserves the data"""
filename = f"test.{ext}" filename = f"test.{ext}"
data = get_sinusoid(sample_rate=sample_rate, n_channels=num_channels, dtype="int16", channels_first=False)
# Write data # Write data
dst = self.get_dst(filename) dst = self.get_dst(filename)
s = torchaudio.io.StreamWriter(dst=dst, format=ext) s = torchaudio.io.StreamWriter(dst=dst, format=ext)
s.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels) s.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels, format="s16")
with s.open():
s.write_audio_chunk(0, data)
freq = 300 if self.test_fileobj:
duration = 60 dst.flush()
theta = torch.linspace(0, freq * 2 * 3.14 * duration, sample_rate * duration)
if num_channels == 1: # Load data
chunk = torch.sin(theta).unsqueeze(-1) s = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
else: s.add_audio_stream(-1)
chunk = torch.stack([torch.sin(theta), torch.cos(theta)], dim=-1) s.process_all_packets()
(saved,) = s.pop_chunks()
self.assertEqual(saved, data)
@parameterized.expand(
[
("mp3", 1, 8000),
("mp3", 1, 16000),
("mp3", 1, 44100),
("mp3", 2, 8000),
("mp3", 2, 16000),
("mp3", 2, 44100),
("opus", 1, 48000),
]
)
def test_audio_num_frames_lossy(self, ext, num_channels, sample_rate):
"""Saving audio preserves the number of channels and frames"""
filename = f"test.{ext}"
data = get_sinusoid(sample_rate=sample_rate, n_channels=num_channels, channels_first=False)
# Write data
dst = self.get_dst(filename)
s = torchaudio.io.StreamWriter(dst=dst, format=ext)
s.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels)
with s.open(): with s.open():
s.write_audio_chunk(0, chunk) s.write_audio_chunk(0, data)
if self.test_fileobj: if self.test_fileobj:
dst.flush() dst.flush()
...@@ -349,9 +379,21 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -349,9 +379,21 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
s.process_all_packets() s.process_all_packets()
(saved,) = s.pop_chunks() (saved,) = s.pop_chunks()
assert saved.shape == chunk.shape # This test fails for OPUS if FFmpeg is 4.1, but it passes for 4.2+
if format in ["wav", "flac"]: # 4.1 produces 48312 samples (extra 312)
self.assertEqual(saved, chunk) # Probably this commit fixes it.
# https://github.com/FFmpeg/FFmpeg/commit/18aea7bdd96b320a40573bccabea56afeccdd91c
# TODO: issue warning if 4.1?
if ext == "opus":
ver = torchaudio.utils.ffmpeg_utils.get_versions()["libavcodec"]
# 5.1 libavcodec 59. 18.100
# 4.4 libavcodec 58.134.100
# 4.3 libavcodec 58. 91.100
# 4.2 libavcodec 58. 54.100
# 4.1 libavcodec 58. 35.100
if ver[0] < 59 and ver[1] < 54:
return
self.assertEqual(saved.shape, data.shape)
def test_preserve_fps(self): def test_preserve_fps(self):
"""Decimal point frame rate is properly saved """Decimal point frame rate is properly saved
......
...@@ -233,10 +233,19 @@ FilterGraph get_audio_filter( ...@@ -233,10 +233,19 @@ FilterGraph get_audio_filter(
AVCodecContext* codec_ctx) { AVCodecContext* codec_ctx) {
auto desc = [&]() -> std::string { auto desc = [&]() -> std::string {
if (src_fmt == codec_ctx->sample_fmt) { if (src_fmt == codec_ctx->sample_fmt) {
return "anull"; if (!codec_ctx->frame_size) {
return "anull";
} else {
std::stringstream ss;
ss << "asetnsamples=n=" << codec_ctx->frame_size << ":p=0";
return ss.str();
}
} else { } else {
std::stringstream ss; std::stringstream ss;
ss << "aformat=" << av_get_sample_fmt_name(codec_ctx->sample_fmt); ss << "aformat=" << av_get_sample_fmt_name(codec_ctx->sample_fmt);
if (codec_ctx->frame_size) {
ss << ",asetnsamples=n=" << codec_ctx->frame_size << ":p=0";
}
return ss.str(); return ss.str();
} }
}(); }();
......
...@@ -40,9 +40,12 @@ void convert_func_(const torch::Tensor& chunk, AVFrame* buffer) { ...@@ -40,9 +40,12 @@ void convert_func_(const torch::Tensor& chunk, AVFrame* buffer) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(chunk.dim() == 2); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(chunk.dim() == 2);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(chunk.size(1) == buffer->channels); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(chunk.size(1) == buffer->channels);
// TODO: make writable
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00334 // https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00334
TORCH_CHECK(av_frame_is_writable(buffer), "frame is not writable."); if (!av_frame_is_writable(buffer)) {
int ret = av_frame_make_writable(buffer);
TORCH_INTERNAL_ASSERT(
ret >= 0, "Failed to make frame writable: ", av_err2string(ret));
}
auto byte_size = chunk.numel() * chunk.element_size(); auto byte_size = chunk.numel() * chunk.element_size();
memcpy(buffer->data[0], chunk.data_ptr(), byte_size); memcpy(buffer->data[0], chunk.data_ptr(), byte_size);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment