You need to sign in or sign up before continuing.
Unverified Commit 580eb52b authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

Do not store input activations when not computing weight gradients (#739)



* Do not store input activations when not computing weight gradients
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* fix userbuffer tp comm overlap case
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

---------
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 8e7795e1
......@@ -292,7 +292,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight,
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
weight_t_fp8,
ln_out,
ln_out if weight.requires_grad else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
)
......@@ -369,7 +369,7 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.ub_bulk_dgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
if tp_world_size == 1 or not weight.requires_grad:
ctx.ub_bulk_dgrad = False
if ctx.ub_bulk_dgrad:
dim_size = list(ln_out.size())
......
......@@ -498,9 +498,9 @@ class _LayerNormMLP(torch.autograd.Function):
ln_weight,
mu,
rsigma,
ln_out,
ln_out if fc1_weight.requires_grad else None,
fc1_out,
gelu_out,
gelu_out if fc2_weight.requires_grad else None,
fc1_weight,
fc1_weight.main_grad if (cpu_offloading and fuse_wgrad_accumulation) else None,
fc1_weight_t_fp8,
......@@ -600,7 +600,7 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.ub_bulk_dgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
if tp_world_size == 1 or not fc1_weight.requires_grad:
ctx.ub_bulk_dgrad = False
if ctx.ub_bulk_dgrad:
dim_size = list(ln_out.size())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment