Unverified Commit feede97e authored by jayleverett's avatar jayleverett Committed by GitHub
Browse files

Put output tensor on proper device in `get_whitenoise()` (#1744)

* put output tensor on device in `get_whitenoise()`

* Update `get_spectrogram()` so that window uses same device as waveform

* put window on proper device in `test_griffinlim()`
parent 768432c3
......@@ -75,6 +75,9 @@ def get_whitenoise(
tensor.clamp_(-1.0, 1.0)
if not channels_first:
tensor = tensor.t()
tensor = tensor.to(device)
return convert_tensor_encoding(tensor, dtype)
......@@ -137,7 +140,7 @@ def get_spectrogram(
"""
hop_length = hop_length or n_fft // 4
win_length = win_length or n_fft
window = torch.hann_window(win_length) if window is None else window
window = torch.hann_window(win_length, device=waveform.device) if window is None else window
spec = torch.stft(
waveform,
n_fft=n_fft,
......
......@@ -33,7 +33,7 @@ class Functional(TestBaseMixin):
n_fft = 400
win_length = n_fft
hop_length = n_fft // 4
window = torch.hann_window(win_length)
window = torch.hann_window(win_length, device=self.device)
power = 1
# GriffinLim params
n_iter = 8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment