You need to sign in or sign up before continuing.
Unverified Commit 264ab15a authored by Brian White's avatar Brian White Committed by GitHub
Browse files

Add deprecation warning to MelScale for unset weight (#1515)

parent 02589246
import warnings
import torch import torch
import torchaudio.transforms as T import torchaudio.transforms as T
...@@ -39,7 +41,7 @@ class TransformsTestBase(TestBaseMixin): ...@@ -39,7 +41,7 @@ class TransformsTestBase(TestBaseMixin):
get_whitenoise(sample_rate=sample_rate, duration=1, n_channels=2), get_whitenoise(sample_rate=sample_rate, duration=1, n_channels=2),
n_fft=n_fft, power=power).to(self.device, self.dtype) n_fft=n_fft, power=power).to(self.device, self.dtype)
input = T.MelScale( input = T.MelScale(
n_mels=n_mels, sample_rate=sample_rate n_mels=n_mels, sample_rate=sample_rate, n_stft=n_stft
).to(self.device, self.dtype)(expected) ).to(self.device, self.dtype)(expected)
# Run transform # Run transform
...@@ -59,3 +61,19 @@ class TransformsTestBase(TestBaseMixin): ...@@ -59,3 +61,19 @@ class TransformsTestBase(TestBaseMixin):
assert _get_ratio(relative_diff < 1e-1) > 0.2 assert _get_ratio(relative_diff < 1e-1) > 0.2
assert _get_ratio(relative_diff < 1e-3) > 5e-3 assert _get_ratio(relative_diff < 1e-3) > 5e-3
assert _get_ratio(relative_diff < 1e-5) > 1e-5 assert _get_ratio(relative_diff < 1e-5) > 1e-5
def test_melscale_unset_weight_warning(self):
"""Issue a warning if MelScale initialized without a weight
As part of the deprecation of lazy intialization behavior (#1510),
issue a warning if `n_stft` is not set.
"""
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always")
T.MelScale(n_mels=64, sample_rate=8000)
assert len(caught_warnings) == 1
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always")
T.MelScale(n_mels=64, sample_rate=8000, n_stft=201)
assert len(caught_warnings) == 0
...@@ -283,6 +283,15 @@ class MelScale(torch.nn.Module): ...@@ -283,6 +283,15 @@ class MelScale(torch.nn.Module):
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
if n_stft is None or n_stft == 0:
warnings.warn(
'Initialization of torchaudio.transforms.MelScale with an unset weight '
'`n_stft=None` is deprecated and will be removed from a future release. '
'Please set a proper `n_stft` value. Typically this is `n_fft // 2 + 1`. '
'Refer to https://github.com/pytorch/audio/issues/1510 '
'for more details.'
)
fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm,
self.mel_scale) self.mel_scale)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment