Commit b29a4639 authored by jamarshon's avatar jamarshon Committed by cpuhrsch
Browse files

[BC] Standardization of Transforms/Functionals (#152)

parent 2271a7ae
...@@ -2,6 +2,8 @@ import math ...@@ -2,6 +2,8 @@ import math
import torch import torch
import torchaudio import torchaudio
import torchaudio.functional as F
import pytest
import unittest import unittest
import test.common_utils import test.common_utils
...@@ -11,10 +13,6 @@ if IMPORT_LIBROSA: ...@@ -11,10 +13,6 @@ if IMPORT_LIBROSA:
import numpy as np import numpy as np
import librosa import librosa
import pytest
import torchaudio.functional as F
xfail = pytest.mark.xfail
class TestFunctional(unittest.TestCase): class TestFunctional(unittest.TestCase):
data_sizes = [(2, 20), (3, 15), (4, 10)] data_sizes = [(2, 20), (3, 15), (4, 10)]
...@@ -197,54 +195,6 @@ def _num_stft_bins(signal_len, fft_len, hop_length, pad): ...@@ -197,54 +195,6 @@ def _num_stft_bins(signal_len, fft_len, hop_length, pad):
return (signal_len + 2 * pad - fft_len + hop_length) // hop_length return (signal_len + 2 * pad - fft_len + hop_length) // hop_length
@pytest.mark.parametrize('fft_length', [512])
@pytest.mark.parametrize('hop_length', [256])
@pytest.mark.parametrize('waveform', [
(torch.randn(1, 100000)),
(torch.randn(1, 2, 100000)),
pytest.param(torch.randn(1, 100), marks=xfail(raises=RuntimeError)),
])
@pytest.mark.parametrize('pad_mode', [
# 'constant',
'reflect',
])
@unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available')
def test_stft(waveform, fft_length, hop_length, pad_mode):
"""
Test STFT for multi-channel signals.
Padding: Value in having padding outside of torch.stft?
"""
pad = fft_length // 2
window = torch.hann_window(fft_length)
complex_spec = F.stft(waveform,
fft_length=fft_length,
hop_length=hop_length,
window=window,
pad_mode=pad_mode)
mag_spec, phase_spec = F.magphase(complex_spec)
# == Test shape
expected_size = list(waveform.size()[:-1])
expected_size += [fft_length // 2 + 1, _num_stft_bins(
waveform.size(-1), fft_length, hop_length, pad), 2]
assert complex_spec.dim() == waveform.dim() + 2
assert complex_spec.size() == torch.Size(expected_size)
# == Test values
fft_config = dict(n_fft=fft_length, hop_length=hop_length, pad_mode=pad_mode)
# note that librosa *automatically* pad with fft_length // 2.
expected_complex_spec = np.apply_along_axis(librosa.stft, -1,
waveform.numpy(), **fft_config)
expected_mag_spec, _ = librosa.magphase(expected_complex_spec)
# Convert torch to np.complex
complex_spec = complex_spec.numpy()
complex_spec = complex_spec[..., 0] + 1j * complex_spec[..., 1]
assert np.allclose(complex_spec, expected_complex_spec, atol=1e-5)
assert np.allclose(mag_spec.numpy(), expected_mag_spec, atol=1e-5)
@pytest.mark.parametrize('rate', [0.5, 1.01, 1.3]) @pytest.mark.parametrize('rate', [0.5, 1.01, 1.3])
@pytest.mark.parametrize('complex_specgrams', [ @pytest.mark.parametrize('complex_specgrams', [
torch.randn(1, 2, 1025, 400, 2), torch.randn(1, 2, 1025, 400, 2),
......
...@@ -30,40 +30,18 @@ class Test_JIT(unittest.TestCase): ...@@ -30,40 +30,18 @@ class Test_JIT(unittest.TestCase):
self.assertTrue(torch.allclose(jit_out, py_out)) self.assertTrue(torch.allclose(jit_out, py_out))
def test_torchscript_scale(self):
@torch.jit.script
def jit_method(tensor, factor):
# type: (Tensor, int) -> Tensor
return F.scale(tensor, factor)
tensor = torch.rand((10, 1))
factor = 2
jit_out = jit_method(tensor, factor)
py_out = F.scale(tensor, factor)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_scale(self):
tensor = torch.rand((10, 1), device="cuda")
self._test_script_module(tensor, transforms.Scale)
def test_torchscript_pad_trim(self): def test_torchscript_pad_trim(self):
@torch.jit.script @torch.jit.script
def jit_method(tensor, ch_dim, max_len, len_dim, fill_value): def jit_method(tensor, max_len, fill_value):
# type: (Tensor, int, int, int, float) -> Tensor # type: (Tensor, int, float) -> Tensor
return F.pad_trim(tensor, ch_dim, max_len, len_dim, fill_value) return F.pad_trim(tensor, max_len, fill_value)
tensor = torch.rand((10, 1)) tensor = torch.rand((1, 10))
ch_dim = 1
max_len = 5 max_len = 5
len_dim = 0
fill_value = 3. fill_value = 3.
jit_out = jit_method(tensor, ch_dim, max_len, len_dim, fill_value) jit_out = jit_method(tensor, max_len, fill_value)
py_out = F.pad_trim(tensor, ch_dim, max_len, len_dim, fill_value) py_out = F.pad_trim(tensor, max_len, fill_value)
self.assertTrue(torch.allclose(jit_out, py_out)) self.assertTrue(torch.allclose(jit_out, py_out))
...@@ -74,45 +52,6 @@ class Test_JIT(unittest.TestCase): ...@@ -74,45 +52,6 @@ class Test_JIT(unittest.TestCase):
self._test_script_module(tensor, transforms.PadTrim, max_len) self._test_script_module(tensor, transforms.PadTrim, max_len)
def test_torchscript_downmix_mono(self):
@torch.jit.script
def jit_method(tensor, ch_dim):
# type: (Tensor, int) -> Tensor
return F.downmix_mono(tensor, ch_dim)
tensor = torch.rand((10, 1))
ch_dim = 1
jit_out = jit_method(tensor, ch_dim)
py_out = F.downmix_mono(tensor, ch_dim)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_downmix_mono(self):
tensor = torch.rand((1, 10), device="cuda")
self._test_script_module(tensor, transforms.DownmixMono)
def test_torchscript_LC2CL(self):
@torch.jit.script
def jit_method(tensor):
# type: (Tensor) -> Tensor
return F.LC2CL(tensor)
tensor = torch.rand((10, 1))
jit_out = jit_method(tensor)
py_out = F.LC2CL(tensor)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_LC2CL(self):
tensor = torch.rand((10, 1), device="cuda")
self._test_script_module(tensor, transforms.LC2CL)
def test_torchscript_spectrogram(self): def test_torchscript_spectrogram(self):
@torch.jit.script @torch.jit.script
def jit_method(sig, pad, window, n_fft, hop, ws, power, normalize): def jit_method(sig, pad, window, n_fft, hop, ws, power, normalize):
...@@ -167,7 +106,7 @@ class Test_JIT(unittest.TestCase): ...@@ -167,7 +106,7 @@ class Test_JIT(unittest.TestCase):
# type: (Tensor, float, float, float, Optional[float]) -> Tensor # type: (Tensor, float, float, float, Optional[float]) -> Tensor
return F.spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db) return F.spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db)
spec = torch.rand((10, 1)) spec = torch.rand((6, 201))
multiplier = 10. multiplier = 10.
amin = 1e-10 amin = 1e-10
db_multiplier = 0. db_multiplier = 0.
...@@ -180,7 +119,7 @@ class Test_JIT(unittest.TestCase): ...@@ -180,7 +119,7 @@ class Test_JIT(unittest.TestCase):
@unittest.skipIf(not RUN_CUDA, "no CUDA") @unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_SpectrogramToDB(self): def test_scriptmodule_SpectrogramToDB(self):
spec = torch.rand((10, 1), device="cuda") spec = torch.rand((6, 201), device="cuda")
self._test_script_module(spec, transforms.SpectrogramToDB) self._test_script_module(spec, transforms.SpectrogramToDB)
...@@ -211,32 +150,13 @@ class Test_JIT(unittest.TestCase): ...@@ -211,32 +150,13 @@ class Test_JIT(unittest.TestCase):
self._test_script_module(tensor, transforms.MelSpectrogram) self._test_script_module(tensor, transforms.MelSpectrogram)
def test_torchscript_BLC2CBL(self):
@torch.jit.script
def jit_method(tensor):
# type: (Tensor) -> Tensor
return F.BLC2CBL(tensor)
tensor = torch.rand((10, 1000, 1))
jit_out = jit_method(tensor)
py_out = F.BLC2CBL(tensor)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_BLC2CBL(self):
tensor = torch.rand((10, 1000, 1), device="cuda")
self._test_script_module(tensor, transforms.BLC2CBL)
def test_torchscript_mu_law_encoding(self): def test_torchscript_mu_law_encoding(self):
@torch.jit.script @torch.jit.script
def jit_method(tensor, qc): def jit_method(tensor, qc):
# type: (Tensor, int) -> Tensor # type: (Tensor, int) -> Tensor
return F.mu_law_encoding(tensor, qc) return F.mu_law_encoding(tensor, qc)
tensor = torch.rand((10, 1)) tensor = torch.rand((1, 10))
qc = 256 qc = 256
jit_out = jit_method(tensor, qc) jit_out = jit_method(tensor, qc)
...@@ -246,7 +166,7 @@ class Test_JIT(unittest.TestCase): ...@@ -246,7 +166,7 @@ class Test_JIT(unittest.TestCase):
@unittest.skipIf(not RUN_CUDA, "no CUDA") @unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_MuLawEncoding(self): def test_scriptmodule_MuLawEncoding(self):
tensor = torch.rand((10, 1), device="cuda") tensor = torch.rand((1, 10), device="cuda")
self._test_script_module(tensor, transforms.MuLawEncoding) self._test_script_module(tensor, transforms.MuLawEncoding)
...@@ -256,7 +176,7 @@ class Test_JIT(unittest.TestCase): ...@@ -256,7 +176,7 @@ class Test_JIT(unittest.TestCase):
# type: (Tensor, int) -> Tensor # type: (Tensor, int) -> Tensor
return F.mu_law_expanding(tensor, qc) return F.mu_law_expanding(tensor, qc)
tensor = torch.rand((10, 1)) tensor = torch.rand((1, 10))
qc = 256 qc = 256
jit_out = jit_method(tensor, qc) jit_out = jit_method(tensor, qc)
...@@ -266,7 +186,7 @@ class Test_JIT(unittest.TestCase): ...@@ -266,7 +186,7 @@ class Test_JIT(unittest.TestCase):
@unittest.skipIf(not RUN_CUDA, "no CUDA") @unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_MuLawExpanding(self): def test_scriptmodule_MuLawExpanding(self):
tensor = torch.rand((10, 1), device="cuda") tensor = torch.rand((1, 10), device="cuda")
self._test_script_module(tensor, transforms.MuLawExpanding) self._test_script_module(tensor, transforms.MuLawExpanding)
......
...@@ -19,191 +19,123 @@ if IMPORT_SCIPY: ...@@ -19,191 +19,123 @@ if IMPORT_SCIPY:
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
# create a sinewave signal for testing # create a sinewave signal for testing
sr = 16000 sample_rate = 16000
freq = 440 freq = 440
volume = .3 volume = .3
sig = (torch.cos(2 * math.pi * torch.arange(0, 4 * sr).float() * freq / sr)) waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate))
sig.unsqueeze_(1) # (64000, 1) waveform.unsqueeze_(0) # (1, 64000)
sig = (sig * volume * 2**31).long() waveform = (waveform * volume * 2**31).long()
# file for stereo stft test # file for stereo stft test
test_dirpath, test_dir = test.common_utils.create_temp_assets_dir() test_dirpath, test_dir = test.common_utils.create_temp_assets_dir()
test_filepath = os.path.join(test_dirpath, "assets", test_filepath = os.path.join(test_dirpath, 'assets',
"steam-train-whistle-daniel_simon.mp3") 'steam-train-whistle-daniel_simon.mp3')
def test_scale(self):
audio_orig = self.sig.clone()
result = transforms.Scale()(audio_orig)
self.assertTrue(result.min() >= -1. and result.max() <= 1.)
maxminmax = max(abs(audio_orig.min()), abs(audio_orig.max())).item()
result = transforms.Scale(factor=maxminmax)(audio_orig)
self.assertTrue((result.min() == -1. or result.max() == 1.) and def scale(self, waveform, factor=float(2**31)):
result.min() >= -1. and result.max() <= 1.) # scales a waveform by a factor
if not waveform.is_floating_point():
repr_test = transforms.Scale() waveform = waveform.to(torch.get_default_dtype())
self.assertTrue(repr_test.__repr__()) return waveform / factor
def test_pad_trim(self): def test_pad_trim(self):
audio_orig = self.sig.clone() waveform = self.waveform.clone()
length_orig = audio_orig.size(0) length_orig = waveform.size(1)
length_new = int(length_orig * 1.2) length_new = int(length_orig * 1.2)
result = transforms.PadTrim(max_len=length_new, channels_first=False)(audio_orig) result = transforms.PadTrim(max_len=length_new)(waveform)
self.assertEqual(result.size(0), length_new)
result = transforms.PadTrim(max_len=length_new, channels_first=True)(audio_orig.transpose(0, 1))
self.assertEqual(result.size(1), length_new) self.assertEqual(result.size(1), length_new)
audio_orig = self.sig.clone()
length_orig = audio_orig.size(0)
length_new = int(length_orig * 0.8) length_new = int(length_orig * 0.8)
result = transforms.PadTrim(max_len=length_new, channels_first=False)(audio_orig) result = transforms.PadTrim(max_len=length_new)(waveform)
self.assertEqual(result.size(1), length_new)
self.assertEqual(result.size(0), length_new)
repr_test = transforms.PadTrim(max_len=length_new, channels_first=False)
self.assertTrue(repr_test.__repr__())
def test_downmix_mono(self):
audio_L = self.sig.clone()
audio_R = self.sig.clone()
R_idx = int(audio_R.size(0) * 0.1)
audio_R = torch.cat((audio_R[R_idx:], audio_R[:R_idx]))
audio_Stereo = torch.cat((audio_L, audio_R), dim=1)
self.assertTrue(audio_Stereo.size(1) == 2)
result = transforms.DownmixMono(channels_first=False)(audio_Stereo)
self.assertTrue(result.size(1) == 1)
repr_test = transforms.DownmixMono(channels_first=False)
self.assertTrue(repr_test.__repr__())
def test_lc2cl(self):
audio = self.sig.clone()
result = transforms.LC2CL()(audio)
self.assertTrue(result.size()[::-1] == audio.size())
repr_test = transforms.LC2CL()
self.assertTrue(repr_test.__repr__())
def test_compose(self):
audio_orig = self.sig.clone()
length_orig = audio_orig.size(0)
length_new = int(length_orig * 1.2)
maxminmax = max(abs(audio_orig.min()), abs(audio_orig.max())).item()
tset = (transforms.Scale(factor=maxminmax),
transforms.PadTrim(max_len=length_new, channels_first=False))
result = transforms.Compose(tset)(audio_orig)
self.assertTrue(max(abs(result.min()), abs(result.max())) == 1.)
self.assertTrue(result.size(0) == length_new)
repr_test = transforms.Compose(tset)
self.assertTrue(repr_test.__repr__())
def test_mu_law_companding(self): def test_mu_law_companding(self):
quantization_channels = 256 quantization_channels = 256
sig = self.sig.clone() waveform = self.waveform.clone()
sig = sig / torch.abs(sig).max() waveform /= torch.abs(waveform).max()
self.assertTrue(sig.min() >= -1. and sig.max() <= 1.) self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.)
sig_mu = transforms.MuLawEncoding(quantization_channels)(sig) waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
self.assertTrue(sig_mu.min() >= 0. and sig.max() <= quantization_channels) self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels)
sig_exp = transforms.MuLawExpanding(quantization_channels)(sig_mu) waveform_exp = transforms.MuLawExpanding(quantization_channels)(waveform_mu)
self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.) self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
repr_test = transforms.MuLawEncoding(quantization_channels)
self.assertTrue(repr_test.__repr__())
repr_test = transforms.MuLawExpanding(quantization_channels)
self.assertTrue(repr_test.__repr__())
def test_mel2(self): def test_mel2(self):
top_db = 80. top_db = 80.
s2db = transforms.SpectrogramToDB("power", top_db) s2db = transforms.SpectrogramToDB('power', top_db)
audio_orig = self.sig.clone() # (16000, 1) waveform = self.waveform.clone() # (1, 16000)
audio_scaled = transforms.Scale()(audio_orig) # (16000, 1) waveform_scaled = self.scale(waveform) # (1, 16000)
audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000)
mel_transform = transforms.MelSpectrogram() mel_transform = transforms.MelSpectrogram()
# check defaults # check defaults
spectrogram_torch = s2db(mel_transform(audio_scaled)) # (1, 319, 40) spectrogram_torch = s2db(mel_transform(waveform_scaled)) # (1, 128, 321)
self.assertTrue(spectrogram_torch.dim() == 3) self.assertTrue(spectrogram_torch.dim() == 3)
self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
self.assertEqual(spectrogram_torch.size(-1), mel_transform.n_mels) self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
# check correctness of filterbank conversion matrix # check correctness of filterbank conversion matrix
self.assertTrue(mel_transform.fm.fb.sum(1).le(1.).all()) self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
self.assertTrue(mel_transform.fm.fb.sum(1).ge(0.).all()) self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
# check options # check options
kwargs = {"window": torch.hamming_window, "pad": 10, "ws": 500, "hop": 125, "n_fft": 800, "n_mels": 50} kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500,
'hop_length': 125, 'n_fft': 800, 'n_mels': 50}
mel_transform2 = transforms.MelSpectrogram(**kwargs) mel_transform2 = transforms.MelSpectrogram(**kwargs)
spectrogram2_torch = s2db(mel_transform2(audio_scaled)) # (1, 506, 50) spectrogram2_torch = s2db(mel_transform2(waveform_scaled)) # (1, 50, 513)
self.assertTrue(spectrogram2_torch.dim() == 3) self.assertTrue(spectrogram2_torch.dim() == 3)
self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
self.assertEqual(spectrogram2_torch.size(-1), mel_transform2.n_mels) self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
self.assertTrue(mel_transform2.fm.fb.sum(1).le(1.).all()) self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
self.assertTrue(mel_transform2.fm.fb.sum(1).ge(0.).all()) self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
# check on multi-channel audio # check on multi-channel audio
x_stereo, sr_stereo = torchaudio.load(self.test_filepath) x_stereo, sr_stereo = torchaudio.load(self.test_filepath) # (2, 278756), 44100
spectrogram_stereo = s2db(mel_transform(x_stereo)) spectrogram_stereo = s2db(mel_transform(x_stereo)) # (2, 128, 1394)
self.assertTrue(spectrogram_stereo.dim() == 3) self.assertTrue(spectrogram_stereo.dim() == 3)
self.assertTrue(spectrogram_stereo.size(0) == 2) self.assertTrue(spectrogram_stereo.size(0) == 2)
self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
self.assertEqual(spectrogram_stereo.size(-1), mel_transform.n_mels) self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
# check filterbank matrix creation # check filterbank matrix creation
fb_matrix_transform = transforms.MelScale(n_mels=100, sr=16000, f_max=None, f_min=0., n_stft=400) fb_matrix_transform = transforms.MelScale(
n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400)
self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all()) self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all()) self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
self.assertEqual(fb_matrix_transform.fb.size(), (400, 100)) self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
def test_mfcc(self): def test_mfcc(self):
audio_orig = self.sig.clone() audio_orig = self.waveform.clone()
audio_scaled = transforms.Scale()(audio_orig) # (16000, 1) audio_scaled = self.scale(audio_orig) # (1, 16000)
audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000)
sample_rate = 16000 sample_rate = 16000
n_mfcc = 40 n_mfcc = 40
n_mels = 128 n_mels = 128
mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate, mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc, n_mfcc=n_mfcc,
norm='ortho') norm='ortho')
# check defaults # check defaults
torch_mfcc = mfcc_transform(audio_scaled) torch_mfcc = mfcc_transform(audio_scaled) # (1, 40, 321)
self.assertTrue(torch_mfcc.dim() == 3) self.assertTrue(torch_mfcc.dim() == 3)
self.assertTrue(torch_mfcc.shape[2] == n_mfcc) self.assertTrue(torch_mfcc.shape[1] == n_mfcc)
self.assertTrue(torch_mfcc.shape[1] == 321) self.assertTrue(torch_mfcc.shape[2] == 321)
# check melkwargs are passed through # check melkwargs are passed through
melkwargs = {'ws': 200} melkwargs = {'win_length': 200}
mfcc_transform2 = torchaudio.transforms.MFCC(sr=sample_rate, mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc, n_mfcc=n_mfcc,
norm='ortho', norm='ortho',
melkwargs=melkwargs) melkwargs=melkwargs)
torch_mfcc2 = mfcc_transform2(audio_scaled) torch_mfcc2 = mfcc_transform2(audio_scaled) # (1, 40, 641)
self.assertTrue(torch_mfcc2.shape[1] == 641) self.assertTrue(torch_mfcc2.shape[2] == 641)
# check norms work correctly # check norms work correctly
mfcc_transform_norm_none = torchaudio.transforms.MFCC(sr=sample_rate, mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc, n_mfcc=n_mfcc,
norm=None) norm=None)
torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled) torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled) # (1, 40, 321)
norm_check = torch_mfcc.clone() norm_check = torch_mfcc.clone()
norm_check[:, :, 0] *= math.sqrt(n_mels) * 2 norm_check[:, 0, :] *= math.sqrt(n_mels) * 2
norm_check[:, :, 1:] *= math.sqrt(n_mels / 2) * 2 norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2
self.assertTrue(torch_mfcc_norm_none.allclose(norm_check)) self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))
...@@ -212,45 +144,45 @@ class Tester(unittest.TestCase): ...@@ -212,45 +144,45 @@ class Tester(unittest.TestCase):
def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate): def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
sound, sample_rate = torchaudio.load(input_path) sound, sample_rate = torchaudio.load(input_path)
sound_librosa = sound.cpu().numpy().squeeze().T # squeeze batch and channel first sound_librosa = sound.cpu().numpy().squeeze() # (64000)
# test core spectrogram # test core spectrogram
spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop=hop_length, power=2) spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=2)
out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa, out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa,
n_fft=n_fft, n_fft=n_fft,
hop_length=hop_length, hop_length=hop_length,
power=2) power=2)
out_torch = spect_transform(sound).squeeze().cpu().t() out_torch = spect_transform(sound).squeeze().cpu()
self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5)) self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5))
# test mel spectrogram # test mel spectrogram
melspect_transform = torchaudio.transforms.MelSpectrogram(sr=sample_rate, window=torch.hann_window, melspect_transform = torchaudio.transforms.MelSpectrogram(
hop=hop_length, n_mels=n_mels, n_fft=n_fft) sample_rate=sample_rate, window_fn=torch.hann_window,
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft)
librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate, librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate,
n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
htk=True, norm=None) htk=True, norm=None)
librosa_mel_tensor = torch.from_numpy(librosa_mel) librosa_mel_tensor = torch.from_numpy(librosa_mel)
torch_mel = melspect_transform(sound).squeeze().cpu().t() torch_mel = melspect_transform(sound).squeeze().cpu()
self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3)) self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3))
# test s2db # test s2db
db_transform = torchaudio.transforms.SpectrogramToDB('power', 80.)
db_transform = torchaudio.transforms.SpectrogramToDB("power", 80.) db_torch = db_transform(spect_transform(sound)).squeeze().cpu()
db_torch = db_transform(spect_transform(sound)).squeeze().cpu().t()
db_librosa = librosa.core.spectrum.power_to_db(out_librosa) db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3)) self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3))
db_torch = db_transform(melspect_transform(sound)).squeeze().cpu().t() db_torch = db_transform(melspect_transform(sound)).squeeze().cpu()
db_librosa = librosa.core.spectrum.power_to_db(librosa_mel) db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
db_librosa_tensor = torch.from_numpy(db_librosa) db_librosa_tensor = torch.from_numpy(db_librosa)
self.assertTrue(torch.allclose(db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3)) self.assertTrue(torch.allclose(db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3))
# test MFCC # test MFCC
melkwargs = {'hop': hop_length, 'n_fft': n_fft} melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate, mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc, n_mfcc=n_mfcc,
norm='ortho', norm='ortho',
melkwargs=melkwargs) melkwargs=melkwargs)
...@@ -271,7 +203,7 @@ class Tester(unittest.TestCase): ...@@ -271,7 +203,7 @@ class Tester(unittest.TestCase):
librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc] librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc) librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
torch_mfcc = mfcc_transform(sound).squeeze().cpu().t() torch_mfcc = mfcc_transform(sound).squeeze().cpu()
self.assertTrue(torch.allclose(torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3)) self.assertTrue(torch.allclose(torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3))
...@@ -308,27 +240,27 @@ class Tester(unittest.TestCase): ...@@ -308,27 +240,27 @@ class Tester(unittest.TestCase):
def test_resample_size(self): def test_resample_size(self):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
sound, sample_rate = torchaudio.load(input_path) waveform, sample_rate = torchaudio.load(input_path)
upsample_rate = sample_rate * 2 upsample_rate = sample_rate * 2
downsample_rate = sample_rate // 2 downsample_rate = sample_rate // 2
invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo') invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo')
self.assertRaises(ValueError, invalid_resample, sound) self.assertRaises(ValueError, invalid_resample, waveform)
upsample_resample = torchaudio.transforms.Resample( upsample_resample = torchaudio.transforms.Resample(
sample_rate, upsample_rate, resampling_method='sinc_interpolation') sample_rate, upsample_rate, resampling_method='sinc_interpolation')
up_sampled = upsample_resample(sound) up_sampled = upsample_resample(waveform)
# we expect the upsampled signal to have twice as many samples # we expect the upsampled signal to have twice as many samples
self.assertTrue(up_sampled.size(-1) == sound.size(-1) * 2) self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)
downsample_resample = torchaudio.transforms.Resample( downsample_resample = torchaudio.transforms.Resample(
sample_rate, downsample_rate, resampling_method='sinc_interpolation') sample_rate, downsample_rate, resampling_method='sinc_interpolation')
down_sampled = downsample_resample(sound) down_sampled = downsample_resample(waveform)
# we expect the downsampled signal to have half as many samples # we expect the downsampled signal to have half as many samples
self.assertTrue(down_sampled.size(-1) == sound.size(-1) // 2) self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -3,109 +3,48 @@ import torch ...@@ -3,109 +3,48 @@ import torch
__all__ = [ __all__ = [
'scale',
'pad_trim', 'pad_trim',
'downmix_mono',
'LC2CL',
'istft', 'istft',
'spectrogram', 'spectrogram',
'create_fb_matrix', 'create_fb_matrix',
'spectrogram_to_DB', 'spectrogram_to_DB',
'create_dct', 'create_dct',
'BLC2CBL',
'mu_law_encoding', 'mu_law_encoding',
'mu_law_expanding' 'mu_law_expanding',
'complex_norm',
'angle',
'magphase',
'phase_vocoder',
] ]
@torch.jit.script @torch.jit.script
def scale(tensor, factor): def pad_trim(waveform, max_len, fill_value):
# type: (Tensor, int) -> Tensor # type: (Tensor, int, float) -> Tensor
r"""Scale audio tensor from a 16-bit integer (represented as a r"""Pad/trim a 2D tensor
:class:`torch.FloatTensor`) to a floating point number between -1.0 and 1.0.
Note the 16-bit number is called the "bit depth" or "precision", not to be
confused with "bit rate".
Args:
tensor (torch.Tensor): Tensor of audio of size (n, c) or (c, n)
factor (int): Maximum value of input tensor
Returns:
torch.Tensor: Scaled by the scale factor
"""
if not tensor.is_floating_point():
tensor = tensor.to(torch.float32)
return tensor / factor
@torch.jit.script
def pad_trim(tensor, ch_dim, max_len, len_dim, fill_value):
# type: (Tensor, int, int, int, float) -> Tensor
r"""Pad/trim a 2D tensor (signal or labels).
Args: Args:
tensor (torch.Tensor): Tensor of audio of size (n, c) or (c, n) waveform (torch.Tensor): Tensor of audio of size (c, n)
ch_dim (int): Dimension of channel (not size) max_len (int): Length to which the waveform will be padded
max_len (int): Length to which the tensor will be padded
len_dim (int): Dimension of length (not size)
fill_value (float): Value to fill in fill_value (float): Value to fill in
Returns: Returns:
torch.Tensor: Padded/trimmed tensor torch.Tensor: Padded/trimmed tensor
""" """
if max_len > tensor.size(len_dim): n = waveform.size(1)
# array of [padding_left, padding_right, padding_top, padding_bottom] if max_len > n:
# so pad similar to append (aka only right/bottom) and do not pad
# the length dimension. assumes equal sizes of padding.
padding = [max_len - tensor.size(len_dim)
if (i % 2 == 1) and (i // 2 != len_dim)
else 0
for i in [0, 1, 2, 3]]
# TODO add "with torch.no_grad():" back when JIT supports it # TODO add "with torch.no_grad():" back when JIT supports it
tensor = torch.nn.functional.pad(tensor, padding, "constant", fill_value) waveform = torch.nn.functional.pad(waveform, (0, max_len - n), 'constant', fill_value)
elif max_len < tensor.size(len_dim): else:
tensor = tensor.narrow(len_dim, 0, max_len) waveform = waveform[:, :max_len]
return tensor return waveform
@torch.jit.script
def downmix_mono(tensor, ch_dim):
# type: (Tensor, int) -> Tensor
r"""Downmix any stereo signals to mono.
Args:
tensor (torch.Tensor): Tensor of audio of size (c, n) or (n, c)
ch_dim (int): Dimension of channel (not size)
Returns:
torch.Tensor: Mono signal
"""
if not tensor.is_floating_point():
tensor = tensor.to(torch.float32)
tensor = torch.mean(tensor, ch_dim, True)
return tensor
@torch.jit.script
def LC2CL(tensor):
# type: (Tensor) -> Tensor
r"""Permute a 2D tensor from samples (n, c) to (c, n).
Args:
tensor (torch.Tensor): Tensor of audio signal with shape (n, c)
Returns:
torch.Tensor: Tensor of audio signal with shape (c, n)
"""
return tensor.transpose(0, 1).contiguous()
# TODO: remove this once https://github.com/pytorch/pytorch/issues/21478 gets solved # TODO: remove this once https://github.com/pytorch/pytorch/issues/21478 gets solved
@torch.jit.ignore @torch.jit.ignore
def _stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided): def _stft(waveform, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided):
# type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensor # type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensor
return torch.stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided) return torch.stft(waveform, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided)
def istft(stft_matrix, # type: Tensor def istft(stft_matrix, # type: Tensor
...@@ -149,8 +88,8 @@ def istft(stft_matrix, # type: Tensor ...@@ -149,8 +88,8 @@ def istft(stft_matrix, # type: Tensor
IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984. IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.
Args: Args:
stft_matrix (torch.Tensor): Output of stft where each row of a batch is a frequency and each stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each
column is a window. it has a shape of either (batch, fft_size, n_frames, 2) or ( column is a window. it has a shape of either (channel, fft_size, n_frames, 2) or (
fft_size, n_frames, 2) fft_size, n_frames, 2)
n_fft (int): Size of Fourier transform n_fft (int): Size of Fourier transform
hop_length (Optional[int]): The distance between neighboring sliding window frames. hop_length (Optional[int]): The distance between neighboring sliding window frames.
...@@ -168,20 +107,20 @@ def istft(stft_matrix, # type: Tensor ...@@ -168,20 +107,20 @@ def istft(stft_matrix, # type: Tensor
Returns: Returns:
torch.Tensor: Least squares estimation of the original signal of size torch.Tensor: Least squares estimation of the original signal of size
(batch, signal_length) or (signal_length) (channel, signal_length) or (signal_length)
""" """
stft_matrix_dim = stft_matrix.dim() stft_matrix_dim = stft_matrix.dim()
assert 3 <= stft_matrix_dim <= 4, ('Incorrect stft dimension: %d' % (stft_matrix_dim)) assert 3 <= stft_matrix_dim <= 4, ('Incorrect stft dimension: %d' % (stft_matrix_dim))
if stft_matrix_dim == 3: if stft_matrix_dim == 3:
# add a batch dimension # add a channel dimension
stft_matrix = stft_matrix.unsqueeze(0) stft_matrix = stft_matrix.unsqueeze(0)
device = stft_matrix.device device = stft_matrix.device
fft_size = stft_matrix.size(1) fft_size = stft_matrix.size(1)
assert (onesided and n_fft // 2 + 1 == fft_size) or (not onesided and n_fft == fft_size), ( assert (onesided and n_fft // 2 + 1 == fft_size) or (not onesided and n_fft == fft_size), (
'one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. ' 'one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. ' +
+ 'Given values were onesided: %s, n_fft: %d, fft_size: %d' % ('True' if onesided else False, n_fft, fft_size)) 'Given values were onesided: %s, n_fft: %d, fft_size: %d' % ('True' if onesided else False, n_fft, fft_size))
# use stft defaults for Optionals # use stft defaults for Optionals
if win_length is None: if win_length is None:
...@@ -206,16 +145,16 @@ def istft(stft_matrix, # type: Tensor ...@@ -206,16 +145,16 @@ def istft(stft_matrix, # type: Tensor
assert window.size(0) == n_fft assert window.size(0) == n_fft
# win_length and n_fft are synonymous from here on # win_length and n_fft are synonymous from here on
stft_matrix = stft_matrix.transpose(1, 2) # size (batch, n_frames, fft_size, 2) stft_matrix = stft_matrix.transpose(1, 2) # size (channel, n_frames, fft_size, 2)
stft_matrix = torch.irfft(stft_matrix, 1, normalized, stft_matrix = torch.irfft(stft_matrix, 1, normalized,
onesided, signal_sizes=(n_fft,)) # size (batch, n_frames, n_fft) onesided, signal_sizes=(n_fft,)) # size (channel, n_frames, n_fft)
assert stft_matrix.size(2) == n_fft assert stft_matrix.size(2) == n_fft
n_frames = stft_matrix.size(1) n_frames = stft_matrix.size(1)
ytmp = stft_matrix * window.view(1, 1, n_fft) # size (batch, n_frames, n_fft) ytmp = stft_matrix * window.view(1, 1, n_fft) # size (channel, n_frames, n_fft)
# each column of a batch is a frame which needs to be overlap added at the right place # each column of a channel is a frame which needs to be overlap added at the right place
ytmp = ytmp.transpose(1, 2) # size (batch, n_fft, n_frames) ytmp = ytmp.transpose(1, 2) # size (channel, n_fft, n_frames)
eye = torch.eye(n_fft, requires_grad=False, eye = torch.eye(n_fft, requires_grad=False,
device=device).unsqueeze(1) # size (n_fft, 1, n_fft) device=device).unsqueeze(1) # size (n_fft, 1, n_fft)
...@@ -223,7 +162,7 @@ def istft(stft_matrix, # type: Tensor ...@@ -223,7 +162,7 @@ def istft(stft_matrix, # type: Tensor
# this does overlap add where the frames of ytmp are added such that the i'th frame of # this does overlap add where the frames of ytmp are added such that the i'th frame of
# ytmp is added starting at i*hop_length in the output # ytmp is added starting at i*hop_length in the output
y = torch.nn.functional.conv_transpose1d( y = torch.nn.functional.conv_transpose1d(
ytmp, eye, stride=hop_length, padding=0) # size (batch, 1, expected_signal_len) ytmp, eye, stride=hop_length, padding=0) # size (channel, 1, expected_signal_len)
# do the same for the window function # do the same for the window function
window_sq = window.pow(2).view(n_fft, 1).repeat((1, n_frames)).unsqueeze(0) # size (1, n_fft, n_frames) window_sq = window.pow(2).view(n_fft, 1).repeat((1, n_frames)).unsqueeze(0) # size (1, n_fft, n_frames)
...@@ -246,67 +185,70 @@ def istft(stft_matrix, # type: Tensor ...@@ -246,67 +185,70 @@ def istft(stft_matrix, # type: Tensor
window_envelop_lowest = window_envelop.abs().min() window_envelop_lowest = window_envelop.abs().min()
assert window_envelop_lowest > 1e-11, ('window overlap add min: %f' % (window_envelop_lowest)) assert window_envelop_lowest > 1e-11, ('window overlap add min: %f' % (window_envelop_lowest))
y = (y / window_envelop).squeeze(1) # size (batch, expected_signal_len) y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len)
if stft_matrix_dim == 3: # remove the batch dimension if stft_matrix_dim == 3: # remove the channel dimension
y = y.squeeze(0) y = y.squeeze(0)
return y return y
@torch.jit.script @torch.jit.script
def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize): def spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, normalized):
# type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor # type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor
r"""Create a spectrogram from a raw audio signal. r"""Create a spectrogram from a raw audio signal.
Args: Args:
sig (torch.Tensor): Tensor of audio of size (c, n) waveform (torch.Tensor): Tensor of audio of size (c, n)
pad (int): Two sided padding of signal pad (int): Two sided padding of signal
window (torch.Tensor): Window_tensor window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of fft n_fft (int): Size of fft
hop (int): Length of hop between STFT windows hop_length (int): Length of hop between STFT windows
ws (int): Window size win_length (int): Window size
power (int) : Exponent for the magnitude spectrogram, power (int): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc. (must be > 0) e.g., 1 for energy, 2 for power, etc.
normalize (bool) : Whether to normalize by magnitude after stft normalized (bool): Whether to normalize by magnitude after stft
Returns: Returns:
torch.Tensor: Channels x hops x n_fft (c, l, f), where channels torch.Tensor: Channels x frequency x time (c, f, t), where channels
is unchanged, hops is the number of hops, and n_fft is the is unchanged, frequency is `n_fft // 2 + 1` where `n_fft` is the number of
number of fourier bins, which should be the window size divided fourier bins, and time is the number of window hops (n_frames).
by 2 plus 1.
""" """
assert sig.dim() == 2 assert waveform.dim() == 2
if pad > 0: if pad > 0:
# TODO add "with torch.no_grad():" back when JIT supports it # TODO add "with torch.no_grad():" back when JIT supports it
sig = torch.nn.functional.pad(sig, (pad, pad), "constant") waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant")
# default values are consistent with librosa.core.spectrum._spectrogram # default values are consistent with librosa.core.spectrum._spectrogram
spec_f = _stft(sig, n_fft, hop, ws, window, spec_f = _stft(waveform, n_fft, hop_length, win_length, window,
True, 'reflect', False, True).transpose(1, 2) True, 'reflect', False, True)
if normalize: if normalized:
spec_f /= window.pow(2).sum().sqrt() spec_f /= window.pow(2).sum().sqrt()
spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor (c, l, n_fft) spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor
return spec_f return spec_f
@torch.jit.script @torch.jit.script
def create_fb_matrix(n_stft, f_min, f_max, n_mels): def create_fb_matrix(n_freqs, f_min, f_max, n_mels):
# type: (int, float, float, int) -> Tensor # type: (int, float, float, int) -> Tensor
r""" Create a frequency bin conversion matrix. r""" Create a frequency bin conversion matrix.
Args: Args:
n_stft (int): Number of filter banks from spectrogram n_freqs (int): Number of frequencies to highlight/apply
f_min (float): Minimum frequency f_min (float): Minimum frequency
f_max (float): Maximum frequency f_max (float): Maximum frequency
n_mels (int): Number of mel bins n_mels (int): Number of mel filterbanks
Returns: Returns:
torch.Tensor: Triangular filter banks (fb matrix) torch.Tensor: Triangular filter banks (fb matrix) of size (`n_freqs`, `n_mels`)
meaning number of frequencies to highlight/apply to x the number of filterbanks.
Each column is a filterbank so that assuming there is a matrix A of
size (..., `n_freqs`), the applied result would be
`A * create_fb_matrix(A.size(-1), ...)`.
""" """
# get stft freq bins # freq bins
stft_freqs = torch.linspace(f_min, f_max, n_stft) freqs = torch.linspace(f_min, f_max, n_freqs)
# calculate mel freq bins # calculate mel freq bins
# hertz to mel(f) is 2595. * math.log10(1. + (f / 700.)) # hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
m_min = 0. if f_min == 0 else 2595. * math.log10(1. + (f_min / 700.)) m_min = 0. if f_min == 0 else 2595. * math.log10(1. + (f_min / 700.))
...@@ -316,17 +258,17 @@ def create_fb_matrix(n_stft, f_min, f_max, n_mels): ...@@ -316,17 +258,17 @@ def create_fb_matrix(n_stft, f_min, f_max, n_mels):
f_pts = 700. * (10**(m_pts / 2595.) - 1.) f_pts = 700. * (10**(m_pts / 2595.) - 1.)
# calculate the difference between each mel point and each stft freq point in hertz # calculate the difference between each mel point and each stft freq point in hertz
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1) f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
slopes = f_pts.unsqueeze(0) - stft_freqs.unsqueeze(1) # (n_stft, n_mels + 2) slopes = f_pts.unsqueeze(0) - freqs.unsqueeze(1) # (n_freqs, n_mels + 2)
# create overlapping triangles # create overlapping triangles
z = torch.zeros(1) zero = torch.zeros(1)
down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1] # (n_stft, n_mels) down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels)
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_stft, n_mels) up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels)
fb = torch.max(z, torch.min(down_slopes, up_slopes)) fb = torch.max(zero, torch.min(down_slopes, up_slopes))
return fb return fb
@torch.jit.script @torch.jit.script
def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None): def spectrogram_to_DB(specgram, multiplier, amin, db_multiplier, top_db=None):
# type: (Tensor, float, float, float, Optional[float]) -> Tensor # type: (Tensor, float, float, float, Optional[float]) -> Tensor
r"""Turns a spectrogram from the power/amplitude scale to the decibel scale. r"""Turns a spectrogram from the power/amplitude scale to the decibel scale.
...@@ -335,72 +277,57 @@ def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None): ...@@ -335,72 +277,57 @@ def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None):
a full clip. a full clip.
Args: Args:
spec (torch.Tensor): Normal STFT specgram (torch.Tensor): Normal STFT of size (c, f, t)
multiplier (float): Use 10. for power and 20. for amplitude multiplier (float): Use 10. for power and 20. for amplitude
amin (float): Number to clamp spec amin (float): Number to clamp specgram
db_multiplier (float): Log10(max(reference value and amin)) db_multiplier (float): Log10(max(reference value and amin))
top_db (Optional[float]): Minimum negative cut-off in decibels. A reasonable number top_db (Optional[float]): Minimum negative cut-off in decibels. A reasonable number
is 80. is 80.
Returns: Returns:
torch.Tensor: Spectrogram in DB torch.Tensor: Spectrogram in DB of size (c, f, t)
""" """
spec_db = multiplier * torch.log10(torch.clamp(spec, min=amin)) specgram_db = multiplier * torch.log10(torch.clamp(specgram, min=amin))
spec_db -= multiplier * db_multiplier specgram_db -= multiplier * db_multiplier
if top_db is not None: if top_db is not None:
new_spec_db_max = torch.tensor(float(spec_db.max()) - top_db, dtype=spec_db.dtype, device=spec_db.device) new_spec_db_max = torch.tensor(float(specgram_db.max()) - top_db,
spec_db = torch.max(spec_db, new_spec_db_max) dtype=specgram_db.dtype, device=specgram_db.device)
specgram_db = torch.max(specgram_db, new_spec_db_max)
return spec_db return specgram_db
@torch.jit.script @torch.jit.script
def create_dct(n_mfcc, n_mels, norm): def create_dct(n_mfcc, n_mels, norm):
# type: (int, int, Optional[str]) -> Tensor # type: (int, int, Optional[str]) -> Tensor
r"""Creates a DCT transformation matrix with shape (num_mels, num_mfcc), r"""Creates a DCT transformation matrix with shape (`n_mels`, `n_mfcc`),
normalized depending on norm. normalized depending on norm.
Args: Args:
n_mfcc (int) : Number of mfc coefficients to retain n_mfcc (int): Number of mfc coefficients to retain
n_mels (int): Number of MEL bins n_mels (int): Number of mel filterbanks
norm (Optional[str]) : Norm to use (either 'ortho' or None) norm (Optional[str]): Norm to use (either 'ortho' or None)
Returns: Returns:
torch.Tensor: The transformation matrix, to be right-multiplied to row-wise data. torch.Tensor: The transformation matrix, to be right-multiplied to
row-wise data of size (`n_mels`, `n_mfcc`).
""" """
outdim = n_mfcc
dim = n_mels
# http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II # http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
n = torch.arange(dim) n = torch.arange(float(n_mels))
k = torch.arange(outdim)[:, None] k = torch.arange(float(n_mfcc)).unsqueeze(1)
dct = torch.cos(math.pi / float(dim) * (n + 0.5) * k) dct = torch.cos(math.pi / float(n_mels) * (n + 0.5) * k) # size (n_mfcc, n_mels)
if norm is None: if norm is None:
dct *= 2.0 dct *= 2.0
else: else:
assert norm == 'ortho' assert norm == 'ortho'
dct[0] *= 1.0 / math.sqrt(2.0) dct[0] *= 1.0 / math.sqrt(2.0)
dct *= math.sqrt(2.0 / float(dim)) dct *= math.sqrt(2.0 / float(n_mels))
return dct.t() return dct.t()
@torch.jit.script @torch.jit.script
def BLC2CBL(tensor): def mu_law_encoding(x, quantization_channels):
# type: (Tensor) -> Tensor
r"""Permute a 3D tensor from Bands x Sample length x Channels to Channels x
Bands x Samples length.
Args:
tensor (torch.Tensor): Tensor of spectrogram with shape (b, l, c)
Returns:
torch.Tensor: Tensor of spectrogram with shape (c, b, l)
"""
return tensor.permute(2, 0, 1).contiguous()
@torch.jit.script
def mu_law_encoding(x, qc):
# type: (Tensor, int) -> Tensor # type: (Tensor, int) -> Tensor
r"""Encode signal based on mu-law companding. For more info see the r"""Encode signal based on mu-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_ `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
...@@ -410,13 +337,12 @@ def mu_law_encoding(x, qc): ...@@ -410,13 +337,12 @@ def mu_law_encoding(x, qc):
Args: Args:
x (torch.Tensor): Input tensor x (torch.Tensor): Input tensor
qc (int): Number of channels (i.e. quantization channels) quantization_channels (int): Number of channels
Returns: Returns:
torch.Tensor: Input after mu-law companding torch.Tensor: Input after mu-law companding
""" """
assert isinstance(x, torch.Tensor), 'mu_law_encoding expects a Tensor' mu = quantization_channels - 1.
mu = qc - 1.
if not x.is_floating_point(): if not x.is_floating_point():
x = x.to(torch.float) x = x.to(torch.float)
mu = torch.tensor(mu, dtype=x.dtype) mu = torch.tensor(mu, dtype=x.dtype)
...@@ -427,7 +353,7 @@ def mu_law_encoding(x, qc): ...@@ -427,7 +353,7 @@ def mu_law_encoding(x, qc):
@torch.jit.script @torch.jit.script
def mu_law_expanding(x_mu, qc): def mu_law_expanding(x_mu, quantization_channels):
# type: (Tensor, int) -> Tensor # type: (Tensor, int) -> Tensor
r"""Decode mu-law encoded signal. For more info see the r"""Decode mu-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_ `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
...@@ -437,13 +363,12 @@ def mu_law_expanding(x_mu, qc): ...@@ -437,13 +363,12 @@ def mu_law_expanding(x_mu, qc):
Args: Args:
x_mu (torch.Tensor): Input tensor x_mu (torch.Tensor): Input tensor
qc (int): Number of channels (i.e. quantization channels) quantization_channels (int): Number of channels
Returns: Returns:
torch.Tensor: Input after decoding torch.Tensor: Input after decoding
""" """
assert isinstance(x_mu, torch.Tensor), 'mu_law_expanding expects a Tensor' mu = quantization_channels - 1.
mu = qc - 1.
if not x_mu.is_floating_point(): if not x_mu.is_floating_point():
x_mu = x_mu.to(torch.float) x_mu = x_mu.to(torch.float)
mu = torch.tensor(mu, dtype=x_mu.dtype) mu = torch.tensor(mu, dtype=x_mu.dtype)
...@@ -452,71 +377,15 @@ def mu_law_expanding(x_mu, qc): ...@@ -452,71 +377,15 @@ def mu_law_expanding(x_mu, qc):
return x return x
def stft(waveforms, fft_length, hop_length=None, win_length=None, window=None,
center=True, pad_mode='reflect', normalized=False, onesided=True):
"""Compute a short time Fourier transform of the input waveform(s).
It wraps `torch.stft` after reshaping the input audio to allow for `waveforms` that `.dim()` >= 3.
It follows most of the `torch.stft` default values, but for `window`, which defaults to hann window.
Args:
waveforms (torch.Tensor): Audio signal of size `(*, channel, time)`
fft_length (int): FFT size [sample].
hop_length (int): Hop size [sample] between STFT frames.
(Defaults to `fft_length // 4`, 75%-overlapping windows by `torch.stft`).
win_length (int): Size of STFT window. (Defaults to `fft_length` by `torch.stft`).
window (torch.Tensor): window function. (Defaults to Hann Window of size `win_length` *unlike* `torch.stft`).
center (bool): Whether to pad `waveforms` on both sides so that the `t`-th frame is centered
at time `t * hop_length`. (Defaults to `True` by `torch.stft`)
pad_mode (str): padding method (see `torch.nn.functional.pad`). (Defaults to `'reflect'` by `torch.stft`).
normalized (bool): Whether the results are normalized. (Defaults to `False` by `torch.stft`).
onesided (bool): Whether the half + 1 frequency bins are returned to removethe symmetric part of STFT
of real-valued signal. (Defaults to `True` by `torch.stft`).
Returns:
torch.Tensor: `(*, channel, num_freqs, time, complex=2)`
Example:
>>> waveforms = torch.randn(16, 2, 10000) # (batch, channel, time)
>>> x = stft(waveforms, 2048, 512)
>>> x.shape
torch.Size([16, 2, 1025, 20])
"""
leading_dims = waveforms.shape[:-1]
waveforms = waveforms.reshape(-1, waveforms.size(-1))
if window is None:
if win_length is None:
window = torch.hann_window(fft_length)
else:
window = torch.hann_window(win_length)
complex_specgrams = torch.stft(waveforms,
n_fft=fft_length,
hop_length=hop_length,
win_length=win_length,
window=window,
center=center,
pad_mode=pad_mode,
normalized=normalized,
onesided=onesided)
complex_specgrams = complex_specgrams.reshape(
leading_dims +
complex_specgrams.shape[1:])
return complex_specgrams
def complex_norm(complex_tensor, power=1.0): def complex_norm(complex_tensor, power=1.0):
"""Compute the norm of complex tensor input r"""Compute the norm of complex tensor input.
Args: Args:
complex_tensor (Tensor): Tensor shape of `(*, complex=2)` complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)`
power (float): Power of the norm. Defaults to `1.0`. power (float): Power of the norm. (Default: `1.0`).
Returns: Returns:
Tensor: power of the normed input tensor, shape of `(*, )` torch.Tensor: Power of the normed input tensor. Shape of `(*, )`
""" """
if power == 1.0: if power == 1.0:
return torch.norm(complex_tensor, 2, -1) return torch.norm(complex_tensor, 2, -1)
...@@ -524,16 +393,26 @@ def complex_norm(complex_tensor, power=1.0): ...@@ -524,16 +393,26 @@ def complex_norm(complex_tensor, power=1.0):
def angle(complex_tensor): def angle(complex_tensor):
""" r"""Compute the angle of complex tensor input.
Return angle of a complex tensor with shape (*, 2).
Args:
complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)`
Return:
torch.Tensor: Angle of a complex tensor. Shape of `(*, )`
""" """
return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0]) return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0])
def magphase(complex_tensor, power=1.): def magphase(complex_tensor, power=1.):
""" r"""Separate a complex-valued spectrogram with shape (*,2) into its magnitude and phase.
Separate a complex-valued spectrogram with shape (*,2)
into its magnitude and phase. Args:
complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)`
power (float): Power of the norm. (Default: `1.0`)
Returns:
Tuple[torch.Tensor, torch.Tensor]: The magnitude and phase of the complex_tensor
""" """
mag = complex_norm(complex_tensor, power) mag = complex_norm(complex_tensor, power)
phase = angle(complex_tensor) phase = angle(complex_tensor)
...@@ -541,20 +420,16 @@ def magphase(complex_tensor, power=1.): ...@@ -541,20 +420,16 @@ def magphase(complex_tensor, power=1.):
def phase_vocoder(complex_specgrams, rate, phase_advance): def phase_vocoder(complex_specgrams, rate, phase_advance):
""" r"""Given a STFT tensor, speed up in time without modifying pitch by a
Phase vocoder. Given a STFT tensor, speed up in time factor of `rate`.
without modifying pitch by a factor of `rate`.
Args: Args:
complex_specgrams (Tensor): complex_specgrams (torch.Tensor): Size of (*, c, f, t, complex=2)
(*, channel, num_freqs, time, complex=2) rate (float): Speed-up factor
rate (float): Speed-up factor. phase_advance (torch.Tensor): Expected phase advance in each bin. Size of (f, 1)
phase_advance (Tensor): Expected phase advance in
each bin. (num_freqs, 1).
Returns: Returns:
complex_specgrams_stretch (Tensor): complex_specgrams_stretch (torch.Tensor): Size of (*, c, f, ceil(t/rate), complex=2)
(*, channel, num_freqs, ceil(time/rate), complex=2).
Example: Example:
>>> num_freqs, hop_length = 1025, 512 >>> num_freqs, hop_length = 1025, 512
......
...@@ -7,278 +7,169 @@ from . import functional as F ...@@ -7,278 +7,169 @@ from . import functional as F
from .compliance import kaldi from .compliance import kaldi
# TODO remove this class
class Compose(object):
"""Composes several transforms together.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.Scale(),
>>> transforms.PadTrim(max_len=16000),
>>> ])
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, audio):
for t in self.transforms:
audio = t(audio)
return audio
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
class Scale(torch.jit.ScriptModule):
"""Scale audio tensor from a 16-bit integer (represented as a FloatTensor)
to a floating point number between -1.0 and 1.0. Note the 16-bit number is
called the "bit depth" or "precision", not to be confused with "bit rate".
Args:
factor (int): maximum value of input tensor. default: 16-bit depth
"""
__constants__ = ['factor']
def __init__(self, factor=2**31):
super(Scale, self).__init__()
self.factor = factor
@torch.jit.script_method
def forward(self, tensor):
"""
Args:
tensor (Tensor): Tensor of audio of size (Samples x Channels)
Returns:
Tensor: Scaled by the scale factor. (default between -1.0 and 1.0)
"""
return F.scale(tensor, self.factor)
def __repr__(self):
return self.__class__.__name__ + '()'
class PadTrim(torch.jit.ScriptModule): class PadTrim(torch.jit.ScriptModule):
"""Pad/Trim a 2d-Tensor (Signal or Labels) r"""Pad/Trim a 2D tensor
Args: Args:
tensor (Tensor): Tensor of audio of size (n x c) or (c x n) max_len (int): Length to which the waveform will be padded
max_len (int): Length to which the tensor will be padded fill_value (float): Value to fill in
channels_first (bool): Pad for channels first tensors. Default: `True`
""" """
__constants__ = ['max_len', 'fill_value', 'len_dim', 'ch_dim'] __constants__ = ['max_len', 'fill_value']
def __init__(self, max_len, fill_value=0., channels_first=True): def __init__(self, max_len, fill_value=0.):
super(PadTrim, self).__init__() super(PadTrim, self).__init__()
self.max_len = max_len self.max_len = max_len
self.fill_value = fill_value self.fill_value = fill_value
self.len_dim, self.ch_dim = int(channels_first), int(not channels_first)
@torch.jit.script_method @torch.jit.script_method
def forward(self, tensor): def forward(self, waveform):
""" r"""
Returns:
Tensor: (c x n) or (n x c)
"""
return F.pad_trim(tensor, self.ch_dim, self.max_len, self.len_dim, self.fill_value)
def __repr__(self):
return self.__class__.__name__ + '(max_len={0})'.format(self.max_len)
class DownmixMono(torch.jit.ScriptModule):
"""Downmix any stereo signals to mono. Consider using a `SoxEffectsChain` with
the `channels` effect instead of this transformation.
Inputs:
tensor (Tensor): Tensor of audio of size (c x n) or (n x c)
channels_first (bool): Downmix across channels dimension. Default: `True`
Returns:
tensor (Tensor) (Samples x 1):
"""
__constants__ = ['ch_dim']
def __init__(self, channels_first=None):
super(DownmixMono, self).__init__()
self.ch_dim = int(not channels_first)
@torch.jit.script_method
def forward(self, tensor):
return F.downmix_mono(tensor, self.ch_dim)
def __repr__(self):
return self.__class__.__name__ + '()'
class LC2CL(torch.jit.ScriptModule):
"""Permute a 2d tensor from samples (n x c) to (c x n)
"""
def __init__(self):
super(LC2CL, self).__init__()
@torch.jit.script_method
def forward(self, tensor):
"""
Args: Args:
tensor (Tensor): Tensor of audio signal with shape (LxC) waveform (torch.Tensor): Tensor of audio of size (c, n)
Returns: Returns:
tensor (Tensor): Tensor of audio signal with shape (CxL) Tensor: Tensor of size (c, `max_len`)
""" """
return F.LC2CL(tensor) return F.pad_trim(waveform, self.max_len, self.fill_value)
def __repr__(self):
return self.__class__.__name__ + '()'
def SPECTROGRAM(*args, **kwargs):
warn("SPECTROGRAM has been renamed to Spectrogram")
return Spectrogram(*args, **kwargs)
class Spectrogram(torch.jit.ScriptModule): class Spectrogram(torch.jit.ScriptModule):
"""Create a spectrogram from a raw audio signal r"""Create a spectrogram from a audio signal
Args: Args:
n_fft (int, optional): size of fft, creates n_fft // 2 + 1 bins n_fft (int, optional): Size of fft, creates `n_fft // 2 + 1` bins
ws (int): window size. default: n_fft win_length (int): Window size. (Default: `n_fft`)
hop (int, optional): length of hop between STFT windows. default: ws // 2 hop_length (int, optional): Length of hop between STFT windows. (
pad (int): two sided padding of signal Default: `win_length // 2`)
window (torch windowing function): default: torch.hann_window pad (int): Two sided padding of signal. (Default: 0)
power (int > 0 ) : Exponent for the magnitude spectrogram, window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
e.g., 1 for energy, 2 for power, etc. that is applied/multiplied to each frame/window. (Default: `torch.hann_window`)
normalize (bool) : whether to normalize by magnitude after stft power (int): Exponent for the magnitude spectrogram,
wkwargs (dict, optional): arguments for window function (must be > 0) e.g., 1 for energy, 2 for power, etc.
""" normalized (bool): Whether to normalize by magnitude after stft. (Default: `False`)
__constants__ = ['n_fft', 'ws', 'hop', 'pad', 'power', 'normalize'] wkwargs (Dict[..., ...]): Arguments for window function. (Default: `None`)
"""
def __init__(self, n_fft=400, ws=None, hop=None, __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
pad=0, window=torch.hann_window,
power=2, normalize=False, wkwargs=None): def __init__(self, n_fft=400, win_length=None, hop_length=None,
pad=0, window_fn=torch.hann_window,
power=2, normalized=False, wkwargs=None):
super(Spectrogram, self).__init__() super(Spectrogram, self).__init__()
self.n_fft = n_fft self.n_fft = n_fft
# number of fft bins. the returned STFT result will have n_fft // 2 + 1 # number of fft bins. the returned STFT result will have n_fft // 2 + 1
# number of frequecies due to onesided=True in torch.stft # number of frequecies due to onesided=True in torch.stft
self.ws = ws if ws is not None else n_fft self.win_length = win_length if win_length is not None else n_fft
self.hop = hop if hop is not None else self.ws // 2 self.hop_length = hop_length if hop_length is not None else self.win_length // 2
window = window(self.ws) if wkwargs is None else window(self.ws, **wkwargs) window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.window = torch.jit.Attribute(window, torch.Tensor) self.window = torch.jit.Attribute(window, torch.Tensor)
self.pad = pad self.pad = pad
self.power = power self.power = power
self.normalize = normalize self.normalized = normalized
@torch.jit.script_method @torch.jit.script_method
def forward(self, sig): def forward(self, waveform):
""" r"""
Args: Args:
sig (Tensor): Tensor of audio of size (c, n) waveform (torch.Tensor): Tensor of audio of size (c, n)
Returns: Returns:
spec_f (Tensor): channels x hops x n_fft (c, l, f), where channels torch.Tensor: Channels x frequency x time (c, f, t), where channels
is unchanged, hops is the number of hops, and n_fft is the is unchanged, frequency is `n_fft // 2 + 1` where `n_fft` is the number of
number of fourier bins, which should be the window size divided fourier bins, and time is the number of window hops (n_frames).
by 2 plus 1.
""" """
return F.spectrogram(sig, self.pad, self.window, self.n_fft, self.hop, return F.spectrogram(waveform, self.pad, self.window, self.n_fft, self.hop_length,
self.ws, self.power, self.normalize) self.win_length, self.power, self.normalized)
def F2M(*args, **kwargs):
warn("F2M has been renamed to MelScale")
return MelScale(*args, **kwargs)
class MelScale(torch.jit.ScriptModule): class MelScale(torch.jit.ScriptModule):
"""This turns a normal STFT into a mel frequency STFT, using a conversion r"""This turns a normal STFT into a mel frequency STFT, using a conversion
matrix. This uses triangular filter banks. matrix. This uses triangular filter banks.
User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)). User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
Args: Args:
n_mels (int): number of mel bins n_mels (int): Number of mel filterbanks. (Default: 128)
sr (int): sample rate of audio signal sample_rate (int): Sample rate of audio signal. (Default: 16000)
f_max (float, optional): maximum frequency. default: `sr` // 2 f_min (float): Minimum frequency. (Default: 0.)
f_min (float): minimum frequency. default: 0 f_max (float, optional): Maximum frequency. (Default: `sample_rate // 2`)
n_stft (int, optional): number of filter banks from stft. Calculated from first input n_stft (int, optional): Number of bins in STFT. Calculated from first input
if `None` is given. See `n_fft` in `Spectrogram`. if `None` is given. See `n_fft` in `Spectrogram`.
""" """
__constants__ = ['n_mels', 'sr', 'f_min', 'f_max'] __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
def __init__(self, n_mels=128, sr=16000, f_max=None, f_min=0., n_stft=None): def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=None):
super(MelScale, self).__init__() super(MelScale, self).__init__()
self.n_mels = n_mels self.n_mels = n_mels
self.sr = sr self.sample_rate = sample_rate
self.f_max = f_max if f_max is not None else float(sr // 2) self.f_max = f_max if f_max is not None else float(sample_rate // 2)
assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)
self.f_min = f_min self.f_min = f_min
fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
n_stft, self.f_min, self.f_max, self.n_mels) n_stft, self.f_min, self.f_max, self.n_mels)
self.fb = torch.jit.Attribute(fb, torch.Tensor) self.fb = torch.jit.Attribute(fb, torch.Tensor)
@torch.jit.script_method @torch.jit.script_method
def forward(self, spec_f): def forward(self, specgram):
r"""
Args:
specgram (torch.Tensor): a spectrogram STFT of size (c, f, t)
Returns:
torch.Tensor: mel frequency spectrogram of size (c, `n_mels`, t)
"""
if self.fb.numel() == 0: if self.fb.numel() == 0:
tmp_fb = F.create_fb_matrix(spec_f.size(2), self.f_min, self.f_max, self.n_mels) tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels)
# Attributes cannot be reassigned outside __init__ so workaround # Attributes cannot be reassigned outside __init__ so workaround
self.fb.resize_(tmp_fb.size()) self.fb.resize_(tmp_fb.size())
self.fb.copy_(tmp_fb) self.fb.copy_(tmp_fb)
spec_m = torch.matmul(spec_f, self.fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return spec_m # (c, f, t).transpose(...) dot (f, n_mels) -> (c, t, n_mels).transpose(...)
mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
return mel_specgram
class SpectrogramToDB(torch.jit.ScriptModule): class SpectrogramToDB(torch.jit.ScriptModule):
"""Turns a spectrogram from the power/amplitude scale to the decibel scale. r"""Turns a spectrogram from the power/amplitude scale to the decibel scale.
This output depends on the maximum value in the input spectrogram, and so This output depends on the maximum value in the input spectrogram, and so
may return different values for an audio clip split into snippets vs. a may return different values for an audio clip split into snippets vs. a
a full clip. a full clip.
Args: Args:
stype (str): scale of input spectrogram ("power" or "magnitude"). The stype (str): scale of input spectrogram ('power' or 'magnitude'). The
power being the elementwise square of the magnitude. default: "power" power being the elementwise square of the magnitude. (Default: 'power')
top_db (float, optional): minimum negative cut-off in decibels. A reasonable number top_db (float, optional): minimum negative cut-off in decibels. A reasonable number
is 80. is 80.
""" """
__constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier'] __constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier']
def __init__(self, stype="power", top_db=None): def __init__(self, stype='power', top_db=None):
super(SpectrogramToDB, self).__init__() super(SpectrogramToDB, self).__init__()
self.stype = torch.jit.Attribute(stype, str) self.stype = torch.jit.Attribute(stype, str)
if top_db is not None and top_db < 0: if top_db is not None and top_db < 0:
raise ValueError('top_db must be positive value') raise ValueError('top_db must be positive value')
self.top_db = torch.jit.Attribute(top_db, Optional[float]) self.top_db = torch.jit.Attribute(top_db, Optional[float])
self.multiplier = 10. if stype == "power" else 20. self.multiplier = 10.0 if stype == 'power' else 20.0
self.amin = 1e-10 self.amin = 1e-10
self.ref_value = 1. self.ref_value = 1.0
self.db_multiplier = math.log10(max(self.amin, self.ref_value)) self.db_multiplier = math.log10(max(self.amin, self.ref_value))
@torch.jit.script_method @torch.jit.script_method
def forward(self, spec): def forward(self, specgram):
# numerically stable implementation from librosa r"""Numerically stable implementation from Librosa
# https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html
return F.spectrogram_to_DB(spec, self.multiplier, self.amin, self.db_multiplier, self.top_db)
Args:
specgram (torch.Tensor): STFT of size (c, f, t)
Returns:
torch.Tensor: STFT after changing scale of size (c, f, t)
"""
return F.spectrogram_to_DB(specgram, self.multiplier, self.amin, self.db_multiplier, self.top_db)
class MFCC(torch.jit.ScriptModule): class MFCC(torch.jit.ScriptModule):
"""Create the Mel-frequency cepstrum coefficients from an audio signal r"""Create the Mel-frequency cepstrum coefficients from an audio signal
By default, this calculates the MFCC on the DB-scaled Mel spectrogram. By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
This is not the textbook implementation, but is implemented here to This is not the textbook implementation, but is implemented here to
...@@ -289,32 +180,32 @@ class MFCC(torch.jit.ScriptModule): ...@@ -289,32 +180,32 @@ class MFCC(torch.jit.ScriptModule):
a full clip. a full clip.
Args: Args:
sr (int) : sample rate of audio signal sample_rate (int): Sample rate of audio signal. (Default: 16000)
n_mfcc (int) : number of mfc coefficients to retain n_mfcc (int): Number of mfc coefficients to retain
dct_type (int) : type of DCT (discrete cosine transform) to use dct_type (int): type of DCT (discrete cosine transform) to use
norm (string, optional) : norm to use norm (string, optional): norm to use
log_mels (bool) : whether to use log-mel spectrograms instead of db-scaled log_mels (bool): whether to use log-mel spectrograms instead of db-scaled
melkwargs (dict, optional): arguments for MelSpectrogram melkwargs (dict, optional): arguments for MelSpectrogram
""" """
__constants__ = ['sr', 'n_mfcc', 'dct_type', 'top_db', 'log_mels'] __constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels']
def __init__(self, sr=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False, def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False,
melkwargs=None): melkwargs=None):
super(MFCC, self).__init__() super(MFCC, self).__init__()
supported_dct_types = [2] supported_dct_types = [2]
if dct_type not in supported_dct_types: if dct_type not in supported_dct_types:
raise ValueError('DCT type not supported'.format(dct_type)) raise ValueError('DCT type not supported'.format(dct_type))
self.sr = sr self.sample_rate = sample_rate
self.n_mfcc = n_mfcc self.n_mfcc = n_mfcc
self.dct_type = dct_type self.dct_type = dct_type
self.norm = torch.jit.Attribute(norm, Optional[str]) self.norm = torch.jit.Attribute(norm, Optional[str])
self.top_db = 80. self.top_db = 80.0
self.s2db = SpectrogramToDB("power", self.top_db) self.spectrogram_to_DB = SpectrogramToDB('power', self.top_db)
if melkwargs is not None: if melkwargs is not None:
self.MelSpectrogram = MelSpectrogram(sr=self.sr, **melkwargs) self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs)
else: else:
self.MelSpectrogram = MelSpectrogram(sr=self.sr) self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate)
if self.n_mfcc > self.MelSpectrogram.n_mels: if self.n_mfcc > self.MelSpectrogram.n_mels:
raise ValueError('Cannot select more MFCC coefficients than # mel bins') raise ValueError('Cannot select more MFCC coefficients than # mel bins')
...@@ -323,29 +214,28 @@ class MFCC(torch.jit.ScriptModule): ...@@ -323,29 +214,28 @@ class MFCC(torch.jit.ScriptModule):
self.log_mels = log_mels self.log_mels = log_mels
@torch.jit.script_method @torch.jit.script_method
def forward(self, sig): def forward(self, waveform):
""" r"""
Args: Args:
sig (Tensor): Tensor of audio of size (channels [c], samples [n]) waveform (torch.Tensor): Tensor of audio of size (c, n)
Returns: Returns:
spec_mel_db (Tensor): channels x hops x n_mels (c, l, m), where channels torch.Tensor: specgram_mel_db of size (c, `n_mfcc`, t)
is unchanged, hops is the number of hops, and n_mels is the
number of mel bins.
""" """
mel_spect = self.MelSpectrogram(sig) mel_specgram = self.MelSpectrogram(waveform)
if self.log_mels: if self.log_mels:
log_offset = 1e-6 log_offset = 1e-6
mel_spect = torch.log(mel_spect + log_offset) mel_specgram = torch.log(mel_specgram + log_offset)
else: else:
mel_spect = self.s2db(mel_spect) mel_specgram = self.spectrogram_to_DB(mel_specgram)
mfcc = torch.matmul(mel_spect, self.dct_mat) # (c, `n_mels`, t).tranpose(...) dot (`n_mels`, `n_mfcc`) -> (c, t, `n_mfcc`).tranpose(...)
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
return mfcc return mfcc
class MelSpectrogram(torch.jit.ScriptModule): class MelSpectrogram(torch.jit.ScriptModule):
"""Create MEL Spectrograms from a raw audio signal using the stft r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram
function in PyTorch. and MelScale.
Sources: Sources:
* https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
...@@ -353,87 +243,58 @@ class MelSpectrogram(torch.jit.ScriptModule): ...@@ -353,87 +243,58 @@ class MelSpectrogram(torch.jit.ScriptModule):
* http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
Args: Args:
sr (int): sample rate of audio signal sample_rate (int): Sample rate of audio signal. (Default: 16000)
ws (int): window size win_length (int): Window size. (Default: `n_fft`)
hop (int, optional): length of hop between STFT windows. default: `ws` // 2 hop_length (int, optional): Length of hop between STFT windows. (
n_fft (int, optional): number of fft bins. default: `ws` // 2 + 1 Default: `win_length // 2`)
f_max (float, optional): maximum frequency. default: `sr` // 2 n_fft (int, optional): Size of fft, creates `n_fft // 2 + 1` bins
f_min (float): minimum frequency. default: 0 f_min (float): Minimum frequency. (Default: 0.)
pad (int): two sided padding of signal f_max (float, optional): Maximum frequency. (Default: `None`)
n_mels (int): number of MEL bins pad (int): Two sided padding of signal. (Default: 0)
window (torch windowing function): default: `torch.hann_window` n_mels (int): Number of mel filterbanks. (Default: 128)
wkwargs (dict, optional): arguments for window function window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: `torch.hann_window`)
wkwargs (Dict[..., ...]): Arguments for window function. (Default: `None`)
Example: Example:
>>> sig, sr = torchaudio.load("test.wav", normalization=True) >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
>>> spec_mel = transforms.MelSpectrogram(sr)(sig) # (c, l, m) >>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform) # (c, n_mels, t)
""" """
__constants__ = ['sr', 'n_fft', 'ws', 'hop', 'pad', 'n_mels', 'f_min'] __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min']
def __init__(self, sr=16000, n_fft=400, ws=None, hop=None, f_min=0., f_max=None, def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=None, f_min=0., f_max=None,
pad=0, n_mels=128, window=torch.hann_window, wkwargs=None): pad=0, n_mels=128, window_fn=torch.hann_window, wkwargs=None):
super(MelSpectrogram, self).__init__() super(MelSpectrogram, self).__init__()
self.sr = sr self.sample_rate = sample_rate
self.n_fft = n_fft self.n_fft = n_fft
self.ws = ws if ws is not None else n_fft self.win_length = win_length if win_length is not None else n_fft
self.hop = hop if hop is not None else self.ws // 2 self.hop_length = hop_length if hop_length is not None else self.win_length // 2
self.pad = pad self.pad = pad
self.n_mels = n_mels # number of mel frequency bins self.n_mels = n_mels # number of mel frequency bins
self.f_max = torch.jit.Attribute(f_max, Optional[float]) self.f_max = torch.jit.Attribute(f_max, Optional[float])
self.f_min = f_min self.f_min = f_min
self.spec = Spectrogram(n_fft=self.n_fft, ws=self.ws, hop=self.hop, self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
pad=self.pad, window=window, power=2, hop_length=self.hop_length,
normalize=False, wkwargs=wkwargs) pad=self.pad, window_fn=window_fn, power=2,
self.fm = MelScale(self.n_mels, self.sr, self.f_max, self.f_min) normalized=False, wkwargs=wkwargs)
self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max)
@torch.jit.script_method
def forward(self, sig):
"""
Args:
sig (Tensor): Tensor of audio of size (channels [c], samples [n])
Returns:
spec_mel (Tensor): channels x hops x n_mels (c, l, m), where channels
is unchanged, hops is the number of hops, and n_mels is the
number of mel bins.
"""
spec = self.spec(sig)
spec_mel = self.fm(spec)
return spec_mel
def MEL(*args, **kwargs):
raise DeprecationWarning("MEL has been removed from the library please use MelSpectrogram or librosa")
class BLC2CBL(torch.jit.ScriptModule):
"""Permute a 3d tensor from Bands x Sample length x Channels to Channels x
Bands x Samples length
"""
def __init__(self):
super(BLC2CBL, self).__init__()
@torch.jit.script_method @torch.jit.script_method
def forward(self, tensor): def forward(self, waveform):
""" r"""
Args: Args:
tensor (Tensor): Tensor of spectrogram with shape (BxLxC) waveform (torch.Tensor): Tensor of audio of size (c, n)
Returns: Returns:
tensor (Tensor): Tensor of spectrogram with shape (CxBxL) torch.Tensor: mel frequency spectrogram of size (c, `n_mels`, t)
""" """
return F.BLC2CBL(tensor) specgram = self.spectrogram(waveform)
mel_specgram = self.mel_scale(specgram)
def __repr__(self): return mel_specgram
return self.__class__.__name__ + '()'
class MuLawEncoding(torch.jit.ScriptModule): class MuLawEncoding(torch.jit.ScriptModule):
"""Encode signal based on mu-law companding. For more info see the r"""Encode signal based on mu-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_ `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
This algorithm assumes the signal has been scaled to between -1 and 1 and This algorithm assumes the signal has been scaled to between -1 and 1 and
...@@ -441,33 +302,27 @@ class MuLawEncoding(torch.jit.ScriptModule): ...@@ -441,33 +302,27 @@ class MuLawEncoding(torch.jit.ScriptModule):
Args: Args:
quantization_channels (int): Number of channels. default: 256 quantization_channels (int): Number of channels. default: 256
""" """
__constants__ = ['qc'] __constants__ = ['quantization_channels']
def __init__(self, quantization_channels=256): def __init__(self, quantization_channels=256):
super(MuLawEncoding, self).__init__() super(MuLawEncoding, self).__init__()
self.qc = quantization_channels self.quantization_channels = quantization_channels
@torch.jit.script_method @torch.jit.script_method
def forward(self, x): def forward(self, x):
""" r"""
Args: Args:
x (FloatTensor/LongTensor) x (torch.Tensor): A signal to be encoded
Returns: Returns:
x_mu (LongTensor) x_mu (torch.Tensor): An encoded signal
""" """
return F.mu_law_encoding(x, self.qc) return F.mu_law_encoding(x, self.quantization_channels)
def __repr__(self):
return self.__class__.__name__ + '()'
class MuLawExpanding(torch.jit.ScriptModule): class MuLawExpanding(torch.jit.ScriptModule):
"""Decode mu-law encoded signal. For more info see the r"""Decode mu-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_ `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
This expects an input with values between 0 and quantization_channels - 1 This expects an input with values between 0 and quantization_channels - 1
...@@ -475,33 +330,27 @@ class MuLawExpanding(torch.jit.ScriptModule): ...@@ -475,33 +330,27 @@ class MuLawExpanding(torch.jit.ScriptModule):
Args: Args:
quantization_channels (int): Number of channels. default: 256 quantization_channels (int): Number of channels. default: 256
""" """
__constants__ = ['qc'] __constants__ = ['quantization_channels']
def __init__(self, quantization_channels=256): def __init__(self, quantization_channels=256):
super(MuLawExpanding, self).__init__() super(MuLawExpanding, self).__init__()
self.qc = quantization_channels self.quantization_channels = quantization_channels
@torch.jit.script_method @torch.jit.script_method
def forward(self, x_mu): def forward(self, x_mu):
""" r"""
Args: Args:
x_mu (Tensor) x_mu (torch.Tensor): A mu-law encoded signal which needs to be decoded
Returns: Returns:
x (Tensor) torch.Tensor: The signal decoded
""" """
return F.mu_law_expanding(x_mu, self.qc) return F.mu_law_expanding(x_mu, self.quantization_channels)
def __repr__(self):
return self.__class__.__name__ + '()'
class Resample(torch.nn.Module): class Resample(torch.nn.Module):
"""Resamples a signal from one frequency to another. A resampling method can r"""Resamples a signal from one frequency to another. A resampling method can
be given. be given.
Args: Args:
...@@ -516,15 +365,15 @@ class Resample(torch.nn.Module): ...@@ -516,15 +365,15 @@ class Resample(torch.nn.Module):
self.new_freq = new_freq self.new_freq = new_freq
self.resampling_method = resampling_method self.resampling_method = resampling_method
def forward(self, sig): def forward(self, waveform):
""" r"""
Args: Args:
sig (Tensor): the input signal of size (c, n) waveform (torch.Tensor): The input signal of size (c, n)
Returns: Returns:
Tensor: output signal of size (c, m) torch.Tensor: Output signal of size (c, m)
""" """
if self.resampling_method == 'sinc_interpolation': if self.resampling_method == 'sinc_interpolation':
return kaldi.resample_waveform(sig, self.orig_freq, self.new_freq) return kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)
raise ValueError('Invalid resampling method: %s' % (self.resampling_method)) raise ValueError('Invalid resampling method: %s' % (self.resampling_method))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment