"docs/vscode:/vscode.git/clone" did not exist on "8eaaa546d89f836b716e92348786d878f883ee86"
Unverified Commit 7ccc9daf authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

Support loading and saving ZeRO checkpoints with changing DP degree (#240)



* Support saving and loading ZeRO checkpoints on different data
parallelism degree.

* Fix formatting

* Support checkpoint with varying GPU count in ZeRO stage 1

* Fix formatting

* Formatting fixes

* Update model tests

* Remove pprint

* Minor fix

* Fix formatting

* Update model tests
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent 366d8816
......@@ -121,6 +121,8 @@ class DeepSpeedLight(Module):
self.gradient_average = True
self.warn_unscaled_loss = True
self.config_params = config_params
self.loaded_checkpoint_mp_world_size = None
self.loaded_checkpoint_dp_world_size = None
if dist_init_required is None:
dist_init_required = not dist.is_initialized()
......@@ -289,9 +291,6 @@ class DeepSpeedLight(Module):
def zero_overlap_comm(self):
return self._config.zero_config.overlap_comm
def zero_max_elements_per_comm(self):
return self._config.zero_max_elements_per_comm
def zero_optimization_stage(self):
return self._config.zero_optimization_stage
......@@ -307,6 +306,9 @@ class DeepSpeedLight(Module):
def zero_contiguous_gradients(self):
return self._config.zero_config.contiguous_gradients
def zero_load_from_fp32_weights(self):
return self._config.zero_config.load_from_fp32_weights
def allgather_size(self):
return self._config.allgather_size
......@@ -472,10 +474,12 @@ class DeepSpeedLight(Module):
if self.mpu is None:
self.data_parallel_group = _initialize_parameter_parallel_groups()
self.dp_world_size = dist.get_world_size()
self.mp_world_size = 1
self.broadcast_src_rank = 0
else:
self.data_parallel_group = self.mpu.get_data_parallel_group()
self.dp_world_size = self.mpu.get_data_parallel_world_size()
self.mp_world_size = self.mpu.get_model_parallel_world_size()
self.broadcast_src_rank = _get_global_rank(
self.mpu.get_data_parallel_group(),
0)
......@@ -1057,18 +1061,19 @@ class DeepSpeedLight(Module):
def load_module_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)
def _get_zero_ckpt_name(self, checkpoints_path, tag):
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
pp_rank = torch.distributed.get_rank(group=self.optimizer.dp_process_group)
filename = 'zero_pp_rank_{}'.format(pp_rank)
def _get_rank_zero_ckpt_name(self, checkpoints_path, tag, mp_rank, dp_rank):
filename = 'zero_pp_rank_{}'.format(dp_rank)
zero_ckpt_name = os.path.join(
checkpoints_path,
str(tag),
filename + '_mp_rank_{:02d}'.format(mp_rank) + 'optim_states.pt')
return zero_ckpt_name
def _get_zero_ckpt_name(self, checkpoints_path, tag):
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
pp_rank = torch.distributed.get_rank(group=self.optimizer.dp_process_group)
return self._get_rank_zero_ckpt_name(checkpoints_path, tag, mp_rank, pp_rank)
def _get_ckpt_name(self, checkpoints_path, tag):
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
......@@ -1144,13 +1149,17 @@ class DeepSpeedLight(Module):
self.csr_tensor_module_names = checkpoint['csr_tensor_module_names']
self.global_steps = checkpoint['global_steps']
self.skipped_steps = checkpoint['skipped_steps']
self.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size']
self.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size']
deepspeed_states = [
'module',
'optimizer',
'lr_scheduler',
'csr_tensor_module_names',
'skipped_steps',
'global_steps'
'global_steps',
'dp_world_size',
'mp_world_size'
]
client_state = {
key: value
......@@ -1161,18 +1170,72 @@ class DeepSpeedLight(Module):
return load_path, client_state
def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
zero_checkpoint_name = self._get_zero_ckpt_name(load_dir, tag)
zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag)
if zero_sd_list is None:
return
self.optimizer.load_state_dict(
state_dict_list=zero_sd_list,
load_optimizer_states=load_optimizer_states,
load_from_fp32_weights=self.zero_load_from_fp32_weights())
print(
f'loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}'
)
if not os.path.exists(zero_checkpoint_name):
logger.warn(
'Client provided checkpoint load path: {} does not exist ... skip checkpoint load'
.format(zero_checkpoint_name))
def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size):
zero_ckpt_names = []
for dp_rank in range(dp_world_size):
ckpt_name = self._get_rank_zero_ckpt_name(checkpoints_path=load_dir,
tag=tag,
mp_rank=mp_rank,
dp_rank=dp_rank)
zero_ckpt_names.append(ckpt_name)
return zero_ckpt_names
def _get_all_zero_checkpoint_names(self,
load_dir,
tag,
mp_world_size,
dp_world_size):
zero_ckpt_names = []
for mp_rank in range(mp_world_size):
mp_rank_ckpt_names = self._get_mp_rank_zero_checkpoint_names(
load_dir=load_dir,
tag=tag,
mp_rank=mp_rank,
dp_world_size=dp_world_size)
zero_ckpt_names += mp_rank_ckpt_names
return zero_ckpt_names
def _get_all_zero_checkpoints(self, load_dir, tag):
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
zero_ckpt_names = self._get_mp_rank_zero_checkpoint_names(
load_dir=load_dir,
tag=tag,
mp_rank=mp_rank,
dp_world_size=self.loaded_checkpoint_dp_world_size)
invalid_zero_ckpt_paths = []
for ckpt_name in zero_ckpt_names:
if not os.path.exists(ckpt_name):
invalid_zero_ckpt_paths.append(ckpt_name)
if len(invalid_zero_ckpt_paths) > 0:
logging.warn(
f"Client provided zero checkpoint load paths: {invalid_zero_ckpt_paths} does not exist"
)
return None
zero_sd = torch.load(zero_checkpoint_name, map_location='cpu')
self.optimizer.load_state_dict(zero_sd['optimizer_state_dict'],
load_optimizer_states=load_optimizer_states)
logger.info('loading zero checkpoint {}'.format(zero_checkpoint_name))
zero_sd_list = []
for ckpt_name in zero_ckpt_names:
zero_sd_list.append(torch.load(ckpt_name, map_location='cpu'))
zero_optimizer_sd = [sd['optimizer_state_dict'] for sd in zero_sd_list]
print(
f"successfully loaded {len(zero_optimizer_sd)} ZeRO state_dicts for rank {self.global_rank}"
)
return zero_optimizer_sd
def save_checkpoint(self, save_dir, tag, client_state={}):
r"""Save training checkpoint
......@@ -1232,6 +1295,10 @@ class DeepSpeedLight(Module):
self.skipped_steps,
'global_steps':
self.global_steps,
'dp_world_size':
self.dp_world_size,
'mp_world_size':
self.mp_world_size
}
state.update(client_state)
......
......@@ -23,6 +23,7 @@ ZeRO optimization should be enabled as:
"contiguous_gradients" : [true|false]
"overlap_comm": [true|false],
"reduce_bucket_size": 500000000
"load_from_fp32_weights": [true|false]
}
}
'''
......@@ -59,17 +60,24 @@ ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT = 500000000
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE = 'allgather_bucket_size'
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT = 500000000
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEPRECATED = 'allgather_size'
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS = 'load_from_fp32_weights'
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT = True
ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_STAGE: ZERO_OPTIMIZATION_STAGE_DEFAULT,
ZERO_OPTIMIZATION_STAGE:
ZERO_OPTIMIZATION_STAGE_DEFAULT,
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS:
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_SCATTER: ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE: ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_SCATTER:
ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE:
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS:
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE:
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS:
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT
}
......@@ -84,6 +92,7 @@ class DeepSpeedZeroConfig(object):
self.allgather_partitions = None
self.allgather_bucket_size = None
self.overlap_comm = None
self.load_from_fp32_weights = None
if ZERO_OPTIMIZATION in param_dict.keys():
zero_config_dict = param_dict[ZERO_OPTIMIZATION]
......@@ -148,3 +157,7 @@ class DeepSpeedZeroConfig(object):
zero_config_dict,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT)
self.load_from_fp32_weights = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT)
......@@ -12,6 +12,7 @@ from torch.autograd import Variable
from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.pt.deepspeed_utils import see_memory_usage, is_model_parallel_parameter
from deepspeed.pt.deepspeed_zero_config import ZERO_OPTIMIZATION_GRADIENTS
#Toggle this to true to enable correctness test
#with gradient partitioning and without
......@@ -61,8 +62,8 @@ def lcm(x, y):
return x * y // gcd(x, y)
#create a flat tensor aligned at the alignment boundary
def flatten_dense_tensors_aligned(tensor_list, alignment, pg):
# create a flat tensor aligned at the alignment boundary
def flatten_dense_tensors_aligned(tensor_list, alignment):
num_elements = 0
for tensor in tensor_list:
num_elements = num_elements + tensor.numel()
......@@ -83,11 +84,21 @@ def flatten_dense_tensors_aligned(tensor_list, alignment, pg):
return _flatten_dense_tensors(padded_tensor_list)
def get_alignment_padding(tensor_list, alignment):
num_elements = sum([tensor.numel() for tensor in tensor_list])
remainder = num_elements % alignment
return (alignment - remainder) if remainder else remainder
def move_to_cpu(tensor_list):
for tensor in tensor_list:
tensor.data = tensor.data.cpu()
def print_rank_msg(msg):
print(f"rank {dist.get_rank()} - {msg}")
class FP16_DeepSpeedZeroOptimizer(object):
"""
DeepSpeedZeroOptimizer designed to reduce the memory footprint
......@@ -195,10 +206,19 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.all_reduce_print = False
# padding on each partition for alignment purposes
self.groups_padding = []
# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
# push this group to list before modify
self.fp16_groups.append(param_group['params'])
# Record padding required to align group to world size
if partition_id == dist.get_world_size(group=self.dp_process_group) - 1:
padding = get_alignment_padding(self.fp16_groups[i],
self.partition_count)
else:
padding = 0
self.groups_padding.append(padding)
#not sure why apex was cloning the weights before flattening
#removing cloning here
......@@ -212,8 +232,8 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.fp16_groups_flat.append(
flatten_dense_tensors_aligned(
self.fp16_groups[i],
dist.get_world_size(group=self.dp_process_group),
self.dp_process_group).cuda(torch.cuda.current_device()))
dist.get_world_size(group=self.dp_process_group)).cuda(
torch.cuda.current_device()))
see_memory_usage(f"After flattening and moving param group {i} to GPU")
if dist.get_rank(group=self.dp_process_group) == 0:
......@@ -1114,8 +1134,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
if partition_id == dist.get_world_size(group=self.dp_process_group) - 1:
single_grad_partition = flatten_dense_tensors_aligned(
self.averaged_gradients[i],
int(self.partition_size[i]),
self.dp_process_group).to(
int(self.partition_size[i])).to(
self.single_partition_of_fp32_groups[i].dtype)
else:
single_grad_partition = _flatten_dense_tensors(
......@@ -1336,6 +1355,38 @@ class FP16_DeepSpeedZeroOptimizer(object):
loss_scale = property(_get_loss_scale, _set_loss_scale)
cur_scale = property(_get_loss_scale, _set_loss_scale)
# Return group tensor after removing paddings that are added for alignment to DP world size.
# This method works on the assumption that each group contains a single flattened tensor.
def _get_groups_without_padding(self, groups_with_padding):
groups_without_padding = []
for i, group in enumerate(groups_with_padding):
lean_length = group.numel() - self.groups_padding[i]
groups_without_padding.append(group[:lean_length])
return groups_without_padding
# Return optimizer state after removing paddings that are added for alignment.
def _get_state_without_padding(self, state_with_padding, padding):
lean_state = {}
for key, value in state_with_padding.items():
lean_length = value.numel() - padding
lean_state[key] = value[:lean_length]
return lean_state
# Return base optimizer states.
# This method assumes that each param group contains a single flattened tensor.
def _get_base_optimizer_state(self):
optimizer_groups_state = []
for i, group in enumerate(self.optimizer.param_groups):
p = group['params'][0]
lean_optimizer_state = self._get_state_without_padding(
self.optimizer.state[p],
self.groups_padding[i])
optimizer_groups_state.append(lean_optimizer_state)
return optimizer_groups_state
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
......@@ -1351,21 +1402,98 @@ class FP16_DeepSpeedZeroOptimizer(object):
state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
state_dict[
'single_partition_of_fp32_groups'] = self.single_partition_of_fp32_groups
state_dict['base_optimizer_state'] = self._get_base_optimizer_state()
state_dict['zero_stage'] = ZERO_OPTIMIZATION_GRADIENTS
state_dict['partition_count'] = self.partition_count
# Remove paddings for DP alignment to enable loading for other alignment values
fp32_groups_without_padding = self._get_groups_without_padding(
self.single_partition_of_fp32_groups)
state_dict['single_partition_of_fp32_groups'] = fp32_groups_without_padding
return state_dict
# Refresh the fp32 master params from the fp16 copies.
def refresh_fp32_params(self):
# Restore base optimizer fp32 weights from checkpoint by:
# 1) Merging fp32 weights from checkpoints of all partitions
# 2) Extracting fp32 weights for current partition from merged weights
# 3) Using extracted weights to update base optimizer weights directly.
def _restore_from_fp32_weights(self, all_state_dict):
partition_id = dist.get_rank(group=self.dp_process_group)
merged_single_partition_of_fp32_groups = []
for i in range(len(self.single_partition_of_fp32_groups)):
merged_partitions = [
sd['single_partition_of_fp32_groups'][i] for sd in all_state_dict
]
flat_merged_partitions = flatten_dense_tensors_aligned(
merged_partitions,
dist.get_world_size(group=self.dp_process_group))
dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions)
merged_single_partition_of_fp32_groups.append(dp_partitions[partition_id])
for current, saved in zip(self.single_partition_of_fp32_groups, merged_single_partition_of_fp32_groups):
current.data.copy_(saved.data)
# Restore base optimizer fp32 weights from ZeRO fp16 weights
def _restore_from_fp16_weights(self):
partition_id = dist.get_rank(group=self.dp_process_group)
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp32_partition.data.copy_(fp16_partitions[partition_id].data)
def load_state_dict(self, state_dict, load_optimizer_states=True):
# Refresh the fp32 master params from the fp16 copies.
def refresh_fp32_params(self):
self._restore_from_fp16_weights()
# Extract optimizer state for current partition from merged states of all partitions
def _partition_base_optimizer_state(self, state_key, all_partition_states):
partition_id = dist.get_rank(group=self.dp_process_group)
alignment = dist.get_world_size(group=self.dp_process_group)
flat_merged_partitions = flatten_dense_tensors_aligned(
all_partition_states,
alignment)
dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions)
return dp_partitions[partition_id]
# Restore base optimizer state from checkpoint by
# 1) Merging optimizer state from checkpoints of all partitions
# 2) Extracting optimizer state for current partition from the merged state
# 3) Using the extracted value to directly update the base optimizer.
def _restore_base_optimizer_state(self, all_state_dict):
base_optimizer_group_states = []
for i in range(len(self.optimizer.param_groups)):
partition_states = {}
all_partition_group_states = [
sd['base_optimizer_state'][i] for sd in all_state_dict
]
for key in all_partition_group_states[0].keys():
all_partition_states = [
all_states[key] for all_states in all_partition_group_states
]
partition_states[key] = self._partition_base_optimizer_state(
key,
all_partition_states)
base_optimizer_group_states.append(partition_states)
for i, group in enumerate(self.optimizer.param_groups):
p = group['params'][0]
for key, saved in base_optimizer_group_states[i].items():
current = self.optimizer.state[p][key]
current.data.copy_(saved.data)
def load_state_dict(self,
state_dict_list,
load_optimizer_states=True,
load_from_fp32_weights=False):
r"""Loading ZeRO checkpoint
Arguments:
state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition.
Note that the number of saved partitions may differ from number of loading partitions to support
changing GPU count, specifically DP world size, between saving and loading checkpoints.
load_optimizer_states: Boolean indicating whether or not to load base optimizer states
load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32
copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss).
"""
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
......@@ -1382,12 +1510,12 @@ class FP16_DeepSpeedZeroOptimizer(object):
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# I think it should actually be ok to reload the optimizer before the model.
self.loss_scaler = state_dict['loss_scaler']
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.overflow = state_dict['overflow']
self.loss_scaler = state_dict_list[0]['loss_scaler']
self.dynamic_loss_scale = state_dict_list[0]['dynamic_loss_scale']
self.overflow = state_dict_list[0]['overflow']
if load_optimizer_states:
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
self._restore_base_optimizer_state(state_dict_list)
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date.
......@@ -1404,16 +1532,10 @@ class FP16_DeepSpeedZeroOptimizer(object):
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
if 'partition_count' in state_dict and state_dict[
'partition_count'] == self.partition_count:
# Use option 2
for current, saved in zip(self.single_partition_of_fp32_groups, state_dict['single_partition_of_fp32_groups']):
current.data.copy_(saved.data)
if load_from_fp32_weights:
self._restore_from_fp32_weights(state_dict_list)
else:
# Use option 1
partition_id = dist.get_rank(group=self.dp_process_group)
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp32_partition.data.copy_(fp16_partitions[partition_id].data)
self._restore_from_fp16_weights()
def _handle_overflow(cpu_sum, x, i):
......
......@@ -8,6 +8,32 @@ from deepspeed.pt.zero_utils import _initialize_parameter_parallel_groups
from deepspeed.pt.log_utils import log_dist, logger
from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.pt.deepspeed_utils import get_grad_norm, CheckOverflow
from deepspeed.pt.deepspeed_zero_config import ZERO_OPTIMIZATION_OPTIMIZER_STATES
def get_alignment_padding(flattened_lean_size, sub_partition_id, sub_partition_size):
sub_partition_high_limit = (sub_partition_id + 1) * sub_partition_size
if sub_partition_high_limit <= flattened_lean_size:
return 0
else:
return min(sub_partition_size, sub_partition_high_limit - flattened_lean_size)
def get_group_alignment_padding(tensor_list, sub_partition_size, sub_partition_count):
group_paddings = []
flattened_size = sum([tensor.numel() for tensor in tensor_list])
for i in range(sub_partition_count):
padding = get_alignment_padding(flattened_size, i, sub_partition_size)
group_paddings.append(padding)
logger.info("****Padding information*****")
logger.info(f"tensor_size = {flattened_size}")
logger.info(f"sub_partition_size = {sub_partition_size}")
logger.info(f"sub_partition_count = {sub_partition_count}")
for i, padding in enumerate(group_paddings):
logger.info(f"padding[{i}] = {padding}")
return group_paddings
def flatten_dense_tensors_sub_partition_aligned(tensor_list,
......@@ -164,6 +190,9 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
local_rank = dist.get_rank(group=self.dp_process_group)
self.group_paddings = []
self.partition_count = dist.get_world_size(group=self.dp_process_group)
# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
# push this group to list before modify
......@@ -215,6 +244,13 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
local_sub_partitions.append(fp32_sub_partition)
self.local_sub_partitions_of_fp32_groups.append(local_sub_partitions)
# Compute sub_partition paddings
sub_partition_paddings = get_group_alignment_padding(
tensor_list=self.fp16_groups[i],
sub_partition_size=sub_partition_size,
sub_partition_count=num_comm_intervals * self.partition_count)
self.group_paddings.append(sub_partition_paddings)
# modify optimizer of have flat master weight
# self.single_partition_of_fp32_groups[i].requires_grad = True # keep this in case internal optimizer uses it
param_group['params'] = self.local_sub_partitions_of_fp32_groups[i]
......@@ -256,6 +292,22 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
mpu=self.mpu,
zero_reduce_scatter=True)
self._initialize_optimizer_states()
def _initialize_optimizer_states(self):
for group_idx, group in enumerate(self.local_sub_partitions_of_fp32_groups):
for idx, sub_partition_param in enumerate(group):
sub_partition_grad = torch.zeros(int(
self.sub_partition_sizes[group_idx]),
dtype=sub_partition_param.dtype).cuda()
sub_partition_param.grad = sub_partition_grad
self.optimizer.step()
for group in self.local_sub_partitions_of_fp32_groups:
for idx, sub_partition_param in enumerate(group):
sub_partition_param.grad = None
@staticmethod
def get_data_parallel_sub_partitions(tensor,
max_elements_per_comm,
......@@ -721,6 +773,51 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
loss_scale = property(_get_loss_scale, _set_loss_scale)
cur_scale = property(_get_loss_scale, _set_loss_scale)
# Return group tensor after removing paddings that are added for alignment to DP world size.
# This method works on the assumption that each group contains sub partitions.
def _get_groups_without_padding(self, groups_with_padding):
groups_without_padding = []
local_rank = dist.get_rank(group=self.dp_process_group)
for i, group in enumerate(groups_with_padding):
low_index = local_rank * len(group)
high_index = (local_rank + 1) * len(group)
group_paddings = self.group_paddings[i][low_index:high_index]
lean_sub_partitions = []
for j, sub_partition in enumerate(group):
lean_length = sub_partition.numel() - group_paddings[j]
lean_sub_partitions.append(sub_partition[:lean_length])
groups_without_padding.append(lean_sub_partitions)
return groups_without_padding
# Return optimizer state after removing paddings that are added for alignment.
def _get_state_without_padding(self, state_with_padding, padding):
lean_state = {}
for key, value in state_with_padding.items():
lean_length = value.numel() - padding
lean_state[key] = value[:lean_length]
return lean_state
# Return base optimizer states.
# This method assumes that each param group contains a single flattened tensor.
def _get_base_optimizer_state(self):
optimizer_groups_state = []
local_rank = dist.get_rank(group=self.dp_process_group)
for group_idx, group in enumerate(self.optimizer.param_groups):
group_lean_state = []
low_index = local_rank * self.num_comm_intervals_per_group[group_idx]
high_index = (local_rank + 1) * self.num_comm_intervals_per_group[group_idx]
param_paddings = self.group_paddings[group_idx][low_index:high_index]
for param_idx, param in enumerate(group['params']):
lean_state = self._get_state_without_padding(self.optimizer.state[param],
param_paddings[param_idx])
group_lean_state.append(lean_state)
optimizer_groups_state.append(group_lean_state)
return optimizer_groups_state
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
......@@ -736,20 +833,140 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
state_dict[
'local_sub_partitions_of_fp32_groups'] = self.local_sub_partitions_of_fp32_groups
state_dict['base_optimizer_state'] = self._get_base_optimizer_state()
state_dict['zero_stage'] = ZERO_OPTIMIZATION_OPTIMIZER_STATES
state_dict['partition_count'] = self.partition_count
state_dict['num_comm_intervals_per_group'] = self.num_comm_intervals_per_group
# Remove paddings for DP alignment to enable loading for other alignment values
fp32_groups_without_padding = self._get_groups_without_padding(
self.local_sub_partitions_of_fp32_groups)
state_dict['local_sub_partitions_of_fp32_groups'] = fp32_groups_without_padding
return state_dict
def _retrieve_group_sub_partition_weights(self, all_partition_fp32_weights):
partition_id = dist.get_rank(group=self.dp_process_group)
all_sub_partition_weights = []
for partition_weights in all_partition_fp32_weights:
for sub_partition_weights in partition_weights:
all_sub_partition_weights.append(sub_partition_weights)
flat_merged_weights = flatten_dense_tensors_sub_partition_aligned(
tensor_list=all_sub_partition_weights,
dp=dist.get_world_size(group=self.dp_process_group),
max_elements_per_comm=self.max_elements_per_comm,
pg=self.dp_process_group)
comm_partitions, dp_sub_partitions, element_intervals, sub_partition_size, num_comm_intervals = \
self.get_data_parallel_sub_partitions(
tensor=flat_merged_weights,
max_elements_per_comm=self.max_elements_per_comm,
world_size=dist.get_world_size(group=self.dp_process_group),
dp_process_group=self.dp_process_group
)
return [sub_partition for sub_partition in dp_sub_partitions[partition_id]]
# Restore base optimizer fp32 weights from checkpoint by:
# 1) Merging fp32 weights from checkpoints of all partitions
# 2) Extracting fp32 weights for current partition from merged weights
# 3) Using extracted weights to update base optimizer weights directly.
def _restore_from_fp32_weights(self, all_state_dict):
sub_partition_of_fp32_groups = []
for group_idx in range(len(self.local_sub_partitions_of_fp32_groups)):
all_partition_fp32_weights = [
sd['local_sub_partitions_of_fp32_groups'][group_idx]
for sd in all_state_dict
]
sub_partition_weights = self._retrieve_group_sub_partition_weights(
all_partition_fp32_weights)
sub_partition_of_fp32_groups.append(sub_partition_weights)
for current_group, saved_group in zip(self.local_sub_partitions_of_fp32_groups, sub_partition_of_fp32_groups):
for current_sub_part, saved_sub_part in zip(current_group, saved_group):
current_sub_part.data.copy_(saved_sub_part.data)
# Extract optimizer state for current partition from merged states of all partitions
def _partition_base_optimizer_state(self, state_key, all_partition_states):
partition_id = dist.get_rank(group=self.dp_process_group)
alignment = dist.get_world_size(group=self.dp_process_group)
flat_merged_partitions = flatten_dense_tensors_sub_partition_aligned(
tensor_list=all_partition_states,
dp=dist.get_world_size(group=self.dp_process_group),
max_elements_per_comm=self.max_elements_per_comm,
pg=self.dp_process_group)
comm_partitions, dp_sub_partitions, element_intervals, sub_partition_size, num_comm_intervals = \
self.get_data_parallel_sub_partitions(
tensor=flat_merged_partitions,
max_elements_per_comm=self.max_elements_per_comm,
world_size=dist.get_world_size(group=self.dp_process_group),
dp_process_group=self.dp_process_group
)
return [sub_partition for sub_partition in dp_sub_partitions[partition_id]]
# Compute the optimizer state partitions for the group by
# 1) Merging state values across the previous partitioning.
# 2) Repartition state values for the new partitioning
# 3) Return state corresponding to local partition
def _retrieve_group_optimizer_states(self, all_partition_states):
merged_optimizer_states = {}
for partition_state in all_partition_states:
for sub_partition_state in partition_state:
for key, value in sub_partition_state.items():
if not key in merged_optimizer_states.keys():
merged_optimizer_states[key] = [value]
else:
merged_optimizer_states[key].append(value)
group_optimizer_states = {}
for key, value in merged_optimizer_states.items():
group_optimizer_states[key] = self._partition_base_optimizer_state(
key,
value)
return group_optimizer_states
# Restore base optimizer state from checkpoint by
# 1) Merging optimizer state from checkpoints of all partitions
# 2) Extracting optimizer state for current partition from the merged state
# 3) Using the extracted value to directly update the base optimizer.
def _restore_base_optimizer_state(self, state_dict_list):
base_optimizer_group_states = []
for group_idx in range(len(self.optimizer.param_groups)):
all_partition_group_states = [
sd['base_optimizer_state'][group_idx] for sd in state_dict_list
]
group_optimizer_states = self._retrieve_group_optimizer_states(
all_partition_group_states)
base_optimizer_group_states.append(group_optimizer_states)
for group_idx, group in enumerate(self.optimizer.param_groups):
for param_idx, param in enumerate(group['params']):
for key, saved in base_optimizer_group_states[group_idx].items():
current = self.optimizer.state[param][key]
current.data.copy_(saved[param_idx].data)
# Restore base optimizer fp32 weights from ZeRO fp16 weights
def _restore_from_fp16_weights(self):
partition_id = dist.get_rank(group=self.dp_process_group)
for fp16_partitions, fp32_partitions in zip(self.parallel_sub_partitioned_fp16_groups, self.local_sub_partitions_of_fp32_groups):
for fp16_sub_partition, fp32_sub_partition in zip(fp16_partitions[partition_id], fp32_partitions):
fp32_sub_partition.data.copy_(fp16_sub_partition.data)
# Refresh the fp32 master params from the fp16 copies.
def refresh_fp32_params(self):
partition_id = dist.get_rank(group=self.dp_process_group)
for fp16_all_sub_partitions, fp32_local_sub_partitions in zip(self.parallel_sub_partitioned_fp16_groups, self.local_sub_partitions_of_fp32_groups):
for local_sub_partition_param_fp16, local_sub_partition_param_fp32 in zip(fp16_all_sub_partitions[partition_id], fp32_local_sub_partitions):
local_sub_partition_param_fp32.data.copy_(
local_sub_partition_param_fp16.data)
self._restore_from_fp16_weights()
def load_state_dict(self, state_dict, load_optimizer_states=True):
def load_state_dict(self,
state_dict_list,
load_optimizer_states=True,
load_from_fp32_weights=False):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
......@@ -766,12 +983,14 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# I think it should actually be ok to reload the optimizer before the model.
self.loss_scaler = state_dict['loss_scaler']
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.overflow = state_dict['overflow']
self.loss_scaler = state_dict_list[0]['loss_scaler']
self.dynamic_loss_scale = state_dict_list[0]['dynamic_loss_scale']
self.overflow = state_dict_list[0]['overflow']
if load_optimizer_states:
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
self._restore_base_optimizer_state(state_dict_list)
for curr_group, saved_group in zip(self.local_sub_partitions_of_fp32_groups, state_dict['local_sub_partitions_of_fp32_groups']):
for curr_param, saved_param in zip(curr_group, saved_group):
curr_param.data.copy_(saved_param.data)
if load_from_fp32_weights:
self._restore_from_fp32_weights(state_dict_list)
else:
self._restore_from_fp16_weights()
......@@ -53,7 +53,7 @@ class GPT2CheckpointTestCase(BaseTestCase):
def tearDown(self):
os.chdir(self.save_dir)
def test_mp4_gpu16_node1_with_zero1(self):
def test_mp2_gpu4_node1_with_zero1(self):
test_config = {
"mp": 2,
"gpus": 4,
......@@ -68,14 +68,14 @@ class GPT2CheckpointTestCase(BaseTestCase):
"tag": "ds_zero1",
"zero": True,
"other_args": "",
"checkpoint_name": "ckpt_mp4_gpu16_w_zero1",
"checkpoint_name": "ckpt_mp2_gpu8_w_zero1",
"checkpoint_interval": 1000,
"json": "ds_config_func_bs8_zero1.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp4_gpu16_node1_with_zero2(self):
def test_mp2_gpu4_node1_with_zero2(self):
test_config = {
"mp": 2,
"gpus": 4,
......@@ -90,14 +90,198 @@ class GPT2CheckpointTestCase(BaseTestCase):
"tag": "ds_zero2",
"zero": True,
"other_args": "",
"checkpoint_name": "ckpt_mp4_gpu16_w_zero2",
"checkpoint_name": "ckpt_mp2_gpu8_w_zero2",
"checkpoint_interval": 1000,
"json": "ds_config_func_bs8_zero2.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp4_gpu16_node1_without_zero(self):
def test_mp1_gpu2_load_gpu1_node1_with_zero1(self):
test_config = {
"mp": 1,
"gpus": 2,
"load_gpus": 1,
"nodes": 1,
"bs": 8,
"steps": 1100,
"layers": LAYERS,
"hidden_size": HIDDEN_SIZE,
"seq_length": 256,
"heads": ATTN_HEADS,
"deepspeed": True,
"tag": "ds_zero2",
"zero": True,
"other_args": "",
"checkpoint_name": "ckpt_mp1_gpu2_gpu1_w_zero1",
"checkpoint_interval": 1000,
"json": "ds_config_func_bs8_zero1.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp1_gpu2_load_gpu4_node1_with_zero1(self):
test_config = {
"mp": 1,
"gpus": 2,
"load_gpus": 4,
"nodes": 1,
"bs": 8,
"steps": 1100,
"layers": LAYERS,
"hidden_size": HIDDEN_SIZE,
"seq_length": 256,
"heads": ATTN_HEADS,
"deepspeed": True,
"tag": "ds_zero2",
"zero": True,
"other_args": "",
"checkpoint_name": "ckpt_mp1_gpu2_gpu4_w_zero1",
"checkpoint_interval": 1000,
"json": "ds_config_func_bs8_zero1.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp1_gpu2_load_gpu1_node1_with_zero2(self):
test_config = {
"mp": 1,
"gpus": 2,
"load_gpus": 1,
"nodes": 1,
"bs": 8,
"steps": 1100,
"layers": LAYERS,
"hidden_size": HIDDEN_SIZE,
"seq_length": 256,
"heads": ATTN_HEADS,
"deepspeed": True,
"tag": "ds_zero2",
"zero": True,
"other_args": "",
"checkpoint_name": "ckpt_mp1_gpu2_gpu1_w_zero2",
"checkpoint_interval": 1000,
"json": "ds_config_func_bs8_zero2.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp1_gpu2_load_gpu4_node1_with_zero2(self):
test_config = {
"mp": 1,
"gpus": 2,
"load_gpus": 4,
"nodes": 1,
"bs": 8,
"steps": 1100,
"layers": LAYERS,
"hidden_size": HIDDEN_SIZE,
"seq_length": 256,
"heads": ATTN_HEADS,
"deepspeed": True,
"tag": "ds_zero2",
"zero": True,
"other_args": "",
"checkpoint_name": "ckpt_mp1_gpu2_gpu4_w_zero2",
"checkpoint_interval": 1000,
"json": "ds_config_func_bs8_zero2.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp2_gpu4_load_gpu2_node1_with_zero1(self):
test_config = {
"mp": 2,
"gpus": 4,
"load_gpus": 2,
"nodes": 1,
"bs": 8,
"steps": 1100,
"layers": LAYERS,
"hidden_size": HIDDEN_SIZE,
"seq_length": 256,
"heads": ATTN_HEADS,
"deepspeed": True,
"tag": "ds_zero1",
"zero": True,
"other_args": "",
"checkpoint_name": "ckpt_mp2_gpu4_gpu2_w_zero1",
"checkpoint_interval": 1000,
"json": "ds_config_func_bs8_zero1.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp2_gpu2_load_gpu4_node1_with_zero1(self):
test_config = {
"mp": 2,
"gpus": 2,
"load_gpus": 4,
"nodes": 1,
"bs": 8,
"steps": 1100,
"layers": LAYERS,
"hidden_size": HIDDEN_SIZE,
"seq_length": 256,
"heads": ATTN_HEADS,
"deepspeed": True,
"tag": "ds_zero1",
"zero": True,
"other_args": "",
"checkpoint_name": "ckpt_mp2_gpu2_gpu4_w_zero1",
"checkpoint_interval": 1000,
"json": "ds_config_func_bs8_zero1.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp2_gpu4_load_gpu2_node1_with_zero2(self):
test_config = {
"mp": 2,
"gpus": 4,
"load_gpus": 2,
"nodes": 1,
"bs": 8,
"steps": 1100,
"layers": LAYERS,
"hidden_size": HIDDEN_SIZE,
"seq_length": 256,
"heads": ATTN_HEADS,
"deepspeed": True,
"tag": "ds_zero2",
"zero": True,
"other_args": "",
"checkpoint_name": "ckpt_mp2_gpu4_gpu2_w_zero2",
"checkpoint_interval": 1000,
"json": "ds_config_func_bs8_zero2.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp2_gpu2_load_gpu4_node1_with_zero2(self):
test_config = {
"mp": 2,
"gpus": 2,
"load_gpus": 4,
"nodes": 1,
"bs": 8,
"steps": 1100,
"layers": LAYERS,
"hidden_size": HIDDEN_SIZE,
"seq_length": 256,
"heads": ATTN_HEADS,
"deepspeed": True,
"tag": "ds_zero2",
"zero": True,
"other_args": "",
"checkpoint_name": "ckpt_mp2_gpu2_gpu4_w_zero2",
"checkpoint_interval": 1000,
"json": "ds_config_func_bs8_zero2.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp2_gpu4_node1_without_zero(self):
test_config = {
"mp": 2,
"gpus": 4,
......@@ -130,6 +314,14 @@ class GPT2CheckpointTestCase(BaseTestCase):
print("{0}: starting......".format(self.id()))
# Cache save and load gpu counts
save_gpus = test_config["gpus"]
if "load_gpus" in test_config:
load_gpus = test_config["load_gpus"]
del test_config["load_gpus"]
else:
load_gpus = test_config["gpus"]
# save to current directory.
checkpoint_folder = test_config["checkpoint_name"]
checkpoint_interval = test_config["checkpoint_interval"]
......@@ -178,6 +370,9 @@ class GPT2CheckpointTestCase(BaseTestCase):
prefix = "gpt2_loading_checkpoint"
# set load gpus
test_config["gpus"] = load_gpus
print("{0}: Second run loading checkpoint and continuing.".format(self.id()))
test_file = self.gen_name(test_config, prefix)
......@@ -216,9 +411,22 @@ class GPT2CheckpointTestCase(BaseTestCase):
def checkpoint_suite():
suite = unittest.TestSuite()
suite.addTest(GPT2CheckpointTestCase('test_mp4_gpu16_node1_with_zero1'))
suite.addTest(GPT2CheckpointTestCase('test_mp4_gpu16_node1_with_zero2'))
suite.addTest(GPT2CheckpointTestCase('test_mp4_gpu16_node1_without_zero'))
suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu4_node1_with_zero1'))
suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu4_node1_with_zero2'))
# Shrink DP
suite.addTest(GPT2CheckpointTestCase('test_mp1_gpu2_load_gpu1_node1_with_zero1'))
suite.addTest(GPT2CheckpointTestCase('test_mp1_gpu2_load_gpu1_node1_with_zero2'))
suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu4_load_gpu2_node1_with_zero1'))
suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu4_load_gpu2_node1_with_zero2'))
# Expand DP
suite.addTest(GPT2CheckpointTestCase('test_mp1_gpu2_load_gpu4_node1_with_zero1'))
suite.addTest(GPT2CheckpointTestCase('test_mp1_gpu2_load_gpu4_node1_with_zero2'))
suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu2_load_gpu4_node1_with_zero1'))
suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu2_load_gpu4_node1_with_zero2'))
suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu4_node1_without_zero'))
return suite
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment