"vscode:/vscode.git/clone" did not exist on "df6b97f2f062f97aa57022eabe228f34167b5161"
Commit 42a89403 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update test script

parent 581411fa
...@@ -21,7 +21,7 @@ import numpy as np ...@@ -21,7 +21,7 @@ import numpy as np
from functools import partial from functools import partial
import unittest import unittest
from openfold.config import model_config from openfold.config import model_config
from openfold.data.data_modules import OpenFoldMultimerDataModule from openfold.data.data_modules import OpenFoldMultimerDataModule,OpenFoldDataModule
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
import os import os
...@@ -40,12 +40,12 @@ class TestMultimerDataModule(unittest.TestCase): ...@@ -40,12 +40,12 @@ class TestMultimerDataModule(unittest.TestCase):
self.data_module = OpenFoldMultimerDataModule( self.data_module = OpenFoldMultimerDataModule(
config=self.config.data, config=self.config.data,
batch_seed=42, batch_seed=42,
train_epoch_len=10, train_epoch_len=100,
template_mmcif_dir = "/g/alphafold/AlphaFold_DBs/2.3.0/pdb_mmcif/mmcif_files/", template_mmcif_dir = "/g/alphafold/AlphaFold_DBs/2.3.0/pdb_mmcif/mmcif_files/",
template_release_dates_cache_path=os.path.join(os.getcwd(),"tests/test_data/mmcif_cache.json"), template_release_dates_cache_path=os.path.join(os.getcwd(),"tests/test_data/mmcif_cache.json"),
max_template_date="2500-01-01", max_template_date="2500-01-01",
train_data_dir=os.path.join(os.getcwd(),"tests/test_data/mmcifs"), train_data_dir=os.path.join(os.getcwd(),"tests/test_data/mmcifs"),
train_alignment_dir=os.path.join(os.getcwd(),"tests/test_data/original_alignments/train"), train_alignment_dir=os.path.join(os.getcwd(),"tests/test_data/original_alignments/"),
kalign_binary_path=shutil.which('kalign'), kalign_binary_path=shutil.which('kalign'),
train_mmcif_data_cache_path=os.path.join(os.getcwd(), train_mmcif_data_cache_path=os.path.join(os.getcwd(),
"tests/test_data/train_mmcifs_cache.json"), "tests/test_data/train_mmcifs_cache.json"),
...@@ -57,5 +57,4 @@ class TestMultimerDataModule(unittest.TestCase): ...@@ -57,5 +57,4 @@ class TestMultimerDataModule(unittest.TestCase):
self.data_module.prepare_data() self.data_module.prepare_data()
self.data_module.setup() self.data_module.setup()
train_dataset = self.data_module.train_dataset train_dataset = self.data_module.train_dataset
# feats = next(iter(train_dataset)) all_chain_features,ground_truth = train_dataset[0]
# print(f"feats keys: {feats.keys()}") \ No newline at end of file
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment