"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "8369196703e07a42e7835a65b223c42d4e993276"
Unverified Commit 17aa81ea authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Update SpectralCentroid to accept window_fn and wkwargs (#1216)

parent 114461cc
...@@ -1087,8 +1087,9 @@ class SpectralCentroid(torch.nn.Module): ...@@ -1087,8 +1087,9 @@ class SpectralCentroid(torch.nn.Module):
win_length (int or None, optional): Window size. (Default: ``n_fft``) win_length (int or None, optional): Window size. (Default: ``n_fft``)
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``) hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
pad (int, optional): Two sided padding of signal. (Default: ``0``) pad (int, optional): Two sided padding of signal. (Default: ``0``)
window(Tensor, optional): A window tensor that is applied/multiplied to each frame. window_fn (Callable[..., Tensor], optional): A function to create a window tensor
(Default: ``torch.hann_window(win_length)``) that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
Example Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True) >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
...@@ -1102,14 +1103,14 @@ class SpectralCentroid(torch.nn.Module): ...@@ -1102,14 +1103,14 @@ class SpectralCentroid(torch.nn.Module):
win_length: Optional[int] = None, win_length: Optional[int] = None,
hop_length: Optional[int] = None, hop_length: Optional[int] = None,
pad: int = 0, pad: int = 0,
window: Optional[Tensor] = None) -> None: window_fn: Callable[..., Tensor] = torch.hann_window,
wkwargs: Optional[dict] = None) -> None:
super(SpectralCentroid, self).__init__() super(SpectralCentroid, self).__init__()
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.n_fft = n_fft self.n_fft = n_fft
self.win_length = win_length if win_length is not None else n_fft self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 2 self.hop_length = hop_length if hop_length is not None else self.win_length // 2
if window is None: window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
window = torch.hann_window(self.win_length)
self.register_buffer('window', window) self.register_buffer('window', window)
self.pad = pad self.pad = pad
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment