diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py index e8d0d02..369fdf6 100644 --- a/megatron/optimizer/clip_grads.py +++ b/megatron/optimizer/clip_grads.py @@ -52,6 +52,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): # - should not be a replica due to tensor model parallelism grads = [] grads_for_norm = [] + grads_in_moe = [] for param in parameters: grad_not_none = param.grad is not None is_not_shared = not hasattr(param, 'shared') or not param.shared @@ -63,7 +64,10 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): assert param.grad.type() == 'torch.cuda.FloatTensor' grads.append(grad) if grad_not_none and is_not_shared and is_not_tp_duplicate: - grads_for_norm.append(grad) + if hasattr(param, 'dp_comm') and param.dp_comm in ('none'): + grads_in_moe.append(grad) + else: + grads_for_norm.append(grad) # Norm parameters. max_norm = float(max_norm) @@ -72,6 +76,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): # Calculate norm. if norm_type == inf: + # TODO: moe total_norm = max(grad.abs().max() for grad in grads_for_norm) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) # Take max across all model-parallel GPUs. @@ -96,7 +101,19 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): # we need the pow(norm-type). total_norm = grad_norm ** norm_type + grad_norm, _ = multi_tensor_applier( + amp_C.multi_tensor_l2norm, + dummy_overflow_buf, + [grads_in_moe], + False # no per-parameter norm + ) + grad_norm = grad_norm ** norm_type + torch.distributed.all_reduce(grad_norm, + group=mpu.get_model_parallel_group()) + total_norm += grad_norm + else: + # TODO: moe for grad in grads_for_norm: grad_norm = torch.norm(grad, norm_type) total_norm += grad_norm ** norm_type