"vscode:/vscode.git/clone" did not exist on "faa48887471d862bfed66c32850304f6e409e86c"
Commit 99cdb062 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix even more undefined references

parent b1e4dc52
......@@ -71,11 +71,10 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
alignment_runner.run(
fasta_path, local_alignment_dir
tmp_fasta_path, local_alignment_dir
)
# Remove temporary FASTA file
......@@ -133,7 +132,7 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
remark = ', '.join([
f"no_recycling={no_recycling}",
f"max_templates={feature_processor.config.predict.max_templates}",
f"config_preset={args.model_name}",
f"config_preset={args.config_preset}",
])
# For multi-chain FASTAs
......@@ -167,16 +166,16 @@ def main(args):
os.makedirs(args.output_dir, exist_ok=True)
# Prep the model
config = model_config(args.model_name)
config = model_config(args.config_preset)
logger.info(f"Using config preset {args.model_name}...")
logger.info(f"Using config preset {args.config_preset}...")
model = AlphaFold(config)
model = model.eval()
if(args.jax_param_path):
import_jax_weights_(
model, args.jax_param_path, version=args.model_name
model, args.jax_param_path, version=args.config_preset
)
logger.info(
f"Successfully loaded JAX parameters at {args.jax_param_path}..."
......@@ -234,8 +233,6 @@ def main(args):
obsolete_pdbs_path=args.obsolete_pdbs_path
)
use_small_bfd=(args.bfd_database_path is None)
data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
)
......@@ -271,7 +268,7 @@ def main(args):
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag = '-'.join(tags)
output_name = f'{tag}_{args.model_name}'
output_name = f'{tag}_{args.config_preset}'
if(args.output_postfix is not None):
output_name = f'{output_name}_{args.output_postfix}'
......@@ -322,9 +319,9 @@ def main(args):
out, batch, feature_dict, feature_processor, args
)
output_name = f'{tag}_{args.model_name}'
output_name = f'{tag}_{args.config_preset}'
if(args.output_postfix is not None):
output_name = f'{output_name}_{args.output_postfix}_{tag_postfix}'
output_name = f'{output_name}_{args.output_postfix}'
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(
......@@ -394,7 +391,7 @@ if __name__ == "__main__":
device name is accepted (e.g. "cpu", "cuda:0")"""
)
parser.add_argument(
"--model_name", type=str, default="model_1",
"--config_preset", type=str, default="model_1",
help="""Name of a model config. Choose one of model_{1-5} or
model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
)
......@@ -441,7 +438,7 @@ if __name__ == "__main__":
if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
args.jax_param_path = os.path.join(
"openfold", "resources", "params",
"params_" + args.model_name + ".npz"
"params_" + args.config_preset + ".npz"
)
if(args.model_device == "cpu" and torch.cuda.is_available()):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment