Unverified Commit bf580c75 authored by nateanl's avatar nateanl Committed by GitHub
Browse files

Refactor transforms.Fade on GPU computation (#1871)

parent 25a8adf6
...@@ -1025,12 +1025,15 @@ class Fade(torch.nn.Module): ...@@ -1025,12 +1025,15 @@ class Fade(torch.nn.Module):
""" """
waveform_length = waveform.size()[-1] waveform_length = waveform.size()[-1]
device = waveform.device device = waveform.device
return self._fade_in(waveform_length).to(device) * \ return (
self._fade_out(waveform_length).to(device) * waveform self._fade_in(waveform_length, device)
* self._fade_out(waveform_length, device)
* waveform
)
def _fade_in(self, waveform_length: int) -> Tensor: def _fade_in(self, waveform_length: int, device: torch.device) -> Tensor:
fade = torch.linspace(0, 1, self.fade_in_len) fade = torch.linspace(0, 1, self.fade_in_len, device=device)
ones = torch.ones(waveform_length - self.fade_in_len) ones = torch.ones(waveform_length - self.fade_in_len, device=device)
if self.fade_shape == "linear": if self.fade_shape == "linear":
fade = fade fade = fade
...@@ -1049,9 +1052,9 @@ class Fade(torch.nn.Module): ...@@ -1049,9 +1052,9 @@ class Fade(torch.nn.Module):
return torch.cat((fade, ones)).clamp_(0, 1) return torch.cat((fade, ones)).clamp_(0, 1)
def _fade_out(self, waveform_length: int) -> Tensor: def _fade_out(self, waveform_length: int, device: torch.device) -> Tensor:
fade = torch.linspace(0, 1, self.fade_out_len) fade = torch.linspace(0, 1, self.fade_out_len, device=device)
ones = torch.ones(waveform_length - self.fade_out_len) ones = torch.ones(waveform_length - self.fade_out_len, device=device)
if self.fade_shape == "linear": if self.fade_shape == "linear":
fade = - fade + 1 fade = - fade + 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment