Commit 4f38c826 authored by Christina Floristean's avatar Christina Floristean
Browse files

Added alignment indexing to multimer pipeline and fixed jackhmmer query return type

parent bcabb8e3
......@@ -451,7 +451,11 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def __getitem__(self, idx):
mmcif_id = self.idx_to_mmcif_id(idx)
alignment_index = None
if self.alignment_index is not None:
alignment_index = {k: v for k, v in self.alignment_index.items()
if f'{mmcif_id}_' in k}
if self.mode == 'train' or self.mode == 'eval':
path = os.path.join(self.data_dir, f"{mmcif_id}")
......@@ -476,7 +480,8 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
path = os.path.join(self.data_dir, f"{mmcif_id}.fasta")
data = self.data_pipeline.process_fasta(
fasta_path=path,
alignment_dir=self.alignment_dir
alignment_dir=self.alignment_dir,
alignment_index=alignment_index
)
if self._output_raw:
......
......@@ -794,7 +794,7 @@ class DataPipeline:
def _get_msas(self,
alignment_dir: str,
input_sequence: Optional[str] = None,
alignment_index: Optional[str] = None,
alignment_index: Optional[Any] = None,
):
msa_data = self._parse_msa_data(alignment_dir, alignment_index)
if(len(msa_data) == 0):
......@@ -814,7 +814,7 @@ class DataPipeline:
self,
alignment_dir: str,
input_sequence: Optional[str] = None,
alignment_index: Optional[str] = None
alignment_index: Optional[Any] = None
) -> Mapping[str, Any]:
msas = self._get_msas(
......@@ -846,7 +846,7 @@ class DataPipeline:
self,
fasta_path: str,
alignment_dir: str,
alignment_index: Optional[str] = None,
alignment_index: Optional[Any] = None,
seqemb_mode: bool = False,
) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file"""
......@@ -899,7 +899,7 @@ class DataPipeline:
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str,
chain_id: Optional[str] = None,
alignment_index: Optional[str] = None,
alignment_index: Optional[Any] = None,
seqemb_mode: bool = False,
) -> FeatureDict:
"""
......@@ -946,7 +946,7 @@ class DataPipeline:
is_distillation: bool = True,
chain_id: Optional[str] = None,
_structure_index: Optional[str] = None,
alignment_index: Optional[str] = None,
alignment_index: Optional[Any] = None,
seqemb_mode: bool = False,
) -> FeatureDict:
"""
......@@ -1000,7 +1000,7 @@ class DataPipeline:
self,
core_path: str,
alignment_dir: str,
alignment_index: Optional[str] = None,
alignment_index: Optional[Any] = None,
seqemb_mode: bool = False,
) -> FeatureDict:
"""
......@@ -1156,39 +1156,59 @@ class DataPipelineMultimer:
sequence: str,
description: str,
chain_alignment_dir: str,
chain_alignment_index: Optional[Any],
is_homomer_or_monomer: bool
) -> FeatureDict:
"""Runs the monomer pipeline on a single chain."""
chain_fasta_str = f'>{chain_id}\n{sequence}\n'
if not os.path.exists(chain_alignment_dir):
if chain_alignment_index is None and not os.path.exists(chain_alignment_dir):
raise ValueError(f"Alignments for {chain_id} not found...")
with temp_fasta_file(chain_fasta_str) as chain_fasta_path:
chain_features = self._monomer_data_pipeline.process_fasta(
fasta_path=chain_fasta_path,
alignment_dir=chain_alignment_dir
alignment_dir=chain_alignment_dir,
alignment_index=chain_alignment_index
)
# We only construct the pairing features if there are 2 or more unique
# sequences.
if not is_homomer_or_monomer:
all_seq_msa_features = self._all_seq_msa_features(
chain_alignment_dir
chain_alignment_dir,
chain_alignment_index
)
chain_features.update(all_seq_msa_features)
return chain_features
@staticmethod
def _all_seq_msa_features(alignment_dir):
def _all_seq_msa_features(alignment_dir, alignment_index):
"""Get MSA features for unclustered uniprot, for pairing."""
uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.sto")
if not os.path.exists(uniprot_msa_path):
chain_id = os.path.basename(os.path.normpath(alignment_dir))
raise ValueError(f"Missing 'uniprot_hits.sto' for {chain_id}. "
f"This is required for Multimer MSA pairing.")
with open(uniprot_msa_path, "r") as fp:
uniprot_msa_string = fp.read()
msa = parsers.parse_stockholm(uniprot_msa_string)
if alignment_index is not None:
fp = open(os.path.join(alignment_dir, alignment_index["db"]), "rb")
def read_msa(start, size):
fp.seek(start)
msa = fp.read(size).decode("utf-8")
return msa
start, size = next(iter((start, size) for name, start, size in alignment_index["files"]
if name == 'uniprot_hits.sto'))
msa = parsers.parse_stockholm(read_msa(start, size))
fp.close()
else:
uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.sto")
if not os.path.exists(uniprot_msa_path):
chain_id = os.path.basename(os.path.normpath(alignment_dir))
raise ValueError(f"Missing 'uniprot_hits.sto' for {chain_id}. "
f"This is required for Multimer MSA pairing.")
with open(uniprot_msa_path, "r") as fp:
uniprot_msa_string = fp.read()
msa = parsers.parse_stockholm(uniprot_msa_string)
all_seq_features = make_msa_features([msa])
valid_feats = msa_pairing.MSA_FEATURES + (
'msa_species_identifiers',
......@@ -1202,15 +1222,14 @@ class DataPipelineMultimer:
def process_fasta(self,
fasta_path: str,
alignment_dir: str,
alignment_index: Optional[Any] = None
) -> FeatureDict:
"""Creates features."""
with open(fasta_path) as f:
input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
all_chain_features = {}
sequence_features = {}
is_homomer_or_monomer = len(set(input_seqs)) == 1
......@@ -1221,16 +1240,22 @@ class DataPipelineMultimer:
)
continue
if alignment_index is not None:
chain_alignment_index = alignment_index.get(desc)
chain_alignment_dir = alignment_dir
else:
chain_alignment_index = None
chain_alignment_dir = os.path.join(alignment_dir, desc)
chain_features = self._process_single_chain(
chain_id=desc,
sequence=seq,
description=desc,
chain_alignment_dir=os.path.join(alignment_dir, desc),
chain_alignment_dir=chain_alignment_dir,
chain_alignment_index=chain_alignment_index,
is_homomer_or_monomer=is_homomer_or_monomer
)
chain_features = convert_monomer_features(
chain_features,
chain_id=desc
......@@ -1238,15 +1263,12 @@ class DataPipelineMultimer:
all_chain_features[desc] = chain_features
sequence_features[seq] = chain_features
all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing_multimer.pair_and_merge(
all_chain_features=all_chain_features,
)
# Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512)
......@@ -1279,55 +1301,54 @@ class DataPipelineMultimer:
self,
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str,
alignment_index: Optional[str] = None,
alignment_index: Optional[Any] = None,
) -> FeatureDict:
all_chain_features = {}
sequence_features = {}
is_homomer_or_monomer = len(set(list(mmcif.chain_to_seqres.values()))) == 1
for chain_id, seq in mmcif.chain_to_seqres.items():
desc= "_".join([mmcif.file_id, chain_id])
if seq in sequence_features:
all_chain_features[desc] = copy.deepcopy(
sequence_features[seq]
)
continue
if alignment_index is not None:
chain_alignment_index = alignment_index.get(desc)
chain_alignment_dir = alignment_dir
else:
chain_alignment_index = None
chain_alignment_dir = os.path.join(alignment_dir, desc)
chain_features = self._process_single_chain(
chain_id=desc,
sequence=seq,
description=desc,
chain_alignment_dir=os.path.join(alignment_dir, desc),
chain_alignment_dir=chain_alignment_dir,
chain_alignment_index=chain_alignment_index,
is_homomer_or_monomer=is_homomer_or_monomer
)
chain_features = convert_monomer_features(
chain_features,
chain_id=desc
)
mmcif_feats = self.get_mmcif_features(mmcif, chain_id)
chain_features.update(mmcif_feats)
all_chain_features[desc] = chain_features
sequence_features[seq] = chain_features
all_chain_features = add_assembly_features(all_chain_features)
np_example = feature_processing_multimer.pair_and_merge(
all_chain_features=all_chain_features,
)
# Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512)
return np_example
\ No newline at end of file
......@@ -190,11 +190,11 @@ class Jackhmmer:
def query(self,
input_fasta_path: str,
max_sequences: Optional[int] = None
) -> Sequence[Mapping[str, Any]]:
return self.query_multiple([input_fasta_path], max_sequences)[0]
) -> Sequence[Sequence[Mapping[str, Any]]]:
return self.query_multiple([input_fasta_path], max_sequences)
def query_multiple(self,
input_fasta_paths: str,
input_fasta_paths: Sequence[str],
max_sequences: Optional[int] = None
) -> Sequence[Sequence[Mapping[str, Any]]]:
"""Queries the database using Jackhmmer."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment