Commit 03d28809 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Merge branch 'l2_grad_clip_fix' into 'master'

Reverting l2 grad optimization

See merge request ADLR/megatron-lm!74
parents 3c709cba d218f9cf
...@@ -32,16 +32,21 @@ def l2_grad_clipper(parameters, max_norm): ...@@ -32,16 +32,21 @@ def l2_grad_clipper(parameters, max_norm):
"""Efficient L2 norm gradient clipping.""" """Efficient L2 norm gradient clipping."""
overflow_buf = torch.zeros(1, dtype=torch.int, device='cuda') overflow_buf = torch.zeros(1, dtype=torch.int, device='cuda')
# Make sure we have an iterable.
if isinstance(parameters, torch.Tensor): if isinstance(parameters, torch.Tensor):
parameters = [parameters] parameters = [parameters]
# Filter parameters with gradients.
parameters_with_grads = list(filter(
lambda p: p.grad is not None, parameters))
# Filter parameters for norm calculations.
mp_rank_is_zero = (get_model_parallel_rank() == 0) mp_rank_is_zero = (get_model_parallel_rank() == 0)
parameters = list(filter(lambda p: (p.grad is not None) and parameters_for_norm = list(filter(
(p.model_parallel or mp_rank_is_zero), lambda p: p.model_parallel or mp_rank_is_zero, parameters_with_grads))
parameters)) # Calculate L2 norm.
norm, _ = multi_tensor_applier( norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm, amp_C.multi_tensor_l2norm,
overflow_buf, overflow_buf,
[parameters], [parameters_for_norm],
False # no per-parameter norm False # no per-parameter norm
) )
# Sum across all model parallel GPUs. # Sum across all model parallel GPUs.
...@@ -50,10 +55,10 @@ def l2_grad_clipper(parameters, max_norm): ...@@ -50,10 +55,10 @@ def l2_grad_clipper(parameters, max_norm):
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group()) group=get_model_parallel_group())
total_norm = norm_2.item() ** 0.5 total_norm = norm_2.item() ** 0.5
# Scale to get max_norm.
clip_coef = max_norm / (total_norm + 1e-6) clip_coef = float(max_norm) / (total_norm + 1.0e-6)
grads = [p.grad for p in parameters] grads = [p.grad for p in parameters_with_grads]
if clip_coef < 1: if clip_coef < 1.0:
multi_tensor_applier( multi_tensor_applier(
amp_C.multi_tensor_scale, amp_C.multi_tensor_scale,
overflow_buf, overflow_buf,
...@@ -96,8 +101,8 @@ def clip_grad_norm(parameters, max_norm, norm_type=2): ...@@ -96,8 +101,8 @@ def clip_grad_norm(parameters, max_norm, norm_type=2):
if clip_coef < 1: if clip_coef < 1:
for p in parameters: for p in parameters:
p.grad.data.mul_(clip_coef) p.grad.data.mul_(clip_coef)
elif norm_type == 2: #elif norm_type == 2:
total_norm = l2_grad_clipper(parameters, max_norm) # total_norm = l2_grad_clipper(parameters, max_norm)
else: else:
total_norm = 0 total_norm = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment