Unverified Commit 6854020f authored by Aziz's avatar Aziz Committed by GitHub
Browse files

Apply codec-based data augmentation (#1200)

parent 4a3d2035
......@@ -8,6 +8,12 @@ from parameterized import parameterized
import itertools
from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import (
TorchaudioTestCase,
skipIfNoExtension,
)
from torchaudio_unittest.backend.sox_io.common import name_func
from .functional_impl import Lfilter, Spectrogram
......@@ -53,6 +59,7 @@ class TestCreateFBMatrix(common_utils.TorchaudioTestCase):
class TestComputeDeltas(common_utils.TorchaudioTestCase):
"""Test suite for correctness of compute_deltas"""
def test_one_channel(self):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]])
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]])
......@@ -211,3 +218,48 @@ class TestMaskAlongAxisIID(common_utils.TorchaudioTestCase):
assert mask_specgrams.size() == specgrams.size()
assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()
@skipIfNoExtension
class TestApplyCodec(TorchaudioTestCase):
backend = "sox_io"
def _smoke_test(self, format, compression, check_num_frames):
"""
The purpose of this test suite is to verify that apply_codec functionalities do not exhibit
abnormal behaviors.
"""
torch.random.manual_seed(42)
sample_rate = 8000
num_frames = 3 * sample_rate
num_channels = 2
waveform = torch.rand(num_channels, num_frames)
augmented = F.apply_codec(waveform,
sample_rate,
format,
True,
compression
)
assert augmented.dtype == waveform.dtype
assert augmented.shape[0] == num_channels
if check_num_frames:
assert augmented.shape[1] == num_frames
def test_wave(self):
self._smoke_test("wav", compression=None, check_num_frames=True)
@parameterized.expand([(96,), (128,), (160,), (192,), (224,), (256,), (320,)],
name_func=name_func)
def test_mp3(self, compression):
self._smoke_test("mp3", compression, check_num_frames=False)
@parameterized.expand([(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,)],
name_func=name_func)
def test_flac(self, compression):
self._smoke_test("flac", compression, check_num_frames=False)
@parameterized.expand([(-1,), (0,), (1,), (2,), (3,), (3.6,), (5,), (10,)],
name_func=name_func)
def test_vorbis(self, compression):
self._smoke_test("vorbis", compression, check_num_frames=False)
......@@ -18,6 +18,7 @@ from .functional import (
sliding_window_cmn,
spectrogram,
spectral_centroid,
apply_codec,
)
from .filtering import (
allpass_biquad,
......@@ -84,4 +85,5 @@ __all__ = [
'riaa_biquad',
'treble_biquad',
'vad',
'apply_codec'
]
# -*- coding: utf-8 -*-
import io
import math
from typing import Optional, Tuple
import warnings
from typing import Optional, Tuple
import torch
from torch import Tensor
from torchaudio._internal import (
module_utils as _mod_utils,
)
import torchaudio
__all__ = [
"spectrogram",
......@@ -29,6 +34,7 @@ __all__ = [
'mask_along_axis_iid',
'sliding_window_cmn',
"spectral_centroid",
"apply_codec",
]
......@@ -994,6 +1000,52 @@ def spectral_centroid(
return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim)
@_mod_utils.requires_module('torchaudio._torchaudio')
def apply_codec(
waveform: Tensor,
sample_rate: int,
format: str,
channels_first: bool = True,
compression: Optional[float] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
) -> Tensor:
r"""
Applies codecs as a form of augmentation
Args:
waveform (Tensor): Audio data. Must be 2 dimensional. See also ```channels_first```
sample_rate (int): Sample rate of the audio waveform
format (str): file format
channels_first (bool):
When True, both the input and output Tensor have dimension ``[channel, time]``.
Otherwise, they have dimension ``[time, channel]``.
compression (float): Used for formats other than WAV.
For mor details see :py:func:`torchaudio.backend.sox_io_backend.save`
encoding (str, optional): Changes the encoding for the supported formats.
For more details see :py:func:`torchaudio.backend.sox_io_backend.save`
bits_per_sample (int, optional): Changes the bit depth for the supported formats.
For more details see :py:func:`torchaudio.backend.sox_io_backend.save`
Returns:
torch.Tensor: Resulting Tensor.
If ``channels_first=True``, it has ``[channel, time]`` else ``[time, channel]``
"""
bytes = io.BytesIO()
torchaudio.backend.sox_io_backend.save(bytes,
waveform,
sample_rate,
channels_first,
compression,
format,
encoding,
bits_per_sample
)
bytes.seek(0)
augmented, _ = torchaudio.sox_effects.sox_effects.apply_effects_file(
bytes, effects=[["rate", f"{sample_rate}"]], channels_first=channels_first, format=format)
return augmented
def compute_kaldi_pitch(
waveform: torch.Tensor,
sample_rate: float,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment