Unverified Commit 8a86c463 authored by Borun Dev Chowdhury's avatar Borun Dev Chowdhury Committed by GitHub
Browse files

Raise error when scripting invalid MelScale (#1505)

parent 9d621fd3
...@@ -59,6 +59,10 @@ class Transforms(TempDirMixin, TestBaseMixin): ...@@ -59,6 +59,10 @@ class Transforms(TempDirMixin, TestBaseMixin):
spec = torch.rand((6, 201)) spec = torch.rand((6, 201))
self._assert_consistency(T.AmplitudeToDB(), spec) self._assert_consistency(T.AmplitudeToDB(), spec)
def test_MelScale_invalid(self):
with self.assertRaises(ValueError):
torch.jit.script(T.MelScale())
def test_MelScale(self): def test_MelScale(self):
spec_f = torch.rand((1, 201, 6)) spec_f = torch.rand((1, 201, 6))
self._assert_consistency(T.MelScale(n_stft=201), spec_f) self._assert_consistency(T.MelScale(n_stft=201), spec_f)
......
...@@ -284,6 +284,20 @@ class MelScale(torch.nn.Module): ...@@ -284,6 +284,20 @@ class MelScale(torch.nn.Module):
self.mel_scale) self.mel_scale)
self.register_buffer('fb', fb) self.register_buffer('fb', fb)
def __prepare_scriptable__(self):
r"""If `self.fb` is empty, the `forward` method will try to resize the parameter,
which does not work once the transform is scripted. However, this error does not happen
until the transform is executed. This is inconvenient especially if the resulting
TorchScript object is executed in other environments. Therefore, we check the
validity of `self.fb` here and fail if the resulting TS does not work.
Returns:
MelScale: self
"""
if self.fb.numel() == 0:
raise ValueError("n_stft must be provided at construction")
return self
def forward(self, specgram: Tensor) -> Tensor: def forward(self, specgram: Tensor) -> Tensor:
r""" r"""
Args: Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment