Unverified Commit 9d621fd3 authored by Kirill Ignatev's avatar Kirill Ignatev Committed by GitHub
Browse files

Add autograd tests for TimeMasking/FrequencyMasking (#1498)

parent 1f136671
......@@ -125,6 +125,31 @@ class AutogradTestMixin(TestBaseMixin):
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform], nondet_tol=1e-10)
@parameterized.expand([(T.TimeMasking,), (T.FrequencyMasking,)])
def test_masking(self, masking_transform):
sample_rate = 8000
n_fft = 400
spectrogram = get_spectrogram(
get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2),
n_fft=n_fft, power=1)
deterministic_transform = _DeterministicWrapper(masking_transform(400))
self.assert_grad(deterministic_transform, [spectrogram])
@parameterized.expand([(T.TimeMasking,), (T.FrequencyMasking,)])
def test_masking_iid(self, masking_transform):
sample_rate = 8000
n_fft = 400
specs = [get_spectrogram(
get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2, seed=i),
n_fft=n_fft, power=1)
for i in range(3)
]
batch = torch.stack(specs)
assert batch.ndim == 4
deterministic_transform = _DeterministicWrapper(masking_transform(400, True))
self.assert_grad(deterministic_transform, [batch])
def test_spectral_centroid(self):
sample_rate = 8000
transform = T.SpectralCentroid(sample_rate=sample_rate)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment