Unverified Commit 657f0a02 authored by moto's avatar moto Committed by GitHub
Browse files

Refactor batch test helper function (#519)

parent 86d54160
...@@ -10,39 +10,23 @@ import common_utils ...@@ -10,39 +10,23 @@ import common_utils
from common_utils import AudioBackendScope, BACKENDS from common_utils import AudioBackendScope, BACKENDS
def _test_batch_shape(functional, tensor, *args, atol=1e-8, rtol=1e-5, **kwargs): def _test_batch_consistency(functional, tensor, *args, batch_size=1, atol=1e-8, rtol=1e-5, **kwargs):
# Single then transform then batch # run then batch the result
torch.random.manual_seed(42) torch.random.manual_seed(42)
expected = functional(tensor.clone(), *args, **kwargs) expected = functional(tensor.clone(), *args, **kwargs)
expected = expected.unsqueeze(0).unsqueeze(0) expected = expected.repeat([batch_size] + [1] * expected.dim())
# 1-Batch then transform
tensors = tensor.unsqueeze(0).unsqueeze(0)
# batch the input and run
torch.random.manual_seed(42) torch.random.manual_seed(42)
computed = functional(tensors.clone(), *args, **kwargs) pattern = [batch_size] + [1] * tensor.dim()
computed = functional(tensor.repeat(pattern), *args, **kwargs)
torch.testing.assert_allclose(computed, expected, rtol=rtol, atol=atol) torch.testing.assert_allclose(computed, expected, rtol=rtol, atol=atol)
return tensors, expected
def _test_batch(functional, tensor, *args, atol=1e-8, rtol=1e-5, **kwargs): def _test_batch(functional, tensor, *args, atol=1e-8, rtol=1e-5, **kwargs):
tensors, expected = _test_batch_shape(functional, tensor, *args, atol=atol, rtol=rtol, **kwargs) _test_batch_consistency(functional, tensor, *args, batch_size=1, atol=atol, rtol=rtol, **kwargs)
_test_batch_consistency(functional, tensor, *args, batch_size=3, atol=atol, rtol=rtol, **kwargs)
# 3-Batch then transform
ind = [3] + [1] * (int(tensors.dim()) - 1)
tensors = tensor.repeat(*ind)
ind = [3] + [1] * (int(expected.dim()) - 1)
expected = expected.repeat(*ind)
torch.random.manual_seed(42)
computed = functional(tensors.clone(), *args, **kwargs)
torch.testing.assert_allclose(computed, expected, rtol=rtol, atol=atol)
class TestFunctional(unittest.TestCase): class TestFunctional(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment