Commit 0c4a93f7 authored by Lucas Bickmann's avatar Lucas Bickmann
Browse files

Fix wrong parsing

parent 1026158e
...@@ -283,9 +283,9 @@ def main(args): ...@@ -283,9 +283,9 @@ def main(args):
sd = {k[len("module."):]:v for k,v in sd.items()} sd = {k[len("module."):]:v for k,v in sd.items()}
model_module.load_state_dict(sd) model_module.load_state_dict(sd)
logging.info("Successfully loaded model weights...") logging.info("Successfully loaded model weights...")
if(args.jax_param_path): if(args.resume_from_jax_params):
model_module.load_from_jax(args.jax_param_path) model_module.load_from_jax(args.resume_from_jax_params)
logging.info(f"Successfully loaded JAX parameters at {args.jax_param_path}...") logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")
# TorchScript components of the model # TorchScript components of the model
if(args.script_modules): if(args.script_modules):
...@@ -591,7 +591,7 @@ if __name__ == "__main__": ...@@ -591,7 +591,7 @@ if __name__ == "__main__":
if(str(args.precision) == "16" and args.deepspeed_config_path is not None): if(str(args.precision) == "16" and args.deepspeed_config_path is not None):
raise ValueError("DeepSpeed and FP16 training are not compatible") raise ValueError("DeepSpeed and FP16 training are not compatible")
if(args.jax_param_path is not None and args.resume_from_ckpt is not None): if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path") raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path")
# This re-applies the training-time filters at the beginning of every epoch # This re-applies the training-time filters at the beginning of every epoch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment