Commit 3246e8ca authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Rename mapping parameter

parent b8034138
...@@ -34,7 +34,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -34,7 +34,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
template_release_dates_cache_path: Optional[str] = None, template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None, shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True, treat_pdb_as_distillation: bool = True,
mapping_path: Optional[str] = None, filter_path: Optional[str] = None,
mode: str = "train", mode: str = "train",
alignment_index: Optional[Any] = None, alignment_index: Optional[Any] = None,
_output_raw: bool = False, _output_raw: bool = False,
...@@ -102,10 +102,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -102,10 +102,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
if(alignment_index is not None): if(alignment_index is not None):
self._chain_ids = list(alignment_index.keys()) self._chain_ids = list(alignment_index.keys())
elif(mapping_path is None): elif(filter_path is None):
self._chain_ids = list(os.listdir(alignment_dir)) self._chain_ids = list(os.listdir(alignment_dir))
else: else:
with open(mapping_path, "r") as f: with open(filter_path, "r") as f:
self._chain_ids = [l.strip() for l in f.readlines()] self._chain_ids = [l.strip() for l in f.readlines()]
self._chain_id_to_idx_dict = { self._chain_id_to_idx_dict = {
...@@ -496,8 +496,8 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -496,8 +496,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
predict_data_dir: Optional[str] = None, predict_data_dir: Optional[str] = None,
predict_alignment_dir: Optional[str] = None, predict_alignment_dir: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign', kalign_binary_path: str = '/usr/bin/kalign',
train_mapping_path: Optional[str] = None, train_filter_path: Optional[str] = None,
distillation_mapping_path: Optional[str] = None, distillation_filter_path: Optional[str] = None,
obsolete_pdbs_file_path: Optional[str] = None, obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None, template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None, batch_seed: Optional[int] = None,
...@@ -525,8 +525,8 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -525,8 +525,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.predict_data_dir = predict_data_dir self.predict_data_dir = predict_data_dir
self.predict_alignment_dir = predict_alignment_dir self.predict_alignment_dir = predict_alignment_dir
self.kalign_binary_path = kalign_binary_path self.kalign_binary_path = kalign_binary_path
self.train_mapping_path = train_mapping_path self.train_filter_path = train_filter_path
self.distillation_mapping_path = distillation_mapping_path self.distillation_filter_path = distillation_filter_path
self.template_release_dates_cache_path = ( self.template_release_dates_cache_path = (
template_release_dates_cache_path template_release_dates_cache_path
) )
...@@ -589,7 +589,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -589,7 +589,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
train_dataset = dataset_gen( train_dataset = dataset_gen(
data_dir=self.train_data_dir, data_dir=self.train_data_dir,
alignment_dir=self.train_alignment_dir, alignment_dir=self.train_alignment_dir,
mapping_path=self.train_mapping_path, filter_path=self.train_filter_path,
max_template_hits=self.config.train.max_template_hits, max_template_hits=self.config.train.max_template_hits,
shuffle_top_k_prefiltered= shuffle_top_k_prefiltered=
self.config.train.shuffle_top_k_prefiltered, self.config.train.shuffle_top_k_prefiltered,
...@@ -603,7 +603,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -603,7 +603,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
distillation_dataset = dataset_gen( distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir, data_dir=self.distillation_data_dir,
alignment_dir=self.distillation_alignment_dir, alignment_dir=self.distillation_alignment_dir,
mapping_path=self.distillation_mapping_path, filter_path=self.distillation_filter_path,
max_template_hits=self.config.train.max_template_hits, max_template_hits=self.config.train.max_template_hits,
treat_pdb_as_distillation=True, treat_pdb_as_distillation=True,
mode="train", mode="train",
...@@ -645,7 +645,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -645,7 +645,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.eval_dataset = dataset_gen( self.eval_dataset = dataset_gen(
data_dir=self.val_data_dir, data_dir=self.val_data_dir,
alignment_dir=self.val_alignment_dir, alignment_dir=self.val_alignment_dir,
mapping_path=None, filter_path=None,
max_template_hits=self.config.eval.max_template_hits, max_template_hits=self.config.eval.max_template_hits,
mode="eval", mode="eval",
) )
...@@ -655,7 +655,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -655,7 +655,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.predict_dataset = dataset_gen( self.predict_dataset = dataset_gen(
data_dir=self.predict_data_dir, data_dir=self.predict_data_dir,
alignment_dir=self.predict_alignment_dir, alignment_dir=self.predict_alignment_dir,
mapping_path=None, filter_path=None,
max_template_hits=self.config.predict.max_template_hits, max_template_hits=self.config.predict.max_template_hits,
mode="predict", mode="predict",
) )
......
...@@ -405,14 +405,14 @@ if __name__ == "__main__": ...@@ -405,14 +405,14 @@ if __name__ == "__main__":
help="Path to the kalign binary" help="Path to the kalign binary"
) )
parser.add_argument( parser.add_argument(
"--train_mapping_path", type=str, default=None, "--train_filter_path", type=str, default=None,
help='''Optional path to a .json file containing a mapping from help='''Optional path to a text file containing names of training
consecutive numerical indices to sample names. Used to filter examples to include, one per line. Used to filter the training
the training set''' set'''
) )
parser.add_argument( parser.add_argument(
"--distillation_mapping_path", type=str, default=None, "--distillation_filter_path", type=str, default=None,
help="""See --train_mapping_path""" help="""See --train_filter_path"""
) )
parser.add_argument( parser.add_argument(
"--obsolete_pdbs_file_path", type=str, default=None, "--obsolete_pdbs_file_path", type=str, default=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment