"vscode:/vscode.git/clone" did not exist on "0fe566a020bdb90ee7bcd673524f352a9b4f5d21"
Unverified Commit 387397a2 authored by Deepak Narayanan's avatar Deepak Narayanan Committed by GitHub
Browse files

`wgrad` should be zero'ed out if a weight parameter is shared among multiple layers (#545)



wgrad should be zero'ed out if a weight parameter is shared among multiple layers
Signed-off-by: default avatarDeepak Narayanan <dnarayanan@nvidia.com>
parent 753eed31
...@@ -532,11 +532,18 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -532,11 +532,18 @@ class _LayerNormLinear(torch.autograd.Function):
# Handle custom DDP from mcore. # Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(weight, 'grad_added_to_main_grad'): if ctx.fuse_wgrad_accumulation and hasattr(weight, 'grad_added_to_main_grad'):
weight.grad_added_to_main_grad = True weight.grad_added_to_main_grad = True
wgrad = torch.empty(weight.main_grad.shape, if getattr(weight, 'zero_out_wgrad', False):
dtype=weight.dtype, wgrad = torch.zeros(weight.main_grad.shape,
device=torch.cuda.current_device(), dtype=weight.dtype,
requires_grad=False device=torch.cuda.current_device(),
) requires_grad=False
)
else:
wgrad = torch.empty(weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False
)
elif ctx.fuse_wgrad_accumulation: elif ctx.fuse_wgrad_accumulation:
wgrad = None wgrad = None
else: else:
......
...@@ -910,11 +910,18 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -910,11 +910,18 @@ class _LayerNormMLP(torch.autograd.Function):
# Handle custom DDP from mcore. # Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(fc1_weight, 'grad_added_to_main_grad'): if ctx.fuse_wgrad_accumulation and hasattr(fc1_weight, 'grad_added_to_main_grad'):
fc1_weight.grad_added_to_main_grad = True fc1_weight.grad_added_to_main_grad = True
fc1_wgrad = torch.empty(fc1_weight.main_grad.shape, if getattr(weight, 'zero_out_wgrad', False):
dtype=fc1_weight.dtype, fc1_wgrad = torch.zeros(fc1_weight.main_grad.shape,
device=torch.cuda.current_device(), dtype=fc1_weight.dtype,
requires_grad=False device=torch.cuda.current_device(),
) requires_grad=False
)
else:
fc1_wgrad = torch.empty(fc1_weight.main_grad.shape,
dtype=fc1_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False
)
elif ctx.fuse_wgrad_accumulation: elif ctx.fuse_wgrad_accumulation:
fc1_wgrad = None fc1_wgrad = None
else: else:
...@@ -924,11 +931,18 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -924,11 +931,18 @@ class _LayerNormMLP(torch.autograd.Function):
# Handle custom DDP from mcore. # Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(fc2_weight, 'grad_added_to_main_grad'): if ctx.fuse_wgrad_accumulation and hasattr(fc2_weight, 'grad_added_to_main_grad'):
fc2_weight.grad_added_to_main_grad = True fc2_weight.grad_added_to_main_grad = True
fc2_wgrad = torch.empty(fc2_weight.main_grad.shape, if getattr(weight, 'zero_out_wgrad', False):
dtype=fc2_weight.dtype, fc2_wgrad = torch.zeros(fc2_weight.main_grad.shape,
device=torch.cuda.current_device(), dtype=fc2_weight.dtype,
requires_grad=False device=torch.cuda.current_device(),
) requires_grad=False
)
else:
fc2_wgrad = torch.empty(fc2_weight.main_grad.shape,
dtype=fc2_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False
)
elif ctx.fuse_wgrad_accumulation: elif ctx.fuse_wgrad_accumulation:
fc2_wgrad = None fc2_wgrad = None
else: else:
......
...@@ -473,11 +473,18 @@ class _Linear(torch.autograd.Function): ...@@ -473,11 +473,18 @@ class _Linear(torch.autograd.Function):
# Handle custom DDP from mcore. # Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(weight, 'grad_added_to_main_grad'): if ctx.fuse_wgrad_accumulation and hasattr(weight, 'grad_added_to_main_grad'):
weight.grad_added_to_main_grad = True weight.grad_added_to_main_grad = True
wgrad = torch.empty(weight.main_grad.shape, if getattr(weight, 'zero_out_wgrad', False):
dtype=weight.dtype, wgrad = torch.zeros(weight.main_grad.shape,
device=torch.cuda.current_device(), dtype=weight.dtype,
requires_grad=False device=torch.cuda.current_device(),
) requires_grad=False
)
else:
wgrad = torch.empty(weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False
)
elif ctx.fuse_wgrad_accumulation: elif ctx.fuse_wgrad_accumulation:
wgrad = None wgrad = None
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment