Unverified Commit 6db8522e authored by moto's avatar moto Committed by GitHub
Browse files

Update InverseMelScale comparison test (#1437)

* Remove an invalid InverseMel comparison unit test

Similar to #1426 `test_InverseMelScale` in `librosa_compatibility_test` is not
ensuring the comaptibility to librosa. Having this test can give a wrong statement
about the librosa numerical compatibility about the function.

* Add test for InverseMelScale

The new test compares the result of inverse mel scale against the reference spectrogram, so no need to use librosa.
This test serves as more like an insurance that the change to the implementation of InverseMelScale only improves the result, not the other way.
parent c1ef2edd
......@@ -202,58 +202,3 @@ class TestTransforms(common_utils.TorchaudioTestCase):
win_length=n_fft, center=True, window='hann', n_mels=n_mels, htk=True, norm=None)
# Note: Using relaxed rtol instead of atol
self.assertEqual(melspec_ta, torch.from_numpy(melspec_lr[None, ...]), atol=1e-8, rtol=1e-3)
def test_InverseMelScale(self):
"""InverseMelScale transform is comparable to that of librosa"""
n_fft = 2048
n_mels = 256
n_stft = n_fft // 2 + 1
hop_length = n_fft // 4
# Prepare mel spectrogram input. We use torchaudio to compute one.
path = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
sound, sample_rate = common_utils.load_wav(path)
sound = sound[:, 2**10:2**10 + 2**14]
sound = sound.mean(dim=0, keepdim=True)
spec_orig = F.spectrogram(
sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft,
hop_length=hop_length, win_length=n_fft, power=2, normalized=False)
melspec_ta = torchaudio.transforms.MelScale(n_mels=n_mels, sample_rate=sample_rate)(spec_orig)
melspec_lr = melspec_ta.cpu().numpy().squeeze()
# Perform InverseMelScale with torch audio and librosa
spec_ta = torchaudio.transforms.InverseMelScale(
n_stft, n_mels=n_mels, sample_rate=sample_rate)(melspec_ta)
spec_lr = librosa.feature.inverse.mel_to_stft(
melspec_lr, sr=sample_rate, n_fft=n_fft, power=2.0, htk=True, norm=None)
spec_lr = torch.from_numpy(spec_lr[None, ...])
# Align dimensions
# librosa does not return power spectrogram while torchaudio returns power spectrogram
spec_orig = spec_orig.sqrt()
spec_ta = spec_ta.sqrt()
threshold = 2.0
# This threshold was chosen empirically, based on the following observation
#
# torch.dist(spec_lr, spec_ta, p=float('inf'))
# >>> tensor(1.9666)
#
# The spectrograms reconstructed by librosa and torchaudio are not comparable elementwise.
# This is because they use different approximation algorithms and resulting values can live
# in different magnitude. (although most of them are very close)
# See
# https://github.com/pytorch/audio/pull/366 for the discussion of the choice of algorithm
# https://github.com/pytorch/audio/pull/448/files#r385747021 for the distribution of P-inf
# distance over frequencies.
self.assertEqual(spec_ta, spec_lr, atol=threshold, rtol=1e-5)
threshold = 1700.0
# This threshold was chosen empirically, based on the following observations
#
# torch.dist(spec_orig, spec_ta, p=1)
# >>> tensor(1644.3516)
# torch.dist(spec_orig, spec_lr, p=1)
# >>> tensor(1420.7103)
# torch.dist(spec_lr, spec_ta, p=1)
# >>> tensor(943.2759)
assert torch.dist(spec_orig, spec_ta, p=1) < threshold
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from . transforms_test_impl import TransformsTestBase
class TransformsCPUFloat32Test(TransformsTestBase, PytorchTestCase):
device = 'cpu'
dtype = torch.float32
class TransformsCPUFloat64Test(TransformsTestBase, PytorchTestCase):
device = 'cpu'
dtype = torch.float64
import torch
from torchaudio_unittest.common_utils import (
PytorchTestCase,
skipIfNoCuda,
)
from . transforms_test_impl import TransformsTestBase
@skipIfNoCuda
class TransformsCUDAFloat32Test(TransformsTestBase, PytorchTestCase):
device = 'cuda'
dtype = torch.float32
@skipIfNoCuda
class TransformsCUDAFloat64Test(TransformsTestBase, PytorchTestCase):
device = 'cuda'
dtype = torch.float64
import torch
import torchaudio.transforms as T
from torchaudio_unittest.common_utils import (
TestBaseMixin,
get_whitenoise,
get_spectrogram,
)
def _get_ratio(mat):
return (mat.sum() / mat.numel()).item()
class TransformsTestBase(TestBaseMixin):
def test_InverseMelScale(self):
"""Gauge the quality of InverseMelScale transform.
As InverseMelScale is currently implemented with
random initialization + iterative optimization,
it is not practically possible to assert the difference between
the estimated spectrogram and the original spectrogram as a whole.
Estimated spectrogram has very huge descrepency locally.
Thus in this test we gauge what percentage of elements are bellow
certain tolerance.
At the moment, the quality of estimated spectrogram is not good.
When implementation is changed in a way it makes the quality even worse,
this test will fail.
"""
n_fft = 400
power = 1
n_mels = 64
sample_rate = 8000
n_stft = n_fft // 2 + 1
# Generate reference spectrogram and input mel-scaled spectrogram
expected = get_spectrogram(
get_whitenoise(sample_rate=sample_rate, duration=1, n_channels=2),
n_fft=n_fft, power=power).to(self.device, self.dtype)
input = T.MelScale(
n_mels=n_mels, sample_rate=sample_rate
).to(self.device, self.dtype)(expected)
# Run transform
transform = T.InverseMelScale(
n_stft, n_mels=n_mels, sample_rate=sample_rate).to(self.device, self.dtype)
torch.random.manual_seed(0)
result = transform(input)
# Compare
epsilon = 1e-60
relative_diff = torch.abs((result - expected) / (expected + epsilon))
for tol in [1e-1, 1e-3, 1e-5, 1e-10]:
print(
f"Ratio of relative diff smaller than {tol:e} is "
f"{_get_ratio(relative_diff < tol)}")
assert _get_ratio(relative_diff < 1e-1) > 0.2
assert _get_ratio(relative_diff < 1e-3) > 5e-3
assert _get_ratio(relative_diff < 1e-5) > 1e-5
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment