Commit 1bd3dbb9 authored by Tongzhou Wang's avatar Tongzhou Wang Committed by Soumith Chintala
Browse files

update torch audio for new stft signature (#55)

parent 18c01bef
......@@ -192,7 +192,9 @@ class SPECTROGRAM(object):
self.sr = sr
self.ws = ws
self.hop = hop if hop is not None else ws // 2
self.n_fft = n_fft # number of fft bins
# number of fft bins. the returned STFT result will have n_fft // 2 + 1
# number of frequecies due to onesided=True in torch.stft
self.n_fft = (n_fft - 1) * 2 if n_fft is not None else ws
self.pad = pad
self.wkwargs = wkwargs
......@@ -212,8 +214,17 @@ class SPECTROGRAM(object):
assert sig.dim() == 2
spec_f = torch.stft(sig, self.ws, self.hop, self.n_fft,
True, True, self.window, self.pad) # (c, l, n_fft, 2)
if self.pad > 0:
c, n = sig.size()
new_sig = sig.new_empty(c, n + self.pad * 2)
new_sig[:, :self.pad].zero_()
new_sig[:, -self.pad:].zero_()
new_sig.narrow(1, self.pad, n).copy_(sig)
sig = new_sig
spec_f = torch.stft(sig, self.n_fft, self.hop, self.ws,
self.window, center=False,
normalized=True, onesided=True).transpose(1, 2)
spec_f /= self.window.pow(2).sum().sqrt()
spec_f = spec_f.pow(2).sum(-1) # get power of "complex" tensor (c, l, n_fft)
return spec_f if is_variable else spec_f.data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment