Commit 7e15d2f9 authored by Jeremy Howard's avatar Jeremy Howard Committed by Soumith Chintala
Browse files

Rename classes in line with PyTorch standards. Remove redundent slow...

Rename classes in line with PyTorch standards. Remove redundent slow librosa-based `MEL`. Add missing docstring params. (#78)

* Bug fix: Use correct device for MEL2 functions so MEL2 works on CUDA tensors

* Rename classes in line with PyTorch standards. Remove redundent
slow librosa-based `MEL`. Add missing docstring params.

* fix param names
parent b311c4cc
...@@ -86,22 +86,6 @@ class Tester(unittest.TestCase): ...@@ -86,22 +86,6 @@ class Tester(unittest.TestCase):
repr_test = transforms.LC2CL() repr_test = transforms.LC2CL()
self.assertTrue(repr_test.__repr__()) self.assertTrue(repr_test.__repr__())
def test_mel(self):
audio = self.sig.clone()
audio = transforms.Scale()(audio)
self.assertTrue(audio.dim() == 2)
result = transforms.MEL()(audio)
self.assertTrue(result.dim() == 3)
result = transforms.BLC2CBL()(result)
self.assertTrue(result.dim() == 3)
repr_test = transforms.MEL()
self.assertTrue(repr_test.__repr__())
repr_test = transforms.BLC2CBL()
self.assertTrue(repr_test.__repr__())
def test_compose(self): def test_compose(self):
audio_orig = self.sig.clone() audio_orig = self.sig.clone()
...@@ -155,7 +139,7 @@ class Tester(unittest.TestCase): ...@@ -155,7 +139,7 @@ class Tester(unittest.TestCase):
audio_orig = self.sig.clone() # (16000, 1) audio_orig = self.sig.clone() # (16000, 1)
audio_scaled = transforms.Scale()(audio_orig) # (16000, 1) audio_scaled = transforms.Scale()(audio_orig) # (16000, 1)
audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000) audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000)
mel_transform = transforms.MEL2() mel_transform = transforms.MelSpectrogram()
# check defaults # check defaults
spectrogram_torch = mel_transform(audio_scaled) # (1, 319, 40) spectrogram_torch = mel_transform(audio_scaled) # (1, 319, 40)
self.assertTrue(spectrogram_torch.dim() == 3) self.assertTrue(spectrogram_torch.dim() == 3)
...@@ -166,7 +150,7 @@ class Tester(unittest.TestCase): ...@@ -166,7 +150,7 @@ class Tester(unittest.TestCase):
self.assertTrue(mel_transform.fm.fb.sum(1).le(1.).all()) self.assertTrue(mel_transform.fm.fb.sum(1).le(1.).all())
self.assertTrue(mel_transform.fm.fb.sum(1).ge(0.).all()) self.assertTrue(mel_transform.fm.fb.sum(1).ge(0.).all())
# check options # check options
mel_transform2 = transforms.MEL2(window=torch.hamming_window, pad=10, ws=500, hop=125, n_fft=800, n_mels=50) mel_transform2 = transforms.MelSpectrogram(window=torch.hamming_window, pad=10, ws=500, hop=125, n_fft=800, n_mels=50)
spectrogram2_torch = mel_transform2(audio_scaled) # (1, 506, 50) spectrogram2_torch = mel_transform2(audio_scaled) # (1, 506, 50)
self.assertTrue(spectrogram2_torch.dim() == 3) self.assertTrue(spectrogram2_torch.dim() == 3)
self.assertTrue(spectrogram2_torch.le(0.).all()) self.assertTrue(spectrogram2_torch.le(0.).all())
...@@ -183,7 +167,7 @@ class Tester(unittest.TestCase): ...@@ -183,7 +167,7 @@ class Tester(unittest.TestCase):
self.assertTrue(spectrogram_stereo.ge(mel_transform.top_db).all()) self.assertTrue(spectrogram_stereo.ge(mel_transform.top_db).all())
self.assertEqual(spectrogram_stereo.size(-1), mel_transform.n_mels) self.assertEqual(spectrogram_stereo.size(-1), mel_transform.n_mels)
# check filterbank matrix creation # check filterbank matrix creation
fb_matrix_transform = transforms.F2M(n_mels=100, sr=16000, f_max=None, f_min=0., n_stft=400) fb_matrix_transform = transforms.MelScale(n_mels=100, sr=16000, f_max=None, f_min=0., n_stft=400)
self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all()) self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all()) self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
self.assertEqual(fb_matrix_transform.fb.size(), (400, 100)) self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
......
from __future__ import division, print_function from __future__ import division, print_function
import torch import torch
import numpy as np import numpy as np
try:
import librosa
except ImportError:
librosa = None
class Compose(object): class Compose(object):
"""Composes several transforms together. """Composes several transforms together.
...@@ -155,7 +150,7 @@ class LC2CL(object): ...@@ -155,7 +150,7 @@ class LC2CL(object):
return self.__class__.__name__ + '()' return self.__class__.__name__ + '()'
class SPECTROGRAM(object): class Spectrogram(object):
"""Create a spectrogram from a raw audio signal """Create a spectrogram from a raw audio signal
Args: Args:
...@@ -205,17 +200,17 @@ class SPECTROGRAM(object): ...@@ -205,17 +200,17 @@ class SPECTROGRAM(object):
return spec_f return spec_f
class F2M(object): class MelScale(object):
"""This turns a normal STFT into a MEL Frequency STFT, using a conversion """This turns a normal STFT into a mel frequency STFT, using a conversion
matrix. This uses triangular filter banks. matrix. This uses triangular filter banks.
Args: Args:
n_mels (int): number of MEL bins n_mels (int): number of mel bins
sr (int): sample rate of audio signal sr (int): sample rate of audio signal
f_max (float, optional): maximum frequency. default: `sr` // 2 f_max (float, optional): maximum frequency. default: `sr` // 2
f_min (float): minimum frequency. default: 0 f_min (float): minimum frequency. default: 0
n_stft (int, optional): number of filter banks from stft. Calculated from first input n_stft (int, optional): number of filter banks from stft. Calculated from first input
if `None` is given. See `n_fft` in `SPECTROGRAM`. if `None` is given. See `n_fft` in `Spectrogram`.
""" """
def __init__(self, n_mels=40, sr=16000, f_max=None, f_min=0., n_stft=None): def __init__(self, n_mels=40, sr=16000, f_max=None, f_min=0., n_stft=None):
self.n_mels = n_mels self.n_mels = n_mels
...@@ -261,7 +256,7 @@ class F2M(object): ...@@ -261,7 +256,7 @@ class F2M(object):
return 700. * (10**(mel / 2595.) - 1.) return 700. * (10**(mel / 2595.) - 1.)
class SPEC2DB(object): class SpectogramToDB(object):
"""Turns a spectrogram from the power/amplitude scale to the decibel scale. """Turns a spectrogram from the power/amplitude scale to the decibel scale.
Args: Args:
...@@ -285,10 +280,9 @@ class SPEC2DB(object): ...@@ -285,10 +280,9 @@ class SPEC2DB(object):
return spec_db return spec_db
class MEL2(object): class MelSpectrogram(object):
"""Create MEL Spectrograms from a raw audio signal using the stft """Create MEL Spectrograms from a raw audio signal using the stft
function in PyTorch. Hopefully this solves the speed issue of using function in PyTorch.
librosa.
Sources: Sources:
* https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
...@@ -300,6 +294,8 @@ class MEL2(object): ...@@ -300,6 +294,8 @@ class MEL2(object):
ws (int): window size ws (int): window size
hop (int, optional): length of hop between STFT windows. default: `ws` // 2 hop (int, optional): length of hop between STFT windows. default: `ws` // 2
n_fft (int, optional): number of fft bins. default: `ws` // 2 + 1 n_fft (int, optional): number of fft bins. default: `ws` // 2 + 1
f_max (float, optional): maximum frequency. default: `sr` // 2
f_min (float): minimum frequency. default: 0
pad (int): two sided padding of signal pad (int): two sided padding of signal
n_mels (int): number of MEL bins n_mels (int): number of MEL bins
window (torch windowing function): default: `torch.hann_window` window (torch windowing function): default: `torch.hann_window`
...@@ -307,9 +303,9 @@ class MEL2(object): ...@@ -307,9 +303,9 @@ class MEL2(object):
Example: Example:
>>> sig, sr = torchaudio.load("test.wav", normalization=True) >>> sig, sr = torchaudio.load("test.wav", normalization=True)
>>> spec_mel = transforms.MEL2(sr)(sig) # (c, l, m) >>> spec_mel = transforms.MelSpectrogram(sr)(sig) # (c, l, m)
""" """
def __init__(self, sr=16000, ws=400, hop=None, n_fft=None, fmin=0., fmax=None, def __init__(self, sr=16000, ws=400, hop=None, n_fft=None, f_min=0., f_max=None,
pad=0, n_mels=40, window=torch.hann_window, wkwargs=None): pad=0, n_mels=40, window=torch.hann_window, wkwargs=None):
self.window = window self.window = window
self.sr = sr self.sr = sr
...@@ -320,12 +316,12 @@ class MEL2(object): ...@@ -320,12 +316,12 @@ class MEL2(object):
self.n_mels = n_mels # number of mel frequency bins self.n_mels = n_mels # number of mel frequency bins
self.wkwargs = wkwargs self.wkwargs = wkwargs
self.top_db = -80. self.top_db = -80.
self.f_max = fmax self.f_max = f_max
self.f_min = fmin self.f_min = f_min
self.spec = SPECTROGRAM(self.ws, self.hop, self.n_fft, self.spec = Spectrogram(self.ws, self.hop, self.n_fft,
self.pad, self.window, self.wkwargs) self.pad, self.window, self.wkwargs)
self.fm = F2M(self.n_mels, self.sr, self.f_max, self.f_min) self.fm = MelScale(self.n_mels, self.sr, self.f_max, self.f_min)
self.s2db = SPEC2DB("power", self.top_db) self.s2db = SpectogramToDB("power", self.top_db)
self.transforms = Compose([ self.transforms = Compose([
self.spec, self.fm, self.s2db, self.spec, self.fm, self.s2db,
]) ])
...@@ -345,48 +341,6 @@ class MEL2(object): ...@@ -345,48 +341,6 @@ class MEL2(object):
return spec_mel_db return spec_mel_db
class MEL(object):
"""Create MEL Spectrograms from a raw audio signal. Relatively pretty slow.
Usage (see librosa.feature.melspectrogram docs):
MEL(sr=16000, n_fft=1600, hop_length=800, n_mels=64)
"""
def __init__(self, **kwargs):
self.kwargs = kwargs
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor of audio of size (samples [n] x channels [c])
Returns:
tensor (Tensor): n_mels x hops x channels (BxLxC), where n_mels is
the number of mel bins, hops is the number of hops, and channels
is unchanged.
"""
if librosa is None:
print("librosa not installed, cannot create spectrograms")
return tensor
L = []
for i in range(tensor.size(1)):
nparr = tensor[:, i].numpy() # (samples, )
sgram = librosa.feature.melspectrogram(
nparr, **self.kwargs) # (n_mels, hops)
L.append(sgram)
L = np.stack(L, 2) # (n_mels, hops, channels)
tensor = torch.from_numpy(L).type_as(tensor)
return tensor
def __repr__(self):
return self.__class__.__name__ + '()'
class BLC2CBL(object): class BLC2CBL(object):
"""Permute a 3d tensor from Bands x Sample length x Channels to Channels x """Permute a 3d tensor from Bands x Sample length x Channels to Channels x
Bands x Samples length Bands x Samples length
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment