Unverified Commit 387397a2 authored by Deepak Narayanan's avatar Deepak Narayanan Committed by GitHub
Browse files

`wgrad` should be zero'ed out if a weight parameter is shared among multiple layers (#545)



wgrad should be zero'ed out if a weight parameter is shared among multiple layers
Signed-off-by: default avatarDeepak Narayanan <dnarayanan@nvidia.com>
parent 753eed31
......@@ -532,6 +532,13 @@ class _LayerNormLinear(torch.autograd.Function):
# Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(weight, 'grad_added_to_main_grad'):
weight.grad_added_to_main_grad = True
if getattr(weight, 'zero_out_wgrad', False):
wgrad = torch.zeros(weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False
)
else:
wgrad = torch.empty(weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
......
......@@ -910,6 +910,13 @@ class _LayerNormMLP(torch.autograd.Function):
# Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(fc1_weight, 'grad_added_to_main_grad'):
fc1_weight.grad_added_to_main_grad = True
if getattr(weight, 'zero_out_wgrad', False):
fc1_wgrad = torch.zeros(fc1_weight.main_grad.shape,
dtype=fc1_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False
)
else:
fc1_wgrad = torch.empty(fc1_weight.main_grad.shape,
dtype=fc1_weight.dtype,
device=torch.cuda.current_device(),
......@@ -924,6 +931,13 @@ class _LayerNormMLP(torch.autograd.Function):
# Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(fc2_weight, 'grad_added_to_main_grad'):
fc2_weight.grad_added_to_main_grad = True
if getattr(weight, 'zero_out_wgrad', False):
fc2_wgrad = torch.zeros(fc2_weight.main_grad.shape,
dtype=fc2_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False
)
else:
fc2_wgrad = torch.empty(fc2_weight.main_grad.shape,
dtype=fc2_weight.dtype,
device=torch.cuda.current_device(),
......
......@@ -473,6 +473,13 @@ class _Linear(torch.autograd.Function):
# Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(weight, 'grad_added_to_main_grad'):
weight.grad_added_to_main_grad = True
if getattr(weight, 'zero_out_wgrad', False):
wgrad = torch.zeros(weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False
)
else:
wgrad = torch.empty(weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment