Commit b0a3fdfe authored by mohammad's avatar mohammad
Browse files

slight refactor of clip grads

parent 6191ff59
......@@ -68,14 +68,19 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
filtered_parameters = []
params_with_grads = []
params_for_norm = []
for param in parameters:
# Make sure the grads are in fp32
assert param.grad.type() == 'torch.cuda.FloatTensor'
grad_not_none = param.grad is not None
is_not_shared = not hasattr(param, 'shared') or not param.shared
is_not_tp_duplicate = param.tensor_model_parallel or \
(mpu.get_tensor_model_parallel_rank() == 0)
if grad_not_none:
params_with_grads.append(param)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
filtered_parameters.append(param)
params_for_norm.append(param)
# Norm parameters.
max_norm = float(max_norm)
......@@ -85,7 +90,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
# Calculate norm.
if norm_type == inf:
total_norm = max(param.grad.detach().abs().max()
for param in filtered_parameters)
for param in params_for_norm)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs.
torch.distributed.all_reduce(total_norm_cuda,
......@@ -94,7 +99,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
total_norm = total_norm_cuda[0].item()
else:
for param in filtered_parameters:
for param in params_for_norm:
param_norm = torch.norm(param.grad.detach(), norm_type)
total_norm += param_norm.item() ** norm_type
# Sum across all model-parallel GPUs.
......@@ -105,9 +110,9 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
# Scale.
clip_coef = max_norm / (total_norm + 1e-6)
clip_coef = max_norm / (total_norm + 1.0e-6)
if clip_coef < 1.0:
for param in parameters:
for param in params_with_grads:
param.grad.detach().mul_(clip_coef)
return total_norm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment