Unverified Commit 7ef1178e authored by moto's avatar moto Committed by GitHub
Browse files

Refactor test_functional module (#503)

parent 9b288109
...@@ -10,52 +10,50 @@ import pytest ...@@ -10,52 +10,50 @@ import pytest
import common_utils import common_utils
class TestFunctional(unittest.TestCase): class TestComputeDeltas(unittest.TestCase):
data_sizes = [(2, 20), (3, 15), (4, 10)] """Test suite for correctness of compute_deltas"""
number_of_trials = 100 def test_one_channel(self):
specgram = torch.tensor([1., 2., 3., 4.]) specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]])
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]])
computed = F.compute_deltas(specgram, win_length=3)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
test_dirpath, test_dir = common_utils.create_temp_assets_dir() def test_two_channels(self):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0],
[1.0, 2.0, 3.0, 4.0]]])
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5]]])
computed = F.compute_deltas(specgram, win_length=3)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
test_filepath = os.path.join(test_dirpath, 'assets',
'steam-train-whistle-daniel_simon.wav')
waveform_train, sr_train = torchaudio.load(test_filepath)
def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8): def _compare_estimate(sound, estimate, atol=1e-6, rtol=1e-8):
computed = F.compute_deltas(specgram, win_length=win_length) # trim sound for case when constructed signal is shorter than original
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) sound = sound[..., :estimate.size(-1)]
torch.testing.assert_allclose(computed, expected, atol=atol, rtol=rtol)
def test_compute_deltas_onechannel(self): assert sound.shape == estimate.shape, (sound.shape, estimate.shape)
specgram = self.specgram.unsqueeze(0).unsqueeze(0) assert torch.allclose(sound, estimate, atol=atol, rtol=rtol)
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]])
self._test_compute_deltas(specgram, expected)
def test_compute_deltas_twochannel(self):
specgram = self.specgram.repeat(1, 2, 1)
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5]]])
self._test_compute_deltas(specgram, expected)
def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8): def _test_istft_is_inverse_of_stft(kwargs):
# trim sound for case when constructed signal is shorter than original # generates a random sound signal for each tril and then does the stft/istft
sound = sound[..., :estimate.size(-1)] # operation to check whether we can reconstruct signal
for data_size in [(2, 20), (3, 15), (4, 10)]:
for i in range(100):
self.assertTrue(sound.shape == estimate.shape, (sound.shape, estimate.shape)) sound = common_utils.random_float_tensor(i, data_size)
self.assertTrue(torch.allclose(sound, estimate, atol=atol, rtol=rtol))
def _test_istft_is_inverse_of_stft(self, kwargs): stft = torch.stft(sound, **kwargs)
# generates a random sound signal for each tril and then does the stft/istft estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs)
# operation to check whether we can reconstruct signal
for data_size in self.data_sizes:
for i in range(self.number_of_trials):
sound = common_utils.random_float_tensor(i, data_size) _compare_estimate(sound, estimate)
stft = torch.stft(sound, **kwargs)
estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs)
self._compare_estimate(sound, estimate) class TestIstft(unittest.TestCase):
"""Test suite for correctness of istft with various input"""
number_of_trials = 100
def test_istft_is_inverse_of_stft1(self): def test_istft_is_inverse_of_stft1(self):
# hann_window, centered, normalized, onesided # hann_window, centered, normalized, onesided
...@@ -69,8 +67,7 @@ class TestFunctional(unittest.TestCase): ...@@ -69,8 +67,7 @@ class TestFunctional(unittest.TestCase):
'normalized': True, 'normalized': True,
'onesided': True, 'onesided': True,
} }
_test_istft_is_inverse_of_stft(kwargs1)
self._test_istft_is_inverse_of_stft(kwargs1)
def test_istft_is_inverse_of_stft2(self): def test_istft_is_inverse_of_stft2(self):
# hann_window, centered, not normalized, not onesided # hann_window, centered, not normalized, not onesided
...@@ -84,8 +81,7 @@ class TestFunctional(unittest.TestCase): ...@@ -84,8 +81,7 @@ class TestFunctional(unittest.TestCase):
'normalized': False, 'normalized': False,
'onesided': False, 'onesided': False,
} }
_test_istft_is_inverse_of_stft(kwargs2)
self._test_istft_is_inverse_of_stft(kwargs2)
def test_istft_is_inverse_of_stft3(self): def test_istft_is_inverse_of_stft3(self):
# hamming_window, centered, normalized, not onesided # hamming_window, centered, normalized, not onesided
...@@ -99,8 +95,7 @@ class TestFunctional(unittest.TestCase): ...@@ -99,8 +95,7 @@ class TestFunctional(unittest.TestCase):
'normalized': True, 'normalized': True,
'onesided': False, 'onesided': False,
} }
_test_istft_is_inverse_of_stft(kwargs3)
self._test_istft_is_inverse_of_stft(kwargs3)
def test_istft_is_inverse_of_stft4(self): def test_istft_is_inverse_of_stft4(self):
# hamming_window, not centered, not normalized, onesided # hamming_window, not centered, not normalized, onesided
...@@ -115,8 +110,7 @@ class TestFunctional(unittest.TestCase): ...@@ -115,8 +110,7 @@ class TestFunctional(unittest.TestCase):
'normalized': False, 'normalized': False,
'onesided': True, 'onesided': True,
} }
_test_istft_is_inverse_of_stft(kwargs4)
self._test_istft_is_inverse_of_stft(kwargs4)
def test_istft_is_inverse_of_stft5(self): def test_istft_is_inverse_of_stft5(self):
# hamming_window, not centered, not normalized, not onesided # hamming_window, not centered, not normalized, not onesided
...@@ -131,8 +125,7 @@ class TestFunctional(unittest.TestCase): ...@@ -131,8 +125,7 @@ class TestFunctional(unittest.TestCase):
'normalized': False, 'normalized': False,
'onesided': False, 'onesided': False,
} }
_test_istft_is_inverse_of_stft(kwargs5)
self._test_istft_is_inverse_of_stft(kwargs5)
def test_istft_of_ones(self): def test_istft_of_ones(self):
# stft = torch.stft(torch.ones(4), 4) # stft = torch.stft(torch.ones(4), 4)
...@@ -143,14 +136,14 @@ class TestFunctional(unittest.TestCase): ...@@ -143,14 +136,14 @@ class TestFunctional(unittest.TestCase):
]) ])
estimate = torchaudio.functional.istft(stft, n_fft=4, length=4) estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
self._compare_estimate(torch.ones(4), estimate) _compare_estimate(torch.ones(4), estimate)
def test_istft_of_zeros(self): def test_istft_of_zeros(self):
# stft = torch.stft(torch.zeros(4), 4) # stft = torch.stft(torch.zeros(4), 4)
stft = torch.zeros((3, 5, 2)) stft = torch.zeros((3, 5, 2))
estimate = torchaudio.functional.istft(stft, n_fft=4, length=4) estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
self._compare_estimate(torch.zeros(4), estimate) _compare_estimate(torch.zeros(4), estimate)
def test_istft_requires_overlap_windows(self): def test_istft_requires_overlap_windows(self):
# the window is size 1 but it hops 20 so there is a gap which throw an error # the window is size 1 but it hops 20 so there is a gap which throw an error
...@@ -199,7 +192,7 @@ class TestFunctional(unittest.TestCase): ...@@ -199,7 +192,7 @@ class TestFunctional(unittest.TestCase):
estimate = torchaudio.functional.istft(stft, L, hop_length=L, win_length=L, estimate = torchaudio.functional.istft(stft, L, hop_length=L, win_length=L,
window=torch.ones(L), center=False, normalized=False) window=torch.ones(L), center=False, normalized=False)
# There is a larger error due to the scaling of amplitude # There is a larger error due to the scaling of amplitude
self._compare_estimate(sound, estimate, atol=1e-3) _compare_estimate(sound, estimate, atol=1e-3)
def test_istft_of_sine(self): def test_istft_of_sine(self):
self._test_istft_of_sine(amplitude=123, L=5, n=1) self._test_istft_of_sine(amplitude=123, L=5, n=1)
...@@ -219,7 +212,7 @@ class TestFunctional(unittest.TestCase): ...@@ -219,7 +212,7 @@ class TestFunctional(unittest.TestCase):
istft2 = torchaudio.functional.istft(tensor2, **kwargs) istft2 = torchaudio.functional.istft(tensor2, **kwargs)
istft = a * istft1 + b * istft2 istft = a * istft1 + b * istft2
estimate = torchaudio.functional.istft(a * tensor1 + b * tensor2, **kwargs) estimate = torchaudio.functional.istft(a * tensor1 + b * tensor2, **kwargs)
self._compare_estimate(istft, estimate, atol, rtol) _compare_estimate(istft, estimate, atol, rtol)
def test_linearity_of_istft1(self): def test_linearity_of_istft1(self):
# hann_window, centered, normalized, onesided # hann_window, centered, normalized, onesided
...@@ -273,10 +266,13 @@ class TestFunctional(unittest.TestCase): ...@@ -273,10 +266,13 @@ class TestFunctional(unittest.TestCase):
data_size = (2, 7, 3, 2) data_size = (2, 7, 3, 2)
self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8) self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8)
class TestDetectPitchFrequency(unittest.TestCase):
def test_pitch(self): def test_pitch(self):
test_dirpath, test_dir = common_utils.create_temp_assets_dir() test_filepath_100 = os.path.join(
test_filepath_100 = os.path.join(test_dirpath, 'assets', "100Hz_44100Hz_16bit_05sec.wav") common_utils.TEST_DIR_PATH, 'assets', "100Hz_44100Hz_16bit_05sec.wav")
test_filepath_440 = os.path.join(test_dirpath, 'assets', "440Hz_44100Hz_16bit_05sec.wav") test_filepath_440 = os.path.join(
common_utils.TEST_DIR_PATH, 'assets', "440Hz_44100Hz_16bit_05sec.wav")
# Files from https://www.mediacollege.com/audio/tone/download/ # Files from https://www.mediacollege.com/audio/tone/download/
tests = [ tests = [
...@@ -293,6 +289,8 @@ class TestFunctional(unittest.TestCase): ...@@ -293,6 +289,8 @@ class TestFunctional(unittest.TestCase):
s = ((freq - freq_ref).abs() > threshold).sum() s = ((freq - freq_ref).abs() > threshold).sum()
self.assertFalse(s) self.assertFalse(s)
class TestDB_to_amplitude(unittest.TestCase):
def test_DB_to_amplitude(self): def test_DB_to_amplitude(self):
# Make some noise # Make some noise
x = torch.rand(1000) x = torch.rand(1000)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment