Unverified Commit f1a5503e authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

phase_advance should be a buffer so it moves device correctly (#457)



* phase_advance should be a buffer so it moves device correctly

* flake8
Co-authored-by: default avatarVincent QB <vincentqb@users.noreply.github.com>
parent db1e7da9
...@@ -610,7 +610,7 @@ class TimeStretch(torch.nn.Module): ...@@ -610,7 +610,7 @@ class TimeStretch(torch.nn.Module):
n_fft = (n_freq - 1) * 2 n_fft = (n_freq - 1) * 2
hop_length = hop_length if hop_length is not None else n_fft // 2 hop_length = hop_length if hop_length is not None else n_fft // 2
self.phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None] self.register_buffer('phase_advance', torch.linspace(0, math.pi * hop_length, n_freq)[..., None])
def forward(self, complex_specgrams, overriding_rate=None): def forward(self, complex_specgrams, overriding_rate=None):
# type: (Tensor, Optional[float]) -> Tensor # type: (Tensor, Optional[float]) -> Tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment