Commit 173f055e authored by Dingquan Yu's avatar Dingquan Yu
Browse files

update test script

parent 0d348792
......@@ -70,16 +70,11 @@ class TestMultimerDataModule(unittest.TestCase):
self.data_module.prepare_data()
self.data_module.setup()
train_dataset = self.data_module.train_dataset
all_chain_features,ground_truth = train_dataset[0]
asym_ids = all_chain_features['asym_id'].unique()
print(f"asym_ids is {asym_ids}")
print(f"ground truth:")
all_chain_features,ground_truth = train_dataset[1]
add_batch_size_dimension = lambda t: (
t.unsqueeze(0)
)
all_chain_features = tensor_tree_map(add_batch_size_dimension,all_chain_features)
with torch.no_grad():
out = self.model(all_chain_features)
print(f"out masked_msa_logits is: {out['masked_msa_logits'].shape}")
self.multimer_loss(out,(all_chain_features,ground_truth))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment