Commit 1abe6160 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

fix

parent 5e341f60
......@@ -24,11 +24,11 @@ from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap
class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __init__(self,
data_dir: str,
chain_data_cache_path: str,
alignment_dir: str,
template_mmcif_dir: str,
max_template_date: str,
config: mlc.ConfigDict,
chain_data_cache_path: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign',
max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None,
......@@ -82,6 +82,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
super(OpenFoldSingleDataset, self).__init__()
self.data_dir = data_dir
self.chain_data_cache = None
if chain_data_cache_path is not None:
with open(chain_data_cache_path, "r") as fp:
self.chain_data_cache = json.load(fp)
assert isinstance(self.chain_data_cache, dict)
......@@ -617,6 +619,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if(self.training_mode):
train_dataset = dataset_gen(
data_dir=self.train_data_dir,
chain_data_cache_path=self.train_chain_data_cache_path,
alignment_dir=self.train_alignment_dir,
filter_path=self.train_filter_path,
max_template_hits=self.config.train.max_template_hits,
......@@ -631,6 +634,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if(self.distillation_data_dir is not None):
distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir,
chain_data_cache_path=self.distillation_chain_data_cache_path,
alignment_dir=self.distillation_alignment_dir,
filter_path=self.distillation_filter_path,
max_template_hits=self.config.train.max_template_hits,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment