"packaging/vscode:/vscode.git/clone" did not exist on "e4b9823f4ad2d752c38d6df4f4b901f97cfafd1a"
Unverified Commit 32b9cf80 authored by hwangjeff's avatar hwangjeff Committed by GitHub
Browse files

Remove lazy behavior from MelScale (#1636)



Rebases #1571; addresses #1569:

"In 0.9.0 we are deprecating the lazy behavior of MelScale because it can make an invalid 
TorchScript object and it does not align with the design of torchaudio. Now in master 
branch, we can remove the implementation."
Co-authored-by: default avatarPankaj Patil <pankaj.patil2099@hotmail.com>
Co-authored-by: default avatarmoto <855818+mthrok@users.noreply.github.com>
Co-authored-by: default avatarhwangjeff <jeffhwang@fb.com>
parent 8d374c4d
...@@ -33,7 +33,7 @@ class TestTransforms(common_utils.TorchaudioTestCase): ...@@ -33,7 +33,7 @@ class TestTransforms(common_utils.TorchaudioTestCase):
self.assertEqual(computed, expected) self.assertEqual(computed, expected)
def test_batch_MelScale(self): def test_batch_MelScale(self):
specgram = torch.randn(2, 31, 2786) specgram = torch.randn(2, 201, 256)
# Single then transform then batch # Single then transform then batch
expected = torchaudio.transforms.MelScale()(specgram).repeat(3, 1, 1, 1) expected = torchaudio.transforms.MelScale()(specgram).repeat(3, 1, 1, 1)
...@@ -41,7 +41,7 @@ class TestTransforms(common_utils.TorchaudioTestCase): ...@@ -41,7 +41,7 @@ class TestTransforms(common_utils.TorchaudioTestCase):
# Batch then transform # Batch then transform
computed = torchaudio.transforms.MelScale()(specgram.repeat(3, 1, 1, 1)) computed = torchaudio.transforms.MelScale()(specgram.repeat(3, 1, 1, 1))
# shape = (3, 2, 201, 1394) # shape = (3, 2, 128, 256)
self.assertEqual(computed, expected) self.assertEqual(computed, expected)
def test_batch_InverseMelScale(self): def test_batch_InverseMelScale(self):
......
...@@ -59,10 +59,6 @@ class Transforms(TempDirMixin, TestBaseMixin): ...@@ -59,10 +59,6 @@ class Transforms(TempDirMixin, TestBaseMixin):
spec = torch.rand((6, 201)) spec = torch.rand((6, 201))
self._assert_consistency(T.AmplitudeToDB(), spec) self._assert_consistency(T.AmplitudeToDB(), spec)
def test_MelScale_invalid(self):
with self.assertRaises(ValueError):
torch.jit.script(T.MelScale())
def test_MelScale(self): def test_MelScale(self):
spec_f = torch.rand((1, 201, 6)) spec_f = torch.rand((1, 201, 6))
self._assert_consistency(T.MelScale(n_stft=201), spec_f) self._assert_consistency(T.MelScale(n_stft=201), spec_f)
......
...@@ -55,17 +55,17 @@ class Tester(common_utils.TorchaudioTestCase): ...@@ -55,17 +55,17 @@ class Tester(common_utils.TorchaudioTestCase):
self.assertEqual(mag_to_db_torch, power_to_db_torch) self.assertEqual(mag_to_db_torch, power_to_db_torch)
def test_melscale_load_save(self): def test_melscale_load_save(self):
specgram = torch.ones(1, 1000, 100) specgram = torch.ones(1, 201, 100)
melscale_transform = transforms.MelScale() melscale_transform = transforms.MelScale()
melscale_transform(specgram) melscale_transform(specgram)
melscale_transform_copy = transforms.MelScale(n_stft=1000) melscale_transform_copy = transforms.MelScale()
melscale_transform_copy.load_state_dict(melscale_transform.state_dict()) melscale_transform_copy.load_state_dict(melscale_transform.state_dict())
fb = melscale_transform.fb fb = melscale_transform.fb
fb_copy = melscale_transform_copy.fb fb_copy = melscale_transform_copy.fb
self.assertEqual(fb_copy.size(), (1000, 128)) self.assertEqual(fb_copy.size(), (201, 128))
self.assertEqual(fb, fb_copy) self.assertEqual(fb, fb_copy)
def test_melspectrogram_load_save(self): def test_melspectrogram_load_save(self):
......
import warnings
import torch import torch
import torchaudio.transforms as T import torchaudio.transforms as T
...@@ -63,22 +61,6 @@ class TransformsTestBase(TestBaseMixin): ...@@ -63,22 +61,6 @@ class TransformsTestBase(TestBaseMixin):
assert _get_ratio(relative_diff < 1e-3) > 5e-3 assert _get_ratio(relative_diff < 1e-3) > 5e-3
assert _get_ratio(relative_diff < 1e-5) > 1e-5 assert _get_ratio(relative_diff < 1e-5) > 1e-5
def test_melscale_unset_weight_warning(self):
"""Issue a warning if MelScale initialized without a weight
As part of the deprecation of lazy intialization behavior (#1510),
issue a warning if `n_stft` is not set.
"""
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always")
T.MelScale(n_mels=64, sample_rate=8000)
assert len(caught_warnings) == 1
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always")
T.MelScale(n_mels=64, sample_rate=8000, n_stft=201)
assert len(caught_warnings) == 0
@nested_params( @nested_params(
["sinc_interpolation", "kaiser_window"], ["sinc_interpolation", "kaiser_window"],
[16000, 44100], [16000, 44100],
......
...@@ -244,9 +244,8 @@ class MelScale(torch.nn.Module): ...@@ -244,9 +244,8 @@ class MelScale(torch.nn.Module):
sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
f_min (float, optional): Minimum frequency. (Default: ``0.``) f_min (float, optional): Minimum frequency. (Default: ``0.``)
f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``) f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
n_stft (int, optional): Number of bins in STFT. Calculated from first input n_stft (int, optional): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. (Default: ``201``)
if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``) norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``) (area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
""" """
...@@ -257,7 +256,7 @@ class MelScale(torch.nn.Module): ...@@ -257,7 +256,7 @@ class MelScale(torch.nn.Module):
sample_rate: int = 16000, sample_rate: int = 16000,
f_min: float = 0., f_min: float = 0.,
f_max: Optional[float] = None, f_max: Optional[float] = None,
n_stft: Optional[int] = None, n_stft: int = 201,
norm: Optional[str] = None, norm: Optional[str] = None,
mel_scale: str = "htk") -> None: mel_scale: str = "htk") -> None:
super(MelScale, self).__init__() super(MelScale, self).__init__()
...@@ -269,35 +268,11 @@ class MelScale(torch.nn.Module): ...@@ -269,35 +268,11 @@ class MelScale(torch.nn.Module):
self.mel_scale = mel_scale self.mel_scale = mel_scale
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
fb = F.create_fb_matrix(
if n_stft is None or n_stft == 0:
warnings.warn(
'Initialization of torchaudio.transforms.MelScale with an unset weight '
'`n_stft=None` is deprecated and will be removed in release 0.10. '
'Please set a proper `n_stft` value. Typically this is `n_fft // 2 + 1`. '
'Refer to https://github.com/pytorch/audio/issues/1510 '
'for more details.'
)
fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm,
self.mel_scale) self.mel_scale)
self.register_buffer('fb', fb) self.register_buffer('fb', fb)
def __prepare_scriptable__(self):
r"""If `self.fb` is empty, the `forward` method will try to resize the parameter,
which does not work once the transform is scripted. However, this error does not happen
until the transform is executed. This is inconvenient especially if the resulting
TorchScript object is executed in other environments. Therefore, we check the
validity of `self.fb` here and fail if the resulting TS does not work.
Returns:
MelScale: self
"""
if self.fb.numel() == 0:
raise ValueError("n_stft must be provided at construction")
return self
def forward(self, specgram: Tensor) -> Tensor: def forward(self, specgram: Tensor) -> Tensor:
r""" r"""
Args: Args:
...@@ -311,14 +286,6 @@ class MelScale(torch.nn.Module): ...@@ -311,14 +286,6 @@ class MelScale(torch.nn.Module):
shape = specgram.size() shape = specgram.size()
specgram = specgram.reshape(-1, shape[-2], shape[-1]) specgram = specgram.reshape(-1, shape[-2], shape[-1])
if self.fb.numel() == 0:
tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max,
self.n_mels, self.sample_rate, self.norm,
self.mel_scale)
# Attributes cannot be reassigned outside __init__ so workaround
self.fb.resize_(tmp_fb.size())
self.fb.copy_(tmp_fb)
# (channel, frequency, time).transpose(...) dot (frequency, n_mels) # (channel, frequency, time).transpose(...) dot (frequency, n_mels)
# -> (channel, time, n_mels).transpose(...) # -> (channel, time, n_mels).transpose(...)
mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2) mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment