Commit 5a85a461 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

[BC-Breaking] Update InverseMelScale solution (#3280)

Summary:
Address https://github.com/pytorch/audio/issues/2643

- replace `SGD` optimization with `torch.linalg.lstsq` which is much faster.
- Add autograd test for `InverseMelScale`
- update other tests

Pull Request resolved: https://github.com/pytorch/audio/pull/3280

Reviewed By: hwangjeff

Differential Revision: D45679988

Pulled By: nateanl

fbshipit-source-id: a42e8bff9dc0f38e47e0482fd8a2aad902eedd59
parent 282ed27a
......@@ -189,8 +189,9 @@ class AutogradTestMixin(TestBaseMixin):
def test_melscale(self):
sample_rate = 8000
n_fft = 400
n_mels = n_fft // 2 + 1
transform = T.MelScale(sample_rate=sample_rate, n_mels=n_mels)
n_stft = n_fft // 2 + 1
n_mels = 128
transform = T.MelScale(sample_rate=sample_rate, n_mels=n_mels, n_stft=n_stft)
spec = get_spectrogram(
get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2), n_fft=n_fft, power=1
)
......
......@@ -52,11 +52,9 @@ class TestTransforms(common_utils.TorchaudioTestCase):
n_mels = 32
n_stft = 5
mel_spec = torch.randn(3, 2, n_mels, 32) ** 2
transform = T.InverseMelScale(n_stft, n_mels)
transform = T.InverseMelScale(n_stft, n_mels, driver="gelsd")
# Because InverseMelScale runs SGD on randomly initialized values so they do not yield
# exactly same result. For this reason, tolerance is very relaxed here.
self.assert_batch_consistency(transform, mel_spec, atol=1.0, rtol=1e-5)
self.assert_batch_consistency(transform, mel_spec)
def test_batch_compute_deltas(self):
specgram = torch.randn(3, 2, 31, 2786)
......
......@@ -18,11 +18,11 @@ def _get_ratio(mat):
class TransformsTestBase(TestBaseMixin):
def test_InverseMelScale(self):
def test_inverse_melscale(self):
"""Gauge the quality of InverseMelScale transform.
As InverseMelScale is currently implemented with
random initialization + iterative optimization,
sub-optimal solution (compute matrix inverse + relu),
it is not practically possible to assert the difference between
the estimated spectrogram and the original spectrogram as a whole.
Estimated spectrogram has very huge descrepency locally.
......
......@@ -420,7 +420,7 @@ class InverseMelScale(torch.nn.Module):
.. devices:: CPU CUDA
It minimizes the euclidian norm between the input mel-spectrogram and the product between
the estimated spectrogram and the filter banks using SGD.
the estimated spectrogram and the filter banks using `torch.linalg.lstsq`.
Args:
n_stft (int): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`.
......@@ -428,13 +428,13 @@ class InverseMelScale(torch.nn.Module):
sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
f_min (float, optional): Minimum frequency. (Default: ``0.``)
f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
max_iter (int, optional): Maximum number of optimization iterations. (Default: ``100000``)
tolerance_loss (float, optional): Value of loss to stop optimization at. (Default: ``1e-5``)
tolerance_change (float, optional): Difference in losses to stop optimization at. (Default: ``1e-8``)
sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
norm (str or None, optional): If "slaney", divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
driver (str, optional): Name of the LAPACK/MAGMA method to be used for `torch.lstsq`.
For CPU inputs the valid values are ``"gels"``, ``"gelsy"``, ``"gelsd"``, ``"gelss"``.
For CUDA input, the only valid driver is ``"gels"``, which assumes that A is full-rank.
(Default: ``"gels``)
Example
>>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
......@@ -449,10 +449,6 @@ class InverseMelScale(torch.nn.Module):
"sample_rate",
"f_min",
"f_max",
"max_iter",
"tolerance_loss",
"tolerance_change",
"sgdargs",
]
def __init__(
......@@ -462,26 +458,23 @@ class InverseMelScale(torch.nn.Module):
sample_rate: int = 16000,
f_min: float = 0.0,
f_max: Optional[float] = None,
max_iter: int = 100000,
tolerance_loss: float = 1e-5,
tolerance_change: float = 1e-8,
sgdargs: Optional[dict] = None,
norm: Optional[str] = None,
mel_scale: str = "htk",
driver: str = "gels",
) -> None:
super(InverseMelScale, self).__init__()
self.n_mels = n_mels
self.sample_rate = sample_rate
self.f_max = f_max or float(sample_rate // 2)
self.f_min = f_min
self.max_iter = max_iter
self.tolerance_loss = tolerance_loss
self.tolerance_change = tolerance_change
self.sgdargs = sgdargs or {"lr": 0.1, "momentum": 0.9}
self.driver = driver
if f_min > self.f_max:
raise ValueError("Require f_min: {} <= f_max: {}".format(f_min, self.f_max))
if driver not in ["gels", "gelsy", "gelsd", "gelss"]:
raise ValueError(f'driver must be one of ["gels", "gelsy", "gelsd", "gelss"]. Found {driver}.')
fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm, mel_scale)
self.register_buffer("fb", fb)
......@@ -499,34 +492,10 @@ class InverseMelScale(torch.nn.Module):
n_mels, time = shape[-2], shape[-1]
freq, _ = self.fb.size() # (freq, n_mels)
melspec = melspec.transpose(-1, -2)
if self.n_mels != n_mels:
raise ValueError("Expected an input with {} mel bins. Found: {}".format(self.n_mels, n_mels))
specgram = torch.rand(
melspec.size()[0], time, freq, requires_grad=True, dtype=melspec.dtype, device=melspec.device
)
optim = torch.optim.SGD([specgram], **self.sgdargs)
loss = float("inf")
for _ in range(self.max_iter):
optim.zero_grad()
diff = melspec - specgram.matmul(self.fb)
new_loss = diff.pow(2).sum(axis=-1).mean()
# take sum over mel-frequency then average over other dimensions
# so that loss threshold is applied par unit timeframe
new_loss.backward()
optim.step()
specgram.data = specgram.data.clamp(min=0)
new_loss = new_loss.item()
if new_loss < self.tolerance_loss or abs(loss - new_loss) < self.tolerance_change:
break
loss = new_loss
specgram.requires_grad_(False)
specgram = specgram.clamp(min=0).transpose(-1, -2)
specgram = torch.relu(torch.linalg.lstsq(self.fb.transpose(-1, -2)[None], melspec, driver=self.driver).solution)
# unpack batch
specgram = specgram.view(shape[:-2] + (freq, time))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment