Unverified Commit feede97e authored by jayleverett's avatar jayleverett Committed by GitHub
Browse files

Put output tensor on proper device in `get_whitenoise()` (#1744)

* put output tensor on device in `get_whitenoise()`

* Update `get_spectrogram()` so that window uses same device as waveform

* put window on proper device in `test_griffinlim()`
parent 768432c3
...@@ -75,6 +75,9 @@ def get_whitenoise( ...@@ -75,6 +75,9 @@ def get_whitenoise(
tensor.clamp_(-1.0, 1.0) tensor.clamp_(-1.0, 1.0)
if not channels_first: if not channels_first:
tensor = tensor.t() tensor = tensor.t()
tensor = tensor.to(device)
return convert_tensor_encoding(tensor, dtype) return convert_tensor_encoding(tensor, dtype)
...@@ -137,7 +140,7 @@ def get_spectrogram( ...@@ -137,7 +140,7 @@ def get_spectrogram(
""" """
hop_length = hop_length or n_fft // 4 hop_length = hop_length or n_fft // 4
win_length = win_length or n_fft win_length = win_length or n_fft
window = torch.hann_window(win_length) if window is None else window window = torch.hann_window(win_length, device=waveform.device) if window is None else window
spec = torch.stft( spec = torch.stft(
waveform, waveform,
n_fft=n_fft, n_fft=n_fft,
......
...@@ -33,7 +33,7 @@ class Functional(TestBaseMixin): ...@@ -33,7 +33,7 @@ class Functional(TestBaseMixin):
n_fft = 400 n_fft = 400
win_length = n_fft win_length = n_fft
hop_length = n_fft // 4 hop_length = n_fft // 4
window = torch.hann_window(win_length) window = torch.hann_window(win_length, device=self.device)
power = 1 power = 1
# GriffinLim params # GriffinLim params
n_iter = 8 n_iter = 8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment