"vscode:/vscode.git/clone" did not exist on "20a39403b0d9a87bbd8d3d399b3e9c06011526a4"
Commit a8d6a41b authored by David Pollack's avatar David Pollack Committed by Soumith Chintala
Browse files

optimization to MEL2 and fixes to filter bank conversion function

parent 3bd4db86
from __future__ import print_function
import os
import torch
import torchaudio
import torchaudio.transforms as transforms
......@@ -8,13 +9,17 @@ import unittest
class Tester(unittest.TestCase):
# create a sinewave signal for testing
sr = 16000
freq = 440
volume = .3
sig = (torch.cos(2 * np.pi * torch.arange(0, 4 * sr).float() * freq / sr))
# sig = (torch.cos((1+torch.arange(0, 4 * sr) * 2) / sr * 2 * np.pi * torch.arange(0, 4 * sr) * freq / sr)).float()
sig.unsqueeze_(1)
sig.unsqueeze_(1) # (64000, 1)
sig = (sig * volume * 2**31).long()
# file for stereo stft test
test_dirpath = os.path.dirname(os.path.realpath(__file__))
test_filepath = os.path.join(test_dirpath, "assets",
"steam-train-whistle-daniel_simon.mp3")
def test_scale(self):
......@@ -29,7 +34,7 @@ class Tester(unittest.TestCase):
result.min() >= -1. and result.max() <= 1.)
repr_test = transforms.Scale()
repr_test.__repr__()
self.assertTrue(repr_test.__repr__())
def test_pad_trim(self):
......@@ -52,7 +57,7 @@ class Tester(unittest.TestCase):
self.assertEqual(result.size(0), length_new)
repr_test = transforms.PadTrim(max_len=length_new, channels_first=False)
repr_test.__repr__()
self.assertTrue(repr_test.__repr__())
def test_downmix_mono(self):
......@@ -70,7 +75,7 @@ class Tester(unittest.TestCase):
self.assertTrue(result.size(1) == 1)
repr_test = transforms.DownmixMono(channels_first=False)
repr_test.__repr__()
self.assertTrue(repr_test.__repr__())
def test_lc2cl(self):
......@@ -79,7 +84,7 @@ class Tester(unittest.TestCase):
self.assertTrue(result.size()[::-1] == audio.size())
repr_test = transforms.LC2CL()
repr_test.__repr__()
self.assertTrue(repr_test.__repr__())
def test_mel(self):
......@@ -92,9 +97,10 @@ class Tester(unittest.TestCase):
self.assertTrue(result.dim() == 3)
repr_test = transforms.MEL()
repr_test.__repr__()
self.assertTrue(repr_test.__repr__())
repr_test = transforms.BLC2CBL()
repr_test.__repr__()
self.assertTrue(repr_test.__repr__())
def test_compose(self):
......@@ -113,7 +119,7 @@ class Tester(unittest.TestCase):
self.assertTrue(result.size(0) == length_new)
repr_test = transforms.Compose(tset)
repr_test.__repr__()
self.assertTrue(repr_test.__repr__())
def test_mu_law_companding(self):
......@@ -141,17 +147,28 @@ class Tester(unittest.TestCase):
self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.)
repr_test = transforms.MuLawEncoding(quantization_channels)
repr_test.__repr__()
self.assertTrue(repr_test.__repr__())
repr_test = transforms.MuLawExpanding(quantization_channels)
repr_test.__repr__()
self.assertTrue(repr_test.__repr__())
def test_mel2(self):
audio_orig = self.sig.clone() # (16000, 1)
audio_scaled = transforms.Scale()(audio_orig) # (16000, 1)
audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000)
spectrogram_torch = transforms.MEL2(window_fn=torch.hamming_window, pad=10)(audio_scaled) # (1, 319, 40)
mel_transform = transforms.MEL2(window=torch.hamming_window, pad=10)
spectrogram_torch = mel_transform(audio_scaled) # (1, 319, 40)
self.assertTrue(spectrogram_torch.dim() == 3)
self.assertTrue(spectrogram_torch.max() <= 0.)
self.assertTrue(spectrogram_torch.le(0.).all())
self.assertTrue(spectrogram_torch.ge(mel_transform.top_db).all())
self.assertEqual(spectrogram_torch.size(-1), mel_transform.n_mels)
# load stereo file
x_stereo, sr_stereo = torchaudio.load(self.test_filepath)
spectrogram_stereo = mel_transform(x_stereo)
self.assertTrue(spectrogram_stereo.dim() == 3)
self.assertTrue(spectrogram_stereo.size(0) == 2)
self.assertTrue(spectrogram_stereo.le(0.).all())
self.assertTrue(spectrogram_stereo.ge(mel_transform.top_db).all())
self.assertEqual(spectrogram_stereo.size(-1), mel_transform.n_mels)
if __name__ == '__main__':
unittest.main()
......@@ -60,7 +60,7 @@ class Scale(object):
Tensor: Scaled by the scale factor. (default between -1.0 and 1.0)
"""
if not tensor.is_floating_point():
if not tensor.dtype.is_floating_point:
tensor = tensor.to(torch.float32)
return tensor / self.factor
......@@ -125,7 +125,7 @@ class DownmixMono(object):
self.ch_dim = int(not channels_first)
def __call__(self, tensor):
if not tensor.is_floating_point():
if not tensor.dtype.is_floating_point:
tensor = tensor.to(torch.float32)
tensor = torch.mean(tensor, self.ch_dim, True)
......@@ -169,8 +169,8 @@ class SPECTROGRAM(object):
"""
def __init__(self, sr=16000, ws=400, hop=None, n_fft=None,
pad=0, window_fn=torch.hann_window, wkwargs=None):
self.window = window_fn(ws) if wkwargs is None else window_fn(ws, **wkwargs)
pad=0, window=torch.hann_window, wkwargs=None):
self.window = window(ws) if wkwargs is None else window(ws, **wkwargs)
self.sr = sr
self.ws = ws
self.hop = hop if hop is not None else ws // 2
......@@ -197,7 +197,6 @@ class SPECTROGRAM(object):
if self.pad > 0:
with torch.no_grad():
sig = torch.nn.functional.pad(sig, (self.pad, self.pad), "constant")
spec_f = torch.stft(sig, self.n_fft, self.hop, self.ws,
self.window, center=False,
normalized=True, onesided=True).transpose(1, 2)
......@@ -215,16 +214,28 @@ class F2M(object):
sr (int): sample rate of audio signal
f_max (float, optional): maximum frequency. default: sr // 2
f_min (float): minimum frequency. default: 0
n_fft (int, optional): number of filter banks from stft. Calculated from first input
if `None` is given.
"""
def __init__(self, n_mels=40, sr=16000, f_max=None, f_min=0.):
def __init__(self, n_mels=40, sr=16000, f_max=None, f_min=0., n_fft=None):
self.n_mels = n_mels
self.sr = sr
self.f_max = f_max if f_max is not None else sr // 2
self.f_min = f_min
self.fb = self._create_fb_matrix(n_fft) if n_fft is not None else n_fft
def __call__(self, spec_f):
if self.fb is None:
self.fb = self._create_fb_matrix(spec_f.size(2))
spec_m = torch.matmul(spec_f, self.fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return spec_m
def _create_fb_matrix(self, n_fft):
""" Create a frequency bin conversion matrix.
n_fft = spec_f.size(2)
Args:
n_fft (int): number of filter banks from spectrogram
"""
m_min = 0. if self.f_min == 0 else 2595 * np.log10(1. + (self.f_min / 700))
m_max = 2595 * np.log10(1. + (self.f_max / 700))
......@@ -234,19 +245,12 @@ class F2M(object):
bins = torch.floor(((n_fft - 1) * 2) * f_pts / self.sr).long()
fb = torch.zeros(n_fft, self.n_mels)
fb = torch.zeros(n_fft, self.n_mels, dtype=torch.float)
for m in range(1, self.n_mels + 1):
f_m_minus = bins[m - 1].item()
f_m = bins[m].item()
f_m_plus = bins[m + 1].item()
if f_m_minus != f_m:
fb[f_m_minus:f_m, m - 1] = (torch.arange(f_m_minus, f_m) - f_m_minus) / (f_m - f_m_minus)
if f_m != f_m_plus:
fb[f_m:f_m_plus, m - 1] = (f_m_plus - torch.arange(f_m, f_m_plus)) / (f_m_plus - f_m)
spec_m = torch.matmul(spec_f, fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return spec_m
fb[f_m_minus:f_m_plus, m - 1] = torch.bartlett_window(f_m_plus - f_m_minus)
return fb
class SPEC2DB(object):
......@@ -267,7 +271,7 @@ class SPEC2DB(object):
spec_db = self.multiplier * torch.log10(spec / spec.max()) # power -> dB
if self.top_db is not None:
spec_db = torch.max(spec_db, spec_db.new([self.top_db]))
spec_db = torch.max(spec_db, torch.tensor(self.top_db, dtype=spec_db.dtype))
return spec_db
......@@ -296,8 +300,8 @@ class MEL2(object):
>>> spec_mel = transforms.MEL2(sr)(sig) # (c, l, m)
"""
def __init__(self, sr=16000, ws=400, hop=None, n_fft=None,
pad=0, n_mels=40, window_fn=torch.hann_window, wkwargs=None):
self.window_fn = window_fn
pad=0, n_mels=40, window=torch.hann_window, wkwargs=None):
self.window = window
self.sr = sr
self.ws = ws
self.hop = hop if hop is not None else ws // 2
......@@ -308,6 +312,13 @@ class MEL2(object):
self.top_db = -80.
self.f_max = None
self.f_min = 0.
self.spec = SPECTROGRAM(self.sr, self.ws, self.hop, self.n_fft,
self.pad, self.window, self.wkwargs)
self.fm = F2M(self.n_mels, self.sr, self.f_max, self.f_min, self.n_fft)
self.s2db = SPEC2DB("power", self.top_db)
self.transforms = Compose([
self.spec, self.fm, self.s2db,
])
def __call__(self, sig):
"""
......@@ -320,15 +331,7 @@ class MEL2(object):
number of mel bins.
"""
transforms = Compose([
SPECTROGRAM(self.sr, self.ws, self.hop, self.n_fft,
self.pad, self.window_fn, self.wkwargs),
F2M(self.n_mels, self.sr, self.f_max, self.f_min),
SPEC2DB("power", self.top_db),
])
spec_mel_db = transforms(sig)
spec_mel_db = self.transforms(sig)
return spec_mel_db
......@@ -426,7 +429,7 @@ class MuLawEncoding(object):
x_mu = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu)
x_mu = ((x_mu + 1) / 2 * mu + 0.5).astype(int)
elif isinstance(x, torch.Tensor):
if not x.is_floating_point():
if not x.dtype.is_floating_point:
x = x.to(torch.float)
mu = torch.tensor(mu, dtype=x.dtype)
x_mu = torch.sign(x) * torch.log1p(mu *
......@@ -468,7 +471,7 @@ class MuLawExpanding(object):
x = ((x_mu) / mu) * 2 - 1.
x = np.sign(x) * (np.exp(np.abs(x) * np.log1p(mu)) - 1.) / mu
elif isinstance(x_mu, torch.Tensor):
if not x_mu.is_floating_point():
if not x_mu.dtype.is_floating_point:
x_mu = x_mu.to(torch.float)
mu = torch.tensor(mu, dtype=x_mu.dtype)
x = ((x_mu) / mu) * 2 - 1.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment