Commit f4df7217 authored by Jennifer's avatar Jennifer
Browse files

Changes resume_model_weights_only flag in train_openfold to reload weights...

Changes resume_model_weights_only flag in train_openfold to reload weights only, without parsing a time step.
parent f1cd1381
......@@ -282,28 +282,38 @@ def main(args):
)
model_module = OpenFoldWrapper(config)
if(args.resume_from_ckpt):
if(os.path.isdir(args.resume_from_ckpt)):
last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
last_global_step = int(sd['global_step'])
model_module.resume_last_lr_step(last_global_step)
logging.info("Successfully loaded last lr step...")
if(args.resume_from_ckpt and args.resume_model_weights_only):
if(os.path.isdir(args.resume_from_ckpt)):
sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
if 'module' in sd:
module_sd = {k[len("module."):]:v for k,v in sd['module'].items()}
import_openfold_weights_(model=model_module, state_dict=module_sd)
elif 'state_dict' in sd:
import_openfold_weights_(model=model_module, state_dict=sd['state_dict'])
else:
import_openfold_weights_(model=model_module, state_dict=sd)
logging.info("Successfully loaded model weights...")
if(args.resume_from_jax_params):
if args.resume_from_ckpt:
if args.resume_model_weights_only:
# Load the checkpoint
if os.path.isdir(args.resume_from_ckpt):
sd = get_fp32_state_dict_from_zero_checkpoint(
args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
# Process the state dict
if 'module' in sd:
sd = {k[len('module.'):]: v for k, v in sd['module'].items()}
import_openfold_weights_(model=model_module, state_dict=sd)
elif 'state_dict' in sd:
import_openfold_weights_(
model=model_module, state_dict=sd['state_dict'])
else:
# Loading from pre-trained model
sd = {'model.'+k: v for k, v in sd.items()}
import_openfold_weights_(model=model_module, state_dict=sd)
logging.info("Successfully loaded model weights...")
else: # Loads a checkpoint to start from a specific time step
if os.path.isdir(args.resume_from_ckpt):
last_global_step = get_global_step_from_zero_checkpoint(
args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
last_global_step = int(sd['global_step'])
model_module.resume_last_lr_step(last_global_step)
logging.info("Successfully loaded last lr step...")
if args.resume_from_jax_params:
model_module.load_from_jax(args.resume_from_jax_params)
logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment