Unverified Commit 954d5121 authored by moto's avatar moto Committed by GitHub
Browse files

Make sliding_window_cmn batch-aware (#570)

parent 38287a75
......@@ -72,6 +72,13 @@ class TestFunctional(unittest.TestCase):
waveform = torch.rand(2, 100) - 0.5
_test_batch(F.dcshift, waveform, shift=0.5, limiter_gain=0.05)
def test_sliding_window_cmn(self):
waveform = torch.randn(2, 1024) - 0.5
_test_batch(F.sliding_window_cmn, waveform, center=True, norm_vars=True)
_test_batch(F.sliding_window_cmn, waveform, center=True, norm_vars=False)
_test_batch(F.sliding_window_cmn, waveform, center=False, norm_vars=True)
_test_batch(F.sliding_window_cmn, waveform, center=False, norm_vars=False)
class TestTransforms(unittest.TestCase):
"""Test suite for classes defined in `transforms` module"""
......
......@@ -1713,14 +1713,18 @@ def sliding_window_cmn(
Returns:
Tensor: Tensor of freq of dimension (..., frame)
"""
input_shape = waveform.shape
num_frames, num_feats = input_shape[-2:]
waveform = waveform.view(-1, num_frames, num_feats)
num_channels = waveform.shape[0]
dtype = waveform.dtype
device = waveform.device
last_window_start = last_window_end = -1
num_frames, num_feats = waveform.shape
cur_sum = torch.zeros(num_feats, dtype=dtype, device=device)
cur_sumsq = torch.zeros(num_feats, dtype=dtype, device=device)
cur_sum = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
cur_sumsq = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
cmn_waveform = torch.zeros(
num_frames, num_feats, dtype=dtype, device=device)
num_channels, num_frames, num_feats, dtype=dtype, device=device)
for t in range(num_frames):
window_start = 0
window_end = 0
......@@ -1742,33 +1746,37 @@ def sliding_window_cmn(
if window_start < 0:
window_start = 0
if last_window_start == -1:
input_part = waveform[window_start: window_end - window_start]
cur_sum += torch.sum(input_part, 0)
input_part = waveform[:, window_start: window_end - window_start, :]
cur_sum += torch.sum(input_part, 1)
if norm_vars:
cur_sumsq += torch.cumsum(input_part ** 2, 0)[-1]
cur_sumsq += torch.cumsum(input_part ** 2, 1)[:, -1, :]
else:
if window_start > last_window_start:
frame_to_remove = waveform[last_window_start]
frame_to_remove = waveform[:, last_window_start, :]
cur_sum -= frame_to_remove
if norm_vars:
cur_sumsq -= (frame_to_remove ** 2)
if window_end > last_window_end:
frame_to_add = waveform[last_window_end]
frame_to_add = waveform[:, last_window_end, :]
cur_sum += frame_to_add
if norm_vars:
cur_sumsq += (frame_to_add ** 2)
window_frames = window_end - window_start
last_window_start = window_start
last_window_end = window_end
cmn_waveform[t] = waveform[t] - cur_sum / window_frames
cmn_waveform[:, t, :] = waveform[:, t, :] - cur_sum / window_frames
if norm_vars:
if window_frames == 1:
cmn_waveform[t] = torch.zeros(
num_feats, dtype=dtype, device=device)
cmn_waveform[:, t, :] = torch.zeros(
num_channels, num_feats, dtype=dtype, device=device)
else:
variance = cur_sumsq
variance = variance / window_frames
variance -= ((cur_sum ** 2) / (window_frames ** 2))
variance = torch.pow(variance, -0.5)
cmn_waveform[t] *= variance
cmn_waveform[:, t, :] *= variance
cmn_waveform = cmn_waveform.view(input_shape[:-2] + (num_frames, num_feats))
if len(input_shape) == 2:
cmn_waveform = cmn_waveform.squeeze(0)
return cmn_waveform
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment