Commit 5621ac05 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update the test input to be A2B3

parent 4666e15e
......@@ -21,9 +21,9 @@ import unittest
from openfold.config import model_config
from openfold.data import data_transforms
from openfold.model.model import AlphaFold
from openfold.utils.loss import AlphaFoldMultimerLoss
from openfold.utils.tensor_utils import tensor_tree_map
from tests.config import consts
from .unifold_permutation import multi_chain_perm_align
import logging
logger = logging.getLogger(__name__)
import os
......@@ -40,12 +40,12 @@ class TestPermutation(unittest.TestCase):
In the test case, use PDB ID 1e4k as the label
"""
self.test_data_dir = os.path.join(os.getcwd(),"tests/test_data")
self.label_ids = ['label_1','label_2','label_2']
self.label_ids = ['label_1','label_1','label_2','label_2','label_2']
self.asym_id = [1]*9+[2]*9+[3]*13+[4]*13 + [5]*13
def test_dry_run(self):
n_seq = consts.n_seq
n_templ = consts.n_templ
n_res = consts.n_res +9
n_res = len(self.asym_id)
n_extra_seq = consts.n_extra
c = model_config(consts.model, train=True)
......@@ -54,6 +54,7 @@ class TestPermutation(unittest.TestCase):
# deepspeed for this test
model = AlphaFold(c)
multimer_loss = AlphaFoldMultimerLoss(c)
example_label = [pickle.load(open(os.path.join(self.test_data_dir,f"{i}.pkl"),'rb'))
for i in self.label_ids]
batch = {}
......@@ -62,8 +63,6 @@ class TestPermutation(unittest.TestCase):
tf, c.model.input_embedder.tf_dim
).float()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)
print(f"target_feat shape is {batch['target_feat'].size()}")
print(f"batch_dim is {batch['target_feat'].shape[:-2]}")
batch["residue_index"] = torch.arange(n_res)
batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim))
......@@ -83,23 +82,19 @@ class TestPermutation(unittest.TestCase):
# Modify asym_id, entity_id and sym_id so that it encodes
# 2 chains
# #
asym_id = [1]*9+[2]*9+[3]*13
asym_id = self.asym_id
batch["asym_id"] = torch.tensor(asym_id,dtype=torch.float64)
# batch["entity_id"] = torch.randint(0, 1, size=(n_res,))
batch['entity_id'] = torch.tensor([1]*18+[2]*13,dtype=torch.float64)
batch['entity_id'] = torch.tensor([1]*18+[2]*39,dtype=torch.float64)
batch["sym_id"] = torch.tensor(asym_id,dtype=torch.float64)
batch["num_sym"] = torch.tensor([1]*18+[2]*13,dtype=torch.int64) # currently there are just 2 chains
# batch["num_sym"] = torch.tensor([1]*18+[2]*13,dtype=torch.int64) # currently there are just 2 chains
batch["extra_deletion_matrix"] = torch.randint(0, 2, size=(n_extra_seq, n_res))
add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters)
)
print(f"max_recycling_iters is {c.data.common.max_recycling_iters}")
input_batch = tensor_tree_map(add_recycling_dims, batch)
batch = tensor_tree_map(add_recycling_dims, batch)
with torch.no_grad():
out = model(input_batch)
print("finished running multimer forward")
print(f"out is {type(out)} and has keys {out.keys()}")
print(f"final_atom_positions is {out['final_atom_positions'].shape}")
print(f"out itpm score is {out['iptm_score']}")
multi_chain_perm_align(out,batch,example_label)
\ No newline at end of file
out = model(batch)
permutated_labels = multimer_loss(out,(batch,example_label))
print(f"permuated_labels is {type(permutated_labels)} and keys are:\n {permutated_labels.keys()}")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment