Commit a8d6a41b authored by David Pollack's avatar David Pollack Committed by Soumith Chintala
Browse files

optimization to MEL2 and fixes to filter bank conversion function

parent 3bd4db86
from __future__ import print_function from __future__ import print_function
import os
import torch import torch
import torchaudio import torchaudio
import torchaudio.transforms as transforms import torchaudio.transforms as transforms
...@@ -8,13 +9,17 @@ import unittest ...@@ -8,13 +9,17 @@ import unittest
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
# create a sinewave signal for testing
sr = 16000 sr = 16000
freq = 440 freq = 440
volume = .3 volume = .3
sig = (torch.cos(2 * np.pi * torch.arange(0, 4 * sr).float() * freq / sr)) sig = (torch.cos(2 * np.pi * torch.arange(0, 4 * sr).float() * freq / sr))
# sig = (torch.cos((1+torch.arange(0, 4 * sr) * 2) / sr * 2 * np.pi * torch.arange(0, 4 * sr) * freq / sr)).float() sig.unsqueeze_(1) # (64000, 1)
sig.unsqueeze_(1)
sig = (sig * volume * 2**31).long() sig = (sig * volume * 2**31).long()
# file for stereo stft test
test_dirpath = os.path.dirname(os.path.realpath(__file__))
test_filepath = os.path.join(test_dirpath, "assets",
"steam-train-whistle-daniel_simon.mp3")
def test_scale(self): def test_scale(self):
...@@ -29,7 +34,7 @@ class Tester(unittest.TestCase): ...@@ -29,7 +34,7 @@ class Tester(unittest.TestCase):
result.min() >= -1. and result.max() <= 1.) result.min() >= -1. and result.max() <= 1.)
repr_test = transforms.Scale() repr_test = transforms.Scale()
repr_test.__repr__() self.assertTrue(repr_test.__repr__())
def test_pad_trim(self): def test_pad_trim(self):
...@@ -52,7 +57,7 @@ class Tester(unittest.TestCase): ...@@ -52,7 +57,7 @@ class Tester(unittest.TestCase):
self.assertEqual(result.size(0), length_new) self.assertEqual(result.size(0), length_new)
repr_test = transforms.PadTrim(max_len=length_new, channels_first=False) repr_test = transforms.PadTrim(max_len=length_new, channels_first=False)
repr_test.__repr__() self.assertTrue(repr_test.__repr__())
def test_downmix_mono(self): def test_downmix_mono(self):
...@@ -70,7 +75,7 @@ class Tester(unittest.TestCase): ...@@ -70,7 +75,7 @@ class Tester(unittest.TestCase):
self.assertTrue(result.size(1) == 1) self.assertTrue(result.size(1) == 1)
repr_test = transforms.DownmixMono(channels_first=False) repr_test = transforms.DownmixMono(channels_first=False)
repr_test.__repr__() self.assertTrue(repr_test.__repr__())
def test_lc2cl(self): def test_lc2cl(self):
...@@ -79,7 +84,7 @@ class Tester(unittest.TestCase): ...@@ -79,7 +84,7 @@ class Tester(unittest.TestCase):
self.assertTrue(result.size()[::-1] == audio.size()) self.assertTrue(result.size()[::-1] == audio.size())
repr_test = transforms.LC2CL() repr_test = transforms.LC2CL()
repr_test.__repr__() self.assertTrue(repr_test.__repr__())
def test_mel(self): def test_mel(self):
...@@ -92,9 +97,10 @@ class Tester(unittest.TestCase): ...@@ -92,9 +97,10 @@ class Tester(unittest.TestCase):
self.assertTrue(result.dim() == 3) self.assertTrue(result.dim() == 3)
repr_test = transforms.MEL() repr_test = transforms.MEL()
repr_test.__repr__() self.assertTrue(repr_test.__repr__())
repr_test = transforms.BLC2CBL() repr_test = transforms.BLC2CBL()
repr_test.__repr__() self.assertTrue(repr_test.__repr__())
def test_compose(self): def test_compose(self):
...@@ -113,7 +119,7 @@ class Tester(unittest.TestCase): ...@@ -113,7 +119,7 @@ class Tester(unittest.TestCase):
self.assertTrue(result.size(0) == length_new) self.assertTrue(result.size(0) == length_new)
repr_test = transforms.Compose(tset) repr_test = transforms.Compose(tset)
repr_test.__repr__() self.assertTrue(repr_test.__repr__())
def test_mu_law_companding(self): def test_mu_law_companding(self):
...@@ -141,17 +147,28 @@ class Tester(unittest.TestCase): ...@@ -141,17 +147,28 @@ class Tester(unittest.TestCase):
self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.) self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.)
repr_test = transforms.MuLawEncoding(quantization_channels) repr_test = transforms.MuLawEncoding(quantization_channels)
repr_test.__repr__() self.assertTrue(repr_test.__repr__())
repr_test = transforms.MuLawExpanding(quantization_channels) repr_test = transforms.MuLawExpanding(quantization_channels)
repr_test.__repr__() self.assertTrue(repr_test.__repr__())
def test_mel2(self): def test_mel2(self):
audio_orig = self.sig.clone() # (16000, 1) audio_orig = self.sig.clone() # (16000, 1)
audio_scaled = transforms.Scale()(audio_orig) # (16000, 1) audio_scaled = transforms.Scale()(audio_orig) # (16000, 1)
audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000) audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000)
spectrogram_torch = transforms.MEL2(window_fn=torch.hamming_window, pad=10)(audio_scaled) # (1, 319, 40) mel_transform = transforms.MEL2(window=torch.hamming_window, pad=10)
spectrogram_torch = mel_transform(audio_scaled) # (1, 319, 40)
self.assertTrue(spectrogram_torch.dim() == 3) self.assertTrue(spectrogram_torch.dim() == 3)
self.assertTrue(spectrogram_torch.max() <= 0.) self.assertTrue(spectrogram_torch.le(0.).all())
self.assertTrue(spectrogram_torch.ge(mel_transform.top_db).all())
self.assertEqual(spectrogram_torch.size(-1), mel_transform.n_mels)
# load stereo file
x_stereo, sr_stereo = torchaudio.load(self.test_filepath)
spectrogram_stereo = mel_transform(x_stereo)
self.assertTrue(spectrogram_stereo.dim() == 3)
self.assertTrue(spectrogram_stereo.size(0) == 2)
self.assertTrue(spectrogram_stereo.le(0.).all())
self.assertTrue(spectrogram_stereo.ge(mel_transform.top_db).all())
self.assertEqual(spectrogram_stereo.size(-1), mel_transform.n_mels)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -60,7 +60,7 @@ class Scale(object): ...@@ -60,7 +60,7 @@ class Scale(object):
Tensor: Scaled by the scale factor. (default between -1.0 and 1.0) Tensor: Scaled by the scale factor. (default between -1.0 and 1.0)
""" """
if not tensor.is_floating_point(): if not tensor.dtype.is_floating_point:
tensor = tensor.to(torch.float32) tensor = tensor.to(torch.float32)
return tensor / self.factor return tensor / self.factor
...@@ -125,7 +125,7 @@ class DownmixMono(object): ...@@ -125,7 +125,7 @@ class DownmixMono(object):
self.ch_dim = int(not channels_first) self.ch_dim = int(not channels_first)
def __call__(self, tensor): def __call__(self, tensor):
if not tensor.is_floating_point(): if not tensor.dtype.is_floating_point:
tensor = tensor.to(torch.float32) tensor = tensor.to(torch.float32)
tensor = torch.mean(tensor, self.ch_dim, True) tensor = torch.mean(tensor, self.ch_dim, True)
...@@ -169,8 +169,8 @@ class SPECTROGRAM(object): ...@@ -169,8 +169,8 @@ class SPECTROGRAM(object):
""" """
def __init__(self, sr=16000, ws=400, hop=None, n_fft=None, def __init__(self, sr=16000, ws=400, hop=None, n_fft=None,
pad=0, window_fn=torch.hann_window, wkwargs=None): pad=0, window=torch.hann_window, wkwargs=None):
self.window = window_fn(ws) if wkwargs is None else window_fn(ws, **wkwargs) self.window = window(ws) if wkwargs is None else window(ws, **wkwargs)
self.sr = sr self.sr = sr
self.ws = ws self.ws = ws
self.hop = hop if hop is not None else ws // 2 self.hop = hop if hop is not None else ws // 2
...@@ -197,7 +197,6 @@ class SPECTROGRAM(object): ...@@ -197,7 +197,6 @@ class SPECTROGRAM(object):
if self.pad > 0: if self.pad > 0:
with torch.no_grad(): with torch.no_grad():
sig = torch.nn.functional.pad(sig, (self.pad, self.pad), "constant") sig = torch.nn.functional.pad(sig, (self.pad, self.pad), "constant")
spec_f = torch.stft(sig, self.n_fft, self.hop, self.ws, spec_f = torch.stft(sig, self.n_fft, self.hop, self.ws,
self.window, center=False, self.window, center=False,
normalized=True, onesided=True).transpose(1, 2) normalized=True, onesided=True).transpose(1, 2)
...@@ -215,16 +214,28 @@ class F2M(object): ...@@ -215,16 +214,28 @@ class F2M(object):
sr (int): sample rate of audio signal sr (int): sample rate of audio signal
f_max (float, optional): maximum frequency. default: sr // 2 f_max (float, optional): maximum frequency. default: sr // 2
f_min (float): minimum frequency. default: 0 f_min (float): minimum frequency. default: 0
n_fft (int, optional): number of filter banks from stft. Calculated from first input
if `None` is given.
""" """
def __init__(self, n_mels=40, sr=16000, f_max=None, f_min=0.): def __init__(self, n_mels=40, sr=16000, f_max=None, f_min=0., n_fft=None):
self.n_mels = n_mels self.n_mels = n_mels
self.sr = sr self.sr = sr
self.f_max = f_max if f_max is not None else sr // 2 self.f_max = f_max if f_max is not None else sr // 2
self.f_min = f_min self.f_min = f_min
self.fb = self._create_fb_matrix(n_fft) if n_fft is not None else n_fft
def __call__(self, spec_f): def __call__(self, spec_f):
if self.fb is None:
self.fb = self._create_fb_matrix(spec_f.size(2))
spec_m = torch.matmul(spec_f, self.fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return spec_m
def _create_fb_matrix(self, n_fft):
""" Create a frequency bin conversion matrix.
n_fft = spec_f.size(2) Args:
n_fft (int): number of filter banks from spectrogram
"""
m_min = 0. if self.f_min == 0 else 2595 * np.log10(1. + (self.f_min / 700)) m_min = 0. if self.f_min == 0 else 2595 * np.log10(1. + (self.f_min / 700))
m_max = 2595 * np.log10(1. + (self.f_max / 700)) m_max = 2595 * np.log10(1. + (self.f_max / 700))
...@@ -234,19 +245,12 @@ class F2M(object): ...@@ -234,19 +245,12 @@ class F2M(object):
bins = torch.floor(((n_fft - 1) * 2) * f_pts / self.sr).long() bins = torch.floor(((n_fft - 1) * 2) * f_pts / self.sr).long()
fb = torch.zeros(n_fft, self.n_mels) fb = torch.zeros(n_fft, self.n_mels, dtype=torch.float)
for m in range(1, self.n_mels + 1): for m in range(1, self.n_mels + 1):
f_m_minus = bins[m - 1].item() f_m_minus = bins[m - 1].item()
f_m = bins[m].item()
f_m_plus = bins[m + 1].item() f_m_plus = bins[m + 1].item()
fb[f_m_minus:f_m_plus, m - 1] = torch.bartlett_window(f_m_plus - f_m_minus)
if f_m_minus != f_m: return fb
fb[f_m_minus:f_m, m - 1] = (torch.arange(f_m_minus, f_m) - f_m_minus) / (f_m - f_m_minus)
if f_m != f_m_plus:
fb[f_m:f_m_plus, m - 1] = (f_m_plus - torch.arange(f_m, f_m_plus)) / (f_m_plus - f_m)
spec_m = torch.matmul(spec_f, fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return spec_m
class SPEC2DB(object): class SPEC2DB(object):
...@@ -267,7 +271,7 @@ class SPEC2DB(object): ...@@ -267,7 +271,7 @@ class SPEC2DB(object):
spec_db = self.multiplier * torch.log10(spec / spec.max()) # power -> dB spec_db = self.multiplier * torch.log10(spec / spec.max()) # power -> dB
if self.top_db is not None: if self.top_db is not None:
spec_db = torch.max(spec_db, spec_db.new([self.top_db])) spec_db = torch.max(spec_db, torch.tensor(self.top_db, dtype=spec_db.dtype))
return spec_db return spec_db
...@@ -296,8 +300,8 @@ class MEL2(object): ...@@ -296,8 +300,8 @@ class MEL2(object):
>>> spec_mel = transforms.MEL2(sr)(sig) # (c, l, m) >>> spec_mel = transforms.MEL2(sr)(sig) # (c, l, m)
""" """
def __init__(self, sr=16000, ws=400, hop=None, n_fft=None, def __init__(self, sr=16000, ws=400, hop=None, n_fft=None,
pad=0, n_mels=40, window_fn=torch.hann_window, wkwargs=None): pad=0, n_mels=40, window=torch.hann_window, wkwargs=None):
self.window_fn = window_fn self.window = window
self.sr = sr self.sr = sr
self.ws = ws self.ws = ws
self.hop = hop if hop is not None else ws // 2 self.hop = hop if hop is not None else ws // 2
...@@ -308,6 +312,13 @@ class MEL2(object): ...@@ -308,6 +312,13 @@ class MEL2(object):
self.top_db = -80. self.top_db = -80.
self.f_max = None self.f_max = None
self.f_min = 0. self.f_min = 0.
self.spec = SPECTROGRAM(self.sr, self.ws, self.hop, self.n_fft,
self.pad, self.window, self.wkwargs)
self.fm = F2M(self.n_mels, self.sr, self.f_max, self.f_min, self.n_fft)
self.s2db = SPEC2DB("power", self.top_db)
self.transforms = Compose([
self.spec, self.fm, self.s2db,
])
def __call__(self, sig): def __call__(self, sig):
""" """
...@@ -320,15 +331,7 @@ class MEL2(object): ...@@ -320,15 +331,7 @@ class MEL2(object):
number of mel bins. number of mel bins.
""" """
spec_mel_db = self.transforms(sig)
transforms = Compose([
SPECTROGRAM(self.sr, self.ws, self.hop, self.n_fft,
self.pad, self.window_fn, self.wkwargs),
F2M(self.n_mels, self.sr, self.f_max, self.f_min),
SPEC2DB("power", self.top_db),
])
spec_mel_db = transforms(sig)
return spec_mel_db return spec_mel_db
...@@ -426,7 +429,7 @@ class MuLawEncoding(object): ...@@ -426,7 +429,7 @@ class MuLawEncoding(object):
x_mu = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu) x_mu = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu)
x_mu = ((x_mu + 1) / 2 * mu + 0.5).astype(int) x_mu = ((x_mu + 1) / 2 * mu + 0.5).astype(int)
elif isinstance(x, torch.Tensor): elif isinstance(x, torch.Tensor):
if not x.is_floating_point(): if not x.dtype.is_floating_point:
x = x.to(torch.float) x = x.to(torch.float)
mu = torch.tensor(mu, dtype=x.dtype) mu = torch.tensor(mu, dtype=x.dtype)
x_mu = torch.sign(x) * torch.log1p(mu * x_mu = torch.sign(x) * torch.log1p(mu *
...@@ -468,7 +471,7 @@ class MuLawExpanding(object): ...@@ -468,7 +471,7 @@ class MuLawExpanding(object):
x = ((x_mu) / mu) * 2 - 1. x = ((x_mu) / mu) * 2 - 1.
x = np.sign(x) * (np.exp(np.abs(x) * np.log1p(mu)) - 1.) / mu x = np.sign(x) * (np.exp(np.abs(x) * np.log1p(mu)) - 1.) / mu
elif isinstance(x_mu, torch.Tensor): elif isinstance(x_mu, torch.Tensor):
if not x_mu.is_floating_point(): if not x_mu.dtype.is_floating_point:
x_mu = x_mu.to(torch.float) x_mu = x_mu.to(torch.float)
mu = torch.tensor(mu, dtype=x_mu.dtype) mu = torch.tensor(mu, dtype=x_mu.dtype)
x = ((x_mu) / mu) * 2 - 1. x = ((x_mu) / mu) * 2 - 1.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment