Commit 4f38c826 authored by Christina Floristean's avatar Christina Floristean
Browse files

Added alignment indexing to multimer pipeline and fixed jackhmmer query return type

parent bcabb8e3
...@@ -451,7 +451,11 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -451,7 +451,11 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
mmcif_id = self.idx_to_mmcif_id(idx) mmcif_id = self.idx_to_mmcif_id(idx)
alignment_index = None alignment_index = None
if self.alignment_index is not None:
alignment_index = {k: v for k, v in self.alignment_index.items()
if f'{mmcif_id}_' in k}
if self.mode == 'train' or self.mode == 'eval': if self.mode == 'train' or self.mode == 'eval':
path = os.path.join(self.data_dir, f"{mmcif_id}") path = os.path.join(self.data_dir, f"{mmcif_id}")
...@@ -476,7 +480,8 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset): ...@@ -476,7 +480,8 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
path = os.path.join(self.data_dir, f"{mmcif_id}.fasta") path = os.path.join(self.data_dir, f"{mmcif_id}.fasta")
data = self.data_pipeline.process_fasta( data = self.data_pipeline.process_fasta(
fasta_path=path, fasta_path=path,
alignment_dir=self.alignment_dir alignment_dir=self.alignment_dir,
alignment_index=alignment_index
) )
if self._output_raw: if self._output_raw:
......
...@@ -794,7 +794,7 @@ class DataPipeline: ...@@ -794,7 +794,7 @@ class DataPipeline:
def _get_msas(self, def _get_msas(self,
alignment_dir: str, alignment_dir: str,
input_sequence: Optional[str] = None, input_sequence: Optional[str] = None,
alignment_index: Optional[str] = None, alignment_index: Optional[Any] = None,
): ):
msa_data = self._parse_msa_data(alignment_dir, alignment_index) msa_data = self._parse_msa_data(alignment_dir, alignment_index)
if(len(msa_data) == 0): if(len(msa_data) == 0):
...@@ -814,7 +814,7 @@ class DataPipeline: ...@@ -814,7 +814,7 @@ class DataPipeline:
self, self,
alignment_dir: str, alignment_dir: str,
input_sequence: Optional[str] = None, input_sequence: Optional[str] = None,
alignment_index: Optional[str] = None alignment_index: Optional[Any] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
msas = self._get_msas( msas = self._get_msas(
...@@ -846,7 +846,7 @@ class DataPipeline: ...@@ -846,7 +846,7 @@ class DataPipeline:
self, self,
fasta_path: str, fasta_path: str,
alignment_dir: str, alignment_dir: str,
alignment_index: Optional[str] = None, alignment_index: Optional[Any] = None,
seqemb_mode: bool = False, seqemb_mode: bool = False,
) -> FeatureDict: ) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file""" """Assembles features for a single sequence in a FASTA file"""
...@@ -899,7 +899,7 @@ class DataPipeline: ...@@ -899,7 +899,7 @@ class DataPipeline:
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str, alignment_dir: str,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
alignment_index: Optional[str] = None, alignment_index: Optional[Any] = None,
seqemb_mode: bool = False, seqemb_mode: bool = False,
) -> FeatureDict: ) -> FeatureDict:
""" """
...@@ -946,7 +946,7 @@ class DataPipeline: ...@@ -946,7 +946,7 @@ class DataPipeline:
is_distillation: bool = True, is_distillation: bool = True,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
_structure_index: Optional[str] = None, _structure_index: Optional[str] = None,
alignment_index: Optional[str] = None, alignment_index: Optional[Any] = None,
seqemb_mode: bool = False, seqemb_mode: bool = False,
) -> FeatureDict: ) -> FeatureDict:
""" """
...@@ -1000,7 +1000,7 @@ class DataPipeline: ...@@ -1000,7 +1000,7 @@ class DataPipeline:
self, self,
core_path: str, core_path: str,
alignment_dir: str, alignment_dir: str,
alignment_index: Optional[str] = None, alignment_index: Optional[Any] = None,
seqemb_mode: bool = False, seqemb_mode: bool = False,
) -> FeatureDict: ) -> FeatureDict:
""" """
...@@ -1156,30 +1156,49 @@ class DataPipelineMultimer: ...@@ -1156,30 +1156,49 @@ class DataPipelineMultimer:
sequence: str, sequence: str,
description: str, description: str,
chain_alignment_dir: str, chain_alignment_dir: str,
chain_alignment_index: Optional[Any],
is_homomer_or_monomer: bool is_homomer_or_monomer: bool
) -> FeatureDict: ) -> FeatureDict:
"""Runs the monomer pipeline on a single chain.""" """Runs the monomer pipeline on a single chain."""
chain_fasta_str = f'>{chain_id}\n{sequence}\n' chain_fasta_str = f'>{chain_id}\n{sequence}\n'
if not os.path.exists(chain_alignment_dir):
if chain_alignment_index is None and not os.path.exists(chain_alignment_dir):
raise ValueError(f"Alignments for {chain_id} not found...") raise ValueError(f"Alignments for {chain_id} not found...")
with temp_fasta_file(chain_fasta_str) as chain_fasta_path: with temp_fasta_file(chain_fasta_str) as chain_fasta_path:
chain_features = self._monomer_data_pipeline.process_fasta( chain_features = self._monomer_data_pipeline.process_fasta(
fasta_path=chain_fasta_path, fasta_path=chain_fasta_path,
alignment_dir=chain_alignment_dir alignment_dir=chain_alignment_dir,
alignment_index=chain_alignment_index
) )
# We only construct the pairing features if there are 2 or more unique # We only construct the pairing features if there are 2 or more unique
# sequences. # sequences.
if not is_homomer_or_monomer: if not is_homomer_or_monomer:
all_seq_msa_features = self._all_seq_msa_features( all_seq_msa_features = self._all_seq_msa_features(
chain_alignment_dir chain_alignment_dir,
chain_alignment_index
) )
chain_features.update(all_seq_msa_features) chain_features.update(all_seq_msa_features)
return chain_features return chain_features
@staticmethod @staticmethod
def _all_seq_msa_features(alignment_dir): def _all_seq_msa_features(alignment_dir, alignment_index):
"""Get MSA features for unclustered uniprot, for pairing.""" """Get MSA features for unclustered uniprot, for pairing."""
if alignment_index is not None:
fp = open(os.path.join(alignment_dir, alignment_index["db"]), "rb")
def read_msa(start, size):
fp.seek(start)
msa = fp.read(size).decode("utf-8")
return msa
start, size = next(iter((start, size) for name, start, size in alignment_index["files"]
if name == 'uniprot_hits.sto'))
msa = parsers.parse_stockholm(read_msa(start, size))
fp.close()
else:
uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.sto") uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.sto")
if not os.path.exists(uniprot_msa_path): if not os.path.exists(uniprot_msa_path):
chain_id = os.path.basename(os.path.normpath(alignment_dir)) chain_id = os.path.basename(os.path.normpath(alignment_dir))
...@@ -1189,6 +1208,7 @@ class DataPipelineMultimer: ...@@ -1189,6 +1208,7 @@ class DataPipelineMultimer:
with open(uniprot_msa_path, "r") as fp: with open(uniprot_msa_path, "r") as fp:
uniprot_msa_string = fp.read() uniprot_msa_string = fp.read()
msa = parsers.parse_stockholm(uniprot_msa_string) msa = parsers.parse_stockholm(uniprot_msa_string)
all_seq_features = make_msa_features([msa]) all_seq_features = make_msa_features([msa])
valid_feats = msa_pairing.MSA_FEATURES + ( valid_feats = msa_pairing.MSA_FEATURES + (
'msa_species_identifiers', 'msa_species_identifiers',
...@@ -1202,15 +1222,14 @@ class DataPipelineMultimer: ...@@ -1202,15 +1222,14 @@ class DataPipelineMultimer:
def process_fasta(self, def process_fasta(self,
fasta_path: str, fasta_path: str,
alignment_dir: str, alignment_dir: str,
alignment_index: Optional[Any] = None
) -> FeatureDict: ) -> FeatureDict:
"""Creates features.""" """Creates features."""
with open(fasta_path) as f: with open(fasta_path) as f:
input_fasta_str = f.read() input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
all_chain_features = {} all_chain_features = {}
sequence_features = {} sequence_features = {}
is_homomer_or_monomer = len(set(input_seqs)) == 1 is_homomer_or_monomer = len(set(input_seqs)) == 1
...@@ -1221,16 +1240,22 @@ class DataPipelineMultimer: ...@@ -1221,16 +1240,22 @@ class DataPipelineMultimer:
) )
continue continue
if alignment_index is not None:
chain_alignment_index = alignment_index.get(desc)
chain_alignment_dir = alignment_dir
else:
chain_alignment_index = None
chain_alignment_dir = os.path.join(alignment_dir, desc)
chain_features = self._process_single_chain( chain_features = self._process_single_chain(
chain_id=desc, chain_id=desc,
sequence=seq, sequence=seq,
description=desc, description=desc,
chain_alignment_dir=os.path.join(alignment_dir, desc), chain_alignment_dir=chain_alignment_dir,
chain_alignment_index=chain_alignment_index,
is_homomer_or_monomer=is_homomer_or_monomer is_homomer_or_monomer=is_homomer_or_monomer
) )
chain_features = convert_monomer_features( chain_features = convert_monomer_features(
chain_features, chain_features,
chain_id=desc chain_id=desc
...@@ -1238,15 +1263,12 @@ class DataPipelineMultimer: ...@@ -1238,15 +1263,12 @@ class DataPipelineMultimer:
all_chain_features[desc] = chain_features all_chain_features[desc] = chain_features
sequence_features[seq] = chain_features sequence_features[seq] = chain_features
all_chain_features = add_assembly_features(all_chain_features) all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing_multimer.pair_and_merge( np_example = feature_processing_multimer.pair_and_merge(
all_chain_features=all_chain_features, all_chain_features=all_chain_features,
) )
# Pad MSA to avoid zero-sized extra_msa. # Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512) np_example = pad_msa(np_example, 512)
...@@ -1279,55 +1301,54 @@ class DataPipelineMultimer: ...@@ -1279,55 +1301,54 @@ class DataPipelineMultimer:
self, self,
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str, alignment_dir: str,
alignment_index: Optional[str] = None, alignment_index: Optional[Any] = None,
) -> FeatureDict: ) -> FeatureDict:
all_chain_features = {} all_chain_features = {}
sequence_features = {} sequence_features = {}
is_homomer_or_monomer = len(set(list(mmcif.chain_to_seqres.values()))) == 1 is_homomer_or_monomer = len(set(list(mmcif.chain_to_seqres.values()))) == 1
for chain_id, seq in mmcif.chain_to_seqres.items(): for chain_id, seq in mmcif.chain_to_seqres.items():
desc= "_".join([mmcif.file_id, chain_id]) desc= "_".join([mmcif.file_id, chain_id])
if seq in sequence_features: if seq in sequence_features:
all_chain_features[desc] = copy.deepcopy( all_chain_features[desc] = copy.deepcopy(
sequence_features[seq] sequence_features[seq]
) )
continue continue
if alignment_index is not None:
chain_alignment_index = alignment_index.get(desc)
chain_alignment_dir = alignment_dir
else:
chain_alignment_index = None
chain_alignment_dir = os.path.join(alignment_dir, desc)
chain_features = self._process_single_chain( chain_features = self._process_single_chain(
chain_id=desc, chain_id=desc,
sequence=seq, sequence=seq,
description=desc, description=desc,
chain_alignment_dir=os.path.join(alignment_dir, desc), chain_alignment_dir=chain_alignment_dir,
chain_alignment_index=chain_alignment_index,
is_homomer_or_monomer=is_homomer_or_monomer is_homomer_or_monomer=is_homomer_or_monomer
) )
chain_features = convert_monomer_features( chain_features = convert_monomer_features(
chain_features, chain_features,
chain_id=desc chain_id=desc
) )
mmcif_feats = self.get_mmcif_features(mmcif, chain_id) mmcif_feats = self.get_mmcif_features(mmcif, chain_id)
chain_features.update(mmcif_feats) chain_features.update(mmcif_feats)
all_chain_features[desc] = chain_features all_chain_features[desc] = chain_features
sequence_features[seq] = chain_features sequence_features[seq] = chain_features
all_chain_features = add_assembly_features(all_chain_features) all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing_multimer.pair_and_merge( np_example = feature_processing_multimer.pair_and_merge(
all_chain_features=all_chain_features, all_chain_features=all_chain_features,
) )
# Pad MSA to avoid zero-sized extra_msa. # Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512) np_example = pad_msa(np_example, 512)
return np_example return np_example
\ No newline at end of file
...@@ -190,11 +190,11 @@ class Jackhmmer: ...@@ -190,11 +190,11 @@ class Jackhmmer:
def query(self, def query(self,
input_fasta_path: str, input_fasta_path: str,
max_sequences: Optional[int] = None max_sequences: Optional[int] = None
) -> Sequence[Mapping[str, Any]]: ) -> Sequence[Sequence[Mapping[str, Any]]]:
return self.query_multiple([input_fasta_path], max_sequences)[0] return self.query_multiple([input_fasta_path], max_sequences)
def query_multiple(self, def query_multiple(self,
input_fasta_paths: str, input_fasta_paths: Sequence[str],
max_sequences: Optional[int] = None max_sequences: Optional[int] = None
) -> Sequence[Sequence[Mapping[str, Any]]]: ) -> Sequence[Sequence[Mapping[str, Any]]]:
"""Queries the database using Jackhmmer.""" """Queries the database using Jackhmmer."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment