Commit 49767099 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Bring tests up to speed

parent a6f56d16
...@@ -41,7 +41,7 @@ class TestTriangularAttention(unittest.TestCase): ...@@ -41,7 +41,7 @@ class TestTriangularAttention(unittest.TestCase):
x = torch.rand((batch_size, n_res, n_res, c_z)) x = torch.rand((batch_size, n_res, n_res, c_z))
shape_before = x.shape shape_before = x.shape
x = tan(x) x = tan(x, chunk_size=None)
shape_after = x.shape shape_after = x.shape
self.assertTrue(shape_before == shape_after) self.assertTrue(shape_before == shape_after)
...@@ -92,6 +92,7 @@ class TestTriangularAttention(unittest.TestCase): ...@@ -92,6 +92,7 @@ class TestTriangularAttention(unittest.TestCase):
out_repro = module( out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(), torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(), mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
chunk_size=None,
).cpu() ).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps)) self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment