Unverified Commit 7a0d4192 authored by kunalb6's avatar kunalb6 Committed by GitHub
Browse files

make seed parameterized (#614)


Co-authored-by: default avatarKunal Bhandari <bkunal@fb.com>

Closes #610 
parent c80d9a71
......@@ -9,23 +9,23 @@ import torchaudio.functional as F
import common_utils
def _test_batch_consistency(functional, tensor, *args, batch_size=1, atol=1e-8, rtol=1e-5, **kwargs):
def _test_batch_consistency(functional, tensor, *args, batch_size=1, atol=1e-8, rtol=1e-5, seed=42, **kwargs):
# run then batch the result
torch.random.manual_seed(42)
torch.random.manual_seed(seed)
expected = functional(tensor.clone(), *args, **kwargs)
expected = expected.repeat([batch_size] + [1] * expected.dim())
# batch the input and run
torch.random.manual_seed(42)
torch.random.manual_seed(seed)
pattern = [batch_size] + [1] * tensor.dim()
computed = functional(tensor.repeat(pattern), *args, **kwargs)
torch.testing.assert_allclose(computed, expected, rtol=rtol, atol=atol)
def _test_batch(functional, tensor, *args, atol=1e-8, rtol=1e-5, **kwargs):
_test_batch_consistency(functional, tensor, *args, batch_size=1, atol=atol, rtol=rtol, **kwargs)
_test_batch_consistency(functional, tensor, *args, batch_size=3, atol=atol, rtol=rtol, **kwargs)
def _test_batch(functional, tensor, *args, atol=1e-8, rtol=1e-5, seed=42, **kwargs):
_test_batch_consistency(functional, tensor, *args, batch_size=1, atol=atol, rtol=rtol, seed=seed, **kwargs)
_test_batch_consistency(functional, tensor, *args, batch_size=3, atol=atol, rtol=rtol, seed=seed, **kwargs)
class TestFunctional(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment