Commit 31c976b5 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

correct the file name

parent a4000358
...@@ -13,30 +13,20 @@ ...@@ -13,30 +13,20 @@
# limitations under the License. # limitations under the License.
from pathlib import Path from pathlib import Path
import shutil
import pickle import pickle
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
from functools import partial
import unittest import unittest
from openfold.config import model_config from openfold.config import model_config
from openfold.data.data_modules import OpenFoldDataModule from openfold.data.data_modules import OpenFoldDataModule
from openfold.utils.tensor_utils import tensor_tree_map
from tests.config import consts
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
import os import os
import io, contextlib
from tests.data_utils import (
random_template_feats,
random_extra_msa_feats,
random_affines_vector, random_affines_4x4
)
from openfold.utils.rigid_utils import (
Rotation,
Rigid,
)
class TestPermutation(unittest.TestCase): class TestMultimerDataModule(unittest.TestCase):
def setUp(self): def setUp(self):
""" """
Set up model config Set up model config
...@@ -50,9 +40,17 @@ class TestPermutation(unittest.TestCase): ...@@ -50,9 +40,17 @@ class TestPermutation(unittest.TestCase):
self.data_module = OpenFoldDataModule( self.data_module = OpenFoldDataModule(
config=self.config.data, config=self.config.data,
batch_seed=42, batch_seed=42,
template_mmcif_dir: str, template_mmcif_dir = "/g/alphafold/AlphaFold_DBs/2.3.0/pdb_mmcif/mmcif_files/",
max_template_date: str, template_release_dates_cache_path=os.path.join(os.getcwd(),"tests/test_data/mmcif_cache.json"),
val_data_dir: Optional[str] = None, max_template_date="2500-01-01",
val_alignment_dir: Optional[str] = None, train_data_dir=os.path.join(os.getcwd(),"tests/test_data/mmcifs"),
val_data_dir=os.path.join(os.getcwd(),"tests/test_data/mmcifs"),
val_alignment_dir=os.path.join(os.getcwd(),"tests/test_data/alignments/validation"),
train_alignment_dir=os.path.join(os.getcwd(),"tests/test_data/alignments/train"),
kalign_binary_path=shutil.which('kalign'),
train_chain_data_cache_path=os.path.join(os.getcwd(),"tests/test_data/train_chain_data_cache.json")
) )
def testPrepareData(self):
self.data_module.prepare_data()
self.data_module.setup()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment