Commit e68ea218 authored by mohammad's avatar mohammad
Browse files

further refactor, matching old results

parent b0a3fdfe
......@@ -68,8 +68,8 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
params_with_grads = []
params_for_norm = []
grads = []
grads_for_norm = []
for param in parameters:
# Make sure the grads are in fp32
assert param.grad.type() == 'torch.cuda.FloatTensor'
......@@ -77,10 +77,11 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
is_not_shared = not hasattr(param, 'shared') or not param.shared
is_not_tp_duplicate = param.tensor_model_parallel or \
(mpu.get_tensor_model_parallel_rank() == 0)
grad = param.grad.detach()
if grad_not_none:
params_with_grads.append(param)
grads.append(grad)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
params_for_norm.append(param)
grads_for_norm.append(grad)
# Norm parameters.
max_norm = float(max_norm)
......@@ -89,8 +90,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
# Calculate norm.
if norm_type == inf:
total_norm = max(param.grad.detach().abs().max()
for param in params_for_norm)
total_norm = max(grad.abs().max() for grad in grads_for_norm)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs.
torch.distributed.all_reduce(total_norm_cuda,
......@@ -99,9 +99,9 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
total_norm = total_norm_cuda[0].item()
else:
for param in params_for_norm:
param_norm = torch.norm(param.grad.detach(), norm_type)
total_norm += param_norm.item() ** norm_type
for grad in grads_for_norm:
grad_norm = torch.norm(grad, norm_type)
total_norm += grad_norm.item() ** norm_type
# Sum across all model-parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
......@@ -112,8 +112,8 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
# Scale.
clip_coef = max_norm / (total_norm + 1.0e-6)
if clip_coef < 1.0:
for param in params_with_grads:
param.grad.detach().mul_(clip_coef)
for grad in grads:
grad.mul_(clip_coef)
return total_norm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment