Commit 775f77dd authored by Jennifer's avatar Jennifer
Browse files

bugfixes and adds a section to convert optim files

parent 260592e0
......@@ -14,6 +14,7 @@
# limitations under the License.
import re
import logging
from enum import Enum
from dataclasses import dataclass
from functools import partial
......@@ -669,6 +670,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
def convert_deprecated_v1_keys(state_dict):
"""Update older OpenFold model weight names to match the current model code."""
logging.warning('converting keys...')
replacements = {
'template_angle_embedder': 'template_single_embedder',
......@@ -686,17 +688,22 @@ def convert_deprecated_v1_keys(state_dict):
converted_state_dict = {}
for key, value in state_dict.items():
# For each match, look-up replacement value in the dictionary
new_key = convert_key_re.sub(lambda m: replacements[m.group()], key)
new_key = convert_key_re.sub(lambda m: replacements[m.group(1)], key)
### DEBUG: remove before final commit
if key == 'template_angle_embedder.linear_1.weight':
print(f'old key: {key}, new_key: {new_key}')
logging.warning(f'old key: {key}, new_key: {new_key}')
### DEBUG: remove before final commit
# Add prefix for template layers
template_match = re.match(template_emb_re, new_key)
if template_match:
prefix = template_match.group(1)
new_key = f'{prefix if prefix else ""}template_embedder.{template_match.group(4)}'
# DEBUG: remove before final commit
if key == 'template_angle_embedder.linear_1.weight':
print(f'old key: {key}, new_key: {new_key}')
breakpoint()
logging.warning(f'old key: {key}, new_key: {new_key}')
### DEBUG: remove before final commit
converted_state_dict[new_key] = value
......
......@@ -15,6 +15,7 @@
# Converts OpenFold .pt checkpoints into AlphaFold .npz ones, which can then be
# used to run inference using DeepMind's JAX code.
import logging
import argparse
import os
import shutil
......@@ -23,47 +24,61 @@ import torch
from openfold.utils.import_weights import convert_deprecated_v1_keys
from zero_to_fp32 import get_optim_files, parse_optim_states, get_model_state_file
def get_latest_checkpoint_dir(checkpoint_dir):
# Based on zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint
latest_path = os.path.join(checkpoint_dir, 'latest')
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
return os.path.join(checkpoint_dir, tag)
def convert_v1_to_v2_weights(args):
# TODO can we use zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint here?
checkpoint_path = args.input_ckpt_path
is_dir = os.path.isdir(checkpoint_path)
if is_dir:
# A DeepSpeed checkpoint
ds_checkpoint_path = get_latest_checkpoint_dir(checkpoint_path)
logging.info(
'Converting checkpoint found at {args.input_checkpoint_path}')
state_dict_key = 'module'
optim_files = get_optim_files(ds_checkpoint_path)
zero_stage, _, _ = parse_optim_states(optim_files, ds_checkpoint_path)
model_file = get_model_state_file(ds_checkpoint_path, zero_stage)
latest_path = os.path.join(checkpoint_path, 'latest')
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
ds_checkpoint_dir = os.path.join(checkpoint_path, tag)
model_output_path = os.path.join(args.output_ckpt_path, tag)
optim_files = get_optim_files(ds_checkpoint_dir)
zero_stage, _, _ = parse_optim_states(optim_files, ds_checkpoint_dir)
model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
else:
# A Pytorch Lightning checkpoint
state_dict_key = 'state_dict'
model_output_path = args.output_ckpt_path
model_file = checkpoint_path
model_dict = torch.load(model_file, map_location=torch.device('cpu'))
model_dict[state_dict_key] = convert_deprecated_v1_keys(model_dict[state_dict_key])
model_dict[state_dict_key] = convert_deprecated_v1_keys(
model_dict[state_dict_key])
if 'ema' in model_dict:
ema_state_dict = model_dict['ema']['params']
model_dict['ema']['params'] = convert_deprecated_v1_keys(ema_state_dict)
model_dict['ema']['params'] = convert_deprecated_v1_keys(
ema_state_dict)
if is_dir:
param_shapes = convert_deprecated_v1_keys(
model_dict['param_shapes'][0])
model_dict['param_shapes'] = [param_shapes]
shutil.copytree(checkpoint_path, args.output_ckpt_path)
out_fname = os.path.join(args.output_ckpt_path, os.path.basename(model_file))
out_fname = os.path.join(
model_output_path, os.path.basename(model_file))
for optim_file in optim_files:
optim_dict = torch.load(optim_file)
new_optim_dict = optim_dict.copy()
new_optim_dict['optimizer_state_dict']['param_slice_mappings'][0] = convert_deprecated_v1_keys(
optim_dict['optimizer_state_dict']['param_slice_mappings'][0])
out_optim_fname = os.path.join(
model_output_path, os.path.basename(optim_file))
torch.save(new_optim_dict, out_optim_fname)
else:
out_fname = args.output_ckpt_path
out_fname = model_output_path
torch.save(model_dict, out_fname)
......@@ -75,4 +90,4 @@ if __name__ == "__main__":
args = parser.parse_args()
convert_v1_to_v2_weights(args)
\ No newline at end of file
convert_v1_to_v2_weights(args)
......@@ -39,6 +39,7 @@ from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint,
get_global_step_from_zero_checkpoint
)
from scripts.zero_to_fp32 import get_optim_files, parse_optim_states, get_model_state_file
from openfold.utils.logger import PerformanceLoggingCallback
......@@ -288,6 +289,16 @@ def main(args):
sd = torch.load(args.resume_from_ckpt)
last_global_step = int(sd['global_step'])
model_module.resume_last_lr_step(last_global_step)
### DEBUG:
ds_checkpoint_dir = os.path.join(args.resume_from_ckpt, 'global_step210')
optim_files = get_optim_files(ds_checkpoint_dir)
zero_stage, _, _ = parse_optim_states(optim_files, ds_checkpoint_dir)
model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
model_dict = torch.load(model_file, map_location=torch.device('cpu'))
###
logging.info("Successfully loaded last lr step...")
if(args.resume_from_ckpt and args.resume_model_weights_only):
if(os.path.isdir(args.resume_from_ckpt)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment