Commit 3246e8ca authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Rename mapping parameter

parent b8034138
......@@ -34,7 +34,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True,
mapping_path: Optional[str] = None,
filter_path: Optional[str] = None,
mode: str = "train",
alignment_index: Optional[Any] = None,
_output_raw: bool = False,
......@@ -102,10 +102,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
if(alignment_index is not None):
self._chain_ids = list(alignment_index.keys())
elif(mapping_path is None):
elif(filter_path is None):
self._chain_ids = list(os.listdir(alignment_dir))
else:
with open(mapping_path, "r") as f:
with open(filter_path, "r") as f:
self._chain_ids = [l.strip() for l in f.readlines()]
self._chain_id_to_idx_dict = {
......@@ -496,8 +496,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
predict_data_dir: Optional[str] = None,
predict_alignment_dir: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign',
train_mapping_path: Optional[str] = None,
distillation_mapping_path: Optional[str] = None,
train_filter_path: Optional[str] = None,
distillation_filter_path: Optional[str] = None,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None,
......@@ -525,8 +525,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.predict_data_dir = predict_data_dir
self.predict_alignment_dir = predict_alignment_dir
self.kalign_binary_path = kalign_binary_path
self.train_mapping_path = train_mapping_path
self.distillation_mapping_path = distillation_mapping_path
self.train_filter_path = train_filter_path
self.distillation_filter_path = distillation_filter_path
self.template_release_dates_cache_path = (
template_release_dates_cache_path
)
......@@ -589,7 +589,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
train_dataset = dataset_gen(
data_dir=self.train_data_dir,
alignment_dir=self.train_alignment_dir,
mapping_path=self.train_mapping_path,
filter_path=self.train_filter_path,
max_template_hits=self.config.train.max_template_hits,
shuffle_top_k_prefiltered=
self.config.train.shuffle_top_k_prefiltered,
......@@ -603,7 +603,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir,
alignment_dir=self.distillation_alignment_dir,
mapping_path=self.distillation_mapping_path,
filter_path=self.distillation_filter_path,
max_template_hits=self.config.train.max_template_hits,
treat_pdb_as_distillation=True,
mode="train",
......@@ -645,7 +645,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.eval_dataset = dataset_gen(
data_dir=self.val_data_dir,
alignment_dir=self.val_alignment_dir,
mapping_path=None,
filter_path=None,
max_template_hits=self.config.eval.max_template_hits,
mode="eval",
)
......@@ -655,7 +655,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.predict_dataset = dataset_gen(
data_dir=self.predict_data_dir,
alignment_dir=self.predict_alignment_dir,
mapping_path=None,
filter_path=None,
max_template_hits=self.config.predict.max_template_hits,
mode="predict",
)
......
......@@ -405,14 +405,14 @@ if __name__ == "__main__":
help="Path to the kalign binary"
)
parser.add_argument(
"--train_mapping_path", type=str, default=None,
help='''Optional path to a .json file containing a mapping from
consecutive numerical indices to sample names. Used to filter
the training set'''
"--train_filter_path", type=str, default=None,
help='''Optional path to a text file containing names of training
examples to include, one per line. Used to filter the training
set'''
)
parser.add_argument(
"--distillation_mapping_path", type=str, default=None,
help="""See --train_mapping_path"""
"--distillation_filter_path", type=str, default=None,
help="""See --train_filter_path"""
)
parser.add_argument(
"--obsolete_pdbs_file_path", type=str, default=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment