Unverified Commit b42d6100 authored by wanglong001's avatar wanglong001 Committed by GitHub
Browse files

add cmvn (#540)



* add cmvn

* Update transforms.rst

add cmvn

* Correct the format

* Correct the format

* Correct the format

* add test unit and cmvn change to cmn

* fix bug
Co-authored-by: default avatarVincent QB <vincentqb@users.noreply.github.com>
parent 8a742e0f
......@@ -147,3 +147,8 @@ Functions to perform common audio operations.
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: detect_pitch_frequency
:hidden:`sliding_window_cmn`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: sliding_window_cmn
\ No newline at end of file
......@@ -128,3 +128,10 @@ Transforms are common audio transforms. They can be chained together using :clas
.. autoclass:: Vol
.. automethod:: forward
:hidden:`SlidingWindowCmn`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: SlidingWindowCmn
.. automethod:: forward
......@@ -406,6 +406,41 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform)
def test_sliding_window_cmn(self):
def func(tensor):
cmn_window = 600
min_cmn_window = 100
center = False
norm_vars = False
a = torch.tensor(
[
[
-1.915875792503357,
1.147700309753418
],
[
1.8242558240890503,
1.3869990110397339
]
],
device=tensor.device,
dtype=tensor.dtype
)
return F.sliding_window_cmn(a, cmn_window, min_cmn_window, center, norm_vars)
b = torch.tensor(
[
[
-1.8701,
-0.1196
],
[
1.8701,
0.1196
]
]
)
self._assert_consistency(func, b)
def test_contrast(self):
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -416,7 +451,6 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform)
class _TransformsTestMixin:
"""Implements test for Transforms that are performed for different devices"""
device = None
......@@ -496,6 +530,10 @@ class _TransformsTestMixin:
waveform, _ = torchaudio.load(test_filepath)
self._assert_consistency(T.Vol(1.1), waveform)
def test_SlidingWindowCmn(self):
tensor = torch.rand((1000, 10))
self._assert_consistency(T.SlidingWindowCmn(), tensor)
class TestFunctionalCPU(_FunctionalTestMixin, unittest.TestCase):
"""Test suite for Functional module on CPU"""
......
......@@ -33,7 +33,8 @@ __all__ = [
"biquad",
"contrast",
'mask_along_axis',
'mask_along_axis_iid'
'mask_along_axis_iid',
'sliding_window_cmn',
]
......@@ -1643,3 +1644,86 @@ def detect_pitch_frequency(
freq = freq.view(shape[:-1] + list(freq.shape[-1:]))
return freq
def sliding_window_cmn(
waveform: Tensor,
cmn_window: int = 600,
min_cmn_window: int = 100,
center: bool = False,
norm_vars: bool = False,
) -> Tensor:
r"""
Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.
Args:
waveform (Tensor): Tensor of audio of dimension (..., freq, time)
cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
min_cmn_window (int, optional): Minimum CMN window used at start of decoding (adds latency only at start).
Only applicable if center == false, ignored if center==true (int, default = 100)
center (bool, optional): If true, use a window centered on the current frame
(to the extent possible, modulo end effects). If false, window is to the left. (bool, default = false)
norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)
Returns:
Tensor: Tensor of freq of dimension (..., frame)
"""
dtype = waveform.dtype
device = waveform.device
last_window_start = last_window_end = -1
num_frames, num_feats = waveform.shape
cur_sum = torch.zeros(num_feats, dtype=dtype, device=device)
cur_sumsq = torch.zeros(num_feats, dtype=dtype, device=device)
cmn_waveform = torch.zeros(
num_frames, num_feats, dtype=dtype, device=device)
for t in range(num_frames):
window_start = 0
window_end = 0
if center:
window_start = t - cmn_window // 2
window_end = window_start + cmn_window
else:
window_start = t - cmn_window
window_end = t + 1
if window_start < 0:
window_end -= window_start
window_start = 0
if not center:
if window_end > t:
window_end = max(t + 1, min_cmn_window)
if window_end > num_frames:
window_start -= (window_end - num_frames)
window_end = num_frames
if window_start < 0:
window_start = 0
if last_window_start == -1:
input_part = waveform[window_start: window_end - window_start]
cur_sum += torch.sum(input_part, 0)
if norm_vars:
cur_sumsq += torch.cumsum(input_part ** 2, 0)[-1]
else:
if window_start > last_window_start:
frame_to_remove = waveform[last_window_start]
cur_sum -= frame_to_remove
if norm_vars:
cur_sumsq -= (frame_to_remove ** 2)
if window_end > last_window_end:
frame_to_add = waveform[last_window_end]
cur_sum += frame_to_add
if norm_vars:
cur_sumsq += (frame_to_add ** 2)
window_frames = window_end - window_start
last_window_start = window_start
last_window_end = window_end
cmn_waveform[t] = waveform[t] - cur_sum / window_frames
if norm_vars:
if window_frames == 1:
cmn_waveform[t] = torch.zeros(
num_feats, dtype=dtype, device=device)
else:
variance = cur_sumsq
variance = variance / window_frames
variance -= ((cur_sum ** 2) / (window_frames ** 2))
variance = torch.pow(variance, -0.5)
cmn_waveform[t] *= variance
return cmn_waveform
......@@ -26,6 +26,7 @@ __all__ = [
'Fade',
'FrequencyMasking',
'TimeMasking',
'SlidingWindowCmn',
]
......@@ -869,3 +870,40 @@ class Vol(torch.nn.Module):
waveform = F.gain(waveform, 10 * math.log10(self.gain))
return torch.clamp(waveform, -1, 1)
class SlidingWindowCmn(torch.nn.Module):
r"""
Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.
Args:
cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
min_cmn_window (int, optional): Minimum CMN window used at start of decoding (adds latency only at start).
Only applicable if center == false, ignored if center==true (int, default = 100)
center (bool, optional): If true, use a window centered on the current frame
(to the extent possible, modulo end effects). If false, window is to the left. (bool, default = false)
norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)
"""
def __init__(self,
cmn_window: int = 600,
min_cmn_window: int = 100,
center: bool = False,
norm_vars: bool = False) -> None:
super().__init__()
self.cmn_window = cmn_window
self.min_cmn_window = min_cmn_window
self.center = center
self.norm_vars = norm_vars
def forward(self, waveform: Tensor) -> Tensor:
r"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
Returns:
Tensor: Tensor of audio of dimension (..., time).
"""
cmn_waveform = F.sliding_window_cmn(
waveform, self.cmn_window, self.min_cmn_window, self.center, self.norm_vars)
return cmn_waveform
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment