Commit 7b29cac4 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add batched inference

parent 98d42cf7
...@@ -92,8 +92,8 @@ extract `.core` files from ProteinNet text files. ...@@ -92,8 +92,8 @@ extract `.core` files from ProteinNet text files.
### Inference ### Inference
To run inference on a sequence using a set of DeepMind's pretrained parameters, To run inference on a sequence or multiple sequences using a set of DeepMind's
run e.g.: pretrained parameters, run e.g.:
```bash ```bash
python3 run_pretrained_openfold.py \ python3 run_pretrained_openfold.py \
...@@ -115,8 +115,8 @@ python3 run_pretrained_openfold.py \ ...@@ -115,8 +115,8 @@ python3 run_pretrained_openfold.py \
where `data` is the same directory as in the previous step. If `jackhmmer`, where `data` is the same directory as in the previous step. If `jackhmmer`,
`hhblits`, `hhsearch` and `kalign` are available at the default path of `hhblits`, `hhsearch` and `kalign` are available at the default path of
`/usr/bin`, their `binary_path` command-line arguments can be dropped. `/usr/bin`, their `binary_path` command-line arguments can be dropped.
If you've already computed alignments for the query (see "Training"), you have If you've already computed alignments for the query, you have the option to
the option to circumvent the expensive alignment computation here. circumvent the expensive alignment computation here.
### Training ### Training
......
...@@ -71,84 +71,101 @@ def main(args): ...@@ -71,84 +71,101 @@ def main(args):
feature_processor = feature_pipeline.FeaturePipeline(config.data) feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base): if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base) os.makedirs(output_dir_base)
alignment_dir = os.path.join(output_dir_base, "alignments")
if not os.path.exists(alignment_dir):
os.makedirs(alignment_dir)
logging.info("Generating features...")
if(args.use_precomputed_alignments is None): if(args.use_precomputed_alignments is None):
alignment_runner = data_pipeline.AlignmentRunner( alignment_dir = os.path.join(output_dir_base, "alignments")
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
small_bfd_database_path=args.small_bfd_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
alignment_runner.run(
args.fasta_path, alignment_dir
)
else: else:
alignment_dir = args.use_precomputed_alignments alignment_dir = args.use_precomputed_alignments
feature_dict = data_processor.process_fasta( # Gather input sequences
fasta_path=args.fasta_path, alignment_dir=alignment_dir with open(args.fasta_path, "r") as fp:
) lines = [l.strip() for l in fp.readlines()]
processed_feature_dict = feature_processor.process_features( tags, seqs = lines[::2], lines[1::2]
feature_dict, mode='predict', tags = [l[1:] for l in tags]
)
for tag, seq in zip(tags, seqs):
logging.info("Executing model...") fasta_path = os.path.join(args.output_dir, "tmp.fasta")
batch = processed_feature_dict with open(fasta_path, "w") as fp:
with torch.no_grad(): fp.write(f">{tag}\n{seq}")
batch = {
k:torch.as_tensor(v, device=args.model_device) logging.info("Generating features...")
for k,v in batch.items() local_alignment_dir = os.path.join(alignment_dir, tag)
} if(args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir):
t = time.time() os.makedirs(local_alignment_dir)
out = model(batch)
logging.info(f"Inference time: {time.time() - t}") alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
# Toss out the recycling dimensions --- we don't need them anymore hhblits_binary_path=args.hhblits_binary_path,
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch) hhsearch_binary_path=args.hhsearch_binary_path,
out = tensor_tree_map(lambda x: np.array(x.cpu()), out) uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
plddt = out["plddt"] bfd_database_path=args.bfd_database_path,
mean_plddt = np.mean(plddt) uniclust30_database_path=args.uniclust30_database_path,
small_bfd_database_path=args.small_bfd_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
alignment_runner.run(
fasta_path, local_alignment_dir
)
plddt_b_factors = np.repeat( feature_dict = data_processor.process_fasta(
plddt[..., None], residue_constants.atom_type_num, axis=-1 fasta_path=fasta_path, alignment_dir=local_alignment_dir
) )
unrelaxed_protein = protein.from_prediction( # Remove temporary FASTA file
features=batch, os.remove(fasta_path)
result=out,
b_factors=plddt_b_factors
)
amber_relaxer = relax.AmberRelaxation(
**config.relax
)
# Relax the prediction. processed_feature_dict = feature_processor.process_features(
t = time.time() feature_dict, mode='predict',
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein) )
logging.info(f"Relaxation time: {time.time() - t}")
# Save the relaxed PDB. logging.info("Executing model...")
relaxed_output_path = os.path.join( batch = processed_feature_dict
args.output_dir, f'relaxed_{args.model_name}.pdb' with torch.no_grad():
) batch = {
with open(relaxed_output_path, 'w') as f: k:torch.as_tensor(v, device=args.model_device)
f.write(relaxed_pdb_str) for k,v in batch.items()
}
t = time.time()
out = model(batch)
logging.info(f"Inference time: {time.time() - t}")
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1
)
unrelaxed_protein = protein.from_prediction(
features=batch,
result=out,
b_factors=plddt_b_factors
)
amber_relaxer = relax.AmberRelaxation(
**config.relax
)
# Relax the prediction.
t = time.time()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
logging.info(f"Relaxation time: {time.time() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(
args.output_dir, f'{tag}_{args.model_name}.pdb'
)
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment