"setup.py" did not exist on "db6454cde74ea461ab4b336323fdeb636c74df47"
Commit 7b29cac4 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add batched inference

parent 98d42cf7
......@@ -92,8 +92,8 @@ extract `.core` files from ProteinNet text files.
### Inference
To run inference on a sequence using a set of DeepMind's pretrained parameters,
run e.g.:
To run inference on a sequence or multiple sequences using a set of DeepMind's
pretrained parameters, run e.g.:
```bash
python3 run_pretrained_openfold.py \
......@@ -115,8 +115,8 @@ python3 run_pretrained_openfold.py \
where `data` is the same directory as in the previous step. If `jackhmmer`,
`hhblits`, `hhsearch` and `kalign` are available at the default path of
`/usr/bin`, their `binary_path` command-line arguments can be dropped.
If you've already computed alignments for the query (see "Training"), you have
the option to circumvent the expensive alignment computation here.
If you've already computed alignments for the query, you have the option to
circumvent the expensive alignment computation here.
### Training
......
......@@ -71,84 +71,101 @@ def main(args):
feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base)
alignment_dir = os.path.join(output_dir_base, "alignments")
if not os.path.exists(alignment_dir):
os.makedirs(alignment_dir)
logging.info("Generating features...")
if(args.use_precomputed_alignments is None):
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
small_bfd_database_path=args.small_bfd_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
alignment_runner.run(
args.fasta_path, alignment_dir
)
alignment_dir = os.path.join(output_dir_base, "alignments")
else:
alignment_dir = args.use_precomputed_alignments
feature_dict = data_processor.process_fasta(
fasta_path=args.fasta_path, alignment_dir=alignment_dir
)
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict',
)
logging.info("Executing model...")
batch = processed_feature_dict
with torch.no_grad():
batch = {
k:torch.as_tensor(v, device=args.model_device)
for k,v in batch.items()
}
t = time.time()
out = model(batch)
logging.info(f"Inference time: {time.time() - t}")
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
# Gather input sequences
with open(args.fasta_path, "r") as fp:
lines = [l.strip() for l in fp.readlines()]
tags, seqs = lines[::2], lines[1::2]
tags = [l[1:] for l in tags]
for tag, seq in zip(tags, seqs):
fasta_path = os.path.join(args.output_dir, "tmp.fasta")
with open(fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
logging.info("Generating features...")
local_alignment_dir = os.path.join(alignment_dir, tag)
if(args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
small_bfd_database_path=args.small_bfd_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
alignment_runner.run(
fasta_path, local_alignment_dir
)
plddt_b_factors = np.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1
)
feature_dict = data_processor.process_fasta(
fasta_path=fasta_path, alignment_dir=local_alignment_dir
)
unrelaxed_protein = protein.from_prediction(
features=batch,
result=out,
b_factors=plddt_b_factors
)
amber_relaxer = relax.AmberRelaxation(
**config.relax
)
# Remove temporary FASTA file
os.remove(fasta_path)
# Relax the prediction.
t = time.time()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
logging.info(f"Relaxation time: {time.time() - t}")
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict',
)
# Save the relaxed PDB.
relaxed_output_path = os.path.join(
args.output_dir, f'relaxed_{args.model_name}.pdb'
)
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
logging.info("Executing model...")
batch = processed_feature_dict
with torch.no_grad():
batch = {
k:torch.as_tensor(v, device=args.model_device)
for k,v in batch.items()
}
t = time.time()
out = model(batch)
logging.info(f"Inference time: {time.time() - t}")
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1
)
unrelaxed_protein = protein.from_prediction(
features=batch,
result=out,
b_factors=plddt_b_factors
)
amber_relaxer = relax.AmberRelaxation(
**config.relax
)
# Relax the prediction.
t = time.time()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
logging.info(f"Relaxation time: {time.time() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(
args.output_dir, f'{tag}_{args.model_name}.pdb'
)
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment