Commit 6b547775 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Merge branch 'main' of github.com:aqlaboratory/openfold into main

parents 9a617649 15fb714a
![header ](imgs/of_banner.png) ![header ](imgs/of_banner.png)
_Figure: Comparison of OpenFold, AlphaFold2, and experimental structure of Streptomyces tokunonesis TokK protein (pdb code 7KDX), related to novel antibiotics used for rare infections including during COVID-19 infection._ _Figure: Comparison of OpenFold and AlphaFold2 predictions to the experimental structure of PDB 7KDX, chain B._
# OpenFold # OpenFold
......
...@@ -208,36 +208,57 @@ def generate_feature_dict( ...@@ -208,36 +208,57 @@ def generate_feature_dict(
return feature_dict return feature_dict
def get_model_basename(model_path):
return os.path.splitext(
os.path.basename(
os.path.normpath(model_path)
)
)[0]
def make_output_directory(output_dir, model_name, multiple_model_mode):
if multiple_model_mode:
prediction_dir = os.path.join(output_dir, "predictions", model_name)
else:
prediction_dir = os.path.join(output_dir, "predictions")
os.makedirs(prediction_dir, exist_ok=True)
return prediction_dir
def count_models_to_evaluate(openfold_checkpoint_path, jax_param_path):
model_count = 0
if openfold_checkpoint_path:
model_count += len(openfold_checkpoint_path.split(","))
if jax_param_path:
model_count += len(jax_param_path.split(","))
return model_count
def load_models_from_command_line(args, config): def load_models_from_command_line(args, config):
# Create the output directory # Create the output directory
os.makedirs(args.output_dir, exist_ok=True)
multiple_model_mode = count_models_to_evaluate(args.openfold_checkpoint_path, args.jax_param_path) > 1
if multiple_model_mode:
logger.info(f"evaluating multiple models")
if args.jax_param_path: if args.jax_param_path:
for path in args.jax_param_path.split(","): for path in args.jax_param_path.split(","):
model_basename = get_model_basename(path)
model_version = "_".join(model_basename.split("_")[1:])
model = AlphaFold(config) model = AlphaFold(config)
model = model.eval() model = model.eval()
import_jax_weights_( import_jax_weights_(
model, path, version=args.config_preset model, path, version=model_version
) )
model = model.to(args.model_device) model = model.to(args.model_device)
logger.info( logger.info(
f"Successfully loaded JAX parameters at {args.jax_param_path}..." f"Successfully loaded JAX parameters at {path}..."
)
model_version = os.path.basename(
os.path.normpath(args.jax_param_path),
) )
model_version = os.path.splitext(model_version)[0] output_directory = make_output_directory(args.output_dir, model_basename, multiple_model_mode)
yield model, model_version yield model, output_directory
if args.openfold_checkpoint_path: if args.openfold_checkpoint_path:
for path in args.openfold_checkpoint_path.split(","): for path in args.openfold_checkpoint_path.split(","):
model = AlphaFold(config) model = AlphaFold(config)
model = model.eval() model = model.eval()
checkpoint_basename = os.path.splitext( checkpoint_basename = get_model_basename(path)
os.path.basename(
os.path.normpath(path)
)
)[0]
if os.path.isdir(path): if os.path.isdir(path):
# A DeepSpeed checkpoint # A DeepSpeed checkpoint
ckpt_path = os.path.join( ckpt_path = os.path.join(
...@@ -256,17 +277,17 @@ def load_models_from_command_line(args, config): ...@@ -256,17 +277,17 @@ def load_models_from_command_line(args, config):
ckpt_path = path ckpt_path = path
d = torch.load(ckpt_path) d = torch.load(ckpt_path)
if ("ema" in d): if "ema" in d:
# The public weights have had this done to them already # The public weights have had this done to them already
d = d["ema"]["params"] d = d["ema"]["params"]
model.load_state_dict(d) model.load_state_dict(d)
model = model.to(args.model_device) model = model.to(args.model_device)
logger.info( logger.info(
f"Loaded OpenFold parameters at {args.openfold_checkpoint_path}..." f"Loaded OpenFold parameters at {path}..."
) )
output_directory = make_output_directory(args.output_dir, checkpoint_basename, multiple_model_mode)
yield model, checkpoint_basename yield model, output_directory
if not args.jax_param_path and not args.openfold_checkpoint_path: if not args.jax_param_path and not args.openfold_checkpoint_path:
raise ValueError( raise ValueError(
...@@ -308,9 +329,6 @@ def main(args): ...@@ -308,9 +329,6 @@ def main(args):
alignment_dir = args.use_precomputed_alignments alignment_dir = args.use_precomputed_alignments
logger.info(f"Using precomputed alignments at {alignment_dir}...") logger.info(f"Using precomputed alignments at {alignment_dir}...")
prediction_dir = os.path.join(args.output_dir, "predictions")
os.makedirs(prediction_dir, exist_ok=True)
for fasta_file in list_files_with_extensions(args.fasta_dir, (".fasta", ".fa")): for fasta_file in list_files_with_extensions(args.fasta_dir, (".fasta", ".fa")):
# Gather input sequences # Gather input sequences
with open(os.path.join(args.fasta_dir, fasta_file), "r") as fp: with open(os.path.join(args.fasta_dir, fasta_file), "r") as fp:
...@@ -323,14 +341,6 @@ def main(args): ...@@ -323,14 +341,6 @@ def main(args):
output_name = f'{tag}_{args.config_preset}' output_name = f'{tag}_{args.config_preset}'
if args.output_postfix is not None: if args.output_postfix is not None:
output_name = f'{output_name}_{args.output_postfix}' output_name = f'{output_name}_{args.output_postfix}'
unrelaxed_output_path = os.path.join(
prediction_dir, f'{output_name}_unrelaxed.pdb'
)
# Output already exists
if os.path.exists(unrelaxed_output_path):
continue
precompute_alignments(tags, seqs, alignment_dir, args) precompute_alignments(tags, seqs, alignment_dir, args)
...@@ -346,7 +356,7 @@ def main(args): ...@@ -346,7 +356,7 @@ def main(args):
feature_dict, mode='predict', feature_dict, mode='predict',
) )
for model, model_version in load_models_from_command_line(args, config): for model, output_directory in load_models_from_command_line(args, config):
working_batch = deepcopy(processed_feature_dict) working_batch = deepcopy(processed_feature_dict)
out = run_model(model, working_batch, tag, args) out = run_model(model, working_batch, tag, args)
...@@ -358,6 +368,14 @@ def main(args): ...@@ -358,6 +368,14 @@ def main(args):
out, working_batch, feature_dict, feature_processor, args out, working_batch, feature_dict, feature_processor, args
) )
unrelaxed_output_path = os.path.join(
output_directory, f'{output_name}_unrelaxed.pdb'
)
# Output already exists
if os.path.exists(unrelaxed_output_path):
continue
with open(unrelaxed_output_path, 'w') as fp: with open(unrelaxed_output_path, 'w') as fp:
fp.write(protein.to_pdb(unrelaxed_protein)) fp.write(protein.to_pdb(unrelaxed_protein))
...@@ -382,7 +400,7 @@ def main(args): ...@@ -382,7 +400,7 @@ def main(args):
# Save the relaxed PDB. # Save the relaxed PDB.
relaxed_output_path = os.path.join( relaxed_output_path = os.path.join(
prediction_dir, f'{output_name}_relaxed.pdb' output_directory, f'{output_name}_relaxed.pdb'
) )
with open(relaxed_output_path, 'w') as fp: with open(relaxed_output_path, 'w') as fp:
fp.write(relaxed_pdb_str) fp.write(relaxed_pdb_str)
...@@ -391,7 +409,7 @@ def main(args): ...@@ -391,7 +409,7 @@ def main(args):
if args.save_outputs: if args.save_outputs:
output_dict_path = os.path.join( output_dict_path = os.path.join(
args.output_dir, f'{output_name}_output_dict.pkl' output_directory, f'{output_name}_output_dict.pkl'
) )
with open(output_dict_path, "wb") as fp: with open(output_dict_path, "wb") as fp:
pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment