Commit a420cced authored by jamarshon's avatar jamarshon Committed by cpuhrsch
Browse files

Modifying modules and functions to support JIT and cuda (#118)

parent d92de5b9
...@@ -52,8 +52,8 @@ pip install -r requirements.txt ...@@ -52,8 +52,8 @@ pip install -r requirements.txt
# Install the following only if running tests # Install the following only if running tests
if [[ "$SKIP_TESTS" != "true" ]]; then if [[ "$SKIP_TESTS" != "true" ]]; then
# PyTorch # PyTorch (nightly as 1.1 does not have Optional for type annotations)
conda install --yes pytorch -c pytorch conda install --yes pytorch-nightly-cpu -c pytorch
# TorchAudio CPP Extensions # TorchAudio CPP Extensions
pip install . pip install .
......
File mode changed from 100644 to 100755
import torch
import torchaudio.functional as F
import torchaudio.transforms as transforms
import unittest
RUN_CUDA = torch.cuda.is_available()
print('Run test with cuda:', RUN_CUDA)
class Test_JIT(unittest.TestCase):
def _get_script_module(self, f, *args):
# takes a transform function `f` and wraps it in a script module
class MyModule(torch.jit.ScriptModule):
def __init__(self):
super(MyModule, self).__init__()
self.module = f(*args)
self.module.eval()
@torch.jit.script_method
def forward(self, tensor):
return self.module(tensor)
return MyModule()
def _test_script_module(self, tensor, f, *args):
# tests a script module that wraps a transform function `f` by feeding
# the tensor into the forward function
jit_out = self._get_script_module(f, *args).cuda()(tensor)
py_out = f(*args).cuda()(tensor)
self.assertTrue(torch.allclose(jit_out, py_out))
def test_torchscript_scale(self):
@torch.jit.script
def jit_method(tensor, factor):
# type: (Tensor, int) -> Tensor
return F.scale(tensor, factor)
tensor = torch.rand((10, 1))
factor = 2
jit_out = jit_method(tensor, factor)
py_out = F.scale(tensor, factor)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_scale(self):
tensor = torch.rand((10, 1), device="cuda")
self._test_script_module(tensor, transforms.Scale)
def test_torchscript_pad_trim(self):
@torch.jit.script
def jit_method(tensor, ch_dim, max_len, len_dim, fill_value):
# type: (Tensor, int, int, int, float) -> Tensor
return F.pad_trim(tensor, ch_dim, max_len, len_dim, fill_value)
tensor = torch.rand((10, 1))
ch_dim = 1
max_len = 5
len_dim = 0
fill_value = 3.
jit_out = jit_method(tensor, ch_dim, max_len, len_dim, fill_value)
py_out = F.pad_trim(tensor, ch_dim, max_len, len_dim, fill_value)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_pad_trim(self):
tensor = torch.rand((1, 10), device="cuda")
max_len = 5
self._test_script_module(tensor, transforms.PadTrim, max_len)
def test_torchscript_downmix_mono(self):
@torch.jit.script
def jit_method(tensor, ch_dim):
# type: (Tensor, int) -> Tensor
return F.downmix_mono(tensor, ch_dim)
tensor = torch.rand((10, 1))
ch_dim = 1
jit_out = jit_method(tensor, ch_dim)
py_out = F.downmix_mono(tensor, ch_dim)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_downmix_mono(self):
tensor = torch.rand((1, 10), device="cuda")
self._test_script_module(tensor, transforms.DownmixMono)
def test_torchscript_LC2CL(self):
@torch.jit.script
def jit_method(tensor):
# type: (Tensor) -> Tensor
return F.LC2CL(tensor)
tensor = torch.rand((10, 1))
jit_out = jit_method(tensor)
py_out = F.LC2CL(tensor)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_LC2CL(self):
tensor = torch.rand((10, 1), device="cuda")
self._test_script_module(tensor, transforms.LC2CL)
def test_torchscript_spectrogram(self):
@torch.jit.script
def jit_method(sig, pad, window, n_fft, hop, ws, power, normalize):
# type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor
return F.spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize)
tensor = torch.rand((1, 1000))
n_fft = 400
ws = 400
hop = 200
pad = 0
window = torch.hann_window(ws)
power = 2
normalize = False
jit_out = jit_method(tensor, pad, window, n_fft, hop, ws, power, normalize)
py_out = F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_Spectrogram(self):
tensor = torch.rand((1, 1000), device="cuda")
self._test_script_module(tensor, transforms.Spectrogram)
def test_torchscript_create_fb_matrix(self):
@torch.jit.script
def jit_method(n_stft, f_min, f_max, n_mels):
# type: (int, float, float, int) -> Tensor
return F.create_fb_matrix(n_stft, f_min, f_max, n_mels)
n_stft = 100
f_min = 0.
f_max = 20.
n_mels = 10
jit_out = jit_method(n_stft, f_min, f_max, n_mels)
py_out = F.create_fb_matrix(n_stft, f_min, f_max, n_mels)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_MelScale(self):
spec_f = torch.rand((1, 6, 201), device="cuda")
self._test_script_module(spec_f, transforms.MelScale)
def test_torchscript_spectrogram_to_DB(self):
@torch.jit.script
def jit_method(spec, multiplier, amin, db_multiplier, top_db):
# type: (Tensor, float, float, float, Optional[float]) -> Tensor
return F.spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db)
spec = torch.rand((10, 1))
multiplier = 10.
amin = 1e-10
db_multiplier = 0.
top_db = 80.
jit_out = jit_method(spec, multiplier, amin, db_multiplier, top_db)
py_out = F.spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_SpectrogramToDB(self):
spec = torch.rand((10, 1), device="cuda")
self._test_script_module(spec, transforms.SpectrogramToDB)
def test_torchscript_create_dct(self):
@torch.jit.script
def jit_method(n_mfcc, n_mels, norm):
# type: (int, int, Optional[str]) -> Tensor
return F.create_dct(n_mfcc, n_mels, norm)
n_mfcc = 40
n_mels = 128
norm = 'ortho'
jit_out = jit_method(n_mfcc, n_mels, norm)
py_out = F.create_dct(n_mfcc, n_mels, norm)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_MFCC(self):
tensor = torch.rand((1, 1000), device="cuda")
self._test_script_module(tensor, transforms.MFCC)
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_MelSpectrogram(self):
tensor = torch.rand((1, 1000), device="cuda")
self._test_script_module(tensor, transforms.MelSpectrogram)
def test_torchscript_BLC2CBL(self):
@torch.jit.script
def jit_method(tensor):
# type: (Tensor) -> Tensor
return F.BLC2CBL(tensor)
tensor = torch.rand((10, 1000, 1))
jit_out = jit_method(tensor)
py_out = F.BLC2CBL(tensor)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_BLC2CBL(self):
tensor = torch.rand((10, 1000, 1), device="cuda")
self._test_script_module(tensor, transforms.BLC2CBL)
def test_torchscript_mu_law_encoding(self):
@torch.jit.script
def jit_method(tensor, qc):
# type: (Tensor, int) -> Tensor
return F.mu_law_encoding(tensor, qc)
tensor = torch.rand((10, 1))
qc = 256
jit_out = jit_method(tensor, qc)
py_out = F.mu_law_encoding(tensor, qc)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_MuLawEncoding(self):
tensor = torch.rand((10, 1), device="cuda")
self._test_script_module(tensor, transforms.MuLawEncoding)
def test_torchscript_mu_law_expanding(self):
@torch.jit.script
def jit_method(tensor, qc):
# type: (Tensor, int) -> Tensor
return F.mu_law_expanding(tensor, qc)
tensor = torch.rand((10, 1))
qc = 256
jit_out = jit_method(tensor, qc)
py_out = F.mu_law_expanding(tensor, qc)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_MuLawExpanding(self):
tensor = torch.rand((10, 1), device="cuda")
self._test_script_module(tensor, transforms.MuLawExpanding)
if __name__ == '__main__':
unittest.main()
...@@ -5,7 +5,7 @@ import unittest ...@@ -5,7 +5,7 @@ import unittest
import test.common_utils import test.common_utils
class KaldiIOTest(unittest.TestCase): class Test_KaldiIO(unittest.TestCase):
data1 = [[1, 2, 3], [11, 12, 13], [21, 22, 23]] data1 = [[1, 2, 3], [11, 12, 13], [21, 22, 23]]
data2 = [[31, 32, 33], [41, 42, 43], [51, 52, 53]] data2 = [[31, 32, 33], [41, 42, 43], [51, 52, 53]]
test_dirpath, test_dir = test.common_utils.create_temp_assets_dir() test_dirpath, test_dir = test.common_utils.create_temp_assets_dir()
......
...@@ -36,8 +36,9 @@ class Tester(unittest.TestCase): ...@@ -36,8 +36,9 @@ class Tester(unittest.TestCase):
result = transforms.Scale()(audio_orig) result = transforms.Scale()(audio_orig)
self.assertTrue(result.min() >= -1. and result.max() <= 1.) self.assertTrue(result.min() >= -1. and result.max() <= 1.)
maxminmax = float(max(abs(audio_orig.min()), abs(audio_orig.max()))) maxminmax = max(abs(audio_orig.min()), abs(audio_orig.max())).item()
result = transforms.Scale(factor=maxminmax)(audio_orig) result = transforms.Scale(factor=maxminmax)(audio_orig)
self.assertTrue((result.min() == -1. or result.max() == 1.) and self.assertTrue((result.min() == -1. or result.max() == 1.) and
result.min() >= -1. and result.max() <= 1.) result.min() >= -1. and result.max() <= 1.)
...@@ -99,7 +100,7 @@ class Tester(unittest.TestCase): ...@@ -99,7 +100,7 @@ class Tester(unittest.TestCase):
audio_orig = self.sig.clone() audio_orig = self.sig.clone()
length_orig = audio_orig.size(0) length_orig = audio_orig.size(0)
length_new = int(length_orig * 1.2) length_new = int(length_orig * 1.2)
maxminmax = float(max(abs(audio_orig.min()), abs(audio_orig.max()))) maxminmax = max(abs(audio_orig.min()), abs(audio_orig.max())).item()
tset = (transforms.Scale(factor=maxminmax), tset = (transforms.Scale(factor=maxminmax),
transforms.PadTrim(max_len=length_new, channels_first=False)) transforms.PadTrim(max_len=length_new, channels_first=False))
......
...@@ -4,7 +4,7 @@ import os.path ...@@ -4,7 +4,7 @@ import os.path
import torch import torch
import _torch_sox import _torch_sox
from torchaudio import transforms, datasets, sox_effects, legacy from torchaudio import transforms, datasets, kaldi_io, sox_effects, legacy
def check_input(src): def check_input(src):
......
...@@ -9,16 +9,15 @@ __all__ = [ ...@@ -9,16 +9,15 @@ __all__ = [
'LC2CL', 'LC2CL',
'spectrogram', 'spectrogram',
'create_fb_matrix', 'create_fb_matrix',
'mel_scale',
'spectrogram_to_DB', 'spectrogram_to_DB',
'create_dct', 'create_dct',
'MFCC',
'BLC2CBL', 'BLC2CBL',
'mu_law_encoding', 'mu_law_encoding',
'mu_law_expanding' 'mu_law_expanding'
] ]
@torch.jit.script
def scale(tensor, factor): def scale(tensor, factor):
# type: (Tensor, int) -> Tensor # type: (Tensor, int) -> Tensor
"""Scale audio tensor from a 16-bit integer (represented as a FloatTensor) """Scale audio tensor from a 16-bit integer (represented as a FloatTensor)
...@@ -32,12 +31,13 @@ def scale(tensor, factor): ...@@ -32,12 +31,13 @@ def scale(tensor, factor):
Outputs: Outputs:
Tensor: Scaled by the scale factor Tensor: Scaled by the scale factor
""" """
if not tensor.dtype.is_floating_point: if not tensor.is_floating_point():
tensor = tensor.to(torch.float32) tensor = tensor.to(torch.float32)
return tensor / factor return tensor / factor
@torch.jit.script
def pad_trim(tensor, ch_dim, max_len, len_dim, fill_value): def pad_trim(tensor, ch_dim, max_len, len_dim, fill_value):
# type: (Tensor, int, int, int, float) -> Tensor # type: (Tensor, int, int, int, float) -> Tensor
"""Pad/Trim a 2d-Tensor (Signal or Labels) """Pad/Trim a 2d-Tensor (Signal or Labels)
...@@ -53,20 +53,21 @@ def pad_trim(tensor, ch_dim, max_len, len_dim, fill_value): ...@@ -53,20 +53,21 @@ def pad_trim(tensor, ch_dim, max_len, len_dim, fill_value):
Tensor: Padded/trimmed tensor Tensor: Padded/trimmed tensor
""" """
if max_len > tensor.size(len_dim): if max_len > tensor.size(len_dim):
# tuple of (padding_left, padding_right, padding_top, padding_bottom) # array of [padding_left, padding_right, padding_top, padding_bottom]
# so pad similar to append (aka only right/bottom) and do not pad # so pad similar to append (aka only right/bottom) and do not pad
# the length dimension. assumes equal sizes of padding. # the length dimension. assumes equal sizes of padding.
padding = [max_len - tensor.size(len_dim) padding = [max_len - tensor.size(len_dim)
if (i % 2 == 1) and (i // 2 != len_dim) if (i % 2 == 1) and (i // 2 != len_dim)
else 0 else 0
for i in range(4)] for i in [0, 1, 2, 3]]
with torch.no_grad(): # TODO add "with torch.no_grad():" back when JIT supports it
tensor = torch.nn.functional.pad(tensor, padding, "constant", fill_value) tensor = torch.nn.functional.pad(tensor, padding, "constant", fill_value)
elif max_len < tensor.size(len_dim): elif max_len < tensor.size(len_dim):
tensor = tensor.narrow(len_dim, 0, max_len) tensor = tensor.narrow(len_dim, 0, max_len)
return tensor return tensor
@torch.jit.script
def downmix_mono(tensor, ch_dim): def downmix_mono(tensor, ch_dim):
# type: (Tensor, int) -> Tensor # type: (Tensor, int) -> Tensor
"""Downmix any stereo signals to mono. """Downmix any stereo signals to mono.
...@@ -78,13 +79,14 @@ def downmix_mono(tensor, ch_dim): ...@@ -78,13 +79,14 @@ def downmix_mono(tensor, ch_dim):
Outputs: Outputs:
Tensor: Mono signal Tensor: Mono signal
""" """
if not tensor.dtype.is_floating_point: if not tensor.is_floating_point():
tensor = tensor.to(torch.float32) tensor = tensor.to(torch.float32)
tensor = torch.mean(tensor, ch_dim, True) tensor = torch.mean(tensor, ch_dim, True)
return tensor return tensor
@torch.jit.script
def LC2CL(tensor): def LC2CL(tensor):
# type: (Tensor) -> Tensor # type: (Tensor) -> Tensor
"""Permute a 2d tensor from samples (n x c) to (c x n) """Permute a 2d tensor from samples (n x c) to (c x n)
...@@ -98,6 +100,12 @@ def LC2CL(tensor): ...@@ -98,6 +100,12 @@ def LC2CL(tensor):
return tensor.transpose(0, 1).contiguous() return tensor.transpose(0, 1).contiguous()
def _stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided):
# type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensor
return torch.stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided)
@torch.jit.script
def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize): def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize):
# type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor # type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor
"""Create a spectrogram from a raw audio signal """Create a spectrogram from a raw audio signal
...@@ -123,21 +131,20 @@ def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize): ...@@ -123,21 +131,20 @@ def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize):
assert sig.dim() == 2 assert sig.dim() == 2
if pad > 0: if pad > 0:
with torch.no_grad(): # TODO add "with torch.no_grad():" back when JIT supports it
sig = torch.nn.functional.pad(sig, (pad, pad), "constant") sig = torch.nn.functional.pad(sig, (pad, pad), "constant")
window = window.to(sig.device)
# default values are consistent with librosa.core.spectrum._spectrogram # default values are consistent with librosa.core.spectrum._spectrogram
spec_f = torch.stft(sig, n_fft, hop, ws, spec_f = _stft(sig, n_fft, hop, ws, window,
window, center=True, True, 'reflect', False, True).transpose(1, 2)
normalized=False, onesided=True,
pad_mode='reflect').transpose(1, 2)
if normalize: if normalize:
spec_f /= window.pow(2).sum().sqrt() spec_f /= window.pow(2).sum().sqrt()
spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor (c, l, n_fft) spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor (c, l, n_fft)
return spec_f return spec_f
@torch.jit.script
def create_fb_matrix(n_stft, f_min, f_max, n_mels): def create_fb_matrix(n_stft, f_min, f_max, n_mels):
# type: (int, float, float, int) -> Tensor # type: (int, float, float, int) -> Tensor
""" Create a frequency bin conversion matrix. """ Create a frequency bin conversion matrix.
...@@ -150,57 +157,29 @@ def create_fb_matrix(n_stft, f_min, f_max, n_mels): ...@@ -150,57 +157,29 @@ def create_fb_matrix(n_stft, f_min, f_max, n_mels):
Outputs: Outputs:
Tensor: triangular filter banks (fb matrix) Tensor: triangular filter banks (fb matrix)
"""
def _hertz_to_mel(f):
# type: (float) -> Tensor
return 2595. * torch.log10(torch.tensor(1.) + (f / 700.))
def _mel_to_hertz(mel):
# type: (Tensor) -> Tensor
return 700. * (10**(mel / 2595.) - 1.)
"""
# get stft freq bins # get stft freq bins
stft_freqs = torch.linspace(f_min, f_max, n_stft) stft_freqs = torch.linspace(f_min, f_max, n_stft)
# calculate mel freq bins # calculate mel freq bins
m_min = 0. if f_min == 0 else _hertz_to_mel(f_min) # hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
m_max = _hertz_to_mel(f_max) m_min = 0. if f_min == 0 else 2595. * math.log10(1. + (f_min / 700.))
m_max = 2595. * math.log10(1. + (f_max / 700.))
m_pts = torch.linspace(m_min, m_max, n_mels + 2) m_pts = torch.linspace(m_min, m_max, n_mels + 2)
f_pts = _mel_to_hertz(m_pts) # mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.)
f_pts = 700. * (10**(m_pts / 2595.) - 1.)
# calculate the difference between each mel point and each stft freq point in hertz # calculate the difference between each mel point and each stft freq point in hertz
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1) f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
slopes = f_pts.unsqueeze(0) - stft_freqs.unsqueeze(1) # (n_stft, n_mels + 2) slopes = f_pts.unsqueeze(0) - stft_freqs.unsqueeze(1) # (n_stft, n_mels + 2)
# create overlapping triangles # create overlapping triangles
z = torch.tensor(0.) z = torch.zeros(1)
down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1] # (n_stft, n_mels) down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1] # (n_stft, n_mels)
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_stft, n_mels) up_slopes = slopes[:, 2:] / f_diff[1:] # (n_stft, n_mels)
fb = torch.max(z, torch.min(down_slopes, up_slopes)) fb = torch.max(z, torch.min(down_slopes, up_slopes))
return fb return fb
def mel_scale(spec_f, f_min, f_max, n_mels, fb=None): @torch.jit.script
# type: (Tensor, float, float, int, Optional[Tensor]) -> Tuple[Tensor, Tensor]
""" This turns a normal STFT into a mel frequency STFT, using a conversion
matrix. This uses triangular filter banks.
Inputs:
spec_f (Tensor): normal STFT
f_min (float): minimum frequency
f_max (float): maximum frequency
n_mels (int): number of mel bins
fb (Optional[Tensor]): triangular filter banks (fb matrix)
Outputs:
Tuple[Tensor, Tensor]: triangular filter banks (fb matrix) and mel frequency STFT
"""
if fb is None:
fb = create_fb_matrix(spec_f.size(2), f_min, f_max, n_mels).to(spec_f.device)
else:
# need to ensure same device for dot product
fb = fb.to(spec_f.device)
spec_m = torch.matmul(spec_f, fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return fb, spec_m
def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None): def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None):
# type: (Tensor, float, float, float, Optional[float]) -> Tensor # type: (Tensor, float, float, float, Optional[float]) -> Tensor
"""Turns a spectrogram from the power/amplitude scale to the decibel scale. """Turns a spectrogram from the power/amplitude scale to the decibel scale.
...@@ -224,12 +203,15 @@ def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None): ...@@ -224,12 +203,15 @@ def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None):
spec_db -= multiplier * db_multiplier spec_db -= multiplier * db_multiplier
if top_db is not None: if top_db is not None:
spec_db = torch.max(spec_db, spec_db.new_full((1,), spec_db.max() - top_db)) new_spec_db_max = torch.tensor(float(spec_db.max()) - top_db, dtype=spec_db.dtype, device=spec_db.device)
spec_db = torch.max(spec_db, new_spec_db_max)
return spec_db return spec_db
@torch.jit.script
def create_dct(n_mfcc, n_mels, norm): def create_dct(n_mfcc, n_mels, norm):
# type: (int, int, string) -> Tensor # type: (int, int, Optional[str]) -> Tensor
""" """
Creates a DCT transformation matrix with shape (num_mels, num_mfcc), Creates a DCT transformation matrix with shape (num_mels, num_mfcc),
normalized depending on norm normalized depending on norm
...@@ -237,7 +219,7 @@ def create_dct(n_mfcc, n_mels, norm): ...@@ -237,7 +219,7 @@ def create_dct(n_mfcc, n_mels, norm):
Inputs: Inputs:
n_mfcc (int) : number of mfc coefficients to retain n_mfcc (int) : number of mfc coefficients to retain
n_mels (int): number of MEL bins n_mels (int): number of MEL bins
norm (string) : norm to use norm (Optional[str]) : norm to use (either 'ortho' or None)
Outputs: Outputs:
Tensor: The transformation matrix, to be right-multiplied to row-wise data. Tensor: The transformation matrix, to be right-multiplied to row-wise data.
...@@ -245,48 +227,19 @@ def create_dct(n_mfcc, n_mels, norm): ...@@ -245,48 +227,19 @@ def create_dct(n_mfcc, n_mels, norm):
outdim = n_mfcc outdim = n_mfcc
dim = n_mels dim = n_mels
# http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II # http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
n = torch.arange(dim, dtype=torch.get_default_dtype()) n = torch.arange(dim)
k = torch.arange(outdim, dtype=torch.get_default_dtype())[:, None] k = torch.arange(outdim)[:, None]
dct = torch.cos(math.pi / dim * (n + 0.5) * k) dct = torch.cos(math.pi / float(dim) * (n + 0.5) * k)
if norm == 'ortho': if norm is None:
dct[0] *= 1.0 / math.sqrt(2.0) dct *= 2.0
dct *= math.sqrt(2.0 / dim)
else: else:
dct *= 2 assert norm == 'ortho'
dct[0] *= 1.0 / math.sqrt(2.0)
dct *= math.sqrt(2.0 / float(dim))
return dct.t() return dct.t()
def MFCC(sig, mel_spect, log_mels, s2db, dct_mat): @torch.jit.script
# type: (Tensor, MelSpectrogram, bool, SpectrogramToDB, Tensor) -> Tensor
"""Create the Mel-frequency cepstrum coefficients from an audio signal
By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
This is not the textbook implementation, but is implemented here to
give consistency with librosa.
This output depends on the maximum value in the input spectrogram, and so
may return different values for an audio clip split into snippets vs. a
a full clip.
Inputs:
sig (Tensor): Tensor of audio of size (channels [c], samples [n])
mel_spect (MelSpectrogram): melspectrogram of sig
log_mels (bool): whether to use log-mel spectrograms instead of db-scaled
s2db (SpectrogramToDB): a SpectrogramToDB instance
dct_mat (Tensor): The transformation matrix (dct matrix), to be
right-multiplied to row-wise data
Outputs:
Tensor: Mel-frequency cepstrum coefficients
"""
if log_mels:
log_offset = 1e-6
mel_spect = torch.log(mel_spect + log_offset)
else:
mel_spect = s2db(mel_spect)
mfcc = torch.matmul(mel_spect, dct_mat.to(mel_spect.device))
return mfcc
def BLC2CBL(tensor): def BLC2CBL(tensor):
# type: (Tensor) -> Tensor # type: (Tensor) -> Tensor
"""Permute a 3d tensor from Bands x Sample length x Channels to Channels x """Permute a 3d tensor from Bands x Sample length x Channels to Channels x
...@@ -301,6 +254,7 @@ def BLC2CBL(tensor): ...@@ -301,6 +254,7 @@ def BLC2CBL(tensor):
return tensor.permute(2, 0, 1).contiguous() return tensor.permute(2, 0, 1).contiguous()
@torch.jit.script
def mu_law_encoding(x, qc): def mu_law_encoding(x, qc):
# type: (Tensor, int) -> Tensor # type: (Tensor, int) -> Tensor
"""Encode signal based on mu-law companding. For more info see the """Encode signal based on mu-law companding. For more info see the
...@@ -318,7 +272,7 @@ def mu_law_encoding(x, qc): ...@@ -318,7 +272,7 @@ def mu_law_encoding(x, qc):
""" """
assert isinstance(x, torch.Tensor), 'mu_law_encoding expects a Tensor' assert isinstance(x, torch.Tensor), 'mu_law_encoding expects a Tensor'
mu = qc - 1. mu = qc - 1.
if not x.dtype.is_floating_point: if not x.is_floating_point():
x = x.to(torch.float) x = x.to(torch.float)
mu = torch.tensor(mu, dtype=x.dtype) mu = torch.tensor(mu, dtype=x.dtype)
x_mu = torch.sign(x) * torch.log1p(mu * x_mu = torch.sign(x) * torch.log1p(mu *
...@@ -327,6 +281,7 @@ def mu_law_encoding(x, qc): ...@@ -327,6 +281,7 @@ def mu_law_encoding(x, qc):
return x_mu return x_mu
@torch.jit.script
def mu_law_expanding(x_mu, qc): def mu_law_expanding(x_mu, qc):
# type: (Tensor, int) -> Tensor # type: (Tensor, int) -> Tensor
"""Decode mu-law encoded signal. For more info see the """Decode mu-law encoded signal. For more info see the
...@@ -344,7 +299,7 @@ def mu_law_expanding(x_mu, qc): ...@@ -344,7 +299,7 @@ def mu_law_expanding(x_mu, qc):
""" """
assert isinstance(x_mu, torch.Tensor), 'mu_law_expanding expects a Tensor' assert isinstance(x_mu, torch.Tensor), 'mu_law_expanding expects a Tensor'
mu = qc - 1. mu = qc - 1.
if not x_mu.dtype.is_floating_point: if not x_mu.is_floating_point():
x_mu = x_mu.to(torch.float) x_mu = x_mu.to(torch.float)
mu = torch.tensor(mu, dtype=x_mu.dtype) mu = torch.tensor(mu, dtype=x_mu.dtype)
x = ((x_mu) / mu) * 2 - 1. x = ((x_mu) / mu) * 2 - 1.
......
...@@ -2,9 +2,11 @@ from __future__ import division, print_function ...@@ -2,9 +2,11 @@ from __future__ import division, print_function
from warnings import warn from warnings import warn
import math import math
import torch import torch
from typing import Optional
from . import functional as F from . import functional as F
# TODO remove this class
class Compose(object): class Compose(object):
"""Composes several transforms together. """Composes several transforms together.
...@@ -17,7 +19,6 @@ class Compose(object): ...@@ -17,7 +19,6 @@ class Compose(object):
>>> transforms.PadTrim(max_len=16000), >>> transforms.PadTrim(max_len=16000),
>>> ]) >>> ])
""" """
def __init__(self, transforms): def __init__(self, transforms):
self.transforms = transforms self.transforms = transforms
...@@ -35,7 +36,7 @@ class Compose(object): ...@@ -35,7 +36,7 @@ class Compose(object):
return format_string return format_string
class Scale(object): class Scale(torch.jit.ScriptModule):
"""Scale audio tensor from a 16-bit integer (represented as a FloatTensor) """Scale audio tensor from a 16-bit integer (represented as a FloatTensor)
to a floating point number between -1.0 and 1.0. Note the 16-bit number is to a floating point number between -1.0 and 1.0. Note the 16-bit number is
called the "bit depth" or "precision", not to be confused with "bit rate". called the "bit depth" or "precision", not to be confused with "bit rate".
...@@ -44,11 +45,14 @@ class Scale(object): ...@@ -44,11 +45,14 @@ class Scale(object):
factor (int): maximum value of input tensor. default: 16-bit depth factor (int): maximum value of input tensor. default: 16-bit depth
""" """
__constants__ = ['factor']
def __init__(self, factor=2**31): def __init__(self, factor=2**31):
super(Scale, self).__init__()
self.factor = factor self.factor = factor
def __call__(self, tensor): @torch.jit.script_method
def forward(self, tensor):
""" """
Args: Args:
...@@ -64,7 +68,7 @@ class Scale(object): ...@@ -64,7 +68,7 @@ class Scale(object):
return self.__class__.__name__ + '()' return self.__class__.__name__ + '()'
class PadTrim(object): class PadTrim(torch.jit.ScriptModule):
"""Pad/Trim a 2d-Tensor (Signal or Labels) """Pad/Trim a 2d-Tensor (Signal or Labels)
Args: Args:
...@@ -73,13 +77,16 @@ class PadTrim(object): ...@@ -73,13 +77,16 @@ class PadTrim(object):
channels_first (bool): Pad for channels first tensors. Default: `True` channels_first (bool): Pad for channels first tensors. Default: `True`
""" """
__constants__ = ['max_len', 'fill_value', 'len_dim', 'ch_dim']
def __init__(self, max_len, fill_value=0, channels_first=True): def __init__(self, max_len, fill_value=0., channels_first=True):
super(PadTrim, self).__init__()
self.max_len = max_len self.max_len = max_len
self.fill_value = fill_value self.fill_value = fill_value
self.len_dim, self.ch_dim = int(channels_first), int(not channels_first) self.len_dim, self.ch_dim = int(channels_first), int(not channels_first)
def __call__(self, tensor): @torch.jit.script_method
def forward(self, tensor):
""" """
Returns: Returns:
...@@ -92,7 +99,7 @@ class PadTrim(object): ...@@ -92,7 +99,7 @@ class PadTrim(object):
return self.__class__.__name__ + '(max_len={0})'.format(self.max_len) return self.__class__.__name__ + '(max_len={0})'.format(self.max_len)
class DownmixMono(object): class DownmixMono(torch.jit.ScriptModule):
"""Downmix any stereo signals to mono. Consider using a `SoxEffectsChain` with """Downmix any stereo signals to mono. Consider using a `SoxEffectsChain` with
the `channels` effect instead of this transformation. the `channels` effect instead of this transformation.
...@@ -104,22 +111,29 @@ class DownmixMono(object): ...@@ -104,22 +111,29 @@ class DownmixMono(object):
tensor (Tensor) (Samples x 1): tensor (Tensor) (Samples x 1):
""" """
__constants__ = ['ch_dim']
def __init__(self, channels_first=None): def __init__(self, channels_first=None):
super(DownmixMono, self).__init__()
self.ch_dim = int(not channels_first) self.ch_dim = int(not channels_first)
def __call__(self, tensor): @torch.jit.script_method
def forward(self, tensor):
return F.downmix_mono(tensor, self.ch_dim) return F.downmix_mono(tensor, self.ch_dim)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '()' return self.__class__.__name__ + '()'
class LC2CL(object): class LC2CL(torch.jit.ScriptModule):
"""Permute a 2d tensor from samples (n x c) to (c x n) """Permute a 2d tensor from samples (n x c) to (c x n)
""" """
def __call__(self, tensor): def __init__(self):
super(LC2CL, self).__init__()
@torch.jit.script_method
def forward(self, tensor):
""" """
Args: Args:
...@@ -139,7 +153,7 @@ def SPECTROGRAM(*args, **kwargs): ...@@ -139,7 +153,7 @@ def SPECTROGRAM(*args, **kwargs):
return Spectrogram(*args, **kwargs) return Spectrogram(*args, **kwargs)
class Spectrogram(object): class Spectrogram(torch.jit.ScriptModule):
"""Create a spectrogram from a raw audio signal """Create a spectrogram from a raw audio signal
Args: Args:
...@@ -153,21 +167,25 @@ class Spectrogram(object): ...@@ -153,21 +167,25 @@ class Spectrogram(object):
normalize (bool) : whether to normalize by magnitude after stft normalize (bool) : whether to normalize by magnitude after stft
wkwargs (dict, optional): arguments for window function wkwargs (dict, optional): arguments for window function
""" """
__constants__ = ['n_fft', 'ws', 'hop', 'pad', 'power', 'normalize']
def __init__(self, n_fft=400, ws=None, hop=None, def __init__(self, n_fft=400, ws=None, hop=None,
pad=0, window=torch.hann_window, pad=0, window=torch.hann_window,
power=2, normalize=False, wkwargs=None): power=2, normalize=False, wkwargs=None):
super(Spectrogram, self).__init__()
self.n_fft = n_fft self.n_fft = n_fft
# number of fft bins. the returned STFT result will have n_fft // 2 + 1 # number of fft bins. the returned STFT result will have n_fft // 2 + 1
# number of frequecies due to onesided=True in torch.stft # number of frequecies due to onesided=True in torch.stft
self.ws = ws if ws is not None else n_fft self.ws = ws if ws is not None else n_fft
self.hop = hop if hop is not None else self.ws // 2 self.hop = hop if hop is not None else self.ws // 2
self.window = window(self.ws) if wkwargs is None else window(self.ws, **wkwargs) window = window(self.ws) if wkwargs is None else window(self.ws, **wkwargs)
self.window = torch.jit.Attribute(window, torch.Tensor)
self.pad = pad self.pad = pad
self.power = power self.power = power
self.normalize = normalize self.normalize = normalize
self.wkwargs = wkwargs
def __call__(self, sig): @torch.jit.script_method
def forward(self, sig):
""" """
Args: Args:
sig (Tensor): Tensor of audio of size (c, n) sig (Tensor): Tensor of audio of size (c, n)
...@@ -188,10 +206,12 @@ def F2M(*args, **kwargs): ...@@ -188,10 +206,12 @@ def F2M(*args, **kwargs):
return MelScale(*args, **kwargs) return MelScale(*args, **kwargs)
class MelScale(object): class MelScale(torch.jit.ScriptModule):
"""This turns a normal STFT into a mel frequency STFT, using a conversion """This turns a normal STFT into a mel frequency STFT, using a conversion
matrix. This uses triangular filter banks. matrix. This uses triangular filter banks.
User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
Args: Args:
n_mels (int): number of mel bins n_mels (int): number of mel bins
sr (int): sample rate of audio signal sr (int): sample rate of audio signal
...@@ -200,20 +220,30 @@ class MelScale(object): ...@@ -200,20 +220,30 @@ class MelScale(object):
n_stft (int, optional): number of filter banks from stft. Calculated from first input n_stft (int, optional): number of filter banks from stft. Calculated from first input
if `None` is given. See `n_fft` in `Spectrogram`. if `None` is given. See `n_fft` in `Spectrogram`.
""" """
__constants__ = ['n_mels', 'sr', 'f_min', 'f_max']
def __init__(self, n_mels=128, sr=16000, f_max=None, f_min=0., n_stft=None): def __init__(self, n_mels=128, sr=16000, f_max=None, f_min=0., n_stft=None):
super(MelScale, self).__init__()
self.n_mels = n_mels self.n_mels = n_mels
self.sr = sr self.sr = sr
self.f_max = f_max if f_max is not None else sr // 2 self.f_max = f_max if f_max is not None else float(sr // 2)
self.f_min = f_min self.f_min = f_min
self.fb = F.create_fb_matrix( fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
n_stft, self.f_min, self.f_max, self.n_mels) if n_stft is not None else n_stft n_stft, self.f_min, self.f_max, self.n_mels)
self.fb = torch.jit.Attribute(fb, torch.Tensor)
def __call__(self, spec_f):
self.fb, spec_m = F.mel_scale(spec_f, self.f_min, self.f_max, self.n_mels, self.fb) @torch.jit.script_method
def forward(self, spec_f):
if self.fb.numel() == 0:
tmp_fb = F.create_fb_matrix(spec_f.size(2), self.f_min, self.f_max, self.n_mels)
# Attributes cannot be reassigned outside __init__ so workaround
self.fb.resize_(tmp_fb.size())
self.fb.copy_(tmp_fb)
spec_m = torch.matmul(spec_f, self.fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return spec_m return spec_m
class SpectrogramToDB(object): class SpectrogramToDB(torch.jit.ScriptModule):
"""Turns a spectrogram from the power/amplitude scale to the decibel scale. """Turns a spectrogram from the power/amplitude scale to the decibel scale.
This output depends on the maximum value in the input spectrogram, and so This output depends on the maximum value in the input spectrogram, and so
...@@ -226,23 +256,27 @@ class SpectrogramToDB(object): ...@@ -226,23 +256,27 @@ class SpectrogramToDB(object):
top_db (float, optional): minimum negative cut-off in decibels. A reasonable number top_db (float, optional): minimum negative cut-off in decibels. A reasonable number
is 80. is 80.
""" """
__constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier']
def __init__(self, stype="power", top_db=None): def __init__(self, stype="power", top_db=None):
self.stype = stype super(SpectrogramToDB, self).__init__()
self.stype = torch.jit.Attribute(stype, str)
if top_db is not None and top_db < 0: if top_db is not None and top_db < 0:
raise ValueError('top_db must be positive value') raise ValueError('top_db must be positive value')
self.top_db = top_db self.top_db = torch.jit.Attribute(top_db, Optional[float])
self.multiplier = 10. if stype == "power" else 20. self.multiplier = 10. if stype == "power" else 20.
self.amin = 1e-10 self.amin = 1e-10
self.ref_value = 1. self.ref_value = 1.
self.db_multiplier = math.log10(max(self.amin, self.ref_value)) self.db_multiplier = math.log10(max(self.amin, self.ref_value))
def __call__(self, spec): @torch.jit.script_method
def forward(self, spec):
# numerically stable implementation from librosa # numerically stable implementation from librosa
# https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html # https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html
return F.spectrogram_to_DB(spec, self.multiplier, self.amin, self.db_multiplier, self.top_db) return F.spectrogram_to_DB(spec, self.multiplier, self.amin, self.db_multiplier, self.top_db)
class MFCC(object): class MFCC(torch.jit.ScriptModule):
"""Create the Mel-frequency cepstrum coefficients from an audio signal """Create the Mel-frequency cepstrum coefficients from an audio signal
By default, this calculates the MFCC on the DB-scaled Mel spectrogram. By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
...@@ -257,21 +291,22 @@ class MFCC(object): ...@@ -257,21 +291,22 @@ class MFCC(object):
sr (int) : sample rate of audio signal sr (int) : sample rate of audio signal
n_mfcc (int) : number of mfc coefficients to retain n_mfcc (int) : number of mfc coefficients to retain
dct_type (int) : type of DCT (discrete cosine transform) to use dct_type (int) : type of DCT (discrete cosine transform) to use
norm (string) : norm to use norm (string, optional) : norm to use
log_mels (bool) : whether to use log-mel spectrograms instead of db-scaled log_mels (bool) : whether to use log-mel spectrograms instead of db-scaled
melkwargs (dict, optional): arguments for MelSpectrogram melkwargs (dict, optional): arguments for MelSpectrogram
""" """
__constants__ = ['sr', 'n_mfcc', 'dct_type', 'top_db', 'log_mels']
def __init__(self, sr=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False, def __init__(self, sr=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False,
melkwargs=None): melkwargs=None):
super(MFCC, self).__init__()
supported_dct_types = [2] supported_dct_types = [2]
if dct_type not in supported_dct_types: if dct_type not in supported_dct_types:
raise ValueError('DCT type not supported'.format(dct_type)) raise ValueError('DCT type not supported'.format(dct_type))
self.sr = sr self.sr = sr
self.n_mfcc = n_mfcc self.n_mfcc = n_mfcc
self.dct_type = dct_type self.dct_type = dct_type
self.norm = norm self.norm = torch.jit.Attribute(norm, Optional[str])
self.melkwargs = melkwargs
self.top_db = 80. self.top_db = 80.
self.s2db = SpectrogramToDB("power", self.top_db) self.s2db = SpectrogramToDB("power", self.top_db)
...@@ -282,10 +317,12 @@ class MFCC(object): ...@@ -282,10 +317,12 @@ class MFCC(object):
if self.n_mfcc > self.MelSpectrogram.n_mels: if self.n_mfcc > self.MelSpectrogram.n_mels:
raise ValueError('Cannot select more MFCC coefficients than # mel bins') raise ValueError('Cannot select more MFCC coefficients than # mel bins')
self.dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm) dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm)
self.dct_mat = torch.jit.Attribute(dct_mat, torch.Tensor)
self.log_mels = log_mels self.log_mels = log_mels
def __call__(self, sig): @torch.jit.script_method
def forward(self, sig):
""" """
Args: Args:
sig (Tensor): Tensor of audio of size (channels [c], samples [n]) sig (Tensor): Tensor of audio of size (channels [c], samples [n])
...@@ -295,10 +332,17 @@ class MFCC(object): ...@@ -295,10 +332,17 @@ class MFCC(object):
is unchanged, hops is the number of hops, and n_mels is the is unchanged, hops is the number of hops, and n_mels is the
number of mel bins. number of mel bins.
""" """
return F.MFCC(sig, self.MelSpectrogram(sig), self.log_mels, self.s2db, self.dct_mat) mel_spect = self.MelSpectrogram(sig)
if self.log_mels:
log_offset = 1e-6
mel_spect = torch.log(mel_spect + log_offset)
else:
mel_spect = self.s2db(mel_spect)
mfcc = torch.matmul(mel_spect, self.dct_mat)
return mfcc
class MelSpectrogram(object): class MelSpectrogram(torch.jit.ScriptModule):
"""Create MEL Spectrograms from a raw audio signal using the stft """Create MEL Spectrograms from a raw audio signal using the stft
function in PyTorch. function in PyTorch.
...@@ -323,27 +367,26 @@ class MelSpectrogram(object): ...@@ -323,27 +367,26 @@ class MelSpectrogram(object):
>>> sig, sr = torchaudio.load("test.wav", normalization=True) >>> sig, sr = torchaudio.load("test.wav", normalization=True)
>>> spec_mel = transforms.MelSpectrogram(sr)(sig) # (c, l, m) >>> spec_mel = transforms.MelSpectrogram(sr)(sig) # (c, l, m)
""" """
__constants__ = ['sr', 'n_fft', 'ws', 'hop', 'pad', 'n_mels', 'f_min']
def __init__(self, sr=16000, n_fft=400, ws=None, hop=None, f_min=0., f_max=None, def __init__(self, sr=16000, n_fft=400, ws=None, hop=None, f_min=0., f_max=None,
pad=0, n_mels=128, window=torch.hann_window, wkwargs=None): pad=0, n_mels=128, window=torch.hann_window, wkwargs=None):
self.window = window super(MelSpectrogram, self).__init__()
self.sr = sr self.sr = sr
self.n_fft = n_fft self.n_fft = n_fft
self.ws = ws if ws is not None else n_fft self.ws = ws if ws is not None else n_fft
self.hop = hop if hop is not None else self.ws // 2 self.hop = hop if hop is not None else self.ws // 2
self.pad = pad self.pad = pad
self.n_mels = n_mels # number of mel frequency bins self.n_mels = n_mels # number of mel frequency bins
self.wkwargs = wkwargs self.f_max = torch.jit.Attribute(f_max, Optional[float])
self.f_max = f_max
self.f_min = f_min self.f_min = f_min
self.spec = Spectrogram(n_fft=self.n_fft, ws=self.ws, hop=self.hop, self.spec = Spectrogram(n_fft=self.n_fft, ws=self.ws, hop=self.hop,
pad=self.pad, window=self.window, power=2, pad=self.pad, window=window, power=2,
normalize=False, wkwargs=self.wkwargs) normalize=False, wkwargs=wkwargs)
self.fm = MelScale(self.n_mels, self.sr, self.f_max, self.f_min) self.fm = MelScale(self.n_mels, self.sr, self.f_max, self.f_min)
self.transforms = Compose([
self.spec, self.fm
])
def __call__(self, sig): @torch.jit.script_method
def forward(self, sig):
""" """
Args: Args:
sig (Tensor): Tensor of audio of size (channels [c], samples [n]) sig (Tensor): Tensor of audio of size (channels [c], samples [n])
...@@ -354,8 +397,8 @@ class MelSpectrogram(object): ...@@ -354,8 +397,8 @@ class MelSpectrogram(object):
number of mel bins. number of mel bins.
""" """
spec_mel = self.transforms(sig) spec = self.spec(sig)
spec_mel = self.fm(spec)
return spec_mel return spec_mel
...@@ -363,12 +406,16 @@ def MEL(*args, **kwargs): ...@@ -363,12 +406,16 @@ def MEL(*args, **kwargs):
raise DeprecationWarning("MEL has been removed from the library please use MelSpectrogram or librosa") raise DeprecationWarning("MEL has been removed from the library please use MelSpectrogram or librosa")
class BLC2CBL(object): class BLC2CBL(torch.jit.ScriptModule):
"""Permute a 3d tensor from Bands x Sample length x Channels to Channels x """Permute a 3d tensor from Bands x Sample length x Channels to Channels x
Bands x Samples length Bands x Samples length
""" """
def __call__(self, tensor): def __init__(self):
super(BLC2CBL, self).__init__()
@torch.jit.script_method
def forward(self, tensor):
""" """
Args: Args:
...@@ -384,7 +431,7 @@ class BLC2CBL(object): ...@@ -384,7 +431,7 @@ class BLC2CBL(object):
return self.__class__.__name__ + '()' return self.__class__.__name__ + '()'
class MuLawEncoding(object): class MuLawEncoding(torch.jit.ScriptModule):
"""Encode signal based on mu-law companding. For more info see the """Encode signal based on mu-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_ `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
...@@ -395,11 +442,14 @@ class MuLawEncoding(object): ...@@ -395,11 +442,14 @@ class MuLawEncoding(object):
quantization_channels (int): Number of channels. default: 256 quantization_channels (int): Number of channels. default: 256
""" """
__constants__ = ['qc']
def __init__(self, quantization_channels=256): def __init__(self, quantization_channels=256):
super(MuLawEncoding, self).__init__()
self.qc = quantization_channels self.qc = quantization_channels
def __call__(self, x): @torch.jit.script_method
def forward(self, x):
""" """
Args: Args:
...@@ -415,7 +465,7 @@ class MuLawEncoding(object): ...@@ -415,7 +465,7 @@ class MuLawEncoding(object):
return self.__class__.__name__ + '()' return self.__class__.__name__ + '()'
class MuLawExpanding(object): class MuLawExpanding(torch.jit.ScriptModule):
"""Decode mu-law encoded signal. For more info see the """Decode mu-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_ `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
...@@ -426,11 +476,14 @@ class MuLawExpanding(object): ...@@ -426,11 +476,14 @@ class MuLawExpanding(object):
quantization_channels (int): Number of channels. default: 256 quantization_channels (int): Number of channels. default: 256
""" """
__constants__ = ['qc']
def __init__(self, quantization_channels=256): def __init__(self, quantization_channels=256):
super(MuLawExpanding, self).__init__()
self.qc = quantization_channels self.qc = quantization_channels
def __call__(self, x_mu): @torch.jit.script_method
def forward(self, x_mu):
""" """
Args: Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment