Commit c4a4df22 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Trim OpenFoldBatchCollator

parent 954ed3d3
......@@ -37,7 +37,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mapping_path: Optional[str] = None,
mode: str = "train",
_output_raw: bool = False,
_alignment_index: Optional[Any] = None
_structure_index: Optional[Any] = None,
_alignment_index: Optional[Any] = None,
):
"""
Args:
......@@ -84,8 +85,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self.treat_pdb_as_distillation = treat_pdb_as_distillation
self.mode = mode
self._output_raw = _output_raw
self._structure_index = _structure_index
self._alignment_index = _alignment_index
self.supported_exts = [".cif", ".core", ".pdb"]
valid_modes = ["train", "eval", "predict"]
if(mode not in valid_modes):
raise ValueError(f'mode must be one of {valid_modes}')
......@@ -103,7 +107,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
else:
with open(mapping_path, "r") as f:
self._chain_ids = [l.strip() for l in f.readlines()]
self._chain_id_to_idx_dict = {
chain: i for i, chain in enumerate(self._chain_ids)
}
......@@ -173,24 +177,42 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
chain_id = None
path = os.path.join(self.data_dir, file_id)
if(os.path.exists(path + ".cif")):
structure_index_entry = None
if(self._structure_index is not None):
structure_index_entry = self._structure_index[name]
assert(len(structure_index_entry["files"]) == 1)
filename, _, _ = structure_index_entry["files"][0]
ext = os.path.splitext(filename)[1]
else:
ext = None
for e in self.supported_exts:
if(os.path.exists(path + e)):
ext = e
break
if(ext is None):
raise ValueError("Invalid file type")
path += ext
if(ext == ".cif"):
data = self._parse_mmcif(
path + ".cif", file_id, chain_id, alignment_dir, _alignment_index,
path, file_id, chain_id, alignment_dir, _alignment_index,
)
elif(os.path.exists(path + ".core")):
elif(ext == ".core"):
data = self.data_pipeline.process_core(
path + ".core", alignment_dir, _alignment_index,
path, alignment_dir, _alignment_index,
)
elif(os.path.exists(path + ".pdb")):
elif(ext == ".pdb"):
data = self.data_pipeline.process_pdb(
pdb_path=path + ".pdb",
pdb_path=path,
alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id,
_structure_index=self._structure_index[name],
_alignment_index=_alignment_index,
)
else:
raise ValueError("Invalid file type")
raise ValueError("Extension branch missing")
else:
path = os.path.join(name, name + ".fasta")
data = self.data_pipeline.process_fasta(
......@@ -206,6 +228,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data, self.mode
)
feats["batch_idx"] = torch.tensor([idx for _ in range(feats["aatype"].shape[-1])], dtype=torch.int64, device=feats["aatype"].device)
return feats
def __len__(self):
......@@ -355,20 +379,9 @@ class OpenFoldDataset(torch.utils.data.Dataset):
class OpenFoldBatchCollator:
def __init__(self, config, stage="train"):
self.stage = stage
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def __call__(self, raw_prots):
processed_prots = []
for prot in raw_prots:
features = self.feature_pipeline.process_features(
prot, self.stage
)
processed_prots.append(features)
def __call__(self, prots):
stack_fn = partial(torch.stack, dim=0)
return dict_multimap(stack_fn, processed_prots)
return dict_multimap(stack_fn, prots)
class OpenFoldDataLoader(torch.utils.data.DataLoader):
......@@ -486,7 +499,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None,
train_epoch_len: int = 50000,
_distillation_structure_index_path: Optional[str] = None,
_alignment_index_path: Optional[str] = None,
_distillation_alignment_index_path: Optional[str] = None,
**kwargs
):
super(OpenFoldDataModule, self).__init__()
......@@ -539,11 +554,21 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
# An ad-hoc measure for our particular filesystem restrictions
self._distillation_structure_index = None
if(_distillation_structure_index_path is not None):
with open(_distillation_structure_index_path, "r") as fp:
self._distillation_structure_index = json.load(fp)
self._alignment_index = None
if(_alignment_index_path is not None):
with open(_alignment_index_path, "r") as fp:
self._alignment_index = json.load(fp)
self._distillation_alignment_index = None
if(_distillation_alignment_index_path is not None):
with open(_distillation_alignment_index_path, "r") as fp:
self._distillation_alignment_index = json.load(fp)
def setup(self):
# Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset,
......@@ -567,7 +592,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.config.train.shuffle_top_k_prefiltered,
treat_pdb_as_distillation=False,
mode="train",
_output_raw=True,
_alignment_index=self._alignment_index,
)
......@@ -577,10 +601,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
data_dir=self.distillation_data_dir,
alignment_dir=self.distillation_alignment_dir,
mapping_path=self.distillation_mapping_path,
max_template_hits=self.train.max_template_hits,
max_template_hits=self.config.train.max_template_hits,
treat_pdb_as_distillation=True,
mode="train",
_output_raw=True,
_structure_index=self._distillation_structure_index,
_alignment_index=self._distillation_alignment_index,
)
d_prob = self.config.train.distillation_prob
......@@ -588,7 +613,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if(distillation_dataset is not None):
datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob
probabilities = [1 - d_prob, d_prob]
probabilities = [1. - d_prob, d_prob]
chain_data_cache_paths = [
self.train_chain_data_cache_path,
self.distillation_chain_data_cache_path,
......@@ -615,7 +640,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
mapping_path=None,
max_template_hits=self.config.eval.max_template_hits,
mode="eval",
_output_raw=True,
)
else:
self.eval_dataset = None
......@@ -646,7 +670,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
else:
raise ValueError("Invalid stage")
batch_collator = OpenFoldBatchCollator(self.config, stage)
batch_collator = OpenFoldBatchCollator()
dl = OpenFoldDataLoader(
dataset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment