Unverified Commit 114461cc authored by jieruan's avatar jieruan Committed by GitHub
Browse files

Expose stft arguments to MelSpectrogram (#1211)

parent af1e457e
......@@ -219,3 +219,19 @@ class Tester(common_utils.TorchaudioTestCase):
computed = transform(specgram)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
self.assertEqual(computed, expected, atol=1e-6, rtol=1e-8)
class SmokeTest(common_utils.TorchaudioTestCase):
def test_spectrogram(self):
specgram = transforms.Spectrogram(center=False, pad_mode="reflect", onesided=False)
self.assertEqual(specgram.center, False)
self.assertEqual(specgram.pad_mode, "reflect")
self.assertEqual(specgram.onesided, False)
def test_melspectrogram(self):
melspecgram = transforms.MelSpectrogram(center=True, pad_mode="reflect", onesided=False)
specgram = melspecgram.spectrogram
self.assertEqual(specgram.center, True)
self.assertEqual(specgram.pad_mode, "reflect")
self.assertEqual(specgram.onesided, False)
......@@ -411,6 +411,13 @@ class MelSpectrogram(torch.nn.Module):
window_fn (Callable[..., Tensor], optional): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
center (bool, optional): whether to pad :attr:`waveform` on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
Default: ``True``
pad_mode (string, optional): controls the padding method used when
:attr:`center` is ``True``. Default: ``"reflect"``
onesided (bool, optional): controls whether to return half of results to
avoid redundancy. Default: ``True``
Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
......@@ -430,7 +437,10 @@ class MelSpectrogram(torch.nn.Module):
window_fn: Callable[..., Tensor] = torch.hann_window,
power: Optional[float] = 2.,
normalized: bool = False,
wkwargs: Optional[dict] = None) -> None:
wkwargs: Optional[dict] = None,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True) -> None:
super(MelSpectrogram, self).__init__()
self.sample_rate = sample_rate
self.n_fft = n_fft
......@@ -445,7 +455,8 @@ class MelSpectrogram(torch.nn.Module):
self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
hop_length=self.hop_length,
pad=self.pad, window_fn=window_fn, power=self.power,
normalized=self.normalized, wkwargs=wkwargs)
normalized=self.normalized, wkwargs=wkwargs,
center=center, pad_mode=pad_mode, onesided=onesided)
self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1)
def forward(self, waveform: Tensor) -> Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment