Commit 49767099 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Bring tests up to speed

parent a6f56d16
......@@ -41,7 +41,7 @@ class TestTriangularAttention(unittest.TestCase):
x = torch.rand((batch_size, n_res, n_res, c_z))
shape_before = x.shape
x = tan(x)
x = tan(x, chunk_size=None)
shape_after = x.shape
self.assertTrue(shape_before == shape_after)
......@@ -92,6 +92,7 @@ class TestTriangularAttention(unittest.TestCase):
out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
chunk_size=None,
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment