Commit 9a617649 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Finish bringing the alignment index out of hiding

parent 805b45cc
...@@ -36,9 +36,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -36,9 +36,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
treat_pdb_as_distillation: bool = True, treat_pdb_as_distillation: bool = True,
mapping_path: Optional[str] = None, mapping_path: Optional[str] = None,
mode: str = "train", mode: str = "train",
alignment_index: Optional[Any] = None,
_output_raw: bool = False, _output_raw: bool = False,
_structure_index: Optional[Any] = None, _structure_index: Optional[Any] = None,
_alignment_index: Optional[Any] = None,
): ):
""" """
Args: Args:
...@@ -84,9 +84,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -84,9 +84,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self.config = config self.config = config
self.treat_pdb_as_distillation = treat_pdb_as_distillation self.treat_pdb_as_distillation = treat_pdb_as_distillation
self.mode = mode self.mode = mode
self.alignment_index = alignment_index
self._output_raw = _output_raw self._output_raw = _output_raw
self._structure_index = _structure_index self._structure_index = _structure_index
self._alignment_index = _alignment_index
self.supported_exts = [".cif", ".core", ".pdb"] self.supported_exts = [".cif", ".core", ".pdb"]
...@@ -100,8 +100,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -100,8 +100,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
"scripts/generate_mmcif_cache.py before running OpenFold" "scripts/generate_mmcif_cache.py before running OpenFold"
) )
if(_alignment_index is not None): if(alignment_index is not None):
self._chain_ids = list(_alignment_index.keys()) self._chain_ids = list(alignment_index.keys())
elif(mapping_path is None): elif(mapping_path is None):
self._chain_ids = list(os.listdir(alignment_dir)) self._chain_ids = list(os.listdir(alignment_dir))
else: else:
...@@ -129,7 +129,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -129,7 +129,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
if(not self._output_raw): if(not self._output_raw):
self.feature_pipeline = feature_pipeline.FeaturePipeline(config) self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, _alignment_index): def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index):
with open(path, 'r') as f: with open(path, 'r') as f:
mmcif_string = f.read() mmcif_string = f.read()
...@@ -148,7 +148,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -148,7 +148,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mmcif=mmcif_object, mmcif=mmcif_object,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
chain_id=chain_id, chain_id=chain_id,
_alignment_index=_alignment_index alignment_index=alignment_index
) )
return data return data
...@@ -163,10 +163,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -163,10 +163,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
name = self.idx_to_chain_id(idx) name = self.idx_to_chain_id(idx)
alignment_dir = os.path.join(self.alignment_dir, name) alignment_dir = os.path.join(self.alignment_dir, name)
_alignment_index = None alignment_index = None
if(self._alignment_index is not None): if(self.alignment_index is not None):
alignment_dir = self.alignment_dir alignment_dir = self.alignment_dir
_alignment_index = self._alignment_index[name] alignment_index = self.alignment_index[name]
if(self.mode == 'train' or self.mode == 'eval'): if(self.mode == 'train' or self.mode == 'eval'):
spl = name.rsplit('_', 1) spl = name.rsplit('_', 1)
...@@ -196,11 +196,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -196,11 +196,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
path += ext path += ext
if(ext == ".cif"): if(ext == ".cif"):
data = self._parse_mmcif( data = self._parse_mmcif(
path, file_id, chain_id, alignment_dir, _alignment_index, path, file_id, chain_id, alignment_dir, alignment_index,
) )
elif(ext == ".core"): elif(ext == ".core"):
data = self.data_pipeline.process_core( data = self.data_pipeline.process_core(
path, alignment_dir, _alignment_index, path, alignment_dir, alignment_index,
) )
elif(ext == ".pdb"): elif(ext == ".pdb"):
data = self.data_pipeline.process_pdb( data = self.data_pipeline.process_pdb(
...@@ -208,8 +208,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -208,8 +208,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation, is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id, chain_id=chain_id,
alignment_index=alignment_index,
_structure_index=self._structure_index[name], _structure_index=self._structure_index[name],
_alignment_index=_alignment_index,
) )
else: else:
raise ValueError("Extension branch missing") raise ValueError("Extension branch missing")
...@@ -218,7 +218,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -218,7 +218,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data = self.data_pipeline.process_fasta( data = self.data_pipeline.process_fasta(
fasta_path=path, fasta_path=path,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
_alignment_index=_alignment_index, alignment_index=alignment_index,
) )
if(self._output_raw): if(self._output_raw):
...@@ -500,8 +500,8 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -500,8 +500,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
batch_seed: Optional[int] = None, batch_seed: Optional[int] = None,
train_epoch_len: int = 50000, train_epoch_len: int = 50000,
_distillation_structure_index_path: Optional[str] = None, _distillation_structure_index_path: Optional[str] = None,
_alignment_index_path: Optional[str] = None, alignment_index_path: Optional[str] = None,
_distillation_alignment_index_path: Optional[str] = None, distillation_alignment_index_path: Optional[str] = None,
**kwargs **kwargs
): ):
super(OpenFoldDataModule, self).__init__() super(OpenFoldDataModule, self).__init__()
...@@ -559,15 +559,15 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -559,15 +559,15 @@ class OpenFoldDataModule(pl.LightningDataModule):
with open(_distillation_structure_index_path, "r") as fp: with open(_distillation_structure_index_path, "r") as fp:
self._distillation_structure_index = json.load(fp) self._distillation_structure_index = json.load(fp)
self._alignment_index = None self.alignment_index = None
if(_alignment_index_path is not None): if(alignment_index_path is not None):
with open(_alignment_index_path, "r") as fp: with open(alignment_index_path, "r") as fp:
self._alignment_index = json.load(fp) self.alignment_index = json.load(fp)
self._distillation_alignment_index = None self.distillation_alignment_index = None
if(_distillation_alignment_index_path is not None): if(distillation_alignment_index_path is not None):
with open(_distillation_alignment_index_path, "r") as fp: with open(distillation_alignment_index_path, "r") as fp:
self._distillation_alignment_index = json.load(fp) self.distillation_alignment_index = json.load(fp)
def setup(self): def setup(self):
# Most of the arguments are the same for the three datasets # Most of the arguments are the same for the three datasets
...@@ -592,7 +592,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -592,7 +592,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.config.train.shuffle_top_k_prefiltered, self.config.train.shuffle_top_k_prefiltered,
treat_pdb_as_distillation=False, treat_pdb_as_distillation=False,
mode="train", mode="train",
_alignment_index=self._alignment_index, alignment_index=self.alignment_index,
) )
distillation_dataset = None distillation_dataset = None
...@@ -604,8 +604,8 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -604,8 +604,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_hits=self.config.train.max_template_hits, max_template_hits=self.config.train.max_template_hits,
treat_pdb_as_distillation=True, treat_pdb_as_distillation=True,
mode="train", mode="train",
alignment_index=self.distillation_alignment_index,
_structure_index=self._distillation_structure_index, _structure_index=self._distillation_structure_index,
_alignment_index=self._distillation_alignment_index,
) )
d_prob = self.config.train.distillation_prob d_prob = self.config.train.distillation_prob
...@@ -625,7 +625,6 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -625,7 +625,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.train_chain_data_cache_path, self.train_chain_data_cache_path,
] ]
generator = None
if(self.batch_seed is not None): if(self.batch_seed is not None):
generator = torch.Generator() generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed + 1) generator = generator.manual_seed(self.batch_seed + 1)
...@@ -638,7 +637,6 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -638,7 +637,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
generator=generator, generator=generator,
_roll_at_init=False, _roll_at_init=False,
) )
if(self.val_data_dir is not None): if(self.val_data_dir is not None):
self.eval_dataset = dataset_gen( self.eval_dataset = dataset_gen(
......
...@@ -462,18 +462,18 @@ class DataPipeline: ...@@ -462,18 +462,18 @@ class DataPipeline:
def _parse_msa_data( def _parse_msa_data(
self, self,
alignment_dir: str, alignment_dir: str,
_alignment_index: Optional[Any] = None, alignment_index: Optional[Any] = None,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
msa_data = {} msa_data = {}
if(_alignment_index is not None): if(alignment_index is not None):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), "rb") fp = open(os.path.join(alignment_dir, alignment_index["db"]), "rb")
def read_msa(start, size): def read_msa(start, size):
fp.seek(start) fp.seek(start)
msa = fp.read(size).decode("utf-8") msa = fp.read(size).decode("utf-8")
return msa return msa
for (name, start, size) in _alignment_index["files"]: for (name, start, size) in alignment_index["files"]:
ext = os.path.splitext(name)[-1] ext = os.path.splitext(name)[-1]
if(ext == ".a3m"): if(ext == ".a3m"):
...@@ -517,17 +517,17 @@ class DataPipeline: ...@@ -517,17 +517,17 @@ class DataPipeline:
def _parse_template_hits( def _parse_template_hits(
self, self,
alignment_dir: str, alignment_dir: str,
_alignment_index: Optional[Any] = None alignment_index: Optional[Any] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
all_hits = {} all_hits = {}
if(_alignment_index is not None): if(alignment_index is not None):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), 'rb') fp = open(os.path.join(alignment_dir, alignment_index["db"]), 'rb')
def read_template(start, size): def read_template(start, size):
fp.seek(start) fp.seek(start)
return fp.read(size).decode("utf-8") return fp.read(size).decode("utf-8")
for (name, start, size) in _alignment_index["files"]: for (name, start, size) in alignment_index["files"]:
ext = os.path.splitext(name)[-1] ext = os.path.splitext(name)[-1]
if(ext == ".hhr"): if(ext == ".hhr"):
...@@ -550,9 +550,9 @@ class DataPipeline: ...@@ -550,9 +550,9 @@ class DataPipeline:
def _get_msas(self, def _get_msas(self,
alignment_dir: str, alignment_dir: str,
input_sequence: Optional[str] = None, input_sequence: Optional[str] = None,
_alignment_index: Optional[str] = None, alignment_index: Optional[str] = None,
): ):
msa_data = self._parse_msa_data(alignment_dir, _alignment_index) msa_data = self._parse_msa_data(alignment_dir, alignment_index)
if(len(msa_data) == 0): if(len(msa_data) == 0):
if(input_sequence is None): if(input_sequence is None):
raise ValueError( raise ValueError(
...@@ -576,10 +576,10 @@ class DataPipeline: ...@@ -576,10 +576,10 @@ class DataPipeline:
self, self,
alignment_dir: str, alignment_dir: str,
input_sequence: Optional[str] = None, input_sequence: Optional[str] = None,
_alignment_index: Optional[str] = None alignment_index: Optional[str] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
msas, deletion_matrices = self._get_msas( msas, deletion_matrices = self._get_msas(
alignment_dir, input_sequence, _alignment_index alignment_dir, input_sequence, alignment_index
) )
msa_features = make_msa_features( msa_features = make_msa_features(
msas=msas, msas=msas,
...@@ -592,7 +592,7 @@ class DataPipeline: ...@@ -592,7 +592,7 @@ class DataPipeline:
self, self,
fasta_path: str, fasta_path: str,
alignment_dir: str, alignment_dir: str,
_alignment_index: Optional[str] = None, alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file""" """Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f: with open(fasta_path) as f:
...@@ -606,7 +606,7 @@ class DataPipeline: ...@@ -606,7 +606,7 @@ class DataPipeline:
input_description = input_descs[0] input_description = input_descs[0]
num_res = len(input_sequence) num_res = len(input_sequence)
hits = self._parse_template_hits(alignment_dir, _alignment_index) hits = self._parse_template_hits(alignment_dir, alignment_index)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
...@@ -619,7 +619,7 @@ class DataPipeline: ...@@ -619,7 +619,7 @@ class DataPipeline:
num_res=num_res, num_res=num_res,
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index) msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return { return {
**sequence_features, **sequence_features,
...@@ -632,7 +632,7 @@ class DataPipeline: ...@@ -632,7 +632,7 @@ class DataPipeline:
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str, alignment_dir: str,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
_alignment_index: Optional[str] = None, alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a specific chain in an mmCIF object. Assembles features for a specific chain in an mmCIF object.
...@@ -650,7 +650,7 @@ class DataPipeline: ...@@ -650,7 +650,7 @@ class DataPipeline:
mmcif_feats = make_mmcif_features(mmcif, chain_id) mmcif_feats = make_mmcif_features(mmcif, chain_id)
input_sequence = mmcif.chain_to_seqres[chain_id] input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits(alignment_dir, _alignment_index) hits = self._parse_template_hits(alignment_dir, alignment_index)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
...@@ -658,7 +658,7 @@ class DataPipeline: ...@@ -658,7 +658,7 @@ class DataPipeline:
query_release_date=to_date(mmcif.header["release_date"]) query_release_date=to_date(mmcif.header["release_date"])
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index) msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return {**mmcif_feats, **template_features, **msa_features} return {**mmcif_feats, **template_features, **msa_features}
...@@ -669,7 +669,7 @@ class DataPipeline: ...@@ -669,7 +669,7 @@ class DataPipeline:
is_distillation: bool = True, is_distillation: bool = True,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
_structure_index: Optional[str] = None, _structure_index: Optional[str] = None,
_alignment_index: Optional[str] = None, alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a protein in a PDB file. Assembles features for a protein in a PDB file.
...@@ -696,14 +696,14 @@ class DataPipeline: ...@@ -696,14 +696,14 @@ class DataPipeline:
is_distillation=is_distillation is_distillation=is_distillation
) )
hits = self._parse_template_hits(alignment_dir, _alignment_index) hits = self._parse_template_hits(alignment_dir, alignment_index)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
self.template_featurizer, self.template_featurizer,
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index) msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return {**pdb_feats, **template_features, **msa_features} return {**pdb_feats, **template_features, **msa_features}
...@@ -711,7 +711,7 @@ class DataPipeline: ...@@ -711,7 +711,7 @@ class DataPipeline:
self, self,
core_path: str, core_path: str,
alignment_dir: str, alignment_dir: str,
_alignment_index: Optional[str] = None, alignment_index: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a protein in a ProteinNet .core file. Assembles features for a protein in a ProteinNet .core file.
...@@ -724,7 +724,7 @@ class DataPipeline: ...@@ -724,7 +724,7 @@ class DataPipeline:
description = os.path.splitext(os.path.basename(core_path))[0].upper() description = os.path.splitext(os.path.basename(core_path))[0].upper()
core_feats = make_protein_features(protein_object, description) core_feats = make_protein_features(protein_object, description)
hits = self._parse_template_hits(alignment_dir, _alignment_index) hits = self._parse_template_hits(alignment_dir, alignment_index)
template_features = make_template_features( template_features = make_template_features(
input_sequence, input_sequence,
hits, hits,
...@@ -809,7 +809,7 @@ class DataPipeline: ...@@ -809,7 +809,7 @@ class DataPipeline:
alignment_dir = os.path.join( alignment_dir = os.path.join(
super_alignment_dir, desc super_alignment_dir, desc
) )
hits = self._parse_template_hits(alignment_dir, _alignment_index=None) hits = self._parse_template_hits(alignment_dir, alignment_index=None)
template_features = make_template_features( template_features = make_template_features(
seq, seq,
hits, hits,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment