Unverified Commit 55ed1057 authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

fix bug related to stitching reduced grads across communication partitions (#318)

parent 91b4a93d
...@@ -249,8 +249,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -249,8 +249,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
# RS: divide up the sub-partitions and keep track of offsets for each param # RS: divide up the sub-partitions and keep track of offsets for each param
# partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size(group=self.dp_process_group) # partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size(group=self.dp_process_group)
params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, \ params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, params_not_local = self.get_all_sub_partition_info(
params_not_local = self.get_all_sub_partition_info(
tensor_list=self.fp16_groups[i], tensor_list=self.fp16_groups[i],
all_element_intervals=element_intervals, all_element_intervals=element_intervals,
local_rank=local_rank, local_rank=local_rank,
...@@ -591,28 +590,20 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -591,28 +590,20 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
all_comm_partitions.append(single_comm_all_partitions) all_comm_partitions.append(single_comm_all_partitions)
for p in my_params: # stitch together all rank sub partitions for each comm idx
partitions = param_partition_map[p] flat_comm_grads = []
parts = [] for comm_idx, rank_partitions in enumerate(all_comm_partitions):
for part in partitions: flat_comm_grads.append(torch.cat(rank_partitions))
params, offsets = partition_param_map[part]
found = False flat_all_grads = torch.cat(flat_comm_grads)
for p_idx, _p in enumerate(params):
if p.__hash__() == _p.__hash__(): # copy back reduced gradients but only those needed for this local rank
found = True for param, updated_grad in zip(self.fp16_groups[i], _unflatten_dense_tensors(flat_all_grads, self.fp16_groups[i])):
if offsets[p_idx][0] is not None: if param in my_params:
my_part = part.narrow(0, param.grad.copy_(updated_grad)
offsets[p_idx][0],
offsets[p_idx][1])
parts.append(my_part)
assert found
if p is not None:
updated_grad = _unflatten_dense_tensors(torch.cat(parts), [p])
p.grad.copy_(updated_grad[0])
def step(self, closure=None): def step(self, closure=None):
# First compute norm for all group so we know if there is overflow # First compute norm for all group so we know if there is overflow
self.overflow = self.overflow_checker.check() self.overflow = self.overflow_checker.check()
prev_scale = self.loss_scale prev_scale = self.loss_scale
...@@ -649,6 +640,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -649,6 +640,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
#) #)
#TODO RS: can we safely use dtype of the first sub-partition? i think so #TODO RS: can we safely use dtype of the first sub-partition? i think so
# create flat gradient partitions for parameters updated by this process
local_grad_sub_partitions = self.get_flat_sub_partitions( local_grad_sub_partitions = self.get_flat_sub_partitions(
comm_tensor_list=self.params_in_rank_sub_partitions[i][partition_id], comm_tensor_list=self.params_in_rank_sub_partitions[i][partition_id],
comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i] comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment