Unverified Commit 5f5df1d6 authored by moto's avatar moto Committed by GitHub
Browse files

Use torch.testing.assert_allclose (#513)

* grep -l 'torch.allclose' -r test | xargs sed -i 's/assert torch.allclose/torch.testing.assert_allclose/g'

* grep -l 'torch.allclose' -r test | xargs sed -i 's/self.assertTrue(torch.allclose(\(.*\)))/torch.testing.assert_allclose(\1)/g'

* Fix missing atol/rtol, wrong shape, argument order. Remove redundant shape assertions
parent bc1ffb11
......@@ -23,8 +23,7 @@ def _test_batch_shape(functional, tensor, *args, atol=1e-8, rtol=1e-5, **kwargs)
torch.random.manual_seed(42)
computed = functional(tensors.clone(), *args, **kwargs)
assert expected.shape == computed.shape, (expected.shape, computed.shape)
assert torch.allclose(expected, computed, atol=atol, rtol=rtol)
torch.testing.assert_allclose(computed, expected, rtol=rtol, atol=atol)
return tensors, expected
......@@ -43,8 +42,7 @@ def _test_batch(functional, tensor, *args, atol=1e-8, rtol=1e-5, **kwargs):
torch.random.manual_seed(42)
computed = functional(tensors.clone(), *args, **kwargs)
assert expected.shape == computed.shape, (expected.shape, computed.shape)
assert torch.allclose(expected, computed, atol=atol, rtol=rtol)
torch.testing.assert_allclose(computed, expected, rtol=rtol, atol=atol)
class TestFunctional(unittest.TestCase):
......@@ -96,8 +94,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform
computed = torchaudio.transforms.AmplitudeToDB()(spec.repeat(3, 1, 1))
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
torch.testing.assert_allclose(computed, expected)
def test_batch_Resample(self):
waveform = torch.randn(2, 2786)
......@@ -108,8 +105,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform
computed = torchaudio.transforms.Resample()(waveform.repeat(3, 1, 1))
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
torch.testing.assert_allclose(computed, expected)
def test_batch_MelScale(self):
specgram = torch.randn(2, 31, 2786)
......@@ -121,8 +117,7 @@ class TestTransforms(unittest.TestCase):
computed = torchaudio.transforms.MelScale()(specgram.repeat(3, 1, 1, 1))
# shape = (3, 2, 201, 1394)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
torch.testing.assert_allclose(computed, expected)
def test_batch_InverseMelScale(self):
n_mels = 32
......@@ -136,11 +131,10 @@ class TestTransforms(unittest.TestCase):
computed = torchaudio.transforms.InverseMelScale(n_stft, n_mels)(mel_spec.repeat(3, 1, 1, 1))
# shape = (3, 2, n_mels, 32)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
# Because InverseMelScale runs SGD on randomly initialized values so they do not yield
# exactly same result. For this reason, tolerance is very relaxed here.
assert torch.allclose(computed, expected, atol=1.0)
torch.testing.assert_allclose(computed, expected, atol=1.0, rtol=1e-5)
def test_batch_compute_deltas(self):
specgram = torch.randn(2, 31, 2786)
......@@ -152,8 +146,7 @@ class TestTransforms(unittest.TestCase):
computed = torchaudio.transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1))
# shape = (3, 2, 201, 1394)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
torch.testing.assert_allclose(computed, expected)
def test_batch_mulaw(self):
test_filepath = os.path.join(
......@@ -169,8 +162,7 @@ class TestTransforms(unittest.TestCase):
computed = torchaudio.transforms.MuLawEncoding()(waveform_batched)
# shape = (3, 2, 201, 1394)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
torch.testing.assert_allclose(computed, expected)
# Single then transform then batch
waveform_decoded = torchaudio.transforms.MuLawDecoding()(waveform_encoded)
......@@ -180,8 +172,7 @@ class TestTransforms(unittest.TestCase):
computed = torchaudio.transforms.MuLawDecoding()(computed)
# shape = (3, 2, 201, 1394)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
torch.testing.assert_allclose(computed, expected)
def test_batch_spectrogram(self):
test_filepath = os.path.join(
......@@ -193,9 +184,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform
computed = torchaudio.transforms.Spectrogram()(waveform.repeat(3, 1, 1))
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
torch.testing.assert_allclose(computed, expected)
def test_batch_melspectrogram(self):
test_filepath = os.path.join(
......@@ -207,9 +196,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform
computed = torchaudio.transforms.MelSpectrogram()(waveform.repeat(3, 1, 1))
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
torch.testing.assert_allclose(computed, expected)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -223,9 +210,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform
computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1))
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected, atol=1e-5)
torch.testing.assert_allclose(computed, expected, atol=1e-5, rtol=1e-5)
def test_batch_TimeStretch(self):
test_filepath = os.path.join(
......@@ -260,8 +245,7 @@ class TestTransforms(unittest.TestCase):
hop_length=512,
)(complex_specgrams.repeat(3, 1, 1, 1, 1))
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected, atol=1e-5)
torch.testing.assert_allclose(computed, expected, atol=1e-5, rtol=1e-5)
def test_batch_Fade(self):
test_filepath = os.path.join(
......@@ -275,9 +259,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform
computed = torchaudio.transforms.Fade(fade_in_len, fade_out_len)(waveform.repeat(3, 1, 1))
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
torch.testing.assert_allclose(computed, expected)
def test_batch_Vol(self):
test_filepath = os.path.join(
......@@ -289,9 +271,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform
computed = torchaudio.transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1))
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
torch.testing.assert_allclose(computed, expected)
if __name__ == '__main__':
......
......@@ -77,7 +77,7 @@ class Test_Kaldi(unittest.TestCase):
for r in range(m):
extract_window(window, waveform, r, window_size, window_shift, snip_edges)
self.assertTrue(torch.allclose(window, output))
torch.testing.assert_allclose(window, output)
def test_get_strided(self):
# generate any combination where 0 < window_size <= num_samples and
......@@ -104,7 +104,7 @@ class Test_Kaldi(unittest.TestCase):
sound, sample_rate = torchaudio.load(test_filepath, normalization=False)
print(y >> 16)
self.assertTrue(sample_rate == sr)
self.assertTrue(torch.allclose(y, sound))
torch.testing.assert_allclose(y, sound)
def _print_diagnostic(self, output, expect_output):
# given an output and expected output, it will print the absolute/relative errors (max and mean squared)
......@@ -156,8 +156,7 @@ class Test_Kaldi(unittest.TestCase):
output = get_output_fn(sound, args)
self._print_diagnostic(output, kaldi_output)
self.assertTrue(output.shape, kaldi_output.shape)
self.assertTrue(torch.allclose(output, kaldi_output, atol=atol, rtol=rtol))
torch.testing.assert_allclose(output, kaldi_output, atol=atol, rtol=rtol)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -299,7 +298,7 @@ class Test_Kaldi(unittest.TestCase):
ground_truth = ground_truth[..., n_to_trim:-n_to_trim]
estimate = estimate[..., n_to_trim:-n_to_trim]
self.assertTrue(torch.allclose(ground_truth, estimate, atol=atol, rtol=rtol))
torch.testing.assert_allclose(estimate, ground_truth, atol=atol, rtol=rtol)
def test_resample_waveform_downsample_accuracy(self):
for i in range(1, 20):
......@@ -324,7 +323,7 @@ class Test_Kaldi(unittest.TestCase):
for i in range(num_channels):
single_channel = sound * (i + 1) * 1.5
single_channel_sampled = kaldi.resample_waveform(single_channel, sample_rate, sample_rate // 2)
self.assertTrue(torch.allclose(multi_sound_sampled[i, :], single_channel_sampled, rtol=1e-4))
torch.testing.assert_allclose(multi_sound_sampled[i, :], single_channel_sampled[0], rtol=1e-4, atol=1e-8)
if __name__ == '__main__':
......
......@@ -16,8 +16,7 @@ class TestComputeDeltas(unittest.TestCase):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]])
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]])
computed = F.compute_deltas(specgram, win_length=3)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
torch.testing.assert_allclose(computed, expected)
def test_two_channels(self):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0],
......@@ -25,16 +24,13 @@ class TestComputeDeltas(unittest.TestCase):
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5]]])
computed = F.compute_deltas(specgram, win_length=3)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
torch.testing.assert_allclose(computed, expected)
def _compare_estimate(sound, estimate, atol=1e-6, rtol=1e-8):
# trim sound for case when constructed signal is shorter than original
sound = sound[..., :estimate.size(-1)]
assert sound.shape == estimate.shape, (sound.shape, estimate.shape)
assert torch.allclose(sound, estimate, atol=atol, rtol=rtol)
torch.testing.assert_allclose(estimate, sound, atol=atol, rtol=rtol)
def _test_istft_is_inverse_of_stft(kwargs):
......@@ -308,13 +304,13 @@ class TestDB_to_amplitude(unittest.TestCase):
db = F.amplitude_to_DB(torch.abs(x), multiplier, amin, db_multiplier, top_db=None)
x2 = F.DB_to_amplitude(db, ref, power)
self.assertTrue(torch.allclose(torch.abs(x), x2, atol=5e-5))
torch.testing.assert_allclose(x2, torch.abs(x), atol=5e-5, rtol=1e-5)
# Spectrogram amplitude -> DB -> amplitude
db = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db=None)
x2 = F.DB_to_amplitude(db, ref, power)
self.assertTrue(torch.allclose(spec, x2, atol=5e-5))
torch.testing.assert_allclose(x2, spec, atol=5e-5, rtol=1e-5)
# Waveform power -> DB -> power
multiplier = 10.
......@@ -323,13 +319,13 @@ class TestDB_to_amplitude(unittest.TestCase):
db = F.amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None)
x2 = F.DB_to_amplitude(db, ref, power)
self.assertTrue(torch.allclose(torch.abs(x), x2, atol=5e-5))
torch.testing.assert_allclose(x2, torch.abs(x), atol=5e-5, rtol=1e-5)
# Spectrogram power -> DB -> power
db = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db=None)
x2 = F.DB_to_amplitude(db, ref, power)
self.assertTrue(torch.allclose(spec, x2, atol=5e-5))
torch.testing.assert_allclose(x2, spec, atol=5e-5, rtol=1e-5)
@pytest.mark.parametrize('complex_tensor', [
......@@ -341,7 +337,7 @@ def test_complex_norm(complex_tensor, power):
expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2)
norm_tensor = F.complex_norm(complex_tensor, power)
assert torch.allclose(expected_norm_tensor, norm_tensor, atol=1e-5)
torch.testing.assert_allclose(norm_tensor, expected_norm_tensor, atol=1e-5, rtol=1e-5)
@pytest.mark.parametrize('specgram', [
......
......@@ -25,7 +25,7 @@ class TestFunctionalFiltering(unittest.TestCase):
a_coeffs = torch.tensor([1, 0, 0, 0], dtype=dtype, device=device)
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
assert torch.allclose(waveform[:, 0:-3], output_waveform[:, 3:], atol=1e-5)
torch.testing.assert_allclose(output_waveform[:, 3:], waveform[:, 0:-3], atol=1e-5, rtol=1e-5)
def test_lfilter_basic(self):
self._test_lfilter_basic(torch.float32, torch.device("cpu"))
......@@ -112,7 +112,7 @@ class TestFunctionalFiltering(unittest.TestCase):
E.append_effect_to_chain("gain", [3])
sox_gain_waveform = E.sox_build_flow_effects()[0]
assert torch.allclose(waveform_gain, sox_gain_waveform, atol=1e-04)
torch.testing.assert_allclose(waveform_gain, sox_gain_waveform, atol=1e-04, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -128,13 +128,13 @@ class TestFunctionalFiltering(unittest.TestCase):
E.append_effect_to_chain("dither", [])
sox_dither_waveform = E.sox_build_flow_effects()[0]
assert torch.allclose(waveform_dithered, sox_dither_waveform, atol=1e-04)
torch.testing.assert_allclose(waveform_dithered, sox_dither_waveform, atol=1e-04, rtol=1e-5)
E.clear_chain()
E.append_effect_to_chain("dither", ["-s"])
sox_dither_waveform_ns = E.sox_build_flow_effects()[0]
assert torch.allclose(waveform_dithered_noiseshaped, sox_dither_waveform_ns, atol=1e-02)
torch.testing.assert_allclose(waveform_dithered_noiseshaped, sox_dither_waveform_ns, atol=1e-02, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -157,7 +157,7 @@ class TestFunctionalFiltering(unittest.TestCase):
E.append_effect_to_chain("dither", ["-s"])
wf_vctk_sox = E.sox_build_flow_effects()[0]
assert torch.allclose(wf_vctk, wf_vctk_sox, rtol=1e-03, atol=1e-03)
torch.testing.assert_allclose(wf_vctk, wf_vctk_sox, rtol=1e-03, atol=1e-03)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -178,7 +178,7 @@ class TestFunctionalFiltering(unittest.TestCase):
waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.lowpass_biquad(waveform, sample_rate, CUTOFF_FREQ)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -199,7 +199,7 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.highpass_biquad(waveform, sample_rate, CUTOFF_FREQ)
# TBD - this fails at the 1e-4 level, debug why
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-3)
torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-3, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -220,7 +220,7 @@ class TestFunctionalFiltering(unittest.TestCase):
waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.allpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -242,7 +242,7 @@ class TestFunctionalFiltering(unittest.TestCase):
waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.bandpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -264,7 +264,7 @@ class TestFunctionalFiltering(unittest.TestCase):
waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.bandpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -285,7 +285,7 @@ class TestFunctionalFiltering(unittest.TestCase):
waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.bandreject_biquad(waveform, sample_rate, CENTRAL_FREQ, Q)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -307,7 +307,7 @@ class TestFunctionalFiltering(unittest.TestCase):
waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.band_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, NOISE)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -329,7 +329,7 @@ class TestFunctionalFiltering(unittest.TestCase):
waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.band_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, NOISE)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -351,7 +351,7 @@ class TestFunctionalFiltering(unittest.TestCase):
waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.treble_biquad(waveform, sample_rate, GAIN, CENTRAL_FREQ, Q)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -369,7 +369,7 @@ class TestFunctionalFiltering(unittest.TestCase):
waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.deemph_biquad(waveform, sample_rate)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -387,7 +387,7 @@ class TestFunctionalFiltering(unittest.TestCase):
waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.riaa_biquad(waveform, sample_rate)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -409,7 +409,7 @@ class TestFunctionalFiltering(unittest.TestCase):
waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.equalizer_biquad(waveform, sample_rate, CENTER_FREQ, GAIN, Q)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -435,7 +435,7 @@ class TestFunctionalFiltering(unittest.TestCase):
waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2])
)
assert torch.allclose(waveform_sox_out, waveform_lfilter_out, atol=1e-4)
torch.testing.assert_allclose(waveform_lfilter_out, waveform_sox_out, atol=1e-4, rtol=1e-5)
if __name__ == "__main__":
......
......@@ -51,7 +51,7 @@ class TestFunctional(_LibrosaMixin, unittest.TestCase):
momentum=momentum, init=init, length=length)
lr_out = torch.from_numpy(lr_out).unsqueeze(0)
assert torch.allclose(ta_out, lr_out, atol=5e-5)
torch.testing.assert_allclose(ta_out, lr_out, atol=5e-5, rtol=1e-5)
def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0):
librosa_fb = librosa.filters.mel(sr=sample_rate,
......@@ -68,7 +68,8 @@ class TestFunctional(_LibrosaMixin, unittest.TestCase):
n_freqs=(n_fft // 2 + 1))
for i_mel_bank in range(n_mels):
assert torch.allclose(fb[:, i_mel_bank], torch.tensor(librosa_fb[i_mel_bank]), atol=1e-4)
torch.testing.assert_allclose(fb[:, i_mel_bank], torch.tensor(librosa_fb[i_mel_bank]),
atol=1e-4, rtol=1e-5)
def test_create_fb(self):
self._test_create_fb()
......@@ -91,18 +92,18 @@ class TestFunctional(_LibrosaMixin, unittest.TestCase):
ta_out = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db)
lr_out = librosa.core.power_to_db(spec.numpy())
lr_out = torch.from_numpy(lr_out).unsqueeze(0)
lr_out = torch.from_numpy(lr_out)
assert torch.allclose(ta_out, lr_out, atol=5e-5)
torch.testing.assert_allclose(ta_out, lr_out, atol=5e-5, rtol=1e-5)
# Amplitude to DB
multiplier = 20.0
ta_out = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db)
lr_out = librosa.core.amplitude_to_db(spec.numpy())
lr_out = torch.from_numpy(lr_out).unsqueeze(0)
lr_out = torch.from_numpy(lr_out)
assert torch.allclose(ta_out, lr_out, atol=5e-5)
torch.testing.assert_allclose(ta_out, lr_out, atol=5e-5, rtol=1e-5)
@pytest.mark.parametrize('complex_specgrams', [
......@@ -164,7 +165,7 @@ def _test_compatibilities(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate)
y=sound_librosa, n_fft=n_fft, hop_length=hop_length, power=power)
out_torch = spect_transform(sound).squeeze().cpu()
assert torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5)
torch.testing.assert_allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5, rtol=1e-5)
# test mel spectrogram
melspect_transform = torchaudio.transforms.MelSpectrogram(
......@@ -175,25 +176,25 @@ def _test_compatibilities(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate)
hop_length=hop_length, n_mels=n_mels, htk=True, norm=None)
librosa_mel_tensor = torch.from_numpy(librosa_mel)
torch_mel = melspect_transform(sound).squeeze().cpu()
assert torch.allclose(
torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3)
torch.testing.assert_allclose(
torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3, rtol=1e-5)
# test s2db
power_to_db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
power_to_db_torch = power_to_db_transform(spect_transform(sound)).squeeze().cpu()
power_to_db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
assert torch.allclose(power_to_db_torch, torch.from_numpy(power_to_db_librosa), atol=5e-3)
torch.testing.assert_allclose(power_to_db_torch, torch.from_numpy(power_to_db_librosa), atol=5e-3, rtol=1e-5)
mag_to_db_transform = torchaudio.transforms.AmplitudeToDB('magnitude', 80.)
mag_to_db_torch = mag_to_db_transform(torch.abs(sound)).squeeze().cpu()
mag_to_db_librosa = librosa.core.spectrum.amplitude_to_db(sound_librosa)
assert torch.allclose(mag_to_db_torch, torch.from_numpy(mag_to_db_librosa), atol=5e-3)
torch.testing.assert_allclose(mag_to_db_torch, torch.from_numpy(mag_to_db_librosa), atol=5e-3, rtol=1e-5)
power_to_db_torch = power_to_db_transform(melspect_transform(sound)).squeeze().cpu()
db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
db_librosa_tensor = torch.from_numpy(db_librosa)
assert torch.allclose(
power_to_db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3)
torch.testing.assert_allclose(
power_to_db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3, rtol=1e-5)
# test MFCC
melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
......@@ -214,8 +215,8 @@ def _test_compatibilities(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate)
librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
torch_mfcc = mfcc_transform(sound).squeeze().cpu()
assert torch.allclose(
torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3)
torch.testing.assert_allclose(
torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3, rtol=1e-5)
class TestTransforms(_LibrosaMixin, unittest.TestCase):
......@@ -289,7 +290,7 @@ class TestTransforms(_LibrosaMixin, unittest.TestCase):
S=spec_lr, sr=sample_rate, n_fft=n_fft, hop_length=hop_length,
win_length=n_fft, center=True, window='hann', n_mels=n_mels, htk=True, norm=None)
# Note: Using relaxed rtol instead of atol
assert torch.allclose(melspec_ta, torch.from_numpy(melspec_lr[None, ...]), rtol=1e-3)
torch.testing.assert_allclose(melspec_ta, torch.from_numpy(melspec_lr[None, ...]), atol=1e-8, rtol=1e-3)
def test_InverseMelScale(self):
"""InverseMelScale transform is comparable to that of librosa"""
......@@ -332,7 +333,7 @@ class TestTransforms(_LibrosaMixin, unittest.TestCase):
# https://github.com/pytorch/audio/pull/366 for the discussion of the choice of algorithm
# https://github.com/pytorch/audio/pull/448/files#r385747021 for the distribution of P-inf
# distance over frequencies.
assert torch.allclose(spec_ta, spec_lr, atol=threshold)
torch.testing.assert_allclose(spec_ta, spec_lr, atol=threshold, rtol=1e-5)
threshold = 1700.0
# This threshold was choosen empirically, based on the following observations
......
......@@ -22,7 +22,7 @@ def _test_torchscript_functional_shape(py_method, *args, **kwargs):
def _test_torchscript_functional(py_method, *args, **kwargs):
jit_out, py_out = _test_torchscript_functional_shape(py_method, *args, **kwargs)
assert torch.allclose(jit_out, py_out)
torch.testing.assert_allclose(jit_out, py_out)
def _test_lfilter(waveform):
......@@ -316,7 +316,7 @@ def _test_script_module(f, tensor, *args, **kwargs):
py_out = py_method(tensor)
jit_out = jit_method(tensor)
assert torch.allclose(jit_out, py_out)
torch.testing.assert_allclose(jit_out, py_out)
if RUN_CUDA:
......@@ -328,7 +328,7 @@ def _test_script_module(f, tensor, *args, **kwargs):
py_out = py_method(tensor)
jit_out = jit_method(tensor)
assert torch.allclose(jit_out, py_out)
torch.testing.assert_allclose(jit_out, py_out)
class TestTransforms(unittest.TestCase):
......
......@@ -53,7 +53,7 @@ class Tester(unittest.TestCase):
mag_to_db_torch = mag_to_db_transform(torch.abs(waveform))
power_to_db_torch = power_to_db_transform(torch.pow(waveform, 2))
self.assertTrue(torch.allclose(mag_to_db_torch, power_to_db_torch))
torch.testing.assert_allclose(mag_to_db_torch, power_to_db_torch)
def test_melscale_load_save(self):
specgram = torch.ones(1, 1000, 100)
......@@ -67,7 +67,7 @@ class Tester(unittest.TestCase):
fb_copy = melscale_transform_copy.fb
self.assertEqual(fb_copy.size(), (1000, 128))
self.assertTrue(torch.allclose(fb, fb_copy))
torch.testing.assert_allclose(fb, fb_copy)
def test_melspectrogram_load_save(self):
waveform = self.waveform.float()
......@@ -83,10 +83,10 @@ class Tester(unittest.TestCase):
fb = mel_spectrogram_transform.mel_scale.fb
fb_copy = mel_spectrogram_transform_copy.mel_scale.fb
self.assertTrue(torch.allclose(window, window_copy))
torch.testing.assert_allclose(window, window_copy)
# the default for n_fft = 400 and n_mels = 128
self.assertEqual(fb_copy.size(), (201, 128))
self.assertTrue(torch.allclose(fb, fb_copy))
torch.testing.assert_allclose(fb, fb_copy)
def test_mel2(self):
top_db = 80.
......@@ -217,7 +217,7 @@ class Tester(unittest.TestCase):
transform = transforms.ComputeDeltas(win_length=3)
computed = transform(specgram)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected, atol=1e-6, rtol=1e-8)
torch.testing.assert_allclose(computed, expected, atol=1e-6, rtol=1e-8)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment