Commit b311c4cc authored by Jeremy Howard's avatar Jeremy Howard Committed by Soumith Chintala
Browse files

Bug fix: Use correct device for MEL2 functions so MEL2 works on CUDA tensors (#77)

parent d62d3c0b
......@@ -196,6 +196,7 @@ class SPECTROGRAM(object):
if self.pad > 0:
with torch.no_grad():
sig = torch.nn.functional.pad(sig, (self.pad, self.pad), "constant")
self.window = self.window.to(sig.device)
spec_f = torch.stft(sig, self.n_fft, self.hop, self.ws,
self.window, center=False,
normalized=True, onesided=True).transpose(1, 2)
......@@ -225,7 +226,7 @@ class F2M(object):
def __call__(self, spec_f):
if self.fb is None:
self.fb = self._create_fb_matrix(spec_f.size(2))
self.fb = self._create_fb_matrix(spec_f.size(2)).to(spec_f.device)
spec_m = torch.matmul(spec_f, self.fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return spec_m
......@@ -280,7 +281,7 @@ class SPEC2DB(object):
spec_db = self.multiplier * torch.log10(spec / spec.max()) # power -> dB
if self.top_db is not None:
spec_db = torch.max(spec_db, torch.tensor(self.top_db, dtype=spec_db.dtype))
spec_db = torch.max(spec_db, spec_db.new_full((1,),self.top_db))
return spec_db
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment