Commit 71fdc063 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update test scripts

parent b61e99bc
...@@ -21,7 +21,7 @@ import numpy as np ...@@ -21,7 +21,7 @@ import numpy as np
from functools import partial from functools import partial
import unittest import unittest
from openfold.config import model_config from openfold.config import model_config
from openfold.data.data_modules import OpenFoldDataModule from openfold.data.data_modules import OpenFoldMultimerDataModule
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
import os import os
...@@ -37,20 +37,25 @@ class TestMultimerDataModule(unittest.TestCase): ...@@ -37,20 +37,25 @@ class TestMultimerDataModule(unittest.TestCase):
"model_1_multimer_v3", "model_1_multimer_v3",
train=True, train=True,
low_prec=True) low_prec=True)
self.data_module = OpenFoldDataModule( self.data_module = OpenFoldMultimerDataModule(
config=self.config.data, config=self.config.data,
batch_seed=42, batch_seed=42,
train_epoch_len=10,
template_mmcif_dir = "/g/alphafold/AlphaFold_DBs/2.3.0/pdb_mmcif/mmcif_files/", template_mmcif_dir = "/g/alphafold/AlphaFold_DBs/2.3.0/pdb_mmcif/mmcif_files/",
template_release_dates_cache_path=os.path.join(os.getcwd(),"tests/test_data/mmcif_cache.json"), template_release_dates_cache_path=os.path.join(os.getcwd(),"tests/test_data/mmcif_cache.json"),
max_template_date="2500-01-01", max_template_date="2500-01-01",
train_data_dir=os.path.join(os.getcwd(),"tests/test_data/mmcifs"), train_data_dir=os.path.join(os.getcwd(),"tests/test_data/mmcifs"),
val_data_dir=os.path.join(os.getcwd(),"tests/test_data/mmcifs"), train_alignment_dir=os.path.join(os.getcwd(),"tests/test_data/original_alignments/train"),
val_alignment_dir=os.path.join(os.getcwd(),"tests/test_data/alignments/validation"),
train_alignment_dir=os.path.join(os.getcwd(),"tests/test_data/alignments/train"),
kalign_binary_path=shutil.which('kalign'), kalign_binary_path=shutil.which('kalign'),
train_chain_data_cache_path=os.path.join(os.getcwd(),"tests/test_data/train_chain_data_cache.json") train_mmcif_data_cache_path=os.path.join(os.getcwd(),
"tests/test_data/train_mmcifs_cache.json"),
train_chain_data_cache_path=os.path.join(os.getcwd(),
"tests/test_data/train_chain_data_cache.json"),
) )
def testPrepareData(self): def testPrepareData(self):
self.data_module.prepare_data() self.data_module.prepare_data()
self.data_module.setup() self.data_module.setup()
train_dataset = self.data_module.train_dataset
# feats = next(iter(train_dataset))
# print(f"feats keys: {feats.keys()}")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment