"...models/git@developer.sourcefind.cn:jerrrrry/infinilm.git" did not exist on "bf74389d2fab76be0e7a82b85fbd13ee3f36cef7"
Unverified Commit c706ff8d authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

Returning an empty tensor of param dtype for wgrad (#507)



* Returning an empty tensor of param dtype for wgrad
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@computelab-frontend-4-ub22.nvidia.com>

* lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@computelab-frontend-4-ub22.nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@computelab-frontend-4-ub22.nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 50ff8116
......@@ -525,6 +525,11 @@ class _LayerNormLinear(torch.autograd.Function):
# Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(weight, 'grad_added_to_main_grad'):
weight.grad_added_to_main_grad = True
wgrad = torch.empty(weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
else:
......
......@@ -879,6 +879,11 @@ class _LayerNormMLP(torch.autograd.Function):
# Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(fc1_weight, 'grad_added_to_main_grad'):
fc1_weight.grad_added_to_main_grad = True
fc1_wgrad = torch.empty(fc1_weight.main_grad.shape,
dtype=fc1_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False
)
elif ctx.fuse_wgrad_accumulation:
fc1_wgrad = None
else:
......@@ -888,6 +893,11 @@ class _LayerNormMLP(torch.autograd.Function):
# Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(fc2_weight, 'grad_added_to_main_grad'):
fc2_weight.grad_added_to_main_grad = True
fc2_wgrad = torch.empty(fc2_weight.main_grad.shape,
dtype=fc2_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False
)
elif ctx.fuse_wgrad_accumulation:
fc2_wgrad = None
else:
......
......@@ -465,6 +465,11 @@ class _Linear(torch.autograd.Function):
# Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(weight, 'grad_added_to_main_grad'):
weight.grad_added_to_main_grad = True
wgrad = torch.empty(weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment