Commit 7fdb503e authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix evoformer test

parent d5da89c1
......@@ -193,10 +193,10 @@ class TestExtraMSAStack(unittest.TestCase):
ckpt=False,
inf=inf,
eps=eps,
).eval()
).eval().cuda()
m = torch.rand((batch_size, s_t, n_res, c_m))
z = torch.rand((batch_size, n_res, n_res, c_z))
m = torch.rand((batch_size, s_t, n_res, c_m), device="cuda")
z = torch.rand((batch_size, n_res, n_res, c_z), device="cuda")
msa_mask = torch.randint(
0,
2,
......@@ -205,6 +205,7 @@ class TestExtraMSAStack(unittest.TestCase):
s_t,
n_res,
),
device="cuda",
)
pair_mask = torch.randint(
0,
......@@ -214,6 +215,7 @@ class TestExtraMSAStack(unittest.TestCase):
n_res,
n_res,
),
device="cuda",
)
shape_z_before = z.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment