Commit 9a010310 authored by mohammad's avatar mohammad
Browse files

add multi-tensor-apply to clip grad

parent 345f5d0d
...@@ -99,15 +99,26 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2): ...@@ -99,15 +99,26 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
total_norm = total_norm_cuda[0].item() total_norm = total_norm_cuda[0].item()
else: else:
for grad in grads_for_norm: if norm_type == 2.0:
grad_norm = torch.norm(grad, norm_type) dummy_overflow_buf = torch.cuda.IntTensor([0])
total_norm += grad_norm.item() ** norm_type grad_norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads_for_norm],
False # no per-parameter norm
)
total_norm = grad_norm ** norm_type
else:
for grad in grads_for_norm:
grad_norm = torch.norm(grad, norm_type)
total_norm += grad_norm ** norm_type
# Sum across all model-parallel GPUs. # Sum across all model-parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) torch.distributed.all_reduce(total_norm,
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group()) group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item() ** (1.0 / norm_type) total_norm = total_norm.item() ** (1.0 / norm_type)
# Scale. # Scale.
clip_coeff = max_norm / (total_norm + 1.0e-6) clip_coeff = max_norm / (total_norm + 1.0e-6)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment