Commit 7b29cac4 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add batched inference

parent 98d42cf7
...@@ -92,8 +92,8 @@ extract `.core` files from ProteinNet text files. ...@@ -92,8 +92,8 @@ extract `.core` files from ProteinNet text files.
### Inference ### Inference
To run inference on a sequence using a set of DeepMind's pretrained parameters, To run inference on a sequence or multiple sequences using a set of DeepMind's
run e.g.: pretrained parameters, run e.g.:
```bash ```bash
python3 run_pretrained_openfold.py \ python3 run_pretrained_openfold.py \
...@@ -115,8 +115,8 @@ python3 run_pretrained_openfold.py \ ...@@ -115,8 +115,8 @@ python3 run_pretrained_openfold.py \
where `data` is the same directory as in the previous step. If `jackhmmer`, where `data` is the same directory as in the previous step. If `jackhmmer`,
`hhblits`, `hhsearch` and `kalign` are available at the default path of `hhblits`, `hhsearch` and `kalign` are available at the default path of
`/usr/bin`, their `binary_path` command-line arguments can be dropped. `/usr/bin`, their `binary_path` command-line arguments can be dropped.
If you've already computed alignments for the query (see "Training"), you have If you've already computed alignments for the query, you have the option to
the option to circumvent the expensive alignment computation here. circumvent the expensive alignment computation here.
### Training ### Training
......
...@@ -71,13 +71,29 @@ def main(args): ...@@ -71,13 +71,29 @@ def main(args):
feature_processor = feature_pipeline.FeaturePipeline(config.data) feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base): if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base) os.makedirs(output_dir_base)
if(args.use_precomputed_alignments is None):
alignment_dir = os.path.join(output_dir_base, "alignments") alignment_dir = os.path.join(output_dir_base, "alignments")
if not os.path.exists(alignment_dir): else:
os.makedirs(alignment_dir) alignment_dir = args.use_precomputed_alignments
logging.info("Generating features...") # Gather input sequences
with open(args.fasta_path, "r") as fp:
lines = [l.strip() for l in fp.readlines()]
tags, seqs = lines[::2], lines[1::2]
tags = [l[1:] for l in tags]
for tag, seq in zip(tags, seqs):
fasta_path = os.path.join(args.output_dir, "tmp.fasta")
with open(fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
logging.info("Generating features...")
local_alignment_dir = os.path.join(alignment_dir, tag)
if(args.use_precomputed_alignments is None): if(args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
alignment_runner = data_pipeline.AlignmentRunner( alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path, jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path, hhblits_binary_path=args.hhblits_binary_path,
...@@ -92,15 +108,16 @@ def main(args): ...@@ -92,15 +108,16 @@ def main(args):
no_cpus=args.cpus, no_cpus=args.cpus,
) )
alignment_runner.run( alignment_runner.run(
args.fasta_path, alignment_dir fasta_path, local_alignment_dir
) )
else:
alignment_dir = args.use_precomputed_alignments
feature_dict = data_processor.process_fasta( feature_dict = data_processor.process_fasta(
fasta_path=args.fasta_path, alignment_dir=alignment_dir fasta_path=fasta_path, alignment_dir=local_alignment_dir
) )
# Remove temporary FASTA file
os.remove(fasta_path)
processed_feature_dict = feature_processor.process_features( processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict', feature_dict, mode='predict',
) )
...@@ -145,7 +162,7 @@ def main(args): ...@@ -145,7 +162,7 @@ def main(args):
# Save the relaxed PDB. # Save the relaxed PDB.
relaxed_output_path = os.path.join( relaxed_output_path = os.path.join(
args.output_dir, f'relaxed_{args.model_name}.pdb' args.output_dir, f'{tag}_{args.model_name}.pdb'
) )
with open(relaxed_output_path, 'w') as f: with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str) f.write(relaxed_pdb_str)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment