Unverified Commit ea857940 authored by Jcaw's avatar Jcaw Committed by GitHub
Browse files

[BC-Breaking] Rename `sliding_window_cmn` arg for correctness (#1347)

* Change the name of the specgram named `waveform`

`F.sliding_window_cmn` takes a spectrogram as input (of shape
`(..., freq, time)`). However, this spectrogram is named `waveform`.
This appears to be an error, so rename this (and the output tensor) to
reflect that both are spectrograms.

* Correct tensor description in docstring

The output tensor of `F.sliding_window_cmn` is also a spectrogram.
Update the description to reflect this.
parent 2593e2e8
...@@ -934,7 +934,7 @@ def detect_pitch_frequency( ...@@ -934,7 +934,7 @@ def detect_pitch_frequency(
def sliding_window_cmn( def sliding_window_cmn(
waveform: Tensor, specgram: Tensor,
cmn_window: int = 600, cmn_window: int = 600,
min_cmn_window: int = 100, min_cmn_window: int = 100,
center: bool = False, center: bool = False,
...@@ -944,7 +944,7 @@ def sliding_window_cmn( ...@@ -944,7 +944,7 @@ def sliding_window_cmn(
Apply sliding-window cepstral mean (and optionally variance) normalization per utterance. Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.
Args: Args:
waveform (Tensor): Tensor of audio of dimension (..., freq, time) specgram (Tensor): Tensor of audio of dimension (..., freq, time)
cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600) cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
min_cmn_window (int, optional): Minimum CMN window used at start of decoding (adds latency only at start). min_cmn_window (int, optional): Minimum CMN window used at start of decoding (adds latency only at start).
Only applicable if center == false, ignored if center==true (int, default = 100) Only applicable if center == false, ignored if center==true (int, default = 100)
...@@ -953,19 +953,19 @@ def sliding_window_cmn( ...@@ -953,19 +953,19 @@ def sliding_window_cmn(
norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false) norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)
Returns: Returns:
Tensor: Tensor of freq of dimension (..., frame) Tensor: Tensor matching input shape (..., freq, time)
""" """
input_shape = waveform.shape input_shape = specgram.shape
num_frames, num_feats = input_shape[-2:] num_frames, num_feats = input_shape[-2:]
waveform = waveform.view(-1, num_frames, num_feats) specgram = specgram.view(-1, num_frames, num_feats)
num_channels = waveform.shape[0] num_channels = specgram.shape[0]
dtype = waveform.dtype dtype = specgram.dtype
device = waveform.device device = specgram.device
last_window_start = last_window_end = -1 last_window_start = last_window_end = -1
cur_sum = torch.zeros(num_channels, num_feats, dtype=dtype, device=device) cur_sum = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
cur_sumsq = torch.zeros(num_channels, num_feats, dtype=dtype, device=device) cur_sumsq = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
cmn_waveform = torch.zeros( cmn_specgram = torch.zeros(
num_channels, num_frames, num_feats, dtype=dtype, device=device) num_channels, num_frames, num_feats, dtype=dtype, device=device)
for t in range(num_frames): for t in range(num_frames):
window_start = 0 window_start = 0
...@@ -988,40 +988,40 @@ def sliding_window_cmn( ...@@ -988,40 +988,40 @@ def sliding_window_cmn(
if window_start < 0: if window_start < 0:
window_start = 0 window_start = 0
if last_window_start == -1: if last_window_start == -1:
input_part = waveform[:, window_start: window_end - window_start, :] input_part = specgram[:, window_start: window_end - window_start, :]
cur_sum += torch.sum(input_part, 1) cur_sum += torch.sum(input_part, 1)
if norm_vars: if norm_vars:
cur_sumsq += torch.cumsum(input_part ** 2, 1)[:, -1, :] cur_sumsq += torch.cumsum(input_part ** 2, 1)[:, -1, :]
else: else:
if window_start > last_window_start: if window_start > last_window_start:
frame_to_remove = waveform[:, last_window_start, :] frame_to_remove = specgram[:, last_window_start, :]
cur_sum -= frame_to_remove cur_sum -= frame_to_remove
if norm_vars: if norm_vars:
cur_sumsq -= (frame_to_remove ** 2) cur_sumsq -= (frame_to_remove ** 2)
if window_end > last_window_end: if window_end > last_window_end:
frame_to_add = waveform[:, last_window_end, :] frame_to_add = specgram[:, last_window_end, :]
cur_sum += frame_to_add cur_sum += frame_to_add
if norm_vars: if norm_vars:
cur_sumsq += (frame_to_add ** 2) cur_sumsq += (frame_to_add ** 2)
window_frames = window_end - window_start window_frames = window_end - window_start
last_window_start = window_start last_window_start = window_start
last_window_end = window_end last_window_end = window_end
cmn_waveform[:, t, :] = waveform[:, t, :] - cur_sum / window_frames cmn_specgram[:, t, :] = specgram[:, t, :] - cur_sum / window_frames
if norm_vars: if norm_vars:
if window_frames == 1: if window_frames == 1:
cmn_waveform[:, t, :] = torch.zeros( cmn_specgram[:, t, :] = torch.zeros(
num_channels, num_feats, dtype=dtype, device=device) num_channels, num_feats, dtype=dtype, device=device)
else: else:
variance = cur_sumsq variance = cur_sumsq
variance = variance / window_frames variance = variance / window_frames
variance -= ((cur_sum ** 2) / (window_frames ** 2)) variance -= ((cur_sum ** 2) / (window_frames ** 2))
variance = torch.pow(variance, -0.5) variance = torch.pow(variance, -0.5)
cmn_waveform[:, t, :] *= variance cmn_specgram[:, t, :] *= variance
cmn_waveform = cmn_waveform.view(input_shape[:-2] + (num_frames, num_feats)) cmn_specgram = cmn_specgram.view(input_shape[:-2] + (num_frames, num_feats))
if len(input_shape) == 2: if len(input_shape) == 2:
cmn_waveform = cmn_waveform.squeeze(0) cmn_specgram = cmn_specgram.squeeze(0)
return cmn_waveform return cmn_specgram
def spectral_centroid( def spectral_centroid(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment