Unverified Commit 2897f366 authored by discort's avatar discort Committed by GitHub
Browse files

Replace torch.assert_allclose with assertEqual (#1387)

parent f2b75427
...@@ -45,9 +45,9 @@ def extract_window(window, wave, f, frame_length, frame_shift, snip_edges): ...@@ -45,9 +45,9 @@ def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
window[f, s] = wave[s_in_wave] window[f, s] = wave[s_in_wave]
@common_utils.skipIfNoSoxBackend @common_utils.skipIfNoSox
class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase): class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
backend = 'sox' backend = 'sox_io'
kaldi_output_dir = common_utils.get_asset_path('kaldi') kaldi_output_dir = common_utils.get_asset_path('kaldi')
test_filepath = common_utils.get_asset_path('kaldi_file.wav') test_filepath = common_utils.get_asset_path('kaldi_file.wav')
...@@ -91,7 +91,7 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase): ...@@ -91,7 +91,7 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
for r in range(m): for r in range(m):
extract_window(window, waveform, r, window_size, window_shift, snip_edges) extract_window(window, waveform, r, window_size, window_shift, snip_edges)
torch.testing.assert_allclose(window, output) self.assertEqual(window, output)
def test_get_strided(self): def test_get_strided(self):
# generate any combination where 0 < window_size <= num_samples and # generate any combination where 0 < window_size <= num_samples and
...@@ -116,7 +116,7 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase): ...@@ -116,7 +116,7 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
sound, sample_rate = torchaudio.load(self.test_filepath, normalization=False) sound, sample_rate = torchaudio.load(self.test_filepath, normalization=False)
print(y >> 16) print(y >> 16)
self.assertTrue(sample_rate == sr) self.assertTrue(sample_rate == sr)
torch.testing.assert_allclose(y, sound) self.assertEqual(y, sound)
def _print_diagnostic(self, output, expect_output): def _print_diagnostic(self, output, expect_output):
# given an output and expected output, it will print the absolute/relative errors (max and mean squared) # given an output and expected output, it will print the absolute/relative errors (max and mean squared)
...@@ -170,7 +170,7 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase): ...@@ -170,7 +170,7 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
output = get_output_fn(sound, args) output = get_output_fn(sound, args)
self._print_diagnostic(output, kaldi_output) self._print_diagnostic(output, kaldi_output)
torch.testing.assert_allclose(output, kaldi_output, atol=atol, rtol=rtol) self.assertEqual(output, kaldi_output, atol=atol, rtol=rtol)
def test_mfcc_empty(self): def test_mfcc_empty(self):
# Passing in an empty tensor should result in an error # Passing in an empty tensor should result in an error
...@@ -178,7 +178,7 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase): ...@@ -178,7 +178,7 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
def test_resample_waveform(self): def test_resample_waveform(self):
def get_output_fn(sound, args): def get_output_fn(sound, args):
output = kaldi.resample_waveform(sound, args[1], args[2]) output = kaldi.resample_waveform(sound.to(torch.float32), args[1], args[2])
return output return output
self._compliance_test_helper(self.test2_filepath, 'resample', 32, 3, get_output_fn, atol=1e-2, rtol=1e-5) self._compliance_test_helper(self.test2_filepath, 'resample', 32, 3, get_output_fn, atol=1e-2, rtol=1e-5)
...@@ -221,7 +221,7 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase): ...@@ -221,7 +221,7 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
ground_truth = ground_truth[..., n_to_trim:-n_to_trim] ground_truth = ground_truth[..., n_to_trim:-n_to_trim]
estimate = estimate[..., n_to_trim:-n_to_trim] estimate = estimate[..., n_to_trim:-n_to_trim]
torch.testing.assert_allclose(estimate, ground_truth, atol=atol, rtol=rtol) self.assertEqual(estimate, ground_truth, atol=atol, rtol=rtol)
def test_resample_waveform_downsample_accuracy(self): def test_resample_waveform_downsample_accuracy(self):
for i in range(1, 20): for i in range(1, 20):
...@@ -246,4 +246,4 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase): ...@@ -246,4 +246,4 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
single_channel = self.test1_signal * (i + 1) * 1.5 single_channel = self.test1_signal * (i + 1) * 1.5
single_channel_sampled = kaldi.resample_waveform(single_channel, self.test1_signal_sr, single_channel_sampled = kaldi.resample_waveform(single_channel, self.test1_signal_sr,
self.test1_signal_sr // 2) self.test1_signal_sr // 2)
torch.testing.assert_allclose(multi_sound_sampled[i, :], single_channel_sampled[0], rtol=1e-4, atol=1e-7) self.assertEqual(multi_sound_sampled[i, :], single_channel_sampled[0], rtol=1e-4, atol=1e-7)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment