Unverified Commit 6f26b0ad authored by Dingquan Yu's avatar Dingquan Yu Committed by GitHub
Browse files

Merge branch 'multimer' into speedup-dataloader

parents 78ecfc64 58d65692
......@@ -726,7 +726,8 @@ class DataPipeline:
)
# The "hmm_output" exception is a crude way to exclude
# multimer template hits.
elif(ext == ".sto" and not "hmm_output" == filename):
# Multimer "uniprot_hits" processed separately.
elif(ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]):
msa = parsers.parse_stockholm(read_msa(start, size))
else:
continue
......
......@@ -191,12 +191,21 @@ class Jackhmmer:
input_fasta_path: str,
max_sequences: Optional[int] = None
) -> Sequence[Mapping[str, Any]]:
return self.query_multiple([input_fasta_path], max_sequences)[0]
def query_multiple(self,
input_fasta_paths: str,
max_sequences: Optional[int] = None
) -> Sequence[Sequence[Mapping[str, Any]]]:
"""Queries the database using Jackhmmer."""
if self.num_streamed_chunks is None:
single_chunk_results = []
for input_fasta_path in input_fasta_paths:
single_chunk_result = self._query_chunk(
input_fasta_path, self.database_path, max_sequences,
)
return [single_chunk_result]
single_chunk_results.append(single_chunk_result)
return single_chunk_results
db_basename = os.path.basename(self.database_path)
db_remote_chunk = lambda db_idx: f"{self.database_path}.{db_idx}"
......@@ -211,7 +220,7 @@ class Jackhmmer:
# Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
with futures.ThreadPoolExecutor(max_workers=2) as executor:
chunked_output = []
chunked_outputs = [[] for _ in range(len(input_fasta_paths))]
for i in range(1, self.num_streamed_chunks + 1):
# Copy the chunk locally
if i == 1:
......@@ -229,7 +238,8 @@ class Jackhmmer:
# Run Jackhmmer with the chunk
future.result()
chunked_output.append(
for fasta_idx, input_fasta_path in enumerate(input_fasta_paths):
chunked_outputs[fasta_idx].append(
self._query_chunk(
input_fasta_path,
db_local_chunk(i),
......@@ -239,11 +249,10 @@ class Jackhmmer:
# Remove the local copy of the chunk
os.remove(db_local_chunk(i))
future = next_future
# Do not set next_future for the last chunk so that this works
# even for databases with only 1 chunk
if(i < self.num_streamed_chunks):
future = next_future
if self.streaming_callback:
self.streaming_callback(i)
return chunked_output
return chunked_outputs
......@@ -716,7 +716,7 @@ class InvariantPointAttentionMultimer(nn.Module):
o_pt_norm = o_pt.norm(epsilon=1e-8)
if (_offload_inference):
z[0] = z[0].to(o_pt.device)
z[0] = z[0].to(o_pt.x.device)
o_pair = torch.einsum('...ijh, ...ijc->...ihc', a, z[0].to(dtype=a.dtype))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment