Unverified Commit 86370639 authored by Joel Frank's avatar Joel Frank Committed by GitHub
Browse files

Add LFCC feature to transforms (#1611)

Summary:
- Add linear_fbank method
- Add LFCC in transforms
parent 108a32d9
......@@ -26,6 +26,11 @@ create_fb_matrix
.. autofunction:: create_fb_matrix
linear_fbanks
-------------
.. autofunction:: linear_fbanks
create_dct
----------
......
......@@ -59,6 +59,13 @@ Transforms are common audio transforms. They can be chained together using :clas
.. automethod:: forward
:hidden:`LFCC`
~~~~~~~~~~~~~~
.. autoclass:: LFCC
.. automethod:: forward
:hidden:`MuLawEncoding`
~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -132,6 +132,21 @@ class Functional(TempDirMixin, TestBaseMixin):
dummy = torch.zeros(1, 1)
self._assert_consistency(func, dummy)
def test_linear_fbanks(self):
if self.device != torch.device('cpu'):
raise unittest.SkipTest('No need to perform test on device other than CPU')
def func(_):
n_stft = 100
f_min = 0.0
f_max = 20.0
n_filter = 10
sample_rate = 16000
return F.linear_fbanks(n_stft, f_min, f_max, n_filter, sample_rate)
dummy = torch.zeros(1, 1)
self._assert_consistency(func, dummy)
def test_amplitude_to_DB(self):
def func(tensor):
multiplier = 10.0
......
......@@ -108,6 +108,13 @@ class AutogradTestMixin(TestBaseMixin):
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform])
@parameterized.expand([(False, ), (True, )])
def test_lfcc(self, log_lf):
sample_rate = 8000
transform = T.LFCC(sample_rate=sample_rate, log_lf=log_lf)
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform])
def test_compute_deltas(self):
transform = T.ComputeDeltas()
spec = torch.rand(10, 20)
......
......@@ -127,6 +127,16 @@ class TestTransforms(common_utils.TorchaudioTestCase):
computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected, atol=1e-4, rtol=1e-5)
def test_batch_lfcc(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=2)
# Single then transform then batch
expected = torchaudio.transforms.LFCC()(waveform).repeat(3, 1, 1, 1)
# Batch then transform
computed = torchaudio.transforms.LFCC()(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected, atol=1e-4, rtol=1e-5)
@parameterized.expand([(True, ), (False, )])
def test_batch_TimeStretch(self, test_pseudo_complex):
rate = 2
......
......@@ -71,6 +71,10 @@ class Transforms(TempDirMixin, TestBaseMixin):
tensor = torch.rand((1, 1000))
self._assert_consistency(T.MFCC(), tensor)
def test_LFCC(self):
tensor = torch.rand((1, 1000))
self._assert_consistency(T.LFCC(), tensor)
def test_Resample(self):
sr1, sr2 = 16000, 8000
tensor = common_utils.get_whitenoise(sample_rate=sr1)
......
......@@ -180,6 +180,65 @@ class Tester(common_utils.TorchaudioTestCase):
self.assertEqual(torch_mfcc_norm_none, norm_check)
def test_lfcc_defaults(self):
"""Check default settings for LFCC transform.
"""
sample_rate = 16000
audio = common_utils.get_whitenoise(sample_rate=sample_rate)
n_lfcc = 40
n_filter = 128
lfcc_transform = torchaudio.transforms.LFCC(sample_rate=sample_rate,
n_filter=n_filter,
n_lfcc=n_lfcc,
norm='ortho')
torch_lfcc = lfcc_transform(audio) # (1, 40, 81)
self.assertEqual(torch_lfcc.dim(), 3)
self.assertEqual(torch_lfcc.shape[1], n_lfcc)
self.assertEqual(torch_lfcc.shape[2], 81)
def test_lfcc_arg_passthrough(self):
"""Check if kwargs get correctly passed to the underlying Spectrogram transform.
"""
sample_rate = 16000
audio = common_utils.get_whitenoise(sample_rate=sample_rate)
n_lfcc = 40
n_filter = 128
speckwargs = {'win_length': 200}
lfcc_transform = torchaudio.transforms.LFCC(sample_rate=sample_rate,
n_filter=n_filter,
n_lfcc=n_lfcc,
norm='ortho',
speckwargs=speckwargs)
torch_lfcc = lfcc_transform(audio) # (1, 40, 161)
self.assertEqual(torch_lfcc.shape[2], 161)
def test_lfcc_norms(self):
"""Check if LFCC-DCT norm works correctly.
"""
sample_rate = 16000
audio = common_utils.get_whitenoise(sample_rate=sample_rate)
n_lfcc = 40
n_filter = 128
lfcc_transform = torchaudio.transforms.LFCC(sample_rate=sample_rate,
n_filter=n_filter,
n_lfcc=n_lfcc,
norm='ortho')
lfcc_transform_norm_none = torchaudio.transforms.LFCC(sample_rate=sample_rate,
n_filter=n_filter,
n_lfcc=n_lfcc,
norm=None)
torch_lfcc_norm_none = lfcc_transform_norm_none(audio) # (1, 40, 161)
norm_check = lfcc_transform(audio) # (1, 40, 161)
norm_check[:, 0, :] *= math.sqrt(n_filter) * 2
norm_check[:, 1:, :] *= math.sqrt(n_filter / 2) * 2
self.assertEqual(torch_lfcc_norm_none, norm_check)
def test_resample_size(self):
input_path = common_utils.get_asset_path('sinewave.wav')
waveform, sample_rate = common_utils.load_wav(input_path)
......
......@@ -6,6 +6,7 @@ from .functional import (
compute_kaldi_pitch,
create_dct,
create_fb_matrix,
linear_fbanks,
DB_to_amplitude,
detect_pitch_frequency,
griffinlim,
......@@ -55,6 +56,7 @@ __all__ = [
'compute_kaldi_pitch',
'create_dct',
'create_fb_matrix',
'linear_fbanks',
'DB_to_amplitude',
'detect_pitch_frequency',
'griffinlim',
......
......@@ -19,6 +19,7 @@ __all__ = [
"compute_deltas",
"compute_kaldi_pitch",
"create_fb_matrix",
"linear_fbanks",
"create_dct",
"compute_deltas",
"detect_pitch_frequency",
......@@ -376,6 +377,32 @@ def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor:
return freqs
def _create_triangular_filterbank(
all_freqs: Tensor,
f_pts: Tensor,
) -> Tensor:
"""Create a triangular filter bank.
Args:
all_freqs (Tensor): STFT freq points of size (`n_freqs`).
f_pts (Tensor): Filter mid points of size (`n_filter`).
Returns:
fb (Tensor): The filter bank of size (`n_freqs`, `n_filter`).
"""
# Adopted from Librosa
# calculate the difference between each filter mid point and each stft freq point in hertz
f_diff = f_pts[1:] - f_pts[:-1] # (n_filter + 1)
slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_filter + 2)
# create overlapping triangles
zero = torch.zeros(1)
down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_filter)
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_filter)
fb = torch.max(zero, torch.min(down_slopes, up_slopes))
return fb
def create_fb_matrix(
n_freqs: int,
f_min: float,
......@@ -409,7 +436,6 @@ def create_fb_matrix(
raise ValueError("norm must be one of None or 'slaney'")
# freq bins
# Equivalent filterbank construction by Librosa
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
# calculate mel freq bins
......@@ -419,14 +445,8 @@ def create_fb_matrix(
m_pts = torch.linspace(m_min, m_max, n_mels + 2)
f_pts = _mel_to_hz(m_pts, mel_scale=mel_scale)
# calculate the difference between each mel point and each stft freq point in hertz
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_mels + 2)
# create overlapping triangles
zero = torch.zeros(1)
down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels)
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels)
fb = torch.max(zero, torch.min(down_slopes, up_slopes))
# create filterbank
fb = _create_triangular_filterbank(all_freqs, f_pts)
if norm is not None and norm == "slaney":
# Slaney-style mel is scaled to be approx constant energy per channel
......@@ -443,6 +463,41 @@ def create_fb_matrix(
return fb
def linear_fbanks(
n_freqs: int,
f_min: float,
f_max: float,
n_filter: int,
sample_rate: int,
) -> Tensor:
r"""Creates a linear triangular filterbank.
Args:
n_freqs (int): Number of frequencies to highlight/apply
f_min (float): Minimum frequency (Hz)
f_max (float): Maximum frequency (Hz)
n_filter (int): Number of (linear) triangular filter
sample_rate (int): Sample rate of the audio waveform
Returns:
Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_filter``)
meaning number of frequencies to highlight/apply to x the number of filterbanks.
Each column is a filterbank so that assuming there is a matrix A of
size (..., ``n_freqs``), the applied result would be
``A * linear_fbanks(A.size(-1), ...)``.
"""
# freq bins
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
# filter mid-points
f_pts = torch.linspace(f_min, f_max, n_filter + 2)
# create filterbank
fb = _create_triangular_filterbank(all_freqs, f_pts)
return fb
def create_dct(
n_mfcc: int,
n_mels: int,
......
......@@ -21,6 +21,7 @@ __all__ = [
'InverseMelScale',
'MelSpectrogram',
'MFCC',
'LFCC',
'MuLawEncoding',
'MuLawDecoding',
'Resample',
......@@ -282,16 +283,8 @@ class MelScale(torch.nn.Module):
Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
"""
# pack batch
shape = specgram.size()
specgram = specgram.reshape(-1, shape[-2], shape[-1])
# (channel, frequency, time).transpose(...) dot (frequency, n_mels)
# -> (channel, time, n_mels).transpose(...)
mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
# unpack batch
mel_specgram = mel_specgram.reshape(shape[:-2] + mel_specgram.shape[-2:])
# (..., time, freq) dot (freq, n_mels) -> (..., n_mels, time)
mel_specgram = torch.matmul(specgram.transpose(-1, -2), self.fb).transpose(-1, -2)
return mel_specgram
......@@ -532,10 +525,8 @@ class MFCC(torch.nn.Module):
self.top_db = 80.0
self.amplitude_to_DB = AmplitudeToDB('power', self.top_db)
if melkwargs is not None:
self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs)
else:
self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate)
melkwargs = melkwargs or {}
self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs)
if self.n_mfcc > self.MelSpectrogram.n_mels:
raise ValueError('Cannot select more MFCC coefficients than # mel bins')
......@@ -558,12 +549,102 @@ class MFCC(torch.nn.Module):
else:
mel_specgram = self.amplitude_to_DB(mel_specgram)
# (..., channel, n_mels, time).transpose(...) dot (n_mels, n_mfcc)
# -> (..., channel, time, n_mfcc).transpose(...)
mfcc = torch.matmul(mel_specgram.transpose(-2, -1), self.dct_mat).transpose(-2, -1)
# (..., time, n_mels) dot (n_mels, n_mfcc) -> (..., n_nfcc, time)
mfcc = torch.matmul(mel_specgram.transpose(-1, -2), self.dct_mat).transpose(-1, -2)
return mfcc
class LFCC(torch.nn.Module):
r"""Create the linear-frequency cepstrum coefficients from an audio signal.
By default, this calculates the LFCC on the DB-scaled linear filtered spectrogram.
This is not the textbook implementation, but is implemented here to
give consistency with librosa.
This output depends on the maximum value in the input spectrogram, and so
may return different values for an audio clip split into snippets vs. a
a full clip.
Args:
sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
n_filter (int, optional): Number of linear filters to apply. (Default: ``128``)
n_lfcc (int, optional): Number of lfc coefficients to retain. (Default: ``40``)
f_min (float, optional): Minimum frequency. (Default: ``0.``)
f_max (float or None, optional): Maximum frequency. (Default: ``None``)
dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``)
norm (str, optional): norm to use. (Default: ``'ortho'``)
log_lf (bool, optional): whether to use log-lf spectrograms instead of db-scaled. (Default: ``False``)
speckwargs (dict or None, optional): arguments for Spectrogram. (Default: ``None``)
"""
__constants__ = ['sample_rate', 'n_filter', 'n_lfcc', 'dct_type', 'top_db', 'log_lf']
def __init__(self,
sample_rate: int = 16000,
n_filter: int = 128,
f_min: float = 0.,
f_max: Optional[float] = None,
n_lfcc: int = 40,
dct_type: int = 2,
norm: str = 'ortho',
log_lf: bool = False,
speckwargs: Optional[dict] = None) -> None:
super(LFCC, self).__init__()
supported_dct_types = [2]
if dct_type not in supported_dct_types:
raise ValueError('DCT type not supported: {}'.format(dct_type))
self.sample_rate = sample_rate
self.f_min = f_min
self.f_max = f_max if f_max is not None else float(sample_rate // 2)
self.n_filter = n_filter
self.n_lfcc = n_lfcc
self.dct_type = dct_type
self.norm = norm
self.top_db = 80.0
self.amplitude_to_DB = AmplitudeToDB('power', self.top_db)
speckwargs = speckwargs or {}
self.Spectrogram = Spectrogram(**speckwargs)
if self.n_lfcc > self.Spectrogram.n_fft:
raise ValueError('Cannot select more LFCC coefficients than # fft bins')
filter_mat = F.linear_fbanks(
n_freqs=self.Spectrogram.n_fft // 2 + 1,
f_min=self.f_min,
f_max=self.f_max,
n_filter=self.n_filter,
sample_rate=self.sample_rate,
)
self.register_buffer("filter_mat", filter_mat)
dct_mat = F.create_dct(self.n_lfcc, self.n_filter, self.norm)
self.register_buffer('dct_mat', dct_mat)
self.log_lf = log_lf
def forward(self, waveform: Tensor) -> Tensor:
r"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
Returns:
Tensor: Linear Frequency Cepstral Coefficients of size (..., ``n_lfcc``, time).
"""
specgram = self.Spectrogram(waveform)
# (..., time, freq) dot (freq, n_filter) -> (..., n_filter, time)
specgram = torch.matmul(specgram.transpose(-1, -2), self.filter_mat).transpose(-1, -2)
if self.log_lf:
log_offset = 1e-6
specgram = torch.log(specgram + log_offset)
else:
specgram = self.amplitude_to_DB(specgram)
# (..., time, n_filter) dot (n_filter, n_lfcc) -> (..., n_lfcc, time)
lfcc = torch.matmul(specgram.transpose(-1, -2), self.dct_mat).transpose(-1, -2)
return lfcc
class MuLawEncoding(torch.nn.Module):
r"""Encode signal based on mu-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment