Unverified Commit 580eb52b authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

Do not store input activations when not computing weight gradients (#739)



* Do not store input activations when not computing weight gradients
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* fix userbuffer tp comm overlap case
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

---------
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 8e7795e1
...@@ -292,7 +292,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -292,7 +292,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight, weight,
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
weight_t_fp8, weight_t_fp8,
ln_out, ln_out if weight.requires_grad else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
) )
...@@ -369,7 +369,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -369,7 +369,7 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group) tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1: if tp_world_size == 1 or not weight.requires_grad:
ctx.ub_bulk_dgrad = False ctx.ub_bulk_dgrad = False
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
dim_size = list(ln_out.size()) dim_size = list(ln_out.size())
......
...@@ -498,9 +498,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -498,9 +498,9 @@ class _LayerNormMLP(torch.autograd.Function):
ln_weight, ln_weight,
mu, mu,
rsigma, rsigma,
ln_out, ln_out if fc1_weight.requires_grad else None,
fc1_out, fc1_out,
gelu_out, gelu_out if fc2_weight.requires_grad else None,
fc1_weight, fc1_weight,
fc1_weight.main_grad if (cpu_offloading and fuse_wgrad_accumulation) else None, fc1_weight.main_grad if (cpu_offloading and fuse_wgrad_accumulation) else None,
fc1_weight_t_fp8, fc1_weight_t_fp8,
...@@ -600,7 +600,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -600,7 +600,7 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group) tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1: if tp_world_size == 1 or not fc1_weight.requires_grad:
ctx.ub_bulk_dgrad = False ctx.ub_bulk_dgrad = False
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
dim_size = list(ln_out.size()) dim_size = list(ln_out.size())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment