"git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "6aff8be6547967f1e14fcd65679dff5ee8445da3"
Commit 9dee6caa authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Update inference script

parent 7642bef9
...@@ -57,7 +57,6 @@ def precompute_alignments(tags, seqs, alignment_dir, args): ...@@ -57,7 +57,6 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
if not os.path.exists(local_alignment_dir): if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir) os.makedirs(local_alignment_dir)
use_small_bfd=(args.bfd_database_path is None)
alignment_runner = data_pipeline.AlignmentRunner( alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path, jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path, hhblits_binary_path=args.hhblits_binary_path,
...@@ -71,7 +70,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args): ...@@ -71,7 +70,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
no_cpus=args.cpus, no_cpus=args.cpus,
) )
alignment_runner.run( alignment_runner.run(
tmp_fasta_path, local_alignment_dir fasta_path, local_alignment_dir
) )
# Remove temporary FASTA file # Remove temporary FASTA file
...@@ -87,7 +86,7 @@ def run_model(model, batch, tag, args): ...@@ -87,7 +86,7 @@ def run_model(model, batch, tag, args):
} }
# Disable templates if there aren't any in the batch # Disable templates if there aren't any in the batch
model.config.template.enabled = any([ model.config.template.enabled = model.config.template.enabled and any([
"template_" in k for k in batch "template_" in k for k in batch
]) ])
...@@ -165,6 +164,7 @@ def main(args): ...@@ -165,6 +164,7 @@ def main(args):
# Prep the model # Prep the model
config = model_config(args.model_name) config = model_config(args.model_name)
model = AlphaFold(config) model = AlphaFold(config)
model = model.eval() model = model.eval()
...@@ -174,6 +174,7 @@ def main(args): ...@@ -174,6 +174,7 @@ def main(args):
) )
elif(args.openfold_checkpoint_path): elif(args.openfold_checkpoint_path):
if(os.path.isdir(args.openfold_checkpoint_path)): if(os.path.isdir(args.openfold_checkpoint_path)):
# A DeepSpeed checkpoint
checkpoint_basename = os.path.splitext( checkpoint_basename = os.path.splitext(
os.path.basename( os.path.basename(
os.path.normpath(args.openfold_checkpoint_path) os.path.normpath(args.openfold_checkpoint_path)
...@@ -193,13 +194,15 @@ def main(args): ...@@ -193,13 +194,15 @@ def main(args):
d = torch.load(ckpt_path) d = torch.load(ckpt_path)
model.load_state_dict(d["ema"]["params"]) model.load_state_dict(d["ema"]["params"])
else: else:
# A checkpoint from the public release, which only contains EMA
# params
ckpt_path = args.openfold_checkpoint_path ckpt_path = args.openfold_checkpoint_path
d = torch.load(ckpt_path) d = torch.load(ckpt_path)
if("ema" in d): if("ema" in d):
# The public weights have had this done to them already # The public weights have had this done to them already
d = d["ema"]["params"] d = d["ema"]["params"]
model.load_state_dict(d) model.load_state_dict(d)
else: else:
raise ValueError( raise ValueError(
...@@ -218,6 +221,8 @@ def main(args): ...@@ -218,6 +221,8 @@ def main(args):
obsolete_pdbs_path=args.obsolete_pdbs_path obsolete_pdbs_path=args.obsolete_pdbs_path
) )
use_small_bfd=(args.bfd_database_path is None)
data_processor = data_pipeline.DataPipeline( data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer, template_featurizer=template_featurizer,
) )
...@@ -234,7 +239,7 @@ def main(args): ...@@ -234,7 +239,7 @@ def main(args):
else: else:
alignment_dir = args.use_precomputed_alignments alignment_dir = args.use_precomputed_alignments
prediction_dir = os.path.join(args.output_dir, "predictions") prediction_dir = os.path.join(args.output_dir, "predictions")
os.makedirs(prediction_dir, exist_ok=True) os.makedirs(prediction_dir, exist_ok=True)
for fasta_file in os.listdir(args.fasta_dir): for fasta_file in os.listdir(args.fasta_dir):
...@@ -249,9 +254,21 @@ def main(args): ...@@ -249,9 +254,21 @@ def main(args):
tags, seqs = lines[::2], lines[1::2] tags, seqs = lines[::2], lines[1::2]
tags = [t.split()[0] for t in tags] tags = [t.split()[0] for t in tags]
assert len(tags) == len(set(tags)), "All FASTA tags must be unique" # assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag = '-'.join(tags) tag = '-'.join(tags)
output_name = f'{tag}_{args.model_name}'
if(args.output_postfix is not None):
output_name = f'{output_name}_{args.output_postfix}'
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(
prediction_dir, f'{output_name}_unrelaxed.pdb'
)
if(os.path.exists(unrelaxed_output_path)):
continue
precompute_alignments(tags, seqs, alignment_dir, args) precompute_alignments(tags, seqs, alignment_dir, args)
tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta") tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
...@@ -272,10 +289,10 @@ def main(args): ...@@ -272,10 +289,10 @@ def main(args):
feature_dict = data_processor.process_multiseq_fasta( feature_dict = data_processor.process_multiseq_fasta(
fasta_path=tmp_fasta_path, super_alignment_dir=alignment_dir, fasta_path=tmp_fasta_path, super_alignment_dir=alignment_dir,
) )
# Remove temporary FASTA file # Remove temporary FASTA file
os.remove(tmp_fasta_path) os.remove(tmp_fasta_path)
processed_feature_dict = feature_processor.process_features( processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict', feature_dict, mode='predict',
) )
...@@ -286,14 +303,14 @@ def main(args): ...@@ -286,14 +303,14 @@ def main(args):
# Toss out the recycling dimensions --- we don't need them anymore # Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch) batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out) out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
unrelaxed_protein = prep_output( unrelaxed_protein = prep_output(
out, batch, feature_dict, feature_processor, args out, batch, feature_dict, feature_processor, args
) )
output_name = f'{tag}_{args.model_name}' output_name = f'{tag}_{args.model_name}'
if(args.output_postfix is not None): if(args.output_postfix is not None):
output_name = f'{output_name}_{args.output_postfix}' output_name = f'{output_name}_{args.output_postfix}_{tag_postfix}'
# Save the unrelaxed PDB. # Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join( unrelaxed_output_path = os.path.join(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment