Unverified Commit 7ea108f8 authored by hwangjeff's avatar hwangjeff Committed by GitHub
Browse files

Support compression level in i/o dispatcher backend

Differential Revision: D50367721

Pull Request resolved: https://github.com/pytorch/audio/pull/3662
parent 671261c3
...@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod ...@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
from typing import BinaryIO, Optional, Tuple, Union from typing import BinaryIO, Optional, Tuple, Union
from torch import Tensor from torch import Tensor
from torchaudio.io import CodecConfig
from .common import AudioMetaData from .common import AudioMetaData
...@@ -37,6 +38,7 @@ class Backend(ABC): ...@@ -37,6 +38,7 @@ class Backend(ABC):
encoding: Optional[str] = None, encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None, bits_per_sample: Optional[int] = None,
buffer_size: int = 4096, buffer_size: int = 4096,
compression: Optional[Union[CodecConfig, float, int]] = None,
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError
......
...@@ -228,6 +228,7 @@ def save_audio( ...@@ -228,6 +228,7 @@ def save_audio(
encoding: Optional[str] = None, encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None, bits_per_sample: Optional[int] = None,
buffer_size: int = 4096, buffer_size: int = 4096,
compression: Optional[torchaudio.io.CodecConfig] = None,
) -> None: ) -> None:
ext = None ext = None
if hasattr(uri, "write"): if hasattr(uri, "write"):
...@@ -250,6 +251,7 @@ def save_audio( ...@@ -250,6 +251,7 @@ def save_audio(
format=_get_sample_format(src.dtype), format=_get_sample_format(src.dtype),
encoder=encoder, encoder=encoder,
encoder_format=enc_fmt, encoder_format=enc_fmt,
codec_config=compression,
) )
with s.open(): with s.open():
s.write_audio_chunk(0, src) s.write_audio_chunk(0, src)
...@@ -304,7 +306,13 @@ class FFmpegBackend(Backend): ...@@ -304,7 +306,13 @@ class FFmpegBackend(Backend):
encoding: Optional[str] = None, encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None, bits_per_sample: Optional[int] = None,
buffer_size: int = 4096, buffer_size: int = 4096,
compression: Optional[Union[torchaudio.io.CodecConfig, float, int]] = None,
) -> None: ) -> None:
if not isinstance(compression, (torchaudio.io.CodecConfig, type(None))):
raise ValueError(
"FFmpeg backend expects non-`None` value for argument `compression` to be of ",
f"type `torchaudio.io.CodecConfig`, but received value of type {type(compression)}",
)
save_audio( save_audio(
uri, uri,
src, src,
...@@ -314,6 +322,7 @@ class FFmpegBackend(Backend): ...@@ -314,6 +322,7 @@ class FFmpegBackend(Backend):
encoding, encoding,
bits_per_sample, bits_per_sample,
buffer_size, buffer_size,
compression,
) )
@staticmethod @staticmethod
......
...@@ -2,6 +2,7 @@ import os ...@@ -2,6 +2,7 @@ import os
from typing import BinaryIO, Optional, Tuple, Union from typing import BinaryIO, Optional, Tuple, Union
import torch import torch
from torchaudio.io import CodecConfig
from . import soundfile_backend from . import soundfile_backend
from .backend import Backend from .backend import Backend
...@@ -35,7 +36,11 @@ class SoundfileBackend(Backend): ...@@ -35,7 +36,11 @@ class SoundfileBackend(Backend):
encoding: Optional[str] = None, encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None, bits_per_sample: Optional[int] = None,
buffer_size: int = 4096, buffer_size: int = 4096,
compression: Optional[Union[CodecConfig, float, int]] = None,
) -> None: ) -> None:
if compression:
raise ValueError("soundfile backend does not support argument `compression`.")
soundfile_backend.save( soundfile_backend.save(
uri, src, sample_rate, channels_first, format=format, encoding=encoding, bits_per_sample=bits_per_sample uri, src, sample_rate, channels_first, format=format, encoding=encoding, bits_per_sample=bits_per_sample
) )
......
...@@ -56,7 +56,13 @@ class SoXBackend(Backend): ...@@ -56,7 +56,13 @@ class SoXBackend(Backend):
encoding: Optional[str] = None, encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None, bits_per_sample: Optional[int] = None,
buffer_size: int = 4096, buffer_size: int = 4096,
compression: Optional[Union[torchaudio.io.CodecConfig, float, int]] = None,
) -> None: ) -> None:
if not isinstance(compression, (float, int, type(None))):
raise ValueError(
"SoX backend expects non-`None` value for argument `compression` to be of ",
f"type `float` or `int`, but received value of type {type(compression)}",
)
if hasattr(uri, "write"): if hasattr(uri, "write"):
raise ValueError( raise ValueError(
"SoX backend does not support writing to file-like objects. ", "SoX backend does not support writing to file-like objects. ",
...@@ -68,7 +74,7 @@ class SoXBackend(Backend): ...@@ -68,7 +74,7 @@ class SoXBackend(Backend):
src, src,
sample_rate, sample_rate,
channels_first, channels_first,
None, compression,
format, format,
encoding, encoding,
bits_per_sample, bits_per_sample,
......
...@@ -5,6 +5,7 @@ from typing import BinaryIO, Dict, Optional, Tuple, Type, Union ...@@ -5,6 +5,7 @@ from typing import BinaryIO, Dict, Optional, Tuple, Type, Union
import torch import torch
from torchaudio._extension import lazy_import_ffmpeg_ext, lazy_import_sox_ext from torchaudio._extension import lazy_import_ffmpeg_ext, lazy_import_sox_ext
from torchaudio.io import CodecConfig
from . import soundfile_backend from . import soundfile_backend
...@@ -229,6 +230,7 @@ def get_save_func(): ...@@ -229,6 +230,7 @@ def get_save_func():
bits_per_sample: Optional[int] = None, bits_per_sample: Optional[int] = None,
buffer_size: int = 4096, buffer_size: int = 4096,
backend: Optional[str] = None, backend: Optional[str] = None,
compression: Optional[Union[CodecConfig, float, int]] = None,
): ):
"""Save audio data to file. """Save audio data to file.
...@@ -283,8 +285,32 @@ def get_save_func(): ...@@ -283,8 +285,32 @@ def get_save_func():
.. seealso:: .. seealso::
:ref:`backend` :ref:`backend`
compression (CodecConfig, float, int, or None, optional):
Compression configuration to apply.
If the selected backend is FFmpeg, an instance of :py:class:`CodecConfig` must be provided.
Otherwise, if the selected backend is SoX, a float or int value corresponding to option ``-C`` of the
``sox`` command line interface must be provided. For instance:
``"mp3"``
Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or
VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``.
``"flac"``
Whole number from ``0`` to ``8``. ``8`` is default and highest compression.
``"ogg"``, ``"vorbis"``
Number from ``-1`` to ``10``; ``-1`` is the highest compression
and lowest quality. Default: ``3``.
Refer to http://sox.sourceforge.net/soxformat.html for more details.
""" """
backend = dispatcher(uri, format, backend) backend = dispatcher(uri, format, backend)
return backend.save(uri, src, sample_rate, channels_first, format, encoding, bits_per_sample, buffer_size) return backend.save(
uri, src, sample_rate, channels_first, format, encoding, bits_per_sample, buffer_size, compression
)
return save return save
...@@ -105,7 +105,7 @@ class DispatcherTest(PytorchTestCase): ...@@ -105,7 +105,7 @@ class DispatcherTest(PytorchTestCase):
f"torchaudio._backend.utils.{expected_backend.__name__}.save" f"torchaudio._backend.utils.{expected_backend.__name__}.save"
) as mock_save: ) as mock_save:
get_save_func()(filename, src, sample_rate, format=format) get_save_func()(filename, src, sample_rate, format=format)
mock_save.assert_called_once_with(filename, src, sample_rate, True, format, None, None, 4096) mock_save.assert_called_once_with(filename, src, sample_rate, True, format, None, None, 4096, None)
@parameterized.expand( @parameterized.expand(
[ [
...@@ -126,4 +126,4 @@ class DispatcherTest(PytorchTestCase): ...@@ -126,4 +126,4 @@ class DispatcherTest(PytorchTestCase):
f"torchaudio._backend.utils.{expected_backend.__name__}.save" f"torchaudio._backend.utils.{expected_backend.__name__}.save"
) as mock_save: ) as mock_save:
get_save_func()(f, src, sample_rate, format=format, buffer_size=buffer_size) get_save_func()(f, src, sample_rate, format=format, buffer_size=buffer_size)
mock_save.assert_called_once_with(f, src, sample_rate, True, format, None, None, buffer_size) mock_save.assert_called_once_with(f, src, sample_rate, True, format, None, None, buffer_size, None)
...@@ -4,11 +4,13 @@ import pathlib ...@@ -4,11 +4,13 @@ import pathlib
import subprocess import subprocess
import sys import sys
from functools import partial from functools import partial
from typing import Optional
import torch import torch
from parameterized import parameterized from parameterized import parameterized
from torchaudio._backend.ffmpeg import _parse_save_args from torchaudio._backend.ffmpeg import _parse_save_args
from torchaudio._backend.utils import get_save_func from torchaudio._backend.utils import get_save_func
from torchaudio.io import CodecConfig
from torchaudio_unittest.backend.dispatcher.sox.common import get_enc_params, name_func from torchaudio_unittest.backend.dispatcher.sox.common import get_enc_params, name_func
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
...@@ -45,6 +47,7 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase): ...@@ -45,6 +47,7 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
self, self,
format: str, format: str,
*, *,
compression: Optional[CodecConfig] = None,
encoding: str = None, encoding: str = None,
bits_per_sample: int = None, bits_per_sample: int = None,
sample_rate: float = 8000, sample_rate: float = 8000,
...@@ -104,7 +107,15 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase): ...@@ -104,7 +107,15 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
data = load_wav(src_path, normalize=False)[0] data = load_wav(src_path, normalize=False)[0]
if test_mode == "path": if test_mode == "path":
ext = format ext = format
self._save(tgt_path, data, sample_rate, format=format, encoding=encoding, bits_per_sample=bits_per_sample) self._save(
tgt_path,
data,
sample_rate,
compression=compression,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
elif test_mode == "fileobj": elif test_mode == "fileobj":
ext = None ext = None
with open(tgt_path, "bw") as file_: with open(tgt_path, "bw") as file_:
...@@ -112,6 +123,7 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase): ...@@ -112,6 +123,7 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
file_, file_,
data, data,
sample_rate, sample_rate,
compression=compression,
format=format, format=format,
encoding=encoding, encoding=encoding,
bits_per_sample=bits_per_sample, bits_per_sample=bits_per_sample,
...@@ -123,6 +135,7 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase): ...@@ -123,6 +135,7 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
file_, file_,
data, data,
sample_rate, sample_rate,
compression=compression,
format=format, format=format,
encoding=encoding, encoding=encoding,
bits_per_sample=bits_per_sample, bits_per_sample=bits_per_sample,
...@@ -198,11 +211,27 @@ class SaveTest(SaveTestBase): ...@@ -198,11 +211,27 @@ class SaveTest(SaveTestBase):
# NOTE: Supported sample formats: s16 s32 (24 bits) # NOTE: Supported sample formats: s16 s32 (24 bits)
# [8, 16, 24], # [8, 16, 24],
[16, 24], [16, 24],
[
0,
1,
2,
3,
4,
5,
6,
7,
8,
],
) )
def test_save_flac(self, test_mode, bits_per_sample): def test_save_flac(self, test_mode, bits_per_sample, compression_level):
# -acodec flac -sample_fmt s16 # -acodec flac -sample_fmt s16
# 24 bits needs to be mapped to s32 # 24 bits needs to be mapped to s32
self.assert_save_consistency("flac", bits_per_sample=bits_per_sample, test_mode=test_mode) codec_config = CodecConfig(
compression_level=compression_level,
)
self.assert_save_consistency(
"flac", compression=codec_config, bits_per_sample=bits_per_sample, test_mode=test_mode
)
# @nested_params( # @nested_params(
# ["path", "fileobj", "bytesio"], # ["path", "fileobj", "bytesio"],
...@@ -212,12 +241,25 @@ class SaveTest(SaveTestBase): ...@@ -212,12 +241,25 @@ class SaveTest(SaveTestBase):
# self.assert_save_consistency("htk", test_mode=test_mode, num_channels=1) # self.assert_save_consistency("htk", test_mode=test_mode, num_channels=1)
@nested_params( @nested_params(
[
None,
-1,
0,
1,
2,
3,
5,
10,
],
["path", "fileobj", "bytesio"], ["path", "fileobj", "bytesio"],
) )
def test_save_vorbis(self, test_mode): def test_save_vorbis(self, quality_level, test_mode):
# NOTE: ffmpeg doesn't recognize extension "vorbis", so we use "ogg" # NOTE: ffmpeg doesn't recognize extension "vorbis", so we use "ogg"
# self.assert_save_consistency("vorbis", test_mode=test_mode) # self.assert_save_consistency("vorbis", test_mode=test_mode)
self.assert_save_consistency("ogg", test_mode=test_mode) codec_config = CodecConfig(
qscale=quality_level,
)
self.assert_save_consistency("ogg", compression=codec_config, test_mode=test_mode)
# @nested_params( # @nested_params(
# ["path", "fileobj", "bytesio"], # ["path", "fileobj", "bytesio"],
......
...@@ -40,6 +40,7 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase): ...@@ -40,6 +40,7 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
self, self,
format: str, format: str,
*, *,
compression: float = None,
encoding: str = None, encoding: str = None,
bits_per_sample: int = None, bits_per_sample: int = None,
sample_rate: float = 8000, sample_rate: float = 8000,
...@@ -101,13 +102,16 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase): ...@@ -101,13 +102,16 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
# 2.1. Convert the original wav to target format with torchaudio # 2.1. Convert the original wav to target format with torchaudio
data = load_wav(src_path, normalize=False)[0] data = load_wav(src_path, normalize=False)[0]
if test_mode == "path": if test_mode == "path":
self._save(tgt_path, data, sample_rate, encoding=encoding, bits_per_sample=bits_per_sample) self._save(
tgt_path, data, sample_rate, compression=compression, encoding=encoding, bits_per_sample=bits_per_sample
)
elif test_mode == "fileobj": elif test_mode == "fileobj":
with open(tgt_path, "bw") as file_: with open(tgt_path, "bw") as file_:
self._save( self._save(
file_, file_,
data, data,
sample_rate, sample_rate,
compression=compression,
format=format, format=format,
encoding=encoding, encoding=encoding,
bits_per_sample=bits_per_sample, bits_per_sample=bits_per_sample,
...@@ -118,6 +122,7 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase): ...@@ -118,6 +122,7 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
file_, file_,
data, data,
sample_rate, sample_rate,
compression=compression,
format=format, format=format,
encoding=encoding, encoding=encoding,
bits_per_sample=bits_per_sample, bits_per_sample=bits_per_sample,
...@@ -134,7 +139,9 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase): ...@@ -134,7 +139,9 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
# 3.1. Convert the original wav to target format with sox # 3.1. Convert the original wav to target format with sox
sox_encoding = _get_sox_encoding(encoding) sox_encoding = _get_sox_encoding(encoding)
sox_utils.convert_audio_file(src_path, sox_path, encoding=sox_encoding, bit_depth=bits_per_sample) sox_utils.convert_audio_file(
src_path, sox_path, compression=compression, encoding=sox_encoding, bit_depth=bits_per_sample
)
# 3.2. Convert the target format to wav with sox # 3.2. Convert the target format to wav with sox
sox_utils.convert_audio_file(sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth) sox_utils.convert_audio_file(sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
# 3.3. Load with SciPy # 3.3. Load with SciPy
...@@ -175,15 +182,42 @@ class SaveTest(SaveTestBase): ...@@ -175,15 +182,42 @@ class SaveTest(SaveTestBase):
@nested_params( @nested_params(
[8, 16, 24], [8, 16, 24],
[
None,
0,
1,
2,
3,
4,
5,
6,
7,
8,
],
)
def test_save_flac(self, bits_per_sample, compression_level):
self.assert_save_consistency(
"flac", compression=compression_level, bits_per_sample=bits_per_sample, test_mode="path"
) )
def test_save_flac(self, bits_per_sample):
self.assert_save_consistency("flac", bits_per_sample=bits_per_sample, test_mode="path")
def test_save_htk(self): def test_save_htk(self):
self.assert_save_consistency("htk", test_mode="path", num_channels=1) self.assert_save_consistency("htk", test_mode="path", num_channels=1)
def test_save_vorbis(self): @nested_params(
self.assert_save_consistency("vorbis", test_mode="path") [
None,
-1,
0,
1,
2,
3,
3.6,
5,
10,
],
)
def test_save_vorbis(self, quality_level):
self.assert_save_consistency("vorbis", compression=quality_level, test_mode="path")
@nested_params( @nested_params(
[ [
...@@ -254,9 +288,22 @@ class SaveTest(SaveTestBase): ...@@ -254,9 +288,22 @@ class SaveTest(SaveTestBase):
encoding, bits_per_sample = enc_params encoding, bits_per_sample = enc_params
self.assert_save_consistency("amb", encoding=encoding, bits_per_sample=bits_per_sample, test_mode="path") self.assert_save_consistency("amb", encoding=encoding, bits_per_sample=bits_per_sample, test_mode="path")
@nested_params(
[
None,
0,
1,
2,
3,
4,
5,
6,
7,
],
)
@skipIfNoSoxEncoder("amr-nb") @skipIfNoSoxEncoder("amr-nb")
def test_save_amr_nb(self): def test_save_amr_nb(self, bit_rate):
self.assert_save_consistency("amr-nb", num_channels=1, test_mode="path") self.assert_save_consistency("amr-nb", compression=bit_rate, num_channels=1, test_mode="path")
def test_save_gsm(self): def test_save_gsm(self):
self.assert_save_consistency("gsm", num_channels=1, test_mode="path") self.assert_save_consistency("gsm", num_channels=1, test_mode="path")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment