Commit 48668ca3 authored by Jennifer Wei's avatar Jennifer Wei
Browse files

Adds float wrapper to to vectors in TestExtraMSAStack in

 test_evoformer.py
parent 2134cc09
...@@ -206,7 +206,7 @@ class TestExtraMSAStack(unittest.TestCase): ...@@ -206,7 +206,7 @@ class TestExtraMSAStack(unittest.TestCase):
n_res, n_res,
), ),
device="cuda", device="cuda",
) ).float()
pair_mask = torch.randint( pair_mask = torch.randint(
0, 0,
2, 2,
...@@ -216,7 +216,7 @@ class TestExtraMSAStack(unittest.TestCase): ...@@ -216,7 +216,7 @@ class TestExtraMSAStack(unittest.TestCase):
n_res, n_res,
), ),
device="cuda", device="cuda",
) ).float()
shape_z_before = z.shape shape_z_before = z.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment