Commit c4a4df22 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Trim OpenFoldBatchCollator

parent 954ed3d3
...@@ -37,7 +37,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -37,7 +37,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mapping_path: Optional[str] = None, mapping_path: Optional[str] = None,
mode: str = "train", mode: str = "train",
_output_raw: bool = False, _output_raw: bool = False,
_alignment_index: Optional[Any] = None _structure_index: Optional[Any] = None,
_alignment_index: Optional[Any] = None,
): ):
""" """
Args: Args:
...@@ -84,8 +85,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -84,8 +85,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self.treat_pdb_as_distillation = treat_pdb_as_distillation self.treat_pdb_as_distillation = treat_pdb_as_distillation
self.mode = mode self.mode = mode
self._output_raw = _output_raw self._output_raw = _output_raw
self._structure_index = _structure_index
self._alignment_index = _alignment_index self._alignment_index = _alignment_index
self.supported_exts = [".cif", ".core", ".pdb"]
valid_modes = ["train", "eval", "predict"] valid_modes = ["train", "eval", "predict"]
if(mode not in valid_modes): if(mode not in valid_modes):
raise ValueError(f'mode must be one of {valid_modes}') raise ValueError(f'mode must be one of {valid_modes}')
...@@ -103,7 +107,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -103,7 +107,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
else: else:
with open(mapping_path, "r") as f: with open(mapping_path, "r") as f:
self._chain_ids = [l.strip() for l in f.readlines()] self._chain_ids = [l.strip() for l in f.readlines()]
self._chain_id_to_idx_dict = { self._chain_id_to_idx_dict = {
chain: i for i, chain in enumerate(self._chain_ids) chain: i for i, chain in enumerate(self._chain_ids)
} }
...@@ -173,24 +177,42 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -173,24 +177,42 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
chain_id = None chain_id = None
path = os.path.join(self.data_dir, file_id) path = os.path.join(self.data_dir, file_id)
if(os.path.exists(path + ".cif")): structure_index_entry = None
if(self._structure_index is not None):
structure_index_entry = self._structure_index[name]
assert(len(structure_index_entry["files"]) == 1)
filename, _, _ = structure_index_entry["files"][0]
ext = os.path.splitext(filename)[1]
else:
ext = None
for e in self.supported_exts:
if(os.path.exists(path + e)):
ext = e
break
if(ext is None):
raise ValueError("Invalid file type")
path += ext
if(ext == ".cif"):
data = self._parse_mmcif( data = self._parse_mmcif(
path + ".cif", file_id, chain_id, alignment_dir, _alignment_index, path, file_id, chain_id, alignment_dir, _alignment_index,
) )
elif(os.path.exists(path + ".core")): elif(ext == ".core"):
data = self.data_pipeline.process_core( data = self.data_pipeline.process_core(
path + ".core", alignment_dir, _alignment_index, path, alignment_dir, _alignment_index,
) )
elif(os.path.exists(path + ".pdb")): elif(ext == ".pdb"):
data = self.data_pipeline.process_pdb( data = self.data_pipeline.process_pdb(
pdb_path=path + ".pdb", pdb_path=path,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation, is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id, chain_id=chain_id,
_structure_index=self._structure_index[name],
_alignment_index=_alignment_index, _alignment_index=_alignment_index,
) )
else: else:
raise ValueError("Invalid file type") raise ValueError("Extension branch missing")
else: else:
path = os.path.join(name, name + ".fasta") path = os.path.join(name, name + ".fasta")
data = self.data_pipeline.process_fasta( data = self.data_pipeline.process_fasta(
...@@ -206,6 +228,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -206,6 +228,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data, self.mode data, self.mode
) )
feats["batch_idx"] = torch.tensor([idx for _ in range(feats["aatype"].shape[-1])], dtype=torch.int64, device=feats["aatype"].device)
return feats return feats
def __len__(self): def __len__(self):
...@@ -355,20 +379,9 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -355,20 +379,9 @@ class OpenFoldDataset(torch.utils.data.Dataset):
class OpenFoldBatchCollator: class OpenFoldBatchCollator:
def __init__(self, config, stage="train"): def __call__(self, prots):
self.stage = stage
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def __call__(self, raw_prots):
processed_prots = []
for prot in raw_prots:
features = self.feature_pipeline.process_features(
prot, self.stage
)
processed_prots.append(features)
stack_fn = partial(torch.stack, dim=0) stack_fn = partial(torch.stack, dim=0)
return dict_multimap(stack_fn, processed_prots) return dict_multimap(stack_fn, prots)
class OpenFoldDataLoader(torch.utils.data.DataLoader): class OpenFoldDataLoader(torch.utils.data.DataLoader):
...@@ -486,7 +499,9 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -486,7 +499,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
template_release_dates_cache_path: Optional[str] = None, template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None, batch_seed: Optional[int] = None,
train_epoch_len: int = 50000, train_epoch_len: int = 50000,
_distillation_structure_index_path: Optional[str] = None,
_alignment_index_path: Optional[str] = None, _alignment_index_path: Optional[str] = None,
_distillation_alignment_index_path: Optional[str] = None,
**kwargs **kwargs
): ):
super(OpenFoldDataModule, self).__init__() super(OpenFoldDataModule, self).__init__()
...@@ -539,11 +554,21 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -539,11 +554,21 @@ class OpenFoldDataModule(pl.LightningDataModule):
) )
# An ad-hoc measure for our particular filesystem restrictions # An ad-hoc measure for our particular filesystem restrictions
self._distillation_structure_index = None
if(_distillation_structure_index_path is not None):
with open(_distillation_structure_index_path, "r") as fp:
self._distillation_structure_index = json.load(fp)
self._alignment_index = None self._alignment_index = None
if(_alignment_index_path is not None): if(_alignment_index_path is not None):
with open(_alignment_index_path, "r") as fp: with open(_alignment_index_path, "r") as fp:
self._alignment_index = json.load(fp) self._alignment_index = json.load(fp)
self._distillation_alignment_index = None
if(_distillation_alignment_index_path is not None):
with open(_distillation_alignment_index_path, "r") as fp:
self._distillation_alignment_index = json.load(fp)
def setup(self): def setup(self):
# Most of the arguments are the same for the three datasets # Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset, dataset_gen = partial(OpenFoldSingleDataset,
...@@ -567,7 +592,6 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -567,7 +592,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.config.train.shuffle_top_k_prefiltered, self.config.train.shuffle_top_k_prefiltered,
treat_pdb_as_distillation=False, treat_pdb_as_distillation=False,
mode="train", mode="train",
_output_raw=True,
_alignment_index=self._alignment_index, _alignment_index=self._alignment_index,
) )
...@@ -577,10 +601,11 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -577,10 +601,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
data_dir=self.distillation_data_dir, data_dir=self.distillation_data_dir,
alignment_dir=self.distillation_alignment_dir, alignment_dir=self.distillation_alignment_dir,
mapping_path=self.distillation_mapping_path, mapping_path=self.distillation_mapping_path,
max_template_hits=self.train.max_template_hits, max_template_hits=self.config.train.max_template_hits,
treat_pdb_as_distillation=True, treat_pdb_as_distillation=True,
mode="train", mode="train",
_output_raw=True, _structure_index=self._distillation_structure_index,
_alignment_index=self._distillation_alignment_index,
) )
d_prob = self.config.train.distillation_prob d_prob = self.config.train.distillation_prob
...@@ -588,7 +613,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -588,7 +613,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if(distillation_dataset is not None): if(distillation_dataset is not None):
datasets = [train_dataset, distillation_dataset] datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob d_prob = self.config.train.distillation_prob
probabilities = [1 - d_prob, d_prob] probabilities = [1. - d_prob, d_prob]
chain_data_cache_paths = [ chain_data_cache_paths = [
self.train_chain_data_cache_path, self.train_chain_data_cache_path,
self.distillation_chain_data_cache_path, self.distillation_chain_data_cache_path,
...@@ -615,7 +640,6 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -615,7 +640,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
mapping_path=None, mapping_path=None,
max_template_hits=self.config.eval.max_template_hits, max_template_hits=self.config.eval.max_template_hits,
mode="eval", mode="eval",
_output_raw=True,
) )
else: else:
self.eval_dataset = None self.eval_dataset = None
...@@ -646,7 +670,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -646,7 +670,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
else: else:
raise ValueError("Invalid stage") raise ValueError("Invalid stage")
batch_collator = OpenFoldBatchCollator(self.config, stage) batch_collator = OpenFoldBatchCollator()
dl = OpenFoldDataLoader( dl = OpenFoldDataLoader(
dataset, dataset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment