Commit fcee03aa authored by Sam DeLuca's avatar Sam DeLuca
Browse files

wip

parent 15a8c321
...@@ -207,7 +207,6 @@ def load_models_from_command_line(args, config): ...@@ -207,7 +207,6 @@ def load_models_from_command_line(args, config):
# Create the output directory # Create the output directory
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.jax_param_path: if args.jax_param_path:
for path in args.jax_param_path.split(","): for path in args.jax_param_path.split(","):
model = AlphaFold(config) model = AlphaFold(config)
model = model.eval() model = model.eval()
...@@ -217,7 +216,7 @@ def load_models_from_command_line(args, config): ...@@ -217,7 +216,7 @@ def load_models_from_command_line(args, config):
model = model.to(args.model_device) model = model.to(args.model_device)
yield model, None yield model, None
if args.openfold_checkpoint_path: if args.openfold_checkpoint_path:
for path in args.openfold_checkpoint_path: for path in args.openfold_checkpoint_path.split(","):
model = AlphaFold(config) model = AlphaFold(config)
model = model.eval() model = model.eval()
checkpoint_basename = None checkpoint_basename = None
...@@ -234,7 +233,7 @@ def load_models_from_command_line(args, config): ...@@ -234,7 +233,7 @@ def load_models_from_command_line(args, config):
if not os.path.isfile(ckpt_path): if not os.path.isfile(ckpt_path):
convert_zero_checkpoint_to_fp32_state_dict( convert_zero_checkpoint_to_fp32_state_dict(
args.openfold_checkpoint_path, path,
ckpt_path, ckpt_path,
) )
else: else:
...@@ -244,7 +243,7 @@ def load_models_from_command_line(args, config): ...@@ -244,7 +243,7 @@ def load_models_from_command_line(args, config):
model.load_state_dict(d["ema"]["params"]) model.load_state_dict(d["ema"]["params"])
model = model.to(args.model_device) model = model.to(args.model_device)
yield model, checkpoint_basename yield model, checkpoint_basename
else: if not args.jax_param_path and not args.openfold_checkpoint_path:
raise ValueError( raise ValueError(
"At least one of jax_param_path or openfold_checkpoint_path must " "At least one of jax_param_path or openfold_checkpoint_path must "
"be specified." "be specified."
......
#!/bin/bash
# the alphafold script has a few frustrating behaviors in how it parses input flags, this script reduces the need for
# code duplication in the argo workflow
set -x
database_path=$1
fasta_path=$2
base_arguments=(
${fasta_path}
"${database_path}/pdb_mmcif/mmcif_files/"
"--uniref90_database_path" "${database_path}/uniref90/uniref90.fasta"
"--mgnify_database_path" "${database_path}/mgnify/mgy_clusters_2018_12.fa"
"--pdb70_database_path" "${database_path}/pdb70/pdb70"
"--uniclust30_database_path" "${database_path}/uniclust30/uniclust30_2018_08/uniclust30_2018_08"
"--output_dir" "out"
"--bfd_database_path" "${database_path}/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt"
"--model_device" "cuda:0"
"--jackhmmer_binary_path" "/home/samdeluca/miniconda3/envs/openfold_venv/bin/jackhmmer"
"--hhblits_binary_path" "/home/samdeluca/miniconda3/envs/openfold_venv/bin/hhblits"
"--hhsearch_binary_path" "/home/samdeluca/miniconda3/envs/openfold_venv/bin/hhsearch"
"--kalign_binary_path" "/home/samdeluca/miniconda3/envs/openfold_venv/bin/kalign"
"--openfold_checkpoint_path" "/mnt/openfold_params/101-80999.ckpt,/mnt/openfold_params/116-84749.ckpt,/mnt/openfold_params/94-79249.ckpt"
)
python3 run_pretrained_openfold.py ${base_arguments[*]}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment