Commit 56632d44 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add non-distillation PDB option to data module

parent f707a9ea
......@@ -33,6 +33,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
max_template_hits: int = 4,
template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True,
mode: str = "train",
_output_raw: bool = False,
):
......@@ -71,6 +72,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
parsing max_template_hits of them. Can be used to
approximate DeepMind's training-time template subsampling
scheme much more performantly.
treat_pdb_as_distillation:
Whether to assume that .pdb files in the data_dir are from
the self-distillation set (and should be subjected to
special distillation set preprocessing steps).
mode:
"train", "val", or "predict"
"""
......@@ -78,6 +83,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self.data_dir = data_dir
self.alignment_dir = alignment_dir
self.config = config
self.treat_pdb_as_distillation = treat_pdb_as_distillation
self.mode = mode
self._output_raw = _output_raw
......@@ -162,10 +168,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
path + ".core", alignment_dir
)
else:
# Try to search for a distillation PDB file instead
data = self.data_pipeline.process_pdb(
pdb_path=path + ".pdb",
alignment_dir=alignment_dir
alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id,
)
else:
path = os.path.join(name, name + ".fasta")
......@@ -387,6 +394,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_hits=self.config.train.max_template_hits,
shuffle_top_k_prefiltered=
self.config.train.shuffle_top_k_prefiltered,
treat_pdb_as_distillation=False,
mode="train",
_output_raw=True,
)
......@@ -397,6 +405,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
alignment_dir=self.distillation_alignment_dir,
mapping_path=self.distillation_mapping_path,
max_template_hits=self.train.max_template_hits,
treat_pdb_as_distillation=True,
mode="train",
_output_raw=True,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment