Unverified Commit 17aa81ea authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Update SpectralCentroid to accept window_fn and wkwargs (#1216)

parent 114461cc
......@@ -1087,8 +1087,9 @@ class SpectralCentroid(torch.nn.Module):
win_length (int or None, optional): Window size. (Default: ``n_fft``)
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
pad (int, optional): Two sided padding of signal. (Default: ``0``)
window(Tensor, optional): A window tensor that is applied/multiplied to each frame.
(Default: ``torch.hann_window(win_length)``)
window_fn (Callable[..., Tensor], optional): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
......@@ -1102,14 +1103,14 @@ class SpectralCentroid(torch.nn.Module):
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
pad: int = 0,
window: Optional[Tensor] = None) -> None:
window_fn: Callable[..., Tensor] = torch.hann_window,
wkwargs: Optional[dict] = None) -> None:
super(SpectralCentroid, self).__init__()
self.sample_rate = sample_rate
self.n_fft = n_fft
self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 2
if window is None:
window = torch.hann_window(self.win_length)
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.register_buffer('window', window)
self.pad = pad
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment