Unverified Commit babc24af authored by moto's avatar moto Committed by GitHub
Browse files

Add test for InverseMelScale (#448)



* Inverse Mel Scale Implementation

* Inverse Mel Scale Docs

* Better working version.

* GPU fix

* These shouldn't go on git..

* Even better one, but does not support JITability.

* Remove JITability test

* Flake8

* n_stft is a must

* minor clean up of initialization

* Add librosa consistency test

This PR follows up #366 and adds test for `InverseMelScale` (and `MelScale`) for librosa compatibility.

For `MelScale` compatibility test;
1. Generate spectrogram
2. Feed the spectrogram to `torchaudio.transforms.MelScale` instance
3. Feed the spectrogram to `librosa.feature.melspectrogram` function.
4. Compare the result from 2 and 3 elementwise.
Element-wise numerical comparison is possible because under the hood their implementations use the same algorith.

For `InverseMelScale` compatibility test, it is more elaborated than that.
1. Generate the original spectrogram
2. Convert the original spectrogram to Mel scale using `torchaudio.transforms.MelScale` instance
3. Reconstruct spectrogram using torchaudio implementation
3.1. Feed the Mel spectrogram to `torchaudio.transforms.InverseMelScale` instance and get reconstructed spectrogram.
3.2. Compute the sum of element-wise P1 distance of the original spectrogram and that from 3.1.
4. Reconstruct spectrogram using librosa
4.1. Feed the Mel spectrogram to `librosa.feature.inverse.mel_to_stft` function and get reconstructed spectrogram.
4.2. Compute the sum of element-wise P1 distance of the original spectrogram and that from 4.1. (this is the reference.)
5. Check that resulting P1 distance are in a roughly same value range.

Element-wise numerical comparison is not possible due to the difference algorithms used to compute the inverse. The reconstructed spectrograms can have some values vary in magnitude.
Therefore the strategy here is to check that P1 distance (reconstruction loss) is not that different from the value obtained using `librosa`. For this purpose, threshold was empirically chosen

```
print('p1 dist (orig <-> ta):', torch.dist(spec_orig, spec_ta, p=1))
print('p1 dist (orig <-> lr):', torch.dist(spec_orig, spec_lr, p=1))
>>> p1 dist (orig <-> ta): tensor(1482.1917)
>>> p1 dist (orig <-> lr): tensor(1420.7103)
```

This value can vary based on the length and the kind of the signal being processed, so it was handpicked.

* Address review feedbacks

* Support arbitrary batch dimensions.

* Add batch test

* Use view for batch

* fix sgd

* Use negative indices and update docstring

* Update threshold
Co-authored-by: default avatarCharles J.Y. Yoon <jaeyeun97@gmail.com>
parent 2cf59c41
......@@ -37,6 +37,14 @@ Transforms are common audio transforms. They can be chained together using :clas
.. automethod:: forward
:hidden:`InverseMelScale`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: InverseMelScale
.. automethod:: forward
:hidden:`MelSpectrogram`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -410,6 +410,25 @@ class Tester(unittest.TestCase):
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def test_batch_InverseMelScale(self):
n_fft = 8
n_mels = 32
n_stft = 5
mel_spec = torch.randn(2, n_mels, 32) ** 2
# Single then transform then batch
expected = transforms.InverseMelScale(n_stft, n_mels)(mel_spec).repeat(3, 1, 1, 1)
# Batch then transform
computed = transforms.InverseMelScale(n_stft, n_mels)(mel_spec.repeat(3, 1, 1, 1))
# shape = (3, 2, n_mels, 32)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
# Because InverseMelScale runs SGD on randomly initialized values so they do not yield
# exactly same result. For this reason, tolerance is very relaxed here.
self.assertTrue(torch.allclose(computed, expected, atol=1.0))
def test_batch_compute_deltas(self):
specgram = torch.randn(2, 31, 2786)
......@@ -509,5 +528,97 @@ class Tester(unittest.TestCase):
_test_script_module(transforms.TimeMasking, tensor, time_mask_param=30, iid_masks=False)
class TestLibrosaConsistency(unittest.TestCase):
test_dirpath = None
test_dir = None
@classmethod
def setUpClass(cls):
cls.test_dirpath, cls.test_dir = common_utils.create_temp_assets_dir()
def _to_librosa(self, sound):
return sound.cpu().numpy().squeeze()
def _get_sample_data(self, *asset_paths, **kwargs):
file_path = os.path.join(self.test_dirpath, 'assets', *asset_paths)
sound, sample_rate = torchaudio.load(file_path, **kwargs)
return sound.mean(dim=0, keepdim=True), sample_rate
@unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available')
def test_MelScale(self):
"""MelScale transform is comparable to that of librosa"""
n_fft = 2048
n_mels = 256
hop_length = n_fft // 4
# Prepare spectrogram input. We use torchaudio to compute one.
sound, sample_rate = self._get_sample_data('whitenoise_1min.mp3')
spec_ta = F.spectrogram(
sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft,
hop_length=hop_length, win_length=n_fft, power=2, normalized=False)
spec_lr = spec_ta.cpu().numpy().squeeze()
# Perform MelScale with torchaudio and librosa
melspec_ta = transforms.MelScale(n_mels=n_mels, sample_rate=sample_rate)(spec_ta)
melspec_lr = librosa.feature.melspectrogram(
S=spec_lr, sr=sample_rate, n_fft=n_fft, hop_length=hop_length,
win_length=n_fft, center=True, window='hann', n_mels=n_mels, htk=True, norm=None)
# Note: Using relaxed rtol instead of atol
assert torch.allclose(melspec_ta, torch.from_numpy(melspec_lr[None, ...]), rtol=1e-3)
@unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available')
def test_InverseMelScale(self):
"""InverseMelScale transform is comparable to that of librosa"""
n_fft = 2048
n_mels = 256
n_stft = n_fft // 2 + 1
hop_length = n_fft // 4
# Prepare mel spectrogram input. We use torchaudio to compute one.
sound, sample_rate = self._get_sample_data(
'steam-train-whistle-daniel_simon.wav', offset=2**10, num_frames=2**14)
spec_orig = F.spectrogram(
sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft,
hop_length=hop_length, win_length=n_fft, power=2, normalized=False)
melspec_ta = transforms.MelScale(n_mels=n_mels, sample_rate=sample_rate)(spec_orig)
melspec_lr = melspec_ta.cpu().numpy().squeeze()
# Perform InverseMelScale with torch audio and librosa
spec_ta = transforms.InverseMelScale(
n_stft, n_mels=n_mels, sample_rate=sample_rate)(melspec_ta)
spec_lr = librosa.feature.inverse.mel_to_stft(
melspec_lr, sr=sample_rate, n_fft=n_fft, power=2.0, htk=True, norm=None)
spec_lr = torch.from_numpy(spec_lr[None, ...])
# Align dimensions
# librosa does not return power spectrogram while torchaudio returns power spectrogram
spec_orig = spec_orig.sqrt()
spec_ta = spec_ta.sqrt()
threshold = 2.0
# This threshold was choosen empirically, based on the following observation
#
# torch.dist(spec_lr, spec_ta, p=float('inf'))
# >>> tensor(1.9666)
#
# The spectrograms reconstructed by librosa and torchaudio are not very comparable elementwise.
# This is because they use different approximation algorithms and resulting values can live
# in different magnitude. (although most of them are very close)
# See https://github.com/pytorch/audio/pull/366 for the discussion of the choice of algorithm
# See https://github.com/pytorch/audio/pull/448/files#r385747021 for the distribution of P-inf
# distance over frequencies.
assert torch.allclose(spec_ta, spec_lr, atol=threshold)
threshold = 1700.0
# This threshold was choosen empirically, based on the following observations
#
# torch.dist(spec_orig, spec_ta, p=1)
# >>> tensor(1644.3516)
# torch.dist(spec_orig, spec_lr, p=1)
# >>> tensor(1420.7103)
# torch.dist(spec_lr, spec_ta, p=1)
# >>> tensor(943.2759)
assert torch.dist(spec_orig, spec_ta, p=1) < threshold
if __name__ == '__main__':
unittest.main()
......@@ -14,6 +14,7 @@ __all__ = [
'GriffinLim',
'AmplitudeToDB',
'MelScale',
'InverseMelScale',
'MelSpectrogram',
'MFCC',
'MuLawEncoding',
......@@ -233,6 +234,90 @@ class MelScale(torch.nn.Module):
return mel_specgram
class InverseMelScale(torch.nn.Module):
r"""Solve for a normal STFT from a mel frequency STFT, using a conversion
matrix. This uses triangular filter banks.
It minimizes the euclidian norm between the input mel-spectrogram and the product between
the estimated spectrogram and the filter banks using SGD.
Args:
n_stft (int): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`.
n_mels (int): Number of mel filterbanks. (Default: ``128``)
sample_rate (int): Sample rate of audio signal. (Default: ``16000``)
f_min (float): Minimum frequency. (Default: ``0.``)
f_max (float, optional): Maximum frequency. (Default: ``sample_rate // 2``)
max_iter (int): Maximum number of optimization iterations.
tolerance_loss (float): Value of loss to stop optimization at.
tolerance_change (float): Difference in losses to stop optimization at.
sgdargs (dict): Arguments for the SGD optimizer.
"""
__constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss',
'tolerance_change', 'sgdargs']
def __init__(self, n_stft, n_mels=128, sample_rate=16000, f_min=0., f_max=None, max_iter=100000,
tolerance_loss=1e-5, tolerance_change=1e-8, sgdargs=None):
super(InverseMelScale, self).__init__()
self.n_mels = n_mels
self.sample_rate = sample_rate
self.f_max = f_max or float(sample_rate // 2)
self.f_min = f_min
self.max_iter = max_iter
self.tolerance_loss = tolerance_loss
self.tolerance_change = tolerance_change
self.sgdargs = sgdargs or {'lr': 0.1, 'momentum': 0.9}
assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)
fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
self.register_buffer('fb', fb)
def forward(self, melspec):
r"""
Args:
melspec (torch.Tensor): A Mel frequency spectrogram of dimension (..., ``n_mels``, time)
Returns:
torch.Tensor: Linear scale spectrogram of size (..., freq, time)
"""
# pack batch
shape = melspec.size()
melspec = melspec.view(-1, shape[-2], shape[-1])
n_mels, time = shape[-2], shape[-1]
freq, _ = self.fb.size() # (freq, n_mels)
melspec = melspec.transpose(-1, -2)
assert self.n_mels == n_mels
specgram = torch.rand(melspec.size()[0], time, freq, requires_grad=True,
dtype=melspec.dtype, device=melspec.device)
optim = torch.optim.SGD([specgram], **self.sgdargs)
loss = float('inf')
for _ in range(self.max_iter):
optim.zero_grad()
diff = melspec - specgram.matmul(self.fb)
new_loss = diff.pow(2).sum(axis=-1).mean()
# take sum over mel-frequency then average over other dimensions
# so that loss threshold is applied par unit timeframe
new_loss.backward()
optim.step()
specgram.data = specgram.data.clamp(min=0)
new_loss = new_loss.item()
if new_loss < self.tolerance_loss or abs(loss - new_loss) < self.tolerance_change:
break
loss = new_loss
specgram.requires_grad_(False)
specgram = specgram.clamp(min=0).transpose(-1, -2)
# unpack batch
specgram = specgram.view(shape[:-2] + (freq, time))
return specgram
class MelSpectrogram(torch.nn.Module):
r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram
and MelScale.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment