"git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "3817a135c5809f7e6b8a9c3efb5c504e109d3271"
Unverified Commit 7ea9c9e4 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #201 from brianloyal/timings

Write inference and relaxation timings to a file
parents 6e930a6c 450f5236
...@@ -120,6 +120,7 @@ def run_model(model, batch, tag, args): ...@@ -120,6 +120,7 @@ def run_model(model, batch, tag, args):
out = model(batch) out = model(batch)
inference_time = time.perf_counter() - t inference_time = time.perf_counter() - t
logger.info(f"Inference time: {inference_time}") logger.info(f"Inference time: {inference_time}")
update_timings({"inference": inference_time}, os.path.join(args.output_dir, "timings.json"))
model.config.template.enabled = template_enabled model.config.template.enabled = template_enabled
...@@ -480,7 +481,10 @@ def main(args): ...@@ -480,7 +481,10 @@ def main(args):
os.environ["CUDA_VISIBLE_DEVICES"] = device_no os.environ["CUDA_VISIBLE_DEVICES"] = device_no
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein) relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
logger.info(f"Relaxation time: {time.perf_counter() - t}") relaxation_time = time.perf_counter() - t
logger.info(f"Relaxation time: {relaxation_time}")
update_timings({"relaxation": relaxation_time}, os.path.join(args.output_dir, "timings.json"))
# Save the relaxed PDB. # Save the relaxed PDB.
relaxed_output_path = os.path.join( relaxed_output_path = os.path.join(
...@@ -500,6 +504,22 @@ def main(args): ...@@ -500,6 +504,22 @@ def main(args):
logger.info(f"Model output written to {output_dict_path}...") logger.info(f"Model output written to {output_dict_path}...")
def update_timings(dict, output_file=os.path.join(os.getcwd(), "timings.json")):
"""Write dictionary of one or more run step times to a file"""
import json
if os.path.exists(output_file):
with open(output_file, "r") as f:
try:
timings = json.load(f)
except json.JSONDecodeError:
logger.info(f"Overwriting non-standard JSON in {output_file}.")
timings = {}
else:
timings = {}
timings.update(dict)
with open(output_file, "w") as f:
json.dump(timings, f)
return output_file
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment