Commit 9776b696 authored by jnwei's avatar jnwei
Browse files

Merge weight-loading changes into setup-improvements

parents 9f346d35 ddfccd56
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import re import re
import logging
from enum import Enum from enum import Enum
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
...@@ -681,15 +682,18 @@ def convert_deprecated_v1_keys(state_dict): ...@@ -681,15 +682,18 @@ def convert_deprecated_v1_keys(state_dict):
} }
convert_key_re = re.compile("(%s)" % "|".join(map(re.escape, replacements.keys()))) convert_key_re = re.compile("(%s)" % "|".join(map(re.escape, replacements.keys())))
template_emb_re = re.compile(r"^((module\.)?(model\.)?)(template(?!_embedder).*)")
converted_state_dict = {} converted_state_dict = {}
for key, value in state_dict.items(): for key, value in state_dict.items():
# For each match, look-up replacement value in the dictionary # For each match, look-up replacement value in the dictionary
new_key = convert_key_re.sub(lambda m: replacements[m.group()], key) new_key = convert_key_re.sub(lambda m: replacements[m.group(1)], key)
# Add prefix for template modules # Add prefix for template layers
if new_key.startswith('template'): template_match = re.match(template_emb_re, new_key)
new_key = f'template_embedder.{new_key}' if template_match:
prefix = template_match.group(1)
new_key = f'{prefix if prefix else ""}template_embedder.{template_match.group(4)}'
converted_state_dict[new_key] = value converted_state_dict[new_key] = value
......
# Copyright 2022 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Converts OpenFold .pt checkpoints into AlphaFold .npz ones, which can then be
# used to run inference using DeepMind's JAX code.
import logging
import argparse
import os
import shutil
import torch
from openfold.utils.import_weights import convert_deprecated_v1_keys
from zero_to_fp32 import get_optim_files, parse_optim_states, get_model_state_file
def convert_v1_to_v2_weights(args):
checkpoint_path = args.input_ckpt_path
is_dir = os.path.isdir(checkpoint_path)
if is_dir:
# A DeepSpeed checkpoint
logging.info(
'Converting deepspeed checkpoint found at {args.input_checkpoint_path}')
state_dict_key = 'module'
latest_path = os.path.join(checkpoint_path, 'latest')
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
ds_checkpoint_dir = os.path.join(checkpoint_path, tag)
model_output_path = os.path.join(args.output_ckpt_path, tag)
optim_files = get_optim_files(ds_checkpoint_dir)
zero_stage, _, _ = parse_optim_states(optim_files, ds_checkpoint_dir)
model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
else:
# A Pytorch Lightning checkpoint
logging.info(
'Converting pytorch lightning checkpoint found at {args.input_checkpoint_path}')
state_dict_key = 'state_dict'
model_output_path = args.output_ckpt_path
model_file = checkpoint_path
model_dict = torch.load(model_file, map_location=torch.device('cpu'))
model_dict[state_dict_key] = convert_deprecated_v1_keys(
model_dict[state_dict_key])
if 'ema' in model_dict:
ema_state_dict = model_dict['ema']['params']
model_dict['ema']['params'] = convert_deprecated_v1_keys(
ema_state_dict)
if is_dir:
param_shapes = convert_deprecated_v1_keys(
model_dict['param_shapes'][0])
model_dict['param_shapes'] = [param_shapes]
shutil.copytree(checkpoint_path, args.output_ckpt_path)
out_fname = os.path.join(
model_output_path, os.path.basename(model_file))
for optim_file in optim_files:
optim_dict = torch.load(optim_file)
new_optim_dict = optim_dict.copy()
new_optim_dict['optimizer_state_dict']['param_slice_mappings'][0] = convert_deprecated_v1_keys(
optim_dict['optimizer_state_dict']['param_slice_mappings'][0])
out_optim_fname = os.path.join(
model_output_path, os.path.basename(optim_file))
torch.save(new_optim_dict, out_optim_fname)
else:
out_fname = model_output_path
torch.save(model_dict, out_fname)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input_ckpt_path", type=str)
parser.add_argument("output_ckpt_path", type=str)
args = parser.parse_args()
convert_v1_to_v2_weights(args)
#!/usr/bin/env python #!/usr/bin/env python
# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
# application. # application.
...@@ -12,13 +17,27 @@ import torch ...@@ -12,13 +17,27 @@ import torch
import glob import glob
import math import math
import os import os
from collections import OrderedDict
import re import re
from collections import OrderedDict
from dataclasses import dataclass
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment. # DeepSpeed data structures it has to be available in the current python environment.
import deepspeed
from deepspeed.utils import logger from deepspeed.utils import logger
from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
@dataclass
class zero_model_state:
buffers: dict()
param_shapes: dict()
shared_params: list
ds_version: int
frozen_param_shapes: dict()
frozen_param_fragments: dict()
debug = 0 debug = 0
...@@ -26,12 +45,25 @@ debug = 0 ...@@ -26,12 +45,25 @@ debug = 0
device = torch.device('cpu') device = torch.device('cpu')
def atoi(text):
return int(text) if text.isdigit() else text
def natural_keys(text):
'''
alist.sort(key=natural_keys) sorts in human order
http://nedbatchelder.com/blog/200712/human_sorting.html
(See Toothy's implementation in the comments)
'''
return [atoi(c) for c in re.split(r'(\d+)', text)]
def get_model_state_file(checkpoint_dir, zero_stage): def get_model_state_file(checkpoint_dir, zero_stage):
if not os.path.isdir(checkpoint_dir): if not os.path.isdir(checkpoint_dir):
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
# there should be only one file # there should be only one file
if zero_stage == 2: if zero_stage <= 2:
file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
elif zero_stage == 3: elif zero_stage == 3:
file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
...@@ -42,33 +74,68 @@ def get_model_state_file(checkpoint_dir, zero_stage): ...@@ -42,33 +74,68 @@ def get_model_state_file(checkpoint_dir, zero_stage):
return file return file
def get_optim_files(checkpoint_dir): def get_checkpoint_files(checkpoint_dir, glob_pattern):
# XXX: need to test that this simple glob rule works for multi-node setup too # XXX: need to test that this simple glob rule works for multi-node setup too
optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, "*_optim_states.pt"))) ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
if len(optim_files) == 0: if len(ckpt_files) == 0:
raise FileNotFoundError( raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'")
return optim_files return ckpt_files
def parse_model_state(file): def get_optim_files(checkpoint_dir):
state_dict = torch.load(file, map_location=device) return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
if "buffer_names" not in state_dict:
raise ValueError(f"{file} is not a model state checkpoint")
buffer_names = state_dict["buffer_names"]
if debug:
print("Found buffers:", buffer_names)
# recover just the buffers while restoring them to fp32 if they were saved in fp16 def get_model_state_files(checkpoint_dir):
buffers = { return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
k: v.float()
for k,
v in state_dict["module"].items() if k in buffer_names def parse_model_states(files):
} zero_model_states = []
return buffers for file in files:
state_dict = torch.load(file, map_location=device)
if BUFFER_NAMES not in state_dict:
raise ValueError(f"{file} is not a model state checkpoint")
buffer_names = state_dict[BUFFER_NAMES]
if debug:
print("Found buffers:", buffer_names)
# recover just the buffers while restoring them to fp32 if they were saved in fp16
buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
param_shapes = state_dict[PARAM_SHAPES]
# collect parameters that are included in param_shapes
param_names = []
for s in param_shapes:
for name in s.keys():
param_names.append(name)
# update with frozen parameters
frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
if frozen_param_shapes is not None:
if debug:
print(f"Found frozen_param_shapes: {frozen_param_shapes}")
param_names += list(frozen_param_shapes.keys())
# handle shared params
shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
ds_version = state_dict.get(DS_VERSION, None)
frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
z_model_state = zero_model_state(buffers=buffers,
param_shapes=param_shapes,
shared_params=shared_params,
ds_version=ds_version,
frozen_param_shapes=frozen_param_shapes,
frozen_param_fragments=frozen_param_fragments)
zero_model_states.append(z_model_state)
return zero_model_states
def parse_optim_states(files, ds_checkpoint_dir): def parse_optim_states(files, ds_checkpoint_dir):
...@@ -76,13 +143,17 @@ def parse_optim_states(files, ds_checkpoint_dir): ...@@ -76,13 +143,17 @@ def parse_optim_states(files, ds_checkpoint_dir):
total_files = len(files) total_files = len(files)
state_dicts = [] state_dicts = []
for f in files: for f in files:
state_dicts.append(torch.load(f, map_location=device)) state_dict = torch.load(f, map_location=device)
# immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
# and also handle the case where it was already removed by another helper script
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
state_dicts.append(state_dict)
if not "zero_stage" in state_dicts[0]['optimizer_state_dict']: if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
raise ValueError(f"{files[0]} is not a zero checkpoint") raise ValueError(f"{files[0]} is not a zero checkpoint")
zero_stage = state_dicts[0]['optimizer_state_dict']["zero_stage"] zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
world_size = state_dicts[0]['optimizer_state_dict']["partition_count"] world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
param_shapes = state_dicts[0]["param_shapes"]
# For ZeRO-2 each param group can have different partition_count as data parallelism for expert # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
# parameters can be different from data parallelism for non-expert parameters. So we can just # parameters can be different from data parallelism for non-expert parameters. So we can just
# use the max of the partition_count to get the dp world_size. # use the max of the partition_count to get the dp world_size.
...@@ -97,18 +168,15 @@ def parse_optim_states(files, ds_checkpoint_dir): ...@@ -97,18 +168,15 @@ def parse_optim_states(files, ds_checkpoint_dir):
) )
# the groups are named differently in each stage # the groups are named differently in each stage
if zero_stage == 2: if zero_stage <= 2:
fp32_groups_key = "single_partition_of_fp32_groups" fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
elif zero_stage == 3: elif zero_stage == 3:
fp32_groups_key = "fp32_flat_groups" fp32_groups_key = FP32_FLAT_GROUPS
else: else:
raise ValueError(f"unknown zero stage {zero_stage}") raise ValueError(f"unknown zero stage {zero_stage}")
if zero_stage == 2: if zero_stage <= 2:
fp32_flat_groups = [ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
state_dicts[i]['optimizer_state_dict'][fp32_groups_key]
for i in range(len(state_dicts))
]
elif zero_stage == 3: elif zero_stage == 3:
# if there is more than one param group, there will be multiple flattened tensors - one # if there is more than one param group, there will be multiple flattened tensors - one
# flattened tensor per group - for simplicity merge them into a single tensor # flattened tensor per group - for simplicity merge them into a single tensor
...@@ -117,11 +185,10 @@ def parse_optim_states(files, ds_checkpoint_dir): ...@@ -117,11 +185,10 @@ def parse_optim_states(files, ds_checkpoint_dir):
# will require matching the sub-lists of param_shapes for each param group flattened tensor # will require matching the sub-lists of param_shapes for each param group flattened tensor
fp32_flat_groups = [ fp32_flat_groups = [
torch.cat(state_dicts[i]['optimizer_state_dict'][fp32_groups_key], torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
0) for i in range(len(state_dicts))
] ]
return zero_stage, world_size, param_shapes, fp32_flat_groups return zero_stage, world_size, fp32_flat_groups
def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
...@@ -135,29 +202,54 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): ...@@ -135,29 +202,54 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
optim_files = get_optim_files(ds_checkpoint_dir) optim_files = get_optim_files(ds_checkpoint_dir)
zero_stage, world_size, param_shapes, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
print( print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
model_files = get_model_state_files(ds_checkpoint_dir)
model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
buffers = parse_model_state(model_file) zero_model_states = parse_model_states(model_files)
print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
if zero_stage == 2:
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, if zero_stage <= 2:
param_shapes, return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states)
fp32_flat_groups,
buffers)
elif zero_stage == 3: elif zero_stage == 3:
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states)
param_shapes,
fp32_flat_groups,
buffers) def _zero2_merge_frozen_params(state_dict, zero_model_states):
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
return
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
frozen_param_fragments = zero_model_states[0].frozen_param_fragments
if debug:
num_elem = sum(s.numel() for s in frozen_param_shapes.values())
print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
wanted_params = len(frozen_param_shapes)
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
print(f'Frozen params: Have {avail_numel} numels to process.')
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
total_params = 0
total_numel = 0
for name, shape in frozen_param_shapes.items():
total_params += 1
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
state_dict[name] = frozen_param_fragments[name]
if debug:
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
param_shapes, param_shapes = zero_model_states[0].param_shapes
fp32_flat_groups,
buffers):
# Reconstruction protocol: # Reconstruction protocol:
# #
...@@ -166,7 +258,7 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, ...@@ -166,7 +258,7 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
if debug: if debug:
for i in range(world_size): for i in range(world_size):
for j in range(len(fp32_flat_groups[0])): for j in range(len(fp32_flat_groups[0])):
print(f"fp32_flat_groups[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
# XXX: memory usage doubles here (zero2) # XXX: memory usage doubles here (zero2)
num_param_groups = len(fp32_flat_groups[0]) num_param_groups = len(fp32_flat_groups[0])
...@@ -175,26 +267,16 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, ...@@ -175,26 +267,16 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
merged_partitions = [sd[i] for sd in fp32_flat_groups] merged_partitions = [sd[i] for sd in fp32_flat_groups]
full_single_fp32_vector = torch.cat(merged_partitions, 0) full_single_fp32_vector = torch.cat(merged_partitions, 0)
merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
avail_numel = sum([ avail_numel = sum(
full_single_fp32_vector.numel() [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
for full_single_fp32_vector in merged_single_partition_of_fp32_groups
])
if debug: if debug:
wanted_params = sum([len(shapes) for shapes in param_shapes]) wanted_params = sum([len(shapes) for shapes in param_shapes])
wanted_numel = sum( wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
[sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
# not asserting if there is a mismatch due to possible padding # not asserting if there is a mismatch due to possible padding
print(f"Have {avail_numel} numels to process.") print(f"Have {avail_numel} numels to process.")
print(f"Need {wanted_numel} numels in {wanted_params} params.") print(f"Need {wanted_numel} numels in {wanted_params} params.")
state_dict = OrderedDict()
# buffers
state_dict.update(buffers)
if debug:
print(f"added {len(buffers)} buffers")
# params # params
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
# out-of-core computing solution # out-of-core computing solution
...@@ -210,13 +292,8 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, ...@@ -210,13 +292,8 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
total_params += 1 total_params += 1
if debug: if debug:
print( print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} " state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
)
state_dict[name] = full_single_fp32_vector.narrow(
0,
offset,
unpartitioned_numel).view(shape)
offset += unpartitioned_numel offset += unpartitioned_numel
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
...@@ -239,12 +316,28 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, ...@@ -239,12 +316,28 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
# Sanity check # Sanity check
if offset != avail_numel: if offset != avail_numel:
raise ValueError( raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
f"consumed {offset} numels out of {avail_numel} - something is wrong")
print( print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements"
)
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states):
state_dict = OrderedDict()
# buffers
buffers = zero_model_states[0].buffers
state_dict.update(buffers)
if debug:
print(f"added {len(buffers)} buffers")
_zero2_merge_frozen_params(state_dict, zero_model_states)
_zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
# recover shared parameters
for pair in zero_model_states[0].shared_params:
if pair[1] in state_dict:
state_dict[pair[0]] = state_dict[pair[1]]
return state_dict return state_dict
...@@ -256,34 +349,61 @@ def zero3_partitioned_param_info(unpartitioned_numel, world_size): ...@@ -256,34 +349,61 @@ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
return partitioned_numel, padding_numel return partitioned_numel, padding_numel
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
param_shapes, if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
fp32_flat_groups, return
buffers):
if debug:
for i in range(world_size):
num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
wanted_params = len(frozen_param_shapes)
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
print(f'Frozen params: Have {avail_numel} numels to process.')
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
total_params = 0
total_numel = 0
for name, shape in zero_model_states[0].frozen_param_shapes.items():
total_params += 1
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
if debug:
print(
f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
)
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
param_shapes = zero_model_states[0].param_shapes
avail_numel = fp32_flat_groups[0].numel() * world_size
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
# param, re-consolidating each param, while dealing with padding if any # param, re-consolidating each param, while dealing with padding if any
avail_numel = fp32_flat_groups[0].numel() * world_size
# merge list of dicts, preserving order # merge list of dicts, preserving order
param_shapes = {k: v for d in param_shapes for k, v in d.items()} param_shapes = {k: v for d in param_shapes for k, v in d.items()}
if debug: if debug:
for i in range(world_size): for i in range(world_size):
print(f"fp32_flat_groups[{i}].shape={fp32_flat_groups[i].shape}") print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
wanted_params = len(param_shapes) wanted_params = len(param_shapes)
wanted_numel = sum(shape.numel() for shape in param_shapes.values()) wanted_numel = sum(shape.numel() for shape in param_shapes.values())
# not asserting if there is a mismatch due to possible padding # not asserting if there is a mismatch due to possible padding
print(f"Have {avail_numel} numels to process.") avail_numel = fp32_flat_groups[0].numel() * world_size
print(f"Need {wanted_numel} numels in {wanted_params} params.") print(f"Trainable params: Have {avail_numel} numels to process.")
print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
state_dict = OrderedDict()
# buffers
state_dict.update(buffers)
if debug:
print(f"added {len(buffers)} buffers")
# params # params
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
...@@ -301,30 +421,41 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, ...@@ -301,30 +421,41 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size,
if debug: if debug:
print( print(
f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
) )
# XXX: memory usage doubles here # XXX: memory usage doubles here
state_dict[name] = torch.cat( state_dict[name] = torch.cat(
tuple(fp32_flat_groups[i].narrow(0, tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
offset, 0).narrow(0, 0, unpartitioned_numel).view(shape)
partitioned_numel)
for i in range(world_size)),
0).narrow(0,
0,
unpartitioned_numel).view(shape)
offset += partitioned_numel offset += partitioned_numel
offset *= world_size offset *= world_size
# Sanity check # Sanity check
if offset != avail_numel: if offset != avail_numel:
raise ValueError( raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
f"consumed {offset} numels out of {avail_numel} - something is wrong")
print( print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements"
)
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states):
state_dict = OrderedDict()
# buffers
buffers = zero_model_states[0].buffers
state_dict.update(buffers)
if debug:
print(f"added {len(buffers)} buffers")
_zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
_zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
# recover shared parameters
for pair in zero_model_states[0].shared_params:
if pair[1] in state_dict:
state_dict[pair[0]] = state_dict[pair[1]]
return state_dict return state_dict
...@@ -447,19 +578,21 @@ def get_global_step_from_zero_checkpoint(checkpoint_dir): ...@@ -447,19 +578,21 @@ def get_global_step_from_zero_checkpoint(checkpoint_dir):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("checkpoint_dir",
"checkpoint_dir", type=str,
type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
parser.add_argument( parser.add_argument(
"output_file", "output_file",
type=str, type=str,
help= help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
"path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)" parser.add_argument("-t",
) "--tag",
type=str,
default=None,
help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
parser.add_argument("-d", "--debug", action='store_true', help="enable debug") parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
args = parser.parse_args() args = parser.parse_args()
debug = args.debug debug = args.debug
convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file) convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file, tag=args.tag)
...@@ -39,6 +39,7 @@ from scripts.zero_to_fp32 import ( ...@@ -39,6 +39,7 @@ from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint, get_fp32_state_dict_from_zero_checkpoint,
get_global_step_from_zero_checkpoint get_global_step_from_zero_checkpoint
) )
from scripts.zero_to_fp32 import get_optim_files, parse_optim_states, get_model_state_file
from openfold.utils.logger import PerformanceLoggingCallback from openfold.utils.logger import PerformanceLoggingCallback
...@@ -294,8 +295,13 @@ def main(args): ...@@ -294,8 +295,13 @@ def main(args):
sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt) sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt)
else: else:
sd = torch.load(args.resume_from_ckpt) sd = torch.load(args.resume_from_ckpt)
sd = {k[len("module."):]:v for k,v in sd.items()} if 'module' in sd:
import_openfold_weights_(model=model_module, state_dict=sd) module_sd = {k[len("module."):]:v for k,v in sd['module'].items()}
import_openfold_weights_(model=model_module, state_dict=module_sd)
elif 'state_dict' in sd:
import_openfold_weights_(model=model_module, state_dict=sd['state_dict'])
else:
import_openfold_weights_(model=model_module, state_dict=sd)
logging.info("Successfully loaded model weights...") logging.info("Successfully loaded model weights...")
if(args.resume_from_jax_params): if(args.resume_from_jax_params):
model_module.load_from_jax(args.resume_from_jax_params) model_module.load_from_jax(args.resume_from_jax_params)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment