Unverified Commit 92c1e500 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Fix incorrect variable name in LayerNormMLP backward (#548)



Fix incorrect variable name in LayerNormMLP backward
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 4f1d70fb
......@@ -914,7 +914,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(fc1_weight, 'grad_added_to_main_grad'):
fc1_weight.grad_added_to_main_grad = True
if getattr(weight, 'zero_out_wgrad', False):
if getattr(fc1_weight, 'zero_out_wgrad', False):
fc1_wgrad = torch.zeros(fc1_weight.main_grad.shape,
dtype=fc1_weight.dtype,
device=torch.cuda.current_device(),
......@@ -935,7 +935,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(fc2_weight, 'grad_added_to_main_grad'):
fc2_weight.grad_added_to_main_grad = True
if getattr(weight, 'zero_out_wgrad', False):
if getattr(fc2_weight, 'zero_out_wgrad', False):
fc2_wgrad = torch.zeros(fc2_weight.main_grad.shape,
dtype=fc2_weight.dtype,
device=torch.cuda.current_device(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment