Unverified Commit 04e68471 authored by moto's avatar moto Committed by GitHub
Browse files

Fix Fade device compatibility (#508)

* Fix Fade device compatibility
parent adef7b94
...@@ -723,8 +723,9 @@ class Fade(torch.nn.Module): ...@@ -723,8 +723,9 @@ class Fade(torch.nn.Module):
Tensor: Tensor of audio of dimension (..., time). Tensor: Tensor of audio of dimension (..., time).
""" """
waveform_length = waveform.size()[-1] waveform_length = waveform.size()[-1]
device = waveform.device
return self._fade_in(waveform_length) * self._fade_out(waveform_length) * waveform return self._fade_in(waveform_length).to(device) * \
self._fade_out(waveform_length).to(device) * waveform
def _fade_in(self, waveform_length: int) -> Tensor: def _fade_in(self, waveform_length: int) -> Tensor:
fade = torch.linspace(0, 1, self.fade_in_len) fade = torch.linspace(0, 1, self.fade_in_len)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment