Unverified Commit 08c96a1b authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

ZeRO-1 tune max-elems + bug fix (#532)

* zero-1 memory fix

* auto-tune max elems per comm to reduce padding/comm intervals

* clean-up and added previously missing reduction options

* fix testing backing to work with torch1.7
parent fdd81c30
...@@ -661,7 +661,7 @@ class DeepSpeedEngine(Module): ...@@ -661,7 +661,7 @@ class DeepSpeedEngine(Module):
def _configure_zero_optimizer(self, optimizer): def _configure_zero_optimizer(self, optimizer):
zero_stage = self.zero_optimization_stage() zero_stage = self.zero_optimization_stage()
logger.info('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage)) logger.info('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage))
assert not self.allreduce_always_fp32(), "ZeRO does not support 'fp32_allreduce': true"
if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES: if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
assert self.zero_reduce_scatter(), 'Stage 1 only supports reduce scatter mode' assert self.zero_reduce_scatter(), 'Stage 1 only supports reduce scatter mode'
optimizer = FP16_DeepSpeedZeroOptimizer_Stage1( optimizer = FP16_DeepSpeedZeroOptimizer_Stage1(
......
...@@ -68,12 +68,6 @@ class CheckOverflow(object): ...@@ -68,12 +68,6 @@ class CheckOverflow(object):
return bool(overflow) return bool(overflow)
def check(self, param_groups=None): def check(self, param_groups=None):
#TODO: what's the equivalent here? do we need this?
# for group in self.fp32_from_fp32_groups:
# for param in group:
# params.append(param)
params = [] params = []
if param_groups is None: if param_groups is None:
params = self.params params = self.params
......
...@@ -73,7 +73,8 @@ def flatten_dense_tensors_sub_partition_aligned(tensor_list, ...@@ -73,7 +73,8 @@ def flatten_dense_tensors_sub_partition_aligned(tensor_list,
dtype=tensor_list[0].dtype) dtype=tensor_list[0].dtype)
aligned_tensor_list = tensor_list + [pad_tensor] aligned_tensor_list = tensor_list + [pad_tensor]
return _flatten_dense_tensors(aligned_tensor_list) flat_tensors = _flatten_dense_tensors(aligned_tensor_list)
return flat_tensors
def _single_range_check(current_index, start_index, end_index, tensor_size): def _single_range_check(current_index, start_index, end_index, tensor_size):
...@@ -144,8 +145,8 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -144,8 +145,8 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
self.all_gather_partitions = all_gather_partitions self.all_gather_partitions = all_gather_partitions
self.allgather_size = allgather_size self.allgather_size = allgather_size
self.max_elements_per_comm = max_elements_per_comm # self.max_elements_per_comm = max_elements_per_comm
logger.info("max_elements_per_comm={}".format(max_elements_per_comm)) # logger.info("max_elements_per_comm={}".format(max_elements_per_comm))
self.elastic_checkpoint = elastic_checkpoint self.elastic_checkpoint = elastic_checkpoint
logger.info(f'ZeRO Elastic Checkpoint = {elastic_checkpoint}') logger.info(f'ZeRO Elastic Checkpoint = {elastic_checkpoint}')
...@@ -189,19 +190,32 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -189,19 +190,32 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
self.group_paddings = [] self.group_paddings = []
self.partition_count = dist.get_world_size(group=self.dp_process_group) self.partition_count = dist.get_world_size(group=self.dp_process_group)
self.default_device = self.optimizer.param_groups[0]['params'][0].device
# max elems per param group
self.max_elems_per_comm = []
self.legacy_max_elements_per_comm = max_elements_per_comm
# loop to deal with groups # loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups): for i, param_group in enumerate(self.optimizer.param_groups):
# push this group to list before modify # push this group to list before modify
self.fp16_groups.append(param_group['params']) self.fp16_groups.append(param_group['params'])
# calculate best max elements per comm based to minimize padding
self.max_elems_per_comm.append(
self.best_max_elems_per_comm(
num_elements=sum(t.numel() for t in self.fp16_groups[i]),
max_elements_per_comm=max_elements_per_comm,
dp=dist.get_world_size(group=self.dp_process_group)))
# flattens all tensors into single 1d tensor aligned with sub-partition size for later dividing # flattens all tensors into single 1d tensor aligned with sub-partition size for later dividing
# RS: create aligned sub-partitions # RS: create aligned sub-partitions
self.fp16_groups_flat.append( flat_aligned_params = flatten_dense_tensors_sub_partition_aligned(
flatten_dense_tensors_sub_partition_aligned( tensor_list=self.fp16_groups[i],
tensor_list=self.fp16_groups[i], dp=dist.get_world_size(group=self.dp_process_group),
dp=dist.get_world_size(group=self.dp_process_group), max_elements_per_comm=self.max_elems_per_comm[i],
max_elements_per_comm=self.max_elements_per_comm, pg=self.dp_process_group)
pg=self.dp_process_group)) self.fp16_groups_flat.append(flat_aligned_params)
# TODO: I don't think this does anything? # TODO: I don't think this does anything?
# set model fp16 weight to slices of flattened buffer # set model fp16 weight to slices of flattened buffer
...@@ -216,7 +230,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -216,7 +230,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
comm_partitions, dp_sub_partitions, element_intervals, sub_partition_size, num_comm_intervals = \ comm_partitions, dp_sub_partitions, element_intervals, sub_partition_size, num_comm_intervals = \
self.get_data_parallel_sub_partitions( self.get_data_parallel_sub_partitions(
tensor=self.fp16_groups_flat[i], tensor=self.fp16_groups_flat[i],
max_elements_per_comm=self.max_elements_per_comm, max_elements_per_comm=self.max_elems_per_comm[i],
world_size=dist.get_world_size( world_size=dist.get_world_size(
group=self.dp_process_group), group=self.dp_process_group),
dp_process_group=self.dp_process_group dp_process_group=self.dp_process_group
...@@ -303,6 +317,34 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -303,6 +317,34 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
for idx, sub_partition_param in enumerate(group): for idx, sub_partition_param in enumerate(group):
sub_partition_param.grad = None sub_partition_param.grad = None
@staticmethod
def best_max_elems_per_comm(num_elements, max_elements_per_comm, dp):
# if we use max-elems-per-comm as is, how many comm intervals will there be
max_comm_intervals = math.ceil(num_elements / max_elements_per_comm)
padding_for_max_comm = (max_elements_per_comm *
max_comm_intervals) - num_elements
# if we use 1 less comm interval how much extra comm padding would be required
min_comm_intervals = num_elements // max_elements_per_comm
if min_comm_intervals == 0:
log_dist(f'Using default max_elements_per_comm {max_elements_per_comm}',
ranks=[0])
return max_elements_per_comm
padding_for_min_comm = math.ceil(num_elements / (dp * min_comm_intervals))
# choose padding that uses least amount of overhead
if padding_for_max_comm > padding_for_min_comm:
new_max_elements_per_comm = padding_for_min_comm + max_elements_per_comm
log_dist(
f'Updating max_elements_per_comm from {max_elements_per_comm} -> {new_max_elements_per_comm}',
ranks=[0])
return new_max_elements_per_comm
else:
log_dist(f'Using default max_elements_per_comm {max_elements_per_comm}',
ranks=[0])
return max_elements_per_comm
@staticmethod @staticmethod
def get_data_parallel_sub_partitions(tensor, def get_data_parallel_sub_partitions(tensor,
max_elements_per_comm, max_elements_per_comm,
...@@ -419,9 +461,10 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -419,9 +461,10 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
comm_param_offsets, comm_param_offsets,
sub_partition_size, sub_partition_size,
dtype, dtype,
default_device,
num_comm_intervals=None, num_comm_intervals=None,
default_device=None,
return_partition_params=False): return_partition_params=False):
partition_params = [] partition_params = []
final_param_offsets = [] final_param_offsets = []
flat_sub_partitions = [] flat_sub_partitions = []
...@@ -431,9 +474,6 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -431,9 +474,6 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
my_offsets = [] my_offsets = []
my_params = [] my_params = []
if dtype is None:
dtype = tensor_list[0].dtype
for i, tensor in enumerate(tensor_list): for i, tensor in enumerate(tensor_list):
if tensor.grad is None: if tensor.grad is None:
tensor.grad = torch.zeros(tensor.size(), tensor.grad = torch.zeros(tensor.size(),
...@@ -538,73 +578,49 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -538,73 +578,49 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
local_rank = dist.get_rank(group=self.dp_process_group) local_rank = dist.get_rank(group=self.dp_process_group)
for i, group in enumerate(self.fp16_groups): for i, group in enumerate(self.fp16_groups):
partition_param_map = {}
param_partition_map = {}
my_params = set()
# [rank] -> [comm] -> partition
num_comm_intervals = self.num_comm_intervals_per_group[i] num_comm_intervals = self.num_comm_intervals_per_group[i]
all_sub_partitions = [] all_sub_partitions = []
for rank in range(world_size): for rank in range(world_size):
# gsp is list of partitions indexed by comm_idx # gsp is list of partitions indexed by comm_idx
#FIXME: currently hardcoding fp16, should infer dtype grad_sub_partitions = self.get_flat_sub_partitions(
grad_sub_partitions, partition_params, param_offsets = self.get_flat_sub_partitions(
comm_tensor_list=self.params_in_rank_sub_partitions[i][rank], comm_tensor_list=self.params_in_rank_sub_partitions[i][rank],
comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i][rank], comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i]
[rank],
dtype=torch.half,
default_device=self.default_device,
sub_partition_size=self.sub_partition_sizes[i], sub_partition_size=self.sub_partition_sizes[i],
dtype=torch.half, #self.params_in_rank_sub_partitions[i][rank][0][0].dtype, num_comm_intervals=self.num_comm_intervals_per_group[i])
num_comm_intervals=self.num_comm_intervals_per_group[i],
default_device='cuda', #self.params_in_rank_sub_partitions[i][rank][0][0].device,
return_partition_params=True)
all_sub_partitions.append(grad_sub_partitions) all_sub_partitions.append(grad_sub_partitions)
# create map from partition -> params in that partition
for comm_idx, part in enumerate(grad_sub_partitions):
partition_param_map[part] = (partition_params[comm_idx],
param_offsets[comm_idx])
for comm_idx, params in enumerate(partition_params):
for pidx, p in enumerate(params):
# store the parameters we care about locally
if rank == local_rank:
my_params.add(p)
# map from param -> partitions
if p in param_partition_map:
param_partition_map[p].append(grad_sub_partitions[comm_idx])
else:
param_partition_map[p] = [grad_sub_partitions[comm_idx]]
assert len(grad_sub_partitions) == num_comm_intervals assert len(grad_sub_partitions) == num_comm_intervals
if not postscale_gradients: local_comm_partitions = []
raise NotImplementedError("pre-scale_gradients is not implemented")
all_comm_partitions = []
for comm_idx in range(num_comm_intervals): for comm_idx in range(num_comm_intervals):
single_comm_all_partitions = [] single_comm_all_partitions = []
for rank in range(world_size): for rank in range(world_size):
single_comm_all_partitions.append(all_sub_partitions[rank][comm_idx]) single_comm_all_partitions.append(all_sub_partitions[rank][comm_idx])
dist.reduce_scatter(output=single_comm_all_partitions[local_rank],
input_list=single_comm_all_partitions,
group=self.dp_process_group)
if gradient_average:
for partition in single_comm_all_partitions:
partition.mul_(gradient_predivide_factor / world_size)
all_comm_partitions.append(single_comm_all_partitions) if postscale_gradients:
if gradient_predivide_factor != 1.0:
for partition in single_comm_all_partitions:
partition.mul_(1. / gradient_predivide_factor)
# stitch together all rank sub partitions for each comm idx dist.reduce_scatter(output=single_comm_all_partitions[local_rank],
flat_comm_grads = [] input_list=single_comm_all_partitions,
for comm_idx, rank_partitions in enumerate(all_comm_partitions): group=self.dp_process_group)
flat_comm_grads.append(torch.cat(rank_partitions))
flat_all_grads = torch.cat(flat_comm_grads) if gradient_average:
# Only need to average our local grads in post scaling
if gradient_predivide_factor != world_size:
single_comm_all_partitions[local_rank].mul_(
gradient_predivide_factor / world_size)
else:
for partition in single_comm_all_partitions:
partition.div_(world_size)
# copy back reduced gradients but only those needed for this local rank dist.reduce_scatter(output=single_comm_all_partitions[local_rank],
for param, updated_grad in zip(self.fp16_groups[i], _unflatten_dense_tensors(flat_all_grads, self.fp16_groups[i])): input_list=single_comm_all_partitions,
if param in my_params: group=self.dp_process_group)
param.grad.copy_(updated_grad)
def step(self, closure=None): def step(self, closure=None):
# First compute norm for all group so we know if there is overflow # First compute norm for all group so we know if there is overflow
...@@ -626,7 +642,6 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -626,7 +642,6 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
partition_id = dist.get_rank(group=self.dp_process_group) partition_id = dist.get_rank(group=self.dp_process_group)
for i, group in enumerate(self.fp16_groups): for i, group in enumerate(self.fp16_groups):
#TODO RS: update get grad norm to support sub partitions #TODO RS: update get grad norm to support sub partitions
norm_groups.append(get_grad_norm(group, mpu=self.mpu)) norm_groups.append(get_grad_norm(group, mpu=self.mpu))
...@@ -634,16 +649,6 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -634,16 +649,6 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
#free gradients for all the parameters that are not updated by this process #free gradients for all the parameters that are not updated by this process
self.free_grad_in_param_list(self.params_not_local[i]) self.free_grad_in_param_list(self.params_not_local[i])
#create flat gradients for parameters updated by this process
#tensor_list, first_offset, partition_size, dtype
#single_grad_partition = self.get_flat_partition(
# tensor_list=self.params_in_partition[i],
# first_offset=self.first_offset[i],
# partition_size=self.partition_size[i],
# dtype=self.single_partition_of_fp32_groups[i].dtype
#)
#TODO RS: can we safely use dtype of the first sub-partition? i think so
# create flat gradient partitions for parameters updated by this process # create flat gradient partitions for parameters updated by this process
local_grad_sub_partitions = self.get_flat_sub_partitions( local_grad_sub_partitions = self.get_flat_sub_partitions(
comm_tensor_list=self.params_in_rank_sub_partitions[i][partition_id], comm_tensor_list=self.params_in_rank_sub_partitions[i][partition_id],
...@@ -652,13 +657,11 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -652,13 +657,11 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
sub_partition_size=self.sub_partition_sizes[i], sub_partition_size=self.sub_partition_sizes[i],
dtype=self.local_sub_partitions_of_fp32_groups[i][0].dtype, dtype=self.local_sub_partitions_of_fp32_groups[i][0].dtype,
num_comm_intervals=self.num_comm_intervals_per_group[i], num_comm_intervals=self.num_comm_intervals_per_group[i],
default_device=self.local_sub_partitions_of_fp32_groups[i][0].device) default_device=self.default_device)
#RS: update all our local params with sub-partition grads #RS: update all our local params with sub-partition grads
#logger. info("self.local_sub_partitions_of_fp32_groups[i]={}, local_grad_sub_partitions={}".format(len(self.local_sub_partitions_of_fp32_groups[i]), len(local_grad_sub_partitions)))
for idx, sub_partition_param in enumerate(self.local_sub_partitions_of_fp32_groups[i]): for idx, sub_partition_param in enumerate(self.local_sub_partitions_of_fp32_groups[i]):
sub_partition_param.grad = local_grad_sub_partitions[idx] sub_partition_param.grad = local_grad_sub_partitions[idx]
#self.single_partition_of_fp32_groups[i].grad = single_grad_partition
#RS: update free grads for sub-partitions #RS: update free grads for sub-partitions
#release all the gradient since we have already created a necessary copy in dp_grad_partition #release all the gradient since we have already created a necessary copy in dp_grad_partition
...@@ -856,6 +859,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -856,6 +859,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
state_dict['zero_stage'] = ZERO_OPTIMIZATION_OPTIMIZER_STATES state_dict['zero_stage'] = ZERO_OPTIMIZATION_OPTIMIZER_STATES
state_dict['partition_count'] = self.partition_count state_dict['partition_count'] = self.partition_count
state_dict['num_comm_intervals_per_group'] = self.num_comm_intervals_per_group state_dict['num_comm_intervals_per_group'] = self.num_comm_intervals_per_group
state_dict['max_elems_per_comm'] = self.max_elems_per_comm
# Remove paddings for DP alignment to enable loading for other alignment values # Remove paddings for DP alignment to enable loading for other alignment values
fp32_groups_without_padding = self._get_groups_without_padding( fp32_groups_without_padding = self._get_groups_without_padding(
...@@ -887,7 +891,9 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -887,7 +891,9 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
# rank 0 = [sub_0_0, sub_0_1] # rank 0 = [sub_0_0, sub_0_1]
# rank 1 = [sub_1_0, sub_1_1] # rank 1 = [sub_1_0, sub_1_1]
# Merge to get [sub_0_0, sub_1_0, sub_0_1, sub_1_1] => original un-padded flattened tensor. # Merge to get [sub_0_0, sub_1_0, sub_0_1, sub_1_1] => original un-padded flattened tensor.
def _retrieve_group_sub_partition_weights(self, all_partition_fp32_weights): def _retrieve_group_sub_partition_weights(self,
all_partition_fp32_weights,
max_elems_per_comm):
num_partitions = len(all_partition_fp32_weights) num_partitions = len(all_partition_fp32_weights)
num_comm_intervals = len(all_partition_fp32_weights[0]) num_comm_intervals = len(all_partition_fp32_weights[0])
num_sub_partitions = num_partitions * num_comm_intervals num_sub_partitions = num_partitions * num_comm_intervals
...@@ -902,13 +908,13 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -902,13 +908,13 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
flat_merged_weights = flatten_dense_tensors_sub_partition_aligned( flat_merged_weights = flatten_dense_tensors_sub_partition_aligned(
tensor_list=all_sub_partition_weights, tensor_list=all_sub_partition_weights,
dp=dist.get_world_size(group=self.dp_process_group), dp=dist.get_world_size(group=self.dp_process_group),
max_elements_per_comm=self.max_elements_per_comm, max_elements_per_comm=max_elems_per_comm,
pg=self.dp_process_group) pg=self.dp_process_group)
comm_partitions, dp_sub_partitions, element_intervals, sub_partition_size, num_comm_intervals = \ comm_partitions, dp_sub_partitions, element_intervals, sub_partition_size, num_comm_intervals = \
self.get_data_parallel_sub_partitions( self.get_data_parallel_sub_partitions(
tensor=flat_merged_weights, tensor=flat_merged_weights,
max_elements_per_comm=self.max_elements_per_comm, max_elements_per_comm=max_elems_per_comm,
world_size=dist.get_world_size(group=self.dp_process_group), world_size=dist.get_world_size(group=self.dp_process_group),
dp_process_group=self.dp_process_group dp_process_group=self.dp_process_group
) )
...@@ -927,8 +933,14 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -927,8 +933,14 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
sd['local_sub_partitions_of_fp32_groups'][group_idx] sd['local_sub_partitions_of_fp32_groups'][group_idx]
for sd in all_state_dict for sd in all_state_dict
] ]
if 'max_elems_per_comm' in all_state_dict[0]:
max_elems_per_comm = all_state_dict[0]['max_elems_per_comm'][group_idx]
else:
max_elems_per_comm = self.legacy_max_elements_per_comm
sub_partition_weights = self._retrieve_group_sub_partition_weights( sub_partition_weights = self._retrieve_group_sub_partition_weights(
all_partition_fp32_weights) all_partition_fp32_weights,
max_elems_per_comm)
sub_partition_of_fp32_groups.append(sub_partition_weights) sub_partition_of_fp32_groups.append(sub_partition_weights)
for current_group, saved_group in zip(self.local_sub_partitions_of_fp32_groups, sub_partition_of_fp32_groups): for current_group, saved_group in zip(self.local_sub_partitions_of_fp32_groups, sub_partition_of_fp32_groups):
...@@ -936,20 +948,23 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -936,20 +948,23 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
current_sub_part.data.copy_(saved_sub_part.data) current_sub_part.data.copy_(saved_sub_part.data)
# Extract optimizer state for current partition from merged states of all partitions # Extract optimizer state for current partition from merged states of all partitions
def _partition_base_optimizer_state(self, state_key, all_partition_states): def _partition_base_optimizer_state(self,
state_key,
all_partition_states,
max_elems_per_comm):
partition_id = dist.get_rank(group=self.dp_process_group) partition_id = dist.get_rank(group=self.dp_process_group)
alignment = dist.get_world_size(group=self.dp_process_group) alignment = dist.get_world_size(group=self.dp_process_group)
flat_merged_partitions = flatten_dense_tensors_sub_partition_aligned( flat_merged_partitions = flatten_dense_tensors_sub_partition_aligned(
tensor_list=all_partition_states, tensor_list=all_partition_states,
dp=dist.get_world_size(group=self.dp_process_group), dp=dist.get_world_size(group=self.dp_process_group),
max_elements_per_comm=self.max_elements_per_comm, max_elements_per_comm=max_elems_per_comm,
pg=self.dp_process_group) pg=self.dp_process_group)
comm_partitions, dp_sub_partitions, element_intervals, sub_partition_size, num_comm_intervals = \ comm_partitions, dp_sub_partitions, element_intervals, sub_partition_size, num_comm_intervals = \
self.get_data_parallel_sub_partitions( self.get_data_parallel_sub_partitions(
tensor=flat_merged_partitions, tensor=flat_merged_partitions,
max_elements_per_comm=self.max_elements_per_comm, max_elements_per_comm=max_elems_per_comm,
world_size=dist.get_world_size(group=self.dp_process_group), world_size=dist.get_world_size(group=self.dp_process_group),
dp_process_group=self.dp_process_group dp_process_group=self.dp_process_group
) )
...@@ -960,7 +975,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -960,7 +975,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
# 1) Merging state values across the previous partitioning. # 1) Merging state values across the previous partitioning.
# 2) Repartition state values for the new partitioning # 2) Repartition state values for the new partitioning
# 3) Return state corresponding to local partition # 3) Return state corresponding to local partition
def _retrieve_group_optimizer_states(self, all_partition_states): def _retrieve_group_optimizer_states(self, all_partition_states, max_elems_per_comm):
merged_optimizer_states = {} merged_optimizer_states = {}
num_partitions = len(all_partition_states) num_partitions = len(all_partition_states)
num_comm_intervals = len(all_partition_states[0]) num_comm_intervals = len(all_partition_states[0])
...@@ -979,7 +994,8 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -979,7 +994,8 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
for key, value in merged_optimizer_states.items(): for key, value in merged_optimizer_states.items():
group_optimizer_states[key] = self._partition_base_optimizer_state( group_optimizer_states[key] = self._partition_base_optimizer_state(
key, key,
value) value,
max_elems_per_comm)
return group_optimizer_states return group_optimizer_states
...@@ -993,8 +1009,13 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -993,8 +1009,13 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
all_partition_group_states = [ all_partition_group_states = [
sd['base_optimizer_state'][group_idx] for sd in state_dict_list sd['base_optimizer_state'][group_idx] for sd in state_dict_list
] ]
if 'max_elems_per_comm' in state_dict_list[0]:
max_elems_per_comm = state_dict_list[0]['max_elems_per_comm'][group_idx]
else:
max_elems_per_comm = self.legacy_max_elements_per_comm
group_optimizer_states = self._retrieve_group_optimizer_states( group_optimizer_states = self._retrieve_group_optimizer_states(
all_partition_group_states) all_partition_group_states,
max_elems_per_comm)
base_optimizer_group_states.append(group_optimizer_states) base_optimizer_group_states.append(group_optimizer_states)
for group_idx, group in enumerate(self.optimizer.param_groups): for group_idx, group in enumerate(self.optimizer.param_groups):
......
...@@ -41,6 +41,8 @@ def distributed_test(world_size=2, backend='nccl'): ...@@ -41,6 +41,8 @@ def distributed_test(world_size=2, backend='nccl'):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
if 'args' in func_kwargs:
func_kwargs['args'].local_rank = local_rank
run_func(*func_args, **func_kwargs) run_func(*func_args, **func_kwargs)
def dist_launcher(num_procs, *func_args, **func_kwargs): def dist_launcher(num_procs, *func_args, **func_kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment