Commit 07b522b7 authored by Sam DeLuca's avatar Sam DeLuca
Browse files

fixing bug in jax handling

parent a2ab7ab7
...@@ -226,7 +226,7 @@ def load_models_from_command_line(args, config): ...@@ -226,7 +226,7 @@ def load_models_from_command_line(args, config):
model = AlphaFold(config) model = AlphaFold(config)
model = model.eval() model = model.eval()
import_jax_weights_( import_jax_weights_(
model, path, version=args.model_name model, path, version=args.config_preset
) )
model = model.to(args.model_device) model = model.to(args.model_device)
logger.info( logger.info(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment