Unverified Commit 909e445b authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Make jit compilation optional for function and use nn.Module (#314)

* nn.Module.

* generalizing spectrogram test.

* adding test to compile functionals.

* add cuda/cpu compilation test.

* adding transform test.

* remove standalone jit file.

* update mel scale.

* remove script decorator.

* apply to augmentations too.
parent ef7c55ce
......@@ -16,6 +16,15 @@ if IMPORT_LIBROSA:
import librosa
def _test_torchscript_functional(py_method, *args, **kwargs):
jit_method = torch.jit.script(py_method)
jit_out = jit_method(*args, **kwargs)
py_out = py_method(*args, **kwargs)
assert torch.allclose(jit_out, py_out)
class TestFunctional(unittest.TestCase):
data_sizes = [(2, 20), (3, 15), (4, 10)]
number_of_trials = 100
......@@ -25,6 +34,21 @@ class TestFunctional(unittest.TestCase):
test_filepath = os.path.join(test_dirpath, 'assets',
'steam-train-whistle-daniel_simon.mp3')
def test_torchscript_spectrogram(self):
tensor = torch.rand((1, 1000))
n_fft = 400
ws = 400
hop = 200
pad = 0
window = torch.hann_window(ws)
power = 2
normalize = False
_test_torchscript_functional(
F.spectrogram, tensor, pad, window, n_fft, hop, ws, power, normalize
)
def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8):
computed = F.compute_deltas(specgram, win_length=win_length)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
......@@ -49,6 +73,7 @@ class TestFunctional(unittest.TestCase):
specgram = torch.randn(channel, n_mfcc, time)
computed = F.compute_deltas(specgram, win_length=win_length)
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
_test_torchscript_functional(F.compute_deltas, specgram, win_length=win_length)
def test_batch_pitch(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
......@@ -63,6 +88,7 @@ class TestFunctional(unittest.TestCase):
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
_test_torchscript_functional(F.detect_pitch_frequency, waveform, sample_rate)
def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8):
# trim sound for case when constructed signal is shorter than original
......@@ -424,6 +450,74 @@ def test_phase_vocoder(complex_specgrams, rate, hop_length):
assert np.allclose(complex_stretch, expected_complex_stretch, atol=1e-5)
def test_torchscript_create_fb_matrix(self):
n_stft = 100
f_min = 0.0
f_max = 20.0
n_mels = 10
sample_rate = 16000
_test_torchscript_functional(F.create_fb_matrix, n_stft, f_min, f_max, n_mels, sample_rate)
def test_torchscript_amplitude_to_DB(self):
spec = torch.rand((6, 201))
multiplier = 10.0
amin = 1e-10
db_multiplier = 0.0
top_db = 80.0
_test_torchscript_functional(F.amplitude_to_DB, spec, multiplier, amin, db_multiplier, top_db)
def test_torchscript_create_dct(self):
n_mfcc = 40
n_mels = 128
norm = "ortho"
_test_torchscript_functional(F.create_dct, n_mfcc, n_mels, norm)
def test_torchscript_mu_law_encoding(self):
tensor = torch.rand((1, 10))
qc = 256
_test_torchscript_functional(F.mu_law_encoding, tensor, qc)
def test_torchscript_mu_law_decoding(self):
tensor = torch.rand((1, 10))
qc = 256
_test_torchscript_functional(F.mu_law_decoding, tensor, qc)
def test_torchscript_complex_norm(self):
complex_tensor = torch.randn(1, 2, 1025, 400, 2),
power = 2
_test_torchscript_functional(F.complex_norm, complex_tensor, power)
def test_mask_along_axis(self):
specgram = torch.randn(2, 1025, 400),
mask_param = 100
mask_value = 30.
axis = 2
_test_torchscript_functional(F.mask_along_axis, specgram, mask_param, mask_value, axis)
def test_mask_along_axis_iid(self):
specgram = torch.randn(2, 1025, 400),
specgrams = torch.randn(4, 2, 1025, 400),
mask_param = 100
mask_value = 30.
axis = 2
_test_torchscript_functional(F.mask_along_axis_iid, specgrams, mask_param, mask_value, axis)
@pytest.mark.parametrize('complex_tensor', [
torch.randn(1, 2, 1025, 400, 2),
......
......@@ -9,6 +9,15 @@ import common_utils
import time
def _test_torchscript_functional(py_method, *args, **kwargs):
jit_method = torch.jit.script(py_method)
jit_out = jit_method(*args, **kwargs)
py_out = py_method(*args, **kwargs)
assert torch.allclose(jit_out, py_out)
class TestFunctionalFiltering(unittest.TestCase):
test_dirpath, test_dir = common_utils.create_temp_assets_dir()
......@@ -79,6 +88,7 @@ class TestFunctionalFiltering(unittest.TestCase):
assert len(output_waveform.size()) == 2
assert output_waveform.size(0) == waveform.size(0)
assert output_waveform.size(1) == waveform.size(1)
_test_torchscript_functional(F.lfilter, waveform, a_coeffs, b_coeffs)
def test_lfilter(self):
......@@ -116,6 +126,7 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.lowpass_biquad(waveform, sample_rate, CUTOFF_FREQ)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.lowpass_biquad, waveform, sample_rate, CUTOFF_FREQ)
def test_highpass(self):
"""
......@@ -135,6 +146,7 @@ class TestFunctionalFiltering(unittest.TestCase):
# TBD - this fails at the 1e-4 level, debug why
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-3)
_test_torchscript_functional(F.highpass_biquad, waveform, sample_rate, CUTOFF_FREQ)
def test_equalizer(self):
"""
......@@ -155,6 +167,7 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.equalizer_biquad(waveform, sample_rate, CENTER_FREQ, GAIN, Q)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.equalizer_biquad, waveform, sample_rate, CENTER_FREQ, GAIN, Q)
def test_perf_biquad_filtering(self):
......@@ -183,6 +196,9 @@ class TestFunctionalFiltering(unittest.TestCase):
_timing_lfilter_run_time = time.time() - _timing_lfilter_filtering
assert torch.allclose(waveform_sox_out, waveform_lfilter_out, atol=1e-4)
_test_torchscript_functional(
F.lfilter, waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2])
)
if __name__ == "__main__":
......
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
import torchaudio.functional as F
import torchaudio.transforms as transforms
import unittest
RUN_CUDA = torch.cuda.is_available()
print('Run test with cuda:', RUN_CUDA)
class Test_JIT(unittest.TestCase):
def _get_script_module(self, f, *args):
# takes a transform function `f` and wraps it in a script module
class MyModule(torch.jit.ScriptModule):
def __init__(self):
super(MyModule, self).__init__()
self.module = f(*args)
self.module.eval()
@torch.jit.script_method
def forward(self, tensor):
return self.module(tensor)
return MyModule()
def _test_script_module(self, tensor, f, *args):
# tests a script module that wraps a transform function `f` by feeding
# the tensor into the forward function
jit_out = self._get_script_module(f, *args).cuda()(tensor)
py_out = f(*args).cuda()(tensor)
self.assertTrue(torch.allclose(jit_out, py_out))
def test_torchscript_spectrogram(self):
@torch.jit.script
def jit_method(sig, pad, window, n_fft, hop, ws, power, normalize):
# type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor
return F.spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize)
tensor = torch.rand((1, 1000))
n_fft = 400
ws = 400
hop = 200
pad = 0
window = torch.hann_window(ws)
power = 2
normalize = False
jit_out = jit_method(tensor, pad, window, n_fft, hop, ws, power, normalize)
py_out = F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_Spectrogram(self):
tensor = torch.rand((1, 1000), device="cuda")
self._test_script_module(tensor, transforms.Spectrogram)
def test_torchscript_create_fb_matrix(self):
@torch.jit.script
def jit_method(n_stft, f_min, f_max, n_mels, sample_rate):
# type: (int, float, float, int, int) -> Tensor
return F.create_fb_matrix(n_stft, f_min, f_max, n_mels, sample_rate)
n_stft = 100
f_min = 0.
f_max = 20.
n_mels = 10
sample_rate = 16000
jit_out = jit_method(n_stft, f_min, f_max, n_mels, sample_rate)
py_out = F.create_fb_matrix(n_stft, f_min, f_max, n_mels, sample_rate)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_MelScale(self):
spec_f = torch.rand((1, 6, 201), device="cuda")
self._test_script_module(spec_f, transforms.MelScale)
def test_torchscript_amplitude_to_DB(self):
@torch.jit.script
def jit_method(spec, multiplier, amin, db_multiplier, top_db):
# type: (Tensor, float, float, float, Optional[float]) -> Tensor
return F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db)
spec = torch.rand((6, 201))
multiplier = 10.
amin = 1e-10
db_multiplier = 0.
top_db = 80.
jit_out = jit_method(spec, multiplier, amin, db_multiplier, top_db)
py_out = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_AmplitudeToDB(self):
spec = torch.rand((6, 201), device="cuda")
self._test_script_module(spec, transforms.AmplitudeToDB)
def test_torchscript_create_dct(self):
@torch.jit.script
def jit_method(n_mfcc, n_mels, norm):
# type: (int, int, Optional[str]) -> Tensor
return F.create_dct(n_mfcc, n_mels, norm)
n_mfcc = 40
n_mels = 128
norm = 'ortho'
jit_out = jit_method(n_mfcc, n_mels, norm)
py_out = F.create_dct(n_mfcc, n_mels, norm)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_MFCC(self):
tensor = torch.rand((1, 1000), device="cuda")
self._test_script_module(tensor, transforms.MFCC)
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_MelSpectrogram(self):
tensor = torch.rand((1, 1000), device="cuda")
self._test_script_module(tensor, transforms.MelSpectrogram)
def test_torchscript_mu_law_encoding(self):
@torch.jit.script
def jit_method(tensor, qc):
# type: (Tensor, int) -> Tensor
return F.mu_law_encoding(tensor, qc)
tensor = torch.rand((1, 10))
qc = 256
jit_out = jit_method(tensor, qc)
py_out = F.mu_law_encoding(tensor, qc)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_MuLawEncoding(self):
tensor = torch.rand((1, 10), device="cuda")
self._test_script_module(tensor, transforms.MuLawEncoding)
def test_torchscript_mu_law_decoding(self):
@torch.jit.script
def jit_method(tensor, qc):
# type: (Tensor, int) -> Tensor
return F.mu_law_decoding(tensor, qc)
tensor = torch.rand((1, 10))
qc = 256
jit_out = jit_method(tensor, qc)
py_out = F.mu_law_decoding(tensor, qc)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_MuLawDecoding(self):
tensor = torch.rand((1, 10), device="cuda")
self._test_script_module(tensor, transforms.MuLawDecoding)
if __name__ == '__main__':
unittest.main()
......@@ -4,6 +4,7 @@ import os
import torch
import torchaudio
import torchaudio.augmentations as A
import torchaudio.transforms as transforms
import torchaudio.functional as F
from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY
......@@ -16,6 +17,32 @@ if IMPORT_LIBROSA:
if IMPORT_SCIPY:
import scipy
RUN_CUDA = torch.cuda.is_available()
print("Run test with cuda:", RUN_CUDA)
def _test_script_module(f, tensor, *args, **kwargs):
py_method = f(*args, **kwargs)
jit_method = torch.jit.script(py_method)
py_out = py_method(tensor)
jit_out = jit_method(tensor)
assert torch.allclose(jit_out, py_out)
if RUN_CUDA:
tensor = tensor.to("cuda")
py_method = py_method.cuda()
jit_method = torch.jit.script(py_method)
py_out = py_method(tensor)
jit_out = jit_method(tensor)
assert torch.allclose(jit_out, py_out)
class Tester(unittest.TestCase):
......@@ -37,6 +64,10 @@ class Tester(unittest.TestCase):
waveform = waveform.to(torch.get_default_dtype())
return waveform / factor
def test_scriptmodule_Spectrogram(self):
tensor = torch.rand((1, 1000))
_test_script_module(transforms.Spectrogram, tensor)
def test_mu_law_companding(self):
quantization_channels = 256
......@@ -51,6 +82,14 @@ class Tester(unittest.TestCase):
waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
def test_scriptmodule_AmplitudeToDB(self):
spec = torch.rand((6, 201))
_test_script_module(transforms.AmplitudeToDB, spec)
def test_scriptmodule_MelScale(self):
spec_f = torch.rand((1, 6, 201))
_test_script_module(transforms.MelScale, spec_f)
def test_melscale_load_save(self):
specgram = torch.ones(1, 1000, 100)
melscale_transform = transforms.MelScale()
......@@ -65,6 +104,10 @@ class Tester(unittest.TestCase):
self.assertEqual(fb_copy.size(), (1000, 128))
self.assertTrue(torch.allclose(fb, fb_copy))
def test_scriptmodule_MelSpectrogram(self):
tensor = torch.rand((1, 1000))
_test_script_module(transforms.MelSpectrogram, tensor)
def test_melspectrogram_load_save(self):
waveform = self.waveform.float()
mel_spectrogram_transform = transforms.MelSpectrogram()
......@@ -123,6 +166,10 @@ class Tester(unittest.TestCase):
self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
def test_scriptmodule_MFCC(self):
tensor = torch.rand((1, 1000))
_test_script_module(transforms.MFCC, tensor)
def test_mfcc(self):
audio_orig = self.waveform.clone()
audio_scaled = self.scale(audio_orig) # (1, 16000)
......@@ -326,6 +373,14 @@ class Tester(unittest.TestCase):
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def test_scriptmodule_MuLawEncoding(self):
tensor = torch.rand((1, 10))
_test_script_module(transforms.MuLawEncoding, tensor)
def test_scriptmodule_MuLawDecoding(self):
tensor = torch.rand((1, 10))
_test_script_module(transforms.MuLawDecoding, tensor)
def test_batch_mulaw(self):
waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100
......@@ -364,6 +419,21 @@ class Tester(unittest.TestCase):
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def test_scriptmodule_TimeStretch(self):
n_freq = 400
hop_length = 512
fixed_rate = 1.3
tensor = torch.rand((10, 2, n_freq, 10, 2))
_test_script_module(A.TimeStretch, tensor, n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate)
def test_scriptmodule_FrequencyMasking(self):
tensor = torch.rand((10, 2, 50, 10, 2))
_test_script_module(A.FrequencyMasking, tensor, freq_mask_param=60, iid_masks=False)
def test_scriptmodule_TimeMasking(self):
tensor = torch.rand((10, 2, 50, 10, 2))
_test_script_module(A.TimeMasking, tensor, time_mask_param=30, iid_masks=False)
if __name__ == '__main__':
unittest.main()
......@@ -24,14 +24,13 @@ class TimeStretch(torch.jit.ScriptModule):
def __init__(self, hop_length=None, n_freq=201, fixed_rate=None):
super(TimeStretch, self).__init__()
self.fixed_rate = fixed_rate
n_fft = (n_freq - 1) * 2
hop_length = hop_length if hop_length is not None else n_fft // 2
self.fixed_rate = fixed_rate
phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None]
self.phase_advance = torch.jit.Attribute(phase_advance, torch.Tensor)
@torch.jit.script_method
def forward(self, complex_specgrams, overriding_rate=None):
# type: (Tensor, Optional[float]) -> Tensor
r"""
......@@ -63,7 +62,7 @@ class TimeStretch(torch.jit.ScriptModule):
return complex_specgrams.reshape(shape[:-3] + complex_specgrams.shape[-3:])
class _AxisMasking(torch.jit.ScriptModule):
class _AxisMasking(torch.nn.Module):
r"""
Apply masking to a spectrogram.
Args:
......@@ -80,7 +79,6 @@ class _AxisMasking(torch.jit.ScriptModule):
self.axis = axis
self.iid_masks = iid_masks
@torch.jit.script_method
def forward(self, specgram, mask_value=0.):
# type: (Tensor, float) -> Tensor
r"""
......
......@@ -221,7 +221,6 @@ def istft(
return y
@torch.jit.script
def spectrogram(
waveform, pad, window, n_fft, hop_length, win_length, power, normalized
):
......@@ -274,7 +273,6 @@ def spectrogram(
return spec_f
@torch.jit.script
def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None):
# type: (Tensor, float, float, float, Optional[float]) -> Tensor
r"""
......@@ -309,7 +307,6 @@ def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None):
return x_db
@torch.jit.script
def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate):
# type: (int, float, float, int, int) -> Tensor
r"""
......@@ -355,7 +352,6 @@ def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate):
return fb
@torch.jit.script
def create_dct(n_mfcc, n_mels, norm):
# type: (int, int, Optional[str]) -> Tensor
r"""
......@@ -386,7 +382,6 @@ def create_dct(n_mfcc, n_mels, norm):
return dct.t()
@torch.jit.script
def mu_law_encoding(x, quantization_channels):
# type: (Tensor, int) -> Tensor
r"""
......@@ -414,7 +409,6 @@ def mu_law_encoding(x, quantization_channels):
return x_mu
@torch.jit.script
def mu_law_decoding(x_mu, quantization_channels):
# type: (Tensor, int) -> Tensor
r"""
......@@ -442,7 +436,6 @@ def mu_law_decoding(x_mu, quantization_channels):
return x
@torch.jit.script
def complex_norm(complex_tensor, power=1.0):
# type: (Tensor, float) -> Tensor
r"""Compute the norm of complex tensor input.
......@@ -459,7 +452,6 @@ def complex_norm(complex_tensor, power=1.0):
return torch.norm(complex_tensor, 2, -1).pow(power)
@torch.jit.script
def angle(complex_tensor):
# type: (Tensor) -> Tensor
r"""Compute the angle of complex tensor input.
......@@ -473,7 +465,6 @@ def angle(complex_tensor):
return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0])
@torch.jit.script
def magphase(complex_tensor, power=1.0):
# type: (Tensor, float) -> Tuple[Tensor, Tensor]
r"""Separate a complex-valued spectrogram with shape `(..., 2)` into its magnitude and phase.
......@@ -490,7 +481,6 @@ def magphase(complex_tensor, power=1.0):
return mag, phase
@torch.jit.script
def phase_vocoder(complex_specgrams, rate, phase_advance):
# type: (Tensor, float, Tensor) -> Tensor
r"""Given a STFT tensor, speed up in time without modifying pitch by a
......@@ -555,7 +545,6 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
return complex_specgrams_stretch
@torch.jit.script
def lfilter(waveform, a_coeffs, b_coeffs):
# type: (Tensor, Tensor, Tensor) -> Tensor
r"""
......@@ -630,7 +619,6 @@ def lfilter(waveform, a_coeffs, b_coeffs):
return output
@torch.jit.script
def biquad(waveform, b0, b1, b2, a0, a1, a2):
# type: (Tensor, float, float, float, float, float, float) -> Tensor
r"""Performs a biquad filter of input tensor. Initial conditions set to 0.
......@@ -665,7 +653,6 @@ def _dB2Linear(x):
return math.exp(x * math.log(10) / 20.0)
@torch.jit.script
def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
# type: (Tensor, int, float, float) -> Tensor
r"""Designs biquad highpass filter and performs filtering. Similar to SoX implementation.
......@@ -695,7 +682,6 @@ def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
return biquad(waveform, b0, b1, b2, a0, a1, a2)
@torch.jit.script
def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
# type: (Tensor, int, float, float) -> Tensor
r"""Designs biquad lowpass filter and performs filtering. Similar to SoX implementation.
......@@ -725,7 +711,6 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
return biquad(waveform, b0, b1, b2, a0, a1, a2)
@torch.jit.script
def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707):
# type: (Tensor, int, float, float, float) -> Tensor
r"""Designs biquad peaking equalizer filter and performs filtering. Similar to SoX implementation.
......@@ -753,7 +738,6 @@ def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707):
return biquad(waveform, b0, b1, b2, a0, a1, a2)
@torch.jit.script
def mask_along_axis_iid(specgrams, mask_param, mask_value, axis):
# type: (Tensor, int, float, int) -> Tensor
r"""
......@@ -790,7 +774,6 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis):
return specgrams
@torch.jit.script
def mask_along_axis(specgram, mask_param, mask_value, axis):
# type: (Tensor, int, float, int) -> Tensor
r"""
......@@ -825,7 +808,6 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
return specgram
@torch.jit.script
def compute_deltas(specgram, win_length=5, mode="replicate"):
# type: (Tensor, int, str) -> Tensor
r"""Compute delta coefficients of a tensor, usually a spectrogram:
......@@ -878,7 +860,6 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
return output
@torch.jit.script
def _compute_nccf(waveform, sample_rate, frame_time, freq_low):
# type: (Tensor, int, float, int) -> Tensor
r"""
......@@ -993,7 +974,6 @@ def _median_smoothing(indices, win_length):
return values
@torch.jit.script
def detect_pitch_frequency(
waveform,
sample_rate,
......@@ -1021,7 +1001,7 @@ def detect_pitch_frequency(
dim = waveform.dim()
# pack batch
shape = waveform.size()
shape = list(waveform.size())
waveform = waveform.reshape([-1] + shape[-1:])
nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low)
......@@ -1033,6 +1013,6 @@ def detect_pitch_frequency(
freq = sample_rate / (EPSILON + indices.to(torch.float))
# unpack batch
freq = freq.reshape(shape[:-1] + freq.shape[-1:])
freq = freq.reshape(shape[:-1] + list(freq.shape[-1:]))
return freq
......@@ -20,7 +20,7 @@ __all__ = [
]
class Spectrogram(torch.jit.ScriptModule):
class Spectrogram(torch.nn.Module):
r"""Create a spectrogram from a audio signal
Args:
......@@ -48,12 +48,11 @@ class Spectrogram(torch.jit.ScriptModule):
self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 2
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.window = torch.jit.Attribute(window, torch.Tensor)
self.window = window
self.pad = pad
self.power = power
self.normalized = normalized
@torch.jit.script_method
def forward(self, waveform):
r"""
Args:
......@@ -85,7 +84,7 @@ class AmplitudeToDB(torch.jit.ScriptModule):
def __init__(self, stype='power', top_db=None):
super(AmplitudeToDB, self).__init__()
self.stype = torch.jit.Attribute(stype, str)
self.stype = stype
if top_db is not None and top_db < 0:
raise ValueError('top_db must be positive value')
self.top_db = torch.jit.Attribute(top_db, Optional[float])
......@@ -94,7 +93,6 @@ class AmplitudeToDB(torch.jit.ScriptModule):
self.ref_value = 1.0
self.db_multiplier = math.log10(max(self.amin, self.ref_value))
@torch.jit.script_method
def forward(self, x):
r"""Numerically stable implementation from Librosa
https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html
......@@ -108,7 +106,7 @@ class AmplitudeToDB(torch.jit.ScriptModule):
return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db)
class MelScale(torch.jit.ScriptModule):
class MelScale(torch.nn.Module):
r"""This turns a normal STFT into a mel frequency STFT, using a conversion
matrix. This uses triangular filter banks.
......@@ -129,13 +127,14 @@ class MelScale(torch.jit.ScriptModule):
self.n_mels = n_mels
self.sample_rate = sample_rate
self.f_max = f_max if f_max is not None else float(sample_rate // 2)
assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)
self.f_min = f_min
assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)
fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
self.fb = torch.jit.Attribute(fb, torch.Tensor)
self.fb = fb
@torch.jit.script_method
def forward(self, specgram):
r"""
Args:
......@@ -156,7 +155,7 @@ class MelScale(torch.jit.ScriptModule):
return mel_specgram
class MelSpectrogram(torch.jit.ScriptModule):
class MelSpectrogram(torch.nn.Module):
r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram
and MelScale.
......@@ -194,7 +193,7 @@ class MelSpectrogram(torch.jit.ScriptModule):
self.hop_length = hop_length if hop_length is not None else self.win_length // 2
self.pad = pad
self.n_mels = n_mels # number of mel frequency bins
self.f_max = torch.jit.Attribute(f_max, Optional[float])
self.f_max = f_max
self.f_min = f_min
self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
hop_length=self.hop_length,
......@@ -202,7 +201,6 @@ class MelSpectrogram(torch.jit.ScriptModule):
normalized=False, wkwargs=wkwargs)
self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1)
@torch.jit.script_method
def forward(self, waveform):
r"""
Args:
......@@ -216,7 +214,7 @@ class MelSpectrogram(torch.jit.ScriptModule):
return mel_specgram
class MFCC(torch.jit.ScriptModule):
class MFCC(torch.nn.Module):
r"""Create the Mel-frequency cepstrum coefficients from an audio signal
By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
......@@ -247,7 +245,7 @@ class MFCC(torch.jit.ScriptModule):
self.sample_rate = sample_rate
self.n_mfcc = n_mfcc
self.dct_type = dct_type
self.norm = torch.jit.Attribute(norm, Optional[str])
self.norm = norm
self.top_db = 80.0
self.amplitude_to_DB = AmplitudeToDB('power', self.top_db)
......@@ -259,10 +257,9 @@ class MFCC(torch.jit.ScriptModule):
if self.n_mfcc > self.MelSpectrogram.n_mels:
raise ValueError('Cannot select more MFCC coefficients than # mel bins')
dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm)
self.dct_mat = torch.jit.Attribute(dct_mat, torch.Tensor)
self.dct_mat = dct_mat
self.log_mels = log_mels
@torch.jit.script_method
def forward(self, waveform):
r"""
Args:
......@@ -283,7 +280,7 @@ class MFCC(torch.jit.ScriptModule):
return mfcc
class MuLawEncoding(torch.jit.ScriptModule):
class MuLawEncoding(torch.nn.Module):
r"""Encode signal based on mu-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
......@@ -299,7 +296,6 @@ class MuLawEncoding(torch.jit.ScriptModule):
super(MuLawEncoding, self).__init__()
self.quantization_channels = quantization_channels
@torch.jit.script_method
def forward(self, x):
r"""
Args:
......@@ -311,7 +307,7 @@ class MuLawEncoding(torch.jit.ScriptModule):
return F.mu_law_encoding(x, self.quantization_channels)
class MuLawDecoding(torch.jit.ScriptModule):
class MuLawDecoding(torch.nn.Module):
r"""Decode mu-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
......@@ -327,7 +323,6 @@ class MuLawDecoding(torch.jit.ScriptModule):
super(MuLawDecoding, self).__init__()
self.quantization_channels = quantization_channels
@torch.jit.script_method
def forward(self, x_mu):
r"""
Args:
......@@ -368,7 +363,7 @@ class Resample(torch.nn.Module):
raise ValueError('Invalid resampling method: %s' % (self.resampling_method))
class ComplexNorm(torch.jit.ScriptModule):
class ComplexNorm(torch.nn.Module):
r"""Compute the norm of complex tensor input
Args:
power (float): Power of the norm. Defaults to `1.0`.
......@@ -379,7 +374,6 @@ class ComplexNorm(torch.jit.ScriptModule):
super(ComplexNorm, self).__init__()
self.power = power
@torch.jit.script_method
def forward(self, complex_tensor):
r"""
Args:
......@@ -390,7 +384,7 @@ class ComplexNorm(torch.jit.ScriptModule):
return F.complex_norm(complex_tensor, self.power)
class ComputeDeltas(torch.jit.ScriptModule):
class ComputeDeltas(torch.nn.Module):
r"""Compute delta coefficients of a tensor, usually a spectrogram.
See `torchaudio.functional.compute_deltas` for more details.
......@@ -403,9 +397,8 @@ class ComputeDeltas(torch.jit.ScriptModule):
def __init__(self, win_length=5, mode="replicate"):
super(ComputeDeltas, self).__init__()
self.win_length = win_length
self.mode = torch.jit.Attribute(mode, str)
self.mode = mode
@torch.jit.script_method
def forward(self, specgram):
r"""
Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment