Unverified Commit db1e7da9 authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Migrate TimeStretch and AmplitudeToDB to torch.nn.Module (#456)

* AmplitudeToDB to torch.nn.Module

* TimeStretch use torch.nn.Module
parent babc24af
...@@ -137,7 +137,7 @@ class GriffinLim(torch.nn.Module): ...@@ -137,7 +137,7 @@ class GriffinLim(torch.nn.Module):
self.normalized, self.n_iter, self.momentum, self.length, self.rand_init) self.normalized, self.n_iter, self.momentum, self.length, self.rand_init)
class AmplitudeToDB(torch.jit.ScriptModule): class AmplitudeToDB(torch.nn.Module):
r"""Turn a tensor from the power/amplitude scale to the decibel scale. r"""Turn a tensor from the power/amplitude scale to the decibel scale.
This output depends on the maximum value in the input tensor, and so This output depends on the maximum value in the input tensor, and so
...@@ -157,7 +157,7 @@ class AmplitudeToDB(torch.jit.ScriptModule): ...@@ -157,7 +157,7 @@ class AmplitudeToDB(torch.jit.ScriptModule):
self.stype = stype self.stype = stype
if top_db is not None and top_db < 0: if top_db is not None and top_db < 0:
raise ValueError('top_db must be positive value') raise ValueError('top_db must be positive value')
self.top_db = torch.jit.Attribute(top_db, Optional[float]) self.top_db = top_db
self.multiplier = 10.0 if stype == 'power' else 20.0 self.multiplier = 10.0 if stype == 'power' else 20.0
self.amin = 1e-10 self.amin = 1e-10
self.ref_value = 1.0 self.ref_value = 1.0
...@@ -592,7 +592,7 @@ class ComputeDeltas(torch.nn.Module): ...@@ -592,7 +592,7 @@ class ComputeDeltas(torch.nn.Module):
return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode) return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)
class TimeStretch(torch.jit.ScriptModule): class TimeStretch(torch.nn.Module):
r"""Stretch stft in time without modifying pitch for a given rate. r"""Stretch stft in time without modifying pitch for a given rate.
Args: Args:
...@@ -610,8 +610,7 @@ class TimeStretch(torch.jit.ScriptModule): ...@@ -610,8 +610,7 @@ class TimeStretch(torch.jit.ScriptModule):
n_fft = (n_freq - 1) * 2 n_fft = (n_freq - 1) * 2
hop_length = hop_length if hop_length is not None else n_fft // 2 hop_length = hop_length if hop_length is not None else n_fft // 2
phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None] self.phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None]
self.phase_advance = torch.jit.Attribute(phase_advance, torch.Tensor)
def forward(self, complex_specgrams, overriding_rate=None): def forward(self, complex_specgrams, overriding_rate=None):
# type: (Tensor, Optional[float]) -> Tensor # type: (Tensor, Optional[float]) -> Tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment