Commit 5a85a461 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

[BC-Breaking] Update InverseMelScale solution (#3280)

Summary:
Address https://github.com/pytorch/audio/issues/2643

- replace `SGD` optimization with `torch.linalg.lstsq` which is much faster.
- Add autograd test for `InverseMelScale`
- update other tests

Pull Request resolved: https://github.com/pytorch/audio/pull/3280

Reviewed By: hwangjeff

Differential Revision: D45679988

Pulled By: nateanl

fbshipit-source-id: a42e8bff9dc0f38e47e0482fd8a2aad902eedd59
parent 282ed27a
...@@ -189,8 +189,9 @@ class AutogradTestMixin(TestBaseMixin): ...@@ -189,8 +189,9 @@ class AutogradTestMixin(TestBaseMixin):
def test_melscale(self): def test_melscale(self):
sample_rate = 8000 sample_rate = 8000
n_fft = 400 n_fft = 400
n_mels = n_fft // 2 + 1 n_stft = n_fft // 2 + 1
transform = T.MelScale(sample_rate=sample_rate, n_mels=n_mels) n_mels = 128
transform = T.MelScale(sample_rate=sample_rate, n_mels=n_mels, n_stft=n_stft)
spec = get_spectrogram( spec = get_spectrogram(
get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2), n_fft=n_fft, power=1 get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2), n_fft=n_fft, power=1
) )
......
...@@ -52,11 +52,9 @@ class TestTransforms(common_utils.TorchaudioTestCase): ...@@ -52,11 +52,9 @@ class TestTransforms(common_utils.TorchaudioTestCase):
n_mels = 32 n_mels = 32
n_stft = 5 n_stft = 5
mel_spec = torch.randn(3, 2, n_mels, 32) ** 2 mel_spec = torch.randn(3, 2, n_mels, 32) ** 2
transform = T.InverseMelScale(n_stft, n_mels) transform = T.InverseMelScale(n_stft, n_mels, driver="gelsd")
# Because InverseMelScale runs SGD on randomly initialized values so they do not yield self.assert_batch_consistency(transform, mel_spec)
# exactly same result. For this reason, tolerance is very relaxed here.
self.assert_batch_consistency(transform, mel_spec, atol=1.0, rtol=1e-5)
def test_batch_compute_deltas(self): def test_batch_compute_deltas(self):
specgram = torch.randn(3, 2, 31, 2786) specgram = torch.randn(3, 2, 31, 2786)
......
...@@ -18,11 +18,11 @@ def _get_ratio(mat): ...@@ -18,11 +18,11 @@ def _get_ratio(mat):
class TransformsTestBase(TestBaseMixin): class TransformsTestBase(TestBaseMixin):
def test_InverseMelScale(self): def test_inverse_melscale(self):
"""Gauge the quality of InverseMelScale transform. """Gauge the quality of InverseMelScale transform.
As InverseMelScale is currently implemented with As InverseMelScale is currently implemented with
random initialization + iterative optimization, sub-optimal solution (compute matrix inverse + relu),
it is not practically possible to assert the difference between it is not practically possible to assert the difference between
the estimated spectrogram and the original spectrogram as a whole. the estimated spectrogram and the original spectrogram as a whole.
Estimated spectrogram has very huge descrepency locally. Estimated spectrogram has very huge descrepency locally.
......
...@@ -420,7 +420,7 @@ class InverseMelScale(torch.nn.Module): ...@@ -420,7 +420,7 @@ class InverseMelScale(torch.nn.Module):
.. devices:: CPU CUDA .. devices:: CPU CUDA
It minimizes the euclidian norm between the input mel-spectrogram and the product between It minimizes the euclidian norm between the input mel-spectrogram and the product between
the estimated spectrogram and the filter banks using SGD. the estimated spectrogram and the filter banks using `torch.linalg.lstsq`.
Args: Args:
n_stft (int): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. n_stft (int): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`.
...@@ -428,13 +428,13 @@ class InverseMelScale(torch.nn.Module): ...@@ -428,13 +428,13 @@ class InverseMelScale(torch.nn.Module):
sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
f_min (float, optional): Minimum frequency. (Default: ``0.``) f_min (float, optional): Minimum frequency. (Default: ``0.``)
f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``) f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
max_iter (int, optional): Maximum number of optimization iterations. (Default: ``100000``)
tolerance_loss (float, optional): Value of loss to stop optimization at. (Default: ``1e-5``)
tolerance_change (float, optional): Difference in losses to stop optimization at. (Default: ``1e-8``)
sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
norm (str or None, optional): If "slaney", divide the triangular mel weights by the width of the mel band norm (str or None, optional): If "slaney", divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``) (area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
driver (str, optional): Name of the LAPACK/MAGMA method to be used for `torch.lstsq`.
For CPU inputs the valid values are ``"gels"``, ``"gelsy"``, ``"gelsd"``, ``"gelss"``.
For CUDA input, the only valid driver is ``"gels"``, which assumes that A is full-rank.
(Default: ``"gels``)
Example Example
>>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True) >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
...@@ -449,10 +449,6 @@ class InverseMelScale(torch.nn.Module): ...@@ -449,10 +449,6 @@ class InverseMelScale(torch.nn.Module):
"sample_rate", "sample_rate",
"f_min", "f_min",
"f_max", "f_max",
"max_iter",
"tolerance_loss",
"tolerance_change",
"sgdargs",
] ]
def __init__( def __init__(
...@@ -462,26 +458,23 @@ class InverseMelScale(torch.nn.Module): ...@@ -462,26 +458,23 @@ class InverseMelScale(torch.nn.Module):
sample_rate: int = 16000, sample_rate: int = 16000,
f_min: float = 0.0, f_min: float = 0.0,
f_max: Optional[float] = None, f_max: Optional[float] = None,
max_iter: int = 100000,
tolerance_loss: float = 1e-5,
tolerance_change: float = 1e-8,
sgdargs: Optional[dict] = None,
norm: Optional[str] = None, norm: Optional[str] = None,
mel_scale: str = "htk", mel_scale: str = "htk",
driver: str = "gels",
) -> None: ) -> None:
super(InverseMelScale, self).__init__() super(InverseMelScale, self).__init__()
self.n_mels = n_mels self.n_mels = n_mels
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.f_max = f_max or float(sample_rate // 2) self.f_max = f_max or float(sample_rate // 2)
self.f_min = f_min self.f_min = f_min
self.max_iter = max_iter self.driver = driver
self.tolerance_loss = tolerance_loss
self.tolerance_change = tolerance_change
self.sgdargs = sgdargs or {"lr": 0.1, "momentum": 0.9}
if f_min > self.f_max: if f_min > self.f_max:
raise ValueError("Require f_min: {} <= f_max: {}".format(f_min, self.f_max)) raise ValueError("Require f_min: {} <= f_max: {}".format(f_min, self.f_max))
if driver not in ["gels", "gelsy", "gelsd", "gelss"]:
raise ValueError(f'driver must be one of ["gels", "gelsy", "gelsd", "gelss"]. Found {driver}.')
fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm, mel_scale) fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm, mel_scale)
self.register_buffer("fb", fb) self.register_buffer("fb", fb)
...@@ -499,34 +492,10 @@ class InverseMelScale(torch.nn.Module): ...@@ -499,34 +492,10 @@ class InverseMelScale(torch.nn.Module):
n_mels, time = shape[-2], shape[-1] n_mels, time = shape[-2], shape[-1]
freq, _ = self.fb.size() # (freq, n_mels) freq, _ = self.fb.size() # (freq, n_mels)
melspec = melspec.transpose(-1, -2)
if self.n_mels != n_mels: if self.n_mels != n_mels:
raise ValueError("Expected an input with {} mel bins. Found: {}".format(self.n_mels, n_mels)) raise ValueError("Expected an input with {} mel bins. Found: {}".format(self.n_mels, n_mels))
specgram = torch.rand( specgram = torch.relu(torch.linalg.lstsq(self.fb.transpose(-1, -2)[None], melspec, driver=self.driver).solution)
melspec.size()[0], time, freq, requires_grad=True, dtype=melspec.dtype, device=melspec.device
)
optim = torch.optim.SGD([specgram], **self.sgdargs)
loss = float("inf")
for _ in range(self.max_iter):
optim.zero_grad()
diff = melspec - specgram.matmul(self.fb)
new_loss = diff.pow(2).sum(axis=-1).mean()
# take sum over mel-frequency then average over other dimensions
# so that loss threshold is applied par unit timeframe
new_loss.backward()
optim.step()
specgram.data = specgram.data.clamp(min=0)
new_loss = new_loss.item()
if new_loss < self.tolerance_loss or abs(loss - new_loss) < self.tolerance_change:
break
loss = new_loss
specgram.requires_grad_(False)
specgram = specgram.clamp(min=0).transpose(-1, -2)
# unpack batch # unpack batch
specgram = specgram.view(shape[:-2] + (freq, time)) specgram = specgram.view(shape[:-2] + (freq, time))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment