Commit 3c02784b authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fix

parent 9773218c
......@@ -424,20 +424,20 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
def __compute_contrib_param_norm(self):
if self._contrib_model_param_for_norm_fp16 is not None and self._contrib_model_param_for_norm_fp32 is not None:
gnorm_fp16 = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, [self._contrib_model_param_for_norm_fp16], True)
gnorm_fp32 = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, [self._contrib_model_param_for_norm_fp32], True)
gnorm_fp16 = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp16], True)
gnorm_fp32 = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp32], True)
gnorm = torch.empty(size=[self._contrib_model_param_for_norm_num], dtype=torch.bool, device='cuda')
gnorm.masked_scatter(self._contrib_model_param_for_norm_is_fp16, gnorm_fp16)
gnorm.masked_scatter(self._contrib_model_param_for_norm_is_fp32, gnorm_fp32)
elif self._contrib_model_param_for_norm_fp16 is not None:
gnorm = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, [self._contrib_model_param_for_norm_fp16], True)
gnorm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp16], True)
elif self._contrib_model_param_for_norm_fp32 is not None:
gnorm = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, [self._contrib_model_param_for_norm_fp32], True)
gnorm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp32], True)
return gnorm
def __compute_contrib_update_norm(self):
l2_norm = torch.zeros(size=[self._model_params_num], dtype=torch.float32, device='cuda')
local_contrib_l2_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, [self._contrib_update_frag_for_norm], True) ** 2
local_contrib_l2_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_update_frag_for_norm], True) ** 2
contrib_l2_norm = l2_norm[self._contrib_min_param_i:self._contrib_max_param_i+1]
contrib_l2_norm.copy_(local_contrib_l2_norm)
torch.distributed.allreduce(l2_norm, group=self._ag_pg[0])
......@@ -453,7 +453,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
param_norm = self.__compute_contrib_param_norm()
max_grad_norm = self.defaults['max_grad_norm']
multi_tensor_applier(self.multi_tensor_lamb_compute_update_term,
self._dummy_overflow_buf,
self._overflow_buf,
self._contrib_compute_update_term_tensor_list, # g, p, m, v, u
self._contrib_beta1,
self._contrib_beta2,
......@@ -467,7 +467,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
max_grad_norm)
upd_norm = self.__compute_contrib_update_norm()
multi_tensor_applier(self.multi_tensor_lamb_update_weights,
self._dummy_overflow_buf,
self._overflow_buf,
self._contrib_update_weights_tensor_list, # u, p, p_copy
param_norm,
upd_norm,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment