Unverified Commit 99c52600 authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

make all functional torchscriptable. (#326)

parent 6d5f0b43
...@@ -442,7 +442,9 @@ def complex_norm(complex_tensor, power=1.0): ...@@ -442,7 +442,9 @@ def complex_norm(complex_tensor, power=1.0):
return torch.norm(complex_tensor, 2, -1).pow(power) return torch.norm(complex_tensor, 2, -1).pow(power)
@torch.jit.script
def angle(complex_tensor): def angle(complex_tensor):
# type: (Tensor) -> Tensor
r"""Compute the angle of complex tensor input. r"""Compute the angle of complex tensor input.
Args: Args:
...@@ -454,7 +456,9 @@ def angle(complex_tensor): ...@@ -454,7 +456,9 @@ def angle(complex_tensor):
return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0]) return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0])
@torch.jit.script
def magphase(complex_tensor, power=1.0): def magphase(complex_tensor, power=1.0):
# type: (Tensor, float) -> Tuple[Tensor, Tensor]
r"""Separate a complex-valued spectrogram with shape `(*, 2)` into its magnitude and phase. r"""Separate a complex-valued spectrogram with shape `(*, 2)` into its magnitude and phase.
Args: Args:
...@@ -534,6 +538,7 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): ...@@ -534,6 +538,7 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
return complex_specgrams_stretch return complex_specgrams_stretch
@torch.jit.script
def lfilter(waveform, a_coeffs, b_coeffs): def lfilter(waveform, a_coeffs, b_coeffs):
# type: (Tensor, Tensor, Tensor) -> Tensor # type: (Tensor, Tensor, Tensor) -> Tensor
r""" r"""
...@@ -595,6 +600,7 @@ def lfilter(waveform, a_coeffs, b_coeffs): ...@@ -595,6 +600,7 @@ def lfilter(waveform, a_coeffs, b_coeffs):
return torch.min(ones, torch.max(ones * -1, padded_output_waveform[:, (n_order - 1):])) return torch.min(ones, torch.max(ones * -1, padded_output_waveform[:, (n_order - 1):]))
@torch.jit.script
def biquad(waveform, b0, b1, b2, a0, a1, a2): def biquad(waveform, b0, b1, b2, a0, a1, a2):
# type: (Tensor, float, float, float, float, float, float) -> Tensor # type: (Tensor, float, float, float, float, float, float) -> Tensor
r"""Performs a biquad filter of input tensor. Initial conditions set to 0. r"""Performs a biquad filter of input tensor. Initial conditions set to 0.
...@@ -625,11 +631,13 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2): ...@@ -625,11 +631,13 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2):
def _dB2Linear(x): def _dB2Linear(x):
# type: (float) -> float
return math.exp(x * math.log(10) / 20.0) return math.exp(x * math.log(10) / 20.0)
@torch.jit.script
def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
# type: (Tensor, int, float, Optional[float]) -> Tensor # type: (Tensor, int, float, float) -> Tensor
r"""Designs biquad highpass filter and performs filtering. Similar to SoX implementation. r"""Designs biquad highpass filter and performs filtering. Similar to SoX implementation.
Args: Args:
...@@ -642,10 +650,10 @@ def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): ...@@ -642,10 +650,10 @@ def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
output_waveform (torch.Tensor): Dimension of `(channel, time)` output_waveform (torch.Tensor): Dimension of `(channel, time)`
""" """
GAIN = 1 # TBD - add as a parameter GAIN = 1.
w0 = 2 * math.pi * cutoff_freq / sample_rate w0 = 2 * math.pi * cutoff_freq / sample_rate
A = math.exp(GAIN / 40.0 * math.log(10)) A = math.exp(GAIN / 40.0 * math.log(10))
alpha = math.sin(w0) / 2 / Q alpha = math.sin(w0) / 2. / Q
mult = _dB2Linear(max(GAIN, 0)) mult = _dB2Linear(max(GAIN, 0))
b0 = (1 + math.cos(w0)) / 2 b0 = (1 + math.cos(w0)) / 2
...@@ -657,8 +665,9 @@ def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): ...@@ -657,8 +665,9 @@ def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
return biquad(waveform, b0, b1, b2, a0, a1, a2) return biquad(waveform, b0, b1, b2, a0, a1, a2)
@torch.jit.script
def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
# type: (Tensor, int, float, Optional[float]) -> Tensor # type: (Tensor, int, float, float) -> Tensor
r"""Designs biquad lowpass filter and performs filtering. Similar to SoX implementation. r"""Designs biquad lowpass filter and performs filtering. Similar to SoX implementation.
Args: Args:
...@@ -671,7 +680,7 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): ...@@ -671,7 +680,7 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
output_waveform (torch.Tensor): Dimension of `(channel, time)` output_waveform (torch.Tensor): Dimension of `(channel, time)`
""" """
GAIN = 1 GAIN = 1.
w0 = 2 * math.pi * cutoff_freq / sample_rate w0 = 2 * math.pi * cutoff_freq / sample_rate
A = math.exp(GAIN / 40.0 * math.log(10)) A = math.exp(GAIN / 40.0 * math.log(10))
alpha = math.sin(w0) / 2 / Q alpha = math.sin(w0) / 2 / Q
...@@ -686,6 +695,7 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): ...@@ -686,6 +695,7 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
return biquad(waveform, b0, b1, b2, a0, a1, a2) return biquad(waveform, b0, b1, b2, a0, a1, a2)
# @torch.jit.script
def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707): def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707):
# type: (Tensor, int, float, float, float) -> Tensor # type: (Tensor, int, float, float, float) -> Tensor
r"""Designs biquad peaking equalizer filter and performs filtering. Similar to SoX implementation. r"""Designs biquad peaking equalizer filter and performs filtering. Similar to SoX implementation.
...@@ -833,6 +843,7 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): ...@@ -833,6 +843,7 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
) / denom ) / denom
@torch.jit.script
def _compute_nccf(waveform, sample_rate, frame_time, freq_low): def _compute_nccf(waveform, sample_rate, frame_time, freq_low):
# type: (Tensor, int, float, int) -> Tensor # type: (Tensor, int, float, int) -> Tensor
r""" r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment