Commit 07b522b7 authored by Sam DeLuca's avatar Sam DeLuca
Browse files

fixing bug in jax handling

parent a2ab7ab7
......@@ -226,7 +226,7 @@ def load_models_from_command_line(args, config):
model = AlphaFold(config)
model = model.eval()
import_jax_weights_(
model, path, version=args.model_name
model, path, version=args.config_preset
)
model = model.to(args.model_device)
logger.info(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment