"git@developer.sourcefind.cn:OpenDAS/dcnv3.git" did not exist on "59c80aa2115948d0d0bfafec3d5870e01e292ed2"
Commit 3279b28d authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Remove alignment index

parent 55bf27d4
...@@ -37,7 +37,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -37,7 +37,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mapping_path: Optional[str] = None, mapping_path: Optional[str] = None,
mode: str = "train", mode: str = "train",
_output_raw: bool = False, _output_raw: bool = False,
_alignment_index: Optional[Any] = None
): ):
""" """
Args: Args:
...@@ -84,7 +83,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -84,7 +83,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self.treat_pdb_as_distillation = treat_pdb_as_distillation self.treat_pdb_as_distillation = treat_pdb_as_distillation
self.mode = mode self.mode = mode
self._output_raw = _output_raw self._output_raw = _output_raw
self._alignment_index = _alignment_index
valid_modes = ["train", "eval", "predict"] valid_modes = ["train", "eval", "predict"]
if(mode not in valid_modes): if(mode not in valid_modes):
...@@ -96,9 +94,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -96,9 +94,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
"scripts/generate_mmcif_cache.py before running OpenFold" "scripts/generate_mmcif_cache.py before running OpenFold"
) )
if(_alignment_index is not None): if(mapping_path is None):
self._chain_ids = list(_alignment_index.keys())
elif(mapping_path is None):
self._chain_ids = list(os.listdir(alignment_dir)) self._chain_ids = list(os.listdir(alignment_dir))
else: else:
with open(mapping_path, "r") as f: with open(mapping_path, "r") as f:
...@@ -125,7 +121,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -125,7 +121,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
if(not self._output_raw): if(not self._output_raw):
self.feature_pipeline = feature_pipeline.FeaturePipeline(config) self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, _alignment_index): def _parse_mmcif(self, path, file_id, chain_id, alignment_dir):
with open(path, 'r') as f: with open(path, 'r') as f:
mmcif_string = f.read() mmcif_string = f.read()
...@@ -144,7 +140,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -144,7 +140,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mmcif=mmcif_object, mmcif=mmcif_object,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
chain_id=chain_id, chain_id=chain_id,
_alignment_index=_alignment_index
) )
return data return data
...@@ -159,11 +154,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -159,11 +154,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
name = self.idx_to_chain_id(idx) name = self.idx_to_chain_id(idx)
alignment_dir = os.path.join(self.alignment_dir, name) alignment_dir = os.path.join(self.alignment_dir, name)
_alignment_index = None
if(self._alignment_index is not None):
alignment_dir = self.alignment_dir
_alignment_index = self._alignment_index[name]
if(self.mode == 'train' or self.mode == 'eval'): if(self.mode == 'train' or self.mode == 'eval'):
spl = name.rsplit('_', 1) spl = name.rsplit('_', 1)
if(len(spl) == 2): if(len(spl) == 2):
...@@ -175,11 +165,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -175,11 +165,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
path = os.path.join(self.data_dir, file_id) path = os.path.join(self.data_dir, file_id)
if(os.path.exists(path + ".cif")): if(os.path.exists(path + ".cif")):
data = self._parse_mmcif( data = self._parse_mmcif(
path + ".cif", file_id, chain_id, alignment_dir, _alignment_index, path + ".cif", file_id, chain_id, alignment_dir,
) )
elif(os.path.exists(path + ".core")): elif(os.path.exists(path + ".core")):
data = self.data_pipeline.process_core( data = self.data_pipeline.process_core(
path + ".core", alignment_dir, _alignment_index, path + ".core", alignment_dir,
) )
elif(os.path.exists(path + ".pdb")): elif(os.path.exists(path + ".pdb")):
data = self.data_pipeline.process_pdb( data = self.data_pipeline.process_pdb(
...@@ -187,7 +177,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -187,7 +177,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation, is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id, chain_id=chain_id,
_alignment_index=_alignment_index,
) )
else: else:
raise ValueError("Invalid file type") raise ValueError("Invalid file type")
...@@ -196,7 +185,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -196,7 +185,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data = self.data_pipeline.process_fasta( data = self.data_pipeline.process_fasta(
fasta_path=path, fasta_path=path,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
_alignment_index=_alignment_index,
) )
if(self._output_raw): if(self._output_raw):
...@@ -486,7 +474,6 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -486,7 +474,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
template_release_dates_cache_path: Optional[str] = None, template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None, batch_seed: Optional[int] = None,
train_epoch_len: int = 50000, train_epoch_len: int = 50000,
_alignment_index_path: Optional[str] = None,
**kwargs **kwargs
): ):
super(OpenFoldDataModule, self).__init__() super(OpenFoldDataModule, self).__init__()
...@@ -538,12 +525,6 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -538,12 +525,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
'be specified as well' 'be specified as well'
) )
# An ad-hoc measure for our particular filesystem restrictions
self._alignment_index = None
if(_alignment_index_path is not None):
with open(_alignment_index_path, "r") as fp:
self._alignment_index = json.load(fp)
def setup(self): def setup(self):
# Most of the arguments are the same for the three datasets # Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset, dataset_gen = partial(OpenFoldSingleDataset,
...@@ -568,7 +549,6 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -568,7 +549,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
treat_pdb_as_distillation=False, treat_pdb_as_distillation=False,
mode="train", mode="train",
_output_raw=True, _output_raw=True,
_alignment_index=self._alignment_index,
) )
distillation_dataset = None distillation_dataset = None
......
...@@ -422,89 +422,42 @@ class DataPipeline: ...@@ -422,89 +422,42 @@ class DataPipeline:
def _parse_msa_data( def _parse_msa_data(
self, self,
alignment_dir: str, alignment_dir: str,
_alignment_index: Optional[Any] = None,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
msa_data = {} msa_data = {}
for f in os.listdir(alignment_dir):
if(_alignment_index is not None): path = os.path.join(alignment_dir, f)
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), "rb") ext = os.path.splitext(f)[-1]
def read_msa(start, size): if(ext == ".a3m"):
fp.seek(start) with open(path, "r") as fp:
msa = fp.read(size).decode("utf-8") msa, deletion_matrix = parsers.parse_a3m(fp.read())
return msa data = {"msa": msa, "deletion_matrix": deletion_matrix}
elif(ext == ".sto"):
for (name, start, size) in _alignment_index["files"]: with open(path, "r") as fp:
ext = os.path.splitext(name)[-1]
if(ext == ".a3m"):
msa, deletion_matrix = parsers.parse_a3m(
read_msa(start, size)
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
elif(ext == ".sto"):
msa, deletion_matrix, _ = parsers.parse_stockholm( msa, deletion_matrix, _ = parsers.parse_stockholm(
read_msa(start, size) fp.read()
) )
data = {"msa": msa, "deletion_matrix": deletion_matrix} data = {"msa": msa, "deletion_matrix": deletion_matrix}
else: else:
continue continue
msa_data[name] = data
fp.close() msa_data[f] = data
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if(ext == ".a3m"):
with open(path, "r") as fp:
msa, deletion_matrix = parsers.parse_a3m(fp.read())
data = {"msa": msa, "deletion_matrix": deletion_matrix}
elif(ext == ".sto"):
with open(path, "r") as fp:
msa, deletion_matrix, _ = parsers.parse_stockholm(
fp.read()
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else:
continue
msa_data[f] = data
return msa_data return msa_data
def _parse_template_hits( def _parse_template_hits(
self, self,
alignment_dir: str, alignment_dir: str,
_alignment_index: Optional[Any] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
all_hits = {} all_hits = {}
if(_alignment_index is not None): for f in os.listdir(alignment_dir):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), 'rb') path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
def read_template(start, size): if(ext == ".hhr"):
fp.seek(start) with open(path, "r") as fp:
return fp.read(size).decode("utf-8") hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits
for (name, start, size) in _alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if(ext == ".hhr"):
hits = parsers.parse_hhr(read_template(start, size))
all_hits[name] = hits
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if(ext == ".hhr"):
with open(path, "r") as fp:
hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits
return all_hits return all_hits
...@@ -512,9 +465,8 @@ class DataPipeline: ...@@ -512,9 +465,8 @@ class DataPipeline:
self, self,
alignment_dir: str, alignment_dir: str,
input_sequence: Optional[str] = None, input_sequence: Optional[str] = None,
_alignment_index: Optional[str] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
msa_data = self._parse_msa_data(alignment_dir, _alignment_index) msa_data = self._parse_msa_data(alignment_dir)
if(len(msa_data) == 0): if(len(msa_data) == 0):
if(input_sequence is None): if(input_sequence is None):
...@@ -544,7 +496,6 @@ class DataPipeline: ...@@ -544,7 +496,6 @@ class DataPipeline:
self, self,
fasta_path: str, fasta_path: str,
alignment_dir: str, alignment_dir: str,
_alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file""" """Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f: with open(fasta_path) as f:
...@@ -558,7 +509,7 @@ class DataPipeline: ...@@ -558,7 +509,7 @@ class DataPipeline:
input_description = input_descs[0] input_description = input_descs[0]
num_res = len(input_sequence) num_res = len(input_sequence)
hits = self._parse_template_hits(alignment_dir, _alignment_index) hits = self._parse_template_hits(alignment_dir)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
...@@ -571,7 +522,7 @@ class DataPipeline: ...@@ -571,7 +522,7 @@ class DataPipeline:
num_res=num_res, num_res=num_res,
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index) msa_features = self._process_msa_feats(alignment_dir, input_sequence)
return { return {
**sequence_features, **sequence_features,
...@@ -584,7 +535,6 @@ class DataPipeline: ...@@ -584,7 +535,6 @@ class DataPipeline:
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str, alignment_dir: str,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
_alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a specific chain in an mmCIF object. Assembles features for a specific chain in an mmCIF object.
...@@ -602,7 +552,7 @@ class DataPipeline: ...@@ -602,7 +552,7 @@ class DataPipeline:
mmcif_feats = make_mmcif_features(mmcif, chain_id) mmcif_feats = make_mmcif_features(mmcif, chain_id)
input_sequence = mmcif.chain_to_seqres[chain_id] input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits(alignment_dir, _alignment_index) hits = self._parse_template_hits(alignment_dir)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
...@@ -610,7 +560,7 @@ class DataPipeline: ...@@ -610,7 +560,7 @@ class DataPipeline:
query_release_date=to_date(mmcif.header["release_date"]) query_release_date=to_date(mmcif.header["release_date"])
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index) msa_features = self._process_msa_feats(alignment_dir, input_sequence)
return {**mmcif_feats, **template_features, **msa_features} return {**mmcif_feats, **template_features, **msa_features}
...@@ -620,7 +570,6 @@ class DataPipeline: ...@@ -620,7 +570,6 @@ class DataPipeline:
alignment_dir: str, alignment_dir: str,
is_distillation: bool = True, is_distillation: bool = True,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
_alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a protein in a PDB file. Assembles features for a protein in a PDB file.
...@@ -637,14 +586,14 @@ class DataPipeline: ...@@ -637,14 +586,14 @@ class DataPipeline:
is_distillation is_distillation
) )
hits = self._parse_template_hits(alignment_dir, _alignment_index) hits = self._parse_template_hits(alignment_dir)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
self.template_featurizer, self.template_featurizer,
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index) msa_features = self._process_msa_feats(alignment_dir, input_sequence)
return {**pdb_feats, **template_features, **msa_features} return {**pdb_feats, **template_features, **msa_features}
...@@ -652,7 +601,6 @@ class DataPipeline: ...@@ -652,7 +601,6 @@ class DataPipeline:
self, self,
core_path: str, core_path: str,
alignment_dir: str, alignment_dir: str,
_alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a protein in a ProteinNet .core file. Assembles features for a protein in a ProteinNet .core file.
...@@ -665,7 +613,7 @@ class DataPipeline: ...@@ -665,7 +613,7 @@ class DataPipeline:
description = os.path.splitext(os.path.basename(core_path))[0].upper() description = os.path.splitext(os.path.basename(core_path))[0].upper()
core_feats = make_protein_features(protein_object, description) core_feats = make_protein_features(protein_object, description)
hits = self._parse_template_hits(alignment_dir, _alignment_index) hits = self._parse_template_hits(alignment_dir)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
......
...@@ -370,9 +370,6 @@ if __name__ == "__main__": ...@@ -370,9 +370,6 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--train_epoch_len", type=int, default=10000, "--train_epoch_len", type=int, default=10000,
) )
parser.add_argument(
"--_alignment_index_path", type=str, default=None,
)
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
# Disable the initial validation pass # Disable the initial validation pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment