Commit 56632d44 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add non-distillation PDB option to data module

parent f707a9ea
...@@ -33,6 +33,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -33,6 +33,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
max_template_hits: int = 4, max_template_hits: int = 4,
template_release_dates_cache_path: Optional[str] = None, template_release_dates_cache_path: Optional[str] = None,
shuffle_top_k_prefiltered: Optional[int] = None, shuffle_top_k_prefiltered: Optional[int] = None,
treat_pdb_as_distillation: bool = True,
mode: str = "train", mode: str = "train",
_output_raw: bool = False, _output_raw: bool = False,
): ):
...@@ -71,6 +72,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -71,6 +72,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
parsing max_template_hits of them. Can be used to parsing max_template_hits of them. Can be used to
approximate DeepMind's training-time template subsampling approximate DeepMind's training-time template subsampling
scheme much more performantly. scheme much more performantly.
treat_pdb_as_distillation:
Whether to assume that .pdb files in the data_dir are from
the self-distillation set (and should be subjected to
special distillation set preprocessing steps).
mode: mode:
"train", "val", or "predict" "train", "val", or "predict"
""" """
...@@ -78,6 +83,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -78,6 +83,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self.data_dir = data_dir self.data_dir = data_dir
self.alignment_dir = alignment_dir self.alignment_dir = alignment_dir
self.config = config self.config = config
self.treat_pdb_as_distillation = treat_pdb_as_distillation
self.mode = mode self.mode = mode
self._output_raw = _output_raw self._output_raw = _output_raw
...@@ -162,10 +168,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -162,10 +168,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
path + ".core", alignment_dir path + ".core", alignment_dir
) )
else: else:
# Try to search for a distillation PDB file instead
data = self.data_pipeline.process_pdb( data = self.data_pipeline.process_pdb(
pdb_path=path + ".pdb", pdb_path=path + ".pdb",
alignment_dir=alignment_dir alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id,
) )
else: else:
path = os.path.join(name, name + ".fasta") path = os.path.join(name, name + ".fasta")
...@@ -387,6 +394,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -387,6 +394,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_hits=self.config.train.max_template_hits, max_template_hits=self.config.train.max_template_hits,
shuffle_top_k_prefiltered= shuffle_top_k_prefiltered=
self.config.train.shuffle_top_k_prefiltered, self.config.train.shuffle_top_k_prefiltered,
treat_pdb_as_distillation=False,
mode="train", mode="train",
_output_raw=True, _output_raw=True,
) )
...@@ -397,6 +405,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -397,6 +405,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
alignment_dir=self.distillation_alignment_dir, alignment_dir=self.distillation_alignment_dir,
mapping_path=self.distillation_mapping_path, mapping_path=self.distillation_mapping_path,
max_template_hits=self.train.max_template_hits, max_template_hits=self.train.max_template_hits,
treat_pdb_as_distillation=True,
mode="train", mode="train",
_output_raw=True, _output_raw=True,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment