Commit ce27a6ca authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add logging to inference script

parent f38f346d
......@@ -45,6 +45,11 @@ from openfold.utils.tensor_utils import (
from scripts.utils import add_data_args
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
def precompute_alignments(tags, seqs, alignment_dir, args):
for tag, seq in zip(tags, seqs):
tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
......@@ -53,7 +58,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
local_alignment_dir = os.path.join(alignment_dir, tag)
if(args.use_precomputed_alignments is None):
logging.info(f"Generating alignments for {tag}...")
logger.info(f"Generating alignments for {tag}...")
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
......@@ -78,7 +83,6 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
def run_model(model, batch, tag, args):
logging.info("Executing model...")
with torch.no_grad():
batch = {
k:torch.as_tensor(v, device=args.model_device)
......@@ -90,10 +94,10 @@ def run_model(model, batch, tag, args):
"template_" in k for k in batch
])
logging.info(f"Running inference for {tag}...")
logger.info(f"Running inference for {tag}...")
t = time.perf_counter()
out = model(batch)
logging.info(f"Inference time: {time.perf_counter() - t}")
logger.info(f"Inference time: {time.perf_counter() - t}")
return out
......@@ -165,6 +169,8 @@ def main(args):
# Prep the model
config = model_config(args.model_name)
logger.info(f"Using config preset {args.model_name}...")
model = AlphaFold(config)
model = model.eval()
......@@ -172,6 +178,9 @@ def main(args):
import_jax_weights_(
model, args.jax_param_path, version=args.model_name
)
logger.info(
f"Successfully loaded JAX parameters at {args.jax_param_path}..."
)
elif(args.openfold_checkpoint_path):
if(os.path.isdir(args.openfold_checkpoint_path)):
# A DeepSpeed checkpoint
......@@ -204,6 +213,10 @@ def main(args):
d = d["ema"]["params"]
model.load_state_dict(d)
logger.info(
f"Loaded OpenFold parameters at {args.openfold_checkpoint_path}..."
)
else:
raise ValueError(
"At least one of jax_param_path or openfold_checkpoint_path must "
......@@ -238,6 +251,7 @@ def main(args):
alignment_dir = os.path.join(output_dir_base, "alignments")
else:
alignment_dir = args.use_precomputed_alignments
logger.info(f"Using precomputed alignments at {alignment_dir}...")
prediction_dir = os.path.join(args.output_dir, "predictions")
os.makedirs(prediction_dir, exist_ok=True)
......@@ -319,6 +333,8 @@ def main(args):
with open(unrelaxed_output_path, 'w') as fp:
fp.write(protein.to_pdb(unrelaxed_protein))
logger.info(f"Output written to {unrelaxed_output_path}...")
if(not args.skip_relaxation):
amber_relaxer = relax.AmberRelaxation(
use_gpu=(args.model_device != "cpu"),
......@@ -326,6 +342,7 @@ def main(args):
)
# Relax the prediction.
logger.info(f"Running relaxation on {unrelaxed_output_path}...")
t = time.perf_counter()
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default="")
if("cuda" in args.model_device):
......@@ -333,7 +350,7 @@ def main(args):
os.environ["CUDA_VISIBLE_DEVICES"] = device_no
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
logging.info(f"Relaxation time: {time.perf_counter() - t}")
logger.info(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(
......@@ -342,6 +359,8 @@ def main(args):
with open(relaxed_output_path, 'w') as fp:
fp.write(relaxed_pdb_str)
logger.info(f"Relaxed output written to {relaxed_output_path}...")
if(args.save_outputs):
output_dict_path = os.path.join(
args.output_dir, f'{output_name}_output_dict.pkl'
......@@ -349,6 +368,7 @@ def main(args):
with open(output_dict_path, "wb") as fp:
pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(f"Model output written to {output_dict_path}...")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment