"configs/vscode:/vscode.git/clone" did not exist on "1401de15d079af4d9d9f995f2d57ddb6d930d7f0"
Unverified Commit 387397a2 authored by Deepak Narayanan's avatar Deepak Narayanan Committed by GitHub
Browse files

`wgrad` should be zero'ed out if a weight parameter is shared among multiple layers (#545)



wgrad should be zero'ed out if a weight parameter is shared among multiple layers
Signed-off-by: default avatarDeepak Narayanan <dnarayanan@nvidia.com>
parent 753eed31
......@@ -532,11 +532,18 @@ class _LayerNormLinear(torch.autograd.Function):
# Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(weight, 'grad_added_to_main_grad'):
weight.grad_added_to_main_grad = True
wgrad = torch.empty(weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False
)
if getattr(weight, 'zero_out_wgrad', False):
wgrad = torch.zeros(weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False
)
else:
wgrad = torch.empty(weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
else:
......
......@@ -910,11 +910,18 @@ class _LayerNormMLP(torch.autograd.Function):
# Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(fc1_weight, 'grad_added_to_main_grad'):
fc1_weight.grad_added_to_main_grad = True
fc1_wgrad = torch.empty(fc1_weight.main_grad.shape,
dtype=fc1_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False
)
if getattr(weight, 'zero_out_wgrad', False):
fc1_wgrad = torch.zeros(fc1_weight.main_grad.shape,
dtype=fc1_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False
)
else:
fc1_wgrad = torch.empty(fc1_weight.main_grad.shape,
dtype=fc1_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False
)
elif ctx.fuse_wgrad_accumulation:
fc1_wgrad = None
else:
......@@ -924,11 +931,18 @@ class _LayerNormMLP(torch.autograd.Function):
# Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(fc2_weight, 'grad_added_to_main_grad'):
fc2_weight.grad_added_to_main_grad = True
fc2_wgrad = torch.empty(fc2_weight.main_grad.shape,
dtype=fc2_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False
)
if getattr(weight, 'zero_out_wgrad', False):
fc2_wgrad = torch.zeros(fc2_weight.main_grad.shape,
dtype=fc2_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False
)
else:
fc2_wgrad = torch.empty(fc2_weight.main_grad.shape,
dtype=fc2_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False
)
elif ctx.fuse_wgrad_accumulation:
fc2_wgrad = None
else:
......
......@@ -473,11 +473,18 @@ class _Linear(torch.autograd.Function):
# Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(weight, 'grad_added_to_main_grad'):
weight.grad_added_to_main_grad = True
wgrad = torch.empty(weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False
)
if getattr(weight, 'zero_out_wgrad', False):
wgrad = torch.zeros(weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False
)
else:
wgrad = torch.empty(weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment