"git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "42ee22970346d274abb6e882132c95cb6f01adc9"
Commit 9dee6caa authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Update inference script

parent 7642bef9
......@@ -57,7 +57,6 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
use_small_bfd=(args.bfd_database_path is None)
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
......@@ -71,7 +70,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
no_cpus=args.cpus,
)
alignment_runner.run(
tmp_fasta_path, local_alignment_dir
fasta_path, local_alignment_dir
)
# Remove temporary FASTA file
......@@ -87,7 +86,7 @@ def run_model(model, batch, tag, args):
}
# Disable templates if there aren't any in the batch
model.config.template.enabled = any([
model.config.template.enabled = model.config.template.enabled and any([
"template_" in k for k in batch
])
......@@ -165,6 +164,7 @@ def main(args):
# Prep the model
config = model_config(args.model_name)
model = AlphaFold(config)
model = model.eval()
......@@ -174,6 +174,7 @@ def main(args):
)
elif(args.openfold_checkpoint_path):
if(os.path.isdir(args.openfold_checkpoint_path)):
# A DeepSpeed checkpoint
checkpoint_basename = os.path.splitext(
os.path.basename(
os.path.normpath(args.openfold_checkpoint_path)
......@@ -193,13 +194,15 @@ def main(args):
d = torch.load(ckpt_path)
model.load_state_dict(d["ema"]["params"])
else:
# A checkpoint from the public release, which only contains EMA
# params
ckpt_path = args.openfold_checkpoint_path
d = torch.load(ckpt_path)
if("ema" in d):
# The public weights have had this done to them already
d = d["ema"]["params"]
model.load_state_dict(d)
else:
raise ValueError(
......@@ -218,6 +221,8 @@ def main(args):
obsolete_pdbs_path=args.obsolete_pdbs_path
)
use_small_bfd=(args.bfd_database_path is None)
data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
)
......@@ -234,7 +239,7 @@ def main(args):
else:
alignment_dir = args.use_precomputed_alignments
prediction_dir = os.path.join(args.output_dir, "predictions")
prediction_dir = os.path.join(args.output_dir, "predictions")
os.makedirs(prediction_dir, exist_ok=True)
for fasta_file in os.listdir(args.fasta_dir):
......@@ -249,9 +254,21 @@ def main(args):
tags, seqs = lines[::2], lines[1::2]
tags = [t.split()[0] for t in tags]
assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag = '-'.join(tags)
output_name = f'{tag}_{args.model_name}'
if(args.output_postfix is not None):
output_name = f'{output_name}_{args.output_postfix}'
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(
prediction_dir, f'{output_name}_unrelaxed.pdb'
)
if(os.path.exists(unrelaxed_output_path)):
continue
precompute_alignments(tags, seqs, alignment_dir, args)
tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
......@@ -272,10 +289,10 @@ def main(args):
feature_dict = data_processor.process_multiseq_fasta(
fasta_path=tmp_fasta_path, super_alignment_dir=alignment_dir,
)
# Remove temporary FASTA file
os.remove(tmp_fasta_path)
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict',
)
......@@ -286,14 +303,14 @@ def main(args):
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
unrelaxed_protein = prep_output(
out, batch, feature_dict, feature_processor, args
)
output_name = f'{tag}_{args.model_name}'
if(args.output_postfix is not None):
output_name = f'{output_name}_{args.output_postfix}'
output_name = f'{output_name}_{args.output_postfix}_{tag_postfix}'
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment