Commit 56650eb8 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fix in update norm calculation

parent fb2d0f48
......@@ -245,6 +245,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._contrib_model_param_for_norm_fp16 = []
self._contrib_model_param_for_norm_fp32 = []
self._contrib_model_param_for_norm_is_fp16 = []
self._model_param_is_contrib = [False]*self._model_params_num
self._contrib_group_properties = []
for shard_id in range(self._group_size):
for block_id in range(self._num_blocks):
......@@ -267,6 +268,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
else:
self._packed_flat_to_model_params_fp32.append( (new_param_packed_fragment, model_param_fragment) )
if shard_id == self._rank_in_group:
self._model_param_is_contrib[param_i] = True
# copy model parameters into master buffer
master_param_fragment = self._fp32_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_m_fragment = self._fp32_m_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
......@@ -427,8 +429,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
gnorm_fp16 = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp16], True)[1]
gnorm_fp32 = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp32], True)[1]
gnorm = torch.empty(size=[self._contrib_model_param_for_norm_num], dtype=torch.bool, device='cuda')
gnorm.masked_scatter(self._contrib_model_param_for_norm_is_fp16, gnorm_fp16)
gnorm.masked_scatter(self._contrib_model_param_for_norm_is_fp32, gnorm_fp32)
gnorm.masked_scatter_(self._contrib_model_param_for_norm_is_fp16, gnorm_fp16)
gnorm.masked_scatter_(self._contrib_model_param_for_norm_is_fp32, gnorm_fp32)
elif self._contrib_model_param_for_norm_fp16 is not None:
gnorm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp16], True)[1]
elif self._contrib_model_param_for_norm_fp32 is not None:
......@@ -438,10 +440,9 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
def __compute_contrib_update_norm(self):
l2_norm = torch.zeros(size=[self._model_params_num], dtype=torch.float32, device='cuda')
local_contrib_l2_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_update_frag_for_norm], True)[1] ** 2
contrib_l2_norm = l2_norm[self._contrib_min_param_i:self._contrib_max_param_i+1]
contrib_l2_norm.copy_(local_contrib_l2_norm)
l2_norm.masked_scatter_(self._model_param_is_contrib, local_contrib_l2_norm)
torch.distributed.allreduce(l2_norm, group=self._ag_pg[0])
return l2_norm
return l2_norm.masked_select(self._model_param_is_contrib)
def _pipeline_step(self):
# Call step kernel once per step
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment