Unverified Commit d8f1e68f authored by Lifu Zhang's avatar Lifu Zhang Committed by GitHub
Browse files

fix gradient accumulation fusion for FSDP (#2371)


Signed-off-by: default avatarLifu Zhang <lifuz@login-lyris01.lyris.clusters.nvidia.com>
Co-authored-by: default avatarLifu Zhang <lifuz@login-lyris01.lyris.clusters.nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent c544ced2
......@@ -293,9 +293,9 @@ class _GroupedLinear(torch.autograd.Function):
origin_weights[i] = ctx.weight_objects[i]
ctx.weight_objects[i] = None
if ctx.fuse_wgrad_accumulation:
for i in range(N):
origin_weights[i].main_grad = main_grads[i]
if ctx.fuse_wgrad_accumulation:
for i in range(N):
origin_weights[i].main_grad = main_grads[i]
# Preprocess grad output
grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1])
......
......@@ -572,8 +572,8 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.cpu_offloading:
if ctx.grad_added_to_main_grad:
origin_weight = ctx.weight_object
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
origin_weight.main_grad = main_grad
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
origin_weight.main_grad = main_grad
# Configure Userbuffers communication (comm+GEMM overlap)
ctx.ub_obj_gradout = None
......
......@@ -508,8 +508,8 @@ class _Linear(torch.autograd.Function):
if ctx.cpu_offloading:
if ctx.grad_added_to_main_grad:
weight = ctx.weight_object
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
weight.main_grad = main_grad
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
weight.main_grad = main_grad
# Gather intermediate/activation tensors if needed
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment