"examples/multimodal/graphs/__init__.py" did not exist on "fd42de29338ed5b8b29b37fb3106145c45619429"
Commit ddf922df authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Made initialization consistent for template_featurizer in all modules.

Added obsolete_pdbs_file and release_dates_path in all initialization points.
parent 2abd8c1d
......@@ -31,6 +31,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
kalign_binary_path: str = '/usr/bin/kalign',
mapping_path: Optional[str] = None,
max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True,
......@@ -67,6 +68,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
from this total quantity.
template_release_dates_cache_path:
Path to the output of scripts/generate_mmcif_cache.
obsolete_pdbs_file_path:
Path to the file containing replacements for obsolete PDBs.
shuffle_top_k_prefiltered:
Whether to uniformly shuffle the top k template hits before
parsing max_template_hits of them. Can be used to
......@@ -112,7 +115,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
max_hits=max_template_hits,
kalign_binary_path=kalign_binary_path,
release_dates_path=template_release_dates_cache_path,
obsolete_pdbs_path=None,
obsolete_pdbs_path=obsolete_pdbs_file_path,
_shuffle_top_k_prefiltered=shuffle_top_k_prefiltered,
)
......@@ -325,6 +328,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
kalign_binary_path: str = '/usr/bin/kalign',
train_mapping_path: Optional[str] = None,
distillation_mapping_path: Optional[str] = None,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None,
**kwargs
......@@ -348,6 +352,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.template_release_dates_cache_path = (
template_release_dates_cache_path
)
self.obsolete_pdbs_file_path = obsolete_pdbs_file_path
self.batch_seed = batch_seed
if(self.train_data_dir is None and self.predict_data_dir is None):
......@@ -384,6 +389,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
kalign_binary_path=self.kalign_binary_path,
template_release_dates_cache_path=
self.template_release_dates_cache_path,
obsolete_pdbs_file_path=
self.obsolete_pdbs_file_path,
)
if(self.training_mode):
......
......@@ -57,7 +57,7 @@ def main(args):
max_template_date=args.max_template_date,
max_hits=config.data.predict.max_templates,
kalign_binary_path=args.kalign_binary_path,
release_dates_path=None,
release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path
)
......
......@@ -43,3 +43,6 @@ def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument(
'--obsolete_pdbs_path', type=str, default=None
)
parser.add_argument(
'--release_dates_path', type=str, default=None
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment