"...llm/vllm/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "eb022ec9091c11fa9bd098e64f7c43d06b06b8d9"
Commit 5bfad074 authored by Jennifer's avatar Jennifer
Browse files

Changes resume_model_weights_only flag in train_openfold to reload weights...

Changes resume_model_weights_only flag in train_openfold to reload weights only, without parsing a time step.
parent 521bc6e9
...@@ -282,28 +282,38 @@ def main(args): ...@@ -282,28 +282,38 @@ def main(args):
) )
model_module = OpenFoldWrapper(config) model_module = OpenFoldWrapper(config)
if(args.resume_from_ckpt): if args.resume_from_ckpt:
if(os.path.isdir(args.resume_from_ckpt)): if args.resume_model_weights_only:
last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt) # Load the checkpoint
else: if os.path.isdir(args.resume_from_ckpt):
sd = torch.load(args.resume_from_ckpt) sd = get_fp32_state_dict_from_zero_checkpoint(
last_global_step = int(sd['global_step']) args.resume_from_ckpt)
model_module.resume_last_lr_step(last_global_step) else:
logging.info("Successfully loaded last lr step...") sd = torch.load(args.resume_from_ckpt)
if(args.resume_from_ckpt and args.resume_model_weights_only): # Process the state dict
if(os.path.isdir(args.resume_from_ckpt)): if 'module' in sd:
sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt) sd = {k[len('module.'):]: v for k, v in sd['module'].items()}
else: import_openfold_weights_(model=model_module, state_dict=sd)
sd = torch.load(args.resume_from_ckpt) elif 'state_dict' in sd:
if 'module' in sd: import_openfold_weights_(
module_sd = {k[len("module."):]:v for k,v in sd['module'].items()} model=model_module, state_dict=sd['state_dict'])
import_openfold_weights_(model=model_module, state_dict=module_sd) else:
elif 'state_dict' in sd: # Loading from pre-trained model
import_openfold_weights_(model=model_module, state_dict=sd['state_dict']) sd = {'model.'+k: v for k, v in sd.items()}
else: import_openfold_weights_(model=model_module, state_dict=sd)
import_openfold_weights_(model=model_module, state_dict=sd) logging.info("Successfully loaded model weights...")
logging.info("Successfully loaded model weights...")
if(args.resume_from_jax_params): else: # Loads a checkpoint to start from a specific time step
if os.path.isdir(args.resume_from_ckpt):
last_global_step = get_global_step_from_zero_checkpoint(
args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
last_global_step = int(sd['global_step'])
model_module.resume_last_lr_step(last_global_step)
logging.info("Successfully loaded last lr step...")
if args.resume_from_jax_params:
model_module.load_from_jax(args.resume_from_jax_params) model_module.load_from_jax(args.resume_from_jax_params)
logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...") logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment