Unverified Commit 954d5121 authored by moto's avatar moto Committed by GitHub
Browse files

Make sliding_window_cmn batch-aware (#570)

parent 38287a75
...@@ -72,6 +72,13 @@ class TestFunctional(unittest.TestCase): ...@@ -72,6 +72,13 @@ class TestFunctional(unittest.TestCase):
waveform = torch.rand(2, 100) - 0.5 waveform = torch.rand(2, 100) - 0.5
_test_batch(F.dcshift, waveform, shift=0.5, limiter_gain=0.05) _test_batch(F.dcshift, waveform, shift=0.5, limiter_gain=0.05)
def test_sliding_window_cmn(self):
waveform = torch.randn(2, 1024) - 0.5
_test_batch(F.sliding_window_cmn, waveform, center=True, norm_vars=True)
_test_batch(F.sliding_window_cmn, waveform, center=True, norm_vars=False)
_test_batch(F.sliding_window_cmn, waveform, center=False, norm_vars=True)
_test_batch(F.sliding_window_cmn, waveform, center=False, norm_vars=False)
class TestTransforms(unittest.TestCase): class TestTransforms(unittest.TestCase):
"""Test suite for classes defined in `transforms` module""" """Test suite for classes defined in `transforms` module"""
......
...@@ -1713,14 +1713,18 @@ def sliding_window_cmn( ...@@ -1713,14 +1713,18 @@ def sliding_window_cmn(
Returns: Returns:
Tensor: Tensor of freq of dimension (..., frame) Tensor: Tensor of freq of dimension (..., frame)
""" """
input_shape = waveform.shape
num_frames, num_feats = input_shape[-2:]
waveform = waveform.view(-1, num_frames, num_feats)
num_channels = waveform.shape[0]
dtype = waveform.dtype dtype = waveform.dtype
device = waveform.device device = waveform.device
last_window_start = last_window_end = -1 last_window_start = last_window_end = -1
num_frames, num_feats = waveform.shape cur_sum = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
cur_sum = torch.zeros(num_feats, dtype=dtype, device=device) cur_sumsq = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
cur_sumsq = torch.zeros(num_feats, dtype=dtype, device=device)
cmn_waveform = torch.zeros( cmn_waveform = torch.zeros(
num_frames, num_feats, dtype=dtype, device=device) num_channels, num_frames, num_feats, dtype=dtype, device=device)
for t in range(num_frames): for t in range(num_frames):
window_start = 0 window_start = 0
window_end = 0 window_end = 0
...@@ -1742,33 +1746,37 @@ def sliding_window_cmn( ...@@ -1742,33 +1746,37 @@ def sliding_window_cmn(
if window_start < 0: if window_start < 0:
window_start = 0 window_start = 0
if last_window_start == -1: if last_window_start == -1:
input_part = waveform[window_start: window_end - window_start] input_part = waveform[:, window_start: window_end - window_start, :]
cur_sum += torch.sum(input_part, 0) cur_sum += torch.sum(input_part, 1)
if norm_vars: if norm_vars:
cur_sumsq += torch.cumsum(input_part ** 2, 0)[-1] cur_sumsq += torch.cumsum(input_part ** 2, 1)[:, -1, :]
else: else:
if window_start > last_window_start: if window_start > last_window_start:
frame_to_remove = waveform[last_window_start] frame_to_remove = waveform[:, last_window_start, :]
cur_sum -= frame_to_remove cur_sum -= frame_to_remove
if norm_vars: if norm_vars:
cur_sumsq -= (frame_to_remove ** 2) cur_sumsq -= (frame_to_remove ** 2)
if window_end > last_window_end: if window_end > last_window_end:
frame_to_add = waveform[last_window_end] frame_to_add = waveform[:, last_window_end, :]
cur_sum += frame_to_add cur_sum += frame_to_add
if norm_vars: if norm_vars:
cur_sumsq += (frame_to_add ** 2) cur_sumsq += (frame_to_add ** 2)
window_frames = window_end - window_start window_frames = window_end - window_start
last_window_start = window_start last_window_start = window_start
last_window_end = window_end last_window_end = window_end
cmn_waveform[t] = waveform[t] - cur_sum / window_frames cmn_waveform[:, t, :] = waveform[:, t, :] - cur_sum / window_frames
if norm_vars: if norm_vars:
if window_frames == 1: if window_frames == 1:
cmn_waveform[t] = torch.zeros( cmn_waveform[:, t, :] = torch.zeros(
num_feats, dtype=dtype, device=device) num_channels, num_feats, dtype=dtype, device=device)
else: else:
variance = cur_sumsq variance = cur_sumsq
variance = variance / window_frames variance = variance / window_frames
variance -= ((cur_sum ** 2) / (window_frames ** 2)) variance -= ((cur_sum ** 2) / (window_frames ** 2))
variance = torch.pow(variance, -0.5) variance = torch.pow(variance, -0.5)
cmn_waveform[t] *= variance cmn_waveform[:, t, :] *= variance
cmn_waveform = cmn_waveform.view(input_shape[:-2] + (num_frames, num_feats))
if len(input_shape) == 2:
cmn_waveform = cmn_waveform.squeeze(0)
return cmn_waveform return cmn_waveform
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment