Unverified Commit 21269247 authored by moto's avatar moto Committed by GitHub
Browse files

Fix incomplete tests in batch and transforms (#506)

* Fix test_compute_deltas_twochannels

* Fix 3batch test helper
parent d1adb7f6
......@@ -64,7 +64,10 @@ def _test_batch(functional, tensor, *args, **kwargs):
expected = expected.repeat(*ind)
torch.random.manual_seed(42)
_ = functional(tensors.clone(), *args, **kwargs)
computed = functional(tensors.clone(), *args, **kwargs)
assert expected.shape == computed.shape, (expected.shape, computed.shape)
assert torch.allclose(expected, computed, **kwargs_compare)
class TestFunctional(unittest.TestCase):
......
......@@ -212,11 +212,12 @@ class Tester(unittest.TestCase):
def test_compute_deltas_twochannel(self):
specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1)
_ = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5]]])
transform = transforms.ComputeDeltas()
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5]]])
transform = transforms.ComputeDeltas(win_length=3)
computed = transform(specgram)
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
assert computed.shape == expected.shape, (computed.shape, expected.shape)
assert torch.allclose(computed, expected, atol=1e-6, rtol=1e-8)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment