Unverified Commit a91cae7c authored by moto's avatar moto Committed by GitHub
Browse files

Extract batch test from test_functional and move to the dedicated module (#491)

parent 413bd18e
"""Test numerical consistency among single input and batched input."""
import os
import unittest
import torch
import torchaudio
import torchaudio.functional as F
import common_utils
def _test_batch_shape(functional, tensor, *args, **kwargs):
kwargs_compare = {}
if 'atol' in kwargs:
atol = kwargs['atol']
del kwargs['atol']
kwargs_compare['atol'] = atol
if 'rtol' in kwargs:
rtol = kwargs['rtol']
del kwargs['rtol']
kwargs_compare['rtol'] = rtol
# Single then transform then batch
torch.random.manual_seed(42)
expected = functional(tensor.clone(), *args, **kwargs)
expected = expected.unsqueeze(0).unsqueeze(0)
# 1-Batch then transform
tensors = tensor.unsqueeze(0).unsqueeze(0)
torch.random.manual_seed(42)
computed = functional(tensors.clone(), *args, **kwargs)
assert expected.shape == computed.shape, (expected.shape, computed.shape)
assert torch.allclose(expected, computed, **kwargs_compare)
return tensors, expected
def _test_batch(functional, tensor, *args, **kwargs):
tensors, expected = _test_batch_shape(functional, tensor, *args, **kwargs)
kwargs_compare = {}
if 'atol' in kwargs:
atol = kwargs['atol']
del kwargs['atol']
kwargs_compare['atol'] = atol
if 'rtol' in kwargs:
rtol = kwargs['rtol']
del kwargs['rtol']
kwargs_compare['rtol'] = rtol
# 3-Batch then transform
ind = [3] + [1] * (int(tensors.dim()) - 1)
tensors = tensor.repeat(*ind)
ind = [3] + [1] * (int(expected.dim()) - 1)
expected = expected.repeat(*ind)
torch.random.manual_seed(42)
computed = functional(tensors.clone(), *args, **kwargs)
class TestFunctional(unittest.TestCase):
"""Test functions defined in `functional` module"""
def test_griffinlim(self):
n_fft = 400
ws = 400
hop = 200
window = torch.hann_window(ws)
power = 2
normalize = False
momentum = 0.99
n_iter = 32
length = 1000
tensor = torch.rand((1, 201, 6))
_test_batch(
F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0, atol=5e-5
)
def test_detect_pitch_frequency(self):
filenames = [
'steam-train-whistle-daniel_simon.wav', # 2ch 44100Hz
# Files from https://www.mediacollege.com/audio/tone/download/
'100Hz_44100Hz_16bit_05sec.wav', # 1ch
'440Hz_44100Hz_16bit_05sec.wav', # 1ch
]
for filename in filenames:
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', filename)
waveform, sample_rate = torchaudio.load(filepath)
_test_batch(F.detect_pitch_frequency, waveform, sample_rate)
def test_istft(self):
stft = torch.tensor([
[[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
])
_test_batch(F.istft, stft, n_fft=4, length=4)
...@@ -23,25 +23,6 @@ class TestFunctional(unittest.TestCase): ...@@ -23,25 +23,6 @@ class TestFunctional(unittest.TestCase):
'steam-train-whistle-daniel_simon.wav') 'steam-train-whistle-daniel_simon.wav')
waveform_train, sr_train = torchaudio.load(test_filepath) waveform_train, sr_train = torchaudio.load(test_filepath)
def test_batch_griffinlim(self):
torch.random.manual_seed(42)
tensor = torch.rand((1, 201, 6))
n_fft = 400
ws = 400
hop = 200
window = torch.hann_window(ws)
power = 2
normalize = False
momentum = 0.99
n_iter = 32
length = 1000
self._test_batch(
F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0, atol=5e-5
)
def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8): def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8):
computed = F.compute_deltas(specgram, win_length=win_length) computed = F.compute_deltas(specgram, win_length=win_length)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
...@@ -58,10 +39,6 @@ class TestFunctional(unittest.TestCase): ...@@ -58,10 +39,6 @@ class TestFunctional(unittest.TestCase):
[0.5, 1.0, 1.0, 0.5]]]) [0.5, 1.0, 1.0, 0.5]]])
self._test_compute_deltas(specgram, expected) self._test_compute_deltas(specgram, expected)
def test_batch_pitch(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
self._test_batch(F.detect_pitch_frequency, waveform, sample_rate)
def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8): def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8):
# trim sound for case when constructed signal is shorter than original # trim sound for case when constructed signal is shorter than original
sound = sound[..., :estimate.size(-1)] sound = sound[..., :estimate.size(-1)]
...@@ -298,16 +275,6 @@ class TestFunctional(unittest.TestCase): ...@@ -298,16 +275,6 @@ class TestFunctional(unittest.TestCase):
data_size = (2, 7, 3, 2) data_size = (2, 7, 3, 2)
self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8) self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8)
def test_batch_istft(self):
stft = torch.tensor([
[[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
])
self._test_batch(F.istft, stft, n_fft=4, length=4)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
def test_gain(self): def test_gain(self):
...@@ -383,65 +350,6 @@ class TestFunctional(unittest.TestCase): ...@@ -383,65 +350,6 @@ class TestFunctional(unittest.TestCase):
s = ((freq - freq_ref).abs() > threshold).sum() s = ((freq - freq_ref).abs() > threshold).sum()
self.assertFalse(s) self.assertFalse(s)
# Convert to stereo and batch for testing purposes
self._test_batch(F.detect_pitch_frequency, waveform, sample_rate)
def _test_batch_shape(self, functional, tensor, *args, **kwargs):
kwargs_compare = {}
if 'atol' in kwargs:
atol = kwargs['atol']
del kwargs['atol']
kwargs_compare['atol'] = atol
if 'rtol' in kwargs:
rtol = kwargs['rtol']
del kwargs['rtol']
kwargs_compare['rtol'] = rtol
# Single then transform then batch
torch.random.manual_seed(42)
expected = functional(tensor.clone(), *args, **kwargs)
expected = expected.unsqueeze(0).unsqueeze(0)
# 1-Batch then transform
tensors = tensor.unsqueeze(0).unsqueeze(0)
torch.random.manual_seed(42)
computed = functional(tensors.clone(), *args, **kwargs)
self._compare_estimate(computed, expected, **kwargs_compare)
return tensors, expected
def _test_batch(self, functional, tensor, *args, **kwargs):
tensors, expected = self._test_batch_shape(functional, tensor, *args, **kwargs)
kwargs_compare = {}
if 'atol' in kwargs:
atol = kwargs['atol']
del kwargs['atol']
kwargs_compare['atol'] = atol
if 'rtol' in kwargs:
rtol = kwargs['rtol']
del kwargs['rtol']
kwargs_compare['rtol'] = rtol
# 3-Batch then transform
ind = [3] + [1] * (int(tensors.dim()) - 1)
tensors = tensor.repeat(*ind)
ind = [3] + [1] * (int(expected.dim()) - 1)
expected = expected.repeat(*ind)
torch.random.manual_seed(42)
computed = functional(tensors.clone(), *args, **kwargs)
def test_DB_to_amplitude(self): def test_DB_to_amplitude(self):
# Make some noise # Make some noise
x = torch.rand(1000) x = torch.rand(1000)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment