Unverified Commit 7ef1178e authored by moto's avatar moto Committed by GitHub
Browse files

Refactor test_functional module (#503)

parent 9b288109
......@@ -10,52 +10,50 @@ import pytest
import common_utils
class TestFunctional(unittest.TestCase):
data_sizes = [(2, 20), (3, 15), (4, 10)]
number_of_trials = 100
specgram = torch.tensor([1., 2., 3., 4.])
class TestComputeDeltas(unittest.TestCase):
"""Test suite for correctness of compute_deltas"""
def test_one_channel(self):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]])
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]])
computed = F.compute_deltas(specgram, win_length=3)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
test_dirpath, test_dir = common_utils.create_temp_assets_dir()
def test_two_channels(self):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0],
[1.0, 2.0, 3.0, 4.0]]])
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5]]])
computed = F.compute_deltas(specgram, win_length=3)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected)
test_filepath = os.path.join(test_dirpath, 'assets',
'steam-train-whistle-daniel_simon.wav')
waveform_train, sr_train = torchaudio.load(test_filepath)
def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8):
computed = F.compute_deltas(specgram, win_length=win_length)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
torch.testing.assert_allclose(computed, expected, atol=atol, rtol=rtol)
def _compare_estimate(sound, estimate, atol=1e-6, rtol=1e-8):
# trim sound for case when constructed signal is shorter than original
sound = sound[..., :estimate.size(-1)]
def test_compute_deltas_onechannel(self):
specgram = self.specgram.unsqueeze(0).unsqueeze(0)
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]])
self._test_compute_deltas(specgram, expected)
assert sound.shape == estimate.shape, (sound.shape, estimate.shape)
assert torch.allclose(sound, estimate, atol=atol, rtol=rtol)
def test_compute_deltas_twochannel(self):
specgram = self.specgram.repeat(1, 2, 1)
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5]]])
self._test_compute_deltas(specgram, expected)
def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8):
# trim sound for case when constructed signal is shorter than original
sound = sound[..., :estimate.size(-1)]
def _test_istft_is_inverse_of_stft(kwargs):
# generates a random sound signal for each tril and then does the stft/istft
# operation to check whether we can reconstruct signal
for data_size in [(2, 20), (3, 15), (4, 10)]:
for i in range(100):
self.assertTrue(sound.shape == estimate.shape, (sound.shape, estimate.shape))
self.assertTrue(torch.allclose(sound, estimate, atol=atol, rtol=rtol))
sound = common_utils.random_float_tensor(i, data_size)
def _test_istft_is_inverse_of_stft(self, kwargs):
# generates a random sound signal for each tril and then does the stft/istft
# operation to check whether we can reconstruct signal
for data_size in self.data_sizes:
for i in range(self.number_of_trials):
stft = torch.stft(sound, **kwargs)
estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs)
sound = common_utils.random_float_tensor(i, data_size)
_compare_estimate(sound, estimate)
stft = torch.stft(sound, **kwargs)
estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs)
self._compare_estimate(sound, estimate)
class TestIstft(unittest.TestCase):
"""Test suite for correctness of istft with various input"""
number_of_trials = 100
def test_istft_is_inverse_of_stft1(self):
# hann_window, centered, normalized, onesided
......@@ -69,8 +67,7 @@ class TestFunctional(unittest.TestCase):
'normalized': True,
'onesided': True,
}
self._test_istft_is_inverse_of_stft(kwargs1)
_test_istft_is_inverse_of_stft(kwargs1)
def test_istft_is_inverse_of_stft2(self):
# hann_window, centered, not normalized, not onesided
......@@ -84,8 +81,7 @@ class TestFunctional(unittest.TestCase):
'normalized': False,
'onesided': False,
}
self._test_istft_is_inverse_of_stft(kwargs2)
_test_istft_is_inverse_of_stft(kwargs2)
def test_istft_is_inverse_of_stft3(self):
# hamming_window, centered, normalized, not onesided
......@@ -99,8 +95,7 @@ class TestFunctional(unittest.TestCase):
'normalized': True,
'onesided': False,
}
self._test_istft_is_inverse_of_stft(kwargs3)
_test_istft_is_inverse_of_stft(kwargs3)
def test_istft_is_inverse_of_stft4(self):
# hamming_window, not centered, not normalized, onesided
......@@ -115,8 +110,7 @@ class TestFunctional(unittest.TestCase):
'normalized': False,
'onesided': True,
}
self._test_istft_is_inverse_of_stft(kwargs4)
_test_istft_is_inverse_of_stft(kwargs4)
def test_istft_is_inverse_of_stft5(self):
# hamming_window, not centered, not normalized, not onesided
......@@ -131,8 +125,7 @@ class TestFunctional(unittest.TestCase):
'normalized': False,
'onesided': False,
}
self._test_istft_is_inverse_of_stft(kwargs5)
_test_istft_is_inverse_of_stft(kwargs5)
def test_istft_of_ones(self):
# stft = torch.stft(torch.ones(4), 4)
......@@ -143,14 +136,14 @@ class TestFunctional(unittest.TestCase):
])
estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
self._compare_estimate(torch.ones(4), estimate)
_compare_estimate(torch.ones(4), estimate)
def test_istft_of_zeros(self):
# stft = torch.stft(torch.zeros(4), 4)
stft = torch.zeros((3, 5, 2))
estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
self._compare_estimate(torch.zeros(4), estimate)
_compare_estimate(torch.zeros(4), estimate)
def test_istft_requires_overlap_windows(self):
# the window is size 1 but it hops 20 so there is a gap which throw an error
......@@ -199,7 +192,7 @@ class TestFunctional(unittest.TestCase):
estimate = torchaudio.functional.istft(stft, L, hop_length=L, win_length=L,
window=torch.ones(L), center=False, normalized=False)
# There is a larger error due to the scaling of amplitude
self._compare_estimate(sound, estimate, atol=1e-3)
_compare_estimate(sound, estimate, atol=1e-3)
def test_istft_of_sine(self):
self._test_istft_of_sine(amplitude=123, L=5, n=1)
......@@ -219,7 +212,7 @@ class TestFunctional(unittest.TestCase):
istft2 = torchaudio.functional.istft(tensor2, **kwargs)
istft = a * istft1 + b * istft2
estimate = torchaudio.functional.istft(a * tensor1 + b * tensor2, **kwargs)
self._compare_estimate(istft, estimate, atol, rtol)
_compare_estimate(istft, estimate, atol, rtol)
def test_linearity_of_istft1(self):
# hann_window, centered, normalized, onesided
......@@ -273,10 +266,13 @@ class TestFunctional(unittest.TestCase):
data_size = (2, 7, 3, 2)
self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8)
class TestDetectPitchFrequency(unittest.TestCase):
def test_pitch(self):
test_dirpath, test_dir = common_utils.create_temp_assets_dir()
test_filepath_100 = os.path.join(test_dirpath, 'assets', "100Hz_44100Hz_16bit_05sec.wav")
test_filepath_440 = os.path.join(test_dirpath, 'assets', "440Hz_44100Hz_16bit_05sec.wav")
test_filepath_100 = os.path.join(
common_utils.TEST_DIR_PATH, 'assets', "100Hz_44100Hz_16bit_05sec.wav")
test_filepath_440 = os.path.join(
common_utils.TEST_DIR_PATH, 'assets', "440Hz_44100Hz_16bit_05sec.wav")
# Files from https://www.mediacollege.com/audio/tone/download/
tests = [
......@@ -293,6 +289,8 @@ class TestFunctional(unittest.TestCase):
s = ((freq - freq_ref).abs() > threshold).sum()
self.assertFalse(s)
class TestDB_to_amplitude(unittest.TestCase):
def test_DB_to_amplitude(self):
# Make some noise
x = torch.rand(1000)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment