Unverified Commit e056664f authored by Jason Wang's avatar Jason Wang Committed by GitHub
Browse files

fix linear sequence parallel when weight is frozen (#373)



* fix linear sequence parallel when weight is frozen
Signed-off-by: default avatarjasonwan <jasonwan@nvidia.com>

* Extend fix to all APIs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarjasonwan <jasonwan@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 6464ced7
......@@ -305,12 +305,16 @@ class _LayerNormLinear(torch.autograd.Function):
# Column Parallel Linear
# Overlap input AG with dgrad
if (not ctx.ub_bulk_dgrad) and ctx.parallel_mode == "column" and ctx.sequence_parallel:
if (weight.requires_grad
and (not ctx.ub_bulk_dgrad)
and ctx.parallel_mode == "column"
and ctx.sequence_parallel):
ln_out_total, handle = gather_along_first_dim(
ln_out, ctx.tp_group, async_op=True
)
else:
ln_out_total = ln_out
handle = None
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
......@@ -371,7 +375,7 @@ class _LayerNormLinear(torch.autograd.Function):
# Overlap dgrad-RS/AR with wgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
if not ctx.ub_bulk_dgrad:
if not ctx.ub_bulk_dgrad and handle is not None:
handle.wait()
if not ctx.ub_bulk_wgrad:
dgrad, handle = reduce_scatter_along_first_dim(
......
......@@ -471,12 +471,16 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.ub_bulk_wgrad = False
# Column Parallel Linear
# Overlap input AG with dgrad
if (not ctx.ub_bulk_dgrad) and ctx.set_parallel_mode and ctx.sequence_parallel:
if (fc1_weight.requires_grad
and (not ctx.ub_bulk_dgrad)
and ctx.set_parallel_mode
and ctx.sequence_parallel):
ln_out_total, handle = gather_along_first_dim(
ln_out, ctx.tp_group, async_op=True
)
else:
ln_out_total = ln_out
handle = None
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
......@@ -687,7 +691,7 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
# Overlap dgrad-RS/AR with wgrad
if ctx.set_parallel_mode and ctx.sequence_parallel:
if not ctx.ub_bulk_dgrad:
if not ctx.ub_bulk_dgrad and handle is not None:
handle.wait()
if not ctx.ub_bulk_wgrad:
fc1_dgrad, handle = reduce_scatter_along_first_dim(
......
......@@ -294,7 +294,7 @@ class _Linear(torch.autograd.Function):
# Column Parallel Linear
# Overlap input AG with dgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
if weight.requires_grad and ctx.parallel_mode == "column" and ctx.sequence_parallel:
if ctx.fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
inputmat_t_total, handle = gather_along_last_dim(
inputmat_t, ctx.tp_group, async_op=ctx.requires_dgrad
......@@ -306,6 +306,7 @@ class _Linear(torch.autograd.Function):
else:
inputmat_t_total = inputmat_t
inputmat_total = inputmat
handle = None
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
......@@ -353,7 +354,8 @@ class _Linear(torch.autograd.Function):
# Overlap dgrad-RS/AR with wgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
handle.wait()
if handle is not None:
handle.wait()
dgrad, handle = reduce_scatter_along_first_dim(
dgrad, ctx.tp_group, async_op=True
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment