Commit 9dee6caa authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Update inference script

parent 7642bef9
...@@ -57,7 +57,6 @@ def precompute_alignments(tags, seqs, alignment_dir, args): ...@@ -57,7 +57,6 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
if not os.path.exists(local_alignment_dir): if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir) os.makedirs(local_alignment_dir)
use_small_bfd=(args.bfd_database_path is None)
alignment_runner = data_pipeline.AlignmentRunner( alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path, jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path, hhblits_binary_path=args.hhblits_binary_path,
...@@ -71,7 +70,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args): ...@@ -71,7 +70,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
no_cpus=args.cpus, no_cpus=args.cpus,
) )
alignment_runner.run( alignment_runner.run(
tmp_fasta_path, local_alignment_dir fasta_path, local_alignment_dir
) )
# Remove temporary FASTA file # Remove temporary FASTA file
...@@ -87,7 +86,7 @@ def run_model(model, batch, tag, args): ...@@ -87,7 +86,7 @@ def run_model(model, batch, tag, args):
} }
# Disable templates if there aren't any in the batch # Disable templates if there aren't any in the batch
model.config.template.enabled = any([ model.config.template.enabled = model.config.template.enabled and any([
"template_" in k for k in batch "template_" in k for k in batch
]) ])
...@@ -165,6 +164,7 @@ def main(args): ...@@ -165,6 +164,7 @@ def main(args):
# Prep the model # Prep the model
config = model_config(args.model_name) config = model_config(args.model_name)
model = AlphaFold(config) model = AlphaFold(config)
model = model.eval() model = model.eval()
...@@ -174,6 +174,7 @@ def main(args): ...@@ -174,6 +174,7 @@ def main(args):
) )
elif(args.openfold_checkpoint_path): elif(args.openfold_checkpoint_path):
if(os.path.isdir(args.openfold_checkpoint_path)): if(os.path.isdir(args.openfold_checkpoint_path)):
# A DeepSpeed checkpoint
checkpoint_basename = os.path.splitext( checkpoint_basename = os.path.splitext(
os.path.basename( os.path.basename(
os.path.normpath(args.openfold_checkpoint_path) os.path.normpath(args.openfold_checkpoint_path)
...@@ -193,6 +194,8 @@ def main(args): ...@@ -193,6 +194,8 @@ def main(args):
d = torch.load(ckpt_path) d = torch.load(ckpt_path)
model.load_state_dict(d["ema"]["params"]) model.load_state_dict(d["ema"]["params"])
else: else:
# A checkpoint from the public release, which only contains EMA
# params
ckpt_path = args.openfold_checkpoint_path ckpt_path = args.openfold_checkpoint_path
d = torch.load(ckpt_path) d = torch.load(ckpt_path)
...@@ -218,6 +221,8 @@ def main(args): ...@@ -218,6 +221,8 @@ def main(args):
obsolete_pdbs_path=args.obsolete_pdbs_path obsolete_pdbs_path=args.obsolete_pdbs_path
) )
use_small_bfd=(args.bfd_database_path is None)
data_processor = data_pipeline.DataPipeline( data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer, template_featurizer=template_featurizer,
) )
...@@ -249,9 +254,21 @@ def main(args): ...@@ -249,9 +254,21 @@ def main(args):
tags, seqs = lines[::2], lines[1::2] tags, seqs = lines[::2], lines[1::2]
tags = [t.split()[0] for t in tags] tags = [t.split()[0] for t in tags]
assert len(tags) == len(set(tags)), "All FASTA tags must be unique" # assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag = '-'.join(tags) tag = '-'.join(tags)
output_name = f'{tag}_{args.model_name}'
if(args.output_postfix is not None):
output_name = f'{output_name}_{args.output_postfix}'
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(
prediction_dir, f'{output_name}_unrelaxed.pdb'
)
if(os.path.exists(unrelaxed_output_path)):
continue
precompute_alignments(tags, seqs, alignment_dir, args) precompute_alignments(tags, seqs, alignment_dir, args)
tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta") tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
...@@ -293,7 +310,7 @@ def main(args): ...@@ -293,7 +310,7 @@ def main(args):
output_name = f'{tag}_{args.model_name}' output_name = f'{tag}_{args.model_name}'
if(args.output_postfix is not None): if(args.output_postfix is not None):
output_name = f'{output_name}_{args.output_postfix}' output_name = f'{output_name}_{args.output_postfix}_{tag_postfix}'
# Save the unrelaxed PDB. # Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join( unrelaxed_output_path = os.path.join(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment