Unverified Commit 5f5df1d6 authored by moto's avatar moto Committed by GitHub
Browse files

Use torch.testing.assert_allclose (#513)

* grep -l 'torch.allclose' -r test | xargs sed -i 's/assert torch.allclose/torch.testing.assert_allclose/g'

* grep -l 'torch.allclose' -r test | xargs sed -i 's/self.assertTrue(torch.allclose(\(.*\)))/torch.testing.assert_allclose(\1)/g'

* Fix missing atol/rtol, wrong shape, argument order. Remove redundant shape assertions
parent bc1ffb11
...@@ -23,8 +23,7 @@ def _test_batch_shape(functional, tensor, *args, atol=1e-8, rtol=1e-5, **kwargs) ...@@ -23,8 +23,7 @@ def _test_batch_shape(functional, tensor, *args, atol=1e-8, rtol=1e-5, **kwargs)
torch.random.manual_seed(42) torch.random.manual_seed(42)
computed = functional(tensors.clone(), *args, **kwargs) computed = functional(tensors.clone(), *args, **kwargs)
assert expected.shape == computed.shape, (expected.shape, computed.shape) torch.testing.assert_allclose(computed, expected, rtol=rtol, atol=atol)
assert torch.allclose(expected, computed, atol=atol, rtol=rtol)
return tensors, expected return tensors, expected
...@@ -43,8 +42,7 @@ def _test_batch(functional, tensor, *args, atol=1e-8, rtol=1e-5, **kwargs): ...@@ -43,8 +42,7 @@ def _test_batch(functional, tensor, *args, atol=1e-8, rtol=1e-5, **kwargs):
torch.random.manual_seed(42) torch.random.manual_seed(42)
computed = functional(tensors.clone(), *args, **kwargs) computed = functional(tensors.clone(), *args, **kwargs)
assert expected.shape == computed.shape, (expected.shape, computed.shape) torch.testing.assert_allclose(computed, expected, rtol=rtol, atol=atol)
assert torch.allclose(expected, computed, atol=atol, rtol=rtol)
class TestFunctional(unittest.TestCase): class TestFunctional(unittest.TestCase):
...@@ -96,8 +94,7 @@ class TestTransforms(unittest.TestCase): ...@@ -96,8 +94,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform # Batch then transform
computed = torchaudio.transforms.AmplitudeToDB()(spec.repeat(3, 1, 1)) computed = torchaudio.transforms.AmplitudeToDB()(spec.repeat(3, 1, 1))
assert computed.shape == expected.shape, (computed.shape, expected.shape) torch.testing.assert_allclose(computed, expected)
assert torch.allclose(computed, expected)
def test_batch_Resample(self): def test_batch_Resample(self):
waveform = torch.randn(2, 2786) waveform = torch.randn(2, 2786)
...@@ -108,8 +105,7 @@ class TestTransforms(unittest.TestCase): ...@@ -108,8 +105,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform # Batch then transform
computed = torchaudio.transforms.Resample()(waveform.repeat(3, 1, 1)) computed = torchaudio.transforms.Resample()(waveform.repeat(3, 1, 1))
assert computed.shape == expected.shape, (computed.shape, expected.shape) torch.testing.assert_allclose(computed, expected)
assert torch.allclose(computed, expected)
def test_batch_MelScale(self): def test_batch_MelScale(self):
specgram = torch.randn(2, 31, 2786) specgram = torch.randn(2, 31, 2786)
...@@ -121,8 +117,7 @@ class TestTransforms(unittest.TestCase): ...@@ -121,8 +117,7 @@ class TestTransforms(unittest.TestCase):
computed = torchaudio.transforms.MelScale()(specgram.repeat(3, 1, 1, 1)) computed = torchaudio.transforms.MelScale()(specgram.repeat(3, 1, 1, 1))
# shape = (3, 2, 201, 1394) # shape = (3, 2, 201, 1394)
assert computed.shape == expected.shape, (computed.shape, expected.shape) torch.testing.assert_allclose(computed, expected)
assert torch.allclose(computed, expected)
def test_batch_InverseMelScale(self): def test_batch_InverseMelScale(self):
n_mels = 32 n_mels = 32
...@@ -136,11 +131,10 @@ class TestTransforms(unittest.TestCase): ...@@ -136,11 +131,10 @@ class TestTransforms(unittest.TestCase):
computed = torchaudio.transforms.InverseMelScale(n_stft, n_mels)(mel_spec.repeat(3, 1, 1, 1)) computed = torchaudio.transforms.InverseMelScale(n_stft, n_mels)(mel_spec.repeat(3, 1, 1, 1))
# shape = (3, 2, n_mels, 32) # shape = (3, 2, n_mels, 32)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
# Because InverseMelScale runs SGD on randomly initialized values so they do not yield # Because InverseMelScale runs SGD on randomly initialized values so they do not yield
# exactly same result. For this reason, tolerance is very relaxed here. # exactly same result. For this reason, tolerance is very relaxed here.
assert torch.allclose(computed, expected, atol=1.0) torch.testing.assert_allclose(computed, expected, atol=1.0, rtol=1e-5)
def test_batch_compute_deltas(self): def test_batch_compute_deltas(self):
specgram = torch.randn(2, 31, 2786) specgram = torch.randn(2, 31, 2786)
...@@ -152,8 +146,7 @@ class TestTransforms(unittest.TestCase): ...@@ -152,8 +146,7 @@ class TestTransforms(unittest.TestCase):
computed = torchaudio.transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1)) computed = torchaudio.transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1))
# shape = (3, 2, 201, 1394) # shape = (3, 2, 201, 1394)
assert computed.shape == expected.shape, (computed.shape, expected.shape) torch.testing.assert_allclose(computed, expected)
assert torch.allclose(computed, expected)
def test_batch_mulaw(self): def test_batch_mulaw(self):
test_filepath = os.path.join( test_filepath = os.path.join(
...@@ -169,8 +162,7 @@ class TestTransforms(unittest.TestCase): ...@@ -169,8 +162,7 @@ class TestTransforms(unittest.TestCase):
computed = torchaudio.transforms.MuLawEncoding()(waveform_batched) computed = torchaudio.transforms.MuLawEncoding()(waveform_batched)
# shape = (3, 2, 201, 1394) # shape = (3, 2, 201, 1394)
assert computed.shape == expected.shape, (computed.shape, expected.shape) torch.testing.assert_allclose(computed, expected)
assert torch.allclose(computed, expected)
# Single then transform then batch # Single then transform then batch
waveform_decoded = torchaudio.transforms.MuLawDecoding()(waveform_encoded) waveform_decoded = torchaudio.transforms.MuLawDecoding()(waveform_encoded)
...@@ -180,8 +172,7 @@ class TestTransforms(unittest.TestCase): ...@@ -180,8 +172,7 @@ class TestTransforms(unittest.TestCase):
computed = torchaudio.transforms.MuLawDecoding()(computed) computed = torchaudio.transforms.MuLawDecoding()(computed)
# shape = (3, 2, 201, 1394) # shape = (3, 2, 201, 1394)
assert computed.shape == expected.shape, (computed.shape, expected.shape) torch.testing.assert_allclose(computed, expected)
assert torch.allclose(computed, expected)
def test_batch_spectrogram(self): def test_batch_spectrogram(self):
test_filepath = os.path.join( test_filepath = os.path.join(
...@@ -193,9 +184,7 @@ class TestTransforms(unittest.TestCase): ...@@ -193,9 +184,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform # Batch then transform
computed = torchaudio.transforms.Spectrogram()(waveform.repeat(3, 1, 1)) computed = torchaudio.transforms.Spectrogram()(waveform.repeat(3, 1, 1))
torch.testing.assert_allclose(computed, expected)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
def test_batch_melspectrogram(self): def test_batch_melspectrogram(self):
test_filepath = os.path.join( test_filepath = os.path.join(
...@@ -207,9 +196,7 @@ class TestTransforms(unittest.TestCase): ...@@ -207,9 +196,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform # Batch then transform
computed = torchaudio.transforms.MelSpectrogram()(waveform.repeat(3, 1, 1)) computed = torchaudio.transforms.MelSpectrogram()(waveform.repeat(3, 1, 1))
torch.testing.assert_allclose(computed, expected)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -223,9 +210,7 @@ class TestTransforms(unittest.TestCase): ...@@ -223,9 +210,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform # Batch then transform
computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1)) computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1))
torch.testing.assert_allclose(computed, expected, atol=1e-5, rtol=1e-5)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected, atol=1e-5)
def test_batch_TimeStretch(self): def test_batch_TimeStretch(self):
test_filepath = os.path.join( test_filepath = os.path.join(
...@@ -260,8 +245,7 @@ class TestTransforms(unittest.TestCase): ...@@ -260,8 +245,7 @@ class TestTransforms(unittest.TestCase):
hop_length=512, hop_length=512,
)(complex_specgrams.repeat(3, 1, 1, 1, 1)) )(complex_specgrams.repeat(3, 1, 1, 1, 1))
assert computed.shape == expected.shape, (computed.shape, expected.shape) torch.testing.assert_allclose(computed, expected, atol=1e-5, rtol=1e-5)
assert torch.allclose(computed, expected, atol=1e-5)
def test_batch_Fade(self): def test_batch_Fade(self):
test_filepath = os.path.join( test_filepath = os.path.join(
...@@ -275,9 +259,7 @@ class TestTransforms(unittest.TestCase): ...@@ -275,9 +259,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform # Batch then transform
computed = torchaudio.transforms.Fade(fade_in_len, fade_out_len)(waveform.repeat(3, 1, 1)) computed = torchaudio.transforms.Fade(fade_in_len, fade_out_len)(waveform.repeat(3, 1, 1))
torch.testing.assert_allclose(computed, expected)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
def test_batch_Vol(self): def test_batch_Vol(self):
test_filepath = os.path.join( test_filepath = os.path.join(
...@@ -289,9 +271,7 @@ class TestTransforms(unittest.TestCase): ...@@ -289,9 +271,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform # Batch then transform
computed = torchaudio.transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1)) computed = torchaudio.transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1))
torch.testing.assert_allclose(computed, expected)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -77,7 +77,7 @@ class Test_Kaldi(unittest.TestCase): ...@@ -77,7 +77,7 @@ class Test_Kaldi(unittest.TestCase):
for r in range(m): for r in range(m):
extract_window(window, waveform, r, window_size, window_shift, snip_edges) extract_window(window, waveform, r, window_size, window_shift, snip_edges)
self.assertTrue(torch.allclose(window, output)) torch.testing.assert_allclose(window, output)
def test_get_strided(self): def test_get_strided(self):
# generate any combination where 0 < window_size <= num_samples and # generate any combination where 0 < window_size <= num_samples and
...@@ -104,7 +104,7 @@ class Test_Kaldi(unittest.TestCase): ...@@ -104,7 +104,7 @@ class Test_Kaldi(unittest.TestCase):
sound, sample_rate = torchaudio.load(test_filepath, normalization=False) sound, sample_rate = torchaudio.load(test_filepath, normalization=False)
print(y >> 16) print(y >> 16)
self.assertTrue(sample_rate == sr) self.assertTrue(sample_rate == sr)
self.assertTrue(torch.allclose(y, sound)) torch.testing.assert_allclose(y, sound)
def _print_diagnostic(self, output, expect_output): def _print_diagnostic(self, output, expect_output):
# given an output and expected output, it will print the absolute/relative errors (max and mean squared) # given an output and expected output, it will print the absolute/relative errors (max and mean squared)
...@@ -156,8 +156,7 @@ class Test_Kaldi(unittest.TestCase): ...@@ -156,8 +156,7 @@ class Test_Kaldi(unittest.TestCase):
output = get_output_fn(sound, args) output = get_output_fn(sound, args)
self._print_diagnostic(output, kaldi_output) self._print_diagnostic(output, kaldi_output)
self.assertTrue(output.shape, kaldi_output.shape) torch.testing.assert_allclose(output, kaldi_output, atol=atol, rtol=rtol)
self.assertTrue(torch.allclose(output, kaldi_output, atol=atol, rtol=rtol))
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -299,7 +298,7 @@ class Test_Kaldi(unittest.TestCase): ...@@ -299,7 +298,7 @@ class Test_Kaldi(unittest.TestCase):
ground_truth = ground_truth[..., n_to_trim:-n_to_trim] ground_truth = ground_truth[..., n_to_trim:-n_to_trim]
estimate = estimate[..., n_to_trim:-n_to_trim] estimate = estimate[..., n_to_trim:-n_to_trim]
self.assertTrue(torch.allclose(ground_truth, estimate, atol=atol, rtol=rtol)) torch.testing.assert_allclose(estimate, ground_truth, atol=atol, rtol=rtol)
def test_resample_waveform_downsample_accuracy(self): def test_resample_waveform_downsample_accuracy(self):
for i in range(1, 20): for i in range(1, 20):
...@@ -324,7 +323,7 @@ class Test_Kaldi(unittest.TestCase): ...@@ -324,7 +323,7 @@ class Test_Kaldi(unittest.TestCase):
for i in range(num_channels): for i in range(num_channels):
single_channel = sound * (i + 1) * 1.5 single_channel = sound * (i + 1) * 1.5
single_channel_sampled = kaldi.resample_waveform(single_channel, sample_rate, sample_rate // 2) single_channel_sampled = kaldi.resample_waveform(single_channel, sample_rate, sample_rate // 2)
self.assertTrue(torch.allclose(multi_sound_sampled[i, :], single_channel_sampled, rtol=1e-4)) torch.testing.assert_allclose(multi_sound_sampled[i, :], single_channel_sampled[0], rtol=1e-4, atol=1e-8)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -16,8 +16,7 @@ class TestComputeDeltas(unittest.TestCase): ...@@ -16,8 +16,7 @@ class TestComputeDeltas(unittest.TestCase):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]]) specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]])
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]]) expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]])
computed = F.compute_deltas(specgram, win_length=3) computed = F.compute_deltas(specgram, win_length=3)
assert computed.shape == expected.shape, (computed.shape, expected.shape) torch.testing.assert_allclose(computed, expected)
assert torch.allclose(computed, expected)
def test_two_channels(self): def test_two_channels(self):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0], specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0],
...@@ -25,16 +24,13 @@ class TestComputeDeltas(unittest.TestCase): ...@@ -25,16 +24,13 @@ class TestComputeDeltas(unittest.TestCase):
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5], expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5]]]) [0.5, 1.0, 1.0, 0.5]]])
computed = F.compute_deltas(specgram, win_length=3) computed = F.compute_deltas(specgram, win_length=3)
assert computed.shape == expected.shape, (computed.shape, expected.shape) torch.testing.assert_allclose(computed, expected)
assert torch.allclose(computed, expected)
def _compare_estimate(sound, estimate, atol=1e-6, rtol=1e-8): def _compare_estimate(sound, estimate, atol=1e-6, rtol=1e-8):
# trim sound for case when constructed signal is shorter than original # trim sound for case when constructed signal is shorter than original
sound = sound[..., :estimate.size(-1)] sound = sound[..., :estimate.size(-1)]
torch.testing.assert_allclose(estimate, sound, atol=atol, rtol=rtol)
assert sound.shape == estimate.shape, (sound.shape, estimate.shape)
assert torch.allclose(sound, estimate, atol=atol, rtol=rtol)
def _test_istft_is_inverse_of_stft(kwargs): def _test_istft_is_inverse_of_stft(kwargs):
...@@ -308,13 +304,13 @@ class TestDB_to_amplitude(unittest.TestCase): ...@@ -308,13 +304,13 @@ class TestDB_to_amplitude(unittest.TestCase):
db = F.amplitude_to_DB(torch.abs(x), multiplier, amin, db_multiplier, top_db=None) db = F.amplitude_to_DB(torch.abs(x), multiplier, amin, db_multiplier, top_db=None)
x2 = F.DB_to_amplitude(db, ref, power) x2 = F.DB_to_amplitude(db, ref, power)
self.assertTrue(torch.allclose(torch.abs(x), x2, atol=5e-5)) torch.testing.assert_allclose(x2, torch.abs(x), atol=5e-5, rtol=1e-5)
# Spectrogram amplitude -> DB -> amplitude # Spectrogram amplitude -> DB -> amplitude
db = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db=None) db = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db=None)
x2 = F.DB_to_amplitude(db, ref, power) x2 = F.DB_to_amplitude(db, ref, power)
self.assertTrue(torch.allclose(spec, x2, atol=5e-5)) torch.testing.assert_allclose(x2, spec, atol=5e-5, rtol=1e-5)
# Waveform power -> DB -> power # Waveform power -> DB -> power
multiplier = 10. multiplier = 10.
...@@ -323,13 +319,13 @@ class TestDB_to_amplitude(unittest.TestCase): ...@@ -323,13 +319,13 @@ class TestDB_to_amplitude(unittest.TestCase):
db = F.amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None) db = F.amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None)
x2 = F.DB_to_amplitude(db, ref, power) x2 = F.DB_to_amplitude(db, ref, power)
self.assertTrue(torch.allclose(torch.abs(x), x2, atol=5e-5)) torch.testing.assert_allclose(x2, torch.abs(x), atol=5e-5, rtol=1e-5)
# Spectrogram power -> DB -> power # Spectrogram power -> DB -> power
db = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db=None) db = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db=None)
x2 = F.DB_to_amplitude(db, ref, power) x2 = F.DB_to_amplitude(db, ref, power)
self.assertTrue(torch.allclose(spec, x2, atol=5e-5)) torch.testing.assert_allclose(x2, spec, atol=5e-5, rtol=1e-5)
@pytest.mark.parametrize('complex_tensor', [ @pytest.mark.parametrize('complex_tensor', [
...@@ -341,7 +337,7 @@ def test_complex_norm(complex_tensor, power): ...@@ -341,7 +337,7 @@ def test_complex_norm(complex_tensor, power):
expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2) expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2)
norm_tensor = F.complex_norm(complex_tensor, power) norm_tensor = F.complex_norm(complex_tensor, power)
assert torch.allclose(expected_norm_tensor, norm_tensor, atol=1e-5) torch.testing.assert_allclose(norm_tensor, expected_norm_tensor, atol=1e-5, rtol=1e-5)
@pytest.mark.parametrize('specgram', [ @pytest.mark.parametrize('specgram', [
......
...@@ -25,7 +25,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -25,7 +25,7 @@ class TestFunctionalFiltering(unittest.TestCase):
a_coeffs = torch.tensor([1, 0, 0, 0], dtype=dtype, device=device) a_coeffs = torch.tensor([1, 0, 0, 0], dtype=dtype, device=device)
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs) output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
assert torch.allclose(waveform[:, 0:-3], output_waveform[:, 3:], atol=1e-5) torch.testing.assert_allclose(output_waveform[:, 3:], waveform[:, 0:-3], atol=1e-5, rtol=1e-5)
def test_lfilter_basic(self): def test_lfilter_basic(self):
self._test_lfilter_basic(torch.float32, torch.device("cpu")) self._test_lfilter_basic(torch.float32, torch.device("cpu"))
...@@ -112,7 +112,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -112,7 +112,7 @@ class TestFunctionalFiltering(unittest.TestCase):
E.append_effect_to_chain("gain", [3]) E.append_effect_to_chain("gain", [3])
sox_gain_waveform = E.sox_build_flow_effects()[0] sox_gain_waveform = E.sox_build_flow_effects()[0]
assert torch.allclose(waveform_gain, sox_gain_waveform, atol=1e-04) torch.testing.assert_allclose(waveform_gain, sox_gain_waveform, atol=1e-04, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -128,13 +128,13 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -128,13 +128,13 @@ class TestFunctionalFiltering(unittest.TestCase):
E.append_effect_to_chain("dither", []) E.append_effect_to_chain("dither", [])
sox_dither_waveform = E.sox_build_flow_effects()[0] sox_dither_waveform = E.sox_build_flow_effects()[0]
assert torch.allclose(waveform_dithered, sox_dither_waveform, atol=1e-04) torch.testing.assert_allclose(waveform_dithered, sox_dither_waveform, atol=1e-04, rtol=1e-5)
E.clear_chain() E.clear_chain()
E.append_effect_to_chain("dither", ["-s"]) E.append_effect_to_chain("dither", ["-s"])
sox_dither_waveform_ns = E.sox_build_flow_effects()[0] sox_dither_waveform_ns = E.sox_build_flow_effects()[0]
assert torch.allclose(waveform_dithered_noiseshaped, sox_dither_waveform_ns, atol=1e-02) torch.testing.assert_allclose(waveform_dithered_noiseshaped, sox_dither_waveform_ns, atol=1e-02, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -157,7 +157,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -157,7 +157,7 @@ class TestFunctionalFiltering(unittest.TestCase):
E.append_effect_to_chain("dither", ["-s"]) E.append_effect_to_chain("dither", ["-s"])
wf_vctk_sox = E.sox_build_flow_effects()[0] wf_vctk_sox = E.sox_build_flow_effects()[0]
assert torch.allclose(wf_vctk, wf_vctk_sox, rtol=1e-03, atol=1e-03) torch.testing.assert_allclose(wf_vctk, wf_vctk_sox, rtol=1e-03, atol=1e-03)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -178,7 +178,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -178,7 +178,7 @@ class TestFunctionalFiltering(unittest.TestCase):
waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True) waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.lowpass_biquad(waveform, sample_rate, CUTOFF_FREQ) output_waveform = F.lowpass_biquad(waveform, sample_rate, CUTOFF_FREQ)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -199,7 +199,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -199,7 +199,7 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.highpass_biquad(waveform, sample_rate, CUTOFF_FREQ) output_waveform = F.highpass_biquad(waveform, sample_rate, CUTOFF_FREQ)
# TBD - this fails at the 1e-4 level, debug why # TBD - this fails at the 1e-4 level, debug why
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-3) torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-3, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -220,7 +220,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -220,7 +220,7 @@ class TestFunctionalFiltering(unittest.TestCase):
waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True) waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.allpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q) output_waveform = F.allpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -242,7 +242,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -242,7 +242,7 @@ class TestFunctionalFiltering(unittest.TestCase):
waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True) waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.bandpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN) output_waveform = F.bandpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -264,7 +264,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -264,7 +264,7 @@ class TestFunctionalFiltering(unittest.TestCase):
waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True) waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.bandpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN) output_waveform = F.bandpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -285,7 +285,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -285,7 +285,7 @@ class TestFunctionalFiltering(unittest.TestCase):
waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True) waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.bandreject_biquad(waveform, sample_rate, CENTRAL_FREQ, Q) output_waveform = F.bandreject_biquad(waveform, sample_rate, CENTRAL_FREQ, Q)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -307,7 +307,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -307,7 +307,7 @@ class TestFunctionalFiltering(unittest.TestCase):
waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True) waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.band_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, NOISE) output_waveform = F.band_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, NOISE)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -329,7 +329,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -329,7 +329,7 @@ class TestFunctionalFiltering(unittest.TestCase):
waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True) waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.band_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, NOISE) output_waveform = F.band_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, NOISE)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -351,7 +351,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -351,7 +351,7 @@ class TestFunctionalFiltering(unittest.TestCase):
waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True) waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.treble_biquad(waveform, sample_rate, GAIN, CENTRAL_FREQ, Q) output_waveform = F.treble_biquad(waveform, sample_rate, GAIN, CENTRAL_FREQ, Q)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -369,7 +369,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -369,7 +369,7 @@ class TestFunctionalFiltering(unittest.TestCase):
waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True) waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.deemph_biquad(waveform, sample_rate) output_waveform = F.deemph_biquad(waveform, sample_rate)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -387,7 +387,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -387,7 +387,7 @@ class TestFunctionalFiltering(unittest.TestCase):
waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True) waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.riaa_biquad(waveform, sample_rate) output_waveform = F.riaa_biquad(waveform, sample_rate)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -409,7 +409,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -409,7 +409,7 @@ class TestFunctionalFiltering(unittest.TestCase):
waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True) waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.equalizer_biquad(waveform, sample_rate, CENTER_FREQ, GAIN, Q) output_waveform = F.equalizer_biquad(waveform, sample_rate, CENTER_FREQ, GAIN, Q)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -435,7 +435,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -435,7 +435,7 @@ class TestFunctionalFiltering(unittest.TestCase):
waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2]) waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2])
) )
assert torch.allclose(waveform_sox_out, waveform_lfilter_out, atol=1e-4) torch.testing.assert_allclose(waveform_lfilter_out, waveform_sox_out, atol=1e-4, rtol=1e-5)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -51,7 +51,7 @@ class TestFunctional(_LibrosaMixin, unittest.TestCase): ...@@ -51,7 +51,7 @@ class TestFunctional(_LibrosaMixin, unittest.TestCase):
momentum=momentum, init=init, length=length) momentum=momentum, init=init, length=length)
lr_out = torch.from_numpy(lr_out).unsqueeze(0) lr_out = torch.from_numpy(lr_out).unsqueeze(0)
assert torch.allclose(ta_out, lr_out, atol=5e-5) torch.testing.assert_allclose(ta_out, lr_out, atol=5e-5, rtol=1e-5)
def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0): def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0):
librosa_fb = librosa.filters.mel(sr=sample_rate, librosa_fb = librosa.filters.mel(sr=sample_rate,
...@@ -68,7 +68,8 @@ class TestFunctional(_LibrosaMixin, unittest.TestCase): ...@@ -68,7 +68,8 @@ class TestFunctional(_LibrosaMixin, unittest.TestCase):
n_freqs=(n_fft // 2 + 1)) n_freqs=(n_fft // 2 + 1))
for i_mel_bank in range(n_mels): for i_mel_bank in range(n_mels):
assert torch.allclose(fb[:, i_mel_bank], torch.tensor(librosa_fb[i_mel_bank]), atol=1e-4) torch.testing.assert_allclose(fb[:, i_mel_bank], torch.tensor(librosa_fb[i_mel_bank]),
atol=1e-4, rtol=1e-5)
def test_create_fb(self): def test_create_fb(self):
self._test_create_fb() self._test_create_fb()
...@@ -91,18 +92,18 @@ class TestFunctional(_LibrosaMixin, unittest.TestCase): ...@@ -91,18 +92,18 @@ class TestFunctional(_LibrosaMixin, unittest.TestCase):
ta_out = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db) ta_out = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db)
lr_out = librosa.core.power_to_db(spec.numpy()) lr_out = librosa.core.power_to_db(spec.numpy())
lr_out = torch.from_numpy(lr_out).unsqueeze(0) lr_out = torch.from_numpy(lr_out)
assert torch.allclose(ta_out, lr_out, atol=5e-5) torch.testing.assert_allclose(ta_out, lr_out, atol=5e-5, rtol=1e-5)
# Amplitude to DB # Amplitude to DB
multiplier = 20.0 multiplier = 20.0
ta_out = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db) ta_out = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db)
lr_out = librosa.core.amplitude_to_db(spec.numpy()) lr_out = librosa.core.amplitude_to_db(spec.numpy())
lr_out = torch.from_numpy(lr_out).unsqueeze(0) lr_out = torch.from_numpy(lr_out)
assert torch.allclose(ta_out, lr_out, atol=5e-5) torch.testing.assert_allclose(ta_out, lr_out, atol=5e-5, rtol=1e-5)
@pytest.mark.parametrize('complex_specgrams', [ @pytest.mark.parametrize('complex_specgrams', [
...@@ -164,7 +165,7 @@ def _test_compatibilities(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate) ...@@ -164,7 +165,7 @@ def _test_compatibilities(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate)
y=sound_librosa, n_fft=n_fft, hop_length=hop_length, power=power) y=sound_librosa, n_fft=n_fft, hop_length=hop_length, power=power)
out_torch = spect_transform(sound).squeeze().cpu() out_torch = spect_transform(sound).squeeze().cpu()
assert torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5) torch.testing.assert_allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5, rtol=1e-5)
# test mel spectrogram # test mel spectrogram
melspect_transform = torchaudio.transforms.MelSpectrogram( melspect_transform = torchaudio.transforms.MelSpectrogram(
...@@ -175,25 +176,25 @@ def _test_compatibilities(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate) ...@@ -175,25 +176,25 @@ def _test_compatibilities(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate)
hop_length=hop_length, n_mels=n_mels, htk=True, norm=None) hop_length=hop_length, n_mels=n_mels, htk=True, norm=None)
librosa_mel_tensor = torch.from_numpy(librosa_mel) librosa_mel_tensor = torch.from_numpy(librosa_mel)
torch_mel = melspect_transform(sound).squeeze().cpu() torch_mel = melspect_transform(sound).squeeze().cpu()
assert torch.allclose( torch.testing.assert_allclose(
torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3) torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3, rtol=1e-5)
# test s2db # test s2db
power_to_db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.) power_to_db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
power_to_db_torch = power_to_db_transform(spect_transform(sound)).squeeze().cpu() power_to_db_torch = power_to_db_transform(spect_transform(sound)).squeeze().cpu()
power_to_db_librosa = librosa.core.spectrum.power_to_db(out_librosa) power_to_db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
assert torch.allclose(power_to_db_torch, torch.from_numpy(power_to_db_librosa), atol=5e-3) torch.testing.assert_allclose(power_to_db_torch, torch.from_numpy(power_to_db_librosa), atol=5e-3, rtol=1e-5)
mag_to_db_transform = torchaudio.transforms.AmplitudeToDB('magnitude', 80.) mag_to_db_transform = torchaudio.transforms.AmplitudeToDB('magnitude', 80.)
mag_to_db_torch = mag_to_db_transform(torch.abs(sound)).squeeze().cpu() mag_to_db_torch = mag_to_db_transform(torch.abs(sound)).squeeze().cpu()
mag_to_db_librosa = librosa.core.spectrum.amplitude_to_db(sound_librosa) mag_to_db_librosa = librosa.core.spectrum.amplitude_to_db(sound_librosa)
assert torch.allclose(mag_to_db_torch, torch.from_numpy(mag_to_db_librosa), atol=5e-3) torch.testing.assert_allclose(mag_to_db_torch, torch.from_numpy(mag_to_db_librosa), atol=5e-3, rtol=1e-5)
power_to_db_torch = power_to_db_transform(melspect_transform(sound)).squeeze().cpu() power_to_db_torch = power_to_db_transform(melspect_transform(sound)).squeeze().cpu()
db_librosa = librosa.core.spectrum.power_to_db(librosa_mel) db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
db_librosa_tensor = torch.from_numpy(db_librosa) db_librosa_tensor = torch.from_numpy(db_librosa)
assert torch.allclose( torch.testing.assert_allclose(
power_to_db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3) power_to_db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3, rtol=1e-5)
# test MFCC # test MFCC
melkwargs = {'hop_length': hop_length, 'n_fft': n_fft} melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
...@@ -214,8 +215,8 @@ def _test_compatibilities(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate) ...@@ -214,8 +215,8 @@ def _test_compatibilities(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate)
librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc) librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
torch_mfcc = mfcc_transform(sound).squeeze().cpu() torch_mfcc = mfcc_transform(sound).squeeze().cpu()
assert torch.allclose( torch.testing.assert_allclose(
torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3) torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3, rtol=1e-5)
class TestTransforms(_LibrosaMixin, unittest.TestCase): class TestTransforms(_LibrosaMixin, unittest.TestCase):
...@@ -289,7 +290,7 @@ class TestTransforms(_LibrosaMixin, unittest.TestCase): ...@@ -289,7 +290,7 @@ class TestTransforms(_LibrosaMixin, unittest.TestCase):
S=spec_lr, sr=sample_rate, n_fft=n_fft, hop_length=hop_length, S=spec_lr, sr=sample_rate, n_fft=n_fft, hop_length=hop_length,
win_length=n_fft, center=True, window='hann', n_mels=n_mels, htk=True, norm=None) win_length=n_fft, center=True, window='hann', n_mels=n_mels, htk=True, norm=None)
# Note: Using relaxed rtol instead of atol # Note: Using relaxed rtol instead of atol
assert torch.allclose(melspec_ta, torch.from_numpy(melspec_lr[None, ...]), rtol=1e-3) torch.testing.assert_allclose(melspec_ta, torch.from_numpy(melspec_lr[None, ...]), atol=1e-8, rtol=1e-3)
def test_InverseMelScale(self): def test_InverseMelScale(self):
"""InverseMelScale transform is comparable to that of librosa""" """InverseMelScale transform is comparable to that of librosa"""
...@@ -332,7 +333,7 @@ class TestTransforms(_LibrosaMixin, unittest.TestCase): ...@@ -332,7 +333,7 @@ class TestTransforms(_LibrosaMixin, unittest.TestCase):
# https://github.com/pytorch/audio/pull/366 for the discussion of the choice of algorithm # https://github.com/pytorch/audio/pull/366 for the discussion of the choice of algorithm
# https://github.com/pytorch/audio/pull/448/files#r385747021 for the distribution of P-inf # https://github.com/pytorch/audio/pull/448/files#r385747021 for the distribution of P-inf
# distance over frequencies. # distance over frequencies.
assert torch.allclose(spec_ta, spec_lr, atol=threshold) torch.testing.assert_allclose(spec_ta, spec_lr, atol=threshold, rtol=1e-5)
threshold = 1700.0 threshold = 1700.0
# This threshold was choosen empirically, based on the following observations # This threshold was choosen empirically, based on the following observations
......
...@@ -22,7 +22,7 @@ def _test_torchscript_functional_shape(py_method, *args, **kwargs): ...@@ -22,7 +22,7 @@ def _test_torchscript_functional_shape(py_method, *args, **kwargs):
def _test_torchscript_functional(py_method, *args, **kwargs): def _test_torchscript_functional(py_method, *args, **kwargs):
jit_out, py_out = _test_torchscript_functional_shape(py_method, *args, **kwargs) jit_out, py_out = _test_torchscript_functional_shape(py_method, *args, **kwargs)
assert torch.allclose(jit_out, py_out) torch.testing.assert_allclose(jit_out, py_out)
def _test_lfilter(waveform): def _test_lfilter(waveform):
...@@ -316,7 +316,7 @@ def _test_script_module(f, tensor, *args, **kwargs): ...@@ -316,7 +316,7 @@ def _test_script_module(f, tensor, *args, **kwargs):
py_out = py_method(tensor) py_out = py_method(tensor)
jit_out = jit_method(tensor) jit_out = jit_method(tensor)
assert torch.allclose(jit_out, py_out) torch.testing.assert_allclose(jit_out, py_out)
if RUN_CUDA: if RUN_CUDA:
...@@ -328,7 +328,7 @@ def _test_script_module(f, tensor, *args, **kwargs): ...@@ -328,7 +328,7 @@ def _test_script_module(f, tensor, *args, **kwargs):
py_out = py_method(tensor) py_out = py_method(tensor)
jit_out = jit_method(tensor) jit_out = jit_method(tensor)
assert torch.allclose(jit_out, py_out) torch.testing.assert_allclose(jit_out, py_out)
class TestTransforms(unittest.TestCase): class TestTransforms(unittest.TestCase):
......
...@@ -53,7 +53,7 @@ class Tester(unittest.TestCase): ...@@ -53,7 +53,7 @@ class Tester(unittest.TestCase):
mag_to_db_torch = mag_to_db_transform(torch.abs(waveform)) mag_to_db_torch = mag_to_db_transform(torch.abs(waveform))
power_to_db_torch = power_to_db_transform(torch.pow(waveform, 2)) power_to_db_torch = power_to_db_transform(torch.pow(waveform, 2))
self.assertTrue(torch.allclose(mag_to_db_torch, power_to_db_torch)) torch.testing.assert_allclose(mag_to_db_torch, power_to_db_torch)
def test_melscale_load_save(self): def test_melscale_load_save(self):
specgram = torch.ones(1, 1000, 100) specgram = torch.ones(1, 1000, 100)
...@@ -67,7 +67,7 @@ class Tester(unittest.TestCase): ...@@ -67,7 +67,7 @@ class Tester(unittest.TestCase):
fb_copy = melscale_transform_copy.fb fb_copy = melscale_transform_copy.fb
self.assertEqual(fb_copy.size(), (1000, 128)) self.assertEqual(fb_copy.size(), (1000, 128))
self.assertTrue(torch.allclose(fb, fb_copy)) torch.testing.assert_allclose(fb, fb_copy)
def test_melspectrogram_load_save(self): def test_melspectrogram_load_save(self):
waveform = self.waveform.float() waveform = self.waveform.float()
...@@ -83,10 +83,10 @@ class Tester(unittest.TestCase): ...@@ -83,10 +83,10 @@ class Tester(unittest.TestCase):
fb = mel_spectrogram_transform.mel_scale.fb fb = mel_spectrogram_transform.mel_scale.fb
fb_copy = mel_spectrogram_transform_copy.mel_scale.fb fb_copy = mel_spectrogram_transform_copy.mel_scale.fb
self.assertTrue(torch.allclose(window, window_copy)) torch.testing.assert_allclose(window, window_copy)
# the default for n_fft = 400 and n_mels = 128 # the default for n_fft = 400 and n_mels = 128
self.assertEqual(fb_copy.size(), (201, 128)) self.assertEqual(fb_copy.size(), (201, 128))
self.assertTrue(torch.allclose(fb, fb_copy)) torch.testing.assert_allclose(fb, fb_copy)
def test_mel2(self): def test_mel2(self):
top_db = 80. top_db = 80.
...@@ -217,7 +217,7 @@ class Tester(unittest.TestCase): ...@@ -217,7 +217,7 @@ class Tester(unittest.TestCase):
transform = transforms.ComputeDeltas(win_length=3) transform = transforms.ComputeDeltas(win_length=3)
computed = transform(specgram) computed = transform(specgram)
assert computed.shape == expected.shape, (computed.shape, expected.shape) assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected, atol=1e-6, rtol=1e-8) torch.testing.assert_allclose(computed, expected, atol=1e-6, rtol=1e-8)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment