Unverified Commit 0dfcbfde authored by Soumith Chintala's avatar Soumith Chintala Committed by GitHub
Browse files

Merge pull request #38 from pytorch/transforms

mel spectrograms in pytorch (no longer req librosa)
parents c844ac63 78be73b7
...@@ -10,8 +10,9 @@ class Tester(unittest.TestCase): ...@@ -10,8 +10,9 @@ class Tester(unittest.TestCase):
sr = 16000 sr = 16000
freq = 440 freq = 440
volume = 0.3 volume = .3
sig = (torch.cos(2 * np.pi * torch.arange(0, 4 * sr) * freq / sr)).float() sig = (torch.cos(2 * np.pi * torch.arange(0, 4 * sr) * freq / sr)).float()
# sig = (torch.cos((1+torch.arange(0, 4 * sr) * 2) / sr * 2 * np.pi * torch.arange(0, 4 * sr) * freq / sr)).float()
sig.unsqueeze_(1) sig.unsqueeze_(1)
sig = (sig * volume * 2**31).long() sig = (sig * volume * 2**31).long()
...@@ -86,11 +87,11 @@ class Tester(unittest.TestCase): ...@@ -86,11 +87,11 @@ class Tester(unittest.TestCase):
audio = self.sig.clone() audio = self.sig.clone()
audio = transforms.Scale()(audio) audio = transforms.Scale()(audio)
self.assertTrue(len(audio.size()) == 2) self.assertTrue(audio.dim() == 2)
result = transforms.MEL()(audio) result = transforms.MEL()(audio)
self.assertTrue(len(result.size()) == 3) self.assertTrue(result.dim() == 3)
result = transforms.BLC2CBL()(result) result = transforms.BLC2CBL()(result)
self.assertTrue(len(result.size()) == 3) self.assertTrue(result.dim() == 3)
repr_test = transforms.MEL() repr_test = transforms.MEL()
repr_test.__repr__() repr_test.__repr__()
...@@ -146,6 +147,13 @@ class Tester(unittest.TestCase): ...@@ -146,6 +147,13 @@ class Tester(unittest.TestCase):
repr_test = transforms.MuLawExpanding(quantization_channels) repr_test = transforms.MuLawExpanding(quantization_channels)
repr_test.__repr__() repr_test.__repr__()
def test_mel2(self):
audio_orig = self.sig.clone() # (16000, 1)
audio_scaled = transforms.Scale()(audio_orig) # (16000, 1)
audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000)
spectrogram_torch = transforms.MEL2()(audio_scaled) # (1, 319, 40)
self.assertTrue(spectrogram_torch.dim() == 3)
self.assertTrue(spectrogram_torch.max() <= 0.)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
from __future__ import division, print_function from __future__ import division, print_function
import torch import torch
from torch.autograd import Variable
import numpy as np import numpy as np
try: try:
import librosa import librosa
...@@ -7,6 +8,24 @@ except ImportError: ...@@ -7,6 +8,24 @@ except ImportError:
librosa = None librosa = None
def _check_is_variable(tensor):
if isinstance(tensor, torch.Tensor):
is_variable = False
tensor = Variable(tensor, requires_grad=False)
elif isinstance(tensor, Variable):
is_variable = True
else:
raise TypeError("tensor should be a Variable or Tensor, but is {}".format(type(tensor)))
return tensor, is_variable
def _tlog10(x):
"""Pytorch Log10
"""
return torch.log(x) / torch.log(x.new([10]))
class Compose(object): class Compose(object):
"""Composes several transforms together. """Composes several transforms together.
...@@ -137,10 +156,10 @@ class LC2CL(object): ...@@ -137,10 +156,10 @@ class LC2CL(object):
""" """
Args: Args:
tensor (Tensor): Tensor of spectrogram with shape (BxLxC) tensor (Tensor): Tensor of audio signal with shape (LxC)
Returns: Returns:
tensor (Tensor): Tensor of spectrogram with shape (CxBxL) tensor (Tensor): Tensor of audio signal with shape (CxL)
""" """
...@@ -150,6 +169,190 @@ class LC2CL(object): ...@@ -150,6 +169,190 @@ class LC2CL(object):
return self.__class__.__name__ + '()' return self.__class__.__name__ + '()'
class SPECTROGRAM(object):
"""Create a spectrogram from a raw audio signal
Args:
sr (int): sample rate of audio signal
ws (int): window size, often called the fft size as well
hop (int, optional): length of hop between STFT windows. default: ws // 2
n_fft (int, optional): number of fft bins. default: ws // 2 + 1
pad (int): two sided padding of signal
window (torch windowing function): default: torch.hann_window
wkwargs (dict, optional): arguments for window function
"""
def __init__(self, sr=16000, ws=400, hop=None, n_fft=None,
pad=0, window=torch.hann_window, wkwargs=None):
if isinstance(window, Variable):
self.window = window
else:
self.window = window(ws) if wkwargs is None else window(ws, **wkwargs)
self.window = Variable(self.window, volatile=True)
self.sr = sr
self.ws = ws
self.hop = hop if hop is not None else ws // 2
self.n_fft = n_fft # number of fft bins
self.pad = pad
self.wkwargs = wkwargs
def __call__(self, sig):
"""
Args:
sig (Tensor or Variable): Tensor of audio of size (c, n)
Returns:
spec_f (Tensor or Variable): channels x hops x n_fft (c, l, f), where channels
is unchanged, hops is the number of hops, and n_fft is the
number of fourier bins, which should be the window size divided
by 2 plus 1.
"""
sig, is_variable = _check_is_variable(sig)
assert sig.dim() == 2
spec_f = torch.stft(sig, self.ws, self.hop, self.n_fft,
True, self.window, self.pad) # (c, l, n_fft, 2)
spec_f /= self.window.pow(2).sum().sqrt()
spec_f = spec_f.pow(2).sum(-1) # get power of "complex" tensor (c, l, n_fft)
return spec_f if is_variable else spec_f.data
class F2M(object):
"""This turns a normal STFT into a MEL Frequency STFT, using a conversion
matrix. This uses triangular filter banks.
Args:
n_mels (int): number of MEL bins
sr (int): sample rate of audio signal
f_max (float, optional): maximum frequency. default: sr // 2
f_min (float): minimum frequency. default: 0
"""
def __init__(self, n_mels=40, sr=16000, f_max=None, f_min=0.):
self.n_mels = n_mels
self.sr = sr
self.f_max = f_max if f_max is not None else sr // 2
self.f_min = f_min
def __call__(self, spec_f):
spec_f, is_variable = _check_is_variable(spec_f)
n_fft = spec_f.size(2)
m_min = 0. if self.f_min == 0 else 2595 * np.log10(1. + (self.f_min / 700))
m_max = 2595 * np.log10(1. + (self.f_max / 700))
m_pts = torch.linspace(m_min, m_max, self.n_mels + 2)
f_pts = (700 * (10**(m_pts / 2595) - 1))
bins = torch.floor(((n_fft - 1) * 2) * f_pts / self.sr).long()
fb = torch.zeros(n_fft, self.n_mels)
for m in range(1, self.n_mels + 1):
f_m_minus = bins[m - 1]
f_m = bins[m]
f_m_plus = bins[m + 1]
if f_m_minus != f_m:
fb[f_m_minus:f_m, m - 1] = (torch.arange(f_m_minus, f_m) - f_m_minus) / (f_m - f_m_minus)
if f_m != f_m_plus:
fb[f_m:f_m_plus, m - 1] = (f_m_plus - torch.arange(f_m, f_m_plus)) / (f_m_plus - f_m)
fb = Variable(fb)
spec_m = torch.matmul(spec_f, fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return spec_m if is_variable else spec_m.data
class SPEC2DB(object):
"""Turns a spectrogram from the power/amplitude scale to the decibel scale.
Args:
stype (str): scale of input spectrogram ("power" or "magnitude"). The
power being the elementwise square of the magnitude. default: "power"
top_db (float, optional): minimum negative cut-off in decibels. A reasonable number
is -80.
"""
def __init__(self, stype="power", top_db=None):
self.stype = stype
self.top_db = -top_db if top_db > 0 else top_db
self.multiplier = 10. if stype == "power" else 20.
def __call__(self, spec):
spec, is_variable = _check_is_variable(spec)
spec_db = self.multiplier * _tlog10(spec / spec.max()) # power -> dB
if self.top_db is not None:
spec_db = torch.max(spec_db, spec_db.new([self.top_db]))
return spec_db if is_variable else spec_db.data
class MEL2(object):
"""Create MEL Spectrograms from a raw audio signal using the stft
function in PyTorch. Hopefully this solves the speed issue of using
librosa.
Sources:
* https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
* https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
* http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
Args:
sr (int): sample rate of audio signal
ws (int): window size, often called the fft size as well
hop (int, optional): length of hop between STFT windows. default: ws // 2
n_fft (int, optional): number of fft bins. default: ws // 2 + 1
pad (int): two sided padding of signal
n_mels (int): number of MEL bins
window (torch windowing function): default: torch.hann_window
wkwargs (dict, optional): arguments for window function
Example:
>>> sig, sr = torchaudio.load("test.wav", normalization=True)
>>> sig = transforms.LC2CL()(sig) # (n, c) -> (c, n)
>>> spec_mel = transforms.MEL2(sr)(sig) # (c, l, m)
"""
def __init__(self, sr=16000, ws=400, hop=None, n_fft=None,
pad=0, n_mels=40, window=torch.hann_window, wkwargs=None):
self.window = window(ws) if wkwargs is None else window(ws, **wkwargs)
self.window = Variable(self.window, requires_grad=False)
self.sr = sr
self.ws = ws
self.hop = hop if hop is not None else ws // 2
self.n_fft = n_fft # number of fourier bins (ws // 2 + 1 by default)
self.pad = pad
self.n_mels = n_mels # number of mel frequency bins
self.wkwargs = wkwargs
self.top_db = -80.
self.f_max = None
self.f_min = 0.
def __call__(self, sig):
"""
Args:
sig (Tensor): Tensor of audio of size (channels [c], samples [n])
Returns:
spec_mel_db (Tensor): channels x hops x n_mels (c, l, m), where channels
is unchanged, hops is the number of hops, and n_mels is the
number of mel bins.
"""
sig, is_variable = _check_is_variable(sig)
transforms = Compose([
SPECTROGRAM(self.sr, self.ws, self.hop, self.n_fft,
self.pad, self.window),
F2M(self.n_mels, self.sr, self.f_max, self.f_min),
SPEC2DB("power", self.top_db),
])
spec_mel_db = transforms(sig)
return spec_mel_db if is_variable else spec_mel_db.data
class MEL(object): class MEL(object):
"""Create MEL Spectrograms from a raw audio signal. Relatively pretty slow. """Create MEL Spectrograms from a raw audio signal. Relatively pretty slow.
...@@ -164,7 +367,7 @@ class MEL(object): ...@@ -164,7 +367,7 @@ class MEL(object):
""" """
Args: Args:
tensor (Tensor): Tensor of audio of size (samples x channels) tensor (Tensor): Tensor of audio of size (samples [n] x channels [c])
Returns: Returns:
tensor (Tensor): n_mels x hops x channels (BxLxC), where n_mels is tensor (Tensor): n_mels x hops x channels (BxLxC), where n_mels is
...@@ -172,6 +375,7 @@ class MEL(object): ...@@ -172,6 +375,7 @@ class MEL(object):
is unchanged. is unchanged.
""" """
if librosa is None: if librosa is None:
print("librosa not installed, cannot create spectrograms") print("librosa not installed, cannot create spectrograms")
return tensor return tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment