Unverified Commit 580eb52b authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

Do not store input activations when not computing weight gradients (#739)



* Do not store input activations when not computing weight gradients
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* fix userbuffer tp comm overlap case
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

---------
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 8e7795e1
......@@ -292,7 +292,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight,
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
weight_t_fp8,
ln_out,
ln_out if weight.requires_grad else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
)
......@@ -369,7 +369,7 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.ub_bulk_dgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
if tp_world_size == 1 or not weight.requires_grad:
ctx.ub_bulk_dgrad = False
if ctx.ub_bulk_dgrad:
dim_size = list(ln_out.size())
......
......@@ -498,9 +498,9 @@ class _LayerNormMLP(torch.autograd.Function):
ln_weight,
mu,
rsigma,
ln_out,
ln_out if fc1_weight.requires_grad else None,
fc1_out,
gelu_out,
gelu_out if fc2_weight.requires_grad else None,
fc1_weight,
fc1_weight.main_grad if (cpu_offloading and fuse_wgrad_accumulation) else None,
fc1_weight_t_fp8,
......@@ -600,7 +600,7 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.ub_bulk_dgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
if tp_world_size == 1 or not fc1_weight.requires_grad:
ctx.ub_bulk_dgrad = False
if ctx.ub_bulk_dgrad:
dim_size = list(ln_out.size())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment