Unverified Commit d8f1e68f authored by Lifu Zhang's avatar Lifu Zhang Committed by GitHub
Browse files

fix gradient accumulation fusion for FSDP (#2371)


Signed-off-by: default avatarLifu Zhang <lifuz@login-lyris01.lyris.clusters.nvidia.com>
Co-authored-by: default avatarLifu Zhang <lifuz@login-lyris01.lyris.clusters.nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent c544ced2
...@@ -293,9 +293,9 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -293,9 +293,9 @@ class _GroupedLinear(torch.autograd.Function):
origin_weights[i] = ctx.weight_objects[i] origin_weights[i] = ctx.weight_objects[i]
ctx.weight_objects[i] = None ctx.weight_objects[i] = None
if ctx.fuse_wgrad_accumulation: if ctx.fuse_wgrad_accumulation:
for i in range(N): for i in range(N):
origin_weights[i].main_grad = main_grads[i] origin_weights[i].main_grad = main_grads[i]
# Preprocess grad output # Preprocess grad output
grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1])
......
...@@ -572,8 +572,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -572,8 +572,8 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.cpu_offloading: if ctx.cpu_offloading:
if ctx.grad_added_to_main_grad: if ctx.grad_added_to_main_grad:
origin_weight = ctx.weight_object origin_weight = ctx.weight_object
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
origin_weight.main_grad = main_grad origin_weight.main_grad = main_grad
# Configure Userbuffers communication (comm+GEMM overlap) # Configure Userbuffers communication (comm+GEMM overlap)
ctx.ub_obj_gradout = None ctx.ub_obj_gradout = None
......
...@@ -508,8 +508,8 @@ class _Linear(torch.autograd.Function): ...@@ -508,8 +508,8 @@ class _Linear(torch.autograd.Function):
if ctx.cpu_offloading: if ctx.cpu_offloading:
if ctx.grad_added_to_main_grad: if ctx.grad_added_to_main_grad:
weight = ctx.weight_object weight = ctx.weight_object
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
weight.main_grad = main_grad weight.main_grad = main_grad
# Gather intermediate/activation tensors if needed # Gather intermediate/activation tensors if needed
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment