Commit 1abe6160 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

fix

parent 5e341f60
...@@ -24,11 +24,11 @@ from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap ...@@ -24,11 +24,11 @@ from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap
class OpenFoldSingleDataset(torch.utils.data.Dataset): class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __init__(self, def __init__(self,
data_dir: str, data_dir: str,
chain_data_cache_path: str,
alignment_dir: str, alignment_dir: str,
template_mmcif_dir: str, template_mmcif_dir: str,
max_template_date: str, max_template_date: str,
config: mlc.ConfigDict, config: mlc.ConfigDict,
chain_data_cache_path: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign', kalign_binary_path: str = '/usr/bin/kalign',
max_template_hits: int = 4, max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None, obsolete_pdbs_file_path: Optional[str] = None,
...@@ -82,9 +82,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -82,9 +82,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
super(OpenFoldSingleDataset, self).__init__() super(OpenFoldSingleDataset, self).__init__()
self.data_dir = data_dir self.data_dir = data_dir
with open(chain_data_cache_path, "r") as fp: self.chain_data_cache = None
self.chain_data_cache = json.load(fp) if chain_data_cache_path is not None:
assert isinstance(self.chain_data_cache, dict) with open(chain_data_cache_path, "r") as fp:
self.chain_data_cache = json.load(fp)
assert isinstance(self.chain_data_cache, dict)
self.alignment_dir = alignment_dir self.alignment_dir = alignment_dir
self.config = config self.config = config
...@@ -617,6 +619,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -617,6 +619,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if(self.training_mode): if(self.training_mode):
train_dataset = dataset_gen( train_dataset = dataset_gen(
data_dir=self.train_data_dir, data_dir=self.train_data_dir,
chain_data_cache_path=self.train_chain_data_cache_path,
alignment_dir=self.train_alignment_dir, alignment_dir=self.train_alignment_dir,
filter_path=self.train_filter_path, filter_path=self.train_filter_path,
max_template_hits=self.config.train.max_template_hits, max_template_hits=self.config.train.max_template_hits,
...@@ -631,6 +634,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -631,6 +634,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if(self.distillation_data_dir is not None): if(self.distillation_data_dir is not None):
distillation_dataset = dataset_gen( distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir, data_dir=self.distillation_data_dir,
chain_data_cache_path=self.distillation_chain_data_cache_path,
alignment_dir=self.distillation_alignment_dir, alignment_dir=self.distillation_alignment_dir,
filter_path=self.distillation_filter_path, filter_path=self.distillation_filter_path,
max_template_hits=self.config.train.max_template_hits, max_template_hits=self.config.train.max_template_hits,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment