Unverified Commit c1a68f6c authored by Jaemin Choi's avatar Jaemin Choi Committed by GitHub
Browse files

Enable TP-AG overlap with return_layernorm_output (#727)



* Enable TP-AG overlap with return_layernorm_output
Signed-off-by: default avatarJaemin Choi <jaeminc@nvidia.com>

* Use ub_overlap_ag
Signed-off-by: default avatarJaemin Choi <jaeminc@nvidia.com>

---------
Signed-off-by: default avatarJaemin Choi <jaeminc@nvidia.com>
Co-authored-by: default avatarJaemin Choi <jaeminc@nvidia.com>
parent 8e672ff0
...@@ -107,12 +107,17 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -107,12 +107,17 @@ class _LayerNormLinear(torch.autograd.Function):
if ub_overlap_ag: if ub_overlap_ag:
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: if tp_world_size == 1 or (not is_grad_enabled):
ub_overlap_ag = False ub_overlap_ag = False
if ub_overlap_ag: if ub_overlap_ag:
dim_size = list(inputmat.size()) dim_size = list(inputmat.size())
dim_size[0] = dim_size[0] * tp_world_size dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub(ub_name+"_fprop") ub_obj_lnout = get_ub(ub_name+"_fprop")
if return_layernorm_output:
# First prepare LN output in higher precision,
# which will be later copied to a FP8 UB
ln_out = torch.empty_like(inputmat)
else:
ln_out = ub_obj_lnout.get_ubuf_output(0) ln_out = ub_obj_lnout.get_ubuf_output(0)
else: else:
ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype
...@@ -136,6 +141,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -136,6 +141,7 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_gathered = False ln_out_gathered = False
if ub_overlap_ag: if ub_overlap_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out_total = ub_obj_lnout.get_ubuf_output(1)
if not return_layernorm_output:
ln_out = torch.empty_like(ln_out) ln_out = torch.empty_like(ln_out)
if ub_obj_lnout.is_atomic_gemm(): if ub_obj_lnout.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P
...@@ -153,6 +159,16 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -153,6 +159,16 @@ class _LayerNormLinear(torch.autograd.Function):
if return_layernorm_output: if return_layernorm_output:
ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out
if fp8: if fp8:
if ub_overlap_ag:
ln_out_fp8 = ub_obj_lnout.get_ubuf_output(0)
tex.cast_to_fp8(
ln_out,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
out=ln_out_fp8)
ln_out = ln_out_fp8
else:
ln_out = tex.cast_to_fp8( ln_out = tex.cast_to_fp8(
ln_out, ln_out,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment