Unverified Commit 89dee905 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #117 from CyrusBiotechnology/run-multiple-models

Use multiple models for inference
parents a48860cb a2ab7ab7
...@@ -19,6 +19,7 @@ import gc ...@@ -19,6 +19,7 @@ import gc
import logging import logging
import numpy as np import numpy as np
import os import os
from copy import deepcopy
import pickle import pickle
from pytorch_lightning.utilities.deepspeed import ( from pytorch_lightning.utilities.deepspeed import (
...@@ -58,10 +59,10 @@ def precompute_alignments(tags, seqs, alignment_dir, args): ...@@ -58,10 +59,10 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
local_alignment_dir = os.path.join(alignment_dir, tag) local_alignment_dir = os.path.join(alignment_dir, tag)
if(args.use_precomputed_alignments is None): if(args.use_precomputed_alignments is None):
logger.info(f"Generating alignments for {tag}...") logger.info(f"Generating alignments for {tag}...")
if not os.path.exists(local_alignment_dir): if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir) os.makedirs(local_alignment_dir)
alignment_runner = data_pipeline.AlignmentRunner( alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path, jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path, hhblits_binary_path=args.hhblits_binary_path,
...@@ -161,69 +162,125 @@ def prep_output(out, batch, feature_dict, feature_processor, args): ...@@ -161,69 +162,125 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
return unrelaxed_protein return unrelaxed_protein
def main(args): def generate_batch(fasta_file, fasta_dir, alignment_dir, data_processor, feature_processor, prediction_dir):
# Create the output directory with open(os.path.join(fasta_dir, fasta_file), "r") as fp:
os.makedirs(args.output_dir, exist_ok=True) data = fp.read()
# Prep the model lines = [
config = model_config(args.config_preset) l.replace('\n', '')
for prot in data.split('>') for l in prot.strip().split('\n', 1)
logger.info(f"Using config preset {args.config_preset}...") ][1:]
tags, seqs = lines[::2], lines[1::2]
tags = [t.split()[0] for t in tags]
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag = '-'.join(tags)
model = AlphaFold(config) output_name = f'{tag}_{args.config_preset}'
model = model.eval() if args.output_postfix is not None:
output_name = f'{output_name}_{args.output_postfix}'
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(
prediction_dir, f'{output_name}_unrelaxed.pdb'
)
if os.path.exists(unrelaxed_output_path):
return
precompute_alignments(tags, seqs, alignment_dir, args)
tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
if len(seqs) == 1:
seq = seqs[0]
with open(tmp_fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
if(args.jax_param_path): local_alignment_dir = os.path.join(alignment_dir, tag)
import_jax_weights_( feature_dict = data_processor.process_fasta(
model, args.jax_param_path, version=args.config_preset fasta_path=tmp_fasta_path, alignment_dir=local_alignment_dir
) )
logger.info( else:
f"Successfully loaded JAX parameters at {args.jax_param_path}..." with open(tmp_fasta_path, "w") as fp:
fp.write(
'\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
)
feature_dict = data_processor.process_multiseq_fasta(
fasta_path=tmp_fasta_path, super_alignment_dir=alignment_dir,
) )
elif(args.openfold_checkpoint_path):
if(os.path.isdir(args.openfold_checkpoint_path)): # Remove temporary FASTA file
# A DeepSpeed checkpoint os.remove(tmp_fasta_path)
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict',
)
return processed_feature_dict, tag, feature_dict
def load_models_from_command_line(args, config):
# Create the output directory
os.makedirs(args.output_dir, exist_ok=True)
if args.jax_param_path:
for path in args.jax_param_path.split(","):
model = AlphaFold(config)
model = model.eval()
import_jax_weights_(
model, path, version=args.model_name
)
model = model.to(args.model_device)
logger.info(
f"Successfully loaded JAX parameters at {args.jax_param_path}..."
)
yield model, None
if args.openfold_checkpoint_path:
for path in args.openfold_checkpoint_path.split(","):
model = AlphaFold(config)
model = model.eval()
checkpoint_basename = os.path.splitext( checkpoint_basename = os.path.splitext(
os.path.basename( os.path.basename(
os.path.normpath(args.openfold_checkpoint_path) os.path.normpath(path)
) )
)[0] )[0]
ckpt_path = os.path.join( if os.path.isdir(path):
args.output_dir, # A DeepSpeed checkpoint
checkpoint_basename + ".pt", ckpt_path = os.path.join(
) args.output_dir,
checkpoint_basename + ".pt",
if(not os.path.isfile(ckpt_path)):
convert_zero_checkpoint_to_fp32_state_dict(
args.openfold_checkpoint_path,
ckpt_path,
) )
d = torch.load(ckpt_path) if not os.path.isfile(ckpt_path):
model.load_state_dict(d["ema"]["params"]) convert_zero_checkpoint_to_fp32_state_dict(
else: path,
# A checkpoint from the public release, which only contains EMA ckpt_path,
# params )
ckpt_path = args.openfold_checkpoint_path d = torch.load(ckpt_path)
d = torch.load(ckpt_path) model.load_state_dict(d["ema"]["params"])
else:
if("ema" in d): ckpt_path = path
# The public weights have had this done to them already d = torch.load(ckpt_path)
d = d["ema"]["params"]
if ("ema" in d):
model.load_state_dict(d) # The public weights have had this done to them already
d = d["ema"]["params"]
logger.info( model.load_state_dict(d)
f"Loaded OpenFold parameters at {args.openfold_checkpoint_path}..." model = model.to(args.model_device)
) logger.info(
else: f"Loaded OpenFold parameters at {args.openfold_checkpoint_path}..."
)
yield model, checkpoint_basename
if not args.jax_param_path and not args.openfold_checkpoint_path:
raise ValueError( raise ValueError(
"At least one of jax_param_path or openfold_checkpoint_path must " "At least one of jax_param_path or openfold_checkpoint_path must "
"be specified." "be specified."
) )
model = model.to(args.model_device)
def main(args):
# Create the output directory
os.makedirs(args.output_dir, exist_ok=True)
config = model_config(args.config_preset)
template_featurizer = templates.TemplateHitFeaturizer( template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=args.template_mmcif_dir, mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date, max_template_date=args.max_template_date,
...@@ -244,128 +301,92 @@ def main(args): ...@@ -244,128 +301,92 @@ def main(args):
feature_processor = feature_pipeline.FeaturePipeline(config.data) feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base): if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base) os.makedirs(output_dir_base)
if(args.use_precomputed_alignments is None): if args.use_precomputed_alignments is None:
alignment_dir = os.path.join(output_dir_base, "alignments") alignment_dir = os.path.join(output_dir_base, "alignments")
else: else:
alignment_dir = args.use_precomputed_alignments alignment_dir = args.use_precomputed_alignments
logger.info(f"Using precomputed alignments at {alignment_dir}...") logger.info(f"Using precomputed alignments at {alignment_dir}...")
prediction_dir = os.path.join(args.output_dir, "predictions") prediction_dir = os.path.join(args.output_dir, "predictions")
os.makedirs(prediction_dir, exist_ok=True) os.makedirs(prediction_dir, exist_ok=True)
for fasta_file in os.listdir(args.fasta_dir): for fasta_file in os.listdir(args.fasta_dir):
# Gather input sequences
with open(os.path.join(args.fasta_dir, fasta_file), "r") as fp:
data = fp.read()
lines = [
l.replace('\n', '')
for prot in data.split('>') for l in prot.strip().split('\n', 1)
][1:]
tags, seqs = lines[::2], lines[1::2]
tags = [t.split()[0] for t in tags]
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag = '-'.join(tags)
output_name = f'{tag}_{args.config_preset}'
if(args.output_postfix is not None):
output_name = f'{output_name}_{args.output_postfix}'
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(
prediction_dir, f'{output_name}_unrelaxed.pdb'
)
if(os.path.exists(unrelaxed_output_path)):
continue
precompute_alignments(tags, seqs, alignment_dir, args) batch_data = generate_batch(
fasta_file,
args.fasta_dir,
alignment_dir,
data_processor,
feature_processor,
prediction_dir)
tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta") if batch_data is None:
if(len(seqs) == 1): # this file has already been processed
seq = seqs[0] continue
with open(tmp_fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
local_alignment_dir = os.path.join(alignment_dir, tag)
feature_dict = data_processor.process_fasta(
fasta_path=tmp_fasta_path, alignment_dir=local_alignment_dir
)
else:
with open(tmp_fasta_path, "w") as fp:
fp.write(
'\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
)
feature_dict = data_processor.process_multiseq_fasta(
fasta_path=tmp_fasta_path, super_alignment_dir=alignment_dir,
)
# Remove temporary FASTA file batch, tag, feature_dict = batch_data
os.remove(tmp_fasta_path)
processed_feature_dict = feature_processor.process_features( for model, model_version in load_models_from_command_line(args, config):
feature_dict, mode='predict',
)
batch = processed_feature_dict working_batch = deepcopy(batch)
out = run_model(model, batch, tag, args) out = run_model(model, working_batch, tag, args)
# Toss out the recycling dimensions --- we don't need them anymore # Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch) working_batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), working_batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out) out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
unrelaxed_protein = prep_output(
out, batch, feature_dict, feature_processor, args
)
output_name = f'{tag}_{args.config_preset}' unrelaxed_protein = prep_output(
if(args.output_postfix is not None): out, working_batch, feature_dict, feature_processor, args
output_name = f'{output_name}_{args.output_postfix}' )
# Save the unrelaxed PDB. output_name = f'{tag}_{args.config_preset}'
unrelaxed_output_path = os.path.join(
prediction_dir, f'{output_name}_unrelaxed.pdb'
)
with open(unrelaxed_output_path, 'w') as fp:
fp.write(protein.to_pdb(unrelaxed_protein))
logger.info(f"Output written to {unrelaxed_output_path}...") if model_version is not None:
output_name = f'{output_name}_{model_version}'
if args.output_postfix is not None:
output_name = f'{output_name}_{args.output_postfix}'
if(not args.skip_relaxation): # Save the unrelaxed PDB.
amber_relaxer = relax.AmberRelaxation( unrelaxed_output_path = os.path.join(
use_gpu=(args.model_device != "cpu"), prediction_dir, f'{output_name}_unrelaxed.pdb'
**config.relax,
)
# Relax the prediction.
logger.info(f"Running relaxation on {unrelaxed_output_path}...")
t = time.perf_counter()
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default="")
if("cuda" in args.model_device):
device_no = args.model_device.split(":")[-1]
os.environ["CUDA_VISIBLE_DEVICES"] = device_no
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
logger.info(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(
prediction_dir, f'{output_name}_relaxed.pdb'
) )
with open(relaxed_output_path, 'w') as fp: with open(unrelaxed_output_path, 'w') as fp:
fp.write(relaxed_pdb_str) fp.write(protein.to_pdb(unrelaxed_protein))
logger.info(f"Output written to {unrelaxed_output_path}...")
if not args.skip_relaxation:
amber_relaxer = relax.AmberRelaxation(
use_gpu=(args.model_device != "cpu"),
**config.relax,
)
logger.info(f"Relaxed output written to {relaxed_output_path}...") # Relax the prediction.
logger.info(f"Running relaxation on {unrelaxed_output_path}...")
t = time.perf_counter()
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default="")
if "cuda" in args.model_device:
device_no = args.model_device.split(":")[-1]
os.environ["CUDA_VISIBLE_DEVICES"] = device_no
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
logger.info(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(
prediction_dir, f'{output_name}_relaxed.pdb'
)
with open(relaxed_output_path, 'w') as fp:
fp.write(relaxed_pdb_str)
logger.info(f"Relaxed output written to {relaxed_output_path}...")
if(args.save_outputs): if args.save_outputs:
output_dict_path = os.path.join( output_dict_path = os.path.join(
args.output_dir, f'{output_name}_output_dict.pkl' args.output_dir, f'{output_name}_output_dict.pkl'
) )
with open(output_dict_path, "wb") as fp: with open(output_dict_path, "wb") as fp:
pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(f"Model output written to {output_dict_path}...") logger.info(f"Model output written to {output_dict_path}...")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment