Commit a8d6a41b authored by David Pollack's avatar David Pollack Committed by Soumith Chintala
Browse files

optimization to MEL2 and fixes to filter bank conversion function

parent 3bd4db86
from __future__ import print_function
import os
import torch
import torchaudio
import torchaudio.transforms as transforms
......@@ -8,13 +9,17 @@ import unittest
class Tester(unittest.TestCase):
# create a sinewave signal for testing
sr = 16000
freq = 440
volume = .3
sig = (torch.cos(2 * np.pi * torch.arange(0, 4 * sr).float() * freq / sr))
# sig = (torch.cos((1+torch.arange(0, 4 * sr) * 2) / sr * 2 * np.pi * torch.arange(0, 4 * sr) * freq / sr)).float()
sig.unsqueeze_(1)
sig.unsqueeze_(1) # (64000, 1)
sig = (sig * volume * 2**31).long()
# file for stereo stft test
test_dirpath = os.path.dirname(os.path.realpath(__file__))
test_filepath = os.path.join(test_dirpath, "assets",
"steam-train-whistle-daniel_simon.mp3")
def test_scale(self):
......@@ -29,7 +34,7 @@ class Tester(unittest.TestCase):
result.min() >= -1. and result.max() <= 1.)
repr_test = transforms.Scale()
repr_test.__repr__()
self.assertTrue(repr_test.__repr__())
def test_pad_trim(self):
......@@ -52,7 +57,7 @@ class Tester(unittest.TestCase):
self.assertEqual(result.size(0), length_new)
repr_test = transforms.PadTrim(max_len=length_new, channels_first=False)
repr_test.__repr__()
self.assertTrue(repr_test.__repr__())
def test_downmix_mono(self):
......@@ -70,7 +75,7 @@ class Tester(unittest.TestCase):
self.assertTrue(result.size(1) == 1)
repr_test = transforms.DownmixMono(channels_first=False)
repr_test.__repr__()
self.assertTrue(repr_test.__repr__())
def test_lc2cl(self):
......@@ -79,7 +84,7 @@ class Tester(unittest.TestCase):
self.assertTrue(result.size()[::-1] == audio.size())
repr_test = transforms.LC2CL()
repr_test.__repr__()
self.assertTrue(repr_test.__repr__())
def test_mel(self):
......@@ -92,9 +97,10 @@ class Tester(unittest.TestCase):
self.assertTrue(result.dim() == 3)
repr_test = transforms.MEL()
repr_test.__repr__()
self.assertTrue(repr_test.__repr__())
repr_test = transforms.BLC2CBL()
repr_test.__repr__()
self.assertTrue(repr_test.__repr__())
def test_compose(self):
......@@ -113,7 +119,7 @@ class Tester(unittest.TestCase):
self.assertTrue(result.size(0) == length_new)
repr_test = transforms.Compose(tset)
repr_test.__repr__()
self.assertTrue(repr_test.__repr__())
def test_mu_law_companding(self):
......@@ -141,17 +147,28 @@ class Tester(unittest.TestCase):
self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.)
repr_test = transforms.MuLawEncoding(quantization_channels)
repr_test.__repr__()
self.assertTrue(repr_test.__repr__())
repr_test = transforms.MuLawExpanding(quantization_channels)
repr_test.__repr__()
self.assertTrue(repr_test.__repr__())
def test_mel2(self):
audio_orig = self.sig.clone() # (16000, 1)
audio_scaled = transforms.Scale()(audio_orig) # (16000, 1)
audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000)
spectrogram_torch = transforms.MEL2(window_fn=torch.hamming_window, pad=10)(audio_scaled) # (1, 319, 40)
mel_transform = transforms.MEL2(window=torch.hamming_window, pad=10)
spectrogram_torch = mel_transform(audio_scaled) # (1, 319, 40)
self.assertTrue(spectrogram_torch.dim() == 3)
self.assertTrue(spectrogram_torch.max() <= 0.)
self.assertTrue(spectrogram_torch.le(0.).all())
self.assertTrue(spectrogram_torch.ge(mel_transform.top_db).all())
self.assertEqual(spectrogram_torch.size(-1), mel_transform.n_mels)
# load stereo file
x_stereo, sr_stereo = torchaudio.load(self.test_filepath)
spectrogram_stereo = mel_transform(x_stereo)
self.assertTrue(spectrogram_stereo.dim() == 3)
self.assertTrue(spectrogram_stereo.size(0) == 2)
self.assertTrue(spectrogram_stereo.le(0.).all())
self.assertTrue(spectrogram_stereo.ge(mel_transform.top_db).all())
self.assertEqual(spectrogram_stereo.size(-1), mel_transform.n_mels)
if __name__ == '__main__':
unittest.main()
......@@ -60,7 +60,7 @@ class Scale(object):
Tensor: Scaled by the scale factor. (default between -1.0 and 1.0)
"""
if not tensor.is_floating_point():
if not tensor.dtype.is_floating_point:
tensor = tensor.to(torch.float32)
return tensor / self.factor
......@@ -125,7 +125,7 @@ class DownmixMono(object):
self.ch_dim = int(not channels_first)
def __call__(self, tensor):
if not tensor.is_floating_point():
if not tensor.dtype.is_floating_point:
tensor = tensor.to(torch.float32)
tensor = torch.mean(tensor, self.ch_dim, True)
......@@ -169,8 +169,8 @@ class SPECTROGRAM(object):
"""
def __init__(self, sr=16000, ws=400, hop=None, n_fft=None,
pad=0, window_fn=torch.hann_window, wkwargs=None):
self.window = window_fn(ws) if wkwargs is None else window_fn(ws, **wkwargs)
pad=0, window=torch.hann_window, wkwargs=None):
self.window = window(ws) if wkwargs is None else window(ws, **wkwargs)
self.sr = sr
self.ws = ws
self.hop = hop if hop is not None else ws // 2
......@@ -197,7 +197,6 @@ class SPECTROGRAM(object):
if self.pad > 0:
with torch.no_grad():
sig = torch.nn.functional.pad(sig, (self.pad, self.pad), "constant")
spec_f = torch.stft(sig, self.n_fft, self.hop, self.ws,
self.window, center=False,
normalized=True, onesided=True).transpose(1, 2)
......@@ -215,16 +214,28 @@ class F2M(object):
sr (int): sample rate of audio signal
f_max (float, optional): maximum frequency. default: sr // 2
f_min (float): minimum frequency. default: 0
n_fft (int, optional): number of filter banks from stft. Calculated from first input
if `None` is given.
"""
def __init__(self, n_mels=40, sr=16000, f_max=None, f_min=0.):
def __init__(self, n_mels=40, sr=16000, f_max=None, f_min=0., n_fft=None):
self.n_mels = n_mels
self.sr = sr
self.f_max = f_max if f_max is not None else sr // 2
self.f_min = f_min
self.fb = self._create_fb_matrix(n_fft) if n_fft is not None else n_fft
def __call__(self, spec_f):
if self.fb is None:
self.fb = self._create_fb_matrix(spec_f.size(2))
spec_m = torch.matmul(spec_f, self.fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return spec_m
def _create_fb_matrix(self, n_fft):
""" Create a frequency bin conversion matrix.
n_fft = spec_f.size(2)
Args:
n_fft (int): number of filter banks from spectrogram
"""
m_min = 0. if self.f_min == 0 else 2595 * np.log10(1. + (self.f_min / 700))
m_max = 2595 * np.log10(1. + (self.f_max / 700))
......@@ -234,19 +245,12 @@ class F2M(object):
bins = torch.floor(((n_fft - 1) * 2) * f_pts / self.sr).long()
fb = torch.zeros(n_fft, self.n_mels)
fb = torch.zeros(n_fft, self.n_mels, dtype=torch.float)
for m in range(1, self.n_mels + 1):
f_m_minus = bins[m - 1].item()
f_m = bins[m].item()
f_m_plus = bins[m + 1].item()
if f_m_minus != f_m:
fb[f_m_minus:f_m, m - 1] = (torch.arange(f_m_minus, f_m) - f_m_minus) / (f_m - f_m_minus)
if f_m != f_m_plus:
fb[f_m:f_m_plus, m - 1] = (f_m_plus - torch.arange(f_m, f_m_plus)) / (f_m_plus - f_m)
spec_m = torch.matmul(spec_f, fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return spec_m
fb[f_m_minus:f_m_plus, m - 1] = torch.bartlett_window(f_m_plus - f_m_minus)
return fb
class SPEC2DB(object):
......@@ -267,7 +271,7 @@ class SPEC2DB(object):
spec_db = self.multiplier * torch.log10(spec / spec.max()) # power -> dB
if self.top_db is not None:
spec_db = torch.max(spec_db, spec_db.new([self.top_db]))
spec_db = torch.max(spec_db, torch.tensor(self.top_db, dtype=spec_db.dtype))
return spec_db
......@@ -296,8 +300,8 @@ class MEL2(object):
>>> spec_mel = transforms.MEL2(sr)(sig) # (c, l, m)
"""
def __init__(self, sr=16000, ws=400, hop=None, n_fft=None,
pad=0, n_mels=40, window_fn=torch.hann_window, wkwargs=None):
self.window_fn = window_fn
pad=0, n_mels=40, window=torch.hann_window, wkwargs=None):
self.window = window
self.sr = sr
self.ws = ws
self.hop = hop if hop is not None else ws // 2
......@@ -308,6 +312,13 @@ class MEL2(object):
self.top_db = -80.
self.f_max = None
self.f_min = 0.
self.spec = SPECTROGRAM(self.sr, self.ws, self.hop, self.n_fft,
self.pad, self.window, self.wkwargs)
self.fm = F2M(self.n_mels, self.sr, self.f_max, self.f_min, self.n_fft)
self.s2db = SPEC2DB("power", self.top_db)
self.transforms = Compose([
self.spec, self.fm, self.s2db,
])
def __call__(self, sig):
"""
......@@ -320,15 +331,7 @@ class MEL2(object):
number of mel bins.
"""
transforms = Compose([
SPECTROGRAM(self.sr, self.ws, self.hop, self.n_fft,
self.pad, self.window_fn, self.wkwargs),
F2M(self.n_mels, self.sr, self.f_max, self.f_min),
SPEC2DB("power", self.top_db),
])
spec_mel_db = transforms(sig)
spec_mel_db = self.transforms(sig)
return spec_mel_db
......@@ -426,7 +429,7 @@ class MuLawEncoding(object):
x_mu = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu)
x_mu = ((x_mu + 1) / 2 * mu + 0.5).astype(int)
elif isinstance(x, torch.Tensor):
if not x.is_floating_point():
if not x.dtype.is_floating_point:
x = x.to(torch.float)
mu = torch.tensor(mu, dtype=x.dtype)
x_mu = torch.sign(x) * torch.log1p(mu *
......@@ -468,7 +471,7 @@ class MuLawExpanding(object):
x = ((x_mu) / mu) * 2 - 1.
x = np.sign(x) * (np.exp(np.abs(x) * np.log1p(mu)) - 1.) / mu
elif isinstance(x_mu, torch.Tensor):
if not x_mu.is_floating_point():
if not x_mu.dtype.is_floating_point:
x_mu = x_mu.to(torch.float)
mu = torch.tensor(mu, dtype=x_mu.dtype)
x = ((x_mu) / mu) * 2 - 1.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment