Commit 775f77dd authored by Jennifer's avatar Jennifer
Browse files

bugfixes and adds a section to convert optim files

parent 260592e0
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import re import re
import logging
from enum import Enum from enum import Enum
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
...@@ -669,6 +670,7 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -669,6 +670,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
def convert_deprecated_v1_keys(state_dict): def convert_deprecated_v1_keys(state_dict):
"""Update older OpenFold model weight names to match the current model code.""" """Update older OpenFold model weight names to match the current model code."""
logging.warning('converting keys...')
replacements = { replacements = {
'template_angle_embedder': 'template_single_embedder', 'template_angle_embedder': 'template_single_embedder',
...@@ -686,17 +688,22 @@ def convert_deprecated_v1_keys(state_dict): ...@@ -686,17 +688,22 @@ def convert_deprecated_v1_keys(state_dict):
converted_state_dict = {} converted_state_dict = {}
for key, value in state_dict.items(): for key, value in state_dict.items():
# For each match, look-up replacement value in the dictionary # For each match, look-up replacement value in the dictionary
new_key = convert_key_re.sub(lambda m: replacements[m.group()], key) new_key = convert_key_re.sub(lambda m: replacements[m.group(1)], key)
### DEBUG: remove before final commit
if key == 'template_angle_embedder.linear_1.weight': if key == 'template_angle_embedder.linear_1.weight':
print(f'old key: {key}, new_key: {new_key}') logging.warning(f'old key: {key}, new_key: {new_key}')
### DEBUG: remove before final commit
# Add prefix for template layers # Add prefix for template layers
template_match = re.match(template_emb_re, new_key) template_match = re.match(template_emb_re, new_key)
if template_match: if template_match:
prefix = template_match.group(1) prefix = template_match.group(1)
new_key = f'{prefix if prefix else ""}template_embedder.{template_match.group(4)}' new_key = f'{prefix if prefix else ""}template_embedder.{template_match.group(4)}'
# DEBUG: remove before final commit
if key == 'template_angle_embedder.linear_1.weight': if key == 'template_angle_embedder.linear_1.weight':
print(f'old key: {key}, new_key: {new_key}') breakpoint()
logging.warning(f'old key: {key}, new_key: {new_key}')
### DEBUG: remove before final commit
converted_state_dict[new_key] = value converted_state_dict[new_key] = value
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# Converts OpenFold .pt checkpoints into AlphaFold .npz ones, which can then be # Converts OpenFold .pt checkpoints into AlphaFold .npz ones, which can then be
# used to run inference using DeepMind's JAX code. # used to run inference using DeepMind's JAX code.
import logging
import argparse import argparse
import os import os
import shutil import shutil
...@@ -23,47 +24,61 @@ import torch ...@@ -23,47 +24,61 @@ import torch
from openfold.utils.import_weights import convert_deprecated_v1_keys from openfold.utils.import_weights import convert_deprecated_v1_keys
from zero_to_fp32 import get_optim_files, parse_optim_states, get_model_state_file from zero_to_fp32 import get_optim_files, parse_optim_states, get_model_state_file
def get_latest_checkpoint_dir(checkpoint_dir):
# Based on zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint
latest_path = os.path.join(checkpoint_dir, 'latest')
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
return os.path.join(checkpoint_dir, tag)
def convert_v1_to_v2_weights(args): def convert_v1_to_v2_weights(args):
# TODO can we use zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint here?
checkpoint_path = args.input_ckpt_path checkpoint_path = args.input_ckpt_path
is_dir = os.path.isdir(checkpoint_path) is_dir = os.path.isdir(checkpoint_path)
if is_dir: if is_dir:
# A DeepSpeed checkpoint # A DeepSpeed checkpoint
ds_checkpoint_path = get_latest_checkpoint_dir(checkpoint_path) logging.info(
'Converting checkpoint found at {args.input_checkpoint_path}')
state_dict_key = 'module' state_dict_key = 'module'
optim_files = get_optim_files(ds_checkpoint_path) latest_path = os.path.join(checkpoint_path, 'latest')
zero_stage, _, _ = parse_optim_states(optim_files, ds_checkpoint_path) if os.path.isfile(latest_path):
model_file = get_model_state_file(ds_checkpoint_path, zero_stage) with open(latest_path, 'r') as fd:
tag = fd.read().strip()
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
ds_checkpoint_dir = os.path.join(checkpoint_path, tag)
model_output_path = os.path.join(args.output_ckpt_path, tag)
optim_files = get_optim_files(ds_checkpoint_dir)
zero_stage, _, _ = parse_optim_states(optim_files, ds_checkpoint_dir)
model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
else: else:
# A Pytorch Lightning checkpoint # A Pytorch Lightning checkpoint
state_dict_key = 'state_dict' state_dict_key = 'state_dict'
model_output_path = args.output_ckpt_path
model_file = checkpoint_path model_file = checkpoint_path
model_dict = torch.load(model_file, map_location=torch.device('cpu')) model_dict = torch.load(model_file, map_location=torch.device('cpu'))
model_dict[state_dict_key] = convert_deprecated_v1_keys(model_dict[state_dict_key]) model_dict[state_dict_key] = convert_deprecated_v1_keys(
model_dict[state_dict_key])
if 'ema' in model_dict: if 'ema' in model_dict:
ema_state_dict = model_dict['ema']['params'] ema_state_dict = model_dict['ema']['params']
model_dict['ema']['params'] = convert_deprecated_v1_keys(ema_state_dict) model_dict['ema']['params'] = convert_deprecated_v1_keys(
ema_state_dict)
if is_dir: if is_dir:
param_shapes = convert_deprecated_v1_keys(
model_dict['param_shapes'][0])
model_dict['param_shapes'] = [param_shapes]
shutil.copytree(checkpoint_path, args.output_ckpt_path) shutil.copytree(checkpoint_path, args.output_ckpt_path)
out_fname = os.path.join(args.output_ckpt_path, os.path.basename(model_file)) out_fname = os.path.join(
model_output_path, os.path.basename(model_file))
for optim_file in optim_files:
optim_dict = torch.load(optim_file)
new_optim_dict = optim_dict.copy()
new_optim_dict['optimizer_state_dict']['param_slice_mappings'][0] = convert_deprecated_v1_keys(
optim_dict['optimizer_state_dict']['param_slice_mappings'][0])
out_optim_fname = os.path.join(
model_output_path, os.path.basename(optim_file))
torch.save(new_optim_dict, out_optim_fname)
else: else:
out_fname = args.output_ckpt_path out_fname = model_output_path
torch.save(model_dict, out_fname) torch.save(model_dict, out_fname)
......
...@@ -39,6 +39,7 @@ from scripts.zero_to_fp32 import ( ...@@ -39,6 +39,7 @@ from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint, get_fp32_state_dict_from_zero_checkpoint,
get_global_step_from_zero_checkpoint get_global_step_from_zero_checkpoint
) )
from scripts.zero_to_fp32 import get_optim_files, parse_optim_states, get_model_state_file
from openfold.utils.logger import PerformanceLoggingCallback from openfold.utils.logger import PerformanceLoggingCallback
...@@ -288,6 +289,16 @@ def main(args): ...@@ -288,6 +289,16 @@ def main(args):
sd = torch.load(args.resume_from_ckpt) sd = torch.load(args.resume_from_ckpt)
last_global_step = int(sd['global_step']) last_global_step = int(sd['global_step'])
model_module.resume_last_lr_step(last_global_step) model_module.resume_last_lr_step(last_global_step)
### DEBUG:
ds_checkpoint_dir = os.path.join(args.resume_from_ckpt, 'global_step210')
optim_files = get_optim_files(ds_checkpoint_dir)
zero_stage, _, _ = parse_optim_states(optim_files, ds_checkpoint_dir)
model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
model_dict = torch.load(model_file, map_location=torch.device('cpu'))
###
logging.info("Successfully loaded last lr step...") logging.info("Successfully loaded last lr step...")
if(args.resume_from_ckpt and args.resume_model_weights_only): if(args.resume_from_ckpt and args.resume_model_weights_only):
if(os.path.isdir(args.resume_from_ckpt)): if(os.path.isdir(args.resume_from_ckpt)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment