Unverified Commit 909e445b authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Make jit compilation optional for function and use nn.Module (#314)

* nn.Module.

* generalizing spectrogram test.

* adding test to compile functionals.

* add cuda/cpu compilation test.

* adding transform test.

* remove standalone jit file.

* update mel scale.

* remove script decorator.

* apply to augmentations too.
parent ef7c55ce
...@@ -16,6 +16,15 @@ if IMPORT_LIBROSA: ...@@ -16,6 +16,15 @@ if IMPORT_LIBROSA:
import librosa import librosa
def _test_torchscript_functional(py_method, *args, **kwargs):
jit_method = torch.jit.script(py_method)
jit_out = jit_method(*args, **kwargs)
py_out = py_method(*args, **kwargs)
assert torch.allclose(jit_out, py_out)
class TestFunctional(unittest.TestCase): class TestFunctional(unittest.TestCase):
data_sizes = [(2, 20), (3, 15), (4, 10)] data_sizes = [(2, 20), (3, 15), (4, 10)]
number_of_trials = 100 number_of_trials = 100
...@@ -25,6 +34,21 @@ class TestFunctional(unittest.TestCase): ...@@ -25,6 +34,21 @@ class TestFunctional(unittest.TestCase):
test_filepath = os.path.join(test_dirpath, 'assets', test_filepath = os.path.join(test_dirpath, 'assets',
'steam-train-whistle-daniel_simon.mp3') 'steam-train-whistle-daniel_simon.mp3')
def test_torchscript_spectrogram(self):
tensor = torch.rand((1, 1000))
n_fft = 400
ws = 400
hop = 200
pad = 0
window = torch.hann_window(ws)
power = 2
normalize = False
_test_torchscript_functional(
F.spectrogram, tensor, pad, window, n_fft, hop, ws, power, normalize
)
def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8): def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8):
computed = F.compute_deltas(specgram, win_length=win_length) computed = F.compute_deltas(specgram, win_length=win_length)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
...@@ -49,6 +73,7 @@ class TestFunctional(unittest.TestCase): ...@@ -49,6 +73,7 @@ class TestFunctional(unittest.TestCase):
specgram = torch.randn(channel, n_mfcc, time) specgram = torch.randn(channel, n_mfcc, time)
computed = F.compute_deltas(specgram, win_length=win_length) computed = F.compute_deltas(specgram, win_length=win_length)
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
_test_torchscript_functional(F.compute_deltas, specgram, win_length=win_length)
def test_batch_pitch(self): def test_batch_pitch(self):
waveform, sample_rate = torchaudio.load(self.test_filepath) waveform, sample_rate = torchaudio.load(self.test_filepath)
...@@ -63,6 +88,7 @@ class TestFunctional(unittest.TestCase): ...@@ -63,6 +88,7 @@ class TestFunctional(unittest.TestCase):
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected)) self.assertTrue(torch.allclose(computed, expected))
_test_torchscript_functional(F.detect_pitch_frequency, waveform, sample_rate)
def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8): def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8):
# trim sound for case when constructed signal is shorter than original # trim sound for case when constructed signal is shorter than original
...@@ -424,6 +450,74 @@ def test_phase_vocoder(complex_specgrams, rate, hop_length): ...@@ -424,6 +450,74 @@ def test_phase_vocoder(complex_specgrams, rate, hop_length):
assert np.allclose(complex_stretch, expected_complex_stretch, atol=1e-5) assert np.allclose(complex_stretch, expected_complex_stretch, atol=1e-5)
def test_torchscript_create_fb_matrix(self):
n_stft = 100
f_min = 0.0
f_max = 20.0
n_mels = 10
sample_rate = 16000
_test_torchscript_functional(F.create_fb_matrix, n_stft, f_min, f_max, n_mels, sample_rate)
def test_torchscript_amplitude_to_DB(self):
spec = torch.rand((6, 201))
multiplier = 10.0
amin = 1e-10
db_multiplier = 0.0
top_db = 80.0
_test_torchscript_functional(F.amplitude_to_DB, spec, multiplier, amin, db_multiplier, top_db)
def test_torchscript_create_dct(self):
n_mfcc = 40
n_mels = 128
norm = "ortho"
_test_torchscript_functional(F.create_dct, n_mfcc, n_mels, norm)
def test_torchscript_mu_law_encoding(self):
tensor = torch.rand((1, 10))
qc = 256
_test_torchscript_functional(F.mu_law_encoding, tensor, qc)
def test_torchscript_mu_law_decoding(self):
tensor = torch.rand((1, 10))
qc = 256
_test_torchscript_functional(F.mu_law_decoding, tensor, qc)
def test_torchscript_complex_norm(self):
complex_tensor = torch.randn(1, 2, 1025, 400, 2),
power = 2
_test_torchscript_functional(F.complex_norm, complex_tensor, power)
def test_mask_along_axis(self):
specgram = torch.randn(2, 1025, 400),
mask_param = 100
mask_value = 30.
axis = 2
_test_torchscript_functional(F.mask_along_axis, specgram, mask_param, mask_value, axis)
def test_mask_along_axis_iid(self):
specgram = torch.randn(2, 1025, 400),
specgrams = torch.randn(4, 2, 1025, 400),
mask_param = 100
mask_value = 30.
axis = 2
_test_torchscript_functional(F.mask_along_axis_iid, specgrams, mask_param, mask_value, axis)
@pytest.mark.parametrize('complex_tensor', [ @pytest.mark.parametrize('complex_tensor', [
torch.randn(1, 2, 1025, 400, 2), torch.randn(1, 2, 1025, 400, 2),
......
...@@ -9,6 +9,15 @@ import common_utils ...@@ -9,6 +9,15 @@ import common_utils
import time import time
def _test_torchscript_functional(py_method, *args, **kwargs):
jit_method = torch.jit.script(py_method)
jit_out = jit_method(*args, **kwargs)
py_out = py_method(*args, **kwargs)
assert torch.allclose(jit_out, py_out)
class TestFunctionalFiltering(unittest.TestCase): class TestFunctionalFiltering(unittest.TestCase):
test_dirpath, test_dir = common_utils.create_temp_assets_dir() test_dirpath, test_dir = common_utils.create_temp_assets_dir()
...@@ -79,6 +88,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -79,6 +88,7 @@ class TestFunctionalFiltering(unittest.TestCase):
assert len(output_waveform.size()) == 2 assert len(output_waveform.size()) == 2
assert output_waveform.size(0) == waveform.size(0) assert output_waveform.size(0) == waveform.size(0)
assert output_waveform.size(1) == waveform.size(1) assert output_waveform.size(1) == waveform.size(1)
_test_torchscript_functional(F.lfilter, waveform, a_coeffs, b_coeffs)
def test_lfilter(self): def test_lfilter(self):
...@@ -116,6 +126,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -116,6 +126,7 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.lowpass_biquad(waveform, sample_rate, CUTOFF_FREQ) output_waveform = F.lowpass_biquad(waveform, sample_rate, CUTOFF_FREQ)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.lowpass_biquad, waveform, sample_rate, CUTOFF_FREQ)
def test_highpass(self): def test_highpass(self):
""" """
...@@ -135,6 +146,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -135,6 +146,7 @@ class TestFunctionalFiltering(unittest.TestCase):
# TBD - this fails at the 1e-4 level, debug why # TBD - this fails at the 1e-4 level, debug why
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-3) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-3)
_test_torchscript_functional(F.highpass_biquad, waveform, sample_rate, CUTOFF_FREQ)
def test_equalizer(self): def test_equalizer(self):
""" """
...@@ -155,6 +167,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -155,6 +167,7 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.equalizer_biquad(waveform, sample_rate, CENTER_FREQ, GAIN, Q) output_waveform = F.equalizer_biquad(waveform, sample_rate, CENTER_FREQ, GAIN, Q)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.equalizer_biquad, waveform, sample_rate, CENTER_FREQ, GAIN, Q)
def test_perf_biquad_filtering(self): def test_perf_biquad_filtering(self):
...@@ -183,6 +196,9 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -183,6 +196,9 @@ class TestFunctionalFiltering(unittest.TestCase):
_timing_lfilter_run_time = time.time() - _timing_lfilter_filtering _timing_lfilter_run_time = time.time() - _timing_lfilter_filtering
assert torch.allclose(waveform_sox_out, waveform_lfilter_out, atol=1e-4) assert torch.allclose(waveform_sox_out, waveform_lfilter_out, atol=1e-4)
_test_torchscript_functional(
F.lfilter, waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2])
)
if __name__ == "__main__": if __name__ == "__main__":
......
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
import torchaudio.functional as F
import torchaudio.transforms as transforms
import unittest
RUN_CUDA = torch.cuda.is_available()
print('Run test with cuda:', RUN_CUDA)
class Test_JIT(unittest.TestCase):
def _get_script_module(self, f, *args):
# takes a transform function `f` and wraps it in a script module
class MyModule(torch.jit.ScriptModule):
def __init__(self):
super(MyModule, self).__init__()
self.module = f(*args)
self.module.eval()
@torch.jit.script_method
def forward(self, tensor):
return self.module(tensor)
return MyModule()
def _test_script_module(self, tensor, f, *args):
# tests a script module that wraps a transform function `f` by feeding
# the tensor into the forward function
jit_out = self._get_script_module(f, *args).cuda()(tensor)
py_out = f(*args).cuda()(tensor)
self.assertTrue(torch.allclose(jit_out, py_out))
def test_torchscript_spectrogram(self):
@torch.jit.script
def jit_method(sig, pad, window, n_fft, hop, ws, power, normalize):
# type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor
return F.spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize)
tensor = torch.rand((1, 1000))
n_fft = 400
ws = 400
hop = 200
pad = 0
window = torch.hann_window(ws)
power = 2
normalize = False
jit_out = jit_method(tensor, pad, window, n_fft, hop, ws, power, normalize)
py_out = F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_Spectrogram(self):
tensor = torch.rand((1, 1000), device="cuda")
self._test_script_module(tensor, transforms.Spectrogram)
def test_torchscript_create_fb_matrix(self):
@torch.jit.script
def jit_method(n_stft, f_min, f_max, n_mels, sample_rate):
# type: (int, float, float, int, int) -> Tensor
return F.create_fb_matrix(n_stft, f_min, f_max, n_mels, sample_rate)
n_stft = 100
f_min = 0.
f_max = 20.
n_mels = 10
sample_rate = 16000
jit_out = jit_method(n_stft, f_min, f_max, n_mels, sample_rate)
py_out = F.create_fb_matrix(n_stft, f_min, f_max, n_mels, sample_rate)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_MelScale(self):
spec_f = torch.rand((1, 6, 201), device="cuda")
self._test_script_module(spec_f, transforms.MelScale)
def test_torchscript_amplitude_to_DB(self):
@torch.jit.script
def jit_method(spec, multiplier, amin, db_multiplier, top_db):
# type: (Tensor, float, float, float, Optional[float]) -> Tensor
return F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db)
spec = torch.rand((6, 201))
multiplier = 10.
amin = 1e-10
db_multiplier = 0.
top_db = 80.
jit_out = jit_method(spec, multiplier, amin, db_multiplier, top_db)
py_out = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_AmplitudeToDB(self):
spec = torch.rand((6, 201), device="cuda")
self._test_script_module(spec, transforms.AmplitudeToDB)
def test_torchscript_create_dct(self):
@torch.jit.script
def jit_method(n_mfcc, n_mels, norm):
# type: (int, int, Optional[str]) -> Tensor
return F.create_dct(n_mfcc, n_mels, norm)
n_mfcc = 40
n_mels = 128
norm = 'ortho'
jit_out = jit_method(n_mfcc, n_mels, norm)
py_out = F.create_dct(n_mfcc, n_mels, norm)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_MFCC(self):
tensor = torch.rand((1, 1000), device="cuda")
self._test_script_module(tensor, transforms.MFCC)
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_MelSpectrogram(self):
tensor = torch.rand((1, 1000), device="cuda")
self._test_script_module(tensor, transforms.MelSpectrogram)
def test_torchscript_mu_law_encoding(self):
@torch.jit.script
def jit_method(tensor, qc):
# type: (Tensor, int) -> Tensor
return F.mu_law_encoding(tensor, qc)
tensor = torch.rand((1, 10))
qc = 256
jit_out = jit_method(tensor, qc)
py_out = F.mu_law_encoding(tensor, qc)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_MuLawEncoding(self):
tensor = torch.rand((1, 10), device="cuda")
self._test_script_module(tensor, transforms.MuLawEncoding)
def test_torchscript_mu_law_decoding(self):
@torch.jit.script
def jit_method(tensor, qc):
# type: (Tensor, int) -> Tensor
return F.mu_law_decoding(tensor, qc)
tensor = torch.rand((1, 10))
qc = 256
jit_out = jit_method(tensor, qc)
py_out = F.mu_law_decoding(tensor, qc)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_MuLawDecoding(self):
tensor = torch.rand((1, 10), device="cuda")
self._test_script_module(tensor, transforms.MuLawDecoding)
if __name__ == '__main__':
unittest.main()
...@@ -4,6 +4,7 @@ import os ...@@ -4,6 +4,7 @@ import os
import torch import torch
import torchaudio import torchaudio
import torchaudio.augmentations as A
import torchaudio.transforms as transforms import torchaudio.transforms as transforms
import torchaudio.functional as F import torchaudio.functional as F
from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY
...@@ -16,6 +17,32 @@ if IMPORT_LIBROSA: ...@@ -16,6 +17,32 @@ if IMPORT_LIBROSA:
if IMPORT_SCIPY: if IMPORT_SCIPY:
import scipy import scipy
RUN_CUDA = torch.cuda.is_available()
print("Run test with cuda:", RUN_CUDA)
def _test_script_module(f, tensor, *args, **kwargs):
py_method = f(*args, **kwargs)
jit_method = torch.jit.script(py_method)
py_out = py_method(tensor)
jit_out = jit_method(tensor)
assert torch.allclose(jit_out, py_out)
if RUN_CUDA:
tensor = tensor.to("cuda")
py_method = py_method.cuda()
jit_method = torch.jit.script(py_method)
py_out = py_method(tensor)
jit_out = jit_method(tensor)
assert torch.allclose(jit_out, py_out)
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
...@@ -37,6 +64,10 @@ class Tester(unittest.TestCase): ...@@ -37,6 +64,10 @@ class Tester(unittest.TestCase):
waveform = waveform.to(torch.get_default_dtype()) waveform = waveform.to(torch.get_default_dtype())
return waveform / factor return waveform / factor
def test_scriptmodule_Spectrogram(self):
tensor = torch.rand((1, 1000))
_test_script_module(transforms.Spectrogram, tensor)
def test_mu_law_companding(self): def test_mu_law_companding(self):
quantization_channels = 256 quantization_channels = 256
...@@ -51,6 +82,14 @@ class Tester(unittest.TestCase): ...@@ -51,6 +82,14 @@ class Tester(unittest.TestCase):
waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu) waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.) self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
def test_scriptmodule_AmplitudeToDB(self):
spec = torch.rand((6, 201))
_test_script_module(transforms.AmplitudeToDB, spec)
def test_scriptmodule_MelScale(self):
spec_f = torch.rand((1, 6, 201))
_test_script_module(transforms.MelScale, spec_f)
def test_melscale_load_save(self): def test_melscale_load_save(self):
specgram = torch.ones(1, 1000, 100) specgram = torch.ones(1, 1000, 100)
melscale_transform = transforms.MelScale() melscale_transform = transforms.MelScale()
...@@ -65,6 +104,10 @@ class Tester(unittest.TestCase): ...@@ -65,6 +104,10 @@ class Tester(unittest.TestCase):
self.assertEqual(fb_copy.size(), (1000, 128)) self.assertEqual(fb_copy.size(), (1000, 128))
self.assertTrue(torch.allclose(fb, fb_copy)) self.assertTrue(torch.allclose(fb, fb_copy))
def test_scriptmodule_MelSpectrogram(self):
tensor = torch.rand((1, 1000))
_test_script_module(transforms.MelSpectrogram, tensor)
def test_melspectrogram_load_save(self): def test_melspectrogram_load_save(self):
waveform = self.waveform.float() waveform = self.waveform.float()
mel_spectrogram_transform = transforms.MelSpectrogram() mel_spectrogram_transform = transforms.MelSpectrogram()
...@@ -123,6 +166,10 @@ class Tester(unittest.TestCase): ...@@ -123,6 +166,10 @@ class Tester(unittest.TestCase):
self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all()) self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
self.assertEqual(fb_matrix_transform.fb.size(), (400, 100)) self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
def test_scriptmodule_MFCC(self):
tensor = torch.rand((1, 1000))
_test_script_module(transforms.MFCC, tensor)
def test_mfcc(self): def test_mfcc(self):
audio_orig = self.waveform.clone() audio_orig = self.waveform.clone()
audio_scaled = self.scale(audio_orig) # (1, 16000) audio_scaled = self.scale(audio_orig) # (1, 16000)
...@@ -326,6 +373,14 @@ class Tester(unittest.TestCase): ...@@ -326,6 +373,14 @@ class Tester(unittest.TestCase):
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected)) self.assertTrue(torch.allclose(computed, expected))
def test_scriptmodule_MuLawEncoding(self):
tensor = torch.rand((1, 10))
_test_script_module(transforms.MuLawEncoding, tensor)
def test_scriptmodule_MuLawDecoding(self):
tensor = torch.rand((1, 10))
_test_script_module(transforms.MuLawDecoding, tensor)
def test_batch_mulaw(self): def test_batch_mulaw(self):
waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100
...@@ -364,6 +419,21 @@ class Tester(unittest.TestCase): ...@@ -364,6 +419,21 @@ class Tester(unittest.TestCase):
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected)) self.assertTrue(torch.allclose(computed, expected))
def test_scriptmodule_TimeStretch(self):
n_freq = 400
hop_length = 512
fixed_rate = 1.3
tensor = torch.rand((10, 2, n_freq, 10, 2))
_test_script_module(A.TimeStretch, tensor, n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate)
def test_scriptmodule_FrequencyMasking(self):
tensor = torch.rand((10, 2, 50, 10, 2))
_test_script_module(A.FrequencyMasking, tensor, freq_mask_param=60, iid_masks=False)
def test_scriptmodule_TimeMasking(self):
tensor = torch.rand((10, 2, 50, 10, 2))
_test_script_module(A.TimeMasking, tensor, time_mask_param=30, iid_masks=False)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -24,14 +24,13 @@ class TimeStretch(torch.jit.ScriptModule): ...@@ -24,14 +24,13 @@ class TimeStretch(torch.jit.ScriptModule):
def __init__(self, hop_length=None, n_freq=201, fixed_rate=None): def __init__(self, hop_length=None, n_freq=201, fixed_rate=None):
super(TimeStretch, self).__init__() super(TimeStretch, self).__init__()
self.fixed_rate = fixed_rate
n_fft = (n_freq - 1) * 2 n_fft = (n_freq - 1) * 2
hop_length = hop_length if hop_length is not None else n_fft // 2 hop_length = hop_length if hop_length is not None else n_fft // 2
self.fixed_rate = fixed_rate
phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None] phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None]
self.phase_advance = torch.jit.Attribute(phase_advance, torch.Tensor) self.phase_advance = torch.jit.Attribute(phase_advance, torch.Tensor)
@torch.jit.script_method
def forward(self, complex_specgrams, overriding_rate=None): def forward(self, complex_specgrams, overriding_rate=None):
# type: (Tensor, Optional[float]) -> Tensor # type: (Tensor, Optional[float]) -> Tensor
r""" r"""
...@@ -63,7 +62,7 @@ class TimeStretch(torch.jit.ScriptModule): ...@@ -63,7 +62,7 @@ class TimeStretch(torch.jit.ScriptModule):
return complex_specgrams.reshape(shape[:-3] + complex_specgrams.shape[-3:]) return complex_specgrams.reshape(shape[:-3] + complex_specgrams.shape[-3:])
class _AxisMasking(torch.jit.ScriptModule): class _AxisMasking(torch.nn.Module):
r""" r"""
Apply masking to a spectrogram. Apply masking to a spectrogram.
Args: Args:
...@@ -80,7 +79,6 @@ class _AxisMasking(torch.jit.ScriptModule): ...@@ -80,7 +79,6 @@ class _AxisMasking(torch.jit.ScriptModule):
self.axis = axis self.axis = axis
self.iid_masks = iid_masks self.iid_masks = iid_masks
@torch.jit.script_method
def forward(self, specgram, mask_value=0.): def forward(self, specgram, mask_value=0.):
# type: (Tensor, float) -> Tensor # type: (Tensor, float) -> Tensor
r""" r"""
......
...@@ -221,7 +221,6 @@ def istft( ...@@ -221,7 +221,6 @@ def istft(
return y return y
@torch.jit.script
def spectrogram( def spectrogram(
waveform, pad, window, n_fft, hop_length, win_length, power, normalized waveform, pad, window, n_fft, hop_length, win_length, power, normalized
): ):
...@@ -274,7 +273,6 @@ def spectrogram( ...@@ -274,7 +273,6 @@ def spectrogram(
return spec_f return spec_f
@torch.jit.script
def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None): def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None):
# type: (Tensor, float, float, float, Optional[float]) -> Tensor # type: (Tensor, float, float, float, Optional[float]) -> Tensor
r""" r"""
...@@ -309,7 +307,6 @@ def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None): ...@@ -309,7 +307,6 @@ def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None):
return x_db return x_db
@torch.jit.script
def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate): def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate):
# type: (int, float, float, int, int) -> Tensor # type: (int, float, float, int, int) -> Tensor
r""" r"""
...@@ -355,7 +352,6 @@ def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate): ...@@ -355,7 +352,6 @@ def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate):
return fb return fb
@torch.jit.script
def create_dct(n_mfcc, n_mels, norm): def create_dct(n_mfcc, n_mels, norm):
# type: (int, int, Optional[str]) -> Tensor # type: (int, int, Optional[str]) -> Tensor
r""" r"""
...@@ -386,7 +382,6 @@ def create_dct(n_mfcc, n_mels, norm): ...@@ -386,7 +382,6 @@ def create_dct(n_mfcc, n_mels, norm):
return dct.t() return dct.t()
@torch.jit.script
def mu_law_encoding(x, quantization_channels): def mu_law_encoding(x, quantization_channels):
# type: (Tensor, int) -> Tensor # type: (Tensor, int) -> Tensor
r""" r"""
...@@ -414,7 +409,6 @@ def mu_law_encoding(x, quantization_channels): ...@@ -414,7 +409,6 @@ def mu_law_encoding(x, quantization_channels):
return x_mu return x_mu
@torch.jit.script
def mu_law_decoding(x_mu, quantization_channels): def mu_law_decoding(x_mu, quantization_channels):
# type: (Tensor, int) -> Tensor # type: (Tensor, int) -> Tensor
r""" r"""
...@@ -442,7 +436,6 @@ def mu_law_decoding(x_mu, quantization_channels): ...@@ -442,7 +436,6 @@ def mu_law_decoding(x_mu, quantization_channels):
return x return x
@torch.jit.script
def complex_norm(complex_tensor, power=1.0): def complex_norm(complex_tensor, power=1.0):
# type: (Tensor, float) -> Tensor # type: (Tensor, float) -> Tensor
r"""Compute the norm of complex tensor input. r"""Compute the norm of complex tensor input.
...@@ -459,7 +452,6 @@ def complex_norm(complex_tensor, power=1.0): ...@@ -459,7 +452,6 @@ def complex_norm(complex_tensor, power=1.0):
return torch.norm(complex_tensor, 2, -1).pow(power) return torch.norm(complex_tensor, 2, -1).pow(power)
@torch.jit.script
def angle(complex_tensor): def angle(complex_tensor):
# type: (Tensor) -> Tensor # type: (Tensor) -> Tensor
r"""Compute the angle of complex tensor input. r"""Compute the angle of complex tensor input.
...@@ -473,7 +465,6 @@ def angle(complex_tensor): ...@@ -473,7 +465,6 @@ def angle(complex_tensor):
return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0]) return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0])
@torch.jit.script
def magphase(complex_tensor, power=1.0): def magphase(complex_tensor, power=1.0):
# type: (Tensor, float) -> Tuple[Tensor, Tensor] # type: (Tensor, float) -> Tuple[Tensor, Tensor]
r"""Separate a complex-valued spectrogram with shape `(..., 2)` into its magnitude and phase. r"""Separate a complex-valued spectrogram with shape `(..., 2)` into its magnitude and phase.
...@@ -490,7 +481,6 @@ def magphase(complex_tensor, power=1.0): ...@@ -490,7 +481,6 @@ def magphase(complex_tensor, power=1.0):
return mag, phase return mag, phase
@torch.jit.script
def phase_vocoder(complex_specgrams, rate, phase_advance): def phase_vocoder(complex_specgrams, rate, phase_advance):
# type: (Tensor, float, Tensor) -> Tensor # type: (Tensor, float, Tensor) -> Tensor
r"""Given a STFT tensor, speed up in time without modifying pitch by a r"""Given a STFT tensor, speed up in time without modifying pitch by a
...@@ -555,7 +545,6 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): ...@@ -555,7 +545,6 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
return complex_specgrams_stretch return complex_specgrams_stretch
@torch.jit.script
def lfilter(waveform, a_coeffs, b_coeffs): def lfilter(waveform, a_coeffs, b_coeffs):
# type: (Tensor, Tensor, Tensor) -> Tensor # type: (Tensor, Tensor, Tensor) -> Tensor
r""" r"""
...@@ -630,7 +619,6 @@ def lfilter(waveform, a_coeffs, b_coeffs): ...@@ -630,7 +619,6 @@ def lfilter(waveform, a_coeffs, b_coeffs):
return output return output
@torch.jit.script
def biquad(waveform, b0, b1, b2, a0, a1, a2): def biquad(waveform, b0, b1, b2, a0, a1, a2):
# type: (Tensor, float, float, float, float, float, float) -> Tensor # type: (Tensor, float, float, float, float, float, float) -> Tensor
r"""Performs a biquad filter of input tensor. Initial conditions set to 0. r"""Performs a biquad filter of input tensor. Initial conditions set to 0.
...@@ -665,7 +653,6 @@ def _dB2Linear(x): ...@@ -665,7 +653,6 @@ def _dB2Linear(x):
return math.exp(x * math.log(10) / 20.0) return math.exp(x * math.log(10) / 20.0)
@torch.jit.script
def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
# type: (Tensor, int, float, float) -> Tensor # type: (Tensor, int, float, float) -> Tensor
r"""Designs biquad highpass filter and performs filtering. Similar to SoX implementation. r"""Designs biquad highpass filter and performs filtering. Similar to SoX implementation.
...@@ -695,7 +682,6 @@ def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): ...@@ -695,7 +682,6 @@ def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
return biquad(waveform, b0, b1, b2, a0, a1, a2) return biquad(waveform, b0, b1, b2, a0, a1, a2)
@torch.jit.script
def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
# type: (Tensor, int, float, float) -> Tensor # type: (Tensor, int, float, float) -> Tensor
r"""Designs biquad lowpass filter and performs filtering. Similar to SoX implementation. r"""Designs biquad lowpass filter and performs filtering. Similar to SoX implementation.
...@@ -725,7 +711,6 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): ...@@ -725,7 +711,6 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
return biquad(waveform, b0, b1, b2, a0, a1, a2) return biquad(waveform, b0, b1, b2, a0, a1, a2)
@torch.jit.script
def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707): def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707):
# type: (Tensor, int, float, float, float) -> Tensor # type: (Tensor, int, float, float, float) -> Tensor
r"""Designs biquad peaking equalizer filter and performs filtering. Similar to SoX implementation. r"""Designs biquad peaking equalizer filter and performs filtering. Similar to SoX implementation.
...@@ -753,7 +738,6 @@ def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707): ...@@ -753,7 +738,6 @@ def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707):
return biquad(waveform, b0, b1, b2, a0, a1, a2) return biquad(waveform, b0, b1, b2, a0, a1, a2)
@torch.jit.script
def mask_along_axis_iid(specgrams, mask_param, mask_value, axis): def mask_along_axis_iid(specgrams, mask_param, mask_value, axis):
# type: (Tensor, int, float, int) -> Tensor # type: (Tensor, int, float, int) -> Tensor
r""" r"""
...@@ -790,7 +774,6 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis): ...@@ -790,7 +774,6 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis):
return specgrams return specgrams
@torch.jit.script
def mask_along_axis(specgram, mask_param, mask_value, axis): def mask_along_axis(specgram, mask_param, mask_value, axis):
# type: (Tensor, int, float, int) -> Tensor # type: (Tensor, int, float, int) -> Tensor
r""" r"""
...@@ -825,7 +808,6 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): ...@@ -825,7 +808,6 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
return specgram return specgram
@torch.jit.script
def compute_deltas(specgram, win_length=5, mode="replicate"): def compute_deltas(specgram, win_length=5, mode="replicate"):
# type: (Tensor, int, str) -> Tensor # type: (Tensor, int, str) -> Tensor
r"""Compute delta coefficients of a tensor, usually a spectrogram: r"""Compute delta coefficients of a tensor, usually a spectrogram:
...@@ -878,7 +860,6 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): ...@@ -878,7 +860,6 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
return output return output
@torch.jit.script
def _compute_nccf(waveform, sample_rate, frame_time, freq_low): def _compute_nccf(waveform, sample_rate, frame_time, freq_low):
# type: (Tensor, int, float, int) -> Tensor # type: (Tensor, int, float, int) -> Tensor
r""" r"""
...@@ -993,7 +974,6 @@ def _median_smoothing(indices, win_length): ...@@ -993,7 +974,6 @@ def _median_smoothing(indices, win_length):
return values return values
@torch.jit.script
def detect_pitch_frequency( def detect_pitch_frequency(
waveform, waveform,
sample_rate, sample_rate,
...@@ -1021,7 +1001,7 @@ def detect_pitch_frequency( ...@@ -1021,7 +1001,7 @@ def detect_pitch_frequency(
dim = waveform.dim() dim = waveform.dim()
# pack batch # pack batch
shape = waveform.size() shape = list(waveform.size())
waveform = waveform.reshape([-1] + shape[-1:]) waveform = waveform.reshape([-1] + shape[-1:])
nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low) nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low)
...@@ -1033,6 +1013,6 @@ def detect_pitch_frequency( ...@@ -1033,6 +1013,6 @@ def detect_pitch_frequency(
freq = sample_rate / (EPSILON + indices.to(torch.float)) freq = sample_rate / (EPSILON + indices.to(torch.float))
# unpack batch # unpack batch
freq = freq.reshape(shape[:-1] + freq.shape[-1:]) freq = freq.reshape(shape[:-1] + list(freq.shape[-1:]))
return freq return freq
...@@ -20,7 +20,7 @@ __all__ = [ ...@@ -20,7 +20,7 @@ __all__ = [
] ]
class Spectrogram(torch.jit.ScriptModule): class Spectrogram(torch.nn.Module):
r"""Create a spectrogram from a audio signal r"""Create a spectrogram from a audio signal
Args: Args:
...@@ -48,12 +48,11 @@ class Spectrogram(torch.jit.ScriptModule): ...@@ -48,12 +48,11 @@ class Spectrogram(torch.jit.ScriptModule):
self.win_length = win_length if win_length is not None else n_fft self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 2 self.hop_length = hop_length if hop_length is not None else self.win_length // 2
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.window = torch.jit.Attribute(window, torch.Tensor) self.window = window
self.pad = pad self.pad = pad
self.power = power self.power = power
self.normalized = normalized self.normalized = normalized
@torch.jit.script_method
def forward(self, waveform): def forward(self, waveform):
r""" r"""
Args: Args:
...@@ -85,7 +84,7 @@ class AmplitudeToDB(torch.jit.ScriptModule): ...@@ -85,7 +84,7 @@ class AmplitudeToDB(torch.jit.ScriptModule):
def __init__(self, stype='power', top_db=None): def __init__(self, stype='power', top_db=None):
super(AmplitudeToDB, self).__init__() super(AmplitudeToDB, self).__init__()
self.stype = torch.jit.Attribute(stype, str) self.stype = stype
if top_db is not None and top_db < 0: if top_db is not None and top_db < 0:
raise ValueError('top_db must be positive value') raise ValueError('top_db must be positive value')
self.top_db = torch.jit.Attribute(top_db, Optional[float]) self.top_db = torch.jit.Attribute(top_db, Optional[float])
...@@ -94,7 +93,6 @@ class AmplitudeToDB(torch.jit.ScriptModule): ...@@ -94,7 +93,6 @@ class AmplitudeToDB(torch.jit.ScriptModule):
self.ref_value = 1.0 self.ref_value = 1.0
self.db_multiplier = math.log10(max(self.amin, self.ref_value)) self.db_multiplier = math.log10(max(self.amin, self.ref_value))
@torch.jit.script_method
def forward(self, x): def forward(self, x):
r"""Numerically stable implementation from Librosa r"""Numerically stable implementation from Librosa
https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html
...@@ -108,7 +106,7 @@ class AmplitudeToDB(torch.jit.ScriptModule): ...@@ -108,7 +106,7 @@ class AmplitudeToDB(torch.jit.ScriptModule):
return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db) return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db)
class MelScale(torch.jit.ScriptModule): class MelScale(torch.nn.Module):
r"""This turns a normal STFT into a mel frequency STFT, using a conversion r"""This turns a normal STFT into a mel frequency STFT, using a conversion
matrix. This uses triangular filter banks. matrix. This uses triangular filter banks.
...@@ -129,13 +127,14 @@ class MelScale(torch.jit.ScriptModule): ...@@ -129,13 +127,14 @@ class MelScale(torch.jit.ScriptModule):
self.n_mels = n_mels self.n_mels = n_mels
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.f_max = f_max if f_max is not None else float(sample_rate // 2) self.f_max = f_max if f_max is not None else float(sample_rate // 2)
assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)
self.f_min = f_min self.f_min = f_min
assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)
fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate) n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
self.fb = torch.jit.Attribute(fb, torch.Tensor) self.fb = fb
@torch.jit.script_method
def forward(self, specgram): def forward(self, specgram):
r""" r"""
Args: Args:
...@@ -156,7 +155,7 @@ class MelScale(torch.jit.ScriptModule): ...@@ -156,7 +155,7 @@ class MelScale(torch.jit.ScriptModule):
return mel_specgram return mel_specgram
class MelSpectrogram(torch.jit.ScriptModule): class MelSpectrogram(torch.nn.Module):
r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram
and MelScale. and MelScale.
...@@ -194,7 +193,7 @@ class MelSpectrogram(torch.jit.ScriptModule): ...@@ -194,7 +193,7 @@ class MelSpectrogram(torch.jit.ScriptModule):
self.hop_length = hop_length if hop_length is not None else self.win_length // 2 self.hop_length = hop_length if hop_length is not None else self.win_length // 2
self.pad = pad self.pad = pad
self.n_mels = n_mels # number of mel frequency bins self.n_mels = n_mels # number of mel frequency bins
self.f_max = torch.jit.Attribute(f_max, Optional[float]) self.f_max = f_max
self.f_min = f_min self.f_min = f_min
self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length, self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
hop_length=self.hop_length, hop_length=self.hop_length,
...@@ -202,7 +201,6 @@ class MelSpectrogram(torch.jit.ScriptModule): ...@@ -202,7 +201,6 @@ class MelSpectrogram(torch.jit.ScriptModule):
normalized=False, wkwargs=wkwargs) normalized=False, wkwargs=wkwargs)
self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1) self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1)
@torch.jit.script_method
def forward(self, waveform): def forward(self, waveform):
r""" r"""
Args: Args:
...@@ -216,7 +214,7 @@ class MelSpectrogram(torch.jit.ScriptModule): ...@@ -216,7 +214,7 @@ class MelSpectrogram(torch.jit.ScriptModule):
return mel_specgram return mel_specgram
class MFCC(torch.jit.ScriptModule): class MFCC(torch.nn.Module):
r"""Create the Mel-frequency cepstrum coefficients from an audio signal r"""Create the Mel-frequency cepstrum coefficients from an audio signal
By default, this calculates the MFCC on the DB-scaled Mel spectrogram. By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
...@@ -247,7 +245,7 @@ class MFCC(torch.jit.ScriptModule): ...@@ -247,7 +245,7 @@ class MFCC(torch.jit.ScriptModule):
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.n_mfcc = n_mfcc self.n_mfcc = n_mfcc
self.dct_type = dct_type self.dct_type = dct_type
self.norm = torch.jit.Attribute(norm, Optional[str]) self.norm = norm
self.top_db = 80.0 self.top_db = 80.0
self.amplitude_to_DB = AmplitudeToDB('power', self.top_db) self.amplitude_to_DB = AmplitudeToDB('power', self.top_db)
...@@ -259,10 +257,9 @@ class MFCC(torch.jit.ScriptModule): ...@@ -259,10 +257,9 @@ class MFCC(torch.jit.ScriptModule):
if self.n_mfcc > self.MelSpectrogram.n_mels: if self.n_mfcc > self.MelSpectrogram.n_mels:
raise ValueError('Cannot select more MFCC coefficients than # mel bins') raise ValueError('Cannot select more MFCC coefficients than # mel bins')
dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm) dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm)
self.dct_mat = torch.jit.Attribute(dct_mat, torch.Tensor) self.dct_mat = dct_mat
self.log_mels = log_mels self.log_mels = log_mels
@torch.jit.script_method
def forward(self, waveform): def forward(self, waveform):
r""" r"""
Args: Args:
...@@ -283,7 +280,7 @@ class MFCC(torch.jit.ScriptModule): ...@@ -283,7 +280,7 @@ class MFCC(torch.jit.ScriptModule):
return mfcc return mfcc
class MuLawEncoding(torch.jit.ScriptModule): class MuLawEncoding(torch.nn.Module):
r"""Encode signal based on mu-law companding. For more info see the r"""Encode signal based on mu-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_ `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
...@@ -299,7 +296,6 @@ class MuLawEncoding(torch.jit.ScriptModule): ...@@ -299,7 +296,6 @@ class MuLawEncoding(torch.jit.ScriptModule):
super(MuLawEncoding, self).__init__() super(MuLawEncoding, self).__init__()
self.quantization_channels = quantization_channels self.quantization_channels = quantization_channels
@torch.jit.script_method
def forward(self, x): def forward(self, x):
r""" r"""
Args: Args:
...@@ -311,7 +307,7 @@ class MuLawEncoding(torch.jit.ScriptModule): ...@@ -311,7 +307,7 @@ class MuLawEncoding(torch.jit.ScriptModule):
return F.mu_law_encoding(x, self.quantization_channels) return F.mu_law_encoding(x, self.quantization_channels)
class MuLawDecoding(torch.jit.ScriptModule): class MuLawDecoding(torch.nn.Module):
r"""Decode mu-law encoded signal. For more info see the r"""Decode mu-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_ `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
...@@ -327,7 +323,6 @@ class MuLawDecoding(torch.jit.ScriptModule): ...@@ -327,7 +323,6 @@ class MuLawDecoding(torch.jit.ScriptModule):
super(MuLawDecoding, self).__init__() super(MuLawDecoding, self).__init__()
self.quantization_channels = quantization_channels self.quantization_channels = quantization_channels
@torch.jit.script_method
def forward(self, x_mu): def forward(self, x_mu):
r""" r"""
Args: Args:
...@@ -368,7 +363,7 @@ class Resample(torch.nn.Module): ...@@ -368,7 +363,7 @@ class Resample(torch.nn.Module):
raise ValueError('Invalid resampling method: %s' % (self.resampling_method)) raise ValueError('Invalid resampling method: %s' % (self.resampling_method))
class ComplexNorm(torch.jit.ScriptModule): class ComplexNorm(torch.nn.Module):
r"""Compute the norm of complex tensor input r"""Compute the norm of complex tensor input
Args: Args:
power (float): Power of the norm. Defaults to `1.0`. power (float): Power of the norm. Defaults to `1.0`.
...@@ -379,7 +374,6 @@ class ComplexNorm(torch.jit.ScriptModule): ...@@ -379,7 +374,6 @@ class ComplexNorm(torch.jit.ScriptModule):
super(ComplexNorm, self).__init__() super(ComplexNorm, self).__init__()
self.power = power self.power = power
@torch.jit.script_method
def forward(self, complex_tensor): def forward(self, complex_tensor):
r""" r"""
Args: Args:
...@@ -390,7 +384,7 @@ class ComplexNorm(torch.jit.ScriptModule): ...@@ -390,7 +384,7 @@ class ComplexNorm(torch.jit.ScriptModule):
return F.complex_norm(complex_tensor, self.power) return F.complex_norm(complex_tensor, self.power)
class ComputeDeltas(torch.jit.ScriptModule): class ComputeDeltas(torch.nn.Module):
r"""Compute delta coefficients of a tensor, usually a spectrogram. r"""Compute delta coefficients of a tensor, usually a spectrogram.
See `torchaudio.functional.compute_deltas` for more details. See `torchaudio.functional.compute_deltas` for more details.
...@@ -403,9 +397,8 @@ class ComputeDeltas(torch.jit.ScriptModule): ...@@ -403,9 +397,8 @@ class ComputeDeltas(torch.jit.ScriptModule):
def __init__(self, win_length=5, mode="replicate"): def __init__(self, win_length=5, mode="replicate"):
super(ComputeDeltas, self).__init__() super(ComputeDeltas, self).__init__()
self.win_length = win_length self.win_length = win_length
self.mode = torch.jit.Attribute(mode, str) self.mode = mode
@torch.jit.script_method
def forward(self, specgram): def forward(self, specgram):
r""" r"""
Args: Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment