Commit ddf922df authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Made initialization consistent for template_featurizer in all modules.

Added obsolete_pdbs_file and release_dates_path in all initialization points.
parent 2abd8c1d
...@@ -31,6 +31,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -31,6 +31,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
kalign_binary_path: str = '/usr/bin/kalign', kalign_binary_path: str = '/usr/bin/kalign',
mapping_path: Optional[str] = None, mapping_path: Optional[str] = None,
max_template_hits: int = 4, max_template_hits: int = 4,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None, template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None, shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True, treat_pdb_as_distillation: bool = True,
...@@ -67,6 +68,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -67,6 +68,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
from this total quantity. from this total quantity.
template_release_dates_cache_path: template_release_dates_cache_path:
Path to the output of scripts/generate_mmcif_cache. Path to the output of scripts/generate_mmcif_cache.
obsolete_pdbs_file_path:
Path to the file containing replacements for obsolete PDBs.
shuffle_top_k_prefiltered: shuffle_top_k_prefiltered:
Whether to uniformly shuffle the top k template hits before Whether to uniformly shuffle the top k template hits before
parsing max_template_hits of them. Can be used to parsing max_template_hits of them. Can be used to
...@@ -112,7 +115,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -112,7 +115,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
max_hits=max_template_hits, max_hits=max_template_hits,
kalign_binary_path=kalign_binary_path, kalign_binary_path=kalign_binary_path,
release_dates_path=template_release_dates_cache_path, release_dates_path=template_release_dates_cache_path,
obsolete_pdbs_path=None, obsolete_pdbs_path=obsolete_pdbs_file_path,
_shuffle_top_k_prefiltered=shuffle_top_k_prefiltered, _shuffle_top_k_prefiltered=shuffle_top_k_prefiltered,
) )
...@@ -325,6 +328,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -325,6 +328,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
kalign_binary_path: str = '/usr/bin/kalign', kalign_binary_path: str = '/usr/bin/kalign',
train_mapping_path: Optional[str] = None, train_mapping_path: Optional[str] = None,
distillation_mapping_path: Optional[str] = None, distillation_mapping_path: Optional[str] = None,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None, template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None, batch_seed: Optional[int] = None,
**kwargs **kwargs
...@@ -348,6 +352,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -348,6 +352,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.template_release_dates_cache_path = ( self.template_release_dates_cache_path = (
template_release_dates_cache_path template_release_dates_cache_path
) )
self.obsolete_pdbs_file_path = obsolete_pdbs_file_path
self.batch_seed = batch_seed self.batch_seed = batch_seed
if(self.train_data_dir is None and self.predict_data_dir is None): if(self.train_data_dir is None and self.predict_data_dir is None):
...@@ -384,6 +389,8 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -384,6 +389,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
kalign_binary_path=self.kalign_binary_path, kalign_binary_path=self.kalign_binary_path,
template_release_dates_cache_path= template_release_dates_cache_path=
self.template_release_dates_cache_path, self.template_release_dates_cache_path,
obsolete_pdbs_file_path=
self.obsolete_pdbs_file_path,
) )
if(self.training_mode): if(self.training_mode):
......
...@@ -57,7 +57,7 @@ def main(args): ...@@ -57,7 +57,7 @@ def main(args):
max_template_date=args.max_template_date, max_template_date=args.max_template_date,
max_hits=config.data.predict.max_templates, max_hits=config.data.predict.max_templates,
kalign_binary_path=args.kalign_binary_path, kalign_binary_path=args.kalign_binary_path,
release_dates_path=None, release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path obsolete_pdbs_path=args.obsolete_pdbs_path
) )
......
...@@ -43,3 +43,6 @@ def add_data_args(parser: argparse.ArgumentParser): ...@@ -43,3 +43,6 @@ def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
'--obsolete_pdbs_path', type=str, default=None '--obsolete_pdbs_path', type=str, default=None
) )
parser.add_argument(
'--release_dates_path', type=str, default=None
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment