Commit 173f055e authored by Dingquan Yu's avatar Dingquan Yu
Browse files

update test script

parent 0d348792
...@@ -70,16 +70,11 @@ class TestMultimerDataModule(unittest.TestCase): ...@@ -70,16 +70,11 @@ class TestMultimerDataModule(unittest.TestCase):
self.data_module.prepare_data() self.data_module.prepare_data()
self.data_module.setup() self.data_module.setup()
train_dataset = self.data_module.train_dataset train_dataset = self.data_module.train_dataset
all_chain_features,ground_truth = train_dataset[0] all_chain_features,ground_truth = train_dataset[1]
asym_ids = all_chain_features['asym_id'].unique()
print(f"asym_ids is {asym_ids}")
print(f"ground truth:")
add_batch_size_dimension = lambda t: ( add_batch_size_dimension = lambda t: (
t.unsqueeze(0) t.unsqueeze(0)
) )
all_chain_features = tensor_tree_map(add_batch_size_dimension,all_chain_features) all_chain_features = tensor_tree_map(add_batch_size_dimension,all_chain_features)
with torch.no_grad(): with torch.no_grad():
out = self.model(all_chain_features) out = self.model(all_chain_features)
print(f"out masked_msa_logits is: {out['masked_msa_logits'].shape}")
self.multimer_loss(out,(all_chain_features,ground_truth)) self.multimer_loss(out,(all_chain_features,ground_truth))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment