"...llm/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "a611726e53f0063f27c410d6893052c208f339f1"
Commit 7fdb503e authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix evoformer test

parent d5da89c1
...@@ -193,10 +193,10 @@ class TestExtraMSAStack(unittest.TestCase): ...@@ -193,10 +193,10 @@ class TestExtraMSAStack(unittest.TestCase):
ckpt=False, ckpt=False,
inf=inf, inf=inf,
eps=eps, eps=eps,
).eval() ).eval().cuda()
m = torch.rand((batch_size, s_t, n_res, c_m)) m = torch.rand((batch_size, s_t, n_res, c_m), device="cuda")
z = torch.rand((batch_size, n_res, n_res, c_z)) z = torch.rand((batch_size, n_res, n_res, c_z), device="cuda")
msa_mask = torch.randint( msa_mask = torch.randint(
0, 0,
2, 2,
...@@ -205,6 +205,7 @@ class TestExtraMSAStack(unittest.TestCase): ...@@ -205,6 +205,7 @@ class TestExtraMSAStack(unittest.TestCase):
s_t, s_t,
n_res, n_res,
), ),
device="cuda",
) )
pair_mask = torch.randint( pair_mask = torch.randint(
0, 0,
...@@ -214,6 +215,7 @@ class TestExtraMSAStack(unittest.TestCase): ...@@ -214,6 +215,7 @@ class TestExtraMSAStack(unittest.TestCase):
n_res, n_res,
n_res, n_res,
), ),
device="cuda",
) )
shape_z_before = z.shape shape_z_before = z.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment