"examples/hello_world" did not exist on "4017bd18d0e84b8463bfe381279d4c5a6fd0c6e0"
Commit 0c4a93f7 authored by Lucas Bickmann's avatar Lucas Bickmann
Browse files

Fix wrong parsing

parent 1026158e
......@@ -283,9 +283,9 @@ def main(args):
sd = {k[len("module."):]:v for k,v in sd.items()}
model_module.load_state_dict(sd)
logging.info("Successfully loaded model weights...")
if(args.jax_param_path):
model_module.load_from_jax(args.jax_param_path)
logging.info(f"Successfully loaded JAX parameters at {args.jax_param_path}...")
if(args.resume_from_jax_params):
model_module.load_from_jax(args.resume_from_jax_params)
logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")
# TorchScript components of the model
if(args.script_modules):
......@@ -591,7 +591,7 @@ if __name__ == "__main__":
if(str(args.precision) == "16" and args.deepspeed_config_path is not None):
raise ValueError("DeepSpeed and FP16 training are not compatible")
if(args.jax_param_path is not None and args.resume_from_ckpt is not None):
if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path")
# This re-applies the training-time filters at the beginning of every epoch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment