Unverified Commit db1e7da9 authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Migrate TimeStretch and AmplitudeToDB to torch.nn.Module (#456)

* AmplitudeToDB to torch.nn.Module

* TimeStretch use torch.nn.Module
parent babc24af
......@@ -137,7 +137,7 @@ class GriffinLim(torch.nn.Module):
self.normalized, self.n_iter, self.momentum, self.length, self.rand_init)
class AmplitudeToDB(torch.jit.ScriptModule):
class AmplitudeToDB(torch.nn.Module):
r"""Turn a tensor from the power/amplitude scale to the decibel scale.
This output depends on the maximum value in the input tensor, and so
......@@ -157,7 +157,7 @@ class AmplitudeToDB(torch.jit.ScriptModule):
self.stype = stype
if top_db is not None and top_db < 0:
raise ValueError('top_db must be positive value')
self.top_db = torch.jit.Attribute(top_db, Optional[float])
self.top_db = top_db
self.multiplier = 10.0 if stype == 'power' else 20.0
self.amin = 1e-10
self.ref_value = 1.0
......@@ -592,7 +592,7 @@ class ComputeDeltas(torch.nn.Module):
return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)
class TimeStretch(torch.jit.ScriptModule):
class TimeStretch(torch.nn.Module):
r"""Stretch stft in time without modifying pitch for a given rate.
Args:
......@@ -610,8 +610,7 @@ class TimeStretch(torch.jit.ScriptModule):
n_fft = (n_freq - 1) * 2
hop_length = hop_length if hop_length is not None else n_fft // 2
phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None]
self.phase_advance = torch.jit.Attribute(phase_advance, torch.Tensor)
self.phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None]
def forward(self, complex_specgrams, overriding_rate=None):
# type: (Tensor, Optional[float]) -> Tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment