Unverified Commit 264ab15a authored by Brian White's avatar Brian White Committed by GitHub
Browse files

Add deprecation warning to MelScale for unset weight (#1515)

parent 02589246
import warnings
import torch import torch
import torchaudio.transforms as T import torchaudio.transforms as T
...@@ -39,7 +41,7 @@ class TransformsTestBase(TestBaseMixin): ...@@ -39,7 +41,7 @@ class TransformsTestBase(TestBaseMixin):
get_whitenoise(sample_rate=sample_rate, duration=1, n_channels=2), get_whitenoise(sample_rate=sample_rate, duration=1, n_channels=2),
n_fft=n_fft, power=power).to(self.device, self.dtype) n_fft=n_fft, power=power).to(self.device, self.dtype)
input = T.MelScale( input = T.MelScale(
n_mels=n_mels, sample_rate=sample_rate n_mels=n_mels, sample_rate=sample_rate, n_stft=n_stft
).to(self.device, self.dtype)(expected) ).to(self.device, self.dtype)(expected)
# Run transform # Run transform
...@@ -59,3 +61,19 @@ class TransformsTestBase(TestBaseMixin): ...@@ -59,3 +61,19 @@ class TransformsTestBase(TestBaseMixin):
assert _get_ratio(relative_diff < 1e-1) > 0.2 assert _get_ratio(relative_diff < 1e-1) > 0.2
assert _get_ratio(relative_diff < 1e-3) > 5e-3 assert _get_ratio(relative_diff < 1e-3) > 5e-3
assert _get_ratio(relative_diff < 1e-5) > 1e-5 assert _get_ratio(relative_diff < 1e-5) > 1e-5
def test_melscale_unset_weight_warning(self):
"""Issue a warning if MelScale initialized without a weight
As part of the deprecation of lazy intialization behavior (#1510),
issue a warning if `n_stft` is not set.
"""
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always")
T.MelScale(n_mels=64, sample_rate=8000)
assert len(caught_warnings) == 1
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always")
T.MelScale(n_mels=64, sample_rate=8000, n_stft=201)
assert len(caught_warnings) == 0
...@@ -283,6 +283,15 @@ class MelScale(torch.nn.Module): ...@@ -283,6 +283,15 @@ class MelScale(torch.nn.Module):
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
if n_stft is None or n_stft == 0:
warnings.warn(
'Initialization of torchaudio.transforms.MelScale with an unset weight '
'`n_stft=None` is deprecated and will be removed from a future release. '
'Please set a proper `n_stft` value. Typically this is `n_fft // 2 + 1`. '
'Refer to https://github.com/pytorch/audio/issues/1510 '
'for more details.'
)
fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm,
self.mel_scale) self.mel_scale)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment