Commit c6ac105d authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

remove recycling dimentions

parent d3b2b265
...@@ -83,7 +83,7 @@ class TestPermutation(unittest.TestCase): ...@@ -83,7 +83,7 @@ class TestPermutation(unittest.TestCase):
# Modify asym_id, entity_id and sym_id so that it encodes # Modify asym_id, entity_id and sym_id so that it encodes
# 2 chains # 2 chains
# # # #
asym_id = [1]*9 + [2]*13 asym_id = [1]*9+[2]*13
batch["asym_id"] = torch.tensor(asym_id,dtype=torch.float64) batch["asym_id"] = torch.tensor(asym_id,dtype=torch.float64)
# batch["entity_id"] = torch.randint(0, 1, size=(n_res,)) # batch["entity_id"] = torch.randint(0, 1, size=(n_res,))
batch['entity_id'] = torch.tensor(asym_id,dtype=torch.float64) batch['entity_id'] = torch.tensor(asym_id,dtype=torch.float64)
...@@ -94,10 +94,10 @@ class TestPermutation(unittest.TestCase): ...@@ -94,10 +94,10 @@ class TestPermutation(unittest.TestCase):
t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters) t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters)
) )
print(f"max_recycling_iters is {c.data.common.max_recycling_iters}") print(f"max_recycling_iters is {c.data.common.max_recycling_iters}")
batch = tensor_tree_map(add_recycling_dims, batch) input_batch = tensor_tree_map(add_recycling_dims, batch)
with torch.no_grad(): with torch.no_grad():
out = model(batch) out = model(input_batch)
print("finished running multimer forward") print("finished running multimer forward")
print(f"out is {type(out)} and has keys {out.keys()}") print(f"out is {type(out)} and has keys {out.keys()}")
print(f"final_atom_positions is {out['final_atom_positions'].shape}") print(f"final_atom_positions is {out['final_atom_positions'].shape}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment