Unverified Commit bc1ffb11 authored by moto's avatar moto Committed by GitHub
Browse files

Simplify helper function (#514)

parent 3695a0ef
...@@ -10,19 +10,7 @@ import common_utils ...@@ -10,19 +10,7 @@ import common_utils
from common_utils import AudioBackendScope, BACKENDS from common_utils import AudioBackendScope, BACKENDS
def _test_batch_shape(functional, tensor, *args, **kwargs): def _test_batch_shape(functional, tensor, *args, atol=1e-8, rtol=1e-5, **kwargs):
kwargs_compare = {}
if 'atol' in kwargs:
atol = kwargs['atol']
del kwargs['atol']
kwargs_compare['atol'] = atol
if 'rtol' in kwargs:
rtol = kwargs['rtol']
del kwargs['rtol']
kwargs_compare['rtol'] = rtol
# Single then transform then batch # Single then transform then batch
torch.random.manual_seed(42) torch.random.manual_seed(42)
...@@ -36,24 +24,13 @@ def _test_batch_shape(functional, tensor, *args, **kwargs): ...@@ -36,24 +24,13 @@ def _test_batch_shape(functional, tensor, *args, **kwargs):
computed = functional(tensors.clone(), *args, **kwargs) computed = functional(tensors.clone(), *args, **kwargs)
assert expected.shape == computed.shape, (expected.shape, computed.shape) assert expected.shape == computed.shape, (expected.shape, computed.shape)
assert torch.allclose(expected, computed, **kwargs_compare) assert torch.allclose(expected, computed, atol=atol, rtol=rtol)
return tensors, expected return tensors, expected
def _test_batch(functional, tensor, *args, **kwargs): def _test_batch(functional, tensor, *args, atol=1e-8, rtol=1e-5, **kwargs):
tensors, expected = _test_batch_shape(functional, tensor, *args, **kwargs) tensors, expected = _test_batch_shape(functional, tensor, *args, atol=atol, rtol=rtol, **kwargs)
kwargs_compare = {}
if 'atol' in kwargs:
atol = kwargs['atol']
del kwargs['atol']
kwargs_compare['atol'] = atol
if 'rtol' in kwargs:
rtol = kwargs['rtol']
del kwargs['rtol']
kwargs_compare['rtol'] = rtol
# 3-Batch then transform # 3-Batch then transform
...@@ -67,7 +44,7 @@ def _test_batch(functional, tensor, *args, **kwargs): ...@@ -67,7 +44,7 @@ def _test_batch(functional, tensor, *args, **kwargs):
computed = functional(tensors.clone(), *args, **kwargs) computed = functional(tensors.clone(), *args, **kwargs)
assert expected.shape == computed.shape, (expected.shape, computed.shape) assert expected.shape == computed.shape, (expected.shape, computed.shape)
assert torch.allclose(expected, computed, **kwargs_compare) assert torch.allclose(expected, computed, atol=atol, rtol=rtol)
class TestFunctional(unittest.TestCase): class TestFunctional(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment