"fs/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "903b1fc97f37fda25fd233ed853355acfc0f63cf"
Unverified Commit 6854020f authored by Aziz's avatar Aziz Committed by GitHub
Browse files

Apply codec-based data augmentation (#1200)

parent 4a3d2035
...@@ -8,6 +8,12 @@ from parameterized import parameterized ...@@ -8,6 +8,12 @@ from parameterized import parameterized
import itertools import itertools
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import (
TorchaudioTestCase,
skipIfNoExtension,
)
from torchaudio_unittest.backend.sox_io.common import name_func
from .functional_impl import Lfilter, Spectrogram from .functional_impl import Lfilter, Spectrogram
...@@ -53,6 +59,7 @@ class TestCreateFBMatrix(common_utils.TorchaudioTestCase): ...@@ -53,6 +59,7 @@ class TestCreateFBMatrix(common_utils.TorchaudioTestCase):
class TestComputeDeltas(common_utils.TorchaudioTestCase): class TestComputeDeltas(common_utils.TorchaudioTestCase):
"""Test suite for correctness of compute_deltas""" """Test suite for correctness of compute_deltas"""
def test_one_channel(self): def test_one_channel(self):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]]) specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]])
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]]) expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]])
...@@ -211,3 +218,48 @@ class TestMaskAlongAxisIID(common_utils.TorchaudioTestCase): ...@@ -211,3 +218,48 @@ class TestMaskAlongAxisIID(common_utils.TorchaudioTestCase):
assert mask_specgrams.size() == specgrams.size() assert mask_specgrams.size() == specgrams.size()
assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel() assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()
@skipIfNoExtension
class TestApplyCodec(TorchaudioTestCase):
backend = "sox_io"
def _smoke_test(self, format, compression, check_num_frames):
"""
The purpose of this test suite is to verify that apply_codec functionalities do not exhibit
abnormal behaviors.
"""
torch.random.manual_seed(42)
sample_rate = 8000
num_frames = 3 * sample_rate
num_channels = 2
waveform = torch.rand(num_channels, num_frames)
augmented = F.apply_codec(waveform,
sample_rate,
format,
True,
compression
)
assert augmented.dtype == waveform.dtype
assert augmented.shape[0] == num_channels
if check_num_frames:
assert augmented.shape[1] == num_frames
def test_wave(self):
self._smoke_test("wav", compression=None, check_num_frames=True)
@parameterized.expand([(96,), (128,), (160,), (192,), (224,), (256,), (320,)],
name_func=name_func)
def test_mp3(self, compression):
self._smoke_test("mp3", compression, check_num_frames=False)
@parameterized.expand([(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,)],
name_func=name_func)
def test_flac(self, compression):
self._smoke_test("flac", compression, check_num_frames=False)
@parameterized.expand([(-1,), (0,), (1,), (2,), (3,), (3.6,), (5,), (10,)],
name_func=name_func)
def test_vorbis(self, compression):
self._smoke_test("vorbis", compression, check_num_frames=False)
...@@ -18,6 +18,7 @@ from .functional import ( ...@@ -18,6 +18,7 @@ from .functional import (
sliding_window_cmn, sliding_window_cmn,
spectrogram, spectrogram,
spectral_centroid, spectral_centroid,
apply_codec,
) )
from .filtering import ( from .filtering import (
allpass_biquad, allpass_biquad,
...@@ -84,4 +85,5 @@ __all__ = [ ...@@ -84,4 +85,5 @@ __all__ = [
'riaa_biquad', 'riaa_biquad',
'treble_biquad', 'treble_biquad',
'vad', 'vad',
'apply_codec'
] ]
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import io
import math import math
from typing import Optional, Tuple
import warnings import warnings
from typing import Optional, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
from torchaudio._internal import (
module_utils as _mod_utils,
)
import torchaudio
__all__ = [ __all__ = [
"spectrogram", "spectrogram",
...@@ -29,6 +34,7 @@ __all__ = [ ...@@ -29,6 +34,7 @@ __all__ = [
'mask_along_axis_iid', 'mask_along_axis_iid',
'sliding_window_cmn', 'sliding_window_cmn',
"spectral_centroid", "spectral_centroid",
"apply_codec",
] ]
...@@ -994,6 +1000,52 @@ def spectral_centroid( ...@@ -994,6 +1000,52 @@ def spectral_centroid(
return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim) return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim)
@_mod_utils.requires_module('torchaudio._torchaudio')
def apply_codec(
waveform: Tensor,
sample_rate: int,
format: str,
channels_first: bool = True,
compression: Optional[float] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
) -> Tensor:
r"""
Applies codecs as a form of augmentation
Args:
waveform (Tensor): Audio data. Must be 2 dimensional. See also ```channels_first```
sample_rate (int): Sample rate of the audio waveform
format (str): file format
channels_first (bool):
When True, both the input and output Tensor have dimension ``[channel, time]``.
Otherwise, they have dimension ``[time, channel]``.
compression (float): Used for formats other than WAV.
For mor details see :py:func:`torchaudio.backend.sox_io_backend.save`
encoding (str, optional): Changes the encoding for the supported formats.
For more details see :py:func:`torchaudio.backend.sox_io_backend.save`
bits_per_sample (int, optional): Changes the bit depth for the supported formats.
For more details see :py:func:`torchaudio.backend.sox_io_backend.save`
Returns:
torch.Tensor: Resulting Tensor.
If ``channels_first=True``, it has ``[channel, time]`` else ``[time, channel]``
"""
bytes = io.BytesIO()
torchaudio.backend.sox_io_backend.save(bytes,
waveform,
sample_rate,
channels_first,
compression,
format,
encoding,
bits_per_sample
)
bytes.seek(0)
augmented, _ = torchaudio.sox_effects.sox_effects.apply_effects_file(
bytes, effects=[["rate", f"{sample_rate}"]], channels_first=channels_first, format=format)
return augmented
def compute_kaldi_pitch( def compute_kaldi_pitch(
waveform: torch.Tensor, waveform: torch.Tensor,
sample_rate: float, sample_rate: float,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment