Unverified Commit df1b16da authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Fix bug in FP8 cast in LayerNormLinear/LayerNormMLP (#738)



Perform FP8 cast on gathered layernorm output in LayerNormLinear
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent c1a68f6c
...@@ -44,6 +44,8 @@ def fp8_gemm( ...@@ -44,6 +44,8 @@ def fp8_gemm(
assert fp8_meta_tensor is not None and out_index is not None assert fp8_meta_tensor is not None and out_index is not None
assert_dim_for_fp8_exec(A) assert_dim_for_fp8_exec(A)
assert_dim_for_fp8_exec(B) assert_dim_for_fp8_exec(B)
assert A.dtype == torch.uint8
assert B.dtype == torch.uint8
if out is None: if out is None:
out = torch.empty( out = torch.empty(
......
...@@ -169,12 +169,19 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -169,12 +169,19 @@ class _LayerNormLinear(torch.autograd.Function):
out=ln_out_fp8) out=ln_out_fp8)
ln_out = ln_out_fp8 ln_out = ln_out_fp8
else: else:
ln_out = tex.cast_to_fp8( ln_out_total = tex.cast_to_fp8(
ln_out, ln_out_total,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
) )
if ln_out_gathered:
rank = torch.distributed.get_rank(tp_group)
slice_start = rank * ln_out.size(0)
slice_end = (rank + 1) * ln_out.size(0)
ln_out = ln_out_total[slice_start:slice_end, ...]
else:
ln_out = ln_out_total
if fp8: if fp8:
bias_dtype = ( bias_dtype = (
......
...@@ -187,12 +187,27 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -187,12 +187,27 @@ class _LayerNormMLP(torch.autograd.Function):
if return_layernorm_output: if return_layernorm_output:
ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out
if fp8: if fp8:
if ub_overlap_ag:
ln_out = tex.cast_to_fp8( ln_out = tex.cast_to_fp8(
ln_out, ln_out,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
) )
else:
ln_out_total = tex.cast_to_fp8(
ln_out_total,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
if ln_out_gathered:
rank = torch.distributed.get_rank(tp_group)
slice_start = rank * ln_out.size(0)
slice_end = (rank + 1) * ln_out.size(0)
ln_out = ln_out_total[slice_start:slice_end, ...]
else:
ln_out = ln_out_total
if fp8: if fp8:
bias_dtype = ( bias_dtype = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment