Commit 857e9b7c authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix evoformer paths in some old tests

parent 9a8910f9
...@@ -65,7 +65,7 @@ class TestImportWeights(unittest.TestCase): ...@@ -65,7 +65,7 @@ class TestImportWeights(unittest.TestCase):
) )
][1].transpose(-1, -2) ][1].transpose(-1, -2)
), ),
model.evoformer.blocks[1].outer_product_mean.linear_1.weight, model.evoformer.blocks[1].core.outer_product_mean.linear_1.weight,
), ),
] ]
......
...@@ -81,7 +81,7 @@ class TestOuterProductMean(unittest.TestCase): ...@@ -81,7 +81,7 @@ class TestOuterProductMean(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
out_repro = ( out_repro = (
model.evoformer.blocks[0] model.evoformer.blocks[0].core
.outer_product_mean( .outer_product_mean(
torch.as_tensor(msa_act).cuda(), torch.as_tensor(msa_act).cuda(),
chunk_size=4, chunk_size=4,
......
...@@ -76,7 +76,7 @@ class TestPairTransition(unittest.TestCase): ...@@ -76,7 +76,7 @@ class TestPairTransition(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
out_repro = ( out_repro = (
model.evoformer.blocks[0] model.evoformer.blocks[0].core
.pair_transition( .pair_transition(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(), torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
chunk_size=4, chunk_size=4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment