"...text-generation-inference.git" did not exist on "e3e487dc711449c23826cfe1d74786f71309d6bd"
Commit 78be73b7 authored by David Pollack's avatar David Pollack Committed by Soumith Chintala
Browse files

mel spectrograms in pytorch (no longer req librosa)

parent c844ac63
...@@ -10,8 +10,9 @@ class Tester(unittest.TestCase): ...@@ -10,8 +10,9 @@ class Tester(unittest.TestCase):
sr = 16000 sr = 16000
freq = 440 freq = 440
volume = 0.3 volume = .3
sig = (torch.cos(2 * np.pi * torch.arange(0, 4 * sr) * freq / sr)).float() sig = (torch.cos(2 * np.pi * torch.arange(0, 4 * sr) * freq / sr)).float()
# sig = (torch.cos((1+torch.arange(0, 4 * sr) * 2) / sr * 2 * np.pi * torch.arange(0, 4 * sr) * freq / sr)).float()
sig.unsqueeze_(1) sig.unsqueeze_(1)
sig = (sig * volume * 2**31).long() sig = (sig * volume * 2**31).long()
...@@ -86,11 +87,11 @@ class Tester(unittest.TestCase): ...@@ -86,11 +87,11 @@ class Tester(unittest.TestCase):
audio = self.sig.clone() audio = self.sig.clone()
audio = transforms.Scale()(audio) audio = transforms.Scale()(audio)
self.assertTrue(len(audio.size()) == 2) self.assertTrue(audio.dim() == 2)
result = transforms.MEL()(audio) result = transforms.MEL()(audio)
self.assertTrue(len(result.size()) == 3) self.assertTrue(result.dim() == 3)
result = transforms.BLC2CBL()(result) result = transforms.BLC2CBL()(result)
self.assertTrue(len(result.size()) == 3) self.assertTrue(result.dim() == 3)
repr_test = transforms.MEL() repr_test = transforms.MEL()
repr_test.__repr__() repr_test.__repr__()
...@@ -146,6 +147,13 @@ class Tester(unittest.TestCase): ...@@ -146,6 +147,13 @@ class Tester(unittest.TestCase):
repr_test = transforms.MuLawExpanding(quantization_channels) repr_test = transforms.MuLawExpanding(quantization_channels)
repr_test.__repr__() repr_test.__repr__()
def test_mel2(self):
audio_orig = self.sig.clone() # (16000, 1)
audio_scaled = transforms.Scale()(audio_orig) # (16000, 1)
audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000)
spectrogram_torch = transforms.MEL2()(audio_scaled) # (1, 319, 40)
self.assertTrue(spectrogram_torch.dim() == 3)
self.assertTrue(spectrogram_torch.max() <= 0.)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
from __future__ import division, print_function from __future__ import division, print_function
import torch import torch
from torch.autograd import Variable
import numpy as np import numpy as np
try: try:
import librosa import librosa
...@@ -7,6 +8,24 @@ except ImportError: ...@@ -7,6 +8,24 @@ except ImportError:
librosa = None librosa = None
def _check_is_variable(tensor):
if isinstance(tensor, torch.Tensor):
is_variable = False
tensor = Variable(tensor, requires_grad=False)
elif isinstance(tensor, Variable):
is_variable = True
else:
raise TypeError("tensor should be a Variable or Tensor, but is {}".format(type(tensor)))
return tensor, is_variable
def _tlog10(x):
"""Pytorch Log10
"""
return torch.log(x) / torch.log(x.new([10]))
class Compose(object): class Compose(object):
"""Composes several transforms together. """Composes several transforms together.
...@@ -137,10 +156,10 @@ class LC2CL(object): ...@@ -137,10 +156,10 @@ class LC2CL(object):
""" """
Args: Args:
tensor (Tensor): Tensor of spectrogram with shape (BxLxC) tensor (Tensor): Tensor of audio signal with shape (LxC)
Returns: Returns:
tensor (Tensor): Tensor of spectrogram with shape (CxBxL) tensor (Tensor): Tensor of audio signal with shape (CxL)
""" """
...@@ -150,6 +169,190 @@ class LC2CL(object): ...@@ -150,6 +169,190 @@ class LC2CL(object):
return self.__class__.__name__ + '()' return self.__class__.__name__ + '()'
class SPECTROGRAM(object):
"""Create a spectrogram from a raw audio signal
Args:
sr (int): sample rate of audio signal
ws (int): window size, often called the fft size as well
hop (int, optional): length of hop between STFT windows. default: ws // 2
n_fft (int, optional): number of fft bins. default: ws // 2 + 1
pad (int): two sided padding of signal
window (torch windowing function): default: torch.hann_window
wkwargs (dict, optional): arguments for window function
"""
def __init__(self, sr=16000, ws=400, hop=None, n_fft=None,
pad=0, window=torch.hann_window, wkwargs=None):
if isinstance(window, Variable):
self.window = window
else:
self.window = window(ws) if wkwargs is None else window(ws, **wkwargs)
self.window = Variable(self.window, volatile=True)
self.sr = sr
self.ws = ws
self.hop = hop if hop is not None else ws // 2
self.n_fft = n_fft # number of fft bins
self.pad = pad
self.wkwargs = wkwargs
def __call__(self, sig):
"""
Args:
sig (Tensor or Variable): Tensor of audio of size (c, n)
Returns:
spec_f (Tensor or Variable): channels x hops x n_fft (c, l, f), where channels
is unchanged, hops is the number of hops, and n_fft is the
number of fourier bins, which should be the window size divided
by 2 plus 1.
"""
sig, is_variable = _check_is_variable(sig)
assert sig.dim() == 2
spec_f = torch.stft(sig, self.ws, self.hop, self.n_fft,
True, self.window, self.pad) # (c, l, n_fft, 2)
spec_f /= self.window.pow(2).sum().sqrt()
spec_f = spec_f.pow(2).sum(-1) # get power of "complex" tensor (c, l, n_fft)
return spec_f if is_variable else spec_f.data
class F2M(object):
"""This turns a normal STFT into a MEL Frequency STFT, using a conversion
matrix. This uses triangular filter banks.
Args:
n_mels (int): number of MEL bins
sr (int): sample rate of audio signal
f_max (float, optional): maximum frequency. default: sr // 2
f_min (float): minimum frequency. default: 0
"""
def __init__(self, n_mels=40, sr=16000, f_max=None, f_min=0.):
self.n_mels = n_mels
self.sr = sr
self.f_max = f_max if f_max is not None else sr // 2
self.f_min = f_min
def __call__(self, spec_f):
spec_f, is_variable = _check_is_variable(spec_f)
n_fft = spec_f.size(2)
m_min = 0. if self.f_min == 0 else 2595 * np.log10(1. + (self.f_min / 700))
m_max = 2595 * np.log10(1. + (self.f_max / 700))
m_pts = torch.linspace(m_min, m_max, self.n_mels + 2)
f_pts = (700 * (10**(m_pts / 2595) - 1))
bins = torch.floor(((n_fft - 1) * 2) * f_pts / self.sr).long()
fb = torch.zeros(n_fft, self.n_mels)
for m in range(1, self.n_mels + 1):
f_m_minus = bins[m - 1]
f_m = bins[m]
f_m_plus = bins[m + 1]
if f_m_minus != f_m:
fb[f_m_minus:f_m, m - 1] = (torch.arange(f_m_minus, f_m) - f_m_minus) / (f_m - f_m_minus)
if f_m != f_m_plus:
fb[f_m:f_m_plus, m - 1] = (f_m_plus - torch.arange(f_m, f_m_plus)) / (f_m_plus - f_m)
fb = Variable(fb)
spec_m = torch.matmul(spec_f, fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return spec_m if is_variable else spec_m.data
class SPEC2DB(object):
"""Turns a spectrogram from the power/amplitude scale to the decibel scale.
Args:
stype (str): scale of input spectrogram ("power" or "magnitude"). The
power being the elementwise square of the magnitude. default: "power"
top_db (float, optional): minimum negative cut-off in decibels. A reasonable number
is -80.
"""
def __init__(self, stype="power", top_db=None):
self.stype = stype
self.top_db = -top_db if top_db > 0 else top_db
self.multiplier = 10. if stype == "power" else 20.
def __call__(self, spec):
spec, is_variable = _check_is_variable(spec)
spec_db = self.multiplier * _tlog10(spec / spec.max()) # power -> dB
if self.top_db is not None:
spec_db = torch.max(spec_db, spec_db.new([self.top_db]))
return spec_db if is_variable else spec_db.data
class MEL2(object):
"""Create MEL Spectrograms from a raw audio signal using the stft
function in PyTorch. Hopefully this solves the speed issue of using
librosa.
Sources:
* https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
* https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
* http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
Args:
sr (int): sample rate of audio signal
ws (int): window size, often called the fft size as well
hop (int, optional): length of hop between STFT windows. default: ws // 2
n_fft (int, optional): number of fft bins. default: ws // 2 + 1
pad (int): two sided padding of signal
n_mels (int): number of MEL bins
window (torch windowing function): default: torch.hann_window
wkwargs (dict, optional): arguments for window function
Example:
>>> sig, sr = torchaudio.load("test.wav", normalization=True)
>>> sig = transforms.LC2CL()(sig) # (n, c) -> (c, n)
>>> spec_mel = transforms.MEL2(sr)(sig) # (c, l, m)
"""
def __init__(self, sr=16000, ws=400, hop=None, n_fft=None,
pad=0, n_mels=40, window=torch.hann_window, wkwargs=None):
self.window = window(ws) if wkwargs is None else window(ws, **wkwargs)
self.window = Variable(self.window, requires_grad=False)
self.sr = sr
self.ws = ws
self.hop = hop if hop is not None else ws // 2
self.n_fft = n_fft # number of fourier bins (ws // 2 + 1 by default)
self.pad = pad
self.n_mels = n_mels # number of mel frequency bins
self.wkwargs = wkwargs
self.top_db = -80.
self.f_max = None
self.f_min = 0.
def __call__(self, sig):
"""
Args:
sig (Tensor): Tensor of audio of size (channels [c], samples [n])
Returns:
spec_mel_db (Tensor): channels x hops x n_mels (c, l, m), where channels
is unchanged, hops is the number of hops, and n_mels is the
number of mel bins.
"""
sig, is_variable = _check_is_variable(sig)
transforms = Compose([
SPECTROGRAM(self.sr, self.ws, self.hop, self.n_fft,
self.pad, self.window),
F2M(self.n_mels, self.sr, self.f_max, self.f_min),
SPEC2DB("power", self.top_db),
])
spec_mel_db = transforms(sig)
return spec_mel_db if is_variable else spec_mel_db.data
class MEL(object): class MEL(object):
"""Create MEL Spectrograms from a raw audio signal. Relatively pretty slow. """Create MEL Spectrograms from a raw audio signal. Relatively pretty slow.
...@@ -164,7 +367,7 @@ class MEL(object): ...@@ -164,7 +367,7 @@ class MEL(object):
""" """
Args: Args:
tensor (Tensor): Tensor of audio of size (samples x channels) tensor (Tensor): Tensor of audio of size (samples [n] x channels [c])
Returns: Returns:
tensor (Tensor): n_mels x hops x channels (BxLxC), where n_mels is tensor (Tensor): n_mels x hops x channels (BxLxC), where n_mels is
...@@ -172,6 +375,7 @@ class MEL(object): ...@@ -172,6 +375,7 @@ class MEL(object):
is unchanged. is unchanged.
""" """
if librosa is None: if librosa is None:
print("librosa not installed, cannot create spectrograms") print("librosa not installed, cannot create spectrograms")
return tensor return tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment