Commit 99cdb062 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix even more undefined references

parent b1e4dc52
...@@ -71,11 +71,10 @@ def precompute_alignments(tags, seqs, alignment_dir, args): ...@@ -71,11 +71,10 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
bfd_database_path=args.bfd_database_path, bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path, uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path, pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus, no_cpus=args.cpus,
) )
alignment_runner.run( alignment_runner.run(
fasta_path, local_alignment_dir tmp_fasta_path, local_alignment_dir
) )
# Remove temporary FASTA file # Remove temporary FASTA file
...@@ -133,7 +132,7 @@ def prep_output(out, batch, feature_dict, feature_processor, args): ...@@ -133,7 +132,7 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
remark = ', '.join([ remark = ', '.join([
f"no_recycling={no_recycling}", f"no_recycling={no_recycling}",
f"max_templates={feature_processor.config.predict.max_templates}", f"max_templates={feature_processor.config.predict.max_templates}",
f"config_preset={args.model_name}", f"config_preset={args.config_preset}",
]) ])
# For multi-chain FASTAs # For multi-chain FASTAs
...@@ -167,16 +166,16 @@ def main(args): ...@@ -167,16 +166,16 @@ def main(args):
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
# Prep the model # Prep the model
config = model_config(args.model_name) config = model_config(args.config_preset)
logger.info(f"Using config preset {args.model_name}...") logger.info(f"Using config preset {args.config_preset}...")
model = AlphaFold(config) model = AlphaFold(config)
model = model.eval() model = model.eval()
if(args.jax_param_path): if(args.jax_param_path):
import_jax_weights_( import_jax_weights_(
model, args.jax_param_path, version=args.model_name model, args.jax_param_path, version=args.config_preset
) )
logger.info( logger.info(
f"Successfully loaded JAX parameters at {args.jax_param_path}..." f"Successfully loaded JAX parameters at {args.jax_param_path}..."
...@@ -234,8 +233,6 @@ def main(args): ...@@ -234,8 +233,6 @@ def main(args):
obsolete_pdbs_path=args.obsolete_pdbs_path obsolete_pdbs_path=args.obsolete_pdbs_path
) )
use_small_bfd=(args.bfd_database_path is None)
data_processor = data_pipeline.DataPipeline( data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer, template_featurizer=template_featurizer,
) )
...@@ -271,7 +268,7 @@ def main(args): ...@@ -271,7 +268,7 @@ def main(args):
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique" # assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag = '-'.join(tags) tag = '-'.join(tags)
output_name = f'{tag}_{args.model_name}' output_name = f'{tag}_{args.config_preset}'
if(args.output_postfix is not None): if(args.output_postfix is not None):
output_name = f'{output_name}_{args.output_postfix}' output_name = f'{output_name}_{args.output_postfix}'
...@@ -322,9 +319,9 @@ def main(args): ...@@ -322,9 +319,9 @@ def main(args):
out, batch, feature_dict, feature_processor, args out, batch, feature_dict, feature_processor, args
) )
output_name = f'{tag}_{args.model_name}' output_name = f'{tag}_{args.config_preset}'
if(args.output_postfix is not None): if(args.output_postfix is not None):
output_name = f'{output_name}_{args.output_postfix}_{tag_postfix}' output_name = f'{output_name}_{args.output_postfix}'
# Save the unrelaxed PDB. # Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join( unrelaxed_output_path = os.path.join(
...@@ -394,7 +391,7 @@ if __name__ == "__main__": ...@@ -394,7 +391,7 @@ if __name__ == "__main__":
device name is accepted (e.g. "cpu", "cuda:0")""" device name is accepted (e.g. "cpu", "cuda:0")"""
) )
parser.add_argument( parser.add_argument(
"--model_name", type=str, default="model_1", "--config_preset", type=str, default="model_1",
help="""Name of a model config. Choose one of model_{1-5} or help="""Name of a model config. Choose one of model_{1-5} or
model_{1-5}_ptm, as defined on the AlphaFold GitHub.""" model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
) )
...@@ -441,7 +438,7 @@ if __name__ == "__main__": ...@@ -441,7 +438,7 @@ if __name__ == "__main__":
if(args.jax_param_path is None and args.openfold_checkpoint_path is None): if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
args.jax_param_path = os.path.join( args.jax_param_path = os.path.join(
"openfold", "resources", "params", "openfold", "resources", "params",
"params_" + args.model_name + ".npz" "params_" + args.config_preset + ".npz"
) )
if(args.model_device == "cpu" and torch.cuda.is_available()): if(args.model_device == "cpu" and torch.cuda.is_available()):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment