Unverified Commit 7ea108f8 authored by hwangjeff's avatar hwangjeff Committed by GitHub
Browse files

Support compression level in i/o dispatcher backend

Differential Revision: D50367721

Pull Request resolved: https://github.com/pytorch/audio/pull/3662
parent 671261c3
......@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
from typing import BinaryIO, Optional, Tuple, Union
from torch import Tensor
from torchaudio.io import CodecConfig
from .common import AudioMetaData
......@@ -37,6 +38,7 @@ class Backend(ABC):
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[Union[CodecConfig, float, int]] = None,
) -> None:
raise NotImplementedError
......
......@@ -228,6 +228,7 @@ def save_audio(
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[torchaudio.io.CodecConfig] = None,
) -> None:
ext = None
if hasattr(uri, "write"):
......@@ -250,6 +251,7 @@ def save_audio(
format=_get_sample_format(src.dtype),
encoder=encoder,
encoder_format=enc_fmt,
codec_config=compression,
)
with s.open():
s.write_audio_chunk(0, src)
......@@ -304,7 +306,13 @@ class FFmpegBackend(Backend):
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[Union[torchaudio.io.CodecConfig, float, int]] = None,
) -> None:
if not isinstance(compression, (torchaudio.io.CodecConfig, type(None))):
raise ValueError(
"FFmpeg backend expects non-`None` value for argument `compression` to be of ",
f"type `torchaudio.io.CodecConfig`, but received value of type {type(compression)}",
)
save_audio(
uri,
src,
......@@ -314,6 +322,7 @@ class FFmpegBackend(Backend):
encoding,
bits_per_sample,
buffer_size,
compression,
)
@staticmethod
......
......@@ -2,6 +2,7 @@ import os
from typing import BinaryIO, Optional, Tuple, Union
import torch
from torchaudio.io import CodecConfig
from . import soundfile_backend
from .backend import Backend
......@@ -35,7 +36,11 @@ class SoundfileBackend(Backend):
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[Union[CodecConfig, float, int]] = None,
) -> None:
if compression:
raise ValueError("soundfile backend does not support argument `compression`.")
soundfile_backend.save(
uri, src, sample_rate, channels_first, format=format, encoding=encoding, bits_per_sample=bits_per_sample
)
......
......@@ -56,7 +56,13 @@ class SoXBackend(Backend):
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[Union[torchaudio.io.CodecConfig, float, int]] = None,
) -> None:
if not isinstance(compression, (float, int, type(None))):
raise ValueError(
"SoX backend expects non-`None` value for argument `compression` to be of ",
f"type `float` or `int`, but received value of type {type(compression)}",
)
if hasattr(uri, "write"):
raise ValueError(
"SoX backend does not support writing to file-like objects. ",
......@@ -68,7 +74,7 @@ class SoXBackend(Backend):
src,
sample_rate,
channels_first,
None,
compression,
format,
encoding,
bits_per_sample,
......
......@@ -5,6 +5,7 @@ from typing import BinaryIO, Dict, Optional, Tuple, Type, Union
import torch
from torchaudio._extension import lazy_import_ffmpeg_ext, lazy_import_sox_ext
from torchaudio.io import CodecConfig
from . import soundfile_backend
......@@ -229,6 +230,7 @@ def get_save_func():
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
backend: Optional[str] = None,
compression: Optional[Union[CodecConfig, float, int]] = None,
):
"""Save audio data to file.
......@@ -283,8 +285,32 @@ def get_save_func():
.. seealso::
:ref:`backend`
compression (CodecConfig, float, int, or None, optional):
Compression configuration to apply.
If the selected backend is FFmpeg, an instance of :py:class:`CodecConfig` must be provided.
Otherwise, if the selected backend is SoX, a float or int value corresponding to option ``-C`` of the
``sox`` command line interface must be provided. For instance:
``"mp3"``
Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or
VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``.
``"flac"``
Whole number from ``0`` to ``8``. ``8`` is default and highest compression.
``"ogg"``, ``"vorbis"``
Number from ``-1`` to ``10``; ``-1`` is the highest compression
and lowest quality. Default: ``3``.
Refer to http://sox.sourceforge.net/soxformat.html for more details.
"""
backend = dispatcher(uri, format, backend)
return backend.save(uri, src, sample_rate, channels_first, format, encoding, bits_per_sample, buffer_size)
return backend.save(
uri, src, sample_rate, channels_first, format, encoding, bits_per_sample, buffer_size, compression
)
return save
......@@ -105,7 +105,7 @@ class DispatcherTest(PytorchTestCase):
f"torchaudio._backend.utils.{expected_backend.__name__}.save"
) as mock_save:
get_save_func()(filename, src, sample_rate, format=format)
mock_save.assert_called_once_with(filename, src, sample_rate, True, format, None, None, 4096)
mock_save.assert_called_once_with(filename, src, sample_rate, True, format, None, None, 4096, None)
@parameterized.expand(
[
......@@ -126,4 +126,4 @@ class DispatcherTest(PytorchTestCase):
f"torchaudio._backend.utils.{expected_backend.__name__}.save"
) as mock_save:
get_save_func()(f, src, sample_rate, format=format, buffer_size=buffer_size)
mock_save.assert_called_once_with(f, src, sample_rate, True, format, None, None, buffer_size)
mock_save.assert_called_once_with(f, src, sample_rate, True, format, None, None, buffer_size, None)
......@@ -4,11 +4,13 @@ import pathlib
import subprocess
import sys
from functools import partial
from typing import Optional
import torch
from parameterized import parameterized
from torchaudio._backend.ffmpeg import _parse_save_args
from torchaudio._backend.utils import get_save_func
from torchaudio.io import CodecConfig
from torchaudio_unittest.backend.dispatcher.sox.common import get_enc_params, name_func
from torchaudio_unittest.common_utils import (
......@@ -45,6 +47,7 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
self,
format: str,
*,
compression: Optional[CodecConfig] = None,
encoding: str = None,
bits_per_sample: int = None,
sample_rate: float = 8000,
......@@ -104,7 +107,15 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
data = load_wav(src_path, normalize=False)[0]
if test_mode == "path":
ext = format
self._save(tgt_path, data, sample_rate, format=format, encoding=encoding, bits_per_sample=bits_per_sample)
self._save(
tgt_path,
data,
sample_rate,
compression=compression,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
elif test_mode == "fileobj":
ext = None
with open(tgt_path, "bw") as file_:
......@@ -112,6 +123,7 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
file_,
data,
sample_rate,
compression=compression,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
......@@ -123,6 +135,7 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
file_,
data,
sample_rate,
compression=compression,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
......@@ -198,11 +211,27 @@ class SaveTest(SaveTestBase):
# NOTE: Supported sample formats: s16 s32 (24 bits)
# [8, 16, 24],
[16, 24],
[
0,
1,
2,
3,
4,
5,
6,
7,
8,
],
)
def test_save_flac(self, test_mode, bits_per_sample):
def test_save_flac(self, test_mode, bits_per_sample, compression_level):
# -acodec flac -sample_fmt s16
# 24 bits needs to be mapped to s32
self.assert_save_consistency("flac", bits_per_sample=bits_per_sample, test_mode=test_mode)
codec_config = CodecConfig(
compression_level=compression_level,
)
self.assert_save_consistency(
"flac", compression=codec_config, bits_per_sample=bits_per_sample, test_mode=test_mode
)
# @nested_params(
# ["path", "fileobj", "bytesio"],
......@@ -212,12 +241,25 @@ class SaveTest(SaveTestBase):
# self.assert_save_consistency("htk", test_mode=test_mode, num_channels=1)
@nested_params(
[
None,
-1,
0,
1,
2,
3,
5,
10,
],
["path", "fileobj", "bytesio"],
)
def test_save_vorbis(self, test_mode):
def test_save_vorbis(self, quality_level, test_mode):
# NOTE: ffmpeg doesn't recognize extension "vorbis", so we use "ogg"
# self.assert_save_consistency("vorbis", test_mode=test_mode)
self.assert_save_consistency("ogg", test_mode=test_mode)
codec_config = CodecConfig(
qscale=quality_level,
)
self.assert_save_consistency("ogg", compression=codec_config, test_mode=test_mode)
# @nested_params(
# ["path", "fileobj", "bytesio"],
......
......@@ -40,6 +40,7 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
self,
format: str,
*,
compression: float = None,
encoding: str = None,
bits_per_sample: int = None,
sample_rate: float = 8000,
......@@ -101,13 +102,16 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
# 2.1. Convert the original wav to target format with torchaudio
data = load_wav(src_path, normalize=False)[0]
if test_mode == "path":
self._save(tgt_path, data, sample_rate, encoding=encoding, bits_per_sample=bits_per_sample)
self._save(
tgt_path, data, sample_rate, compression=compression, encoding=encoding, bits_per_sample=bits_per_sample
)
elif test_mode == "fileobj":
with open(tgt_path, "bw") as file_:
self._save(
file_,
data,
sample_rate,
compression=compression,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
......@@ -118,6 +122,7 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
file_,
data,
sample_rate,
compression=compression,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
......@@ -134,7 +139,9 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
# 3.1. Convert the original wav to target format with sox
sox_encoding = _get_sox_encoding(encoding)
sox_utils.convert_audio_file(src_path, sox_path, encoding=sox_encoding, bit_depth=bits_per_sample)
sox_utils.convert_audio_file(
src_path, sox_path, compression=compression, encoding=sox_encoding, bit_depth=bits_per_sample
)
# 3.2. Convert the target format to wav with sox
sox_utils.convert_audio_file(sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
# 3.3. Load with SciPy
......@@ -175,15 +182,42 @@ class SaveTest(SaveTestBase):
@nested_params(
[8, 16, 24],
[
None,
0,
1,
2,
3,
4,
5,
6,
7,
8,
],
)
def test_save_flac(self, bits_per_sample):
self.assert_save_consistency("flac", bits_per_sample=bits_per_sample, test_mode="path")
def test_save_flac(self, bits_per_sample, compression_level):
self.assert_save_consistency(
"flac", compression=compression_level, bits_per_sample=bits_per_sample, test_mode="path"
)
def test_save_htk(self):
self.assert_save_consistency("htk", test_mode="path", num_channels=1)
def test_save_vorbis(self):
self.assert_save_consistency("vorbis", test_mode="path")
@nested_params(
[
None,
-1,
0,
1,
2,
3,
3.6,
5,
10,
],
)
def test_save_vorbis(self, quality_level):
self.assert_save_consistency("vorbis", compression=quality_level, test_mode="path")
@nested_params(
[
......@@ -254,9 +288,22 @@ class SaveTest(SaveTestBase):
encoding, bits_per_sample = enc_params
self.assert_save_consistency("amb", encoding=encoding, bits_per_sample=bits_per_sample, test_mode="path")
@nested_params(
[
None,
0,
1,
2,
3,
4,
5,
6,
7,
],
)
@skipIfNoSoxEncoder("amr-nb")
def test_save_amr_nb(self):
self.assert_save_consistency("amr-nb", num_channels=1, test_mode="path")
def test_save_amr_nb(self, bit_rate):
self.assert_save_consistency("amr-nb", compression=bit_rate, num_channels=1, test_mode="path")
def test_save_gsm(self):
self.assert_save_consistency("gsm", num_channels=1, test_mode="path")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment