Commit e31e0378 authored by Jennifer's avatar Jennifer
Browse files

Add log statement to weight conversion script

parent 775f77dd
......@@ -670,7 +670,6 @@ def import_jax_weights_(model, npz_path, version="model_1"):
def convert_deprecated_v1_keys(state_dict):
"""Update older OpenFold model weight names to match the current model code."""
logging.warning('converting keys...')
replacements = {
'template_angle_embedder': 'template_single_embedder',
......@@ -683,27 +682,18 @@ def convert_deprecated_v1_keys(state_dict):
}
convert_key_re = re.compile("(%s)" % "|".join(map(re.escape, replacements.keys())))
template_emb_re = re.compile("((module\\.)?(model\\.))?(template(?!_embedder).*)")
template_emb_re = re.compile(r"^((module\.)?(model\.)?)(template(?!_embedder).*)")
converted_state_dict = {}
for key, value in state_dict.items():
# For each match, look-up replacement value in the dictionary
new_key = convert_key_re.sub(lambda m: replacements[m.group(1)], key)
### DEBUG: remove before final commit
if key == 'template_angle_embedder.linear_1.weight':
logging.warning(f'old key: {key}, new_key: {new_key}')
### DEBUG: remove before final commit
# Add prefix for template layers
template_match = re.match(template_emb_re, new_key)
if template_match:
prefix = template_match.group(1)
new_key = f'{prefix if prefix else ""}template_embedder.{template_match.group(4)}'
# DEBUG: remove before final commit
if key == 'template_angle_embedder.linear_1.weight':
breakpoint()
logging.warning(f'old key: {key}, new_key: {new_key}')
### DEBUG: remove before final commit
converted_state_dict[new_key] = value
......
......@@ -31,7 +31,7 @@ def convert_v1_to_v2_weights(args):
if is_dir:
# A DeepSpeed checkpoint
logging.info(
'Converting checkpoint found at {args.input_checkpoint_path}')
'Converting deepspeed checkpoint found at {args.input_checkpoint_path}')
state_dict_key = 'module'
latest_path = os.path.join(checkpoint_path, 'latest')
if os.path.isfile(latest_path):
......@@ -47,6 +47,8 @@ def convert_v1_to_v2_weights(args):
model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
else:
# A Pytorch Lightning checkpoint
logging.info(
'Converting pytorch lightning checkpoint found at {args.input_checkpoint_path}')
state_dict_key = 'state_dict'
model_output_path = args.output_ckpt_path
model_file = checkpoint_path
......
......@@ -289,16 +289,6 @@ def main(args):
sd = torch.load(args.resume_from_ckpt)
last_global_step = int(sd['global_step'])
model_module.resume_last_lr_step(last_global_step)
### DEBUG:
ds_checkpoint_dir = os.path.join(args.resume_from_ckpt, 'global_step210')
optim_files = get_optim_files(ds_checkpoint_dir)
zero_stage, _, _ = parse_optim_states(optim_files, ds_checkpoint_dir)
model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
model_dict = torch.load(model_file, map_location=torch.device('cpu'))
###
logging.info("Successfully loaded last lr step...")
if(args.resume_from_ckpt and args.resume_model_weights_only):
if(os.path.isdir(args.resume_from_ckpt)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment