Commit caac73c0 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update test

parent 64493e08
...@@ -20,8 +20,12 @@ import torch.nn as nn ...@@ -20,8 +20,12 @@ import torch.nn as nn
import numpy as np import numpy as np
from functools import partial from functools import partial
import unittest import unittest
from openfold.utils.tensor_utils import tensor_tree_map
from openfold.config import model_config from openfold.config import model_config
from openfold.data.data_modules import OpenFoldMultimerDataModule,OpenFoldDataModule from openfold.data.data_modules import OpenFoldMultimerDataModule,OpenFoldDataModule
from openfold.model.model import AlphaFold
from openfold.utils.loss import AlphaFoldMultimerLoss
from tests.config import consts
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
import os import os
...@@ -52,9 +56,30 @@ class TestMultimerDataModule(unittest.TestCase): ...@@ -52,9 +56,30 @@ class TestMultimerDataModule(unittest.TestCase):
train_chain_data_cache_path=os.path.join(os.getcwd(), train_chain_data_cache_path=os.path.join(os.getcwd(),
"tests/test_data/train_chain_data_cache.json"), "tests/test_data/train_chain_data_cache.json"),
) )
# setup model
self.c = model_config(consts.model, train=True)
self.c.loss.masked_msa.num_classes = 22 # somehow need overwrite this part in multimer loss config
self.c.model.evoformer_stack.no_blocks = 4 # no need to go overboard here
self.c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test
self.model = AlphaFold(self.c)
self.multimer_loss = AlphaFoldMultimerLoss(self.c.loss)
def testPrepareData(self): def testPrepareData(self):
self.data_module.prepare_data() self.data_module.prepare_data()
self.data_module.setup() self.data_module.setup()
train_dataset = self.data_module.train_dataset train_dataset = self.data_module.train_dataset
all_chain_features,ground_truth = train_dataset[0] all_chain_features,ground_truth = train_dataset[0]
asym_ids = all_chain_features['asym_id'].unique()
print(f"asym_ids is {asym_ids}")
print(f"ground truth:")
add_batch_size_dimension = lambda t: (
t.unsqueeze(0)
)
all_chain_features = tensor_tree_map(add_batch_size_dimension,all_chain_features)
with torch.no_grad():
out = self.model(all_chain_features)
print(f"out masked_msa_logits is: {out['masked_msa_logits'].shape}")
self.multimer_loss(out,(all_chain_features,ground_truth))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment