Commit 7b29cac4 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add batched inference

parent 98d42cf7
......@@ -92,8 +92,8 @@ extract `.core` files from ProteinNet text files.
### Inference
To run inference on a sequence using a set of DeepMind's pretrained parameters,
run e.g.:
To run inference on a sequence or multiple sequences using a set of DeepMind's
pretrained parameters, run e.g.:
```bash
python3 run_pretrained_openfold.py \
......@@ -115,8 +115,8 @@ python3 run_pretrained_openfold.py \
where `data` is the same directory as in the previous step. If `jackhmmer`,
`hhblits`, `hhsearch` and `kalign` are available at the default path of
`/usr/bin`, their `binary_path` command-line arguments can be dropped.
If you've already computed alignments for the query (see "Training"), you have
the option to circumvent the expensive alignment computation here.
If you've already computed alignments for the query, you have the option to
circumvent the expensive alignment computation here.
### Training
......
......@@ -71,13 +71,29 @@ def main(args):
feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base)
if(args.use_precomputed_alignments is None):
alignment_dir = os.path.join(output_dir_base, "alignments")
if not os.path.exists(alignment_dir):
os.makedirs(alignment_dir)
else:
alignment_dir = args.use_precomputed_alignments
logging.info("Generating features...")
# Gather input sequences
with open(args.fasta_path, "r") as fp:
lines = [l.strip() for l in fp.readlines()]
tags, seqs = lines[::2], lines[1::2]
tags = [l[1:] for l in tags]
for tag, seq in zip(tags, seqs):
fasta_path = os.path.join(args.output_dir, "tmp.fasta")
with open(fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
logging.info("Generating features...")
local_alignment_dir = os.path.join(alignment_dir, tag)
if(args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
......@@ -92,15 +108,16 @@ def main(args):
no_cpus=args.cpus,
)
alignment_runner.run(
args.fasta_path, alignment_dir
fasta_path, local_alignment_dir
)
else:
alignment_dir = args.use_precomputed_alignments
feature_dict = data_processor.process_fasta(
fasta_path=args.fasta_path, alignment_dir=alignment_dir
fasta_path=fasta_path, alignment_dir=local_alignment_dir
)
# Remove temporary FASTA file
os.remove(fasta_path)
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict',
)
......@@ -145,7 +162,7 @@ def main(args):
# Save the relaxed PDB.
relaxed_output_path = os.path.join(
args.output_dir, f'relaxed_{args.model_name}.pdb'
args.output_dir, f'{tag}_{args.model_name}.pdb'
)
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment