Commit 34f3c12e authored by Charles J.Y. Yoon's avatar Charles J.Y. Yoon Committed by Vincent QB
Browse files

Module GPU test fixes (#369)

* Fixed GPU tests
parent f3365ecf
......@@ -51,7 +51,7 @@ class Spectrogram(torch.nn.Module):
self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 2
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.window = window
self.register_buffer('window', window)
self.pad = pad
self.power = power
self.normalized = normalized
......@@ -136,7 +136,7 @@ class MelScale(torch.nn.Module):
fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
self.fb = fb
self.register_buffer('fb', fb)
def forward(self, specgram):
r"""
......@@ -260,7 +260,7 @@ class MFCC(torch.nn.Module):
if self.n_mfcc > self.MelSpectrogram.n_mels:
raise ValueError('Cannot select more MFCC coefficients than # mel bins')
dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm)
self.dct_mat = dct_mat
self.register_buffer('dct_mat', dct_mat)
self.log_mels = log_mels
def forward(self, waveform):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment