Unverified Commit a14f6962 authored by Kashyap Chhatbar's avatar Kashyap Chhatbar Committed by GitHub
Browse files

Enable saving prediction output (#133)

- Added argument and save by default
- Include all tags in output file for multimer preset
parent b7ee0ff3
......@@ -241,7 +241,7 @@ def inference_multimer_model(args):
][1:]
tags, seqs = lines[::2], lines[1::2]
output_prefix = "_and_".join(tags)
for tag, seq in zip(tags, seqs):
local_alignment_dir = os.path.join(alignment_dir, tag)
if(args.use_precomputed_alignments is None):
......@@ -268,7 +268,6 @@ def inference_multimer_model(args):
feature_dict = data_processor.process_fasta(
fasta_path=fasta_path, alignment_dir=local_alignment_dir
)
# feature_dict = pickle.load(open("/home/lcmql/data/features_pdb1o5d.pkl", "rb"))
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict', is_multimer=True,
......@@ -282,6 +281,13 @@ def inference_multimer_model(args):
out = result_q.get()
if args.save_prediction_result:
# Save the prediction result .pkl
prediction_result_path = os.path.join(args.output_dir,
f'{output_prefix}_{args.model_name}.pkl')
with open(prediction_result_path, 'wb') as f:
pickle.dump(out, f)
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
......@@ -296,7 +302,7 @@ def inference_multimer_model(args):
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(args.output_dir,
f'{tag}_{args.model_name}_unrelaxed.pdb')
f'{output_prefix}_{args.model_name}_unrelaxed.pdb')
with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein))
......@@ -312,7 +318,7 @@ def inference_multimer_model(args):
# Save the relaxed PDB.
relaxed_output_path = os.path.join(args.output_dir,
f'{tag}_{args.model_name}_relaxed.pdb')
f'{output_prefix}_{args.model_name}_relaxed.pdb')
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
......@@ -500,6 +506,9 @@ if __name__ == "__main__":
type=str,
default='full_dbs',
choices=('reduced_dbs', 'full_dbs'))
parser.add_argument('--save_prediction_result',
type=bool,
default=True)
parser.add_argument('--data_random_seed', type=str, default=None)
parser.add_argument(
"--model_preset",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment