Commit 1df591b0 authored by Jennifer's avatar Jennifer
Browse files

updates zero_to_fp32.py for new deepspeed version and import_weight bugfix

parent bb3f51e5
...@@ -688,8 +688,9 @@ def convert_deprecated_v1_keys(state_dict): ...@@ -688,8 +688,9 @@ def convert_deprecated_v1_keys(state_dict):
new_key = convert_key_re.sub(lambda m: replacements[m.group()], key) new_key = convert_key_re.sub(lambda m: replacements[m.group()], key)
# Add prefix for template modules # Add prefix for template modules
if new_key.startswith('template'): subheader = re.search('(?<=model.).*$', new_key).group()
new_key = f'template_embedder.{new_key}' if subheader.startswith('template'):
new_key = f'model.template_embedder.{subheader}'
converted_state_dict[new_key] = value converted_state_dict[new_key] = value
......
#!/usr/bin/env python #!/usr/bin/env python
# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
# application. # application.
...@@ -12,13 +17,27 @@ import torch ...@@ -12,13 +17,27 @@ import torch
import glob import glob
import math import math
import os import os
from collections import OrderedDict
import re import re
from collections import OrderedDict
from dataclasses import dataclass
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment. # DeepSpeed data structures it has to be available in the current python environment.
import deepspeed
from deepspeed.utils import logger from deepspeed.utils import logger
from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
@dataclass
class zero_model_state:
buffers: dict()
param_shapes: dict()
shared_params: list
ds_version: int
frozen_param_shapes: dict()
frozen_param_fragments: dict()
debug = 0 debug = 0
...@@ -26,12 +45,25 @@ debug = 0 ...@@ -26,12 +45,25 @@ debug = 0
device = torch.device('cpu') device = torch.device('cpu')
def atoi(text):
return int(text) if text.isdigit() else text
def natural_keys(text):
'''
alist.sort(key=natural_keys) sorts in human order
http://nedbatchelder.com/blog/200712/human_sorting.html
(See Toothy's implementation in the comments)
'''
return [atoi(c) for c in re.split(r'(\d+)', text)]
def get_model_state_file(checkpoint_dir, zero_stage): def get_model_state_file(checkpoint_dir, zero_stage):
if not os.path.isdir(checkpoint_dir): if not os.path.isdir(checkpoint_dir):
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
# there should be only one file # there should be only one file
if zero_stage == 2: if zero_stage <= 2:
file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
elif zero_stage == 3: elif zero_stage == 3:
file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
...@@ -42,33 +74,68 @@ def get_model_state_file(checkpoint_dir, zero_stage): ...@@ -42,33 +74,68 @@ def get_model_state_file(checkpoint_dir, zero_stage):
return file return file
def get_optim_files(checkpoint_dir): def get_checkpoint_files(checkpoint_dir, glob_pattern):
# XXX: need to test that this simple glob rule works for multi-node setup too # XXX: need to test that this simple glob rule works for multi-node setup too
optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, "*_optim_states.pt"))) ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
if len(optim_files) == 0: if len(ckpt_files) == 0:
raise FileNotFoundError( raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'")
return optim_files return ckpt_files
def parse_model_state(file): def get_optim_files(checkpoint_dir):
state_dict = torch.load(file, map_location=device) return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
if "buffer_names" not in state_dict:
raise ValueError(f"{file} is not a model state checkpoint")
buffer_names = state_dict["buffer_names"]
if debug:
print("Found buffers:", buffer_names)
# recover just the buffers while restoring them to fp32 if they were saved in fp16 def get_model_state_files(checkpoint_dir):
buffers = { return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
k: v.float()
for k,
v in state_dict["module"].items() if k in buffer_names def parse_model_states(files):
} zero_model_states = []
return buffers for file in files:
state_dict = torch.load(file, map_location=device)
if BUFFER_NAMES not in state_dict:
raise ValueError(f"{file} is not a model state checkpoint")
buffer_names = state_dict[BUFFER_NAMES]
if debug:
print("Found buffers:", buffer_names)
# recover just the buffers while restoring them to fp32 if they were saved in fp16
buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
param_shapes = state_dict[PARAM_SHAPES]
# collect parameters that are included in param_shapes
param_names = []
for s in param_shapes:
for name in s.keys():
param_names.append(name)
# update with frozen parameters
frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
if frozen_param_shapes is not None:
if debug:
print(f"Found frozen_param_shapes: {frozen_param_shapes}")
param_names += list(frozen_param_shapes.keys())
# handle shared params
shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
ds_version = state_dict.get(DS_VERSION, None)
frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
z_model_state = zero_model_state(buffers=buffers,
param_shapes=param_shapes,
shared_params=shared_params,
ds_version=ds_version,
frozen_param_shapes=frozen_param_shapes,
frozen_param_fragments=frozen_param_fragments)
zero_model_states.append(z_model_state)
return zero_model_states
def parse_optim_states(files, ds_checkpoint_dir): def parse_optim_states(files, ds_checkpoint_dir):
...@@ -76,13 +143,17 @@ def parse_optim_states(files, ds_checkpoint_dir): ...@@ -76,13 +143,17 @@ def parse_optim_states(files, ds_checkpoint_dir):
total_files = len(files) total_files = len(files)
state_dicts = [] state_dicts = []
for f in files: for f in files:
state_dicts.append(torch.load(f, map_location=device)) state_dict = torch.load(f, map_location=device)
# immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
# and also handle the case where it was already removed by another helper script
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
state_dicts.append(state_dict)
if not "zero_stage" in state_dicts[0]['optimizer_state_dict']: if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
raise ValueError(f"{files[0]} is not a zero checkpoint") raise ValueError(f"{files[0]} is not a zero checkpoint")
zero_stage = state_dicts[0]['optimizer_state_dict']["zero_stage"] zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
world_size = state_dicts[0]['optimizer_state_dict']["partition_count"] world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
param_shapes = state_dicts[0]["param_shapes"]
# For ZeRO-2 each param group can have different partition_count as data parallelism for expert # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
# parameters can be different from data parallelism for non-expert parameters. So we can just # parameters can be different from data parallelism for non-expert parameters. So we can just
# use the max of the partition_count to get the dp world_size. # use the max of the partition_count to get the dp world_size.
...@@ -97,18 +168,15 @@ def parse_optim_states(files, ds_checkpoint_dir): ...@@ -97,18 +168,15 @@ def parse_optim_states(files, ds_checkpoint_dir):
) )
# the groups are named differently in each stage # the groups are named differently in each stage
if zero_stage == 2: if zero_stage <= 2:
fp32_groups_key = "single_partition_of_fp32_groups" fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
elif zero_stage == 3: elif zero_stage == 3:
fp32_groups_key = "fp32_flat_groups" fp32_groups_key = FP32_FLAT_GROUPS
else: else:
raise ValueError(f"unknown zero stage {zero_stage}") raise ValueError(f"unknown zero stage {zero_stage}")
if zero_stage == 2: if zero_stage <= 2:
fp32_flat_groups = [ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
state_dicts[i]['optimizer_state_dict'][fp32_groups_key]
for i in range(len(state_dicts))
]
elif zero_stage == 3: elif zero_stage == 3:
# if there is more than one param group, there will be multiple flattened tensors - one # if there is more than one param group, there will be multiple flattened tensors - one
# flattened tensor per group - for simplicity merge them into a single tensor # flattened tensor per group - for simplicity merge them into a single tensor
...@@ -117,11 +185,10 @@ def parse_optim_states(files, ds_checkpoint_dir): ...@@ -117,11 +185,10 @@ def parse_optim_states(files, ds_checkpoint_dir):
# will require matching the sub-lists of param_shapes for each param group flattened tensor # will require matching the sub-lists of param_shapes for each param group flattened tensor
fp32_flat_groups = [ fp32_flat_groups = [
torch.cat(state_dicts[i]['optimizer_state_dict'][fp32_groups_key], torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
0) for i in range(len(state_dicts))
] ]
return zero_stage, world_size, param_shapes, fp32_flat_groups return zero_stage, world_size, fp32_flat_groups
def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
...@@ -135,29 +202,54 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): ...@@ -135,29 +202,54 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
optim_files = get_optim_files(ds_checkpoint_dir) optim_files = get_optim_files(ds_checkpoint_dir)
zero_stage, world_size, param_shapes, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
print( print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
model_files = get_model_state_files(ds_checkpoint_dir)
model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
buffers = parse_model_state(model_file) zero_model_states = parse_model_states(model_files)
print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
if zero_stage == 2:
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, if zero_stage <= 2:
param_shapes, return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states)
fp32_flat_groups,
buffers)
elif zero_stage == 3: elif zero_stage == 3:
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states)
param_shapes,
fp32_flat_groups,
buffers) def _zero2_merge_frozen_params(state_dict, zero_model_states):
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
return
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
frozen_param_fragments = zero_model_states[0].frozen_param_fragments
if debug:
num_elem = sum(s.numel() for s in frozen_param_shapes.values())
print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
wanted_params = len(frozen_param_shapes)
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
print(f'Frozen params: Have {avail_numel} numels to process.')
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
total_params = 0
total_numel = 0
for name, shape in frozen_param_shapes.items():
total_params += 1
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
state_dict[name] = frozen_param_fragments[name]
if debug:
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
param_shapes, param_shapes = zero_model_states[0].param_shapes
fp32_flat_groups,
buffers):
# Reconstruction protocol: # Reconstruction protocol:
# #
...@@ -166,7 +258,7 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, ...@@ -166,7 +258,7 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
if debug: if debug:
for i in range(world_size): for i in range(world_size):
for j in range(len(fp32_flat_groups[0])): for j in range(len(fp32_flat_groups[0])):
print(f"fp32_flat_groups[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
# XXX: memory usage doubles here (zero2) # XXX: memory usage doubles here (zero2)
num_param_groups = len(fp32_flat_groups[0]) num_param_groups = len(fp32_flat_groups[0])
...@@ -175,26 +267,16 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, ...@@ -175,26 +267,16 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
merged_partitions = [sd[i] for sd in fp32_flat_groups] merged_partitions = [sd[i] for sd in fp32_flat_groups]
full_single_fp32_vector = torch.cat(merged_partitions, 0) full_single_fp32_vector = torch.cat(merged_partitions, 0)
merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
avail_numel = sum([ avail_numel = sum(
full_single_fp32_vector.numel() [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
for full_single_fp32_vector in merged_single_partition_of_fp32_groups
])
if debug: if debug:
wanted_params = sum([len(shapes) for shapes in param_shapes]) wanted_params = sum([len(shapes) for shapes in param_shapes])
wanted_numel = sum( wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
[sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
# not asserting if there is a mismatch due to possible padding # not asserting if there is a mismatch due to possible padding
print(f"Have {avail_numel} numels to process.") print(f"Have {avail_numel} numels to process.")
print(f"Need {wanted_numel} numels in {wanted_params} params.") print(f"Need {wanted_numel} numels in {wanted_params} params.")
state_dict = OrderedDict()
# buffers
state_dict.update(buffers)
if debug:
print(f"added {len(buffers)} buffers")
# params # params
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
# out-of-core computing solution # out-of-core computing solution
...@@ -210,13 +292,8 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, ...@@ -210,13 +292,8 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
total_params += 1 total_params += 1
if debug: if debug:
print( print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} " state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
)
state_dict[name] = full_single_fp32_vector.narrow(
0,
offset,
unpartitioned_numel).view(shape)
offset += unpartitioned_numel offset += unpartitioned_numel
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
...@@ -239,12 +316,28 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, ...@@ -239,12 +316,28 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
# Sanity check # Sanity check
if offset != avail_numel: if offset != avail_numel:
raise ValueError( raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
f"consumed {offset} numels out of {avail_numel} - something is wrong")
print( print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements"
)
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states):
state_dict = OrderedDict()
# buffers
buffers = zero_model_states[0].buffers
state_dict.update(buffers)
if debug:
print(f"added {len(buffers)} buffers")
_zero2_merge_frozen_params(state_dict, zero_model_states)
_zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
# recover shared parameters
for pair in zero_model_states[0].shared_params:
if pair[1] in state_dict:
state_dict[pair[0]] = state_dict[pair[1]]
return state_dict return state_dict
...@@ -256,34 +349,61 @@ def zero3_partitioned_param_info(unpartitioned_numel, world_size): ...@@ -256,34 +349,61 @@ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
return partitioned_numel, padding_numel return partitioned_numel, padding_numel
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
param_shapes, if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
fp32_flat_groups, return
buffers):
if debug:
for i in range(world_size):
num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
wanted_params = len(frozen_param_shapes)
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
print(f'Frozen params: Have {avail_numel} numels to process.')
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
total_params = 0
total_numel = 0
for name, shape in zero_model_states[0].frozen_param_shapes.items():
total_params += 1
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
if debug:
print(
f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
)
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
param_shapes = zero_model_states[0].param_shapes
avail_numel = fp32_flat_groups[0].numel() * world_size
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
# param, re-consolidating each param, while dealing with padding if any # param, re-consolidating each param, while dealing with padding if any
avail_numel = fp32_flat_groups[0].numel() * world_size
# merge list of dicts, preserving order # merge list of dicts, preserving order
param_shapes = {k: v for d in param_shapes for k, v in d.items()} param_shapes = {k: v for d in param_shapes for k, v in d.items()}
if debug: if debug:
for i in range(world_size): for i in range(world_size):
print(f"fp32_flat_groups[{i}].shape={fp32_flat_groups[i].shape}") print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
wanted_params = len(param_shapes) wanted_params = len(param_shapes)
wanted_numel = sum(shape.numel() for shape in param_shapes.values()) wanted_numel = sum(shape.numel() for shape in param_shapes.values())
# not asserting if there is a mismatch due to possible padding # not asserting if there is a mismatch due to possible padding
print(f"Have {avail_numel} numels to process.") avail_numel = fp32_flat_groups[0].numel() * world_size
print(f"Need {wanted_numel} numels in {wanted_params} params.") print(f"Trainable params: Have {avail_numel} numels to process.")
print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
state_dict = OrderedDict()
# buffers
state_dict.update(buffers)
if debug:
print(f"added {len(buffers)} buffers")
# params # params
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
...@@ -301,30 +421,41 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, ...@@ -301,30 +421,41 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size,
if debug: if debug:
print( print(
f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
) )
# XXX: memory usage doubles here # XXX: memory usage doubles here
state_dict[name] = torch.cat( state_dict[name] = torch.cat(
tuple(fp32_flat_groups[i].narrow(0, tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
offset, 0).narrow(0, 0, unpartitioned_numel).view(shape)
partitioned_numel)
for i in range(world_size)),
0).narrow(0,
0,
unpartitioned_numel).view(shape)
offset += partitioned_numel offset += partitioned_numel
offset *= world_size offset *= world_size
# Sanity check # Sanity check
if offset != avail_numel: if offset != avail_numel:
raise ValueError( raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
f"consumed {offset} numels out of {avail_numel} - something is wrong")
print( print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements"
)
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states):
state_dict = OrderedDict()
# buffers
buffers = zero_model_states[0].buffers
state_dict.update(buffers)
if debug:
print(f"added {len(buffers)} buffers")
_zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
_zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
# recover shared parameters
for pair in zero_model_states[0].shared_params:
if pair[1] in state_dict:
state_dict[pair[0]] = state_dict[pair[1]]
return state_dict return state_dict
...@@ -447,19 +578,21 @@ def get_global_step_from_zero_checkpoint(checkpoint_dir): ...@@ -447,19 +578,21 @@ def get_global_step_from_zero_checkpoint(checkpoint_dir):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("checkpoint_dir",
"checkpoint_dir", type=str,
type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
parser.add_argument( parser.add_argument(
"output_file", "output_file",
type=str, type=str,
help= help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
"path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)" parser.add_argument("-t",
) "--tag",
type=str,
default=None,
help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
parser.add_argument("-d", "--debug", action='store_true', help="enable debug") parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
args = parser.parse_args() args = parser.parse_args()
debug = args.debug debug = args.debug
convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file) convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file, tag=args.tag)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment