"...llm/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "65a2dfabc6e893b11d700d29906a84c5801c0a08"
Commit 4f38c826 authored by Christina Floristean's avatar Christina Floristean
Browse files

Added alignment indexing to multimer pipeline and fixed jackhmmer query return type

parent bcabb8e3
...@@ -451,7 +451,11 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -451,7 +451,11 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
mmcif_id = self.idx_to_mmcif_id(idx) mmcif_id = self.idx_to_mmcif_id(idx)
alignment_index = None alignment_index = None
if self.alignment_index is not None:
alignment_index = {k: v for k, v in self.alignment_index.items()
if f'{mmcif_id}_' in k}
if self.mode == 'train' or self.mode == 'eval': if self.mode == 'train' or self.mode == 'eval':
path = os.path.join(self.data_dir, f"{mmcif_id}") path = os.path.join(self.data_dir, f"{mmcif_id}")
...@@ -476,7 +480,8 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -476,7 +480,8 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
path = os.path.join(self.data_dir, f"{mmcif_id}.fasta") path = os.path.join(self.data_dir, f"{mmcif_id}.fasta")
data = self.data_pipeline.process_fasta( data = self.data_pipeline.process_fasta(
fasta_path=path, fasta_path=path,
alignment_dir=self.alignment_dir alignment_dir=self.alignment_dir,
alignment_index=alignment_index
) )
if self._output_raw: if self._output_raw:
......
...@@ -794,7 +794,7 @@ class DataPipeline: ...@@ -794,7 +794,7 @@ class DataPipeline:
def _get_msas(self, def _get_msas(self,
alignment_dir: str, alignment_dir: str,
input_sequence: Optional[str] = None, input_sequence: Optional[str] = None,
alignment_index: Optional[str] = None, alignment_index: Optional[Any] = None,
): ):
msa_data = self._parse_msa_data(alignment_dir, alignment_index) msa_data = self._parse_msa_data(alignment_dir, alignment_index)
if(len(msa_data) == 0): if(len(msa_data) == 0):
...@@ -814,7 +814,7 @@ class DataPipeline: ...@@ -814,7 +814,7 @@ class DataPipeline:
self, self,
alignment_dir: str, alignment_dir: str,
input_sequence: Optional[str] = None, input_sequence: Optional[str] = None,
alignment_index: Optional[str] = None alignment_index: Optional[Any] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
msas = self._get_msas( msas = self._get_msas(
...@@ -846,7 +846,7 @@ class DataPipeline: ...@@ -846,7 +846,7 @@ class DataPipeline:
self, self,
fasta_path: str, fasta_path: str,
alignment_dir: str, alignment_dir: str,
alignment_index: Optional[str] = None, alignment_index: Optional[Any] = None,
seqemb_mode: bool = False, seqemb_mode: bool = False,
) -> FeatureDict: ) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file""" """Assembles features for a single sequence in a FASTA file"""
...@@ -899,7 +899,7 @@ class DataPipeline: ...@@ -899,7 +899,7 @@ class DataPipeline:
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str, alignment_dir: str,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
alignment_index: Optional[str] = None, alignment_index: Optional[Any] = None,
seqemb_mode: bool = False, seqemb_mode: bool = False,
) -> FeatureDict: ) -> FeatureDict:
""" """
...@@ -946,7 +946,7 @@ class DataPipeline: ...@@ -946,7 +946,7 @@ class DataPipeline:
is_distillation: bool = True, is_distillation: bool = True,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
_structure_index: Optional[str] = None, _structure_index: Optional[str] = None,
alignment_index: Optional[str] = None, alignment_index: Optional[Any] = None,
seqemb_mode: bool = False, seqemb_mode: bool = False,
) -> FeatureDict: ) -> FeatureDict:
""" """
...@@ -1000,7 +1000,7 @@ class DataPipeline: ...@@ -1000,7 +1000,7 @@ class DataPipeline:
self, self,
core_path: str, core_path: str,
alignment_dir: str, alignment_dir: str,
alignment_index: Optional[str] = None, alignment_index: Optional[Any] = None,
seqemb_mode: bool = False, seqemb_mode: bool = False,
) -> FeatureDict: ) -> FeatureDict:
""" """
...@@ -1156,39 +1156,59 @@ class DataPipelineMultimer: ...@@ -1156,39 +1156,59 @@ class DataPipelineMultimer:
sequence: str, sequence: str,
description: str, description: str,
chain_alignment_dir: str, chain_alignment_dir: str,
chain_alignment_index: Optional[Any],
is_homomer_or_monomer: bool is_homomer_or_monomer: bool
) -> FeatureDict: ) -> FeatureDict:
"""Runs the monomer pipeline on a single chain.""" """Runs the monomer pipeline on a single chain."""
chain_fasta_str = f'>{chain_id}\n{sequence}\n' chain_fasta_str = f'>{chain_id}\n{sequence}\n'
if not os.path.exists(chain_alignment_dir):
if chain_alignment_index is None and not os.path.exists(chain_alignment_dir):
raise ValueError(f"Alignments for {chain_id} not found...") raise ValueError(f"Alignments for {chain_id} not found...")
with temp_fasta_file(chain_fasta_str) as chain_fasta_path: with temp_fasta_file(chain_fasta_str) as chain_fasta_path:
chain_features = self._monomer_data_pipeline.process_fasta( chain_features = self._monomer_data_pipeline.process_fasta(
fasta_path=chain_fasta_path, fasta_path=chain_fasta_path,
alignment_dir=chain_alignment_dir alignment_dir=chain_alignment_dir,
alignment_index=chain_alignment_index
) )
# We only construct the pairing features if there are 2 or more unique # We only construct the pairing features if there are 2 or more unique
# sequences. # sequences.
if not is_homomer_or_monomer: if not is_homomer_or_monomer:
all_seq_msa_features = self._all_seq_msa_features( all_seq_msa_features = self._all_seq_msa_features(
chain_alignment_dir chain_alignment_dir,
chain_alignment_index
) )
chain_features.update(all_seq_msa_features) chain_features.update(all_seq_msa_features)
return chain_features return chain_features
@staticmethod @staticmethod
def _all_seq_msa_features(alignment_dir): def _all_seq_msa_features(alignment_dir, alignment_index):
"""Get MSA features for unclustered uniprot, for pairing.""" """Get MSA features for unclustered uniprot, for pairing."""
uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.sto") if alignment_index is not None:
if not os.path.exists(uniprot_msa_path): fp = open(os.path.join(alignment_dir, alignment_index["db"]), "rb")
chain_id = os.path.basename(os.path.normpath(alignment_dir))
raise ValueError(f"Missing 'uniprot_hits.sto' for {chain_id}. " def read_msa(start, size):
f"This is required for Multimer MSA pairing.") fp.seek(start)
msa = fp.read(size).decode("utf-8")
with open(uniprot_msa_path, "r") as fp: return msa
uniprot_msa_string = fp.read()
msa = parsers.parse_stockholm(uniprot_msa_string) start, size = next(iter((start, size) for name, start, size in alignment_index["files"]
if name == 'uniprot_hits.sto'))
msa = parsers.parse_stockholm(read_msa(start, size))
fp.close()
else:
uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.sto")
if not os.path.exists(uniprot_msa_path):
chain_id = os.path.basename(os.path.normpath(alignment_dir))
raise ValueError(f"Missing 'uniprot_hits.sto' for {chain_id}. "
f"This is required for Multimer MSA pairing.")
with open(uniprot_msa_path, "r") as fp:
uniprot_msa_string = fp.read()
msa = parsers.parse_stockholm(uniprot_msa_string)
all_seq_features = make_msa_features([msa]) all_seq_features = make_msa_features([msa])
valid_feats = msa_pairing.MSA_FEATURES + ( valid_feats = msa_pairing.MSA_FEATURES + (
'msa_species_identifiers', 'msa_species_identifiers',
...@@ -1202,15 +1222,14 @@ class DataPipelineMultimer: ...@@ -1202,15 +1222,14 @@ class DataPipelineMultimer:
def process_fasta(self, def process_fasta(self,
fasta_path: str, fasta_path: str,
alignment_dir: str, alignment_dir: str,
alignment_index: Optional[Any] = None
) -> FeatureDict: ) -> FeatureDict:
"""Creates features.""" """Creates features."""
with open(fasta_path) as f: with open(fasta_path) as f:
input_fasta_str = f.read() input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
all_chain_features = {} all_chain_features = {}
sequence_features = {} sequence_features = {}
is_homomer_or_monomer = len(set(input_seqs)) == 1 is_homomer_or_monomer = len(set(input_seqs)) == 1
...@@ -1221,16 +1240,22 @@ class DataPipelineMultimer: ...@@ -1221,16 +1240,22 @@ class DataPipelineMultimer:
) )
continue continue
if alignment_index is not None:
chain_alignment_index = alignment_index.get(desc)
chain_alignment_dir = alignment_dir
else:
chain_alignment_index = None
chain_alignment_dir = os.path.join(alignment_dir, desc)
chain_features = self._process_single_chain( chain_features = self._process_single_chain(
chain_id=desc, chain_id=desc,
sequence=seq, sequence=seq,
description=desc, description=desc,
chain_alignment_dir=os.path.join(alignment_dir, desc), chain_alignment_dir=chain_alignment_dir,
chain_alignment_index=chain_alignment_index,
is_homomer_or_monomer=is_homomer_or_monomer is_homomer_or_monomer=is_homomer_or_monomer
) )
chain_features = convert_monomer_features( chain_features = convert_monomer_features(
chain_features, chain_features,
chain_id=desc chain_id=desc
...@@ -1238,15 +1263,12 @@ class DataPipelineMultimer: ...@@ -1238,15 +1263,12 @@ class DataPipelineMultimer:
all_chain_features[desc] = chain_features all_chain_features[desc] = chain_features
sequence_features[seq] = chain_features sequence_features[seq] = chain_features
all_chain_features = add_assembly_features(all_chain_features) all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing_multimer.pair_and_merge( np_example = feature_processing_multimer.pair_and_merge(
all_chain_features=all_chain_features, all_chain_features=all_chain_features,
) )
# Pad MSA to avoid zero-sized extra_msa. # Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512) np_example = pad_msa(np_example, 512)
...@@ -1279,55 +1301,54 @@ class DataPipelineMultimer: ...@@ -1279,55 +1301,54 @@ class DataPipelineMultimer:
self, self,
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str, alignment_dir: str,
alignment_index: Optional[str] = None, alignment_index: Optional[Any] = None,
) -> FeatureDict: ) -> FeatureDict:
all_chain_features = {} all_chain_features = {}
sequence_features = {} sequence_features = {}
is_homomer_or_monomer = len(set(list(mmcif.chain_to_seqres.values()))) == 1 is_homomer_or_monomer = len(set(list(mmcif.chain_to_seqres.values()))) == 1
for chain_id, seq in mmcif.chain_to_seqres.items(): for chain_id, seq in mmcif.chain_to_seqres.items():
desc= "_".join([mmcif.file_id, chain_id]) desc= "_".join([mmcif.file_id, chain_id])
if seq in sequence_features: if seq in sequence_features:
all_chain_features[desc] = copy.deepcopy( all_chain_features[desc] = copy.deepcopy(
sequence_features[seq] sequence_features[seq]
) )
continue continue
if alignment_index is not None:
chain_alignment_index = alignment_index.get(desc)
chain_alignment_dir = alignment_dir
else:
chain_alignment_index = None
chain_alignment_dir = os.path.join(alignment_dir, desc)
chain_features = self._process_single_chain( chain_features = self._process_single_chain(
chain_id=desc, chain_id=desc,
sequence=seq, sequence=seq,
description=desc, description=desc,
chain_alignment_dir=os.path.join(alignment_dir, desc), chain_alignment_dir=chain_alignment_dir,
chain_alignment_index=chain_alignment_index,
is_homomer_or_monomer=is_homomer_or_monomer is_homomer_or_monomer=is_homomer_or_monomer
) )
chain_features = convert_monomer_features( chain_features = convert_monomer_features(
chain_features, chain_features,
chain_id=desc chain_id=desc
) )
mmcif_feats = self.get_mmcif_features(mmcif, chain_id) mmcif_feats = self.get_mmcif_features(mmcif, chain_id)
chain_features.update(mmcif_feats) chain_features.update(mmcif_feats)
all_chain_features[desc] = chain_features all_chain_features[desc] = chain_features
sequence_features[seq] = chain_features sequence_features[seq] = chain_features
all_chain_features = add_assembly_features(all_chain_features) all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing_multimer.pair_and_merge( np_example = feature_processing_multimer.pair_and_merge(
all_chain_features=all_chain_features, all_chain_features=all_chain_features,
) )
# Pad MSA to avoid zero-sized extra_msa. # Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512) np_example = pad_msa(np_example, 512)
return np_example return np_example
\ No newline at end of file
...@@ -190,11 +190,11 @@ class Jackhmmer: ...@@ -190,11 +190,11 @@ class Jackhmmer:
def query(self, def query(self,
input_fasta_path: str, input_fasta_path: str,
max_sequences: Optional[int] = None max_sequences: Optional[int] = None
) -> Sequence[Mapping[str, Any]]: ) -> Sequence[Sequence[Mapping[str, Any]]]:
return self.query_multiple([input_fasta_path], max_sequences)[0] return self.query_multiple([input_fasta_path], max_sequences)
def query_multiple(self, def query_multiple(self,
input_fasta_paths: str, input_fasta_paths: Sequence[str],
max_sequences: Optional[int] = None max_sequences: Optional[int] = None
) -> Sequence[Sequence[Mapping[str, Any]]]: ) -> Sequence[Sequence[Mapping[str, Any]]]:
"""Queries the database using Jackhmmer.""" """Queries the database using Jackhmmer."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment