"git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "ee4ef06bac4ee3f6bed53a3b77cb95c5ba5d824e"
Commit 1fa6ffab authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Refactor run_pretrained_openfold.py a little

parent b4b849af
...@@ -162,36 +162,28 @@ def prep_output(out, batch, feature_dict, feature_processor, args): ...@@ -162,36 +162,28 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
return unrelaxed_protein return unrelaxed_protein
def generate_batch(fasta_file, fasta_dir, alignment_dir, data_processor, feature_processor, prediction_dir): def parse_fasta(data):
with open(os.path.join(fasta_dir, fasta_file), "r") as fp:
data = fp.read()
lines = [ lines = [
l.replace('\n', '') l.replace('\n', '')
for prot in data.split('>') for l in prot.strip().split('\n', 1) for prot in data.split('>') for l in prot.strip().split('\n', 1)
][1:] ][1:]
tags, seqs = lines[::2], lines[1::2] tags, seqs = lines[::2], lines[1::2]
tags = [t.split()[0] for t in tags] tags = [t.split()[0] for t in tags]
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag = '-'.join(tags)
output_name = f'{tag}_{args.config_preset}'
if args.output_postfix is not None:
output_name = f'{output_name}_{args.output_postfix}'
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(
prediction_dir, f'{output_name}_unrelaxed.pdb'
)
if os.path.exists(unrelaxed_output_path): return tags, seqs
return
precompute_alignments(tags, seqs, alignment_dir, args)
def generate_feature_dict(
tags,
seqs,
alignment_dir,
data_processor,
args,
):
tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta") tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
if len(seqs) == 1: if len(seqs) == 1:
tag = tags[0]
seq = seqs[0] seq = seqs[0]
with open(tmp_fasta_path, "w") as fp: with open(tmp_fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}") fp.write(f">{tag}\n{seq}")
...@@ -212,10 +204,7 @@ def generate_batch(fasta_file, fasta_dir, alignment_dir, data_processor, feature ...@@ -212,10 +204,7 @@ def generate_batch(fasta_file, fasta_dir, alignment_dir, data_processor, feature
# Remove temporary FASTA file # Remove temporary FASTA file
os.remove(tmp_fasta_path) os.remove(tmp_fasta_path)
processed_feature_dict = feature_processor.process_features( return feature_dict
feature_dict, mode='predict',
)
return processed_feature_dict, tag, feature_dict
def load_models_from_command_line(args, config): def load_models_from_command_line(args, config):
...@@ -226,13 +215,18 @@ def load_models_from_command_line(args, config): ...@@ -226,13 +215,18 @@ def load_models_from_command_line(args, config):
model = AlphaFold(config) model = AlphaFold(config)
model = model.eval() model = model.eval()
import_jax_weights_( import_jax_weights_(
model, path, version=args.model_name model, path, version=args.config_preset
) )
model = model.to(args.model_device) model = model.to(args.model_device)
logger.info( logger.info(
f"Successfully loaded JAX parameters at {args.jax_param_path}..." f"Successfully loaded JAX parameters at {args.jax_param_path}..."
) )
yield model, None model_version = os.path.basename(
os.path.normpath(args.jax_param_path),
)
model_version = os.path.splitext(model_version)[0]
yield model, model_version
if args.openfold_checkpoint_path: if args.openfold_checkpoint_path:
for path in args.openfold_checkpoint_path.split(","): for path in args.openfold_checkpoint_path.split(","):
model = AlphaFold(config) model = AlphaFold(config)
...@@ -264,11 +258,14 @@ def load_models_from_command_line(args, config): ...@@ -264,11 +258,14 @@ def load_models_from_command_line(args, config):
# The public weights have had this done to them already # The public weights have had this done to them already
d = d["ema"]["params"] d = d["ema"]["params"]
model.load_state_dict(d) model.load_state_dict(d)
model = model.to(args.model_device) model = model.to(args.model_device)
logger.info( logger.info(
f"Loaded OpenFold parameters at {args.openfold_checkpoint_path}..." f"Loaded OpenFold parameters at {args.openfold_checkpoint_path}..."
) )
yield model, checkpoint_basename yield model, checkpoint_basename
if not args.jax_param_path and not args.openfold_checkpoint_path: if not args.jax_param_path and not args.openfold_checkpoint_path:
raise ValueError( raise ValueError(
"At least one of jax_param_path or openfold_checkpoint_path must " "At least one of jax_param_path or openfold_checkpoint_path must "
...@@ -311,23 +308,40 @@ def main(args): ...@@ -311,23 +308,40 @@ def main(args):
os.makedirs(prediction_dir, exist_ok=True) os.makedirs(prediction_dir, exist_ok=True)
for fasta_file in os.listdir(args.fasta_dir): for fasta_file in os.listdir(args.fasta_dir):
with open(os.path.join(fasta_dir, fasta_file), "r") as fp:
data = fp.read()
tags, seqs = parse_fasta(data)
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag = '-'.join(tags)
output_name = f'{tag}_{args.config_preset}'
if args.output_postfix is not None:
output_name = f'{output_name}_{args.output_postfix}'
unrelaxed_output_path = os.path.join(
prediction_dir, f'{output_name}_unrelaxed.pdb'
)
# Output already exists
if os.path.exists(unrelaxed_output_path):
continue
batch_data = generate_batch( precompute_alignments(tags, seqs, alignment_dir, args)
fasta_file,
args.fasta_dir, feature_dict = generate_feature_dict(
tags,
seqs,
alignment_dir, alignment_dir,
data_processor, data_processor,
feature_processor, args,
prediction_dir) )
if batch_data is None:
# this file has already been processed
continue
batch, tag, feature_dict = batch_data processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict',
)
for model, model_version in load_models_from_command_line(args, config): for model, model_version in load_models_from_command_line(args, config):
working_batch = deepcopy(batch) working_batch = deepcopy(batch)
out = run_model(model, working_batch, tag, args) out = run_model(model, working_batch, tag, args)
...@@ -339,21 +353,11 @@ def main(args): ...@@ -339,21 +353,11 @@ def main(args):
out, working_batch, feature_dict, feature_processor, args out, working_batch, feature_dict, feature_processor, args
) )
output_name = f'{tag}_{args.config_preset}'
if model_version is not None:
output_name = f'{output_name}_{model_version}'
if args.output_postfix is not None:
output_name = f'{output_name}_{args.output_postfix}'
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(
prediction_dir, f'{output_name}_unrelaxed.pdb'
)
with open(unrelaxed_output_path, 'w') as fp: with open(unrelaxed_output_path, 'w') as fp:
fp.write(protein.to_pdb(unrelaxed_protein)) fp.write(protein.to_pdb(unrelaxed_protein))
logger.info(f"Output written to {unrelaxed_output_path}...") logger.info(f"Output written to {unrelaxed_output_path}...")
if not args.skip_relaxation: if not args.skip_relaxation:
amber_relaxer = relax.AmberRelaxation( amber_relaxer = relax.AmberRelaxation(
use_gpu=(args.model_device != "cpu"), use_gpu=(args.model_device != "cpu"),
...@@ -377,6 +381,7 @@ def main(args): ...@@ -377,6 +381,7 @@ def main(args):
) )
with open(relaxed_output_path, 'w') as fp: with open(relaxed_output_path, 'w') as fp:
fp.write(relaxed_pdb_str) fp.write(relaxed_pdb_str)
logger.info(f"Relaxed output written to {relaxed_output_path}...") logger.info(f"Relaxed output written to {relaxed_output_path}...")
if args.save_outputs: if args.save_outputs:
...@@ -388,6 +393,7 @@ def main(args): ...@@ -388,6 +393,7 @@ def main(args):
logger.info(f"Model output written to {output_dict_path}...") logger.info(f"Model output written to {output_dict_path}...")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
...@@ -413,8 +419,7 @@ if __name__ == "__main__": ...@@ -413,8 +419,7 @@ if __name__ == "__main__":
) )
parser.add_argument( parser.add_argument(
"--config_preset", type=str, default="model_1", "--config_preset", type=str, default="model_1",
help="""Name of a model config. Choose one of model_{1-5} or help="""Name of a model config preset defined in openfold/config.py"""
model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
) )
parser.add_argument( parser.add_argument(
"--jax_param_path", type=str, default=None, "--jax_param_path", type=str, default=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment