Unverified Commit 7531c6bf authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

full fp32 weights reconstruction for zero 2+3 (#892)

parent 7f03282c
......@@ -3,10 +3,13 @@ Copyright 2019 The Microsoft DeepSpeed Team
'''
import os
import stat
import torch
import warnings
import hashlib
import torch.distributed as dist
from collections import OrderedDict
from shutil import copyfile
from torch.nn.modules import Module
from torch.distributed.distributed_c10d import _get_global_rank
......@@ -1684,8 +1687,30 @@ class DeepSpeedEngine(Module):
torch.save(state, save_path)
self._curr_save_path = None
def _get_param_shapes(self):
param_shapes = OrderedDict()
for name, param in self.module.named_parameters():
param_shapes[name] = param.ds_shape if hasattr(param,
"ds_shape") else param.shape
# print(f"saving param {name} {param_shapes[name]}")
return param_shapes
def _copy_recovery_script(self, save_path):
base_dir = os.path.dirname(os.path.dirname(__file__))
script = "zero_to_fp32.py"
src = os.path.join(base_dir, "utils", script)
dst = os.path.join(save_path, script)
logger.info(f"creating recovery script {dst}")
copyfile(src, dst)
# make executable
os.chmod(dst, os.stat(dst).st_mode | stat.S_IEXEC)
def _save_zero_checkpoint(self, save_path, tag):
zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag)
zero_sd = {'optimizer_state_dict': self.optimizer.state_dict()}
zero_sd = dict(
optimizer_state_dict=self.optimizer.state_dict(),
param_shapes=self._get_param_shapes(),
)
torch.save(zero_sd, zero_checkpoint_name)
self._copy_recovery_script(save_path)
logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name))
#!/usr/bin/env python
# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
# application.
#
# example: python zero_to_fp32.py global_step1 pytorch_model.bin
import argparse
import torch
import glob
import os
from collections import OrderedDict
import deepspeed
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment.
def get_optim_files(checkpoint_dir):
if not os.path.isdir(checkpoint_dir):
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
# XXX: need to test that this simple glob rule works for multi-node setup too
optim_files = sorted(glob.glob(f"{checkpoint_dir}/*_optim_states.pt"))
if len(optim_files) == 0:
raise FileNotFoundError(
f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'")
return optim_files
def parse_optim_states(files):
state_dicts = []
for f in files:
state_dicts.append(torch.load(f))
if not "zero_stage" in state_dicts[0]['optimizer_state_dict']:
raise ValueError(f"non zero checkpoint")
zero_stage = state_dicts[0]['optimizer_state_dict']["zero_stage"]
# the groups are named differently in each stage
if zero_stage == 2:
fp32_groups_key = "single_partition_of_fp32_groups"
elif zero_stage == 3:
fp32_groups_key = "fp32_flat_groups"
else:
raise ValueError(f"unknown zero stage {zero_stage}")
param_shapes = state_dicts[0]["param_shapes"]
fp32_flat_groups = [
state_dicts[i]['optimizer_state_dict'][fp32_groups_key][0]
for i in range(len(state_dicts))
]
world_size = state_dicts[0]['optimizer_state_dict']["partition_count"]
return zero_stage, world_size, param_shapes, fp32_flat_groups
def zero3_partitioned_param_info(unpartitioned_numel, world_size):
remainder = unpartitioned_numel % world_size
padding_numel = (world_size - remainder) if remainder else 0
partitioned_numel = int(unpartitioned_numel / world_size)
return partitioned_numel, padding_numel
def convert_zero_chkpt_to_fp32_consolid_state_dict(checkpoint_dir, output_file):
"""
Convert zero 2 or 3 checkpoint into a single fp32 consolidated state_dict file that can be
loaded with ``torch.load(file)`` and used for training without DeepSpeed.
Args:
- ``checkpoint_dir``: path to the deepspeed checkpoint folder
- ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
"""
print(f"Processing zero checkpoint '{checkpoint_dir}'")
optim_files = get_optim_files(checkpoint_dir)
zero_stage, world_size, param_shapes, fp32_flat_groups = parse_optim_states(optim_files)
print(
f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
# Reconstruction protocol:
#
# - for zero2 we just need to concat the partitions back to back and reconsolidate over one huge
# flat buffer - no need to deal with padding since if there is any it will be only in the tail
# of the last partition so there it will be just left out
#
# - for zero3 we need to zip the partitions together at boundary of each param, re-consolidating
# each param, while dealing with padding if any
if zero_stage == 2:
# XXX: memory usage doubles here (zero2)
full_single_fp32_vector = torch.cat(fp32_flat_groups, 0)
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
# out-of-core computing solution
state_dict = OrderedDict()
offset = 0
total_numel = 0
for name, shape in param_shapes.items():
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
if zero_stage == 2:
# print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
state_dict[name] = full_single_fp32_vector.narrow(
0,
offset,
unpartitioned_numel).view(shape)
offset += unpartitioned_numel
elif zero_stage == 3:
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
# print(f"{name} full shape: {shape} partition0 numel {partitioned_numel} partitioned_padding_numel {partitioned_padding_numel}")
# XXX: memory usage doubles here (zero3)
state_dict[name] = torch.cat(
tuple(fp32_flat_groups[i].narrow(0,
offset,
partitioned_numel)
for i in range(world_size)),
0).view(shape)
offset += partitioned_numel + partitioned_padding_numel
# the job is done
print(f"Saving fp32 state dict to {output_file} (total_numel={total_numel})")
torch.save(state_dict, output_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"checkpoint_dir",
type=str,
help=
"path to the deepspeed checkpoint folder, e.g., path/checkpoint-1/global_step1")
parser.add_argument(
"output_file",
type=str,
help=
"path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-1/pytorch_model.bin)"
)
args = parser.parse_args()
convert_zero_chkpt_to_fp32_consolid_state_dict(args.checkpoint_dir, args.output_file)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment