Commit 3279b28d authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Remove alignment index

parent 55bf27d4
......@@ -37,7 +37,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mapping_path: Optional[str] = None,
mode: str = "train",
_output_raw: bool = False,
_alignment_index: Optional[Any] = None
):
"""
Args:
......@@ -84,7 +83,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self.treat_pdb_as_distillation = treat_pdb_as_distillation
self.mode = mode
self._output_raw = _output_raw
self._alignment_index = _alignment_index
valid_modes = ["train", "eval", "predict"]
if(mode not in valid_modes):
......@@ -96,9 +94,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
"scripts/generate_mmcif_cache.py before running OpenFold"
)
if(_alignment_index is not None):
self._chain_ids = list(_alignment_index.keys())
elif(mapping_path is None):
if(mapping_path is None):
self._chain_ids = list(os.listdir(alignment_dir))
else:
with open(mapping_path, "r") as f:
......@@ -125,7 +121,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
if(not self._output_raw):
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, _alignment_index):
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir):
with open(path, 'r') as f:
mmcif_string = f.read()
......@@ -144,7 +140,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mmcif=mmcif_object,
alignment_dir=alignment_dir,
chain_id=chain_id,
_alignment_index=_alignment_index
)
return data
......@@ -159,11 +154,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
name = self.idx_to_chain_id(idx)
alignment_dir = os.path.join(self.alignment_dir, name)
_alignment_index = None
if(self._alignment_index is not None):
alignment_dir = self.alignment_dir
_alignment_index = self._alignment_index[name]
if(self.mode == 'train' or self.mode == 'eval'):
spl = name.rsplit('_', 1)
if(len(spl) == 2):
......@@ -175,11 +165,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
path = os.path.join(self.data_dir, file_id)
if(os.path.exists(path + ".cif")):
data = self._parse_mmcif(
path + ".cif", file_id, chain_id, alignment_dir, _alignment_index,
path + ".cif", file_id, chain_id, alignment_dir,
)
elif(os.path.exists(path + ".core")):
data = self.data_pipeline.process_core(
path + ".core", alignment_dir, _alignment_index,
path + ".core", alignment_dir,
)
elif(os.path.exists(path + ".pdb")):
data = self.data_pipeline.process_pdb(
......@@ -187,7 +177,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id,
_alignment_index=_alignment_index,
)
else:
raise ValueError("Invalid file type")
......@@ -196,7 +185,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data = self.data_pipeline.process_fasta(
fasta_path=path,
alignment_dir=alignment_dir,
_alignment_index=_alignment_index,
)
if(self._output_raw):
......@@ -486,7 +474,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None,
train_epoch_len: int = 50000,
_alignment_index_path: Optional[str] = None,
**kwargs
):
super(OpenFoldDataModule, self).__init__()
......@@ -538,12 +525,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
'be specified as well'
)
# An ad-hoc measure for our particular filesystem restrictions
self._alignment_index = None
if(_alignment_index_path is not None):
with open(_alignment_index_path, "r") as fp:
self._alignment_index = json.load(fp)
def setup(self):
# Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset,
......@@ -568,7 +549,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
treat_pdb_as_distillation=False,
mode="train",
_output_raw=True,
_alignment_index=self._alignment_index,
)
distillation_dataset = None
......
......@@ -422,38 +422,8 @@ class DataPipeline:
def _parse_msa_data(
self,
alignment_dir: str,
_alignment_index: Optional[Any] = None,
) -> Mapping[str, Any]:
msa_data = {}
if(_alignment_index is not None):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), "rb")
def read_msa(start, size):
fp.seek(start)
msa = fp.read(size).decode("utf-8")
return msa
for (name, start, size) in _alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if(ext == ".a3m"):
msa, deletion_matrix = parsers.parse_a3m(
read_msa(start, size)
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
elif(ext == ".sto"):
msa, deletion_matrix, _ = parsers.parse_stockholm(
read_msa(start, size)
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else:
continue
msa_data[name] = data
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
......@@ -478,25 +448,8 @@ class DataPipeline:
def _parse_template_hits(
self,
alignment_dir: str,
_alignment_index: Optional[Any] = None
) -> Mapping[str, Any]:
all_hits = {}
if(_alignment_index is not None):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), 'rb')
def read_template(start, size):
fp.seek(start)
return fp.read(size).decode("utf-8")
for (name, start, size) in _alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if(ext == ".hhr"):
hits = parsers.parse_hhr(read_template(start, size))
all_hits[name] = hits
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
......@@ -512,9 +465,8 @@ class DataPipeline:
self,
alignment_dir: str,
input_sequence: Optional[str] = None,
_alignment_index: Optional[str] = None
) -> Mapping[str, Any]:
msa_data = self._parse_msa_data(alignment_dir, _alignment_index)
msa_data = self._parse_msa_data(alignment_dir)
if(len(msa_data) == 0):
if(input_sequence is None):
......@@ -544,7 +496,6 @@ class DataPipeline:
self,
fasta_path: str,
alignment_dir: str,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f:
......@@ -558,7 +509,7 @@ class DataPipeline:
input_description = input_descs[0]
num_res = len(input_sequence)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
hits = self._parse_template_hits(alignment_dir)
template_features = make_template_features(
input_sequence,
hits,
......@@ -571,7 +522,7 @@ class DataPipeline:
num_res=num_res,
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
msa_features = self._process_msa_feats(alignment_dir, input_sequence)
return {
**sequence_features,
......@@ -584,7 +535,6 @@ class DataPipeline:
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str,
chain_id: Optional[str] = None,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a specific chain in an mmCIF object.
......@@ -602,7 +552,7 @@ class DataPipeline:
mmcif_feats = make_mmcif_features(mmcif, chain_id)
input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits(alignment_dir, _alignment_index)
hits = self._parse_template_hits(alignment_dir)
template_features = make_template_features(
input_sequence,
hits,
......@@ -610,7 +560,7 @@ class DataPipeline:
query_release_date=to_date(mmcif.header["release_date"])
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
msa_features = self._process_msa_feats(alignment_dir, input_sequence)
return {**mmcif_feats, **template_features, **msa_features}
......@@ -620,7 +570,6 @@ class DataPipeline:
alignment_dir: str,
is_distillation: bool = True,
chain_id: Optional[str] = None,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a protein in a PDB file.
......@@ -637,14 +586,14 @@ class DataPipeline:
is_distillation
)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
hits = self._parse_template_hits(alignment_dir)
template_features = make_template_features(
input_sequence,
hits,
self.template_featurizer,
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
msa_features = self._process_msa_feats(alignment_dir, input_sequence)
return {**pdb_feats, **template_features, **msa_features}
......@@ -652,7 +601,6 @@ class DataPipeline:
self,
core_path: str,
alignment_dir: str,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a protein in a ProteinNet .core file.
......@@ -665,7 +613,7 @@ class DataPipeline:
description = os.path.splitext(os.path.basename(core_path))[0].upper()
core_feats = make_protein_features(protein_object, description)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
hits = self._parse_template_hits(alignment_dir)
template_features = make_template_features(
input_sequence,
hits,
......
......@@ -370,9 +370,6 @@ if __name__ == "__main__":
parser.add_argument(
"--train_epoch_len", type=int, default=10000,
)
parser.add_argument(
"--_alignment_index_path", type=str, default=None,
)
parser = pl.Trainer.add_argparse_args(parser)
# Disable the initial validation pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment