Commit 6191ff59 authored by Mohammad's avatar Mohammad
Browse files

found a bug in l2 norm calculation

parent b84d7a90
......@@ -76,7 +76,6 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
(mpu.get_tensor_model_parallel_rank() == 0)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
filtered_parameters.append(param)
parameters = filtered_parameters
# Norm parameters.
max_norm = float(max_norm)
......@@ -86,7 +85,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
# Calculate norm.
if norm_type == inf:
total_norm = max(param.grad.detach().abs().max()
for param in parameters)
for param in filtered_parameters)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs.
torch.distributed.all_reduce(total_norm_cuda,
......@@ -95,7 +94,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
total_norm = total_norm_cuda[0].item()
else:
for param in parameters:
for param in filtered_parameters:
param_norm = torch.norm(param.grad.detach(), norm_type)
total_norm += param_norm.item() ** norm_type
# Sum across all model-parallel GPUs.
......@@ -107,7 +106,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
# Scale.
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
if clip_coef < 1.0:
for param in parameters:
param.grad.detach().mul_(clip_coef)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment