Unverified Commit bf580c75 authored by nateanl's avatar nateanl Committed by GitHub
Browse files

Refactor transforms.Fade on GPU computation (#1871)

parent 25a8adf6
......@@ -1025,12 +1025,15 @@ class Fade(torch.nn.Module):
"""
waveform_length = waveform.size()[-1]
device = waveform.device
return self._fade_in(waveform_length).to(device) * \
self._fade_out(waveform_length).to(device) * waveform
return (
self._fade_in(waveform_length, device)
* self._fade_out(waveform_length, device)
* waveform
)
def _fade_in(self, waveform_length: int) -> Tensor:
fade = torch.linspace(0, 1, self.fade_in_len)
ones = torch.ones(waveform_length - self.fade_in_len)
def _fade_in(self, waveform_length: int, device: torch.device) -> Tensor:
fade = torch.linspace(0, 1, self.fade_in_len, device=device)
ones = torch.ones(waveform_length - self.fade_in_len, device=device)
if self.fade_shape == "linear":
fade = fade
......@@ -1049,9 +1052,9 @@ class Fade(torch.nn.Module):
return torch.cat((fade, ones)).clamp_(0, 1)
def _fade_out(self, waveform_length: int) -> Tensor:
fade = torch.linspace(0, 1, self.fade_out_len)
ones = torch.ones(waveform_length - self.fade_out_len)
def _fade_out(self, waveform_length: int, device: torch.device) -> Tensor:
fade = torch.linspace(0, 1, self.fade_out_len, device=device)
ones = torch.ones(waveform_length - self.fade_out_len, device=device)
if self.fade_shape == "linear":
fade = - fade + 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment