"docs/zh_cn/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "733e6ff84eecaf5f987ca368bee12274d4a8eb59"
Unverified Commit 703614f2 authored by jieruan's avatar jieruan Committed by GitHub
Browse files

Expose normalization method to Mel transforms (#1212)

parent a4c095a3
...@@ -248,6 +248,8 @@ class MelScale(torch.nn.Module): ...@@ -248,6 +248,8 @@ class MelScale(torch.nn.Module):
f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``) f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
n_stft (int, optional): Number of bins in STFT. Calculated from first input n_stft (int, optional): Number of bins in STFT. Calculated from first input
if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``) if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
""" """
__constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max'] __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
...@@ -256,17 +258,19 @@ class MelScale(torch.nn.Module): ...@@ -256,17 +258,19 @@ class MelScale(torch.nn.Module):
sample_rate: int = 16000, sample_rate: int = 16000,
f_min: float = 0., f_min: float = 0.,
f_max: Optional[float] = None, f_max: Optional[float] = None,
n_stft: Optional[int] = None) -> None: n_stft: Optional[int] = None,
norm: Optional[str] = None) -> None:
super(MelScale, self).__init__() super(MelScale, self).__init__()
self.n_mels = n_mels self.n_mels = n_mels
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.f_max = f_max if f_max is not None else float(sample_rate // 2) self.f_max = f_max if f_max is not None else float(sample_rate // 2)
self.f_min = f_min self.f_min = f_min
self.norm = norm
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate) n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm)
self.register_buffer('fb', fb) self.register_buffer('fb', fb)
def forward(self, specgram: Tensor) -> Tensor: def forward(self, specgram: Tensor) -> Tensor:
...@@ -283,7 +287,8 @@ class MelScale(torch.nn.Module): ...@@ -283,7 +287,8 @@ class MelScale(torch.nn.Module):
specgram = specgram.reshape(-1, shape[-2], shape[-1]) specgram = specgram.reshape(-1, shape[-2], shape[-1])
if self.fb.numel() == 0: if self.fb.numel() == 0:
tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate) tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max,
self.n_mels, self.sample_rate, self.norm)
# Attributes cannot be reassigned outside __init__ so workaround # Attributes cannot be reassigned outside __init__ so workaround
self.fb.resize_(tmp_fb.size()) self.fb.resize_(tmp_fb.size())
self.fb.copy_(tmp_fb) self.fb.copy_(tmp_fb)
...@@ -315,6 +320,8 @@ class InverseMelScale(torch.nn.Module): ...@@ -315,6 +320,8 @@ class InverseMelScale(torch.nn.Module):
tolerance_loss (float, optional): Value of loss to stop optimization at. (Default: ``1e-5``) tolerance_loss (float, optional): Value of loss to stop optimization at. (Default: ``1e-5``)
tolerance_change (float, optional): Difference in losses to stop optimization at. (Default: ``1e-8``) tolerance_change (float, optional): Difference in losses to stop optimization at. (Default: ``1e-8``)
sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``) sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
""" """
__constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss', __constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss',
'tolerance_change', 'sgdargs'] 'tolerance_change', 'sgdargs']
...@@ -328,7 +335,8 @@ class InverseMelScale(torch.nn.Module): ...@@ -328,7 +335,8 @@ class InverseMelScale(torch.nn.Module):
max_iter: int = 100000, max_iter: int = 100000,
tolerance_loss: float = 1e-5, tolerance_loss: float = 1e-5,
tolerance_change: float = 1e-8, tolerance_change: float = 1e-8,
sgdargs: Optional[dict] = None) -> None: sgdargs: Optional[dict] = None,
norm: Optional[str] = None) -> None:
super(InverseMelScale, self).__init__() super(InverseMelScale, self).__init__()
self.n_mels = n_mels self.n_mels = n_mels
self.sample_rate = sample_rate self.sample_rate = sample_rate
...@@ -341,7 +349,7 @@ class InverseMelScale(torch.nn.Module): ...@@ -341,7 +349,7 @@ class InverseMelScale(torch.nn.Module):
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate) fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm)
self.register_buffer('fb', fb) self.register_buffer('fb', fb)
def forward(self, melspec: Tensor) -> Tensor: def forward(self, melspec: Tensor) -> Tensor:
...@@ -418,6 +426,8 @@ class MelSpectrogram(torch.nn.Module): ...@@ -418,6 +426,8 @@ class MelSpectrogram(torch.nn.Module):
:attr:`center` is ``True``. Default: ``"reflect"`` :attr:`center` is ``True``. Default: ``"reflect"``
onesided (bool, optional): controls whether to return half of results to onesided (bool, optional): controls whether to return half of results to
avoid redundancy. Default: ``True`` avoid redundancy. Default: ``True``
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
Example Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True) >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
...@@ -440,7 +450,8 @@ class MelSpectrogram(torch.nn.Module): ...@@ -440,7 +450,8 @@ class MelSpectrogram(torch.nn.Module):
wkwargs: Optional[dict] = None, wkwargs: Optional[dict] = None,
center: bool = True, center: bool = True,
pad_mode: str = "reflect", pad_mode: str = "reflect",
onesided: bool = True) -> None: onesided: bool = True,
norm: Optional[str] = None) -> None:
super(MelSpectrogram, self).__init__() super(MelSpectrogram, self).__init__()
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.n_fft = n_fft self.n_fft = n_fft
...@@ -457,7 +468,7 @@ class MelSpectrogram(torch.nn.Module): ...@@ -457,7 +468,7 @@ class MelSpectrogram(torch.nn.Module):
pad=self.pad, window_fn=window_fn, power=self.power, pad=self.pad, window_fn=window_fn, power=self.power,
normalized=self.normalized, wkwargs=wkwargs, normalized=self.normalized, wkwargs=wkwargs,
center=center, pad_mode=pad_mode, onesided=onesided) center=center, pad_mode=pad_mode, onesided=onesided)
self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1) self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1, norm)
def forward(self, waveform: Tensor) -> Tensor: def forward(self, waveform: Tensor) -> Tensor:
r""" r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment